Skip to content

Commit

Permalink
Merge pull request #94 from acon96/release/v0.2.9
Browse files Browse the repository at this point in the history
Release v0.2.9
  • Loading branch information
acon96 authored Mar 21, 2024
2 parents 316459b + 7637c75 commit 1ab0d82
Show file tree
Hide file tree
Showing 26 changed files with 679 additions and 364 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ In order to facilitate running the project entirely on the system where Home Ass
## Version History
| Version | Description |
| ------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| v0.2.9 | Fix HuggingFace Download, Fix llama.cpp wheel installation, Fix light color changing, Add in-context-learning support |
| v0.2.8 | Fix ollama model names with colons |
| v0.2.7 | Publish model v3, Multiple Ollama backend improvements, Updates for HA 2024.02, support for voice assistant aliases |
| v0.2.6 | Bug fixes, add options for limiting chat history, HTTPS endpoint support, added zephyr prompt format. |
Expand Down
7 changes: 5 additions & 2 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# TODO
- [ ] setup github actions to build wheels that are optimized for RPIs
- [ ] setup github actions to build wheels that are optimized for RPIs??
- [ ] setup github actions to publish docker images for text-gen-webui addon
- [ ] detection/mitigation of too many entities being exposed & blowing out the context length
- [ ] areas/room support
- [ ] figure out DPO for refusals + fixing incorrect entity id
- [ ] mixtral + prompting (no fine tuning)
- [x] mixtral + prompting (no fine tuning)
- add in context learning variables to sys prompt template
- add new options to setup process for setting prompt style + picking fine-tuned/ICL
- [ ] prime kv cache with current "state" so that requests are faster
- [ ] support fine-tuning with RoPE for longer contexts
- [ ] support config via yaml instead of configflow
- [x] ChatML format (actually need to add special tokens)
- [x] Vicuna dataset merge (yahma/alpaca-cleaned)
- [x] Phi-2 fine tuning
Expand Down
128 changes: 112 additions & 16 deletions custom_components/llama_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import re
import os
import json
import csv
import random

import homeassistant.components.conversation as ha_conversation
from homeassistant.components.conversation import ConversationInput, ConversationResult, AbstractConversationAgent
Expand All @@ -33,8 +35,10 @@
CONF_BACKEND_TYPE,
CONF_DOWNLOADED_MODEL_FILE,
CONF_EXTRA_ATTRIBUTES_TO_EXPOSE,
CONF_ALLOWED_SERVICE_CALL_ARGUMENTS,
CONF_PROMPT_TEMPLATE,
CONF_USE_GBNF_GRAMMAR,
CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES,
CONF_TEXT_GEN_WEBUI_PRESET,
CONF_OPENAI_API_KEY,
CONF_TEXT_GEN_WEBUI_ADMIN_KEY,
Expand All @@ -53,8 +57,10 @@
DEFAULT_BACKEND_TYPE,
DEFAULT_REQUEST_TIMEOUT,
DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE,
DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS,
DEFAULT_PROMPT_TEMPLATE,
DEFAULT_USE_GBNF_GRAMMAR,
DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES,
DEFAULT_REFRESH_SYSTEM_PROMPT,
DEFAULT_REMEMBER_CONVERSATION,
DEFAULT_SERVICE_CALL_REGEX,
Expand All @@ -73,6 +79,7 @@
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
DOMAIN,
GBNF_GRAMMAR_FILE,
IN_CONTEXT_EXAMPLES_FILE,
PROMPT_TEMPLATE_DESCRIPTIONS,
)

Expand All @@ -83,6 +90,10 @@
async def update_listener(hass, entry):
"""Handle options update."""
hass.data[DOMAIN][entry.entry_id] = entry

# call update handler
agent = await ha_conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
agent._update_options()
return True

async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
Expand Down Expand Up @@ -158,6 +169,7 @@ class LLaMAAgent(AbstractConversationAgent):
hass: Any
entry_id: str
history: dict[str, list[dict]]
in_context_examples: list[dict]

def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
Expand All @@ -169,8 +181,34 @@ def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
CONF_BACKEND_TYPE, DEFAULT_BACKEND_TYPE
)

if entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
self._load_icl_examples()
else:
self.in_context_examples = None

self._load_model(entry)

def _load_icl_examples(self):
try:
icl_filename = os.path.join(os.path.dirname(__file__), IN_CONTEXT_EXAMPLES_FILE)

with open(icl_filename) as f:
self.in_context_examples = list(csv.DictReader(f))

if set(self.in_context_examples[0].keys()) != set(["service", "response" ]):
raise Exception("ICL csv file did not have 2 columns: service & response")

