- Notifications
You must be signed in to change notification settings - Fork26.3k
Open
Description
🐛 Describe the bug
This issue reports a numerical inconsistency between theeager backend and theInductor backend.
Foraten.cosine_similarity with bfloat16 inputs, the output produced by Inductor is not numerically consistent with eager execution, exceeding the default tolerances.
This indicates a correctness issue, rather than an expected precision difference.
Result
Traceback (most recent call last): File "/data/shenqingchao/tritonFuzz/src/../results/12-17_19-25/res_executions/seed_test_cosine_similarity_1.py", line 12, in <module> torch.testing.assert_close(out1, out_inductor) File "/home/shenqingchao/miniconda3/lib/python3.12/site-packages/torch/testing/_comparison.py", line 1530, in assert_close raise error_metas[0].to_error(msg)AssertionError: Tensor-likes are not close!Mismatched elements: 3 / 30 (10.0%)Greatest absolute difference: 0.0012359619140625 at index (0, 4) (up to 1e-05 allowed)Greatest relative difference: 0.2451171875 at index (0, 4) (up to 0.016 allowed)importtorchdefmodel_func(x1,x2,dim,eps):out=torch.ops.aten.cosine_similarity(x1,x2=x2,dim=dim,eps=eps,)returnoutop_config= {'x1':torch.randn([5,6,7],dtype=torch.bfloat16,device='cuda')*0.1,'x2':torch.randn([5,6,7],dtype=torch.bfloat16,device='cuda')*0.1,'dim':2,'eps':1e-05,}compiled_eager=torch.compile(model_func,backend="eager")out1=compiled_eager(**op_config)compiled_inductor=torch.compile(model_func,backend="inductor",options={"trace.enabled":True,})out_inductor=compiled_inductor(**op_config)torch.testing.assert_close(out1,out_inductor)
Versions
[pip3] mypy_extensions==1.1.0[pip3] numpy==2.2.1[pip3] nvidia-cublas-cu12==12.4.5.8[pip3] nvidia-cuda-cupti-cu12==12.4.127[pip3] nvidia-cuda-nvrtc-cu12==12.4.127[pip3] nvidia-cuda-runtime-cu12==12.4.127[pip3] nvidia-cudnn-cu12==9.1.0.70[pip3] nvidia-cufft-cu12==11.2.1.3[pip3] nvidia-curand-cu12==10.3.5.147[pip3] nvidia-cusolver-cu12==11.6.1.9[pip3] nvidia-cusparse-cu12==12.3.1.170[pip3] nvidia-cusparselt-cu12==0.7.1[pip3] nvidia-nccl-cu12==2.21.5[pip3] nvidia-nvjitlink-cu12==12.4.127[pip3] nvidia-nvtx-cu12==12.4.127[pip3] torch==2.5.1[pip3] triton==3.1.0[conda] numpy 2.2.1 pypi_0 pypi[conda] nvidia-cublas-cu12 12.4.5.8 pypi_0 pypi[conda] nvidia-cuda-cupti-cu12 12.4.127 pypi_0 pypi[conda] nvidia-cuda-nvrtc-cu12 12.4.127 pypi_0 pypi[conda] nvidia-cuda-runtime-cu12 12.4.127 pypi_0 pypi[conda] nvidia-cudnn-cu12 9.1.0.70 pypi_0 pypi[conda] nvidia-cufft-cu12 11.2.1.3 pypi_0 pypi[conda] nvidia-curand-cu12 10.3.5.147 pypi_0 pypi[conda] nvidia-cusolver-cu12 11.6.1.9 pypi_0 pypi[conda] nvidia-cusparse-cu12 12.3.1.170 pypi_0 pypi[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi[conda] nvidia-nccl-cu12 2.21.5 pypi_0 pypi[conda] nvidia-nvjitlink-cu12 12.4.127 pypi_0 pypi[conda] nvidia-nvtx-cu12 12.4.127 pypi_0 pypi[conda] torch 2.5.1 pypi_0 pypi[conda] triton 3.1.0 pypi_0 pypicc@ezyang@gchanan@kadeng@msaroufim@chauhang@penguinwu@voznesenskym@EikanWang@jgong5@Guobing-Chen@XiaobingSuper@zhuhaozhe@blzheng@wenzhe-nrv@jiayisunx@ipiszy@muchulee8@amjames@aakhundov@coconutruben@jataylo