今日の問題
今日の問題はこれ。
問題文を要約するとこんな感じ。
\(0 \)以上\(N \)未満の数字のうち, 与えられる数字の組\(S\)のみで構成される数の総数を求めよ。
\(1 \leq N \leq 10^{18} \)
桁に関する縛りがある上, 制約が非常に大きいので, いわゆる桁DPと呼ばれる動的計画法を行うのが良さそう。
桁DPの基本
桁DPでは, [上からi桁目まで見たときに][なんらかの縛りを満たすか?] のような感じの添え字を持ったテーブルを用いる。(ここでは, 桁は上から1,2, …という具合に数えることにしよう)
今回の縛りとしては,
- 0以上N未満であること
- 使える数字が限られている
である。すると, 桁DPテーブルは, 以下のように定まっていく。
- dp[0] = 1, 他を 0 として初期化する。
- i桁目までの数がNと全く同じ場合, i+1 桁目で使える数は N[i+1桁目] 以下になる。そうでないなら, すべてのS内の数字を選択できる(上位桁が0で埋まっている場合も許す。9 は “009” のような形で扱われる)。
- 使える数字の数のぶん, dp[i+1]にdp[i]の値を加算する。
実際にやる際は, 「Nギリギリか?」をboolの添え字として持つことで, 2の場合分けを行える。初期化はdp[0][1] = 1。
例えば, S = {0, 1, 2, 4} , N= 21 として考えてみよう。
1桁目では, 選択可能な数は 0, 1, 2の3つ。
0, 1を選択すると, i桁目までの数がNと全く同じ(Nギリギリ)状態でなくなる。
よって, dp[1][0] += dp[0][1] を2回繰り返し, dp[1][0] = 2 。
一方, 2を選択すると, Nギリギリの状態のままなので,
dp[1][1] += dp[0][1] を1回繰り返し, dp[1][1] = 1 。
2桁目では, まずNギリギリでない場合, 選択可能な数は任意のS内の数 4種類になる。
よって, dp[2][0] += dp[1][0]を4回繰り返し, dp[2][0] = 8 となる。
一方, Nギリギリの場合, 選択可能な数は0, 1の2つ。
0を選択すると, ギリギリが続くので dp[2][1] += dp[1][1] ×1で dp[2][1] = 1 。
1を選択すると, ギリギリでなくなるので dp[2][0] += dp[1][1] ×1で dp[2][0] = 8+1 = 9 。
よって答えは, dp[2][1] と dp[2][0] の和で 10 ということになる。列挙すると, 0, 1, 2, 4, 10, 11, 12, 14, 20, 21 と確かに10個である(このやり方ではNを含んでしまうので, 問題ではN-1してからやるとよい)。
「上位桁0を許す」の罠
先述したように, 桁DPは上位桁0を許して数え上げている(これによって, Nがx桁でもx-1桁以下の数を数え上げられる)。
Sに0が含まれないときはどうだろう? 0が使えないので, 上位桁0が許されなくなってしまう。ということは, 場合分けが必要になってしまう。
ここでもうひとつ, [上位桁が0のみか] という添字を増やす。こうすることで, このフラグが立っているときに, 次の桁の値として0が選択できるようになる。
解答例
- #include <bits/stdc++.h>
- #define REP(i,n) for (int i=0;i<(n);i++)
- #define FOR(i,s,e) for (int i=s;i<(e);i++)
- #define All(v) (v).begin(),(v).end()
- #define mp(a,b) make_pair(a,b)
- #define pb(a) push_back(a)
- #define chmax(x, y) x = max(x, y)
- #define chmin(x, y) x = min(x, y)
- #define int long long
- using namespace std;
- typedef long long llint;
- typedef pair<int, int> P;
- const int MOD = (int)1e9 + 7;
- const int INF = (int)1e18 * 5;
- int dp[25][2][2];
- signed main(){
- cin.tie(0);
- ios::sync_with_stdio(false);
- vector<int> digit;
- REP(i, 10){
- int x;
- cin >> x;
- if(!x) digit.pb(i);
- }
- int Nll;
- cin >> Nll;
- string N = to_string(--Nll);
- int len = N.size();
- memset(dp, 0, sizeof(dp));
- // dp[何桁目まで見たか][直前の桁までギリギリか][上位桁が0のみか]
- dp[0][1][1] = 1;
- REP(i, len)REP(j, 2)REP(k, 2){
- if(dp[i][j][k] == 0) continue;
- int lim = (j ? N[i]-'0' : 9);
- if(digit[0] != 0 && k && i+1 != len){
- dp[i+1][j && 0==lim][k] += dp[i][j][k];
- }
- for(auto d : digit){
- if(d > lim) break;
- dp[i+1][j && d==lim][0] += dp[i][j][k];
- }
- }
- int ans = 0;
- REP(j, 2)REP(k, 2){
- ans += dp[len][j][k];
- }
- cout << ans << "\n";
- return 0;
- }
桁DP, わかればやるだけでは(いいえ)