Skip to content

Commit

Permalink
fix: Disable tools embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
whiterabbit1983 committed May 7, 2024
1 parent 250fc30 commit 79eeeb6
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 34 deletions.
12 changes: 5 additions & 7 deletions agents-api/agents_api/models/entry/proc_mem_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def proc_mem_context_query(
Parameters:
session_id (UUID),
tool_query_embedding (list[float]),
doc_query_embedding (list[float]),
tools_confidence (float),
docs_confidence (float),
Expand All @@ -31,7 +30,6 @@ def proc_mem_context_query(
"""
VECTOR_SIZE = 1024
session_id = str(session_id)
assert len(tool_query_embedding) == len(doc_query_embedding) == VECTOR_SIZE

tools_radius: float = 1.0 - tools_confidence
docs_radius: float = 1.0 - docs_confidence
Expand All @@ -41,14 +39,14 @@ def proc_mem_context_query(
{{
# Input table for the query
# (This is temporary to this query)
input[session_id, tool_query, doc_query] <- [[
input[session_id, doc_query] <- [[
to_uuid($session_id),
$tool_query_embedding,
# $tool_query_embedding,
$doc_query_embedding,
]]
?[session_id, tool_query, doc_query, agent_id, user_id] :=
input[session_id, tool_query, doc_query],
?[session_id, doc_query, agent_id, user_id] :=
input[session_id, doc_query],
*session_lookup{{
session_id,
agent_id,
Expand All @@ -59,7 +57,7 @@ def proc_mem_context_query(
session_id: Uuid,
agent_id: Uuid,
user_id: Uuid,
tool_query: <F32; {VECTOR_SIZE}>,
# tool_query: <F32; {VECTOR_SIZE}>,
doc_query: <F32; {VECTOR_SIZE}>,
}}
}} {{
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/models/entry/test_entry_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def _():
client=client,
),
embed_docs_snippets_query(
agent_doc_id, snippet_indices=[0], embeddings=[[1.0] * 768], client=client
agent_doc_id, snippet_indices=[0], embeddings=[[1.0] * 1024], client=client
),
embed_docs_snippets_query(
user_doc_id, snippet_indices=[0], embeddings=[[1.0] * 768], client=client
user_doc_id, snippet_indices=[0], embeddings=[[1.0] * 1024], client=client
),
]

Expand All @@ -185,7 +185,7 @@ def _():
result = proc_mem_context_query(
session_id=session_id,
tool_query_embedding=[0.9] * 768,
doc_query_embedding=[0.9] * 768,
doc_query_embedding=[0.9] * 1024,
client=client,
)

Expand Down
48 changes: 24 additions & 24 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,19 +264,19 @@ async def create_agent(

if request.tools:
functions = [t.function for t in request.tools]
embeddings = await embed(
[
function_embed_instruction
+ f"{function.name}, {function.description}, "
+ "required_params:"
+ function.parameters.model_dump_json()
for function in functions
]
)
# embeddings = await embed(
# [
# function_embed_instruction
# + f"{function.name}, {function.description}, "
# + "required_params:"
# + function.parameters.model_dump_json()
# for function in functions
# ]
# )
create_tools_query(
new_agent_id,
functions,
embeddings,
[[0.0] * 768],
)

return res
Expand Down Expand Up @@ -441,20 +441,20 @@ async def create_tool(
created_at=resp["created_at"][0],
)

embeddings = await embed(
[
function_embed_instruction
+ request.function.description
+ "\nParameters: "
+ json.dumps(request.function.parameters.model_dump())
]
)

embed_functions_query(
agent_id=agent_id,
tool_ids=[tool_id],
embeddings=embeddings,
)
# embeddings = await embed(
# [
# function_embed_instruction
# + request.function.description
# + "\nParameters: "
# + json.dumps(request.function.parameters.model_dump())
# ]
# )

# embed_functions_query(
# agent_id=agent_id,
# tool_ids=[tool_id],
# embeddings=embeddings,
# )

return res

Expand Down

0 comments on commit 79eeeb6

Please sign in to comment.