今日の問題
今日の問題はこれ。
問題文を要約するとこんな感じ。
\(0 \)以上\(N \)未満の数字のうち, 与えられる数字の組\(S\)のみで構成される数の総数を求めよ。
\(1 \leq N \leq 10^{18} \)
桁に関する縛りがある上, 制約が非常に大きいので, いわゆる桁DPと呼ばれる動的計画法を行うのが良さそう。
桁DPの基本
桁DPでは, [上からi桁目まで見たときに][なんらかの縛りを満たすか?] のような感じの添え字を持ったテーブルを用いる。(ここでは, 桁は上から1,2, …という具合に数えることにしよう)
今回の縛りとしては,
- 0以上N未満であること
- 使える数字が限られている
である。すると, 桁DPテーブルは, 以下のように定まっていく。
dp[0] = 1
, 他を 0 として初期化する。- i桁目までの数がNと全く同じ場合, i+1 桁目で使える数は N[i+1桁目] 以下になる。そうでないなら, すべてのS内の数字を選択できる(上位桁が0で埋まっている場合も許す。9 は “009” のような形で扱われる)。
- 使える数字の数のぶん, dp[i+1]にdp[i]の値を加算する。
実際にやる際は, 「Nギリギリか?」をboolの添え字として持つことで, 2の場合分けを行える。初期化はdp[0][1] = 1
。
例えば, S = {0, 1, 2, 4} , N= 21 として考えてみよう。
1桁目では, 選択可能な数は 0, 1, 2の3つ。
0, 1を選択すると, i桁目までの数がNと全く同じ(Nギリギリ)状態でなくなる。
よって, dp[1][0] += dp[0][1]
を2回繰り返し, dp[1][0] = 2
。
一方, 2を選択すると, Nギリギリの状態のままなので,
dp[1][1] += dp[0][1]
を1回繰り返し, dp[1][1] = 1
。
2桁目では, まずNギリギリでない場合, 選択可能な数は任意のS内の数 4種類になる。
よって, dp[2][0] += dp[1][0]
を4回繰り返し, dp[2][0] = 8
となる。
一方, Nギリギリの場合, 選択可能な数は0, 1の2つ。
0を選択すると, ギリギリが続くので dp[2][1] += dp[1][1]
×1で dp[2][1] = 1
。
1を選択すると, ギリギリでなくなるので dp[2][0] += dp[1][1]
×1で dp[2][0] = 8+1 = 9
。
よって答えは, dp[2][1]
と dp[2][0]
の和で 10 ということになる。列挙すると, 0, 1, 2, 4, 10, 11, 12, 14, 20, 21 と確かに10個である(このやり方ではNを含んでしまうので, 問題ではN-1してからやるとよい)。
「上位桁0を許す」の罠
先述したように, 桁DPは上位桁0を許して数え上げている(これによって, Nがx桁でもx-1桁以下の数を数え上げられる)。
Sに0が含まれないときはどうだろう? 0が使えないので, 上位桁0が許されなくなってしまう。ということは, 場合分けが必要になってしまう。
ここでもうひとつ, [上位桁が0のみか] という添字を増やす。こうすることで, このフラグが立っているときに, 次の桁の値として0が選択できるようになる。
解答例
#include <bits/stdc++.h> #define REP(i,n) for (int i=0;i<(n);i++) #define FOR(i,s,e) for (int i=s;i<(e);i++) #define All(v) (v).begin(),(v).end() #define mp(a,b) make_pair(a,b) #define pb(a) push_back(a) #define chmax(x, y) x = max(x, y) #define chmin(x, y) x = min(x, y) #define int long long using namespace std; typedef long long llint; typedef pair<int, int> P; const int MOD = (int)1e9 + 7; const int INF = (int)1e18 * 5; int dp[25][2][2]; signed main(){ cin.tie(0); ios::sync_with_stdio(false); vector<int> digit; REP(i, 10){ int x; cin >> x; if(!x) digit.pb(i); } int Nll; cin >> Nll; string N = to_string(--Nll); int len = N.size(); memset(dp, 0, sizeof(dp)); // dp[何桁目まで見たか][直前の桁までギリギリか][上位桁が0のみか] dp[0][1][1] = 1; REP(i, len)REP(j, 2)REP(k, 2){ if(dp[i][j][k] == 0) continue; int lim = (j ? N[i]-'0' : 9); if(digit[0] != 0 && k && i+1 != len){ dp[i+1][j && 0==lim][k] += dp[i][j][k]; } for(auto d : digit){ if(d > lim) break; dp[i+1][j && d==lim][0] += dp[i][j][k]; } } int ans = 0; REP(j, 2)REP(k, 2){ ans += dp[len][j][k]; } cout << ans << "\n"; return 0; }
桁DP, わかればやるだけでは(いいえ)