Skip to content

Commit

Permalink
FIM benchmark for StarCoder 2
Browse files Browse the repository at this point in the history
  • Loading branch information
arjunguha committed Nov 2, 2023
1 parent 5414751 commit 2c8c56b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 5 deletions.
10 changes: 5 additions & 5 deletions fim_inference.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# I've hacked this to hardcode BigCode15B.
# I've hacked this to hardcode StarCoder2
import json
from pathlib import Path
import argparse
from more_itertools import chunked
from tqdm import tqdm
import bigcode15b
import starcoder2


def main():
parser = argparse.ArgumentParser()
# parser.add_argument("--model", type=str, required=True, help="Module name of the model to use")
parser.add_argument("--model-path", type=Path, required=True, help="Module name of the model to use")
parser.add_argument("--batch-size", type=int, required=True, help="Batch size to use")
parser.add_argument("--output-dir", type=Path, default=Path("."), help="Output directory for results")

args = parser.parse_args()

model = bigcode15b.Model(bigcode15b.CHECKPOINT_TO_REVISION["1000m"])
name = "bigcode15b"
name = args.model_path.name
model = starcoder2.Model(args.model_path)

# Load existing results if any
result_path = args.output_dir / f"fim-results-{name}.jsonl"
Expand Down
81 changes: 81 additions & 0 deletions starcoder2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
from typing import List, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM
from multipl_e.completions import partial_arg_parser, make_main, stop_at_stop_token

FIM_PREFIX = "<fim_prefix>"
FIM_MIDDLE = "<fim_middle>"
FIM_SUFFIX = "<fim_suffix>"
FIM_PAD = "<fim_pad>"
EOD = "<|endoftext|>"
SPEC_TOKS = [EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD]

def extract_fim_part(s: str):
# Find the index of <fim-middle>
start = s.find(FIM_MIDDLE) + len(FIM_MIDDLE)
stop = s.find(EOD, start) or len(s)
return s[start:stop]

class Model:
def __init__(self, name):
self.model = AutoModelForCausalLM.from_pretrained(name, trust_remote_code=True, torch_dtype=torch.float16)
self.model = self.model.cuda()
self.tokenizer = AutoTokenizer.from_pretrained(name, padding_side="left", trust_remote_code=True)
self.tokenizer.pad_token = "<|endoftext|>"
self.special_tokens = SPEC_TOKS

def completion_tensors(
self,
prompt: str,
max_length: int,
temperature: float,
n: int,
top_p: float,
):
"""
Produces n samples.
"""
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.cuda()
max_length = max_length + input_ids.flatten().size(0)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device="cuda")
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_return_sequences=n,
max_length=max_length,
attention_mask=attention_mask,
pad_token_id=self.tokenizer.pad_token_id
)
return output

def decode_single_output(self, output_tensor, prompt):
detok_hypo_str = self.tokenizer.decode(
output_tensor, clean_up_tokenization_spaces=False
)
# Skip the prompt (which may even have stop_tokens)
return detok_hypo_str[len(prompt) :]

def fill_in_the_middle(self, prefix_suffix_tuples: List[Tuple[str, str]], max_tokens: int, temperature: float) -> List[str]:
prompts = [f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" for prefix, suffix in prefix_suffix_tuples]
result = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)
input_ids = result.input_ids.cuda()
attention_mask = result.attention_mask.cuda()
max_length = input_ids[0].size(0) + max_tokens
with torch.no_grad():
output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=True,
temperature=temperature,
top_p=0.95,
max_length=max_length,
pad_token_id=self.tokenizer.pad_token_id
)
# WARNING: cannot use skip_special_tokens, because it clobbers the fim special tokens
return [
extract_fim_part(self.tokenizer.decode(tensor)) for tensor in output
]

0 comments on commit 2c8c56b

Please sign in to comment.