diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py index 86f5b0aabe72..dd545a778578 100644 --- a/python/mxnet/amp/lists/symbol_bf16.py +++ b/python/mxnet/amp/lists/symbol_bf16.py @@ -362,6 +362,7 @@ 'zeros_like', '_sg_onednn_conv', '_sg_onednn_fully_connected', + '_sg_onednn_batch_dot', 'broadcast_mul', 'Convolution_v1', 'IdentityAttachKLSparseReg', diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index b561b335d9a7..b2a0a9d90a3c 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -615,6 +615,7 @@ '_sg_onednn_fully_connected', '_sg_onednn_selfatt_qk', '_sg_onednn_selfatt_valatt', + '_sg_onednn_batch_dot' ]) # Functions that have to be cast to FP32 only for