桁DPの典型(?) – CSA Round#78 Banned Digits

今日の問題

今日の問題はこれ。

CSA Round#78 Banned Digits

 

問題文を要約するとこんな感じ。

\(0 \)以上\(N \)未満の数字のうち, 与えられる数字の組\(S\)のみで構成される数の総数を求めよ。

\(1 \leq N \leq 10^{18} \)

桁に関する縛りがある上, 制約が非常に大きいので, いわゆる桁DPと呼ばれる動的計画法を行うのが良さそう。

 


桁DPの基本

桁DPでは, [上からi桁目まで見たときに][なんらかの縛りを満たすか?] のような感じの添え字を持ったテーブルを用いる。(ここでは, 桁は上から1,2, …という具合に数えることにしよう)

今回の縛りとしては,

  • 0以上N未満であること
  • 使える数字が限られている

である。すると, 桁DPテーブルは, 以下のように定まっていく。

  1. dp[0] = 1, 他を 0 として初期化する。
  2. i桁目までの数がNと全く同じ場合, i+1 桁目で使える数は N[i+1桁目] 以下になる。そうでないなら, すべてのS内の数字を選択できる(上位桁が0で埋まっている場合も許す。9 は “009” のような形で扱われる)。
  3. 使える数字の数のぶん, dp[i+1]にdp[i]の値を加算する。

実際にやる際は, 「Nギリギリか?」をboolの添え字として持つことで, 2の場合分けを行える。初期化はdp[0][1] = 1

 

例えば, S = {0, 1, 2, 4} , N= 21 として考えてみよう。

 

1桁目では, 選択可能な数は 0, 1, 2の3つ。

0, 1を選択すると, i桁目までの数がNと全く同じ(Nギリギリ)状態でなくなる。

よって, dp[1][0] += dp[0][1] を2回繰り返し, dp[1][0] = 2

一方, 2を選択すると, Nギリギリの状態のままなので,

dp[1][1] += dp[0][1] を1回繰り返し, dp[1][1] = 1

 

2桁目では, まずNギリギリでない場合, 選択可能な数は任意のS内の数 4種類になる。

よって, dp[2][0] += dp[1][0]を4回繰り返し, dp[2][0] = 8 となる。

一方, Nギリギリの場合, 選択可能な数は0, 1の2つ。

0を選択すると, ギリギリが続くので dp[2][1] += dp[1][1] ×1で dp[2][1] = 1

1を選択すると, ギリギリでなくなるので dp[2][0] += dp[1][1] ×1で dp[2][0] = 8+1 = 9

 

よって答えは, dp[2][1]dp[2][0] の和で 10 ということになる。列挙すると, 0, 1, 2, 4, 10, 11, 12, 14, 20, 21 と確かに10個である(このやり方ではNを含んでしまうので, 問題ではN-1してからやるとよい)。

 

「上位桁0を許す」の罠

先述したように, 桁DPは上位桁0を許して数え上げている(これによって, Nがx桁でもx-1桁以下の数を数え上げられる)。

Sに0が含まれないときはどうだろう? 0が使えないので, 上位桁0が許されなくなってしまう。ということは, 場合分けが必要になってしまう。

ここでもうひとつ, [上位桁が0のみか] という添字を増やす。こうすることで, このフラグが立っているときに, 次の桁の値として0が選択できるようになる。

 

解答例

#include <bits/stdc++.h>
#define REP(i,n) for (int i=0;i<(n);i++)
#define FOR(i,s,e) for (int i=s;i<(e);i++)
#define All(v) (v).begin(),(v).end()
#define mp(a,b) make_pair(a,b)
#define pb(a) push_back(a)
#define chmax(x, y) x = max(x, y)
#define chmin(x, y) x = min(x, y)
#define int long long
using namespace std;
typedef long long llint;
typedef pair<int, int> P;
const int MOD = (int)1e9 + 7;
const int INF = (int)1e18 * 5;

int dp[25][2][2];

signed main(){
    cin.tie(0);
    ios::sync_with_stdio(false);

    vector<int> digit;
    REP(i, 10){
        int x;
        cin >> x;
        if(!x) digit.pb(i);
    }
    int Nll;
    cin >> Nll;
    
    string N = to_string(--Nll);
    int len = N.size();
    memset(dp, 0, sizeof(dp));
    // dp[何桁目まで見たか][直前の桁までギリギリか][上位桁が0のみか]
    dp[0][1][1] = 1;
    REP(i, len)REP(j, 2)REP(k, 2){
        if(dp[i][j][k] == 0) continue;
        int lim = (j ? N[i]-'0' : 9);
        if(digit[0] != 0 && k && i+1 != len){
           dp[i+1][j && 0==lim][k] += dp[i][j][k];
        }
        for(auto d : digit){
            if(d > lim) break;
            dp[i+1][j && d==lim][0] += dp[i][j][k];
        }
    }
    int ans = 0;
    REP(j, 2)REP(k, 2){
        ans += dp[len][j][k];
    }
    cout << ans << "\n";
    
    return 0;
}

 

桁DP, わかればやるだけでは(いいえ)

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です