Skip to the content.

:warning: Dirichlet 積の prefix sum
(number-theory/dirichlet-series-prefix-sum.hpp)

Code

#pragma once

// https://maspypy.com/dirichlet-%E7%A9%8D%E3%81%A8%E3%80%81%E6%95%B0%E8%AB%96%E9%96%A2%E6%95%B0%E3%81%AE%E7%B4%AF%E7%A9%8D%E5%92%8C
/**
 * @brief Dirichlet 積の prefix sum
 */
template <class mint>
struct DirichletSeriesPrefixSum {
  using DP = DirichletSeriesPrefixSum<mint>;
  using u64 = uint64_t;

 public:
  u64 N;
  size_t K, L;
  vector<mint> a, s, S;

  DirichletSeriesPrefixSum(u64 n)
      : N(n),
        K(max(sqrt(N), pow(max(1.0, N / log(N)), 2.0 / 3))),
        L((N - 1) / K + 1) {
    a.assign(K + 1, 0);
    s.assign(K + 1, 0);
    S.assign(L + 1, 0);
  }
  DirichletSeriesPrefixSum(const DP &d)
      : N(d.N), K(d.K), L(d.L), a(d.a), s(d.s), S(d.S) {}
  static DP id(u64 n) {
    DP z(n);
    return z.add(1, 1);
  }
  // {floor(n/k):1<=k<=n}={x[0],...,x[m-1]}, x[i-1]<x[i]
  // a[i]=sum_{1<=j<=x[i]}f(j)
  DirichletSeriesPrefixSum(u64 n, const vector<mint> &a) {

  }

  // zeta(s-k)
  static DP zeta(u64 n, size_t k = 0) {
    assert(k <= 2);
    DP z(n);
    for (size_t i = 1; i <= z.K; i++) z.a[i] = mint(i).pow(k);
    for (size_t i = 1; i <= z.K; i++) z.s[i] = z.s[i - 1] + z.a[i];
    for (size_t i = 1; i <= z.L; i++) {
      u64 x = n / i;
      if (k == 0)
        z.S[i] = x;
      else if (k == 1)
        z.S[i] = mint(x | 1) * mint((x + 1) / 2);
      else if (k == 2) {
        array<u64, 3> xs{x, x + 1, 2 * x + 1};
        xs[x & 1] /= 2;
        xs[(3 - (x % 3)) % 3] /= 3;
        z.S[i] = mint(xs[0]) * mint(xs[1]) * mint(xs[2]);
      }
    }
    return z;
  }

  // n(n+1)...(n+k-1)/k!
  static DP rising(u64 n, size_t k) {
    if (k == 0) return zeta(n);
    DP z(n);
    for (size_t i = 1; i <= z.K; i++) z.a[i] = rising_mint(i, k);
    for (size_t i = 1; i <= z.K; i++) z.s[i] = z.s[i - 1] + z.a[i];
    for (size_t i = 1; i <= z.L; i++) {
      u64 x = n / i;
      z.S[i] = rising_mint(x, k + 1);
    }
    return z;
  }
  // n(n-1)...(n-k+1)/k!
  static DP falling(u64 n, size_t k) {
    if (k == 0) return zeta(n);
    DP z(n);
    for (size_t i = 1; i <= z.K; i++) z.a[i] = falling_mint(i, k);
    for (size_t i = 1; i <= z.K; i++) z.s[i] = z.s[i - 1] + z.a[i];
    for (size_t i = 1; i <= z.L; i++) {
      u64 x = n / i;
      z.S[i] = falling_mint(x + 1, k + 1);  // -falling_mint(1, k + 1);
    }
    return z;
  }

  void calc_small_sum() {
    for (size_t i = 1; i <= K; i++) s[i] = s[i - 1] + a[i];
  }

  // += v*x^(-s)
  DP add(u64 x, mint v) {
    DP f(*this);
    if (1 <= x && x <= K) {
      f.a[x] += v;
      for (size_t i = x; i <= K; i++) f.s[i] += v;
    }
    for (size_t i = 1; i <= L; i++)
      if (x <= N / i) f.S[i] += v;
    return f;
  }
  // *= x^(-s)
  DP shift(u64 x) {
    DP f(N);
    for (size_t i = 1; i <= K / x; i++) f.a[i * x] = a[i];
    f.calc_small_sum();
    for (size_t i = 1; i <= L; i++) {
      u64 y = N / i / x;
      f.S[i] = y <= K ? s[y] : S[i * x];
    }
    return f;
  }
  DP inv() { return id(N) / DP(*this); }

