Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support generation tasks for eval.py #206

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,6 @@ We use the EleutherAI evaluation harness to evaluate our model accuracy. To eval
python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --compile --tasks hellaswag winogrande
```

Note: Generative tasks are currently not supported for gpt-fast

Installation Instructions for the evaluation harness: https://github.com/EleutherAI/lm-evaluation-harness/tree/master#install

### GPTQ
Expand Down
71 changes: 67 additions & 4 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import sys
import time
from pathlib import Path
from typing import Optional
from typing import List, Optional, Tuple

import torch
import torch._dynamo.config
import torch._inductor.config
from torch.nn.utils.rnn import pad_sequence

torch._dynamo.config.automatic_dynamic_shapes = True
torch._inductor.config.triton.unique_kernel_names = True
Expand All @@ -21,6 +22,7 @@
from tokenizer import get_tokenizer

from model import Transformer
from generate import generate

try:
import lm_eval
Expand Down Expand Up @@ -91,12 +93,14 @@ def __init__(
model: Transformer,
tokenizer,
max_seq_length: Optional[int]=None,
batch_size: int = 1,
):
super().__init__()
self._model = model
self._tokenizer = tokenizer
self._device = torch.device('cuda')
self._max_seq_length = 2048 if max_seq_length is None else max_seq_length
self._batch_size = batch_size

@property
def eot_token_id(self):
Expand All @@ -112,7 +116,7 @@ def max_gen_toks(self):

@property
def batch_size(self):
return 1
return self._batch_size

@property
def device(self):
Expand All @@ -127,6 +131,24 @@ def tok_encode(self, string: str, **kwargs):
encoded = encoded.tolist()
return encoded

def tok_batch_encode(
self, text: List[str], **kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
tokenized_text = [self.tok_encode(x) for x in text]

# pad left
x = pad_sequence(
[
torch.tensor(x[::-1]) for x in tokenized_text
], # first flip each sequence and pad
batch_first=True,
padding_value=self._tokenizer.pad_id(),
).flip(
dims=[1]
) # flip back to correct order

return x, torch.ones_like(x) # return 'mask' b/c it's expected by the harness

def tok_decode(self, tokens):
decoded = self._tokenizer.decode(tokens)
return decoded
Expand All @@ -147,8 +169,49 @@ def _model_call(self, inps):
logits = model_forward(self._model, x, input_pos)
return logits

def _model_generate(self, context, max_length, eos_token_id):
raise Exception('unimplemented')
def _model_generate(self, context, max_length, stop, **generation_kwargs):
curr_batch_size = context.size(0)
assert curr_batch_size == 1, "Currently generation only supports batch size of 1. Provided prompt has batch size {curr_batch_size}."

# temperature = 0.0 if not set
# if do_sample is false and temp==0.0:
# remove temperature, as do_sample=False takes care of this
# and we don't want a warning from HF
generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
do_sample = generation_kwargs.get("do_sample", None)

# The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
generation_kwargs["do_sample"] = do_sample = False

# TODO: handle top_p and top_k

# Setup caches for a given batch size
# Technically this is not necessary, but it's a good way to ensure that
# the caches won't error on a different batch size. In addition, caches
# are not needed for a regular model call, so we just setup here
# TODO: call setup_cache_padded_seq_input_pos_max_seq_length_for_prefill() instead
with context.device:
self._model.setup_caches(max_batch_size=curr_batch_size, max_seq_length=max_length)

# TODO: currently, the generate() function assumes 1D tensor with batch size 1. Need to update it to accept 2D tensor
context = context.flatten(0)

toks, accept_counts = generate(
self._model,
context,
max_new_tokens=self.max_gen_toks,
interactive=False,
draft_model=None,
temperature=generation_kwargs["temperature"],
# top_k=None, # do_sample is not supported currently
# stop_tokens=self._tokenizer.stop_tokens,
)

# TODO: output from generate() is 1D tensor with batch size 1. Need to update to return 2D tensor.
toks = toks.unsqueeze(0)

return torch.tensor(toks, dtype=torch.int32)


@torch.no_grad()
Expand Down
11 changes: 11 additions & 0 deletions tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def bos_id(self):
def eos_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")

def pad_id(self):
raise NotImplementedError("This method should be overridden by subclasses.")

class SentencePieceWrapper(TokenizerInterface):
def __init__(self, model_path):
super().__init__(model_path)
Expand All @@ -38,6 +41,10 @@ def bos_id(self):
def eos_id(self):
return self.processor.eos_id()

def pad_id(self):
# TODO: handle other models that do have pad_id
return self.processor.eos_id()

class TiktokenWrapper(TokenizerInterface):
"""
Tokenizing and encoding/decoding text using the Tiktoken tokenizer.
Expand Down Expand Up @@ -94,6 +101,10 @@ def bos_id(self):
def eos_id(self):
return self._eos_id

def pad_id(self):
# TODO: handle other models that do have pad_id
return self._eos_id

def get_tokenizer(tokenizer_model_path, model_name):
"""
Factory function to get the appropriate tokenizer based on the model name.
Expand Down