diff --git a/llm_gpt4all.py b/llm_gpt4all.py index 9b4b137..8d97ae9 100644 --- a/llm_gpt4all.py +++ b/llm_gpt4all.py @@ -35,10 +35,13 @@ def retrieve_model( verbose: bool = False, ) -> str: try: - return _GPT4All.retrieve_model(model_name, model_path, allow_download, verbose) + return _GPT4All.retrieve_model( + model_name, model_path, allow_download, verbose + ) except requests.exceptions.ConnectionError: - return _GPT4All.retrieve_model(model_name, model_path, allow_download=False, verbose=verbose) - + return _GPT4All.retrieve_model( + model_name, model_path, allow_download=False, verbose=verbose + ) def get_gpt4all_models(): @@ -73,22 +76,28 @@ class Options(llm.Options): description="The maximum number of tokens to generate.", default=200 ) temp: float = Field( - description="The model temperature. Larger values increase creativity but decrease factuality.", default=0.7 + description="The model temperature. Larger values increase creativity but decrease factuality.", + default=0.7, ) top_k: int = Field( - description="Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding.", default=40 + description="Randomly sample from the top_k most likely tokens at each generation step. Set this to 1 for greedy decoding.", + default=40, ) top_p: float = Field( - description="Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p.", default=0.4 + description="Randomly sample at each generation step from the top most likely tokens whose probabilities add up to top_p.", + default=0.4, ) repeat_penalty: float = Field( - description="Penalize the model for repetition. Higher values result in less repetition.", default=1.18 + description="Penalize the model for repetition. Higher values result in less repetition.", + default=1.18, ) repeat_last_n: int = Field( - description="How far in the models generation history to apply the repeat penalty.", default=64 + description="How far in the models generation history to apply the repeat penalty.", + default=64, ) n_batch: int = Field( - description="Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.", default=8 + description="Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.", + default=8, ) def __init__(self, details): @@ -158,13 +167,13 @@ def execute(self, prompt, stream, response, conversation): output = gpt_model.generate( text_prompt, streaming=True, - max_tokens = prompt.options.max_tokens or 400, - temp = prompt.options.temp, - top_k = prompt.options.top_k, - top_p = prompt.options.top_p, - repeat_penalty = prompt.options.repeat_penalty, - repeat_last_n = prompt.options.repeat_last_n, - n_batch = prompt.options.n_batch, + max_tokens=prompt.options.max_tokens or 400, + temp=prompt.options.temp, + top_k=prompt.options.top_k, + top_p=prompt.options.top_p, + repeat_penalty=prompt.options.repeat_penalty, + repeat_last_n=prompt.options.repeat_last_n, + n_batch=prompt.options.n_batch, ) yield from output