From af3a5e16d8953ad96dccdd1a04061576f07dfadc Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 6 Dec 2024 11:23:23 -0800 Subject: [PATCH] Updates nanoGPT to not store certain tensors as members --- tripy/examples/nanogpt/example.py | 2 +- tripy/examples/nanogpt/model.py | 17 ++++++----------- tripy/examples/nanogpt/weight_loader.py | 12 ++++++++---- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tripy/examples/nanogpt/example.py b/tripy/examples/nanogpt/example.py index 0041b6426..5c7b79715 100644 --- a/tripy/examples/nanogpt/example.py +++ b/tripy/examples/nanogpt/example.py @@ -103,7 +103,7 @@ def main(): 1, # We can specify dynamic dimensions by using a sequence indicating the min/opt/max values that # a dimension should support: - [1, len(input_ids), padded_seq_len], + (1, len(input_ids), padded_seq_len), ) model = tp.compile(model, args=[tp.InputInfo(input_shape, dtype=tp.int32)]) compile_end_time = time.perf_counter() diff --git a/tripy/examples/nanogpt/model.py b/tripy/examples/nanogpt/model.py index e641674bb..4f839adf9 100644 --- a/tripy/examples/nanogpt/model.py +++ b/tripy/examples/nanogpt/model.py @@ -67,7 +67,6 @@ def __init__(self, config): tp.tril(tp.ones((config.block_size, config.block_size), dtype=config.dtype)), (1, 1, config.block_size, config.block_size), ) - self.zeros = tp.zeros((1, 1, self.seq_len, self.seq_len), dtype=config.dtype) def __call__(self, x: tp.Tensor): B, T = x.shape[0:2] @@ -87,11 +86,7 @@ def __call__(self, x: tp.Tensor): k_t = tp.transpose(k, -2, -1) att = (q @ k_t) * (1.0 / math.sqrt(self.embedding_size // self.num_heads)) - att = tp.masked_fill( - att, - self.bias[:, :, :T, :T] == self.zeros[:, :, :T, :T], - float("-inf"), - ) + att = tp.masked_fill(att, self.bias[:, :, :T, :T] == 0, float("-inf")) att = tp.softmax(att, dim=-1) @@ -135,18 +130,18 @@ def __call__(self, x): class Transformer(tp.Module): def __init__(self, config): super().__init__() + self.seq_len = config.seq_len self.wte = tp.Embedding(config.vocab_size, config.embedding_size, dtype=config.dtype) self.wpe = tp.Embedding(config.block_size, config.embedding_size, dtype=config.dtype) - self.h = [Block(config) for _ in range(config.num_layers)] + self.h = tp.Sequential(*[Block(config) for _ in range(config.num_layers)]) self.ln_f = tp.LayerNorm(config.embedding_size) - self.pos = tp.reshape(tp.arange(0, config.seq_len, dtype=tp.int32), (1, config.seq_len)) def __call__(self, idx): tok_emb = self.wte(idx) # token embeddings of shape (batch_size, seq_len, embedding_size) - pos_emb = self.wpe(self.pos[:, : idx.shape[1]]) # position embeddings of shape (seq_len, embedding_size) + pos = tp.unsqueeze(tp.arange(self.seq_len, dtype=tp.int32)[: idx.shape[1]], 0) + pos_emb = self.wpe(pos) # position embeddings of shape (seq_len, embedding_size) x = tok_emb + pos_emb # (batch_size, seq_len, embedding_size) - for block in self.h: - x = block(x) + x = self.h(x) x = tp.cast(self.ln_f(tp.cast(x, self.ln_f.dtype)), x.dtype) return x diff --git a/tripy/examples/nanogpt/weight_loader.py b/tripy/examples/nanogpt/weight_loader.py index be84c628f..77e734bec 100644 --- a/tripy/examples/nanogpt/weight_loader.py +++ b/tripy/examples/nanogpt/weight_loader.py @@ -26,16 +26,20 @@ def load_weights_from_hf(model, model_type, dtype): tripy_state_dict = model.state_dict() # attention biases are initialized in the model based on block size. - tripy_keys = [key for key in tripy_state_dict.keys() if not key.endswith(".attn.bias")] + tripy_keys = {key for key in tripy_state_dict.keys() if not key.endswith(".attn.bias")} # Load huggingface/transformers model model_hf = GPT2LMHeadModel.from_pretrained(model_type) hf_state_dict = model_hf.state_dict() # We ignore some of the keys in the HF checkpoint: - hf_keys = [ + hf_keys = { key for key in hf_state_dict.keys() if not key.endswith(".attn.masked_bias") and not key.endswith(".attn.bias") - ] - assert len(hf_keys) == len(tripy_keys), f"Mismatched keys: {hf_keys} != {tripy_keys}" + } + assert hf_keys == tripy_keys, ( + f"Mismatched keys. Note:\n" + f"`hf_keys` extra keys: {hf_keys - tripy_keys}\n" + f"`tripy_keys` extra keys: {tripy_keys - hf_keys}" + ) # See https://paperswithcode.com/method/weight-tying for details on why we do this: hf_state_dict["transformer.wte.weight"] = hf_state_dict["lm_head.weight"]