Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitbc21b44

Browse files
committed
move factorials to library file
1 parent0f96754 commitbc21b44

File tree

3 files changed

+107
-105
lines changed

3 files changed

+107
-105
lines changed

‎cp-algo/math/factorials.hpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#ifndef CP_ALGO_MATH_FACTORIALS_HPP
2+
#defineCP_ALGO_MATH_FACTORIALS_HPP
3+
#include"cp-algo/util/checkpoint.hpp"
4+
#include"cp-algo/util/bump_alloc.hpp"
5+
#include"cp-algo/util/simd.hpp"
6+
#include"cp-algo/math/common.hpp"
7+
#include"cp-algo/number_theory/modint.hpp"
8+
#include<ranges>
9+
10+
namespacecp_algo::math {
11+
template<bool use_bump_alloc =false,int maxn =100'000>
12+
autofacts(autoconst& args) {
13+
constexprint max_mod =1'000'000'000;
14+
constexprint accum =4;
15+
constexprint simd_size =8;
16+
constexprint block =1 <<18;
17+
constexprint subblock = block / simd_size;
18+
using base = std::decay_t<decltype(args[0])>;
19+
static_assert(modint_type<base>,"Base type must be a modint type");
20+
using T = std::array<int,2>;
21+
using alloc = std::conditional_t<use_bump_alloc,
22+
bump_alloc<T,30 * maxn>,
23+
big_alloc<T>>;
24+
std::basic_string<T, std::char_traits<T>, alloc> odd_args_per_block[max_mod / subblock];
25+
std::basic_string<T, std::char_traits<T>, alloc> reg_args_per_block[max_mod / subblock];
26+
constexprint limit_reg = max_mod /64;
27+
int limit_odd =0;
28+
29+
std::vector<base, big_alloc<base>>res(size(args),1);
30+
constint mod =base::mod();
31+
constint imod = -math::inv2(mod);
32+
for(auto [i, xy]:std::views::zip(args, res) | std::views::enumerate) {
33+
auto [x, y] = xy;
34+
int t = x.getr();
35+
if(t >= mod /2) {
36+
t = mod - t -1;
37+
y = t %2 ?1 : mod-1;
38+
}
39+
int pw =0;
40+
while(t > limit_reg) {
41+
limit_odd =std::max(limit_odd, (t -1) /2);
42+
odd_args_per_block[(t -1) /2 / subblock].push_back({int(i), (t -1) /2});
43+
t /=2;
44+
pw += t;
45+
}
46+
reg_args_per_block[t / subblock].push_back({int(i), t});
47+
y *=bpow(base(2), pw);
48+
}
49+
checkpoint("init");
50+
uint32_t b2x32 = (1ULL <<32) % mod;
51+
auto process = [&](int limit,auto &args_per_block,auto step,auto &&proj) {
52+
base fact =1;
53+
for(int b =0; b <= limit; b += accum * block) {
54+
u32x8 cur[accum];
55+
static std::array<u32x8, subblock> prods[accum];
56+
for(int z =0; z < accum; z++) {
57+
for(int j =0; j < simd_size; j++) {
58+
cur[z][j] =uint32_t(b + z * block + j * subblock);
59+
cur[z][j] =proj(cur[z][j]);
60+
prods[z][0][j] = cur[z][j] + !cur[z][j];
61+
#pragma GCC diagnostic push
62+
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
63+
cur[z][j] =uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
64+
#pragma GCC diagnostic pop
65+
}
66+
}
67+
for(int i =1; i < block / simd_size; i++) {
68+
for(int z =0; z < accum; z++) {
69+
cur[z] += step;
70+
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
71+
prods[z][i] =montgomery_mul(prods[z][i -1], cur[z], mod, imod);
72+
}
73+
}
74+
checkpoint("inner loop");
75+
for(int z =0; z < accum; z++) {
76+
for(int j =0; j < simd_size; j++) {
77+
int bl = b + z * block + j * subblock;
78+
for(auto [i, x]: args_per_block[bl / subblock]) {
79+
res[i] *= fact * prods[z][x - bl][j];
80+
}
81+
fact *=base(prods[z].back()[j]);
82+
}
83+
}
84+
checkpoint("mul ans");
85+
}
86+
};
87+
uint32_t b2x33 = (1ULL <<33) % mod;
88+
process(limit_reg, reg_args_per_block, b2x32, std::identity{});
89+
process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return2 * x +1;});
90+
for(auto [i, x]: res | std::views::enumerate) {
91+
if (args[i] >= mod /2) {
92+
x = x.inv();
93+
}
94+
}
95+
checkpoint("inv ans");
96+
return res;
97+
}
98+
}
99+
#endif// CP_ALGO_MATH_FACTORIALS_HPP

