Skip to content

Commit

Permalink
Update run.py
Browse files Browse the repository at this point in the history
add Ernie
  • Loading branch information
ErnestinaQiu authored Jan 25, 2024
1 parent 26179dc commit e1fdd67
Showing 1 changed file with 49 additions and 52 deletions.
101 changes: 49 additions & 52 deletions pipelines/examples/tree-of-thought/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,87 +13,79 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import argparse
import json
import os

from src.tot.methods.bfs import naive_solve, solve
from src.tot.models import gpt_usage
from src.tot.tasks import get_task
from src.llm.llama import llm_config, llamaChatCompletion, Ernie_llm_list, Ernie


def run(args):
def run(args, chatter):
task = get_task(args.task)
logs, cnt_avg, cnt_any = [], 0, 0
if args.naive_run:
file = f"./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json"
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_naive_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_select}_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt"
else:
file = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json"
file = f'./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}.json'
metric_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.method_generate}{args.n_generate_sample}_{args.method_evaluate}{args.n_evaluate_sample}_{args.method_select}{args.n_select_sample}_start{args.task_start_index}_end{args.task_end_index}_metric.txt"
os.makedirs(os.path.dirname(file), exist_ok=True)


for i in range(args.task_start_index, args.task_end_index):
args.log_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}.log"
args.query_fp = f"./logs/{args.task}/{args.backend}_{args.temperature}_{args.prompt_sample}_sample_{args.n_generate_sample}_start{args.task_start_index}_end{args.task_end_index}_query.log"
f = open(args.log_fp, 'a', encoding='utf8')
f.write(f"------ index: {i}")
f.close()

f = open(args.query_fp, 'a', encoding='utf8')
f.write(f"------ index: {i}")
f.close()

chatter.query = []
chatter.tokenizer.init_chat_template(os.path.join(os.getcwd(), "pipelines", "examples", "tree-of-thought", "src", "llm", "chat_template.json"))

# solve
if args.naive_run:
ys, info = naive_solve(args, task, i)
ys, info = naive_solve(args, task, i, chatter=chatter, args=args)
else:
ys, info = solve(args, task, i)
ys, info = solve(args, task, i, chatter=chatter, args=args)

# log
infos = [task.test_output(i, y) for y in ys]
info.update(
{
"idx": i,
"ys": ys,
"infos": infos,
"usage_so_far": gpt_usage(args.backend),
}
)
info.update({'idx': i, 'ys': ys, 'infos': infos, 'usage_so_far': gpt_usage(args.backend)})
logs.append(info)
with open(file, "w") as f:
with open(file, 'w') as f:
json.dump(logs, f, indent=4)

# log main metric
accs = [info["r"] for info in infos]
accs = [info['r'] for info in infos]
cnt_avg += sum(accs) / len(accs)
cnt_any += any(accs)
print(i, "sum(accs)", sum(accs), "cnt_avg", cnt_avg, "cnt_any", cnt_any, "\n")
mes = f"{i}, 'sum(accs)', {sum(accs)}, 'cnt_avg', {cnt_avg}, 'cnt_any', {cnt_any}, '\n'"
f = open(metric_fp, 'a', encoding="utf8")
f.write(mes)
f.close()

f = open(args.query_fp, 'a', encoding="utf8")
f.write(json.dumps(chatter.query))
f.close()

n = args.task_end_index - args.task_start_index
print(cnt_avg / n, cnt_any / n)
print("usage_so_far", gpt_usage(args.backend))


llm_backend_choices = [
"llama-2-7b",
"llama-2-7b-chat",
"llama-2-13b",
"llama-2-13b-chat",
"llama-2-70b",
"llama-2-70b-chat",
"llama-7b",
"llama-13b",
"llama-30b",
"llama-65b",
"ziqingyang/chinese-llama-7b",
"ziqingyang/chinese-llama-13b",
"ziqingyang/chinese-alpaca-7b",
"ziqingyang/chinese-alpaca-13b",
"idea-ccnl/ziya-llama-13b-v1",
"linly-ai/chinese-llama-2-7b",
"linly-ai/chinese-llama-2-13b",
"baichuan-inc/Baichuan-7B",
"baichuan-inc/Baichuan-13B-Base",
"baichuan-inc/Baichuan-13B-Chat",
"baichuan-inc/Baichuan2-7B-Base",
"baichuan-inc/Baichuan2-7B-Chat",
"baichuan-inc/Baichuan2-13B-Base",
"baichuan-inc/Baichuan2-13B-Chat",
"FlagAlpha/Llama2-Chinese-7b-Chat",
"FlagAlpha/Llama2-Chinese-13b-Chat",
]
mes2 = f"cnt_avg / n: {cnt_avg / n}, cnt_any / n: {cnt_any / n}"
mes3 = f"'usage_so_far', {gpt_usage(args.backend)}"
f = open(metric_fp, 'a', encoding="utf8")
f.write(mes2)
f.write(mes3)
f.close()


llm_backend_choices = list(llm_config.keys())

def parse_args():
args = argparse.ArgumentParser()
args.add_argument("--backend", type=str, choices=llm_backend_choices, default="llama-2-7b-chat")
Expand All @@ -115,11 +107,16 @@ def parse_args():
args.add_argument("--n_evaluate_sample", type=int, default=1)
args.add_argument("--n_select_sample", type=int, default=1)

args.add_argument("--query_fp", type=str, default=f"./logs/default/query_{int(time.time())}.log")

args = args.parse_args()
return args


if __name__ == "__main__":
args = parse_args()
print(args)
run(args)
if args.backend in llm_backend_choices:
chatter = llamaChatCompletion(args.backend)
elif args.backend in Ernie_llm_list:
chatter = Ernie(model=args.backend)
run(args, chatter=chatter)

0 comments on commit e1fdd67

Please sign in to comment.