Skip to content

Commit

Permalink
Update llama.py
Browse files Browse the repository at this point in the history
兼容Ernie
  • Loading branch information
ErnestinaQiu authored Jan 25, 2024
1 parent 671c6b8 commit 1face8b
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions pipelines/examples/tree-of-thought/src/llm/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time

from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -53,11 +54,14 @@ class llamaChatCompletion:

def __init__(self, model="llama-2-7b-chat") -> None:
config_path = llm_config[model]
self.model_name = model
self.tokenizer = AutoTokenizer.from_pretrained(config_path)
self.generator = AutoModelForCausalLM.from_pretrained(config_path, dtype="float16")
self.tokenizer.init_chat_template(os.path.join(os.getcwd(), "pipelines", "examples", "tree-of-thought", "src", "llm", "chat_template.json"))
self.query = []
self.query_count = 0

# @staticmethod
def create(self, messages, temperature=0.6, top_p=0.9, max_gen_len=518):
def create(self, messages, temperature=0.6, top_p=0.9, max_gen_len=512):
"""
Entry point of the program for generating text using a pretrained model.
Expand All @@ -70,7 +74,7 @@ def create(self, messages, temperature=0.6, top_p=0.9, max_gen_len=518):
Defaults to 0.6.
top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
Defaults to 0.9.
max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512.
max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512. Max length is 4096
max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8.
max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be
set to the model's max sequence length. Defaults to None.
Expand All @@ -79,22 +83,25 @@ def create(self, messages, temperature=0.6, top_p=0.9, max_gen_len=518):
"choices": [],
"created": time.time(),
"id": "llama2_{}".format(int(time.time())),
"model": "llama-2-7b-chat",
"model": self.model_name,
"object": "chat.completion",
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}

for i in range(len(messages)):
one_mes = messages[i][0]
assert len(messages[i]) == 1
if one_mes["role"] != "user":
continue
mes = one_mes["content"]
input_features = self.tokenizer(mes, return_tensors="pd")
outputs = self.generator.generate(**input_features, max_new_tokens=max_gen_len)
self.query.append([mes])
self.query_count += len(mes)
while self.query_count > max_gen_len and len(self.query) > 2:
pop_size = len("".join(self.query.pop(0)))
self.query_count -= pop_size
input_features = self.tokenizer.apply_chat_template(self.query, return_tensors="pd")
outputs = self.generator.generate(**input_features, decode_strategy="greedy_search", temperature=temperature, top_p=top_p, max_new_tokens=max_gen_len)
out_0 = self.tokenizer.batch_decode(outputs[0])
print(f"dialog: \n {one_mes}")
print(out_0)
self.query[-1].append(out_0[0])
self.query_count += len(out_0[0])
if i == len(messages) - 1:
finish_reason = "stop"
else:
Expand Down

0 comments on commit 1face8b

Please sign in to comment.