Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Error when composing jit.script and jacrev #170680

Open
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queue
@Giodiro

Description

@Giodiro

🐛 Describe the bug

I think there is a bug when combining a very simple function (minimal_fn) with jit and then jacrev to compute its jacobian.
Strangely the error only occurs the 3rd time you call the jacobian (1st and 2nd work fine), and only when minimal_fn contains three or more additions.
Other than this the combination of jit+jacrev mostly works fine even for much more complex functions - it's the addition of three things that chokes it.

MRE:

importtorchdevice="cuda:0"defminimal_fn(x:torch.Tensor)->torch.Tensor:returnx+x+x#return x + x  # with two summands no errordefbug_report():x=torch.ones(128,device=device)jitted_fn=torch.jit.script(minimal_fn)jacfn=torch.func.jacrev(jitted_fn,argnums=0    )out=jacfn(x)print("Run 1 jitted OK.")out=jacfn(x)print("Run 2 jitted OK.")out=jacfn(x)# error thrown hereprint("Run 3 jitted OK.")if__name__=="__main__":bug_report()

Output and error trace:

Run 1 jitted OK.Run 2 jitted OK.Traceback (most recent call last):  File "/home/gmeanti/franken/test_jfwd.py", line 22, in <module>    test()  File "/home/gmeanti/franken/test_jfwd.py", line 18, in test    out = jacfn(x)          ^^^^^^^^  File "/srv/storage/thoth1@storage4.grenoble.grid5000.fr/gmeanti/conda/torch/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 570, in wrapper_fn    vjp_out = _vjp_with_argnums(func, *args, argnums=argnums, has_aux=has_aux)              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^  File "/srv/storage/thoth1@storage4.grenoble.grid5000.fr/gmeanti/conda/torch/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 47, in fn    return f(*args, **kwargs)           ^^^^^^^^^^^^^^^^^^  File "/srv/storage/thoth1@storage4.grenoble.grid5000.fr/gmeanti/conda/torch/lib/python3.12/site-packages/torch/_functorch/eager_transforms.py", line 358, in _vjp_with_argnums    primals_out = func(*primals)                  ^^^^^^^^^^^^^^RuntimeError: The following operation failed in the TorchScript interpreter.Traceback of TorchScript (most recent call last):RuntimeError: The following operation failed in the TorchScript interpreter.Traceback of TorchScript (most recent call last):RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Versions

PyTorch version: 2.9.1+cu126
Is debug build: False
CUDA used to build PyTorch: 12.6
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 11 (bullseye) (x86_64)
GCC version: (Debian 10.2.1-6) 10.2.1 20210110
Clang version: Could not collect
CMake version: version 3.18.4
Libc version: glibc-2.31

Python version: 3.12.11 | packaged by Anaconda, Inc. | (main, Jun 5 2025, 13:09:17) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.10.0-36-amd64-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.2.152
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA RTX A5000
Nvidia driver version: 535.183.06
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 16
On-line CPU(s) list: 0-15
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 2
NUMA node(s): 2
Vendor ID: GenuineIntel
CPU family: 6
Model: 63
Model name: Intel(R) Xeon(R) CPU E5-2623 v3 @ 3.00GHz
Stepping: 2
CPU MHz: 3441.994
CPU max MHz: 3500.0000
CPU min MHz: 1200.0000
BogoMIPS: 6000.24
Virtualization: VT-x
L1d cache: 256 KiB
L1i cache: 256 KiB
L2 cache: 2 MiB
L3 cache: 20 MiB
NUMA node0 CPU(s): 0,2,4,6,8,10,12,14
NUMA node1 CPU(s): 1,3,5,7,9,11,13,15
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsa: Not affected
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Mitigation; IBPB before exit to userspace
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_per
fmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer
aes xsave avx f16c rdrand lahf_lm abm cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm xsaveopt cqm_ll
c cqm_occup_llc dtherm ida arat pln pts md_clear flush_l1d ibpb_exit_to_user

Versions of relevant libraries:
[pip3] botorch==0.16.1
[pip3] gpytorch==1.14.3
[pip3] mace-torch==0.3.14
[pip3] mypy_extensions==1.1.0
[pip3] numpy==2.1.2
[pip3] nvidia-cublas-cu12==12.6.4.1
[pip3] nvidia-cuda-cupti-cu12==12.6.80
[pip3] nvidia-cuda-nvrtc-cu12==12.6.77
[pip3] nvidia-cuda-runtime-cu12==12.6.77
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cufft-cu12==11.3.0.4
[pip3] nvidia-curand-cu12==10.3.7.77
[pip3] nvidia-cusolver-cu12==11.7.1.2
[pip3] nvidia-cusparse-cu12==12.5.4.2
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.6.85
[pip3] nvidia-nvtx-cu12==12.6.77
[pip3] pytorch-lightning==2.5.6
[pip3] torch==2.9.1+cu126
[pip3] torch-ema==0.3
[pip3] torchdiffeq==0.2.5
[pip3] torchmetrics==1.8.2
[pip3] torchvision==0.24.1+cu126
[pip3] triton==3.5.1
[conda] botorch 0.16.1 pypi_0 pypi
[conda] gpytorch 1.14.3 pypi_0 pypi
[conda] mace-torch 0.3.14 pypi_0 pypi
[conda] numpy 2.1.2 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.6.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.6.80 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.6.77 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.0.4 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.7.77 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.1.2 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.4.2 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.6.85 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.6.77 pypi_0 pypi
[conda] pytorch-lightning 2.5.6 pypi_0 pypi
[conda] torch 2.9.1+cu126 pypi_0 pypi
[conda] torch-ema 0.3 pypi_0 pypi
[conda] torchdiffeq 0.2.5 pypi_0 pypi
[conda] torchmetrics 1.8.2 pypi_0 pypi
[conda] torchvision 0.24.1+cu126 pypi_0 pypi
[conda] triton 3.5.1 pypi_0 pypi

cc@EikanWang@jgong5@wenzhe-nrv@sanchitintel

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions


      [8]ページ先頭

      ©2009-2025 Movatter.jp