2
2
#define PROBLEM " https://judge.yosupo.jp/problem/many_factorials"
3
3
#pragma GCC optimize("Ofast,unroll-loops")
4
4
#include < bits/stdc++.h>
5
- #define CP_ALGO_CHECKPOINT
6
- #include " cp-algo/util/checkpoint.hpp"
5
+ // #define CP_ALGO_CHECKPOINT
7
6
#include " blazingio/blazingio.min.hpp"
7
+ #include " cp-algo/util/checkpoint.hpp"
8
8
#include " cp-algo/util/simd.hpp"
9
9
#include " cp-algo/math/common.hpp"
10
10
@@ -14,84 +14,97 @@ using namespace cp_algo;
14
14
constexpr int mod =998244353 ;
15
15
constexpr int imod = -math::inv2(mod);
16
16
17
- void facts_inplace (vector<int > &args) {
18
- constexpr int block =1 <<16 ;
19
- static basic_string<size_t > args_per_block[mod / block];
20
- uint64_t limit =0 ;
21
- for (auto [i, x]: args | views::enumerate) {
22
- if (x < mod /2 ) {
23
- limit =max (limit,uint64_t (x));
24
- args_per_block[x / block].push_back (i);
25
- }else {
26
- limit =max (limit,uint64_t (mod - x -1 ));
27
- args_per_block[(mod - x -1 ) / block].push_back (i);
17
+ vector<int >facts (vector<int >const & args) {
18
+ constexpr int accum =4 ;
19
+ constexpr int simd_size =8 ;
20
+ constexpr int block =1 <<18 ;
21
+ constexpr int subblock = block / simd_size;
22
+ static basic_string<array<int ,2 >> odd_args_per_block[mod / subblock];
23
+ static basic_string<array<int ,2 >> reg_args_per_block[mod / subblock];
24
+ constexpr int limit_reg = mod /64 ;
25
+ int limit_odd =0 ;
26
+
27
+ vector<int >res (size (args),1 );
28
+ auto prod_mod = [&](uint64_t a,uint64_t b) {
29
+ return (a * b) % mod;
30
+ };
31
+ for (auto [i, xy]:views::zip (args, res) | views::enumerate) {
32
+ auto [x, y] = xy;
33
+ auto t = x;
34
+ if (t >= mod /2 ) {
35
+ t = mod - t -1 ;
36
+ y = t %2 ?1 : mod -1 ;
37
+ }
38
+ int pw =0 ;
39
+ while (t > limit_reg) {
40
+ limit_odd =max (limit_odd, (t -1 ) /2 );
41
+ odd_args_per_block[(t -1 ) /2 / subblock].push_back ({int (i), (t -1 ) /2 });
42
+ t /=2 ;
43
+ pw += t;
28
44
}
45
+ reg_args_per_block[t / subblock].push_back ({int (i), t});
46
+ y =int (y *math::bpow (2 , pw,1ULL , prod_mod) % mod);
29
47
}
30
48
cp_algo::checkpoint (" init" );
31
49
uint32_t b2x32 = (1ULL <<32 ) % mod;
32
- uint64_t fact =1 ;
33
- const int accum =4 ;
34
- const int simd_size =8 ;
35
- for (uint64_t b =0 ; b <= limit; b += accum * block) {
36
- u32x8 cur[accum];
37
- static array<u32x8, block / simd_size> prods[accum];
38
- for (int z =0 ; z < accum; z++) {
39
- for (int j =0 ; j < simd_size; j++) {
40
- cur[z][j] =uint32_t (b + z * block + j * (block / simd_size));
41
- prods[z][0 ][j] = cur[z][j] + !(b || z || j);
42
- #pragma GCC diagnostic push
43
- #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
44
- cur[z][j] =uint32_t (uint64_t (cur[z][j]) * b2x32 % mod);
45
- #pragma GCC diagnostic pop
46
- }
47
- }
48
- for (int i =1 ; i < block / simd_size; i++) {
50
+ auto process = [&](int limit,auto &args_per_block,auto step,auto &&proj) {
51
+ uint64_t fact =1 ;
52
+ for (int b =0 ; b <= limit; b += accum * block) {
53
+ u32x8 cur[accum];
54
+ static array<u32x8, subblock> prods[accum];
49
55
for (int z =0 ; z < accum; z++) {
50
- cur[z] += b2x32;
51
- cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
52
- prods[z][i] =montgomery_mul (prods[z][i -1 ], cur[z], mod, imod);
53
- }
54
- }
55
- cp_algo::checkpoint (" inner loop" );
56
- for (int z =0 ; z < accum; z++) {
57
- uint64_t bl = b + z * block;
58
- for (auto i: args_per_block[bl / block]) {
59
- size_t x = args[i];
60
- if (x >= mod /2 ) {
61
- x = mod - x -1 ;
62
- }
63
- x -= bl;
64
- auto pre_blocks = x / (block / simd_size);
65
- auto in_block = x % (block / simd_size);
66
- auto ans = fact * prods[z][in_block][pre_blocks] % mod;
67
- for (size_t j =0 ; j < pre_blocks; j++) {
68
- ans = ans * prods[z].back ()[j] % mod;
56
+ for (int j =0 ; j < simd_size; j++) {
57
+ cur[z][j] =uint32_t (b + z * block + j * subblock);
58
+ cur[z][j] =proj (cur[z][j]);
59
+ prods[z][0 ][j] = cur[z][j] + !cur[z][j];
60
+ #pragma GCC diagnostic push
61
+ #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
62
+ cur[z][j] =uint32_t (uint64_t (cur[z][j]) * b2x32 % mod);
63
+ #pragma GCC diagnostic pop
69
64
}
70
- if (args[i] >= mod /2 ) {
71
- ans =math::bpow (ans, mod -2 ,1ULL , [](auto a,auto b){return a * b % mod;});
72
- args[i] =int (x %2 ? ans : mod - ans);
73
- }else {
74
- args[i] =int (ans);
65
+ }
66
+ for (int i =1 ; i < block / simd_size; i++) {
67
+ for (int z =0 ; z < accum; z++) {
68
+ cur[z] += step;
69
+ cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
70
+ prods[z][i] =montgomery_mul (prods[z][i -1 ], cur[z], mod, imod);
75
71
}
76
72
}
77
- args_per_block[bl / block].clear ();
78
- for (int j =0 ; j < simd_size; j++) {
79
- fact = fact * prods[z].back ()[j] % mod;
73
+ cp_algo::checkpoint (" inner loop" );
74
+ for (int z =0 ; z < accum; z++) {
75
+ for (int j =0 ; j < simd_size; j++) {
76
+ int bl = b + z * block + j * subblock;
77
+ for (auto [i, x]: args_per_block[bl / subblock]) {
78
+ auto ans = fact * prods[z][x - bl][j] % mod;
79
+ res[i] =int (res[i] * ans % mod);
80
+ }
81
+ fact = fact * prods[z].back ()[j] % mod;
82
+ }
80
83
}
84
+ cp_algo::checkpoint (" mul ans" );
85
+ }
86
+ };
87
+ uint32_t b2x33 = (1ULL <<33 ) % mod;
88
+ process (limit_reg, reg_args_per_block, b2x32, identity{});
89
+ process (limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return 2 * x +1 ;});
90
+ for (auto [i, x]: res | views::enumerate) {
91
+ if (args[i] >= mod /2 ) {
92
+ x =int (math::bpow (x, mod -2 ,1ULL , prod_mod));
81
93
}
82
- cp_algo::checkpoint (" write ans" );
83
94
}
95
+ cp_algo::checkpoint (" inv ans" );
96
+ return res;
84
97
}
85
98
86
99
void solve () {
87
100
int n;
88
101
cin >> n;
89
102
vector<int >args (n);
90
103
for (auto &x : args) {cin >> x;}
91
- cp_algo::checkpoint (" input read" );
92
- facts_inplace (args);
93
- for (auto it:args ) {cout << it <<" \n " ;}
94
- cp_algo::checkpoint (" output written " );
104
+ cp_algo::checkpoint (" read" );
105
+ auto res = facts (args);
106
+ for (auto it:res ) {cout << it <<" \n " ;}
107
+ cp_algo::checkpoint (" write " );
95
108
cp_algo::checkpoint<1 >();
96
109
}
97
110