Skip to content

Commit

Permalink
Merge pull request #40 from jekalmin/v0.0.8
Browse files Browse the repository at this point in the history
bump to 0.0.8
  • Loading branch information
jekalmin authored Dec 3, 2023
2 parents 13867d1 + 19b05a2 commit 78f7ac2
Show file tree
Hide file tree
Showing 6 changed files with 570 additions and 13 deletions.
100 changes: 100 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,106 @@ When using [ytube_music_player](https://github.com/KoljaWindeler/ytube_music_pla

<img width="300" alt="스크린샷 2023-11-02 오후 8 40 36" src="https://github.com/jekalmin/extended_openai_conversation/assets/2917984/648efef8-40d1-45d2-b3f9-9bac4a36c517">

### 7. sqlite
#### 7-1. Let model generate a query
- Without examples, a query tries to fetch data only from "states" table like below
> Question: When did bedroom light turn on? <br/>
Query(generated by gpt-3.5): SELECT * FROM states WHERE entity_id = 'input_boolean.livingroom_light_2' AND state = 'on' ORDER BY last_changed DESC LIMIT 1
- Since "entity_id" is stored in "states_meta" table, we need to give examples of question and query.
- Not secured, but flexible way

```yaml
- spec:
name: query_histories_from_db
description: >-
Use this function to query histories from Home Assistant SQLite database.
Example:
Question: When did bedroom light turn on?
Answer: SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') last_updated_ts FROM states s INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id INNER JOIN states old ON s.old_state_id = old.state_id WHERE sm.entity_id = 'light.bedroom' AND s.state = 'on' AND s.state != old.state ORDER BY s.last_updated_ts DESC LIMIT 1
Question: Was livingroom light on at 9 am?
Answer: SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') last_updated, s.state FROM states s INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id INNER JOIN states old ON s.old_state_id = old.state_id WHERE sm.entity_id = 'switch.livingroom' AND s.state != old.state AND datetime(s.last_updated_ts, 'unixepoch', 'localtime') < '2023-11-17 08:00:00' ORDER BY s.last_updated_ts DESC LIMIT 1
parameters:
type: object
properties:
query:
type: string
description: A fully formed SQL query.
function:
type: sqlite
```

Get last changed date time of state | Get state at specific time
--|--
<img width="300" alt="스크린샷 2023-11-19 오후 5 32 56" src="https://github.com/jekalmin/extended_openai_conversation/assets/2917984/5a25db59-f66c-4dfd-9e7b-ae6982ed3cd2"> |<img width="300" alt="스크린샷 2023-11-19 오후 5 32 30" src="https://github.com/jekalmin/extended_openai_conversation/assets/2917984/51faaa26-3294-4f96-b115-c71b268b708e">


**FAQ**
1. Can gpt modify or delete data?
> No, since connection is created in a read only mode, data are only used for fetching.
2. Can gpt query data that are not exposed in database?
> Yes, it is hard to validate whether a query is only using exposed entities.
3. Query uses UTC time. Is there any way to adjust timezone?
> Yes. Set "TZ" environment variable to your [region](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) (eg. `Asia/Seoul`). <br/>
Or use plus/minus hours to adjust instead of 'localtime' (eg. `datetime(s.last_updated_ts, 'unixepoch', '+9 hours')`).


#### 7-2. Let model generate a query (with minimum validation)
- If need to check at least "entity_id" of exposed entities is present in a query, use "is_exposed_entity_in_query" in combination with "raise".
- Not secured enough, but flexible way
```yaml
- spec:
name: query_histories_from_db
description: >-
Use this function to query histories from Home Assistant SQLite database.
Example:
Question: When did bedroom light turn on?
Answer: SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') last_updated_ts FROM states s INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id INNER JOIN states old ON s.old_state_id = old.state_id WHERE sm.entity_id = 'light.bedroom' AND s.state = 'on' AND s.state != old.state ORDER BY s.last_updated_ts DESC LIMIT 1
Question: Was livingroom light on at 9 am?
Answer: SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') last_updated, s.state FROM states s INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id INNER JOIN states old ON s.old_state_id = old.state_id WHERE sm.entity_id = 'switch.livingroom' AND s.state != old.state AND datetime(s.last_updated_ts, 'unixepoch', 'localtime') < '2023-11-17 08:00:00' ORDER BY s.last_updated_ts DESC LIMIT 1
parameters:
type: object
properties:
query:
type: string
description: A fully formed SQL query.
function:
type: sqlite
query: >-
{%- if is_exposed_entity_in_query(query) -%}
{{ query }}
{%- else -%}
{{ raise("entity_id should be exposed.") }}
{%- endif -%}
```

#### 7-3. Defined SQL manually
- Use a user defined query, which is verified. And model passes a requested entity to get data from database.
- Secured, but less flexible way
```yaml
- spec:
name: get_last_updated_time_of_entity
description: >
Use this function to get last updated time of entity
parameters:
type: object
properties:
entity_id:
type: string
description: The target entity
function:
type: sqlite
query: >-
{%- if is_exposed(entity_id) -%}
SELECT datetime(s.last_updated_ts, 'unixepoch', 'localtime') as last_updated_ts
FROM states s
INNER JOIN states_meta sm ON s.metadata_id = sm.metadata_id
INNER JOIN states old ON s.old_state_id = old.state_id
WHERE sm.entity_id = '{{entity_id}}' AND s.state != old.state ORDER BY s.last_updated_ts DESC LIMIT 1
{%- else -%}
{{ raise("entity_id should be exposed.") }}
{%- endif -%}
```

## Practical Usage
See more practical [examples](https://github.com/jekalmin/extended_openai_conversation/tree/main/examples).

Expand Down
101 changes: 89 additions & 12 deletions custom_components/extended_openai_conversation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import os
import yaml
import time
import sqlite3
from bs4 import BeautifulSoup
from typing import Any
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from openai.error import AuthenticationError
from urllib import parse

from homeassistant.components import automation, rest, scrape
from homeassistant.components.automation.config import _async_validate_config_item
Expand All @@ -22,7 +24,7 @@
CONF_ATTRIBUTE,
)
from homeassistant.config import AUTOMATION_CONFIG_PATH
from homeassistant.components import conversation
from homeassistant.components import conversation, recorder
from homeassistant.core import HomeAssistant
from homeassistant.helpers.template import Template
from homeassistant.helpers.script import (
Expand Down Expand Up @@ -128,7 +130,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
"""execute function"""


Expand All @@ -143,7 +145,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
name = function["name"]
if name == "execute_service":
return await self.execute_service(
Expand All @@ -163,7 +165,7 @@ async def execute_service(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
result = []
for service_argument in arguments.get("list", []):
domain = service_argument["domain"]
Expand Down Expand Up @@ -198,7 +200,7 @@ async def execute_service(
_LOGGER.error(e)
result.append(False)

return str(result)
return result

async def add_automation(
self,
Expand All @@ -207,7 +209,7 @@ async def add_automation(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
automation_config = yaml.safe_load(arguments["automation_config"])
config = {"id": str(round(time.time() * 1000))}
if isinstance(automation_config, list):
Expand Down Expand Up @@ -252,7 +254,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
script = Script(
hass,
function["sequence"],
Expand All @@ -279,7 +281,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
return Template(function["value_template"], hass).async_render(
arguments,
parse_result=False,
Expand All @@ -297,7 +299,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
config = function
rest_data = _get_rest_data(hass, config, arguments)

Expand All @@ -324,7 +326,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
config = function
rest_data = _get_rest_data(hass, config, arguments)
coordinator = scrape.coordinator.ScrapeCoordinator(
Expand Down Expand Up @@ -408,7 +410,7 @@ async def execute(
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
) -> str:
):
config = function
sequence = config["sequence"]

Expand All @@ -420,11 +422,85 @@ async def execute(

response_variable = executor_config.get("response_variable")
if response_variable:
arguments[response_variable] = str(result)
arguments[response_variable] = result

return result


class SqliteFunctionExecutor(FunctionExecutor):
def __init__(self) -> None:
"""initialize sqlite function"""

def is_exposed(self, entity_id, exposed_entities) -> bool:
return any(
exposed_entity["entity_id"] == entity_id
for exposed_entity in exposed_entities
)

def is_exposed_entity_in_query(self, query: str, exposed_entities) -> bool:
exposed_entity_ids = list(
map(lambda e: f"'{e['entity_id']}'", exposed_entities)
)
return any(
exposed_entity_id in query for exposed_entity_id in exposed_entity_ids
)

def raise_error(self, msg="Unexpected error occurred."):
raise HomeAssistantError(msg)

def get_default_db_url(self, hass: HomeAssistant) -> str:
db_file_path = os.path.join(hass.config.config_dir, recorder.DEFAULT_DB_FILE)
return f"file:{db_file_path}?mode=ro"

def set_url_read_only(self, url: str) -> str:
scheme, netloc, path, query_string, fragment = parse.urlsplit(url)
query_params = parse.parse_qs(query_string)

query_params["mode"] = ["ro"]
new_query_string = parse.urlencode(query_params, doseq=True)

return parse.urlunsplit((scheme, netloc, path, new_query_string, fragment))

async def execute(
self,
hass: HomeAssistant,
function,
arguments,
user_input: conversation.ConversationInput,
exposed_entities,
):
db_url = self.set_url_read_only(
function.get("db_url", self.get_default_db_url(hass))
)
query = function.get("query", "{{query}}")

template_arguments = {
"is_exposed": lambda e: self.is_exposed(e, exposed_entities),
"is_exposed_entity_in_query": lambda q: self.is_exposed_entity_in_query(
q, exposed_entities
),
"exposed_entities": exposed_entities,
"raise": self.raise_error,
}
template_arguments.update(arguments)

q = Template(query, hass).async_render(template_arguments)
_LOGGER.info("Rendered query: %s", q)
with sqlite3.connect(db_url, uri=True) as conn:
cursor = conn.execute(q)
names = [description[0] for description in cursor.description]

if function.get("single") is True:
row = cursor.fetchone()
return {name: val for name, val in zip(names, row)}

rows = cursor.fetchall()
result = []
for row in rows:
result.append({name: val for name, val in zip(names, row)})
return result


FUNCTION_EXECUTORS: dict[str, FunctionExecutor] = {
"predefined": NativeFunctionExecutor(),
"native": NativeFunctionExecutor(),
Expand All @@ -433,4 +509,5 @@ async def execute(
"rest": RestFunctionExecutor(),
"scrape": ScrapeFunctionExecutor(),
"composite": CompositeFunctionExecutor(),
"sqlite": SqliteFunctionExecutor(),
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
"requirements": [
"openai==0.27.2"
],
"version": "0.0.7"
"version": "0.0.8"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"config": {
"error": {
"cannot_connect": "Fehler beim Verbinden",
"invalid_auth": "Authentifizierung fehlgeschlagen",
"unknown": "Unbekannter Fehler"
},
"step": {
"user": {
"data": {
"name": "Name",
"api_key": "API Key",
"base_url": "Base Url"
}
}
}
},
"options": {
"step": {
"init": {
"data": {
"max_tokens": "Maximale Anzahl an Tokens, die in einer Antwort zurückgegeben werden",
"model": "Completion Model",
"prompt": "Prompt Vorlage",
"temperature": "Temperatur",
"top_p": "Top P",
"max_function_calls_per_conversation": "Maximale Anzahl an Funktionsaufrufen pro Konversation",
"functions": "Funktionen"
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"config": {
"error": {
"cannot_connect": "Verbinden mislukt",
"invalid_auth": "Ongeldige authenticatie",
"unknown": "Onverwachte fout"
},
"step": {
"user": {
"data": {
"name": "Naam",
"api_key": "API-sleutel",
"base_url": "Basis-URL"
}
}
}
},
"options": {
"step": {
"init": {
"data": {
"max_tokens": "Maximale aantal tokens dat mag worden gegenereerd",
"model": "Completion Model",
"prompt": "Prompt Sjabloon",
"temperature": "Temperatuur",
"top_p": "Top P",
"max_function_calls_per_conversation": "Maximale keren functies mogen worden aangeroepen per conversatie",
"functions": "Functies"
}
}
}
}
}
Loading

0 comments on commit 78f7ac2

Please sign in to comment.