diff --git a/python/dgl/backend/pytorch/sparse.py b/python/dgl/backend/pytorch/sparse.py index 67034ad99579..cea390020a02 100644 --- a/python/dgl/backend/pytorch/sparse.py +++ b/python/dgl/backend/pytorch/sparse.py @@ -145,7 +145,7 @@ def __exit__(self, *args, **kargs): # and do it only in a nested autocast context. def _disable_autocast_if_enabled(): if th.is_autocast_enabled(): - return th.cuda.amp.autocast(enabled=False) + return th.amp.autocast("cuda", enabled=False) else: return empty_context() @@ -154,8 +154,8 @@ def _cast_if_autocast_enabled(*args): if not th.is_autocast_enabled(): return args else: - return th.cuda.amp.autocast_mode._cast( - args, th.get_autocast_gpu_dtype() + return th.amp.autocast_mode._cast( + args, "cuda", th.get_autocast_gpu_dtype() )