  DP &operator+=(const DP &r) {
    assert(N == r.N);
    for (size_t i = 1; i <= K; i++) a[i] += r.a[i];
    for (size_t i = 1; i <= K; i++) s[i] += r.s[i];
    for (size_t i = 1; i <= L; i++) S[i] += r.S[i];
    return *this;
  }
  DP &operator-=(const DP &r) {
    assert(N == r.N);
    for (size_t i = 1; i <= K; i++) a[i] -= r.a[i];
    for (size_t i = 1; i <= K; i++) s[i] -= r.s[i];
    for (size_t i = 1; i <= L; i++) S[i] -= r.S[i];
    return *this;
  }
  DP &operator*=(const mint &r) {
    for (auto &v : a) v *= r;
    for (auto &v : s) v *= r;
    for (auto &v : S) v *= r;
    return *this;
  }
  DP &operator*=(const DP &r) {
    for (size_t i = 1; i <= L; i++) {
      mint v = 0;
      size_t m = sqrt(N / i);
      for (size_t x = 1; x <= m; x++) {
        u64 t = N / i / x;
        v += a[x] * (t <= K ? r.s[t] : r.S[i * x]);
      }
      for (size_t y = 1; y <= m; y++) {
        u64 t = N / i / y;
        v += r.a[y] * ((t <= K ? s[t] : S[i * y]) - s[m]);
      }
      S[i] = v;
    }
    for (size_t i = K; i >= 1; i--) {
      for (size_t j = 2; u64(i) * j <= K; j++) a[i * j] += a[i] * r.a[j];
      a[i] *= r.a[1];
    }
    calc_small_sum();
    return *this;
  }
  DP &operator/=(const DP &r) {
    mint inv = r.a[1].inv();
    for (size_t i = 1; i <= K; i++) {
      a[i] *= inv;
      for (size_t j = 2; u64(i) * j <= K; j++) a[i * j] -= a[i] * r.a[j];
      s[i] = s[i - 1] + a[i];
    }
    for (size_t i = L; i >= 1; i--) {
      size_t m = sqrt(N / i);
      for (size_t x = 1; x <= m; x++) {
        u64 t = N / i / x;
        S[i] -= a[x] * (t <= K ? r.s[t] : r.S[i * x]);
      }
      for (size_t y = 2; y <= m; y++) {
        u64 t = N / i / y;
        S[i] -= r.a[y] * ((t <= K ? s[t] : S[i * y]) - s[m]);
      }
      S[i] *= inv;
      S[i] += s[m];
    }
    return *this;
  }
  DP operator+(const DP &r) const { return DP(*this) += r; }
  DP operator-(const DP &r) const { return DP(*this) -= r; }
  DP operator*(const DP &r) const { return DP(*this) *= r; }
  DP operator*(const mint &r) const { return DP(*this) *= r; }
  DP operator/(const DP &r) const { return DP(*this) /= r; }

  // sparse
  DP &operator*=(const vector<pair<u64, mint>> &r) {
    DP f(N);
    for (auto [x, v] : r) {
      for (size_t i = 1; i <= K / x; i++) f.a[i * x] += v * a[i];
      for (size_t i = 1; i <= L; i++) {
        u64 t = N / i / x;
        f.S[i] += v * (t <= K ? s[t] : f.S[i * x]);
      }
    }
    f.calc_small_sum();
    return f;
  }

  friend ostream &operator<<(ostream &os, const DP &f) {
    os << "a: ";
    for (size_t i = 1; i <= f.K; i++) os << f.a[i].val() << ",\n"[i == f.K];
    os << "s: ";
    for (size_t i = 1; i <= f.K; i++) os << f.s[i].val() << ",\n"[i == f.K];
    os << "S: ";
    for (size_t i = f.L; i >= 1; i--)
      os << (f.N / i) << ":" << f.S[i].val() << ",\n"[i == 1];
    return os;
  };

 private:
  static mint rising_mint(i64 x, size_t k) {
    vector<i64> xs(k);
    iota(xs.begin(), xs.end(), x);
    for (i64 v = 2; v <= k; v++) {
      i64 w = v;
      for (auto &y : xs) {
        i64 g = gcd(w, y);
        w /= g;
        y /= g;
      }
      assert(w == 1);
    }
    mint z = 1;
    for (auto &y : xs) z *= y;
    return z;
  }
  static mint falling_mint(i64 x, size_t k) {
    return rising_mint(x - k + 1, k);
  }
};
#line 2 "number-theory/dirichlet-series-prefix-sum.hpp"

