Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit6948419

Browse files
gderossipytorchmergebot
authored andcommitted
Fix scaled_matmul_cuda tests (#169834)
This PR fixes a few test failures in `test_scaled_matmul_cuda.py` by adding Thor to a list of devices not compatible with SM carveout and by updating an SM version check to include all devices with SM >= 10.x instead of just devices with SM == 10.x.Based on commit history, it looks like the `dprops->major == 10` was just a typo introduced when upgrading to the new `scaled_mm_v2` API, but if it was intentional I can look into alternative fixes to these tests.Fixes#169833Pull Requestresolved:#169834Approved by:https://github.com/slayton58
1 parentf73345c commit6948419

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

‎aten/src/ATen/native/cuda/ScaledBlas.cpp‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ _scaled_rowwise_rowwise(
739739
auto dprops =at::cuda::getCurrentDeviceProperties();
740740
if (((dprops->major <9 || CUBLAS_VERSION <120900 ||cublasLtGetVersion() <120900)
741741
// cuBLAS only supports tiled 1D factor layout for 1D block scaling, no 2D block scales
742-
|| (dprops->major==10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) {
742+
|| (dprops->major>=10 && (!scale_a.sizes().empty() || !scale_b.sizes().empty())))) {
743743
TORCH_CHECK_VALUE(out.dtype() ==kBFloat16 || out.dtype() ==kHalf,"Only bf16 and fp16 high precision output types are supported for row-wise scaling.");
744744
at::cuda::detail::f8f8bf16_rowwise(
745745
mat_a,

‎test/test_scaled_matmul_cuda.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1790,7 +1790,7 @@ def test_honor_sm_carveout(self) -> None:
17901790

17911791
self.assertEqual(no_carveout,no_carveout_again)
17921792
capability=torch.cuda.get_device_capability()
1793-
ifcapabilityin {(10,0), (10,3), (12,0), (12,1)}:
1793+
ifcapabilityin {(10,0), (10,3), (11,0), (12,0), (12,1)}:
17941794
# expected failure
17951795
# CUTLASS only supports SM carveout via green contexts on SM100
17961796
self.assertEqual(no_carveout,carveout_66)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp