-
Notifications
You must be signed in to change notification settings - Fork 0
/
thesa.py
31 lines (24 loc) · 927 Bytes
/
thesa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from cleaning import CleanData
from finetuning import FineTune
from inference import InferFromModel
from transformers import AutoTokenizer
# get and clean dataset
thesa_dataset = CleanData().get_data()
# import Zephyr 7B model and tokenizer
checkpoint_zephyr = "TheBloke/zephyr-7B-alpha-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(checkpoint_zephyr)
# instantiate model (parameters in finetuning.py)
finetuned_model = FineTune(
dataset=thesa_dataset,
checkpoint=checkpoint_zephyr,
tokenizer=tokenizer
)
# finetune model
finetuned_model.finetune(output="/thesa",
epochs=10)
# inference
checkpoint_thesa = "johnhandleyd/thesa"
tokenizer_thesa = AutoTokenizer.from_pretrained(checkpoint_thesa)
example = "I've been feeling depressed lately. Can you help me?"
sample = InferFromModel(model=checkpoint_thesa, tokenizer=tokenizer_thesa)
print(sample.infer(example=example))