Skip to content

Commit

Permalink
Merge pull request #98 from acon96/release/v0.2.10
Browse files Browse the repository at this point in the history
Release v0.2.10
  • Loading branch information
acon96 authored Mar 24, 2024
2 parents 5f6e3dd + 1c2cbc6 commit 0244e79
Show file tree
Hide file tree
Showing 9 changed files with 438 additions and 264 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ In order to facilitate running the project entirely on the system where Home Ass
## Version History
| Version | Description |
| ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| v0.2.10 | Allow configuring the model parameters during initial setup, attempt to auto-detect defaults for recommended models, Fix to allow lights to be set to max brightness |
| v0.2.9 | Fix HuggingFace Download, Fix llama.cpp wheel installation, Fix light color changing, Add in-context-learning support |
| v0.2.8 | Fix ollama model names with colons |
| v0.2.7 | Publish model v3, Multiple Ollama backend improvements, Updates for HA 2024.02, support for voice assistant aliases |
Expand Down
19 changes: 7 additions & 12 deletions custom_components/llama_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,6 @@ async def update_listener(hass, entry):
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Local LLaMA Conversation from a config entry."""

# TODO: figure out how to make this happen as part of the config flow. when I tried it errored out passing options in
if len(entry.options) == 0:
entry.options = { **DEFAULT_OPTIONS }
copy_to_options = [ CONF_REMOTE_USE_CHAT_ENDPOINT, CONF_TEXT_GEN_WEBUI_CHAT_MODE, CONF_TEXT_GEN_WEBUI_PRESET ]
for item in copy_to_options:
value = entry.data.get(item)
if value:
entry.options[item] = value

def create_agent(backend_type):
agent_cls = None
Expand Down Expand Up @@ -181,10 +173,9 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
)

self.in_context_examples = None
if entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
self._load_icl_examples()
else:
self.in_context_examples = None

self._load_model(entry)

Expand Down Expand Up @@ -235,6 +226,8 @@ async def async_process(
"""Process a sentence."""

raw_prompt = self.entry.options.get(CONF_PROMPT, DEFAULT_PROMPT)
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
template_desc = PROMPT_TEMPLATE_DESCRIPTIONS[prompt_template]
refresh_system_prompt = self.entry.options.get(CONF_REFRESH_SYSTEM_PROMPT, DEFAULT_REFRESH_SYSTEM_PROMPT)
remember_conversation = self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, False)
Expand Down Expand Up @@ -345,7 +338,7 @@ async def async_process(

# fix certain arguments
# make sure brightness is 0-255 and not a percentage
if "brightness" in extra_arguments and 0.0 < extra_arguments["brightness"] < 1.0:
if "brightness" in extra_arguments and 0.0 < extra_arguments["brightness"] <= 1.0:
extra_arguments["brightness"] = int(extra_arguments["brightness"] * 255)

# convert string "tuple" to a list for RGB colors
Expand Down Expand Up @@ -374,7 +367,8 @@ async def async_process(
to_say += f"\nFailed to run: {line}"
_LOGGER.exception(f"Failed to run: {line}")

to_say = to_say.replace("<|im_end|>", "") # remove the eos token if it is returned (some backends + the old model does this)
if template_desc["assistant"]["suffix"]:
to_say = to_say.replace(template_desc["assistant"]["suffix"], "") # remove the eos token if it is returned (some backends + the old model does this)

intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_speech(to_say)
Expand Down Expand Up @@ -567,6 +561,7 @@ def _load_model(self, entry: ConfigEntry) -> None:
# n_threads_batch=4,
)

self.grammar = None
if entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
self._load_grammar()

Expand Down
70 changes: 51 additions & 19 deletions custom_components/llama_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
CONF_TEXT_GEN_WEBUI_CHAT_MODE,
CONF_OLLAMA_KEEP_ALIVE_MIN,
DEFAULT_CHAT_MODEL,
DEFAULT_HOST,
DEFAULT_PORT,
DEFAULT_SSL,
DEFAULT_MAX_TOKENS,
Expand All @@ -80,7 +79,6 @@
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_SERVICE_CALL_REGEX,
DEFAULT_OPTIONS,
DEFAULT_REMOTE_USE_CHAT_ENDPOINT,
DEFAULT_TEXT_GEN_WEBUI_CHAT_MODE,
DEFAULT_OLLAMA_KEEP_ALIVE_MIN,
Expand All @@ -95,6 +93,8 @@
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
DOMAIN,
DEFAULT_OPTIONS,
OPTIONS_OVERRIDES,
)

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -138,31 +138,25 @@ def STEP_LOCAL_SETUP_DOWNLOAD_DATA_SCHEMA(*, chat_model=None, downloaded_model_q
}
)

def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None, use_chat_endpoint=None, webui_preset="", webui_chat_mode=""):
def STEP_REMOTE_SETUP_DATA_SCHEMA(backend_type: str, *, host=None, port=None, ssl=None, chat_model=None):

extra1, extra2 = ({}, {})
default_port = DEFAULT_PORT

if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_PRESET, default=webui_preset)] = str
extra1[vol.Optional(CONF_TEXT_GEN_WEBUI_CHAT_MODE, default=webui_chat_mode)] = SelectSelector(SelectSelectorConfig(
options=["", TEXT_GEN_WEBUI_CHAT_MODE_CHAT, TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT, TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT],
translation_key=CONF_TEXT_GEN_WEBUI_CHAT_MODE,
multiple=False,
mode=SelectSelectorMode.DROPDOWN,
))
extra2[vol.Optional(CONF_TEXT_GEN_WEBUI_ADMIN_KEY)] = TextSelector(TextSelectorConfig(type="password"))

elif backend_type == BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER:
default_port = "8000"
elif backend_type == BACKEND_TYPE_OLLAMA:
default_port = "11434"

return vol.Schema(
{
vol.Required(CONF_HOST, default=host if host else DEFAULT_HOST): str,
vol.Required(CONF_HOST, default=host if host else ""): str,
vol.Required(CONF_PORT, default=port if port else default_port): str,
vol.Required(CONF_SSL, default=ssl if ssl else DEFAULT_SSL): bool,
vol.Required(CONF_CHAT_MODEL, default=chat_model if chat_model else DEFAULT_CHAT_MODEL): str,
vol.Required(CONF_REMOTE_USE_CHAT_ENDPOINT, default=use_chat_endpoint if use_chat_endpoint else DEFAULT_REMOTE_USE_CHAT_ENDPOINT): bool,
**extra1,
vol.Optional(CONF_OPENAI_API_KEY): TextSelector(TextSelectorConfig(type="password")),
**extra2
Expand Down Expand Up @@ -202,6 +196,12 @@ async def async_step_remote_model(
) -> FlowResult:
""" Configure a remote model """

@abstractmethod
async def async_step_model_parameters(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
""" Configure a remote model """

@abstractmethod
async def async_step_download(
self, user_input: dict[str, Any] | None = None
Expand All @@ -222,7 +222,8 @@ class ConfigFlow(BaseLlamaConversationConfigFlow, config_entries.ConfigFlow, dom
install_wheel_error = None
download_task = None
download_error = None
model_config: dict[str, Any] = {}
model_config: dict[str, Any]
options: dict[str, Any]

@property
def flow_manager(self) -> config_entries.ConfigEntriesFlowManager:
Expand All @@ -237,6 +238,8 @@ async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the initial step."""
self.model_config = {}
self.options = {}
return await self.async_step_pick_backend()

