-
Notifications
You must be signed in to change notification settings - Fork 3
/
inferrence.py
67 lines (53 loc) · 2.19 KB
/
inferrence.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import json
import argparse
parser = argparse.ArgumentParser(description='args')
parser.add_argument('input_file', type=str, help='输入文件')
parser.add_argument('output_file', type=str, help='输出文件')
parser.add_argument('adapter_path', type=str, default=None, help='PEFT文件路径')
args = parser.parse_args()
input_file = args.input_file
output_file = args.output_file
adapter_path = args.adapter_path
model_name = "DataCanvas/Alaya-7B-Chat"
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.attn_config['attn_impl'] = 'torch'
# config.max_seq_len = 4096 # (input + output) tokens can now be up to 4096
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.bfloat16, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if adapter_path:
model.load_adapter(adapter_path)
eos_token_id = 2
bad_words_ids = 3
gpu_id = '0'
pipe = pipeline('text-generation',
model=model,
tokenizer=tokenizer,
bad_words_ids=[[bad_words_ids]],
eos_token_id=eos_token_id,
pad_token_id=eos_token_id,
device='cuda:'+gpu_id
)
with open(input_file, 'r', encoding='utf-8') as file:
lines = file.readlines()
instructions = [line.strip() for line in lines]
def do_inference(instruction):
PROMPT_FORMAT = '### Instruction:\t\n{instruction}\n\n'
prompt = PROMPT_FORMAT.format(instruction=instruction)
result = pipe(prompt, max_new_tokens=1000, do_sample=True, use_cache=True, eos_token_id=eos_token_id, pad_token_id=eos_token_id)
flag = '### Output:\t\n'
try:
output = result[0]['generated_text'].split(flag)[1].rstrip('\n\n')
except:
output = ''
org_output = result[0]['generated_text']
return output, org_output
with open(output_file, 'w', encoding='utf-8') as file:
for ins in instructions:
response, response_org = do_inference(ins)
result = {'prompt':ins, 'response':response, 'response_org':response_org}
print(result)
json.dump(result, file, ensure_ascii=False)
file.write('\n')
print('All done')