‎cp-algo/util/simd.hpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,9 @@ namespace cp_algo {
4141
[[gnu::always_inline]]inline u64x4low32(u64x4 x) {
4242
return x &uint32_t(-1);
4343
}
44-
[[gnu::always_inline]]inlineautorotr(auto x) {
45-
returndecltype(x)(__builtin_shufflevector(u32x8(x),u32x8(x),1,2,3,0,5,6,7,4));
44+
[[gnu::always_inline]]inlineautoswap_bytes(auto x) {
45+
returndecltype(x)(__builtin_shufflevector(u32x8(x),u32x8(x),1,0,3,2,5,4,7,6));
4646
}
47-
[[gnu::always_inline]]inlineautorotl(auto x) {
48-
returndecltype(x)(__builtin_shufflevector(u32x8(x),u32x8(x),3,0,1,2,7,4,5,6));
49-
}
50-
5147
[[gnu::always_inline]]inline u64x4montgomery_reduce(u64x4 x,uint32_t mod,uint32_t imod) {
5248
#ifdef __AVX2__
5349
auto x_ninv =u64x4(_mm256_mul_epu32(__m256i(x),__m256i() + imod));
@@ -56,7 +52,7 @@ namespace cp_algo {
5652
auto x_ninv = x * imod;
5753
x +=low32(x_ninv) * mod;
5854
#endif
59-
returnrotr(x);
55+
returnswap_bytes(x);
6056
}
6157

6258
[[gnu::always_inline]]inline u64x4montgomery_mul(u64x4 x, u64x4 y,uint32_t mod,uint32_t imod) {
@@ -68,7 +64,7 @@ namespace cp_algo {
6864
}
6965
[[gnu::always_inline]]inline u32x8montgomery_mul(u32x8 x, u32x8 y,uint32_t mod,uint32_t imod) {
7066
returnu32x8(montgomery_mul(u64x4(x),u64x4(y), mod, imod)) |
71-
u32x8(rotl(montgomery_mul(u64x4(rotr(x)),u64x4(rotr(y)), mod, imod)));
67+
u32x8(swap_bytes(montgomery_mul(u64x4(swap_bytes(x)),u64x4(swap_bytes(y)), mod, imod)));
7268
}
7369
[[gnu::always_inline]]inline dx4rotate_right(dx4 x) {
7470
staticconstexpr u64x4 shuffler = {3,0,1,2};

‎verify/simd/many_facts.test.cpp

Lines changed: 4 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,18 @@
11
// @brief Many Factorials
22
#definePROBLEM"https://judge.yosupo.jp/problem/many_factorials"
33
#pragma GCC optimize("Ofast,unroll-loops")
4+
#defineCP_ALGO_CHECKPOINT
45
#include<bits/stdc++.h>
5-
//#define CP_ALGO_CHECKPOINT
66
#include"blazingio/blazingio.min.hpp"
7-
#include"cp-algo/util/checkpoint.hpp"
8-
#include"cp-algo/util/simd.hpp"
9-
#include"cp-algo/util/bump_alloc.hpp"
10-
#include"cp-algo/math/common.hpp"
7+
#include"cp-algo/math/factorials.hpp"
118

129
usingnamespacestd;
13-
usingnamespacecp_algo;
14-
15-
constexprint mod =998244353;
16-
constexprint imod = -math::inv2(mod);
17-
18-
template<bool use_bump_alloc =false,int maxn =100'000>
19-
vector<int>facts(vector<int>const& args) {
20-
constexprint accum =4;
21-
constexprint simd_size =8;
22-
constexprint block =1 <<18;
23-
constexprint subblock = block / simd_size;
24-
using T = array<int,2>;
25-
using alloc =conditional_t<use_bump_alloc,
26-
bump_alloc<T,30 * maxn>,
27-
allocator<T>>;
28-
basic_string<T, char_traits<T>, alloc> odd_args_per_block[mod / subblock];
29-
basic_string<T, char_traits<T>, alloc> reg_args_per_block[mod / subblock];
30-
constexprint limit_reg = mod /64;
31-
int limit_odd =0;
32-
33-
vector<int>res(size(args),1);
34-
auto prod_mod = [&](uint64_t a,uint64_t b) {
35-
return (a * b) % mod;
36-
};
37-
for(auto [i, xy]:views::zip(args, res) | views::enumerate) {
38-
auto [x, y] = xy;
39-
auto t = x;
40-
if(t >= mod /2) {
41-
t = mod - t -1;
42-
y = t %2 ?1 : mod -1;
43-
}
44-
int pw =0;
45-
while(t > limit_reg) {
46-
limit_odd =max(limit_odd, (t -1) /2);
47-
odd_args_per_block[(t -1) /2 / subblock].push_back({int(i), (t -1) /2});
48-
t /=2;
49-
pw += t;
50-
}
51-
reg_args_per_block[t / subblock].push_back({int(i), t});
52-
y =int(y *math::bpow(2, pw,1ULL, prod_mod) % mod);
53-
}
54-
cp_algo::checkpoint("init");
55-
uint32_t b2x32 = (1ULL <<32) % mod;
56-
auto process = [&](int limit,auto &args_per_block,auto step,auto &&proj) {
57-
uint64_t fact =1;
58-
for(int b =0; b <= limit; b += accum * block) {
59-
u32x8 cur[accum];
60-
static array<u32x8, subblock> prods[accum];
61-
for(int z =0; z < accum; z++) {
62-
for(int j =0; j < simd_size; j++) {
63-
cur[z][j] =uint32_t(b + z * block + j * subblock);
64-
cur[z][j] =proj(cur[z][j]);
65-
prods[z][0][j] = cur[z][j] + !cur[z][j];
66-
#pragma GCC diagnostic push
67-
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
68-
cur[z][j] =uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
69-
#pragma GCC diagnostic pop
70-
}
71-
}
72-
for(int i =1; i < block / simd_size; i++) {
73-
for(int z =0; z < accum; z++) {
74-
cur[z] += step;
75-
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
76-
prods[z][i] =montgomery_mul(prods[z][i -1], cur[z], mod, imod);
77-
}
78-
}
79-
cp_algo::checkpoint("inner loop");
80-
for(int z =0; z < accum; z++) {
81-
for(int j =0; j < simd_size; j++) {
82-
int bl = b + z * block + j * subblock;
83-
for(auto [i, x]: args_per_block[bl / subblock]) {
84-
auto ans = fact * prods[z][x - bl][j] % mod;
85-
res[i] =int(res[i] * ans % mod);
86-
}
87-
fact = fact * prods[z].back()[j] % mod;
88-
}
89-
}
90-
cp_algo::checkpoint("mul ans");
91-
}
92-
};
93-
uint32_t b2x33 = (1ULL <<33) % mod;
94-
process(limit_reg, reg_args_per_block, b2x32, identity{});
95-
process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return2 * x +1;});
96-
for(auto [i, x]: res | views::enumerate) {
97-
if (args[i] >= mod /2) {
98-
x =int(math::bpow(x, mod -2,1ULL, prod_mod));
99-
}
100-
}
101-
cp_algo::checkpoint("inv ans");
102-
return res;
103-
}
10+
using base = cp_algo::math::modint<998244353>;
10411

10512
voidsolve() {
10613
int n;
10714
cin >> n;
108-
vector<int>args(n);
15+
vector<base>args(n);
10916
for(auto &x : args) {cin >> x;}
11017
cp_algo::checkpoint("read");
11118
auto res =facts(args);

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp