Skip to content

Commit

Permalink
add mistral+vllm example for llm-leaderboard
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Jan 30, 2024
1 parent 411c05d commit fb88541
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 54 deletions.
2 changes: 2 additions & 0 deletions example/llm-leaderboard/leaderboard.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Current supported LLMs:
- chatglm 6b
- chatglm2 6b
- aquila 7b/7b-chat
- mistral-7b-instruct
- mistral-8*7b-instruct

## Build Starwhale Runtime

Expand Down
4 changes: 4 additions & 0 deletions example/llm-leaderboard/src/benchmark/cmmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def _ingest_choice(self, content: str) -> str:
if match:
return match.group(index)

m = re.findall(r"[ABCD]", content)
if len(m) >= 1:
return m[0]

raise ValueError(f"cannot ingest ABCD choice from {content}")

def calculate_score(
Expand Down
107 changes: 76 additions & 31 deletions example/llm-leaderboard/src/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import os
import typing as t
import threading
import dataclasses
from collections import defaultdict

import numpy

from starwhale import evaluation
from starwhale import argument, evaluation
from starwhale.utils.debug import console

try:
Expand All @@ -30,18 +31,39 @@
_g_llm = None
_g_benchmarks: t.Dict[str, BenchmarkBase] = {}

max_prompt_length = int(os.environ.get("MAX_PROMPT_LENGTH", 2048))
max_new_tokens = int(os.environ.get("MAX_NEW_TOKENS", 256))

@dataclasses.dataclass
class ModelGenerateArguments:
max_prompt_length: int = dataclasses.field(
default=2048, metadata={"help": "max length of prompt"}
)
max_new_tokens: int = dataclasses.field(
default=256, metadata={"help": "max length of generated text"}
)
batch: int = dataclasses.field(
default=1, metadata={"help": "batch size for inference"}
)
temperature: float = dataclasses.field(
default=0.8, metadata={"help": "temperature"}
)
top_p: float = dataclasses.field(default=0.95, metadata={"help": "top p"})
tensor_parallel: int = dataclasses.field(
default=1, metadata={"help": "tensor parallel for vllm"}
)


# TODO: support multi-gpus evaluation
# TODO: enhance selected features
@argument(ModelGenerateArguments)
@evaluation.predict(
resources={"nvidia.com/gpu": 1},
replicas=1,
batch_size=32,
auto_log=False,
)
def predict_question(data: dict, external: dict) -> None:
def predict_question(
data: t.List[dict], external: dict, arguments: ModelGenerateArguments
) -> None:
# dev split is used for few shot samples
if data.get("_hf_split", "") == "dev":
return
Expand All @@ -50,7 +72,7 @@ def predict_question(data: dict, external: dict) -> None:
global _g_llm
with threading.Lock():
if _g_llm is None:
_g_llm = get_built_llm()
_g_llm = get_built_llm(tensor_parallel=arguments.tensor_parallel)

global _g_benchmarks
dataset_uri = external["dataset_uri"]
Expand All @@ -59,34 +81,57 @@ def predict_question(data: dict, external: dict) -> None:
# TODO: use dataset_info to get benchmark
_g_benchmarks[dataset_name] = get_benchmark(dataset_name)

result = {}
benchmark = _g_benchmarks[dataset_name]()
for shot, show_name in few_shot_choices.items():
prompt = benchmark.generate_prompt(
data,
few_shot=shot,
dataset_uri=dataset_uri,
max_length=max_prompt_length,
len_tokens=_g_llm.calculate_tokens_length,
)
predict_result = _g_llm.do_predict(
prompt,
benchmark_type=benchmark.get_type(),
max_new_tokens=max_new_tokens,
predict_choice_by_logits=True,

inputs = []
for _index, _data in zip(data, external["index"]):
for _shot, _show_name in few_shot_choices.items():
_prompt = benchmark.generate_prompt(
_data,
few_shot=_shot,
dataset_uri=dataset_uri,
max_length=arguments.max_prompt_length,
len_tokens=_g_llm.calculate_tokens_length,
)
inputs.append((_index, _show_name, _data, _prompt))

predict_results = []
for idx in range(0, len(inputs), arguments.batch):
batch_prompts = [x[-1] for x in inputs[idx : idx + arguments.batch]]

if _g_llm.support_batch_inference():
_results = _g_llm.do_batch_predict(
batch_prompts,
benchmark_type=benchmark.get_type(),
max_new_tokens=arguments.max_new_tokens,
predict_choice_by_logits=True,
)
predict_results.extend(_results)
else:
for _prompt in batch_prompts:
_result = _g_llm.do_predict(
_prompt,
benchmark_type=benchmark.get_type(),
max_new_tokens=arguments.max_new_tokens,
predict_choice_by_logits=True,
)
predict_results.append(_result)

for (_index, _show_name, _data, _prompt), predict_result in zip(
inputs, predict_results
):
score = benchmark.calculate_score(predict_result, _data)
console.trace(f"prompt:\n {_prompt}")
console.trace(f"answer: {_data['answer']}, predict: {score}")

evaluation.log(
category="results",
id=f"{benchmark.get_name()}-{_index}",
metrics={
"input": benchmark.make_input_features_display(_data),
"output": {_show_name: score},
},
)
result[show_name] = benchmark.calculate_score(predict_result, data)
console.trace(f"prompt:\n {prompt}")
console.trace(f"answer: {data['answer']}, predict: {result[show_name]}")

evaluation.log(
category="results",
id=f"{benchmark.get_name()}-{external['index']}",
metrics={
"input": benchmark.make_input_features_display(data),
"output": result,
},
)


@evaluation.evaluate(needs=[predict_question], use_predict_auto_log=False)
Expand Down
121 changes: 98 additions & 23 deletions example/llm-leaderboard/src/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import typing as t
import logging
from abc import ABC, abstractmethod
from typing import Dict, List
from pathlib import Path
from dataclasses import dataclass

Expand Down Expand Up @@ -37,7 +38,8 @@ class LLMModelDesc:


class LLMBase(ABC):
def __init__(self, rootdir: Path | None = None) -> None:
def __init__(self, **kwargs: t.Any) -> None:
rootdir = kwargs.get("rootdir")
self.rootdir = rootdir if rootdir is not None else Path.cwd()

@classmethod
Expand All @@ -54,6 +56,9 @@ def get_description(cls) -> LLMModelDesc:
def download(self) -> None:
raise NotImplementedError

def support_batch_inference(self) -> bool:
return False

def ensure_swignore(self) -> None:
swi_path = self.rootdir / ".swignore"

Expand All @@ -80,16 +85,29 @@ def get_pretrained_dir(self) -> Path:
@abstractmethod
def do_predict(
self,
input_prompt: str,
input_prompt: str | t.List[str],
benchmark_type: BenchmarkType = BenchmarkType.MultipleChoice,
max_new_tokens: int = 50,
predict_choice_by_logits: bool = False,
) -> t.Dict | str:
**kwargs: t.Any,
) -> t.Dict | str | t.List:
raise NotImplementedError

def calculate_tokens_length(self, input_prompt: str) -> int:
return len(input_prompt)

def _simplify_content(self, content: str) -> str:
content = content.strip()

for token in ("###", "[UNK]", "</s>", "]", ")", ")", "】"):
if content.endswith(token):
content = content[: -len(token)].strip()

for prefix in (":", ":", "(", "(", "[", "【"):
if content.startswith(prefix):
content = content[len(prefix) :].strip()

return content

def ensure_readme(self) -> None:
readme_path = self.rootdir / "README.md"
desc = self.get_description()
Expand Down Expand Up @@ -141,9 +159,78 @@ def ensure_readme(self) -> None:
ensure_file(readme_path, content)


class vLLMBase(LLMBase):
def __init__(self, **kwargs: t.Any) -> None:
super().__init__(**kwargs)

self._model = None
self._tensor_parallel = kwargs.get("tensor_parallel", 1)

@abstractmethod
def get_hf_repo_id(self) -> str:
raise NotImplementedError

def get_base_dir(self) -> Path:
return self.get_pretrained_dir() / self.get_name() / "base"

def download(self) -> None:
from huggingface_hub import snapshot_download

local_dir = self.get_base_dir()
ensure_dir(local_dir)
snapshot_download(
repo_id=self.get_hf_repo_id(), local_dir=local_dir, max_workers=16
)

def support_batch_inference(self) -> bool:
return True

def _get_model(self) -> t.Any:
if self._model is None:
path = self.get_base_dir()
console.print(f":monkey: try to load model({path}) into memory...")

import vllm

self._model = vllm.LLM(
path, dtype=torch.float16, tensor_parallel_size=self._tensor_parallel
)

return self._model

@torch.no_grad()
def do_predict(
self,
input_prompts: str | List[str],
benchmark_type: BenchmarkType = BenchmarkType.MultipleChoice,
max_new_tokens: int = 50,
**kwargs: t.Any,
) -> Dict | str | t.List:
if isinstance(input_prompts, str):
input_prompts = [input_prompts]

import vllm

temperature = kwargs.get("temperature", 0.8)
top_p = kwargs.get("top_p", 0.95)

sp = vllm.SamplingParams(
temperature=temperature, top_p=top_p, max_tokens=max_new_tokens
)
outputs = self._get_model().generate(input_prompts, sp)
outputs.sort(key=lambda x: x.request_id)

ret = []
for output in outputs:
content = "".join([o.text for o in output.outputs])
content = self._simplify_content(content)
ret.append(content)
return ret


class HuggingfaceLLMBase(LLMBase):
def __init__(self, rootdir: Path | None = None) -> None:
super().__init__(rootdir)
def __init__(self, **kwargs: t.Any) -> None:
super().__init__(**kwargs)

self._tokenizer = None
self._model = None
Expand Down Expand Up @@ -247,30 +334,17 @@ def get_generate_kwargs(self) -> t.Dict[str, t.Any]:
repetition_penalty=float(os.environ.get("REPETITION_PENALTY", 1.3)),
)

def _simplify_content(self, content: str) -> str:
content = content.strip()

for token in ("###", "[UNK]", "</s>", "]", ")", ")", "】"):
if content.endswith(token):
content = content[: -len(token)].strip()

for prefix in (":", ":", "(", "(", "[", "【"):
if content.startswith(prefix):
content = content[len(prefix) :].strip()

return content

def do_predict(
self,
input_prompt: str,
benchmark_type: BenchmarkType = BenchmarkType.MultipleChoice,
max_new_tokens: int = 50,
predict_choice_by_logits: bool = False,
) -> t.Dict | str:
**kwargs: t.Any,
) -> t.Dict | str | t.List:
# TODO: add self prompt wrapper
content = self._do_predict_with_generate(input_prompt, max_new_tokens)
ret = {"content": self._simplify_content(content)}
if predict_choice_by_logits:
if kwargs.get("predict_choice_by_logits", False):
if benchmark_type != BenchmarkType.MultipleChoice:
raise ValueError(
"predict_choice_by_logits only support BenchmarkType.MultipleChoice"
Expand Down Expand Up @@ -363,7 +437,8 @@ def get_llm(name: str, **kwargs: t.Any) -> LLMBase:
return _SUPPORTED_LLM[name](**kwargs)


def get_built_llm(rootdir: Path | None = None, **kwargs: t.Any) -> LLMBase:
def get_built_llm(**kwargs: t.Any) -> LLMBase:
rootdir = kwargs.get("rootdir")
rootdir = rootdir if rootdir is not None else Path.cwd()
config_path = rootdir / "pretrained" / "sw_config.json"
if not config_path.exists():
Expand Down
43 changes: 43 additions & 0 deletions example/llm-leaderboard/src/llm/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

from .base import register, vLLMBase, LLMModelDesc


@register()
class Mistral7BInstruct(vLLMBase):
def get_hf_repo_id(self) -> str:
return "mistralai/Mistral-7B-Instruct-v0.2"

@classmethod
def get_name(cls) -> str:
return "mistral-7b-instruct"

@classmethod
def get_description(cls) -> LLMModelDesc:
return LLMModelDesc(
params="7b",
intro=(
"We introduce Mistral 7B v0.1, a 7-billion-parameter language model engineered for superior performance and efficiency. "
"Mistral 7B outperforms Llama 2 13B across all evaluated benchmarks, and Llama 1 34B in reasoning, mathematics, and code generation."
"Our model leverages grouped-query attention (GQA) for faster inference, coupled with sliding window attention (SWA) to effectively handle sequences of arbitrary length with a reduced inference cost. "
"We also provide a model fine-tuned to follow instructions, Mistral 7B -- Instruct, that surpasses the Llama 2 13B -- Chat model both on human and automated benchmarks. Our models are released under the Apache 2.0 license."
),
license="apache-2.0",
author="Mistral",
github="https://github.com/mistralai/mistral-src",
type="fine-tuned",
)

def download(self) -> None:
super().download()

local_dir = self.get_base_dir()
# We only need safetensors files.
useless_fnames = (
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin",
"pytorch_model.bin.index.json",
)
for fname in useless_fnames:
(local_dir / fname).unlink()

0 comments on commit fb88541

Please sign in to comment.