@@ -1083,6 +1083,20 @@ RegisterOperators logging_operators(
10831083DEFINE_GENERIC_OP (aten_op, op, op,bool ,bool ), \
10841084DEFINE_INT_FLOAT_OP (aten_op, op,bool ), DEFINE_STR_CMP_OP(aten_op, op)
10851085
1086+ #define DEFINE_UNARY_OP (aten_op, op, int_result, float_result ) \
1087+ Operator (#aten_op" (int a) ->" #int_result, [](Stack& stack) { \
1088+ int64_t a; \
1089+ pop (stack, a); \
1090+ push (stack, op); \
1091+ return 0 ; \
1092+ }), \
1093+ Operator (#aten_op" (float a) ->" #float_result, [](Stack& stack) { \
1094+ double a; \
1095+ pop (stack, a); \
1096+ push (stack, op); \
1097+ return 0 ; \
1098+ })
1099+
10861100#define DEFINE_BOOL_OP (aten_op, op ) \
10871101Operator (#aten_op" (bool a, bool b) -> bool" , [](Stack& stack) { \
10881102bool a, b; \
@@ -2070,29 +2084,31 @@ RegisterOperators reg2({
20702084float ),
20712085DEFINE_INT_FLOAT_OP (aten::floordiv,std::floor (a / b),float ),
20722086
2087+ // NB: This is the python truediv operation
2088+ DEFINE_GENERIC_OP (
2089+ aten::div,
2090+ static_cast <double >(a) /static_cast <double >(b),
2091+ a / b,
2092+ float ,
2093+ float ),
2094+
20732095// only used in loop unrolling, not exposed to end users
20742096DEFINE_INT_OP (aten::__round_to_zero_floordiv, a / b),
20752097
20762098DEFINE_INT_OP (aten::__and__, a& b),
20772099DEFINE_INT_OP (aten::__or__, a | b),
20782100DEFINE_INT_OP (aten::__xor__, a ^ b),
20792101
2080- Operator (
2081- " prim::abs(int x) -> int" ,
2082- [](Stack& stack) {
2083- int64_t x;
2084- pop (stack, x);
2085- push (stack,std::abs (x));
2086- return 0 ;
2087- }),
2088- Operator (
2089- " prim::abs(float x) -> float" ,
2090- [](Stack& stack) {
2091- float x;
2092- pop (stack, x);
2093- push (stack,std::abs (x));
2094- return 0 ;
2095- }),
2102+ DEFINE_UNARY_OP (aten::floor,std::floor (a),float ,float ),
2103+ DEFINE_UNARY_OP (aten::ceil,std::ceil (a),float ,float ),
2104+ DEFINE_UNARY_OP (aten::log,std::log (a),float ,float ),
2105+ DEFINE_UNARY_OP (aten::log1p,std::log1p (a),float ,float ),
2106+ DEFINE_UNARY_OP (aten::log10,std::log10 (a),float ,float ),
2107+ DEFINE_UNARY_OP (aten::exp,std::exp (a),float ,float ),
2108+ DEFINE_UNARY_OP (aten::sqrt,std::sqrt (a),float ,float ),
2109+
2110+ // TODO: move abs to aten namespace because it's schematized!
2111+ DEFINE_UNARY_OP (prim::abs,std::abs (a),int ,float ),
20962112Operator (
20972113" prim::abs(Tensor x) -> Tensor" ,
20982114 [](Stack& stack) {
@@ -2102,127 +2118,6 @@ RegisterOperators reg2({
21022118return 0 ;
21032119 }),
21042120
2105- // NB: This is the python truediv operation
2106- Operator (
2107- " aten::div(int a, int b) -> float" ,
2108- [](Stack& stack) {
2109- int64_t a, b;
2110- pop (stack, a, b);
2111- push (stack,static_cast <double >(a) /static_cast <double >(b));
2112- return 0 ;
2113- }),
2114- Operator (
2115- " aten::div(float a, float b) -> float" ,
2116- [](Stack& stack) {
2117- double a, b;
2118- pop (stack, a, b);
2119- push (stack, a / b);
2120- return 0 ;
2121- }),
2122-
2123- Operator (
2124- " aten::floor(float a) -> float" ,
2125- [](Stack& stack) {
2126- double a;
2127- pop (stack, a);
2128- push (stack,std::floor (a));
2129- return 0 ;
2130- }),
2131-
2132- Operator (
2133- " aten::ceil(float a) -> float" ,
2134- [](Stack& stack) {
2135- double a;
2136- pop (stack, a);
2137- push (stack,std::ceil (a));
2138- return 0 ;
2139- }),
2140-
2141- Operator (
2142- " aten::log(float a) -> float" ,
2143- [](Stack& stack) {
2144- double a;
2145- pop (stack, a);
2146- push (stack,std::log (a));
2147- return 0 ;
2148- }),
2149- Operator (
2150- " aten::log(int a) -> float" ,
2151- [](Stack& stack) {
2152- int64_t a;
2153- pop (stack, a);
2154- push (stack,std::log (a));
2155- return 0 ;
2156- }),
2157-
2158- Operator (
2159- " aten::log1p(float a) -> float" ,
2160- [](Stack& stack) {
2161- double a;
2162- pop (stack, a);
2163- push (stack,std::log1p (a));
2164- return 0 ;
2165- }),
2166- Operator (
2167- " aten::log1p(int a) -> float" ,
2168- [](Stack& stack) {
2169- int64_t a;
2170- pop (stack, a);
2171- push (stack,std::log1p (a));
2172- return 0 ;
2173- }),
2174-
2175- Operator (
2176- " aten::log10(float a) -> float" ,
2177- [](Stack& stack) {
2178- double a;
2179- pop (stack, a);
2180- push (stack,std::log10 (a));
2181- return 0 ;
2182- }),
2183- Operator (
2184- " aten::log10(int a) -> float" ,
2185- [](Stack& stack) {
2186- int64_t a;
2187- pop (stack, a);
2188- push (stack,std::log10 (a));
2189- return 0 ;
2190- }),
2191-
2192- Operator (
2193- " aten::exp(float a) -> float" ,
2194- [](Stack& stack) {
2195- double a;
2196- pop (stack, a);
2197- push (stack,std::exp (a));
2198- return 0 ;
2199- }),
2200- Operator (
2201- " aten::exp(int a) -> float" ,
2202- [](Stack& stack) {
2203- int64_t a;
2204- pop (stack, a);
2205- push (stack,std::exp (a));
2206- return 0 ;
2207- }),
2208-
2209- Operator (
2210- " aten::sqrt(float a) -> float" ,
2211- [](Stack& stack) {
2212- double a;
2213- pop (stack, a);
2214- push (stack,std::sqrt (a));
2215- return 0 ;
2216- }),
2217- Operator (
2218- " aten::sqrt(int a) -> float" ,
2219- [](Stack& stack) {
2220- int64_t a;
2221- pop (stack, a);
2222- push (stack,std::sqrt (a));
2223- return 0 ;
2224- }),
2225-
22262121DEFINE_INT_OP (aten::gcd,gcd (a, b)),
22272122
22282123DEFINE_GENERIC_OP (
@@ -2233,28 +2128,12 @@ RegisterOperators reg2({
22332128float ),
22342129DEFINE_INT_FLOAT_OP (aten::copysign,std::copysign (a, b),float ),
22352130
2236- #define DEFINE_MATH_OP (aten_op, op, int_result, float_result ) \
2237- Operator ( \
2238- #aten_op" (int a) ->" #int_result, \
2239- [](Stack& stack) { \
2240- int64_t a; \
2241- pop (stack, a); \
2242- push (stack, op); \
2243- return 0 ; \
2244- }), \
2245- Operator (#aten_op" (float a) ->" #float_result, [](Stack& stack) { \
2246- double a; \
2247- pop (stack, a); \
2248- push (stack, op); \
2249- return 0 ; \
2250- })
2251-
2252- DEFINE_MATH_OP (aten::gamma,std::tgamma (a),float ,float ),
2253- DEFINE_MATH_OP (aten::erf,std::erf (a),float ,float ),
2254- DEFINE_MATH_OP (aten::erfc,std::erfc (a),float ,float ),
2255- DEFINE_MATH_OP (aten::expm1,std::expm1 (a),float ,float ),
2256- DEFINE_MATH_OP (aten::fabs,std::fabs (a),float ,float ),
2257- DEFINE_MATH_OP (aten::lgamma,std::lgamma (a),float ,float ),
2131+ DEFINE_UNARY_OP (aten::gamma,std::tgamma (a),float ,float ),
2132+ DEFINE_UNARY_OP (aten::erf,std::erf (a),float ,float ),
2133+ DEFINE_UNARY_OP (aten::erfc,std::erfc (a),float ,float ),
2134+ DEFINE_UNARY_OP (aten::expm1,std::expm1 (a),float ,float ),
2135+ DEFINE_UNARY_OP (aten::fabs,std::fabs (a),float ,float ),
2136+ DEFINE_UNARY_OP (aten::lgamma,std::lgamma (a),float ,float ),
22582137
22592138DEFINE_COMPARISON_OP (aten::ne, a != b),
22602139DEFINE_COMPARISON_OP (aten::eq, a == b),
@@ -2266,18 +2145,7 @@ RegisterOperators reg2({
22662145DEFINE_BOOL_OP (aten::__or__, a || b),
22672146DEFINE_BOOL_OP (aten::__xor__, a != b),
22682147
2269- Operator (
2270- " aten::neg(int self) -> int" ,
2271- [](Stack& stack) {
2272- push (stack, -pop (stack).toInt ());
2273- return 0 ;
2274- }),
2275- Operator (
2276- " aten::neg(float self) -> float" ,
2277- [](Stack& stack) {
2278- push (stack, -pop (stack).toDouble ());
2279- return 0 ;
2280- }),
2148+ DEFINE_UNARY_OP (aten::neg, -a,int ,float ),
22812149Operator (
22822150" aten::__not__(bool self) -> bool" ,
22832151 [](Stack& stack) {