3
3
#include " cp-algo/util/checkpoint.hpp"
4
4
#include " cp-algo/util/bump_alloc.hpp"
5
5
#include " cp-algo/util/simd.hpp"
6
- #include " cp-algo/math/common .hpp"
6
+ #include " cp-algo/math/combinatorics .hpp"
7
7
#include " cp-algo/number_theory/modint.hpp"
8
8
#include < ranges>
9
9
@@ -37,38 +37,37 @@ namespace cp_algo::math {
37
37
t = mod - t -1 ;
38
38
y = t %2 ?1 : mod-1 ;
39
39
}
40
- int pw =0 ;
40
+ auto pw =32ull * (t + 1 ) ;
41
41
while (t > limit_reg) {
42
42
limit_odd =std::max (limit_odd, (t -1 ) /2 );
43
43
odd_args_per_block[(t -1 ) /2 / subblock].push_back ({int (i), (t -1 ) /2 });
44
44
t /=2 ;
45
45
pw += t;
46
46
}
47
47
reg_args_per_block[t / subblock].push_back ({int (i), t});
48
- y *=bpow ( base ( 2 ), pw );
48
+ y *=pow_fixed< base, 2 >( int (pw % (mod - 1 )) );
49
49
}
50
50
checkpoint (" init" );
51
- uint32_t b2x32 =( 1ULL << 32 ) % mod ;
51
+ base bi2x32 =pow_fixed<base, 2 >( 32 ). inv () ;
52
52
auto process = [&](int limit,auto &args_per_block,auto step,auto &&proj) {
53
53
base fact =1 ;
54
54
for (int b =0 ; b <= limit; b += accum * block) {
55
55
u32x8 cur[accum];
56
56
static std::array<u32x8, subblock> prods[accum];
57
57
for (int z =0 ; z < accum; z++) {
58
58
for (int j =0 ; j < simd_size; j++) {
59
+ #pragma GCC diagnostic push
60
+ #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
59
61
cur[z][j] =uint32_t (b + z * block + j * subblock);
60
62
cur[z][j] =proj (cur[z][j]);
61
63
prods[z][0 ][j] = cur[z][j] + !cur[z][j];
62
- #pragma GCC diagnostic push
63
- #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
64
- cur[z][j] =uint32_t (uint64_t (cur[z][j]) * b2x32 % mod);
64
+ prods[z][0 ][j] =uint32_t (uint64_t (prods[z][0 ][j]) * bi2x32.getr () % mod);
65
65
#pragma GCC diagnostic pop
66
66
}
67
67
}
68
68
for (int i =1 ; i < block / simd_size; i++) {
69
69
for (int z =0 ; z < accum; z++) {
70
70
cur[z] += step;
71
- cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
72
71
prods[z][i] =montgomery_mul (prods[z][i -1 ], cur[z], mod, imod);
73
72
}
74
73
}
@@ -85,12 +84,12 @@ namespace cp_algo::math {
85
84
checkpoint (" mul ans" );
86
85
}
87
86
};
88
- uint32_t b2x33 = ( 1ULL << 33 ) % mod ;
89
- process (limit_reg, reg_args_per_block, b2x32, std::identity{ });
90
- process (limit_odd, odd_args_per_block, b2x33, []( uint32_t x) { return 2 * x + 1 ;} );
87
+ process (limit_reg, reg_args_per_block, 1 , std::identity{}) ;
88
+ process (limit_odd, odd_args_per_block, 2 , []( uint32_t x) { return 2 * x + 1 ; });
89
+ auto invs = bulk_invs<base>(res );
91
90
for (auto [i, x]: res | std::views::enumerate) {
92
91
if (args[i] >= mod /2 ) {
93
- x =x. inv () ;
92
+ x =invs[i] ;
94
93
}
95
94
}
96
95
checkpoint (" inv ans" );