Skip to content

Commit

Permalink
example(LLM): use @starwhale.argument feature to refactor llm finet…
Browse files Browse the repository at this point in the history
…une examples (#3125)
  • Loading branch information
tianweidut authored Jan 16, 2024
1 parent c474ca6 commit 09edf32
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 74 deletions.
32 changes: 21 additions & 11 deletions example/llm-finetune/models/baichuan2/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import os
import typing as t
import dataclasses

import torch
import gradio
from peft import PeftModel
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig

from starwhale import handler, evaluation
from starwhale import handler, argument, evaluation

try:
from .utils import BASE_MODEL_DIR, ADAPTER_MODEL_DIR
Expand All @@ -20,6 +21,22 @@
_g_tokenizer = None


@dataclasses.dataclass
class ModelGenerateArguments:
max_new_tokens: int = dataclasses.field(
default=512, metadata={"help": "max length of generated text"}
)
do_sample: bool = dataclasses.field(default=True, metadata={"help": "do sample"})
temperature: float = dataclasses.field(
default=0.7, metadata={"help": "temperature"}
)
top_p: float = dataclasses.field(default=0.9, metadata={"help": "top p"})
top_k: int = dataclasses.field(default=30, metadata={"help": "top k"})
repetition_penalty: float = dataclasses.field(
default=1.3, metadata={"help": "repetition penalty"}
)


def _load_model_and_tokenizer() -> t.Tuple:
global _g_model, _g_tokenizer

Expand Down Expand Up @@ -57,26 +74,19 @@ def _load_model_and_tokenizer() -> t.Tuple:
return _g_model, _g_tokenizer


@argument(ModelGenerateArguments)
@evaluation.predict(
resources={"nvidia.com/gpu": 1},
replicas=1,
log_mode="plain",
)
def copilot_predict(data: dict) -> str:
def copilot_predict(data: dict, arguments: ModelGenerateArguments) -> str:
model, tokenizer = _load_model_and_tokenizer()
# support z-bench-common dataset: https://cloud.starwhale.cn/projects/401/datasets/161/versions/223/files
messages = [{"role": "user", "content": data["prompt"]}]

config_dict = model.generation_config.to_dict()
# TODO: use arguments
config_dict.update(
max_new_tokens=int(os.environ.get("MAX_MODEL_LENGTH", 512)),
do_sample=True,
temperature=float(os.environ.get("TEMPERATURE", 0.7)),
top_p=float(os.environ.get("TOP_P", 0.9)),
top_k=int(os.environ.get("TOP_K", 30)),
repetition_penalty=float(os.environ.get("REPETITION_PENALTY", 1.3)),
)
config_dict.update(dataclasses.asdict(arguments))
return model.chat(
tokenizer,
messages=messages,
Expand Down
108 changes: 75 additions & 33 deletions example/llm-finetune/models/baichuan2/finetune.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations

import os
import typing as t
from dataclasses import dataclass
import dataclasses

import torch
from peft import (
Expand All @@ -20,9 +19,9 @@
AutoModelForCausalLM,
)
from torch.utils.data import ChainDataset
from transformers.training_args import TrainingArguments
from transformers.training_args import TrainingArguments as HFTrainingArguments

from starwhale import Dataset, finetune
from starwhale import Dataset, argument, finetune

try:
from .utils import BASE_MODEL_DIR, ADAPTER_MODEL_DIR
Expand All @@ -36,7 +35,7 @@
torch.backends.cuda.matmul.allow_tf32 = True


@dataclass
@dataclasses.dataclass
class DataCollatorForCausalLM:
tokenizer: PreTrainedTokenizer
source_max_len: int
Expand Down Expand Up @@ -81,12 +80,78 @@ def __call__(self, example: t.Dict) -> t.Dict:
}


@dataclasses.dataclass
class ModelArguments:
model_max_length: int = dataclasses.field(
default=512, metadata={"help": "max length of generated text"}
)


@dataclasses.dataclass
class TrainingArguments(HFTrainingArguments):
# change default value
# copy from https://github.com/baichuan-inc/Baichuan2/blob/main/README.md#%E5%8D%95%E6%9C%BA%E8%AE%AD%E7%BB%83
output_dir: str = dataclasses.field(
default=str(ADAPTER_MODEL_DIR), metadata={"help": "output dir"}
)
optim: str = dataclasses.field(
default="adamw_torch", metadata={"help": "optimizer"}
)
report_to: str = dataclasses.field(
default="none", metadata={"help": "report metrics to some service"}
)
num_train_epochs: int = dataclasses.field(default=2, metadata={"help": "epochs"})
max_steps: int = dataclasses.field(default=18, metadata={"help": "max steps"})
per_device_train_batch_size: int = dataclasses.field(
default=2, metadata={"help": "per device train batch size"}
)
gradient_accumulation_steps: int = dataclasses.field(
default=16, metadata={"help": "gradient accumulation steps"}
)
save_strategy: str = dataclasses.field(
default="no", metadata={"help": "save strategy"}
)
learning_rate: float = dataclasses.field(
default=2e-5, metadata={"help": "learning rate"}
)
lr_scheduler_type: str = dataclasses.field(
default="constant", metadata={"help": "lr scheduler type"}
)
adam_beta1: float = dataclasses.field(default=0.9, metadata={"help": "adam beta1"})
adam_beta2: float = dataclasses.field(default=0.98, metadata={"help": "adam beta2"})
adam_epsilon: float = dataclasses.field(
default=1e-8, metadata={"help": "adam epsilon"}
)
max_grad_norm: float = dataclasses.field(
default=1.0, metadata={"help": "max grad norm"}
)
weight_decay: float = dataclasses.field(
default=1e-4, metadata={"help": "weight decay"}
)
warmup_ratio: float = dataclasses.field(
default=0.0, metadata={"help": "warmup ratio"}
)
logging_steps: int = dataclasses.field(
default=10, metadata={"help": "logging steps"}
)
gradient_checkpointing: bool = dataclasses.field(
default=False, metadata={"help": "gradient checkpointing"}
)
remove_unused_columns: bool = dataclasses.field(
default=False, metadata={"help": "remove unused columns"}
)


@argument((ModelArguments, TrainingArguments))
@finetune(
resources={"nvidia.com/gpu": 1},
require_train_datasets=True,
model_modules=[copilot_predict, "finetune:lora_finetune"],
)
def lora_finetune(train_datasets: t.List[Dataset]) -> None:
def lora_finetune(
train_datasets: t.List[Dataset],
arguments: t.Tuple[ModelArguments, HFTrainingArguments],
) -> None:
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_DIR,
trust_remote_code=True,
Expand Down Expand Up @@ -133,38 +198,15 @@ def lora_finetune(train_datasets: t.List[Dataset]) -> None:
if "norm" in name:
module = module.to(torch.float32)

model_arguments, train_arguments = arguments

tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL_DIR,
use_fast=False,
trust_remote_code=True,
model_max_length=int(os.environ.get("MODEL_MAX_LENGTH", 512)),
model_max_length=model_arguments.model_max_length,
)

# TODO: support finetune arguments
# copy from https://github.com/baichuan-inc/Baichuan2/blob/main/README.md#%E5%8D%95%E6%9C%BA%E8%AE%AD%E7%BB%83
train_args = TrainingArguments(
output_dir=str(ADAPTER_MODEL_DIR),
optim="adamw_torch",
report_to="none",
num_train_epochs=int(os.environ.get("NUM_TRAIN_EPOCHS", 2)),
max_steps=int(os.environ.get("MAX_STEPS", 18)),
per_device_train_batch_size=2, # more batch size will cause OOM
gradient_accumulation_steps=16,
save_strategy="no", # no need to save checkpoint for finetune
learning_rate=2e-5,
lr_scheduler_type="constant",
adam_beta1=0.9,
adam_beta2=0.98,
adam_epsilon=1e-8,
max_grad_norm=1.0,
weight_decay=1e-4,
warmup_ratio=0.0,
logging_steps=10,
gradient_checkpointing=False,
remove_unused_columns=False,
)

# TODO: support deepspeed
train_dataset = ChainDataset(
[
ds.to_pytorch(
Expand All @@ -179,7 +221,7 @@ def lora_finetune(train_datasets: t.List[Dataset]) -> None:
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=train_args,
args=train_arguments,
train_dataset=train_dataset,
)

Expand Down
32 changes: 22 additions & 10 deletions example/llm-finetune/models/chatglm3/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import os
import typing as t
import dataclasses

import torch
from transformers import AutoModel, AutoConfig, AutoTokenizer

from starwhale import evaluation
from starwhale import argument, evaluation
from starwhale.api.service import api, LLMChat

try:
Expand All @@ -18,15 +18,26 @@
_g_tokenizer = None


def _load_model_and_tokenizer() -> t.Tuple:
@dataclasses.dataclass
class ModelGenerateArguments:
max_length: int = dataclasses.field(
default=512, metadata={"help": "max length of generated text"}
)
top_p: float = dataclasses.field(default=0.9, metadata={"help": "top p"})
temperature: float = dataclasses.field(
default=1.2, metadata={"help": "temperature"}
)
pre_seq_len: int = dataclasses.field(default=128, metadata={"help": "pre seq len"})


def _load_model_and_tokenizer(pre_seq_len: int = 128) -> t.Tuple:
global _g_model, _g_tokenizer

if _g_model is None:
# TODO: after starwhale supports parameters, we can remove os environ.
config = AutoConfig.from_pretrained(
BASE_MODEL_DIR,
trust_remote_code=True,
pre_seq_len=int(os.environ.get("PT_PRE_SEQ_LEN", "128")),
pre_seq_len=pre_seq_len,
)
_g_model = (
AutoModel.from_pretrained(
Expand Down Expand Up @@ -59,21 +70,22 @@ def _load_model_and_tokenizer() -> t.Tuple:
return _g_model, _g_tokenizer


@argument(ModelGenerateArguments)
@evaluation.predict(
resources={"nvidia.com/gpu": 1},
replicas=1,
log_mode="plain",
)
def copilot_predict(data: dict) -> str:
model, tokenizer = _load_model_and_tokenizer()
def copilot_predict(data: dict, argument: ModelGenerateArguments) -> str:
model, tokenizer = _load_model_and_tokenizer(argument.pre_seq_len)
print(data["prompt"])
response, _ = model.chat(
tokenizer,
data["prompt"],
history=[],
max_length=int(os.environ.get("MAX_LENGTH", "512")),
top_p=float(os.environ.get("TOP_P", "0.9")),
temperature=float(os.environ.get("TEMPERATURE", "1.2")),
max_length=argument.max_length,
top_p=argument.top_p,
temperature=argument.temperature,
)
return response

Expand Down
Loading

0 comments on commit 09edf32

Please sign in to comment.