Skip to content

Commit ff12831

Browse files
authored
ModInt 現代化 (#310)
* modint: assume C++17 * Fix precalculation rule * Update workflow * modint: update nCr functions
1 parent 590ff63 commit ff12831

File tree

3 files changed

+83
-43
lines changed

3 files changed

+83
-43
lines changed

.github/workflows/verify.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ jobs:
3434
GH_PAT: ${{ secrets.GH_PAT }}
3535
YUKICODER_TOKEN: ${{ secrets.YUKICODER_TOKEN }}
3636
DROPBOX_TOKEN: ${{ secrets.DROPBOX_TOKEN }}
37-
run: oj-verify all --jobs 2
37+
run: oj-verify all --jobs 2 --tle 20 --timeout 3600

modint.hpp

+79-39
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
#pragma once
2+
#include <cassert>
23
#include <iostream>
34
#include <set>
45
#include <vector>
56

67
template <int md> struct ModInt {
7-
#if __cplusplus >= 201402L
8-
#define MDCONST constexpr
9-
#else
10-
#define MDCONST
11-
#endif
128
using lint = long long;
13-
MDCONST static int mod() { return md; }
9+
constexpr static int mod() { return md; }
1410
static int get_primitive_root() {
1511
static int primitive_root = 0;
1612
if (!primitive_root) {
@@ -36,52 +32,53 @@ template <int md> struct ModInt {
3632
}
3733
int val_;
3834
int val() const noexcept { return val_; }
39-
MDCONST ModInt() : val_(0) {}
40-
MDCONST ModInt &_setval(lint v) { return val_ = (v >= md ? v - md : v), *this; }
41-
MDCONST ModInt(lint v) { _setval(v % md + md); }
42-
MDCONST explicit operator bool() const { return val_ != 0; }
43-
MDCONST ModInt operator+(const ModInt &x) const {
35+
constexpr ModInt() : val_(0) {}
36+
constexpr ModInt &_setval(lint v) { return val_ = (v >= md ? v - md : v), *this; }
37+
constexpr ModInt(lint v) { _setval(v % md + md); }
38+
constexpr explicit operator bool() const { return val_ != 0; }
39+
constexpr ModInt operator+(const ModInt &x) const {
4440
return ModInt()._setval((lint)val_ + x.val_);
4541
}
46-
MDCONST ModInt operator-(const ModInt &x) const {
42+
constexpr ModInt operator-(const ModInt &x) const {
4743
return ModInt()._setval((lint)val_ - x.val_ + md);
4844
}
49-
MDCONST ModInt operator*(const ModInt &x) const {
45+
constexpr ModInt operator*(const ModInt &x) const {
5046
return ModInt()._setval((lint)val_ * x.val_ % md);
5147
}
52-
MDCONST ModInt operator/(const ModInt &x) const {
48+
constexpr ModInt operator/(const ModInt &x) const {
5349
return ModInt()._setval((lint)val_ * x.inv().val() % md);
5450
}
55-
MDCONST ModInt operator-() const { return ModInt()._setval(md - val_); }
56-
MDCONST ModInt &operator+=(const ModInt &x) { return *this = *this + x; }
57-
MDCONST ModInt &operator-=(const ModInt &x) { return *this = *this - x; }
58-
MDCONST ModInt &operator*=(const ModInt &x) { return *this = *this * x; }
59-
MDCONST ModInt &operator/=(const ModInt &x) { return *this = *this / x; }
60-
friend MDCONST ModInt operator+(lint a, const ModInt &x) {
51+
constexpr ModInt operator-() const { return ModInt()._setval(md - val_); }
52+
constexpr ModInt &operator+=(const ModInt &x) { return *this = *this + x; }
53+
constexpr ModInt &operator-=(const ModInt &x) { return *this = *this - x; }
54+
constexpr ModInt &operator*=(const ModInt &x) { return *this = *this * x; }
55+
constexpr ModInt &operator/=(const ModInt &x) { return *this = *this / x; }
56+
friend constexpr ModInt operator+(lint a, const ModInt &x) {
6157
return ModInt()._setval(a % md + x.val_);
6258
}
63-
friend MDCONST ModInt operator-(lint a, const ModInt &x) {
59+
friend constexpr ModInt operator-(lint a, const ModInt &x) {
6460
return ModInt()._setval(a % md - x.val_ + md);
6561
}
66-
friend MDCONST ModInt operator*(lint a, const ModInt &x) {
62+
friend constexpr ModInt operator*(lint a, const ModInt &x) {
6763
return ModInt()._setval(a % md * x.val_ % md);
6864
}
69-
friend MDCONST ModInt operator/(lint a, const ModInt &x) {
65+
friend constexpr ModInt operator/(lint a, const ModInt &x) {
7066
return ModInt()._setval(a % md * x.inv().val() % md);
7167
}
72-
MDCONST bool operator==(const ModInt &x) const { return val_ == x.val_; }
73-
MDCONST bool operator!=(const ModInt &x) const { return val_ != x.val_; }
74-
MDCONST bool operator<(const ModInt &x) const {
68+
constexpr bool operator==(const ModInt &x) const { return val_ == x.val_; }
69+
constexpr bool operator!=(const ModInt &x) const { return val_ != x.val_; }
70+
constexpr bool operator<(const ModInt &x) const {
7571
return val_ < x.val_;
7672
} // To use std::map<ModInt, T>
7773
friend std::istream &operator>>(std::istream &is, ModInt &x) {
7874
lint t;
7975
return is >> t, x = ModInt(t), is;
8076
}
81-
MDCONST friend std::ostream &operator<<(std::ostream &os, const ModInt &x) {
77+
constexpr friend std::ostream &operator<<(std::ostream &os, const ModInt &x) {
8278
return os << x.val_;
8379
}
84-
MDCONST ModInt pow(lint n) const {
80+
81+
constexpr ModInt pow(lint n) const {
8582
ModInt ans = 1, tmp = *this;
8683
while (n) {
8784
if (n & 1) ans *= tmp;
@@ -90,9 +87,11 @@ template <int md> struct ModInt {
9087
return ans;
9188
}
9289

90+
static constexpr int cache_limit = std::min(md, 1 << 21);
9391
static std::vector<ModInt> facs, facinvs, invs;
94-
MDCONST static void _precalculation(int N) {
95-
int l0 = facs.size();
92+
93+
constexpr static void _precalculation(int N) {
94+
const int l0 = facs.size();
9695
if (N > md) N = md;
9796
if (N <= l0) return;
9897
facs.resize(N), facinvs.resize(N), invs.resize(N);
@@ -101,33 +100,74 @@ template <int md> struct ModInt {
101100
for (int i = N - 2; i >= l0; i--) facinvs[i] = facinvs[i + 1] * (i + 1);
102101
for (int i = N - 1; i >= l0; i--) invs[i] = facinvs[i] * facs[i - 1];
103102
}
104-
MDCONST ModInt inv() const {
105-
if (this->val_ < std::min(md >> 1, 1 << 21)) {
103+
104+
constexpr ModInt inv() const {
105+
if (this->val_ < cache_limit) {
106106
if (facs.empty()) facs = {1}, facinvs = {1}, invs = {0};
107107
while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
108108
return invs[this->val_];
109109
} else {
110110
return this->pow(md - 2);
111111
}
112112
}
113-
MDCONST ModInt fac() const {
113+
constexpr ModInt fac() const {
114114
while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
115115
return facs[this->val_];
116116
}
117-
MDCONST ModInt facinv() const {
117+
constexpr ModInt facinv() const {
118118
while (this->val_ >= int(facs.size())) _precalculation(facs.size() * 2);
119119
return facinvs[this->val_];
120120
}
121-
MDCONST ModInt doublefac() const {
121+
constexpr ModInt doublefac() const {
122122
lint k = (this->val_ + 1) / 2;
123123
return (this->val_ & 1) ? ModInt(k * 2).fac() / (ModInt(2).pow(k) * ModInt(k).fac())
124124
: ModInt(k).fac() * ModInt(2).pow(k);
125125
}
126-
MDCONST ModInt nCr(const ModInt &r) const {
127-
return (this->val_ < r.val_) ? 0 : this->fac() * (*this - r).facinv() * r.facinv();
126+
127+
constexpr ModInt nCr(int r) const {
128+
if (r < 0 or this->val_ < r) return ModInt(0);
129+
return this->fac() * (*this - r).facinv() * ModInt(r).facinv();
130+
}
131+
132+
constexpr ModInt nPr(int r) const {
133+
if (r < 0 or this->val_ < r) return ModInt(0);
134+
return this->fac() * (*this - r).facinv();
135+
}
136+
137+
static ModInt binom(int n, int r) {
138+
static long long bruteforce_times = 0;
139+
140+
if (r < 0 or n < r) return ModInt(0);
141+
if (n <= bruteforce_times or n < (int)facs.size()) return ModInt(n).nCr(r);
142+
143+
r = std::min(r, n - r);
144+
145+
ModInt ret = ModInt(r).facinv();
146+
for (int i = 0; i < r; ++i) ret *= n - i;
147+
bruteforce_times += r;
148+
149+
return ret;
150+
}
151+
152+
// Multinomial coefficient, (k_1 + k_2 + ... + k_m)! / (k_1! k_2! ... k_m!)
153+
// Complexity: O(sum(ks))
154+
template <class Vec> static ModInt multinomial(const Vec &ks) {
155+
ModInt ret{1};
156+
int sum = 0;
157+
for (int k : ks) {
158+
assert(k >= 0);
159+
ret *= ModInt(k).facinv(), sum += k;
160+
}
161+
return ret * ModInt(sum).fac();
128162
}
129-
MDCONST ModInt nPr(const ModInt &r) const {
130-
return (this->val_ < r.val_) ? 0 : this->fac() * (*this - r).facinv();
163+
164+
// Catalan number, C_n = binom(2n, n) / (n + 1)
165+
// C_0 = 1, C_1 = 1, C_2 = 2, C_3 = 5, C_4 = 14, ...
166+
// https://oeis.org/A000108
167+
// Complexity: O(n)
168+
static ModInt catalan(int n) {
169+
if (n < 0) return ModInt(0);
170+
return ModInt(n * 2).fac() * ModInt(n + 1).facinv() * ModInt(n).facinv();
131171
}
132172

133173
ModInt sqrt() const {

number/modint_runtime.hpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ struct ModIntRuntime {
122122
: ModIntRuntime(k).fac() * ModIntRuntime(2).pow(k);
123123
}
124124

125-
ModIntRuntime nCr(const ModIntRuntime &r) const {
126-
return (this->val_ < r.val_) ? ModIntRuntime(0)
127-
: this->fac() / ((*this - r).fac() * r.fac());
125+
ModIntRuntime nCr(int r) const {
126+
if (r < 0 or this->val_ < r) return ModIntRuntime(0);
127+
return this->fac() / ((*this - r).fac() * ModIntRuntime(r).fac());
128128
}
129129

130130
ModIntRuntime sqrt() const {

0 commit comments

Comments
 (0)