Skip to content

Commit

Permalink
Speed monitor refactor (mosaicml#1987)
Browse files Browse the repository at this point in the history
* add speed monitor refactor

* fix docs

* fix tests

* fix remove 1

* extend test

* format

* respond to comments

* restore caching

* add deepcopy

* add comment
  • Loading branch information
mvpatel2000 authored and bmosaicml committed Mar 2, 2023
1 parent e8fb131 commit e471ca5
Showing 1 changed file with 29 additions and 8 deletions.
37 changes: 29 additions & 8 deletions composer/optim/decoupled_weight_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def __init__(self,

@staticmethod
def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[torch.Tensor],
exp_avg_sqs: List[torch.Tensor], max_exp_avg_sqs: List[torch.Tensor], state_steps: List[int], *,
amsgrad: bool, beta1: float, beta2: float, lr: float, initial_lr: float, weight_decay: float,
eps: float) -> None:
exp_avg_sqs: List[torch.Tensor], max_exp_avg_sqs: List[torch.Tensor], state_steps: List[int],
masks_on: List[bool], *, amsgrad: bool, beta1: float, beta2: float, lr: float, initial_lr: float,
weight_decay: float, eps: float) -> None:
r"""Functional API that performs AdamW algorithm computation with decoupled weight decay.
Args:
Expand All @@ -250,6 +250,7 @@ def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[
exp_avg = exp_avgs[i]
exp_avg_sq = exp_avg_sqs[i]
step = state_steps[i]
mask_on = masks_on[i]

# Perform stepweight decay
if weight_decay != 0:
Expand All @@ -259,20 +260,32 @@ def adamw(params: List[torch.Tensor], grads: List[torch.Tensor], exp_avgs: List[
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step

# mask out any params from the moment that point in the opposite direction of
# the grad
if mask_on:
update = exp_avg.sign().mul_(grad.sign()).sign_().clamp_(0)
update.mul_(exp_avg).mul_(beta1).add_(grad, alpha=1 - beta1)
else:
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
# Use the max. for normalizing running avg. of gradient
denom = (max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps)
update.div_((max_exp_avg_sqs[i].sqrt() / math.sqrt(bias_correction2)).add_(eps))
else:
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
update.div_((exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps))

step_size = lr / bias_correction1

param.addcdiv_(exp_avg, denom, value=-step_size)
param.add_(update, alpha=-step_size)

exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

def turn_on_masking(self, param):
self.state[param]['mask'] = True

@torch.no_grad()
def step(self, closure=None):
Expand All @@ -292,6 +305,7 @@ def step(self, closure=None):
grads = []
exp_avgs = []
exp_avg_sqs = []
masks_on = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
Expand Down Expand Up @@ -326,6 +340,7 @@ def step(self, closure=None):

exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
masks_on.append('mask' in state and state['mask'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])

Expand All @@ -340,6 +355,7 @@ def step(self, closure=None):
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
masks_on,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
Expand Down Expand Up @@ -414,7 +430,12 @@ def report_per_parameter_metrics(self, param: torch.Tensor, name: str, optimizer
bias_correction2 = 1 - beta2**step
denom = (param_optim_state['exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
step_tensor = step_size * param_optim_state['exp_avg'].div(denom)
if 'mask' in param_optim_state and param_optim_state['mask']:
step_tensor = param_optim_state['exp_avg'].sign().mul_(param.grad.sign()).sign_().clamp_(0)
step_tensor.mul_(param_optim_state['exp_avg']).mul_(beta1).add_(param.grad, alpha=1 - beta1)
step_tensor = step_size * step_tensor.div(denom)
else:
step_tensor = step_size * param_optim_state['exp_avg'].div(denom)
decay_factor = (lr / initial_lr) if initial_lr else 1.0
step_tensor.add_(param, alpha=-weight_decay * decay_factor)
for metric in self.metric_functions:
Expand Down

0 comments on commit e471ca5

Please sign in to comment.