Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0885dd2

Browse files
wanchaolfacebook-github-bot
authored andcommitted
refactor register_prim_ops (#21001)
Summary:Pull Requestresolved:#21001ghimport-source-id:f1b8e39Differential Revision: D15523445Pulled By: wanchaolfbshipit-source-id: c1e29b0985bde580703a1fca9df46da773826df6
1 parentb85c529 commit0885dd2

File tree

1 file changed

+39
-171
lines changed

1 file changed

+39
-171
lines changed

‎torch/csrc/jit/register_prim_ops.cpp‎

Lines changed: 39 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,20 @@ RegisterOperators logging_operators(
10831083
DEFINE_GENERIC_OP(aten_op, op, op,bool,bool), \
10841084
DEFINE_INT_FLOAT_OP(aten_op, op,bool), DEFINE_STR_CMP_OP(aten_op, op)
10851085

1086+
#defineDEFINE_UNARY_OP(aten_op, op, int_result, float_result) \
1087+
Operator(#aten_op"(int a) ->" #int_result, [](Stack& stack) { \
1088+
int64_t a; \
1089+
pop(stack, a); \
1090+
push(stack, op); \
1091+
return0; \
1092+
}), \
1093+
Operator(#aten_op"(float a) ->" #float_result, [](Stack& stack) { \
1094+
double a; \
1095+
pop(stack, a); \
1096+
push(stack, op); \
1097+
return0; \
1098+
})
1099+
10861100
#defineDEFINE_BOOL_OP(aten_op, op) \
10871101
Operator(#aten_op"(bool a, bool b) -> bool", [](Stack& stack) { \
10881102
bool a, b; \
@@ -2070,29 +2084,31 @@ RegisterOperators reg2({
20702084
float),
20712085
DEFINE_INT_FLOAT_OP(aten::floordiv,std::floor(a / b),float),
20722086

2087+
// NB: This is the python truediv operation
2088+
DEFINE_GENERIC_OP(
2089+
aten::div,
2090+
static_cast<double>(a) /static_cast<double>(b),
2091+
a / b,
2092+
float,
2093+
float),
2094+
20732095
// only used in loop unrolling, not exposed to end users
20742096
DEFINE_INT_OP(aten::__round_to_zero_floordiv, a / b),
20752097

20762098
DEFINE_INT_OP(aten::__and__, a& b),
20772099
DEFINE_INT_OP(aten::__or__, a | b),
20782100
DEFINE_INT_OP(aten::__xor__, a ^ b),
20792101

2080-
Operator(
2081-
"prim::abs(int x) -> int",
2082-
[](Stack& stack) {
2083-
int64_t x;
2084-
pop(stack, x);
2085-
push(stack,std::abs(x));
2086-
return0;
2087-
}),
2088-
Operator(
2089-
"prim::abs(float x) -> float",
2090-
[](Stack& stack) {
2091-
float x;
2092-
pop(stack, x);
2093-
push(stack,std::abs(x));
2094-
return0;
2095-
}),
2102+
DEFINE_UNARY_OP(aten::floor,std::floor(a),float,float),
2103+
DEFINE_UNARY_OP(aten::ceil,std::ceil(a),float,float),
2104+
DEFINE_UNARY_OP(aten::log,std::log(a),float,float),
2105+
DEFINE_UNARY_OP(aten::log1p,std::log1p(a),float,float),
2106+
DEFINE_UNARY_OP(aten::log10,std::log10(a),float,float),
2107+
DEFINE_UNARY_OP(aten::exp,std::exp(a),float,float),
2108+
DEFINE_UNARY_OP(aten::sqrt,std::sqrt(a),float,float),
2109+
2110+
// TODO: move abs to aten namespace because it's schematized!
2111+
DEFINE_UNARY_OP(prim::abs,std::abs(a),int,float),
20962112
Operator(
20972113
"prim::abs(Tensor x) -> Tensor",
20982114
[](Stack& stack) {
@@ -2102,127 +2118,6 @@ RegisterOperators reg2({
21022118
return0;
21032119
}),
21042120

2105-
// NB: This is the python truediv operation
2106-
Operator(
2107-
"aten::div(int a, int b) -> float",
2108-
[](Stack& stack) {
2109-
int64_t a, b;
2110-
pop(stack, a, b);
2111-
push(stack,static_cast<double>(a) /static_cast<double>(b));
2112-
return0;
2113-
}),
2114-
Operator(
2115-
"aten::div(float a, float b) -> float",
2116-
[](Stack& stack) {
2117-
double a, b;
2118-
pop(stack, a, b);
2119-
push(stack, a / b);
2120-
return0;
2121-
}),
2122-
2123-
Operator(
2124-
"aten::floor(float a) -> float",
2125-
[](Stack& stack) {
2126-
double a;
2127-
pop(stack, a);
2128-
push(stack,std::floor(a));
2129-
return0;
2130-
}),
2131-
2132-
Operator(
2133-
"aten::ceil(float a) -> float",
2134-
[](Stack& stack) {
2135-
double a;
2136-
pop(stack, a);
2137-
push(stack,std::ceil(a));
2138-
return0;
2139-
}),
2140-
2141-
Operator(
2142-
"aten::log(float a) -> float",
2143-
[](Stack& stack) {
2144-
double a;
2145-
pop(stack, a);
2146-
push(stack,std::log(a));
2147-
return0;
2148-
}),
2149-
Operator(
2150-
"aten::log(int a) -> float",
2151-
[](Stack& stack) {
2152-
int64_t a;
2153-
pop(stack, a);
2154-
push(stack,std::log(a));
2155-
return0;
2156-
}),
2157-
2158-
Operator(
2159-
"aten::log1p(float a) -> float",
2160-
[](Stack& stack) {
2161-
double a;
2162-
pop(stack, a);
2163-
push(stack,std::log1p(a));
2164-
return0;
2165-
}),
2166-
Operator(
2167-
"aten::log1p(int a) -> float",
2168-
[](Stack& stack) {
2169-
int64_t a;
2170-
pop(stack, a);
2171-
push(stack,std::log1p(a));
2172-
return0;
2173-
}),
2174-
2175-
Operator(
2176-
"aten::log10(float a) -> float",
2177-
[](Stack& stack) {
2178-
double a;
2179-
pop(stack, a);
2180-
push(stack,std::log10(a));
2181-
return0;
2182-
}),
2183-
Operator(
2184-
"aten::log10(int a) -> float",
2185-
[](Stack& stack) {
2186-
int64_t a;
2187-
pop(stack, a);
2188-
push(stack,std::log10(a));
2189-
return0;
2190-
}),
2191-
2192-
Operator(
2193-
"aten::exp(float a) -> float",
2194-
[](Stack& stack) {
2195-
double a;
2196-
pop(stack, a);
2197-
push(stack,std::exp(a));
2198-
return0;
2199-
}),
2200-
Operator(
2201-
"aten::exp(int a) -> float",
2202-
[](Stack& stack) {
2203-
int64_t a;
2204-
pop(stack, a);
2205-
push(stack,std::exp(a));
2206-
return0;
2207-
}),
2208-
2209-
Operator(
2210-
"aten::sqrt(float a) -> float",
2211-
[](Stack& stack) {
2212-
double a;
2213-
pop(stack, a);
2214-
push(stack,std::sqrt(a));
2215-
return0;
2216-
}),
2217-
Operator(
2218-
"aten::sqrt(int a) -> float",
2219-
[](Stack& stack) {
2220-
int64_t a;
2221-
pop(stack, a);
2222-
push(stack,std::sqrt(a));
2223-
return0;
2224-
}),
2225-
22262121
DEFINE_INT_OP(aten::gcd,gcd(a, b)),
22272122

22282123
DEFINE_GENERIC_OP(
@@ -2233,28 +2128,12 @@ RegisterOperators reg2({
22332128
float),
22342129
DEFINE_INT_FLOAT_OP(aten::copysign,std::copysign(a, b),float),
22352130

2236-
#defineDEFINE_MATH_OP(aten_op, op, int_result, float_result) \
2237-
Operator( \
2238-
#aten_op"(int a) ->" #int_result, \
2239-
[](Stack& stack) { \
2240-
int64_t a; \
2241-
pop(stack, a); \
2242-
push(stack, op); \
2243-
return0; \
2244-
}), \
2245-
Operator(#aten_op"(float a) ->" #float_result, [](Stack& stack) { \
2246-
double a; \
2247-
pop(stack, a); \
2248-
push(stack, op); \
2249-
return0; \
2250-
})
2251-
2252-
DEFINE_MATH_OP(aten::gamma,std::tgamma(a),float,float),
2253-
DEFINE_MATH_OP(aten::erf,std::erf(a),float,float),
2254-
DEFINE_MATH_OP(aten::erfc,std::erfc(a),float,float),
2255-
DEFINE_MATH_OP(aten::expm1,std::expm1(a),float,float),
2256-
DEFINE_MATH_OP(aten::fabs,std::fabs(a),float,float),
2257-
DEFINE_MATH_OP(aten::lgamma,std::lgamma(a),float,float),
2131+
DEFINE_UNARY_OP(aten::gamma,std::tgamma(a),float,float),
2132+
DEFINE_UNARY_OP(aten::erf,std::erf(a),float,float),
2133+
DEFINE_UNARY_OP(aten::erfc,std::erfc(a),float,float),
2134+
DEFINE_UNARY_OP(aten::expm1,std::expm1(a),float,float),
2135+
DEFINE_UNARY_OP(aten::fabs,std::fabs(a),float,float),
2136+
DEFINE_UNARY_OP(aten::lgamma,std::lgamma(a),float,float),
22582137

22592138
DEFINE_COMPARISON_OP(aten::ne, a != b),
22602139
DEFINE_COMPARISON_OP(aten::eq, a == b),
@@ -2266,18 +2145,7 @@ RegisterOperators reg2({
22662145
DEFINE_BOOL_OP(aten::__or__, a || b),
22672146
DEFINE_BOOL_OP(aten::__xor__, a != b),
22682147

2269-
Operator(
2270-
"aten::neg(int self) -> int",
2271-
[](Stack& stack) {
2272-
push(stack, -pop(stack).toInt());
2273-
return0;
2274-
}),
2275-
Operator(
2276-
"aten::neg(float self) -> float",
2277-
[](Stack& stack) {
2278-
push(stack, -pop(stack).toDouble());
2279-
return0;
2280-
}),
2148+
DEFINE_UNARY_OP(aten::neg, -a,int,float),
22812149
Operator(
22822150
"aten::__not__(bool self) -> bool",
22832151
[](Stack& stack) {

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp