Skip to content

Commit

Permalink
Updates nanoGPT to not store certain tensors as members
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavm-nvidia committed Dec 6, 2024
1 parent c15f675 commit af3a5e1
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tripy/examples/nanogpt/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def main():
1,
# We can specify dynamic dimensions by using a sequence indicating the min/opt/max values that
# a dimension should support:
[1, len(input_ids), padded_seq_len],
(1, len(input_ids), padded_seq_len),
)
model = tp.compile(model, args=[tp.InputInfo(input_shape, dtype=tp.int32)])
compile_end_time = time.perf_counter()
Expand Down
17 changes: 6 additions & 11 deletions tripy/examples/nanogpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def __init__(self, config):
tp.tril(tp.ones((config.block_size, config.block_size), dtype=config.dtype)),
(1, 1, config.block_size, config.block_size),
)
self.zeros = tp.zeros((1, 1, self.seq_len, self.seq_len), dtype=config.dtype)

def __call__(self, x: tp.Tensor):
B, T = x.shape[0:2]
Expand All @@ -87,11 +86,7 @@ def __call__(self, x: tp.Tensor):

k_t = tp.transpose(k, -2, -1)
att = (q @ k_t) * (1.0 / math.sqrt(self.embedding_size // self.num_heads))
att = tp.masked_fill(
att,
self.bias[:, :, :T, :T] == self.zeros[:, :, :T, :T],
float("-inf"),
)
att = tp.masked_fill(att, self.bias[:, :, :T, :T] == 0, float("-inf"))

att = tp.softmax(att, dim=-1)

Expand Down Expand Up @@ -135,18 +130,18 @@ def __call__(self, x):
class Transformer(tp.Module):
def __init__(self, config):
super().__init__()
self.seq_len = config.seq_len
self.wte = tp.Embedding(config.vocab_size, config.embedding_size, dtype=config.dtype)
self.wpe = tp.Embedding(config.block_size, config.embedding_size, dtype=config.dtype)
self.h = [Block(config) for _ in range(config.num_layers)]
self.h = tp.Sequential(*[Block(config) for _ in range(config.num_layers)])
self.ln_f = tp.LayerNorm(config.embedding_size)
self.pos = tp.reshape(tp.arange(0, config.seq_len, dtype=tp.int32), (1, config.seq_len))

def __call__(self, idx):
tok_emb = self.wte(idx) # token embeddings of shape (batch_size, seq_len, embedding_size)
pos_emb = self.wpe(self.pos[:, : idx.shape[1]]) # position embeddings of shape (seq_len, embedding_size)
pos = tp.unsqueeze(tp.arange(self.seq_len, dtype=tp.int32)[: idx.shape[1]], 0)
pos_emb = self.wpe(pos) # position embeddings of shape (seq_len, embedding_size)
x = tok_emb + pos_emb # (batch_size, seq_len, embedding_size)
for block in self.h:
x = block(x)
x = self.h(x)
x = tp.cast(self.ln_f(tp.cast(x, self.ln_f.dtype)), x.dtype)
return x

Expand Down
12 changes: 8 additions & 4 deletions tripy/examples/nanogpt/weight_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,20 @@ def load_weights_from_hf(model, model_type, dtype):

tripy_state_dict = model.state_dict()
# attention biases are initialized in the model based on block size.
tripy_keys = [key for key in tripy_state_dict.keys() if not key.endswith(".attn.bias")]
tripy_keys = {key for key in tripy_state_dict.keys() if not key.endswith(".attn.bias")}

# Load huggingface/transformers model
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
hf_state_dict = model_hf.state_dict()
# We ignore some of the keys in the HF checkpoint:
hf_keys = [
hf_keys = {
key for key in hf_state_dict.keys() if not key.endswith(".attn.masked_bias") and not key.endswith(".attn.bias")
]
assert len(hf_keys) == len(tripy_keys), f"Mismatched keys: {hf_keys} != {tripy_keys}"
}
assert hf_keys == tripy_keys, (
f"Mismatched keys. Note:\n"
f"`hf_keys` extra keys: {hf_keys - tripy_keys}\n"
f"`tripy_keys` extra keys: {tripy_keys - hf_keys}"
)

# See https://paperswithcode.com/method/weight-tying for details on why we do this:
hf_state_dict["transformer.wte.weight"] = hf_state_dict["lm_head.weight"]
Expand Down

0 comments on commit af3a5e1

Please sign in to comment.