Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd131d01

Browse files
committed
update many_facts
1 parent6ef9873 commitd131d01

File tree

1 file changed

+75
-62
lines changed

1 file changed

+75
-62
lines changed

‎verify/simd/many_facts.test.cpp

Lines changed: 75 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
#definePROBLEM"https://judge.yosupo.jp/problem/many_factorials"
33
#pragma GCC optimize("Ofast,unroll-loops")
44
#include<bits/stdc++.h>
5-
#defineCP_ALGO_CHECKPOINT
6-
#include"cp-algo/util/checkpoint.hpp"
5+
//#define CP_ALGO_CHECKPOINT
76
#include"blazingio/blazingio.min.hpp"
7+
#include"cp-algo/util/checkpoint.hpp"
88
#include"cp-algo/util/simd.hpp"
99
#include"cp-algo/math/common.hpp"
1010

@@ -14,84 +14,97 @@ using namespace cp_algo;
1414
constexprint mod =998244353;
1515
constexprint imod = -math::inv2(mod);
1616

17-
voidfacts_inplace(vector<int> &args) {
18-
constexprint block =1 <<16;
19-
static basic_string<size_t> args_per_block[mod / block];
20-
uint64_t limit =0;
21-
for(auto [i, x]: args | views::enumerate) {
22-
if(x < mod /2) {
23-
limit =max(limit,uint64_t(x));
24-
args_per_block[x / block].push_back(i);
25-
}else {
26-
limit =max(limit,uint64_t(mod - x -1));
27-
args_per_block[(mod - x -1) / block].push_back(i);
17+
vector<int>facts(vector<int>const& args) {
18+
constexprint accum =4;
19+
constexprint simd_size =8;
20+
constexprint block =1 <<18;
21+
constexprint subblock = block / simd_size;
22+
static basic_string<array<int,2>> odd_args_per_block[mod / subblock];
23+
static basic_string<array<int,2>> reg_args_per_block[mod / subblock];
24+
constexprint limit_reg = mod /64;
25+
int limit_odd =0;
26+
27+
vector<int>res(size(args),1);
28+
auto prod_mod = [&](uint64_t a,uint64_t b) {
29+
return (a * b) % mod;
30+
};
31+
for(auto [i, xy]:views::zip(args, res) | views::enumerate) {
32+
auto [x, y] = xy;
33+
auto t = x;
34+
if(t >= mod /2) {
35+
t = mod - t -1;
36+
y = t %2 ?1 : mod -1;
37+
}
38+
int pw =0;
39+
while(t > limit_reg) {
40+
limit_odd =max(limit_odd, (t -1) /2);
41+
odd_args_per_block[(t -1) /2 / subblock].push_back({int(i), (t -1) /2});
42+
t /=2;
43+
pw += t;
2844
}
45+
reg_args_per_block[t / subblock].push_back({int(i), t});
46+
y =int(y *math::bpow(2, pw,1ULL, prod_mod) % mod);
2947
}
3048
cp_algo::checkpoint("init");
3149
uint32_t b2x32 = (1ULL <<32) % mod;
32-
uint64_t fact =1;
33-
constint accum =4;
34-
constint simd_size =8;
35-
for(uint64_t b =0; b <= limit; b += accum * block) {
36-
u32x8 cur[accum];
37-
static array<u32x8, block / simd_size> prods[accum];
38-
for(int z =0; z < accum; z++) {
39-
for(int j =0; j < simd_size; j++) {
40-
cur[z][j] =uint32_t(b + z * block + j * (block / simd_size));
41-
prods[z][0][j] = cur[z][j] + !(b || z || j);
42-
#pragma GCC diagnostic push
43-
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
44-
cur[z][j] =uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
45-
#pragma GCC diagnostic pop
46-
}
47-
}
48-
for(int i =1; i < block / simd_size; i++) {
50+
auto process = [&](int limit,auto &args_per_block,auto step,auto &&proj) {
51+
uint64_t fact =1;
52+
for(int b =0; b <= limit; b += accum * block) {
53+
u32x8 cur[accum];
54+
static array<u32x8, subblock> prods[accum];
4955
for(int z =0; z < accum; z++) {
50-
cur[z] += b2x32;
51-
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
52-
prods[z][i] =montgomery_mul(prods[z][i -1], cur[z], mod, imod);
53-
}
54-
}
55-
cp_algo::checkpoint("inner loop");
56-
for(int z =0; z < accum; z++) {
57-
uint64_t bl = b + z * block;
58-
for(auto i: args_per_block[bl / block]) {
59-
size_t x = args[i];
60-
if(x >= mod /2) {
61-
x = mod - x -1;
62-
}
63-
x -= bl;
64-
auto pre_blocks = x / (block / simd_size);
65-
auto in_block = x % (block / simd_size);
66-
auto ans = fact * prods[z][in_block][pre_blocks] % mod;
67-
for(size_t j =0; j < pre_blocks; j++) {
68-
ans = ans * prods[z].back()[j] % mod;
56+
for(int j =0; j < simd_size; j++) {
57+
cur[z][j] =uint32_t(b + z * block + j * subblock);
58+
cur[z][j] =proj(cur[z][j]);
59+
prods[z][0][j] = cur[z][j] + !cur[z][j];
60+
#pragma GCC diagnostic push
61+
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
62+
cur[z][j] =uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
63+
#pragma GCC diagnostic pop
6964
}
70-
if(args[i] >= mod /2) {
71-
ans =math::bpow(ans, mod -2,1ULL, [](auto a,auto b){return a * b % mod;});
72-
args[i] =int(x %2 ? ans : mod - ans);
73-
}else {
74-
args[i] =int(ans);
65+
}
66+
for(int i =1; i < block / simd_size; i++) {
67+
for(int z =0; z < accum; z++) {
68+
cur[z] += step;
69+
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
70+
prods[z][i] =montgomery_mul(prods[z][i -1], cur[z], mod, imod);
7571
}
7672
}
77-
args_per_block[bl / block].clear();
78-
for(int j =0; j < simd_size; j++) {
79-
fact = fact * prods[z].back()[j] % mod;
73+
cp_algo::checkpoint("inner loop");
74+
for(int z =0; z < accum; z++) {
75+
for(int j =0; j < simd_size; j++) {
76+
int bl = b + z * block + j * subblock;
77+
for(auto [i, x]: args_per_block[bl / subblock]) {
78+
auto ans = fact * prods[z][x - bl][j] % mod;
79+
res[i] =int(res[i] * ans % mod);
80+
}
81+
fact = fact * prods[z].back()[j] % mod;
82+
}
8083
}
84+
cp_algo::checkpoint("mul ans");
85+
}
86+
};
87+
uint32_t b2x33 = (1ULL <<33) % mod;
88+
process(limit_reg, reg_args_per_block, b2x32, identity{});
89+
process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return2 * x +1;});
90+
for(auto [i, x]: res | views::enumerate) {
91+
if (args[i] >= mod /2) {
92+
x =int(math::bpow(x, mod -2,1ULL, prod_mod));
8193
}
82-
cp_algo::checkpoint("write ans");
8394
}
95+
cp_algo::checkpoint("inv ans");
96+
return res;
8497
}
8598

8699
voidsolve() {
87100
int n;
88101
cin >> n;
89102
vector<int>args(n);
90103
for(auto &x : args) {cin >> x;}
91-
cp_algo::checkpoint("inputread");
92-
facts_inplace(args);
93-
for(auto it:args) {cout << it <<"\n";}
94-
cp_algo::checkpoint("output written");
104+
cp_algo::checkpoint("read");
105+
auto res =facts(args);
106+
for(auto it:res) {cout << it <<"\n";}
107+
cp_algo::checkpoint("write");
95108
cp_algo::checkpoint<1>();
96109
}
97110

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp