Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0aef44c

Browse files
IvanYashchukfacebook-github-bot
authored andcommitted
Add forward AD for torch.linalg.eigh (#62163)
Summary:This PR adds forward mode differentiation for `torch.linalg.eigh` and a few other functions required for tests to pass.For some reason running tests for `torch.linalg.eigvalsh` and complex `torch.linalg.eigh` hangs. These tests are skipped for now.cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry heitorschueroff walterddr IvanYashchuk xwang233Pull Requestresolved:#62163Reviewed By: jbschlosserDifferential Revision: D30903988Pulled By: albanDfbshipit-source-id: d6a74adb9e6d2f4be8ac707848ecabf06d629823
1 parent35c82db commit0aef44c

File tree

4 files changed

+74
-3
lines changed

4 files changed

+74
-3
lines changed

‎tools/autograd/derivatives.yaml‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@
479479

480480
-name:diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
481481
self:diagonal_backward(grad, self.sizes(), offset, dim1, dim2)
482+
result:auto_linear
482483

483484
-name:dist(Tensor self, Tensor other, Scalar p=2) -> Tensor
484485
self:norm_backward(grad, self - other, p, result)
@@ -579,10 +580,12 @@
579580

580581
-name:fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
581582
self:zeros_like(grad)
583+
result:self_t.fill_(0)
582584

583585
-name:fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
584586
self:zeros_like(grad)
585587
value:grad.sum()
588+
result:self_t.fill_(value_t)
586589

587590
-name:floor(Tensor self) -> Tensor
588591
self:zeros_like(grad)
@@ -1338,6 +1341,8 @@
13381341

13391342
-name:linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors)
13401343
self:eigh_backward(grads, self, /*eigenvectors=*/true, eigenvalues, eigenvectors)
1344+
eigenvalues:eigh_jvp_eigenvalues(self_t, eigenvalues, eigenvectors)
1345+
eigenvectors:eigh_jvp_eigenvectors(self_t, eigenvalues, eigenvectors)
13411346

13421347
-name:linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)
13431348
self:linalg_eig_backward(grads, self, eigenvalues, eigenvectors)

‎torch/csrc/autograd/FunctionsManual.cpp‎

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,6 +2409,54 @@ Tensor linalg_eig_backward(const std::vector<torch::autograd::Variable> &grads,
24092409
}
24102410
}
24112411

2412+
// jvp functions for eigenvalues and eigenvectors are separate
2413+
// because currently forward AD only works with one rule per output
2414+
Tensoreigh_jvp_eigenvalues(
2415+
const Tensor& input_tangent,
2416+
const Tensor& eigenvalues,
2417+
const Tensor& eigenvectors) {
2418+
// An extended collection of matrix derivative results for forward and reverse mode automatic differentiation
2419+
// https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
2420+
// Section 3.1 Eigenvalues and eigenvectors
2421+
2422+
// TODO: gradcheck from test_ops.py hangs with complex inputs
2423+
TORCH_CHECK_NOT_IMPLEMENTED(
2424+
!input_tangent.is_complex(),
2425+
"the derivative for 'eigh' with complex inputs is not implemented.");
2426+
2427+
// see the note in the implementation of eigh_backward that tangent should be Hermitian
2428+
auto hermitian_tangent =0.5*(input_tangent + input_tangent.transpose(-2, -1).conj());
2429+
2430+
auto tmp =at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors);
2431+
auto eigenvalues_tangent = tmp.diagonal(/*offset=*/0,/*dim1=*/-2,/*dim2=*/-1);
2432+
if (eigenvalues_tangent.is_complex()) {
2433+
returnat::real(eigenvalues_tangent);
2434+
}
2435+
return eigenvalues_tangent;
2436+
}
2437+
2438+
Tensoreigh_jvp_eigenvectors(
2439+
const Tensor& input_tangent,
2440+
const Tensor& eigenvalues,
2441+
const Tensor& eigenvectors) {
2442+
// An extended collection of matrix derivative results for forward and reverse mode automatic differentiation
2443+
// https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
2444+
// Section 3.1 Eigenvalues and eigenvectors
2445+
2446+
TORCH_CHECK_NOT_IMPLEMENTED(
2447+
!input_tangent.is_complex(),
2448+
"the derivative for 'eigh' with complex inputs is not implemented.");
2449+
2450+
auto E = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1);
2451+
E.diagonal(/*offset=*/0,/*dim1=*/-2,/*dim2=*/-1).fill_(INFINITY);
2452+
2453+
// see the note in the implementation of eigh_backward that tangent should be Hermitian
2454+
auto hermitian_tangent =0.5*(input_tangent + input_tangent.transpose(-2, -1).conj());
2455+
2456+
auto tmp =at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors);
2457+
returnat::matmul(eigenvectors, tmp.div(E));
2458+
}
2459+
24122460
Tensoreigh_backward(const std::vector<torch::autograd::Variable> &grads,const Tensor& self,
24132461
bool eigenvectors,const Tensor& L,const Tensor& V) {
24142462
// This function is used for both torch.symeig and torch.linalg.eigh.

‎torch/csrc/autograd/FunctionsManual.h‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ Tensor slice_backward_wrapper(
157157
int64_t step);
158158
Tensorlinalg_eig_backward(const std::vector<torch::autograd::Variable> &grads,const Tensor& self,
159159
const Tensor& L,const Tensor& V);
160+
Tensoreigh_jvp_eigenvectors(const Tensor& input_tangent,const Tensor& eigenvalues,const Tensor& eigenvectors);
161+
Tensoreigh_jvp_eigenvalues(const Tensor& input_tangent,const Tensor& eigenvalues,const Tensor& eigenvectors);
160162
Tensoreigh_backward(const std::vector<torch::autograd::Variable> &grads,const Tensor& self,
161163
bool eigenvectors,const Tensor& L,const Tensor& V);
162164
std::tuple<Tensor, Tensor>triangular_solve_backward(

‎torch/testing/_internal/common_methods_invocations.py‎

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6583,6 +6583,7 @@ def wrapper(x: np.ndarray, *args, **kwargs):
65836583
OpInfo('diagonal',
65846584
dtypes=all_types_and_complex_and(torch.bool,torch.bfloat16,torch.float16),
65856585
supports_out=False,
6586+
supports_forward_ad=True,
65866587
sample_inputs_func=sample_inputs_diagonal_diag_embed),
65876588
OpInfo('eq',
65886589
dtypes=all_types_and_complex_and(torch.bool,torch.bfloat16,torch.float16),
@@ -6969,16 +6970,25 @@ def wrapper(x: np.ndarray, *args, **kwargs):
69696970
aten_name='linalg_eigh',
69706971
dtypes=floating_and_complex_types(),
69716972
check_batched_gradgrad=False,
6973+
supports_forward_ad=True,
69726974
sample_inputs_func=sample_inputs_linalg_eigh,
69736975
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
6974-
decorators=[skipCUDAIfNoMagma,skipCUDAIfRocm,skipCPUIfNoLapack]),
6976+
decorators=[skipCUDAIfNoMagma,skipCUDAIfRocm,skipCPUIfNoLapack],
6977+
skips=(
6978+
# Gradcheck for complex hangs for this function, therefore it raises NotImplementedError for now
6979+
SkipInfo('TestGradients','test_forward_mode_AD',dtypes=complex_types()),),
6980+
),
69756981
OpInfo('linalg.eigvalsh',
69766982
aten_name='linalg_eigvalsh',
69776983
dtypes=floating_and_complex_types(),
69786984
check_batched_gradgrad=False,
69796985
sample_inputs_func=sample_inputs_linalg_eigh,
69806986
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
6981-
decorators=[skipCUDAIfNoMagma,skipCUDAIfRocm,skipCPUIfNoLapack],),
6987+
decorators=[skipCUDAIfNoMagma,skipCUDAIfRocm,skipCPUIfNoLapack],
6988+
skips=(
6989+
# Gradcheck hangs for this function
6990+
SkipInfo('TestGradients','test_forward_mode_AD'),),
6991+
),
69826992
OpInfo('linalg.householder_product',
69836993
aten_name='linalg_householder_product',
69846994
op=torch.linalg.householder_product,
@@ -8429,7 +8439,11 @@ def wrapper(x: np.ndarray, *args, **kwargs):
84298439
check_batched_gradgrad=False,
84308440
sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
84318441
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
8432-
decorators=[skipCUDAIfNoMagma,skipCUDAIfRocm,skipCPUIfNoLapack]),
8442+
decorators=[skipCUDAIfNoMagma,skipCUDAIfRocm,skipCPUIfNoLapack],
8443+
skips=(
8444+
# Gradcheck hangs for this function
8445+
SkipInfo('TestGradients','test_forward_mode_AD'),),
8446+
),
84338447
OpInfo('eig',
84348448
op=torch.eig,
84358449
dtypes=floating_and_complex_types(),
@@ -8448,6 +8462,7 @@ def wrapper(x: np.ndarray, *args, **kwargs):
84488462
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half,
84498463
*[torch.bfloat16]if (SM60OrLaterandCUDA11OrLater)else []),
84508464
supports_out=False,
8465+
supports_forward_ad=True,
84518466
sample_inputs_func=sample_inputs_einsum,
84528467
skips=(
84538468
# test does not work with passing lambda for op
@@ -8877,6 +8892,7 @@ def wrapper(x: np.ndarray, *args, **kwargs):
88778892
op=lambdax,scalar:torch.fill_(x.clone(),scalar),
88788893
method_variant=None,
88798894
inplace_variant=torch.Tensor.fill_,
8895+
supports_forward_ad=True,
88808896
dtypes=all_types_and_complex_and(torch.bool,torch.float16,torch.bfloat16),
88818897
supports_out=False,
88828898
skips=(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp