Skip to content

Commit

Permalink
Applied latest Black
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Apr 20, 2024
1 parent 9ff25ae commit 3ff82d8
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions llm_gpt4all.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ def retrieve_model(
verbose: bool = False,
) -> str:
try:
return _GPT4All.retrieve_model(model_name, model_path, allow_download, verbose)
return _GPT4All.retrieve_model(
model_name, model_path, allow_download, verbose
)
except requests.exceptions.ConnectionError:
return _GPT4All.retrieve_model(model_name, model_path, allow_download=False, verbose=verbose)

return _GPT4All.retrieve_model(
model_name, model_path, allow_download=False, verbose=verbose
)


def get_gpt4all_models():
Expand Down Expand Up @@ -73,22 +76,28 @@ class Options(llm.Options):
description="The maximum number of tokens to generate.", default=200
)
temp: float = Field(
description="The model temperature. Larger values increase creativity but decrease factuality.", default=0.7
description="The model temperature. Larger values increase creativity but decrease factuality.",
default=0.7,
)
top_k: int = Field(
description="Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding.", default=40
description="Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding.",
default=40,
)
top_p: float = Field(
description="Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p.", default=0.4
description="Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p.",
default=0.4,
)
repeat_penalty: float = Field(
description="Penalize the model for repetition. Higher values result in less repetition.", default=1.18
description="Penalize the model for repetition. Higher values result in less repetition.",
default=1.18,
)
repeat_last_n: int = Field(
description="How far in the models generation history to apply the repeat penalty.", default=64
description="How far in the models generation history to apply the repeat penalty.",
default=64,
)
n_batch: int = Field(
description="Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.", default=8
description="Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.",
default=8,
)

def __init__(self, details):
Expand Down Expand Up @@ -158,13 +167,13 @@ def execute(self, prompt, stream, response, conversation):
output = gpt_model.generate(
text_prompt,
streaming=True,
max_tokens = prompt.options.max_tokens or 400,
temp = prompt.options.temp,
top_k = prompt.options.top_k,
top_p = prompt.options.top_p,
repeat_penalty = prompt.options.repeat_penalty,
repeat_last_n = prompt.options.repeat_last_n,
n_batch = prompt.options.n_batch,
max_tokens=prompt.options.max_tokens or 400,
temp=prompt.options.temp,
top_k=prompt.options.top_k,
top_p=prompt.options.top_p,
repeat_penalty=prompt.options.repeat_penalty,
repeat_last_n=prompt.options.repeat_last_n,
n_batch=prompt.options.n_batch,
)
yield from output

Expand Down

0 comments on commit 3ff82d8

Please sign in to comment.