// https://maspypy.com/dirichlet-%E7%A9%8D%E3%81%A8%E3%80%81%E6%95%B0%E8%AB%96%E9%96%A2%E6%95%B0%E3%81%AE%E7%B4%AF%E7%A9%8D%E5%92%8C
/**
 * @brief Dirichlet 積の prefix sum
 */
template <class mint>
struct DirichletSeriesPrefixSum {
  using DP = DirichletSeriesPrefixSum<mint>;
  using u64 = uint64_t;

 public:
  u64 N;
  size_t K, L;
  vector<mint> a, s, S;

  DirichletSeriesPrefixSum(u64 n)
      : N(n),
        K(max(sqrt(N), pow(max(1.0, N / log(N)), 2.0 / 3))),
        L((N - 1) / K + 1) {
    a.assign(K + 1, 0);
    s.assign(K + 1, 0);
    S.assign(L + 1, 0);
  }
  DirichletSeriesPrefixSum(const DP &d)
      : N(d.N), K(d.K), L(d.L), a(d.a), s(d.s), S(d.S) {}
  static DP id(u64 n) {
    DP z(n);
    return z.add(1, 1);
  }
  // {floor(n/k):1<=k<=n}={x[0],...,x[m-1]}, x[i-1]<x[i]
  // a[i]=sum_{1<=j<=x[i]}f(j)
  DirichletSeriesPrefixSum(u64 n, const vector<mint> &a) {

  }

  // zeta(s-k)
  static DP zeta(u64 n, size_t k = 0) {
    assert(k <= 2);
    DP z(n);
    for (size_t i = 1; i <= z.K; i++) z.a[i] = mint(i).pow(k);
    for (size_t i = 1; i <= z.K; i++) z.s[i] = z.s[i - 1] + z.a[i];
    for (size_t i = 1; i <= z.L; i++) {
      u64 x = n / i;
      if (k == 0)
        z.S[i] = x;
      else if (k == 1)
        z.S[i] = mint(x | 1) * mint((x + 1) / 2);
      else if (k == 2) {
        array<u64, 3> xs{x, x + 1, 2 * x + 1};
        xs[x & 1] /= 2;
        xs[(3 - (x % 3)) % 3] /= 3;
        z.S[i] = mint(xs[0]) * mint(xs[1]) * mint(xs[2]);
      }
    }
    return z;
  }

  // n(n+1)...(n+k-1)/k!
  static DP rising(u64 n, size_t k) {
    if (k == 0) return zeta(n);
    DP z(n);
    for (size_t i = 1; i <= z.K; i++) z.a[i] = rising_mint(i, k);
    for (size_t i = 1; i <= z.K; i++) z.s[i] = z.s[i - 1] + z.a[i];
    for (size_t i = 1; i <= z.L; i++) {
      u64 x = n / i;
      z.S[i] = rising_mint(x, k + 1);
    }
    return z;
  }
  // n(n-1)...(n-k+1)/k!
  static DP falling(u64 n, size_t k) {
    if (k == 0) return zeta(n);
    DP z(n);
    for (size_t i = 1; i <= z.K; i++) z.a[i] = falling_mint(i, k);
    for (size_t i = 1; i <= z.K; i++) z.s[i] = z.s[i - 1] + z.a[i];
    for (size_t i = 1; i <= z.L; i++) {
      u64 x = n / i;
      z.S[i] = falling_mint(x + 1, k + 1);  // -falling_mint(1, k + 1);
    }
    return z;
  }

  void calc_small_sum() {
    for (size_t i = 1; i <= K; i++) s[i] = s[i - 1] + a[i];
  }

  // += v*x^(-s)
  DP add(u64 x, mint v) {
    DP f(*this);
    if (1 <= x && x <= K) {
      f.a[x] += v;
      for (size_t i = x; i <= K; i++) f.s[i] += v;
    }
    for (size_t i = 1; i <= L; i++)
      if (x <= N / i) f.S[i] += v;
    return f;
  }
  // *= x^(-s)
  DP shift(u64 x) {
    DP f(N);
    for (size_t i = 1; i <= K / x; i++) f.a[i * x] = a[i];
    f.calc_small_sum();
    for (size_t i = 1; i <= L; i++) {
      u64 y = N / i / x;
      f.S[i] = y <= K ? s[y] : S[i * x];
    }
    return f;
  }
  DP inv() { return id(N) / DP(*this); }

