Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitbcee361

Browse files
Alexey Sidnevfacebook-github-bot
Alexey Sidnev
authored andcommitted
Replacetorch.det() with manual implementation for 3x3 matrix
Summary:# BackgroundThere is an unstable error during training (it can happen after several minutes or after several hours).The error is connected to `torch.det()` function in `_check_valid_rotation_matrix()`.if I remove the function `torch.det()` in `_check_valid_rotation_matrix()` or remove the whole functions `_check_valid_rotation_matrix()` the error is disappeared (D29555876).# SolutionReplace `torch.det()` with manual implementation for 3x3 matrix.Reviewed By: patricklabatutDifferential Revision: D29655924fbshipit-source-id: 41bde1119274a705ab849751ece28873d2c45155
1 parent2f668ec commitbcee361

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

‎pytorch3d/common/workaround.py‎

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
importtorch
9+
10+
11+
def_safe_det_3x3(t:torch.Tensor):
12+
"""
13+
Fast determinant calculation for a batch of 3x3 matrices.
14+
15+
Note, result of this function might not be the same as `torch.det()`.
16+
The differences might be in the last significant digit.
17+
18+
Args:
19+
t: Tensor of shape (N, 3, 3).
20+
21+
Returns:
22+
Tensor of shape (N) with determinants.
23+
"""
24+
25+
det= (
26+
t[...,0,0]* (t[...,1,1]*t[...,2,2]-t[...,1,2]*t[...,2,1])
27+
-t[...,0,1]* (t[...,1,0]*t[...,2,2]-t[...,2,0]*t[...,1,2])
28+
+t[...,0,2]* (t[...,1,0]*t[...,2,1]-t[...,2,0]*t[...,1,1])
29+
)
30+
31+
returndet

‎pytorch3d/transforms/transform3d.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
importtorch
1212

1313
from ..common.typesimportDevice,get_device,make_device
14+
from ..common.workaroundimport_safe_det_3x3
1415
from .rotation_conversionsimport_axis_angle_rotation
1516

1617

@@ -774,7 +775,7 @@ def _check_valid_rotation_matrix(R, tol: float = 1e-7):
774775
eye=torch.eye(3,dtype=R.dtype,device=R.device)
775776
eye=eye.view(1,3,3).expand(N,-1,-1)
776777
orthogonal=torch.allclose(R.bmm(R.transpose(1,2)),eye,atol=tol)
777-
det_R=torch.det(R)
778+
det_R=_safe_det_3x3(R)
778779
no_distortion=torch.allclose(det_R,torch.ones_like(det_R))
779780
ifnot (orthogonalandno_distortion):
780781
msg="R is not a valid rotation matrix"

‎tests/test_common_workaround.py‎

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
importunittest
9+
10+
importnumpyasnp
11+
importtorch
12+
fromcommon_testingimportTestCaseMixin
13+
frompytorch3d.common.workaroundimport_safe_det_3x3
14+
15+
16+
classTestSafeDet3x3(TestCaseMixin,unittest.TestCase):
17+
defsetUp(self)->None:
18+
super().setUp()
19+
torch.manual_seed(42)
20+
np.random.seed(42)
21+
22+
def_test_det_3x3(self,batch_size,device):
23+
t=torch.rand((batch_size,3,3),dtype=torch.float32,device=device)
24+
actual_det=_safe_det_3x3(t)
25+
expected_det=t.det()
26+
self.assertClose(actual_det,expected_det,atol=1e-7)
27+
28+
deftest_empty_batch(self):
29+
self._test_det_3x3(0,torch.device("cpu"))
30+
self._test_det_3x3(0,torch.device("cuda:0"))
31+
32+
deftest_manual(self):
33+
t=torch.Tensor(
34+
[
35+
[[1,0,0], [0,1,0], [0,0,1]],
36+
[[2,-5,3], [0,7,-2], [-1,4,1]],
37+
[[6,1,1], [4,-2,5], [2,8,7]],
38+
]
39+
).to(dtype=torch.float32)
40+
expected_det=torch.Tensor([1,41,-306]).to(dtype=torch.float32)
41+
self.assertClose(_safe_det_3x3(t),expected_det)
42+
43+
device_cuda=torch.device("cuda:0")
44+
self.assertClose(
45+
_safe_det_3x3(t.to(device=device_cuda)),expected_det.to(device=device_cuda)
46+
)
47+
48+
deftest_regression(self):
49+
tries=32
50+
device_cpu=torch.device("cpu")
51+
device_cuda=torch.device("cuda:0")
52+
batch_sizes=np.random.randint(low=1,high=128,size=tries)
53+
54+
forbatch_sizeinbatch_sizes:
55+
self._test_det_3x3(batch_size,device_cpu)
56+
self._test_det_3x3(batch_size,device_cuda)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp