Skip to content

Commit

Permalink
Update optim.grad_scaler to use torch.amp
Browse files Browse the repository at this point in the history
Co-authored-by: Luciferian Ink <LuciferianInk@protonmail.com>
  • Loading branch information
mryab committed Jun 9, 2024
1 parent 232c6b7 commit 8859194
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions hivemind/optim/grad_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@
from typing import Dict, Optional

import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state
from packaging import version

torch_version = torch.__version__.split("+")[0]

if version.parse(torch_version) >= version.parse("1.12.0"):
from torch.amp import GradScaler as TorchGradScaler
from torch.amp.grad_scaler import OptState, _refresh_per_optimizer_state
else:
from torch.cuda.amp import GradScaler as TorchGradScaler
from torch.cuda.amp.grad_scaler import OptState, _refresh_per_optimizer_state

from torch.optim import Optimizer as TorchOptimizer

import hivemind
Expand Down

0 comments on commit 8859194

Please sign in to comment.