Skip to content

Commit

Permalink
Update bfs.py
Browse files Browse the repository at this point in the history
兼容Ernie
  • Loading branch information
ErnestinaQiu authored Jan 25, 2024
1 parent 1face8b commit e9b2100
Showing 1 changed file with 30 additions and 25 deletions.
55 changes: 30 additions & 25 deletions pipelines/examples/tree-of-thought/src/tot/methods/bfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,65 +12,68 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import itertools
from functools import partial

import numpy as np
from src.tot.models import gpt


def get_value(task, x, y, n_evaluate_sample, cache_value=True):
def get_value(task, x, y, n_evaluate_sample, cache_value=True, chatter=None, args=None):
value_prompt = task.value_prompt_wrap(x, y)
if cache_value and value_prompt in task.value_cache:
return task.value_cache[value_prompt]
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None)
value_outputs = gpt(value_prompt, n=n_evaluate_sample, stop=None, chatter=chatter, args=chatter)
value = task.value_outputs_unwrap(x, y, value_outputs)
if cache_value:
task.value_cache[value_prompt] = value
return value


def get_values(task, x, ys, n_evaluate_sample, cache_value=True):
def get_values(task, x, ys, n_evaluate_sample, cache_value=True, chatter=None, args=None):
values = []
local_value_cache = {}
for y in ys: # each partial output
if y in local_value_cache: # avoid duplicate candidates
value = 0
else:
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value)
value = get_value(task, x, y, n_evaluate_sample, cache_value=cache_value, chatter=chatter, args=args)
local_value_cache[y] = value
values.append(value)
return values


def get_votes(task, x, ys, n_evaluate_sample):
def get_votes(task, x, ys, n_evaluate_sample, chatter=None, args=None):
vote_prompt = task.vote_prompt_wrap(x, ys)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None)
vote_outputs = gpt(vote_prompt, n=n_evaluate_sample, stop=None, chatter=chatter, args=args)
values = task.vote_outputs_unwrap(vote_outputs, len(ys))
return values


def get_proposals(task, x, y):
def get_proposals(task, x, y, chatter=None, args=None):
propose_prompt = task.propose_prompt_wrap(x, y)
proposals = gpt(propose_prompt, n=1, stop=None)[0].split("\n")
proposals = gpt(propose_prompt, n=1, stop=None, args=args, chatter=chatter)[0].split("\n")
return [y + _ + "\n" for _ in proposals]


def get_samples(task, x, y, n_generate_sample, prompt_sample, stop):
def get_samples(task, x, y, n_generate_sample, prompt_sample, stop, chatter=None, args=None):
if prompt_sample == "standard":
prompt = task.standard_prompt_wrap(x, y)
elif prompt_sample == "cot":
prompt = task.cot_prompt_wrap(x, y)
else:
raise ValueError(f"prompt_sample {prompt_sample} not recognized")
samples = gpt(prompt, n=n_generate_sample, stop=stop)
samples = gpt(prompt, n=n_generate_sample, stop=stop, chatter=chatter, args=args)
return [y + _ for _ in samples]


def solve(args, task, idx, to_print=True):
def solve(args, task, idx, to_print=True, chatter=None):
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
print(gpt)
if chatter:
chatter.query = []

gpt = partial(gpt, model=args.backend, temperature=args.temperature, args=args, chatter=chatter)
logging.info(gpt)
x = task.get_input(idx) # input
ys = [""] # current output candidates
infos = []
Expand All @@ -85,18 +88,20 @@ def solve(args, task, idx, to_print=True):
args.n_generate_sample,
prompt_sample=args.prompt_sample,
stop=task.stops[step],
chatter=chatter,
args=args
)
for y in ys
]
elif args.method_generate == "propose":
new_ys = [get_proposals(task, x, y) for y in ys]
new_ys = [get_proposals(task, x, y, chatter=chatter, args=args) for y in ys]
new_ys = list(itertools.chain(*new_ys))
ids = list(range(len(new_ys)))
# evaluation
if args.method_evaluate == "vote":
values = get_votes(task, x, new_ys, args.n_evaluate_sample)
values = get_votes(task, x, new_ys, args.n_evaluate_sample, chatter=chatter)
elif args.method_evaluate == "value":
values = get_values(task, x, new_ys, args.n_evaluate_sample)
values = get_values(task, x, new_ys, args.n_evaluate_sample, chatter=chatter)

# selection
if args.method_select == "sample":
Expand All @@ -109,9 +114,6 @@ def solve(args, task, idx, to_print=True):
# log
if to_print:
sorted_new_ys, sorted_values = zip(*sorted(zip(new_ys, values), key=lambda x: x[1], reverse=True))
print(
f"-- new_ys --: {sorted_new_ys}\n-- sol values --: {sorted_values}\n-- choices --: {select_new_ys}\n"
)

infos.append(
{
Expand All @@ -125,14 +127,17 @@ def solve(args, task, idx, to_print=True):
)
ys = select_new_ys

if to_print:
print(ys)
if args.query_fp and chatter:
f = open(args.query_fp, 'w', encoding="utf8")
f.write(str(chatter.query))
f.close()

return ys, {"steps": infos}


def naive_solve(args, task, idx, to_print=True):
def naive_solve(args, task, idx, to_print=True, chatter=None):
global gpt
gpt = partial(gpt, model=args.backend, temperature=args.temperature)
gpt = partial(gpt, model=args.backend, temperature=args.temperature, args=args, chatter=chatter)
x = task.get_input(idx) # input
ys = get_samples(task, x, "", args.n_generate_sample, args.prompt_sample, stop=None)
ys = get_samples(task, x, "", args.n_generate_sample, args.prompt_sample, stop=None, chatter=chatter, args=args)
return ys, {}

0 comments on commit e9b2100

Please sign in to comment.