_LOGGER.debug(f"Loaded {len(self.in_context_examples)} examples for ICL")
except Exception:
_LOGGER.exception("Failed to load in context learning examples!")
self.in_context_examples = None

def _update_options(self):
if self.entry.options.get(CONF_USE_IN_CONTEXT_LEARNING_EXAMPLES, DEFAULT_USE_IN_CONTEXT_LEARNING_EXAMPLES):
self._load_icl_examples()
else:
self.in_context_examples = None

@property
def entry(self):
return self.hass.data[DOMAIN][self.entry_id]
Expand Down Expand Up @@ -201,8 +239,8 @@ async def async_process(
remember_conversation = self.entry.options.get(CONF_REMEMBER_CONVERSATION, DEFAULT_REMEMBER_CONVERSATION)
remember_num_interactions = self.entry.options.get(CONF_REMEMBER_NUM_INTERACTIONS, False)
service_call_regex = self.entry.options.get(CONF_SERVICE_CALL_REGEX, DEFAULT_SERVICE_CALL_REGEX)
extra_attributes_to_expose = self.entry.options \
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
allowed_service_call_arguments = self.entry.options \
.get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS)

try:
service_call_pattern = re.compile(service_call_regex)
Expand Down Expand Up @@ -291,6 +329,9 @@ async def async_process(
service = json_output["service"]
entity = json_output["target_device"]
domain, service = tuple(service.split("."))
if "to_say" in json_output:
to_say = to_say + json_output.pop("to_say")

extra_arguments = { k: v for k, v in json_output.items() if k not in [ "service", "target_device" ] }
except Exception:
try:
Expand All @@ -307,17 +348,22 @@ async def async_process(
if "brightness" in extra_arguments and 0.0 < extra_arguments["brightness"] < 1.0:
extra_arguments["brightness"] = int(extra_arguments["brightness"] * 255)

# convert string "tuple" to a list for RGB colors
if "rgb_color" in extra_arguments and isinstance(extra_arguments["rgb_color"], str):
extra_arguments["rgb_color"] = [ int(x) for x in extra_arguments["rgb_color"][1:-1].split(",") ]

# only acknowledge requests to exposed entities
if entity not in exposed_entities:
to_say += f" Can't find device '{entity}'!"
else:
# copy arguments to service call
service_data = {ATTR_ENTITY_ID: entity}
for attr in extra_attributes_to_expose:
for attr in allowed_service_call_arguments:
if attr in extra_arguments.keys():
service_data[attr] = extra_arguments[attr]

try:
_LOGGER.debug(f"service data: {service_data}")
await self.hass.services.async_call(
domain,
service,
Expand All @@ -326,7 +372,7 @@ async def async_process(
)
except Exception as err:
to_say += f"\nFailed to run: {line}"
_LOGGER.debug(f"err: {err}; {repr(err)}")
_LOGGER.exception(f"Failed to run: {line}")

to_say = to_say.replace("<|im_end|>", "") # remove the eos token if it is returned (some backends + the old model does this)

Expand Down Expand Up @@ -367,6 +413,11 @@ def _format_prompt(
prompt_template = self.entry.options.get(CONF_PROMPT_TEMPLATE, DEFAULT_PROMPT_TEMPLATE)
template_desc = PROMPT_TEMPLATE_DESCRIPTIONS[prompt_template]

# handle models without a system prompt
if prompt[0]["role"] == "system" and "system" not in template_desc:
system_prompt = prompt.pop(0)
prompt[0]["message"] = system_prompt["message"] + prompt[0]["message"]

for message in prompt:
role = message["role"]
message = message["message"]
Expand All @@ -378,7 +429,7 @@ def _format_prompt(
if include_generation_prompt:
formatted_prompt = formatted_prompt + template_desc["generation_prompt"]

# _LOGGER.debug(formatted_prompt)
_LOGGER.debug(formatted_prompt)
return formatted_prompt

def _generate_system_prompt(self, prompt_template: str) -> str:
Expand All @@ -387,6 +438,32 @@ def _generate_system_prompt(self, prompt_template: str) -> str:

extra_attributes_to_expose = self.entry.options \
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)
allowed_service_call_arguments = self.entry.options \
.get(CONF_ALLOWED_SERVICE_CALL_ARGUMENTS, DEFAULT_ALLOWED_SERVICE_CALL_ARGUMENTS)

def icl_example_generator(num_examples, entity_names, service_names):
entity_domains = set([x.split(".")[0] for x in entity_names])
entity_names = entity_names[:]

# filter out examples for disabled services
selected_in_context_examples = []
for x in self.in_context_examples:
if x["service"] in service_names and x["service"].split(".")[0] in entity_domains:
selected_in_context_examples.append(x)

random.shuffle(selected_in_context_examples)
random.shuffle(entity_names)

for x in range(num_examples):
chosen_example = selected_in_context_examples.pop()
chosen_service = chosen_example["service"]
device = [ x for x in entity_names if x.split(".")[0] == chosen_service.split(".")[0] ][0]
example = {
"to_say": chosen_example["response"],
"service": chosen_service,
"target_device": device,
}
yield json.dumps(example) + "\n"

def expose_attributes(attributes):
result = attributes["state"]
Expand Down Expand Up @@ -428,25 +505,34 @@ def expose_attributes(attributes):

service_dict = self.hass.services.async_services()
all_services = []
all_service_names = []
for domain in domains:
for name, service in service_dict.get(domain, {}).items():
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(extra_attributes_to_expose)
args_to_expose = set(args).intersection(allowed_service_call_arguments)
all_services.append(f"{domain}.{name}({','.join(args_to_expose)})")
all_service_names.append(f"{domain}.{name}")
formatted_services = ", ".join(all_services)

render_variables = {
"devices": formatted_states,
"services": formatted_services,
}

if self.in_context_examples:
# TODO: make number of examples configurable
render_variables["response_examples"] = "\n".join(icl_example_generator(4, list(entities_to_expose.keys()), all_service_names))

return template.Template(prompt_template, self.hass).async_render(
{
"devices": formatted_states,
"services": formatted_services,
},
render_variables,
parse_result=False,
)

class LocalLLaMAAgent(LLaMAAgent):
model_path: str
llm: Any
grammar: Any
llama_cpp_module: Any

def _load_model(self, entry: ConfigEntry) -> None:
self.model_path = entry.data.get(CONF_DOWNLOADED_MODEL_FILE)
Expand All @@ -460,17 +546,16 @@ def _load_model(self, entry: ConfigEntry) -> None:

# don't import it until now because the wheel is installed by config_flow.py
try:
module = importlib.import_module("llama_cpp")
self.llama_cpp_module = importlib.import_module("llama_cpp")
except ModuleNotFoundError:
# attempt to re-install llama-cpp-python if it was uninstalled for some reason
install_result = install_llama_cpp_python(self.hass.config.config_dir)
if not install_result == True:
raise ConfigEntryError("llama-cpp-python was not installed on startup and re-installing it led to an error!")

module = importlib.import_module("llama_cpp")
self.llama_cpp_module = importlib.import_module("llama_cpp")

Llama = getattr(module, "Llama")
LlamaGrammar = getattr(module, "LlamaGrammar")
Llama = getattr(self.llama_cpp_module, "Llama")

_LOGGER.debug("Loading model...")
self.llm = Llama(
Expand All @@ -482,6 +567,11 @@ def _load_model(self, entry: ConfigEntry) -> None:
# n_threads_batch=4,
)

if entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
self._load_grammar()

def _load_grammar(self):
LlamaGrammar = getattr(self.llama_cpp_module, "LlamaGrammar")
_LOGGER.debug("Loading grammar...")
try:
# TODO: make grammar configurable
Expand All @@ -492,6 +582,13 @@ def _load_model(self, entry: ConfigEntry) -> None:
except Exception:
_LOGGER.exception("Failed to load grammar!")
self.grammar = None

def _update_options(self):
LLaMAAgent._update_options()
if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR):
self._load_grammar()
else:
self.grammar = None

def _generate(self, conversation: dict) -> str:
prompt = self._format_prompt(conversation)
Expand All @@ -503,7 +600,6 @@ def _generate(self, conversation: dict) -> str:
temperature = self.entry.options.get(CONF_TEMPERATURE, DEFAULT_TEMPERATURE)
top_k = int(self.entry.options.get(CONF_TOP_K, DEFAULT_TOP_K))
top_p = self.entry.options.get(CONF_TOP_P, DEFAULT_TOP_P)
grammar = self.grammar if self.entry.options.get(CONF_USE_GBNF_GRAMMAR, DEFAULT_USE_GBNF_GRAMMAR) else None

_LOGGER.debug(f"Options: {self.entry.options}")

Expand All @@ -513,7 +609,7 @@ def _generate(self, conversation: dict) -> str:
temp=temperature,
top_k=top_k,
top_p=top_p,
grammar=grammar
grammar=self.grammar
)

result_tokens = []
Expand Down
Loading

0 comments on commit 1ab0d82

Please sign in to comment.