Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit17f8c32

Browse files
nickggfacebook-github-bot
authored andcommitted
[NNC] IRSimplifier rules for Compare and Mod (#46412)
Summary:Adds new rules to the NNC IRSimplifier to take care of the following cases:* Comparisons which are symbolic but have a constant difference. E.g. this is most useful in cases like `if (x > x + 4) ...` which we can now eliminate.* Simplification of `Mod` nodes, including simple rules such as `0 % x` and `x % 1`, but also factorization of both sides to find common symbolic multiples. E.g. `(x * y) % x` can be cancelled out to `0`.See tests for many more examples!Pull Requestresolved:#46412Reviewed By: navahgarDifferential Revision: D24396151Pulled By: nickggfbshipit-source-id: abb954dc930867d62010dcbcd8a4701430733715
1 parenta06b95b commit17f8c32

File tree

4 files changed

+442
-18
lines changed

4 files changed

+442
-18
lines changed

‎test/cpp/tensorexpr/test_simplify.cpp‎

Lines changed: 303 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,12 +1120,130 @@ void testSimplifyDiv() {
11201120

11211121
IS_VAR_WITH_NAME(simplified.node(),"x");
11221122
}
1123+
}
1124+
1125+
voidtestSimplifyMod() {
1126+
KernelScope kernel_scope;
1127+
VarHandlex("x",kInt);
1128+
VarHandley("y",kInt);
1129+
VarHandlez("z",kInt);
1130+
1131+
{
1132+
// Constant folding works.
1133+
ExprHandle body =ExprHandle(10) %8;
1134+
ExprHandle simplified =IRSimplifier::simplify(body);
1135+
IS_IMM_WITH_VAL(Int, simplified.node(),2);
1136+
}
11231137

11241138
{
1125-
ExprHandle body = x / x;
1139+
// x % x => 0
1140+
ExprHandle body = x % x;
11261141
ExprHandle simplified =IRSimplifier::simplify(body);
1142+
IS_IMM_WITH_VAL(Int, simplified.node(),0);
1143+
}
11271144

1128-
IS_IMM_WITH_VAL(Int, simplified.node(),1);
1145+
{
1146+
// 0 % x => 0
1147+
ExprHandle body =ExprHandle(0) % x;
1148+
ExprHandle simplified =IRSimplifier::simplify(body);
1149+
IS_IMM_WITH_VAL(Int, simplified.node(),0);
1150+
}
1151+
1152+
{
1153+
// x % 1 => 0
1154+
ExprHandle body = x %1;
1155+
ExprHandle simplified =IRSimplifier::simplify(body);
1156+
IS_IMM_WITH_VAL(Int, simplified.node(),0);
1157+
}
1158+
1159+
{
1160+
// Doesn't change unknown mods.
1161+
// x % y => x % y
1162+
ExprHandle body = x % y;
1163+
ExprHandle simplified =IRSimplifier::simplify(body);
1164+
IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1165+
IS_VAR_WITH_NAME(mod->lhs(),"x");
1166+
IS_VAR_WITH_NAME(mod->rhs(),"y");
1167+
}
1168+
1169+
{
1170+
// don't touch if RHS is unknown.
1171+
// 4 % x => 4 % x
1172+
ExprHandle body =ExprHandle(4) % x;
1173+
ExprHandle simplified =IRSimplifier::simplify(body);
1174+
IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1175+
IS_IMM_WITH_VAL(Int, mod->lhs(),4);
1176+
IS_VAR_WITH_NAME(mod->rhs(),"x");
1177+
}
1178+
1179+
{
1180+
// don't touch if LHS is unknown.
1181+
// x % 4 => x % 4
1182+
ExprHandle body = x %4;
1183+
ExprHandle simplified =IRSimplifier::simplify(body);
1184+
IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1185+
IS_VAR_WITH_NAME(mod->lhs(),"x");
1186+
IS_IMM_WITH_VAL(Int, mod->rhs(),4);
1187+
}
1188+
1189+
{
1190+
// if LHS is a multiple of RHS, mod is zero.
1191+
// 2 * x % x => 0
1192+
ExprHandle body = (x *2) % x;
1193+
ExprHandle simplified =IRSimplifier::simplify(body);
1194+
IS_IMM_WITH_VAL(Int, simplified.node(),0);
1195+
}
1196+
1197+
{
1198+
// true even if the multiple is not constant.
1199+
// x * y % x => 0
1200+
ExprHandle body = (x * y) % x;
1201+
ExprHandle simplified =IRSimplifier::simplify(body);
1202+
IS_IMM_WITH_VAL(Int, simplified.node(),0);
1203+
}
1204+
1205+
{
1206+
// true with multiple unknown values in LHS.
1207+
// x * y * z % x => 0
1208+
ExprHandle body = (x * y * z) % x;
1209+
ExprHandle simplified =IRSimplifier::simplify(body);
1210+
IS_IMM_WITH_VAL(Int, simplified.node(),0);
1211+
}
1212+
1213+
{
1214+
// true if the denom is compound.
1215+
// x * y * z % y * z => 0
1216+
ExprHandle body = (x * y * z) % (y * z);
1217+
ExprHandle simplified =IRSimplifier::simplify(body);
1218+
IS_IMM_WITH_VAL(Int, simplified.node(),0);
1219+
}
1220+
1221+
{
1222+
// Sanity check true with scalars that are multiples.
1223+
// 12 * x % 4 => 0
1224+
ExprHandle body = (x *12) %4;
1225+
ExprHandle simplified =IRSimplifier::simplify(body);
1226+
IS_IMM_WITH_VAL(Int, simplified.node(),0);
1227+
}
1228+
1229+
{
1230+
// Sanity check not true if the smaller scalar is on LHS.
1231+
// 4 * x % 12 => 4 * x % 12
1232+
ExprHandle body = (x *4) %12;
1233+
ExprHandle simplified =IRSimplifier::simplify(body);
1234+
IS_NODE_WITH_NAME(Mod, simplified.node(), mod);
1235+
IS_NODE_WITH_NAME(Mul, mod->lhs(), mul);
1236+
IS_IMM_WITH_VAL(Int, mul->lhs(),4);
1237+
IS_VAR_WITH_NAME(mul->rhs(),"x");
1238+
IS_IMM_WITH_VAL(Int, mod->rhs(),12);
1239+
}
1240+
1241+
{
1242+
// Both scalar and symbolic in multiple.
1243+
// (6 * x * y) % (3 * x * y) => 0
1244+
ExprHandle body = (ExprHandle(6) * x * y) % (x * y *3);
1245+
ExprHandle simplified =IRSimplifier::simplify(body);
1246+
IS_IMM_WITH_VAL(Int, simplified.node(),0);
11291247
}
11301248
}
11311249

