You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I currently observed the following sanity test error when running with PyTorch 2.4.0 + CUDA 12.4 + cuDNN 9.1.0.
=================================================================================================== short test summary info ====================================================================================================
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_linear_accuracy[True-LayerNorm-126m-1-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Location of the maximum difference: 359 with 0.2375284731388092 vs 0.23826351761817932 (diff 0.0007350444793701172).
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_linear_accuracy[True-LayerNorm-126m-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Location of the maximum difference: 2663 with 0.2375284731388092 vs 0.23826351761817932 (diff 0.0007350444793701172).
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_linear_accuracy[True-RMSNorm-126m-1-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Location of the maximum difference: 614 with 0.12908415496349335 vs 0.12974251806735992 (diff 0.0006583631038665771).
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_linear_accuracy[True-RMSNorm-126m-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Location of the maximum difference: 1602 with 0.10864763706922531 vs 0.10934028774499893 (diff 0.0006926506757736206).
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_linear_accuracy[False-LayerNorm-126m-1-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Location of the maximum difference: 359 with 0.2375284731388092 vs 0.23826351761817932 (diff 0.0007350444793701172).
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_linear_accuracy[False-LayerNorm-126m-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Location of the maximum difference: 2663 with 0.2375284731388092 vs 0.23826351761817932 (diff 0.0007350444793701172).
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_linear_accuracy[False-RMSNorm-126m-1-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Location of the maximum difference: 614 with 0.12908415496349335 vs 0.12974251806735992 (diff 0.0006583631038665771).
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_linear_accuracy[False-RMSNorm-126m-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=0. Location of the maximum difference: 1602 with 0.10864763706922531 vs 0.10934028774499893 (diff 0.0006926506757736206).
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_mlp_accuracy[LayerNorm-srelu-126m-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=1152. Location of the maximum difference: 608 with 21.20404052734375 vs 21.217058181762695 (diff 0.013017654418945312).
FAILED TransformerEngine/tests/pytorch/test_numerics.py::test_layernorm_mlp_accuracy[RMSNorm-srelu-126m-2-dtype0] - AssertionError: Outputs not close enough in tensor at idx=1152. Location of the maximum difference: 608 with 21.316505432128906 vs 21.329757690429688 (diff 0.01325225830078125).
============================================================================= 10 failed, 477 passed, 80 skipped, 663 warnings in 98.25s (0:01:38) =============================================================================
This is running on a single AWS p4d.24xlarge instance with A100 GPUs within a docker container.
It seems that our numerical tolerances are too tight for TF32 compute. It's suggestive that only FP32 tests are failing and that the errors in test_layernorm_linear_accuracy are near the machine epsilon of TF32 (5e-4). Also, TE configures GEMMs with FP32 data to perform TF32 compute:
I think we didn't notice this before because NVIDIA PyTorch containers enable TF32 by default, so we were using the same cuBLAS kernels in the TE module and the PyTorch reference. However, vanilla PyTorch disables TF32 by default.
We have recently done some work to make our numerical testing more robust. In particular, #1229 reduces the size of the test cases. We should follow-up by tweaking the tolerances based on the data and compute dtypes.
Hi,
I currently observed the following sanity test error when running with PyTorch 2.4.0 + CUDA 12.4 + cuDNN 9.1.0.
This is running on a single AWS
p4d.24xlarge
instance with A100 GPUs within a docker container.The test is run using
TE is installed through
Installed libaries:
The text was updated successfully, but these errors were encountered: