Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Speed up an integer to the power of a positive integer on CPU#26020

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed
xuhdev wants to merge1 commit intopytorch:masterfromxuhdev:int-pow
Closed
Show file tree
Hide file tree
Changes fromall commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 68 additions & 52 deletionsaten/src/ATen/native/cpu/PowKernel.cpp
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -35,10 +35,8 @@ void pow_tensor_tensor_kernel(TensorIterator& iter) {
}

void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
// Casting exponent to double(not tensor.dtype) allows powering integral
// tensors to float exponent e.g. tensor([4]).pow(0.5) will be tensor([2])
const auto exp = exp_scalar.to<double>();
if (isFloatingType(iter.dtype())) {
const auto exp = exp_scalar.to<double>();
// Floating types allow AVX2 vector optimizations for pow/sqrt/rsqrt:
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "pow", [&]() {
using Vec = Vec256<scalar_t>;
Expand DownExpand Up@@ -98,55 +96,73 @@ void pow_tensor_scalar_kernel(TensorIterator& iter, Scalar exp_scalar) {
// Trying to implement pow/sqrt/rsqrt as loop in vec256_int.h does not allow
// powering integral tensor to float exponent. That's why we need this code
// duplication:
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
if (exp == 0.5) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return std::sqrt(static_cast<long double>(base));
}
);
} else if (exp == 2) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
const auto ld_base = static_cast<long double>(base);
return ld_base * ld_base;
}
);
} else if (exp == 3) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
const auto ld_base = static_cast<long double>(base);
return ld_base * ld_base * ld_base;
}
);
} else if (exp == -0.5) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / std::sqrt(static_cast<long double>(base));
}
);
} else if (exp == -1) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / static_cast<long double>(base);
}
);
} else if (exp == -2) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
const auto ld_base = static_cast<long double>(base);
return 1.0 / (ld_base * ld_base);
}
);
} else {
cpu_kernel(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(static_cast<long double>(base),
static_cast<long double>(exp));
}
);
}
});

if (exp_scalar.isIntegral(true) && exp_scalar.to<int64_t>() >= 0) {
// Specifically deal with an integer to the power of a positive integer for better efficiency.
const auto exp = exp_scalar.to<int64_t>();

AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
switch (exp) {
case 2:
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return base * base;
}
);
break;
case 3:
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return base * base * base;
}
);
break;
default:
cpu_kernel(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(base, exp);
}
);
}
});
} else {
// Casting exponent to double(not tensor.dtype) allows powering integral
// tensors to float exponent e.g. tensor([4]).pow(0.5) will be tensor([2])
const auto exp = exp_scalar.to<double>();
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "pow", [&]() {
if (exp == 0.5) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return std::sqrt(static_cast<long double>(base));
}
);
} else if (exp == -0.5) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / std::sqrt(static_cast<long double>(base));
}
);
} else if (exp == -1) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / static_cast<long double>(base);
}
);
} else if (exp == -2) {
cpu_kernel(iter,
[](scalar_t base) -> scalar_t {
return 1.0 / (base * base);
}
);
} else {
cpu_kernel(iter,
[=](scalar_t base) -> scalar_t {
return std::pow(static_cast<long double>(base), exp);
}
);
}
});
}
}
}

Expand Down
105 changes: 60 additions & 45 deletionstest/test_torch.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -1346,51 +1346,6 @@ def test_baddbmm(self):
res6 = torch.baddbmm(.1, res2, .5, b1, b2)
self.assertEqual(res6, res2 * .1 + res * .5)

def test_pow(self):
# [res] torch.pow([res,] x)

# pow has dedicated implementation for different exponents
for exponent in [-2, -1, -0.5, 0.5, 1, 2, 3, 4]:
# base - tensor, exponent - number
# contiguous
m1 = torch.rand(100, 100) + 0.5
res1 = torch.pow(m1[4], exponent)
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(m1[4][i], exponent)
self.assertEqual(res1, res2)

# non-contiguous
m1 = torch.rand(100, 100) + 0.5
res1 = torch.pow(m1[:, 4], exponent)
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(m1[i, 4], exponent)
self.assertEqual(res1, res2)

# base - number, exponent - tensor
# contiguous
m1 = torch.randn(100, 100)
res1 = torch.pow(3, m1[4])
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(3, m1[4, i])
self.assertEqual(res1, res2)

# non-contiguous
m1 = torch.randn(100, 100)
res1 = torch.pow(3, m1[:, 4])
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(3, m1[i][4])
self.assertEqual(res1, res2)

# resize behavior for exp == 1
m1 = torch.randn(2, 2)
out = torch.randn([0])
torch.pow(m1, 1, out=out)
self.assertEqual(out, m1)

def _test_cop(self, torchfn, mathfn):
def reference_implementation(res2):
for i, j in iter_indices(sm1):
Expand DownExpand Up@@ -7013,6 +6968,66 @@ def test_diagonal(self, device):
expected = torch.diag(x, 17)
self.assertEqual(result, expected)

def test_pow(self, device):
# [res] torch.pow([res,] x)

# pow has dedicated implementation for different exponents
for dtype in torch.testing.get_all_math_dtypes(device):

# This test won't work on torch.half because math.pow will generate a much more accurate result. We skip it
# for now.
if dtype == torch.half:
continue

m1 = torch.empty(0, dtype=dtype, device=device)
if m1.is_floating_point():
m1 = torch.rand(100, 100, dtype=dtype, device=device) + 0.5
else:
# math.pow will overflow and throw exceptions for large integers
range_high = 4 if dtype in (torch.int8, torch.uint8) else 10
m1 = torch.randint(1, range_high, (100, 100), dtype=dtype, device=device)

for num in [-2.8, -2, -1, -0.5, 0, 0.5, 1, 2, 3, 4, 3.3]:
if isinstance(num, int) and num < 0 and not m1.is_floating_point():
with self.assertRaisesRegex(RuntimeError,
r'Integers to negative integer powers are not allowed\.'):
torch.pow(m1[4], num)
else:
# base - tensor, exponent - number
# contiguous
res1 = torch.pow(m1[4], num)
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(m1[4][i], num)
self.assertEqual(res1, res2)

# non-contiguous
res1 = torch.pow(m1[:, 4], num)
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(m1[i, 4], num)
self.assertEqual(res1, res2)

# base - number, exponent - tensor
# contiguous
res1 = torch.pow(3, m1[4])
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(3, m1[4, i])
self.assertEqual(res1, res2)

# non-contiguous
res1 = torch.pow(3, m1[:, 4])
res2 = res1.clone().zero_()
for i in range(res2.size(0)):
res2[i] = math.pow(3, m1[i][4])
self.assertEqual(res1, res2)

# resize behavior for exp == 1
out = torch.zeros(1, dtype=dtype, device=device)
torch.pow(m1, 1, out=out)
self.assertEqual(out, m1)

def test_neg(self, device):
int_types = [torch.int, torch.short, torch.int8, torch.uint8]
float_types = [torch.float, torch.double, torch.long]
Expand Down

[8]ページ先頭

©2009-2025 Movatter.jp