@@ -2807,6 +2925,189 @@ void testSimplifyEliminateEmptyCond() {
28072925
}
28082926
}
28092927

2928+
voidtestSimplifyConstantComparisons() {
2929+
KernelScope kernel_scope;
2930+
2931+
auto ComparisonTest =
2932+
[](ExprHandle a, ExprHandle b, CompareSelectOperation op,int result) {
2933+
ExprHandle body =CompareSelect::make(a, b, op);
2934+
ExprHandle simplified =IRSimplifier::simplify(body);
2935+
IS_IMM_WITH_VAL(Int, simplified.node(), result);
2936+
};
2937+
2938+
// Equals.
2939+
ComparisonTest(2,2,kEQ,1);
2940+
ComparisonTest(1,2,kEQ,0);
2941+
ComparisonTest(2,1,kEQ,0);
2942+
2943+
// Greater than.
2944+
ComparisonTest(2,2,kGT,0);
2945+
ComparisonTest(1,2,kGT,0);
2946+
ComparisonTest(2,1,kGT,1);
2947+
2948+
// Greater or Equal.
2949+
ComparisonTest(2,2,kGE,1);
2950+
ComparisonTest(1,2,kGE,0);
2951+
ComparisonTest(2,1,kGE,1);
2952+
2953+
// Less Than.
2954+
ComparisonTest(2,2,kLT,0);
2955+
ComparisonTest(1,2,kLT,1);
2956+
ComparisonTest(2,1,kLT,0);
2957+
2958+
// Less or Equal.
2959+
ComparisonTest(2,2,kLE,1);
2960+
ComparisonTest(1,2,kLE,1);
2961+
ComparisonTest(2,1,kLE,0);
2962+
2963+
// Not equal.
2964+
ComparisonTest(2,2,kNE,0);
2965+
ComparisonTest(1,2,kNE,1);
2966+
ComparisonTest(2,1,kNE,1);
2967+
2968+
// With specified results:
2969+
ExprHandle body =CompareSelect::make(2,2,5,42,kNE);
2970+
ExprHandle simplified =IRSimplifier::simplify(body);
2971+
IS_IMM_WITH_VAL(Int, simplified.node(),42);
2972+
}
2973+
2974+
voidtestSimplifySymbolicComparisons() {
2975+
KernelScope kernel_scope;
2976+
VarHandlex("x",kInt);
2977+
VarHandley("y",kInt);
2978+
2979+
auto TookTrueBranch = [](ExprHandle a) {IS_IMM_WITH_VAL(Int, a.node(),1); };
2980+
auto TookFalseBranch = [](ExprHandle a) {
2981+
IS_IMM_WITH_VAL(Int, a.node(),0);
2982+
};
2983+
2984+
// EQ
2985+
2986+
// x == x => 1
2987+
ExprHandle body =CompareSelect::make(x, x,kEQ);
2988+
TookTrueBranch(IRSimplifier::simplify(body));
2989+
2990+
// x == x+1 => 0
2991+
body =CompareSelect::make(x, x +1,kEQ);
2992+
TookFalseBranch(IRSimplifier::simplify(body));
2993+
2994+
// x == x * 2 cannot simplify since we don't know x is nonzero.
2995+
body =CompareSelect::make(x, x *2,kEQ);
2996+
IS_NODE(CompareSelect,IRSimplifier::simplify(body).node());
2997+
2998+
// x == x * 1 => 1
2999+
body =CompareSelect::make(x, x *1,kEQ);
3000+
TookTrueBranch(IRSimplifier::simplify(body));
3001+
3002+
{
3003+
// x == y => x == y
3004+
body =CompareSelect::make(x, y,kEQ);
3005+
ExprHandle simplified =IRSimplifier::simplify(body);
3006+
IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
3007+
ASSERT_EQ(cmp->compare_select_op(),kEQ);
3008+
IS_VAR_WITH_NAME(cmp->lhs(),"x");
3009+
IS_VAR_WITH_NAME(cmp->rhs(),"y");
3010+
}
3011+
3012+
{
3013+
// x == 5 => x == 5
3014+
body =CompareSelect::make(x,5,kEQ);
3015+
ExprHandle simplified =IRSimplifier::simplify(body);
3016+
IS_NODE_WITH_NAME(CompareSelect, simplified.node(), cmp);
3017+
ASSERT_EQ(cmp->compare_select_op(),kEQ);
3018+
IS_VAR_WITH_NAME(cmp->lhs(),"x");
3019+
IS_IMM_WITH_VAL(Int, cmp->rhs(),5);
3020+
}
3021+
3022+
// GT
3023+
3024+
// x+1 > x => 1
3025+
body =CompareSelect::make(x +1, x,kGT);
3026+
TookTrueBranch(IRSimplifier::simplify(body));
3027+
3028+
// x > x + 1 => 0
3029+
body =CompareSelect::make(x, x +1,kGT);
3030+
TookFalseBranch(IRSimplifier::simplify(body));
3031+
3032+
// x > x - 1 => 1
3033+
body =CompareSelect::make(x, x -1,kGT);
3034+
TookTrueBranch(IRSimplifier::simplify(body));
3035+
3036+
// x - 1 > x => 0
3037+
body =CompareSelect::make(x -1, x,kGT);
3038+
TookFalseBranch(IRSimplifier::simplify(body));
3039+
3040+
// x > x => 0
3041+
body =CompareSelect::make(x, x,kGT);
3042+
TookFalseBranch(IRSimplifier::simplify(body));
3043+
3044+
// x * 2 > x => x * 2 > x
3045+
// since we don't know the sign of x.
3046+
body =CompareSelect::make(x *2, x,kGT);
3047+
IS_NODE(CompareSelect,IRSimplifier::simplify(body).node());
3048+
3049+
// GE
3050+
3051+
// x+1 >= x => 1
3052+
body =CompareSelect::make(x +1, x,kGE);
3053+
TookTrueBranch(IRSimplifier::simplify(body));
3054+
3055+
// x >= x + 1 => 0
3056+
body =CompareSelect::make(x, x +1,kGE);
3057+
TookFalseBranch(IRSimplifier::simplify(body));
3058+
3059+
// x >= x => 1
3060+
body =CompareSelect::make(x, x,kGE);
3061+
TookTrueBranch(IRSimplifier::simplify(body));
3062+
3063+
// x * 2 >= x => x * 2 >= x
3064+
// since we don't know the sign of x.
3065+
body =CompareSelect::make(x *2, x,kGE);
3066+
IS_NODE(CompareSelect,IRSimplifier::simplify(body).node());
3067+
3068+
// LT
3069+
3070+
// x+1 < x => 0
3071+
body =CompareSelect::make(x +1, x,kLT);
3072+
TookFalseBranch(IRSimplifier::simplify(body));
3073+
3074+
// x < x + 1 => 1
3075+
body =CompareSelect::make(x, x +1,kLT);
3076+
TookTrueBranch(IRSimplifier::simplify(body));
3077+
3078+
// x < x => 0
3079+
body =CompareSelect::make(x, x,kLT);
3080+
TookFalseBranch(IRSimplifier::simplify(body));
3081+
3082+
// LE
3083+
3084+
// x+1 <= x => 0
3085+
body =CompareSelect::make(x +1, x,kLE);
3086+
TookFalseBranch(IRSimplifier::simplify(body));
3087+
3088+
// x <= x + 1 => 1
3089+
body =CompareSelect::make(x, x +1,kLE);
3090+
TookTrueBranch(IRSimplifier::simplify(body));
3091+
3092+
// x <= x => 1
3093+
body =CompareSelect::make(x, x,kLE);
3094+
TookTrueBranch(IRSimplifier::simplify(body));
3095+
3096+
// NE
3097+
3098+
// x+1 != x => 1
3099+
body =CompareSelect::make(x +1, x,kNE);
3100+
TookTrueBranch(IRSimplifier::simplify(body));
3101+
3102+
// x != x + 1 => 1
3103+
body =CompareSelect::make(x, x +1,kNE);
3104+
TookTrueBranch(IRSimplifier::simplify(body));
3105+
3106+
// x != x => 0
3107+
body =CompareSelect::make(x, x,kNE);
3108+
TookFalseBranch(IRSimplifier::simplify(body));
3109+
}
3110+
28103111
voidtestSimplifyEliminateZeroLengthFor() {
28113112
KernelScope kernel_scope;
28123113

‎test/cpp/tensorexpr/tests.h‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ namespace jit {
194194
_(SimplifyMuls) \
195195
_(SimplifySubs) \
196196
_(SimplifyDiv) \
197+
_(SimplifyMod) \
197198
_(SimplifyMultiOp) \
198199
_(SimplifyManyOps) \
199200
_(SimplifyFactorization) \
@@ -214,6 +215,8 @@ namespace jit {
214215
_(SimplifyConstantBranches) \
215216
_(SimplifyConstantCond) \
216217
_(SimplifyEliminateEmptyCond) \
218+
_(SimplifyConstantComparisons) \
219+
_(SimplifySymbolicComparisons) \
217220
_(SimplifyEliminateZeroLengthFor) \
218221
_(SimplifyOneLoopFor) \
219222
_(SimplifyForWontLoseLoopOptions) \

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp