気ままに実装する機械学習

機械学習に興味のある大学院生によるブログです. 機械学習以外のトピック多めです.

MENU

AtCoderの Typical DP Contestを解いてみた (E 数)

今回もTypical DP Contestのアウトプットに関する記事です.

今回解いた問題はE 数です

前回までの記事は以下の通りです.

  1. A コンテスト
  2. B ゲーム
  3. C トーナメント
  4. D サイコロ

E 数

今回は「E 数」という問題を解きました.

問題の概要

 N以下の正整数であって、10進数表記したときの各桁の数の和が Dの倍数であるものの個数を mod 1,000,000,007で求めよ.

制約

  •  1 \leq N \leq 10^{10000}

  •  1 \leq D \leq 100

解法

愚直にループを回すと、N10^{10000}のケースが考えられるので到底間に合いません.

したがってDPで解くことを考えます.

試行錯誤して漸化式をたてようと試みたのですが、うまくたてられませんでした.

なので、(他の人が解いたこの問題の解法は避けるように)ググってみた結果、桁DPというものが使えそうなので、そちらの勉強から始めました.

(桁DPはよく使う手法らしいのですが恥ずかしながら初めて聞きました... 勉強不足ですね...)

寄り道 (桁DPの勉強)

私は、こちらの記事を参考にさせていただき、一通り実装することで、桁DPの考え方を学びました.

pekempey.hatenablog.com

非常にわかりやすくまとまっているので、桁DPってなんぞやって方は、本記事の続きを見る前に一読することをお勧めします.

(桁DPについての解説などは本記事では省きます.)

解法 (本題に戻る)

(ほとんど上記の参考記事に記載されているプログラムの書き方で、アルゴリズムを含め自分で考えたのはほんの少しの部分です.)

桁DPの紹介記事にあった通り、 $$ dp[i][j] : \begin{cases} i : 上からi桁目まで参照している \\ j : N未満であることが確定しているかどうか\\ (j = 0 で確定、j = 1で確定していない) \end{cases} $$

というところから考え始めます.

本問題では各桁の和がDである数の総数を求めなければいけないので、追加要素として、Dで割った余りという状態を持たせます.

つまり、 $$ dp[i][j][k] : \begin{cases} i : 上からi桁目まで参照している \\ j : N未満であることが確定しているかどうか\\ (j = 0 で確定、j = 1で確定していない)\\ k : 各桁の和をDで割った余り \end{cases} $$

とします.

漸化式は、まず桁DPの基本のところから考えると、dp[i][j]に対して、 $$ dp[i + 1][j \,|| \, d < lim] += dp[i][j] $$

となります.

ただし、limという数字はそれまでにNを超えていたら次の桁の数に制限が加わるので、その次の桁に使える数の上限を表しています.

例えば、 N = 123に対して、現在12?とみていたら、?には0~9の全ての数を使えません.

 Nを超えていはいけないので、?には0 ~ 3のいずれかしか入れることができないからです.

したがって、この例で言うと lim = 3です.

次に本問題のために拡張したDPの漸化式をたてます. $$ dp[i + 1][j \,||\, d \,< \, lim][(k + d) \, \% \, D] += dp[i][j][k] $$

3つめの要素では1桁前の総和に使用可能な数を足し合わせ、それをDで割った余りで分類しています.

最終出力は、dp[N][j][0] - 1\, (j = 0, 1)です.

(なぜなら、今回は各桁の総和をDで割った余りが0となる正整数の個数を求められているからです.)

ちなみに最後にマイナス1をしているのは0Dの倍数であるとして数え上げているためです.

ソース

使用言語はC++です.

#include <iostream>
#include <string>
#include <vector>
#include <cstdlib>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <map>
#include <stack>
#include <queue>
#include <set>
#include <cstring>

using namespace std;
// ascending order
#define vsort(v) sort(v.begin(), v.end())
// descending order
#define vsort_r(v) sort(v.begin(), v.end(), greater<int>())
#define vunique(v) unique(v.begin(), v.end())
#define mp make_pair
#define ts(x) to_string(x)
#define rep(i, a, b) for(int i = (int)a; i < (int)b; i++)
#define repm(i, a, b) for(int i = (int)a; i > (int)b; i--)
#define bit(a) bitset<8>(a)
typedef long long ll;
typedef pair<int, int> P;
const ll INF = 1e18;

int main(){
    cin.tie(0);
    ios::sync_with_stdio(false);
    int D;
    string N;
    cin >> D >> N;
    ll n = N.length();
    ll mod = 1e9 + 7;

    int dp[n + 1][2][D];
    memset(dp, 0, sizeof(dp));
    dp[0][0][0] = 1;
    
    rep(i, 0, n) rep(j, 0, 2) rep(k, 0, D) {
        int lim = j ? 9 : N[i] - '0';
        rep(d, 0, lim + 1) (dp[i + 1][j || d < lim][(k + d) % D] += dp[i][j][k]) %= mod;
    }

    int ans = 0;
    rep(j, 0, 2) (ans += dp[n][j][0]) %= mod;
    cout << (ans - 1) << endl;
}