From e1fdd678fed99244fa66d11dafedaf6efb910e81 Mon Sep 17 00:00:00 2001 From: Ernestina <48557439+ErnestinaQiu@users.noreply.github.com> Date: Fri, 26 Jan 2024 00:13:32 +0800 Subject: [PATCH] Update run.py add Ernie --- pipelines/examples/tree-of-thought/run.py | 101 +++++++++++----------- 1 file changed, 49 insertions(+), 52 deletions(-) diff --git a/pipelines/examples/tree-of-thought/run.py b/pipelines/examples/tree-of-thought/run.py index 4fcf8eb9b1dc..197dfd938fe4 100644 --- a/pipelines/examples/tree-of-thought/run.py +++ b/pipelines/examples/tree-of-thought/run.py @@ -13,87 +13,79 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import time import argparse import json import os - from src.tot.methods.bfs import naive_solve, solve from src.tot.models import gpt_usage from src.tot.tasks import get_task +from src.llm.llama import llm_config, llamaChatCompletion, Ernie_llm_list, Ernie -def run(args): +def run(args, chatter): task = get_task(args.task) logs, cnt_avg, cnt_any = [], 0, 0 if args.naive_run: - file = f"./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json" + file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json' + metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_select}_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt" else: - file = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json" + file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json' + metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt" os.makedirs(os.path.dirname(file), exist_ok=True) + for i in range(args.task_start_index, args.task_end_index): + args.log_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.log" + args.query_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_query.log" + f = open(args.log_fp, 'a', encoding='utf8') + f.write(f"------ index: {i}") + f.close() + + f = open(args.query_fp, 'a', encoding='utf8') + f.write(f"------ index: {i}") + f.close() + + chatter.query = [] + chatter.tokenizer.init_chat_template(os.path.join(os.getcwd(), "pipelines", "examples", "tree-of-thought", "src", "llm", "chat_template.json")) + # solve if args.naive_run: - ys, info = naive_solve(args, task, i) + ys, info = naive_solve(args, task, i, chatter=chatter, args=args) else: - ys, info = solve(args, task, i) + ys, info = solve(args, task, i, chatter=chatter, args=args) # log infos = [task.test_output(i, y) for y in ys] - info.update( - { - "idx": i, - "ys": ys, - "infos": infos, - "usage_so_far": gpt_usage(args.backend), - } - ) + info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far': gpt_usage(args.backend)}) logs.append(info) - with open(file, "w") as f: + with open(file, 'w') as f: json.dump(logs, f, indent=4) # log main metric - accs = [info["r"] for info in infos] + accs = [info['r'] for info in infos] cnt_avg += sum(accs) / len(accs) cnt_any += any(accs) - print(i, "sum(accs)", sum(accs), "cnt_avg", cnt_avg, "cnt_any", cnt_any, "\n") + mes = f"{i}, 'sum(accs)', {sum(accs)}, 'cnt_avg', {cnt_avg}, 'cnt_any', {cnt_any}, '\n'" + f = open(metric_fp, 'a', encoding="utf8") + f.write(mes) + f.close() + + f = open(args.query_fp, 'a', encoding="utf8") + f.write(json.dumps(chatter.query)) + f.close() n = args.task_end_index - args.task_start_index - print(cnt_avg / n, cnt_any / n) - print("usage_so_far", gpt_usage(args.backend)) - - -llm_backend_choices = [ - "llama-2-7b", - "llama-2-7b-chat", - "llama-2-13b", - "llama-2-13b-chat", - "llama-2-70b", - "llama-2-70b-chat", - "llama-7b", - "llama-13b", - "llama-30b", - "llama-65b", - "ziqingyang/chinese-llama-7b", - "ziqingyang/chinese-llama-13b", - "ziqingyang/chinese-alpaca-7b", - "ziqingyang/chinese-alpaca-13b", - "idea-ccnl/ziya-llama-13b-v1", - "linly-ai/chinese-llama-2-7b", - "linly-ai/chinese-llama-2-13b", - "baichuan-inc/Baichuan-7B", - "baichuan-inc/Baichuan-13B-Base", - "baichuan-inc/Baichuan-13B-Chat", - "baichuan-inc/Baichuan2-7B-Base", - "baichuan-inc/Baichuan2-7B-Chat", - "baichuan-inc/Baichuan2-13B-Base", - "baichuan-inc/Baichuan2-13B-Chat", - "FlagAlpha/Llama2-Chinese-7b-Chat", - "FlagAlpha/Llama2-Chinese-13b-Chat", -] + mes2 = f"cnt_avg / n: {cnt_avg / n}, cnt_any / n: {cnt_any / n}" + mes3 = f"'usage_so_far', {gpt_usage(args.backend)}" + f = open(metric_fp, 'a', encoding="utf8") + f.write(mes2) + f.write(mes3) + f.close() +llm_backend_choices = list(llm_config.keys()) + def parse_args(): args = argparse.ArgumentParser() args.add_argument("--backend", type=str, choices=llm_backend_choices, default="llama-2-7b-chat") @@ -115,11 +107,16 @@ def parse_args(): args.add_argument("--n_evaluate_sample", type=int, default=1) args.add_argument("--n_select_sample", type=int, default=1) + args.add_argument("--query_fp", type=str, default=f"./logs/default/query_{int(time.time())}.log") + args = args.parse_args() return args if __name__ == "__main__": args = parse_args() - print(args) - run(args) + if args.backend in llm_backend_choices: + chatter = llamaChatCompletion(args.backend) + elif args.backend in Ernie_llm_list: + chatter = Ernie(model=args.backend) + run(args, chatter=chatter)