async def async_step_pick_backend(
Expand Down Expand Up @@ -384,7 +387,7 @@ async def async_step_download(
next_step = "local_model"
else:
self.model_config[CONF_DOWNLOADED_MODEL_FILE] = self.download_task.result()
next_step = "finish"
next_step = "model_parameters"

self.download_task = None
return self.async_show_progress_done(next_step_id=next_step)
Expand All @@ -404,6 +407,7 @@ def _validate_text_generation_webui(self, user_input: dict) -> str:

models_result = requests.get(
f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
timeout=5, # quick timeout
headers=headers
)
models_result.raise_for_status()
Expand Down Expand Up @@ -435,6 +439,7 @@ def _validate_ollama(self, user_input: dict) -> str:

models_result = requests.get(
f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/api/tags",
timeout=5, # quick timeout
headers=headers
)
models_result.raise_for_status()
Expand Down Expand Up @@ -466,7 +471,7 @@ async def async_step_remote_model(
if user_input:
try:
self.model_config.update(user_input)
error_reason = None
error_message = None

# validate and load when using text-generation-webui or ollama
if backend_type == BACKEND_TYPE_TEXT_GEN_WEBUI:
Expand All @@ -488,12 +493,9 @@ async def async_step_remote_model(
port=user_input[CONF_PORT],
ssl=user_input[CONF_SSL],
chat_model=user_input[CONF_CHAT_MODEL],
use_chat_endpoint=user_input[CONF_REMOTE_USE_CHAT_ENDPOINT],
webui_preset=user_input.get(CONF_TEXT_GEN_WEBUI_PRESET),
webui_chat_mode=user_input.get(CONF_TEXT_GEN_WEBUI_CHAT_MODE),
)
else:
return await self.async_step_finish()
return await self.async_step_model_parameters()

except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
Expand All @@ -502,6 +504,35 @@ async def async_step_remote_model(
return self.async_show_form(
step_id="remote_model", data_schema=schema, errors=errors, description_placeholders=description_placeholders,
)

async def async_step_model_parameters(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
errors = {}
description_placeholders = {}
backend_type = self.model_config[CONF_BACKEND_TYPE]
model_name = self.model_config[CONF_CHAT_MODEL].lower()

selected_default_options = { **DEFAULT_OPTIONS }
for key in OPTIONS_OVERRIDES.keys():
if key in model_name:
selected_default_options.update(OPTIONS_OVERRIDES[key])

schema = vol.Schema(local_llama_config_option_schema(selected_default_options, backend_type))

if user_input:
self.options = user_input
try:
# validate input
schema(user_input)
return await self.async_step_finish()
except Exception as ex:
_LOGGER.exception("An unknown error has occurred!")
errors["base"] = "unknown"

return self.async_show_form(
step_id="model_parameters", data_schema=schema, errors=errors, description_placeholders=description_placeholders,
)

async def async_step_finish(
self, user_input: dict[str, Any] | None = None
Expand All @@ -517,6 +548,7 @@ async def async_step_finish(
title=f"LLM Model '{model_name}' ({location})",
description="A Large Language Model Chat Agent",
data=self.model_config,
options=self.options,
)

@staticmethod
Expand Down
Loading

0 comments on commit 0244e79

Please sign in to comment.