From e9b210068437ecd51f6c2e8f2af0d1fff3696823 Mon Sep 17 00:00:00 2001 From: Ernestina <48557439+ErnestinaQiu@users.noreply.github.com> Date: Fri, 26 Jan 2024 00:23:24 +0800 Subject: [PATCH] Update bfs.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 兼容Ernie --- .../tree-of-thought/src/tot/methods/bfs.py | 55 ++++++++++--------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/pipelines/examples/tree-of-thought/src/tot/methods/bfs.py b/pipelines/examples/tree-of-thought/src/tot/methods/bfs.py index 5cef7dff9073..0b7ae4644353 100644 --- a/pipelines/examples/tree-of-thought/src/tot/methods/bfs.py +++ b/pipelines/examples/tree-of-thought/src/tot/methods/bfs.py @@ -12,65 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import itertools from functools import partial - import numpy as np from src.tot.models import gpt -def get_value(task, x, y, n_evaluate_sample, cache_value=True): +def get_value(task, x, y, n_evaluate_sample, cache_value=True, chatter=None, args=None): value_prompt = task.value_prompt_wrap(x, y) if cache_value and value_prompt in task.value_cache: return task.value_cache[value_prompt] - value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None) + value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None, chatter=chatter, args=chatter) value = task.value_outputs_unwrap(x, y, value_outputs) if cache_value: task.value_cache[value_prompt] = value return value -def get_values(task, x, ys, n_evaluate_sample, cache_value=True): +def get_values(task, x, ys, n_evaluate_sample, cache_value=True, chatter=None, args=None): values = [] local_value_cache = {} for y in ys: # each partial output if y in local_value_cache: # avoid duplicate candidates value = 0 else: - value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value) + value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value, chatter=chatter, args=args) local_value_cache[y] = value values.append(value) return values -def get_votes(task, x, ys, n_evaluate_sample): +def get_votes(task, x, ys, n_evaluate_sample, chatter=None, args=None): vote_prompt = task.vote_prompt_wrap(x, ys) - vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None) + vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None, chatter=chatter, args=args) values = task.vote_outputs_unwrap(vote_outputs, len(ys)) return values -def get_proposals(task, x, y): +def get_proposals(task, x, y, chatter=None, args=None): propose_prompt = task.propose_prompt_wrap(x, y) - proposals = gpt(propose_prompt, n=1, stop=None)[0].split("\n") + proposals = gpt(propose_prompt, n=1, stop=None, args=args, chatter=chatter)[0].split("\n") return [y + _ + "\n" for _ in proposals] -def get_samples(task, x, y, n_generate_sample, prompt_sample, stop): +def get_samples(task, x, y, n_generate_sample, prompt_sample, stop, chatter=None, args=None): if prompt_sample == "standard": prompt = task.standard_prompt_wrap(x, y) elif prompt_sample == "cot": prompt = task.cot_prompt_wrap(x, y) else: raise ValueError(f"prompt_sample {prompt_sample} not recognized") - samples = gpt(prompt, n=n_generate_sample, stop=stop) + samples = gpt(prompt, n=n_generate_sample, stop=stop, chatter=chatter, args=args) return [y + _ for _ in samples] -def solve(args, task, idx, to_print=True): +def solve(args, task, idx, to_print=True, chatter=None): global gpt - gpt = partial(gpt, model=args.backend, temperature=args.temperature) - print(gpt) + if chatter: + chatter.query = [] + + gpt = partial(gpt, model=args.backend, temperature=args.temperature, args=args, chatter=chatter) + logging.info(gpt) x = task.get_input(idx) # input ys = [""] # current output candidates infos = [] @@ -85,18 +88,20 @@ def solve(args, task, idx, to_print=True): args.n_generate_sample, prompt_sample=args.prompt_sample, stop=task.stops[step], + chatter=chatter, + args=args ) for y in ys ] elif args.method_generate == "propose": - new_ys = [get_proposals(task, x, y) for y in ys] + new_ys = [get_proposals(task, x, y, chatter=chatter, args=args) for y in ys] new_ys = list(itertools.chain(*new_ys)) ids = list(range(len(new_ys))) # evaluation if args.method_evaluate == "vote": - values = get_votes(task, x, new_ys, args.n_evaluate_sample) + values = get_votes(task, x, new_ys, args.n_evaluate_sample, chatter=chatter) elif args.method_evaluate == "value": - values = get_values(task, x, new_ys, args.n_evaluate_sample) + values = get_values(task, x, new_ys, args.n_evaluate_sample, chatter=chatter) # selection if args.method_select == "sample": @@ -109,9 +114,6 @@ def solve(args, task, idx, to_print=True): # log if to_print: sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True)) - print( - f"-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n" - ) infos.append( { @@ -125,14 +127,17 @@ def solve(args, task, idx, to_print=True): ) ys = select_new_ys - if to_print: - print(ys) + if args.query_fp and chatter: + f = open(args.query_fp, 'w', encoding="utf8") + f.write(str(chatter.query)) + f.close() + return ys, {"steps": infos} -def naive_solve(args, task, idx, to_print=True): +def naive_solve(args, task, idx, to_print=True, chatter=None): global gpt - gpt = partial(gpt, model=args.backend, temperature=args.temperature) + gpt = partial(gpt, model=args.backend, temperature=args.temperature, args=args, chatter=chatter) x = task.get_input(idx) # input - ys = get_samples(task, x, "", args.n_generate_sample, args.prompt_sample, stop=None) + ys = get_samples(task, x, "", args.n_generate_sample, args.prompt_sample, stop=None, chatter=chatter, args=args) return ys, {}