Skip to content

Commit

Permalink
Avoid loading different frequency in rotary embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
jmercat committed May 16, 2024
1 parent fde6e25 commit 9a92986
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion open_lm/positional_embedding/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def __init__(self, dim_model: int, seq_len: int, frequency: float = 10000, *_, *
self.frequency = frequency
self.reset_parameters()

def load_state_dict(self, state_dict, strict=True):
# The state dict is not used, as the parameters are not trainable
# We want to avoid loading the inv_freq buffer in case the frequency is different
pass

def reset_parameters(self):
self.inv_freq = 1.0 / (self.frequency ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))
self._update_cos_sin_tables(self.seq_len)
Expand All @@ -72,7 +77,7 @@ def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = Non
if seq_len > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
self._seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(dtype))
freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(device=device, dtype=dtype))
emb = torch.cat((freqs, freqs), dim=-1).to(device)

self._cos_cached = emb.cos()[None, :, None, :].to(dtype)
Expand Down

0 comments on commit 9a92986

Please sign in to comment.