Contest de seleccion IOI 2022: C. Producto de Digitos

Problem Statement: Given integers \(n\) and \(b\), count how many integers \(x\) exists where \(1 \leq x \leq n\), \(x - P(x) = b\) and \(P(x)\) is the product of the digits of \(x\).

A straight forward observation is that \(x - P(x)\) is counted only for one possible \(b\). In other words, its value does not depend on the input. Is there any interesting property of \(x-P(x)\) that we can use to solve the problem? Let's take a look on \(x - P(x)\) for \(0 \leq x \leq 30\) using the following snippet.

from functools import reduce
from operator import mul

def p(n):
    return reduce(mul, map(int, str(n)))

return (
    [("x", "px", "x - p(x)")] +
    [(x, p(x), x - p(x)) for x in range(51)]
)
x px x - p(x)
0 0 0
1 1 0
2 2 0
3 3 0
4 4 0
5 5 0
6 6 0
7 7 0
8 8 0
9 9 0
10 0 10
11 1 10
12 2 10
13 3 10
14 4 10
15 5 10
16 6 10
17 7 10
18 8 10
19 9 10
20 0 20
21 2 19
22 4 18
23 6 17
24 8 16
25 10 15
26 12 14
27 14 13
28 16 12
29 18 11
30 0 30
31 3 28
32 6 26
33 9 24
34 12 22
35 15 20
36 18 18
37 21 16
38 24 14
39 27 12
40 0 40
41 4 37
42 8 34
43 12 31
44 16 28
45 20 25
46 24 22
47 28 19
48 32 16
49 36 13
50 0 50

The result is 0 if \(x\) has only one digit or contains at least one digit 0. The result is 10 if \(x\) has two digits. Besides that, there are some repetitions as in \(x=49\) and \(x=27\) which means that for \(n=49\) and \(b=13\) we should return at least 2. What is the upper/lower bound for the expected result? This question is interesting because it the answer is the output can be a very large number that we can't iterate to compute, or the output is computable number. The former implies that we have to find some sort of formula or an algorithm with logarithmic complexity, while the later means that there is a way to compute the answer in a reasonable way. So, the question becomes: How many different \(x-P(x)\) exist?

from functools import reduce, cache
from operator import mul

@cache
def p(n):
    return reduce(mul, map(int, str(n)))

return len({x - p(x) for x in range(10**7)})
6953658

That is a big number which would require either a formula to compute or a logarithmic algorithm. Looking to the results of \(x-P(x)\), the problem seems unlikely to be solved with a formula. Why? This \(P(x)\) depends on the digits of \(x\) and how would a close formula handle this property? Wait. What do we know about \(P(x)\)? It is a number smaller than \(x\). It can be the same for different \(x\), for example \(x=123\) or \(x=321\) has the same \(P(x)=6\). How many different \(P(x)\) exist?

from functools import reduce, cache
from operator import mul

@cache
def p(n):
    return reduce(mul, map(int, str(n)))

def count_p(n):
    return len({p(x) for x in range(n)})

return [
    ("interval", "# of distinct P(x)"),
    ("$0...10^4$", count_p(10**4)),
    ("$0...10^5$", count_p(10**5)),
    ("$0...10^6$", count_p(10**6)),
    ("$0...10^7$", count_p(10**7)),
    ("$0...10^8$", count_p(10**8)),
]
interval # of distinct P(x)
\(0...10^4\) 226
\(0...10^5\) 442
\(0...10^6\) 785
\(0...10^7\) 1297
\(0...10^8\) 2026

There are not many of them, but how could we generate them efficiently? For that, we can use Dynamic Programming by Digit to go over all possible numbers \(x\) and compute \(P(x)\) of each one of them. The function rec try all possible digits for the \(i\)th digit and returns all \(P(x)\) that could be generated so far. With all possible \(P(x)\) generated, we can iterate over each one of them, say \(z\) and count those where \(z+b

With \(k\) being the number of different values of \(P(x)\):

Python solution

from functools import cache, reduce
from operator import mul


def gen(n):
    digits = list(map(int, str(n)))

    @cache
    def rec(i, is_smaller, is_first_digit):
        if i == len(digits):
            return set([1] if not is_first_digit else [])

        ans = set()
        lim = 9 if is_smaller else digits[i]
        for k in range(lim + 1):
            for j in rec(
                i + 1, is_smaller or (k < digits[i]), is_first_digit and k == 0
            ):
                if is_first_digit and k == 0:
                    ans.add(j)
                else:
                    ans.add(k * j)
        return ans

    return rec(0, False, True)


def p(n):
    return reduce(mul, map(int, str(n)))


def solve(n, b):
    ans = 0
    for px in gen(n):
        if b + px <= n and p(b + px) == px:
            ans += 1
    return ans


assert solve(99999999999, 502) == 12
assert solve(999, 434) == 2
assert solve(255, 15) == 2
assert solve(9999999999, 1) == 0

C++ solution

#include <set>
#include <vector>
#include <cstdio>
#include <algorithm>
#include <string.h>
#include <iostream>
#include <assert.h>

using namespace std;

#define ll long long

int vis[12][2][2];
set<ll> memo[12][2][2];
vector<int> digits;

set<ll> gen(int i, int is_smaller, int is_first) {
  set<ll> ans;
  if (i == digits.size()) {
    if (!is_first) {
      ans.insert(1);
    }
    return ans;
  }
  if (vis[i][is_smaller][is_first]) {
    return memo[i][is_smaller][is_first];
  }

  int lim = is_smaller ? 9 : digits[i];
  for (int k = 0; k <= lim; k++) {
    set<ll> cur = gen(i + 1, is_smaller || (k < digits[i]), is_first && k == 0);
    for (set<ll>::iterator it = cur.begin(); it != cur.end(); it++) {
      ll j = *it;
      ans.insert(is_first && k == 0 ? j : k * j);
    }
  }
  memo[i][is_smaller][is_first] = ans;
  vis[i][is_smaller][is_first] = 1;
  return ans;
}

ll p(ll x) {
  if (x < 10) {
    return x;
  }
  int ans = 1;
  while (x > 0) {
    ans *= (x % 10);
    x = x / 10;
  }
  return ans;
}

vector<int> to_digits(ll n) {
  vector<int> digits;

  if (n < 10) {
    digits.push_back(n);
    return digits;
  }
  while (n > 0) {
    digits.push_back(n % 10);
    n = n / 10;
  }
  reverse(digits.begin(), digits.end());
  return digits;
}

int solve(ll n, ll b) {
  memset(vis, 0, sizeof(vis));
  digits = to_digits(n);

  set<ll> candidates = gen(0, 0, 1);

  int ans = 0;
  for (set<ll>::iterator it = candidates.begin(); it != candidates.end(); it++) {
    ll px = *it;
    ll x = b + px;
    if (x <= n && p(x) == px) {
      ans++;
    }
  }
  return ans;
}

int main() {

  assert(solve(99999999999, 502) == 12);
  assert(solve(999, 434) == 2);
  assert(solve(255, 15) == 2);
  assert(solve(9999999999, 1) == 0);

  ll n, b;
  scanf("%lld %lld", &n, &b);
  to_digits(n);
  printf("%d\n", solve(n, b));

  return 0;
}