  DP &operator+=(const DP &r) {
    assert(N == r.N);
    for (size_t i = 1; i <= K; i++) a[i] += r.a[i];
    for (size_t i = 1; i <= K; i++) s[i] += r.s[i];
    for (size_t i = 1; i <= L; i++) S[i] += r.S[i];
    return *this;
  }
  DP &operator-=(const DP &r) {
    assert(N == r.N);
    for (size_t i = 1; i <= K; i++) a[i] -= r.a[i];
    for (size_t i = 1; i <= K; i++) s[i] -= r.s[i];
    for (size_t i = 1; i <= L; i++) S[i] -= r.S[i];
    return *this;
  }
  DP &operator*=(const mint &r) {
    for (auto &v : a) v *= r;
    for (auto &v : s) v *= r;
    for (auto &v : S) v *= r;
    return *this;
  }
  DP &operator*=(const DP &r) {
    for (size_t i = 1; i <= L; i++) {
      mint v = 0;
      size_t m = sqrt(N / i);
      for (size_t x = 1; x <= m; x++) {
        u64 t = N / i / x;
        v += a[x] * (t <= K ? r.s[t] : r.S[i * x]);
      }
      for (size_t y = 1; y <= m; y++) {
        u64 t = N / i / y;
        v += r.a[y] * ((t <= K ? s[t] : S[i * y]) - s[m]);
      }
      S[i] = v;
    }
    for (size_t i = K; i >= 1; i--) {
      for (size_t j = 2; u64(i) * j <= K; j++) a[i * j] += a[i] * r.a[j];
      a[i] *= r.a[1];
    }
    calc_small_sum();
    return *this;
  }
  DP &operator/=(const DP &r) {
    mint inv = r.a[1].inv();
    for (size_t i = 1; i <= K; i++) {
      a[i] *= inv;
      for (size_t j = 2; u64(i) * j <= K; j++) a[i * j] -= a[i] * r.a[j];
      s[i] = s[i - 1] + a[i];
    }
    for (size_t i = L; i >= 1; i--) {
      size_t m = sqrt(N / i);
      for (size_t x = 1; x <= m; x++) {
        u64 t = N / i / x;
        S[i] -= a[x] * (t <= K ? r.s[t] : r.S[i * x]);
      }
      for (size_t y = 2; y <= m; y++) {
        u64 t = N / i / y;
        S[i] -= r.a[y] * ((t <= K ? s[t] : S[i * y]) - s[m]);
      }
      S[i] *= inv;
      S[i] += s[m];
    }
    return *this;
  }
  DP operator+(const DP &r) const { return DP(*this) += r; }
  DP operator-(const DP &r) const { return DP(*this) -= r; }
  DP operator*(const DP &r) const { return DP(*this) *= r; }
  DP operator*(const mint &r) const { return DP(*this) *= r; }
  DP operator/(const DP &r) const { return DP(*this) /= r; }

  // sparse
  DP &operator*=(const vector<pair<u64, mint>> &r) {
    DP f(N);
    for (auto [x, v] : r) {
      for (size_t i = 1; i <= K / x; i++) f.a[i * x] += v * a[i];
      for (size_t i = 1; i <= L; i++) {
        u64 t = N / i / x;
        f.S[i] += v * (t <= K ? s[t] : f.S[i * x]);
      }
    }
    f.calc_small_sum();
    return f;
  }

  friend ostream &operator<<(ostream &os, const DP &f) {
    os << "a: ";
    for (size_t i = 1; i <= f.K; i++) os << f.a[i].val() << ",\n"[i == f.K];
    os << "s: ";
    for (size_t i = 1; i <= f.K; i++) os << f.s[i].val() << ",\n"[i == f.K];
    os << "S: ";
    for (size_t i = f.L; i >= 1; i--)
      os << (f.N / i) << ":" << f.S[i].val() << ",\n"[i == 1];
    return os;
  };

 private:
  static mint rising_mint(i64 x, size_t k) {
    vector<i64> xs(k);
    iota(xs.begin(), xs.end(), x);
    for (i64 v = 2; v <= k; v++) {
      i64 w = v;
      for (auto &y : xs) {
        i64 g = gcd(w, y);
        w /= g;
        y /= g;
      }
      assert(w == 1);
    }
    mint z = 1;
    for (auto &y : xs) z *= y;
    return z;
  }
  static mint falling_mint(i64 x, size_t k) {
    return rising_mint(x - k + 1, k);
  }
};
Back to top page