Skip to content

Commit

Permalink
Speedup EMA
Browse files Browse the repository at this point in the history
  • Loading branch information
flytocc committed Mar 1, 2023
1 parent 4da55fe commit 35fc732
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions ppcls/utils/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ def __init__(self, model, decay=0.9999):

@paddle.no_grad()
def _update(self, model, update_fn):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.set_value(update_fn(ema_v, model_v))
for ema_v, model_v in zip(self.module.state_dict().values(),
model.state_dict().values()):
ema_v.set_value(update_fn(ema_v.numpy(), model_v.numpy()))

def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
self._update(
model,
update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

def set(self, model):
self._update(model, update_fn=lambda e, m: m)
self._update(model, update_fn=lambda e, m: m)

0 comments on commit 35fc732

Please sign in to comment.