Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix docs search #311

Merged
merged 7 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions agents-api/agents_api/activities/summarization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
entries_summarization_query,
)
from agents_api.common.protocol.entries import Entry
from ..env import summarization_model_name
from ..model_registry import JULEP_MODELS
from ..env import summarization_model_name, model_inference_url, model_api_key


example_previous_memory = """
Expand Down Expand Up @@ -128,6 +129,12 @@ async def run_prompt(
parser: Callable[[str], str] = lambda x: x,
**kwargs,
) -> str:
api_base = None
api_key = None
if model in JULEP_MODELS:
api_base = model_inference_url
api_key = model_api_key
model = f"openai/{model}"
prompt = make_prompt(dialog, previous_memories, **kwargs)
response = await acompletion(
model=model,
Expand All @@ -141,6 +148,8 @@ async def run_prompt(
temperature=temperature,
stop=["<", "<|"],
stream=False,
api_base=api_base,
api_key=api_key,
)

content = response.choices[0].message.content
Expand All @@ -159,7 +168,7 @@ async def summarization(session_id: str) -> None:
assert len(entries) > 0, "no need to summarize on empty entries list"

response = await run_prompt(
dialog=entries, previous_memories=[], model=summarization_model_name
dialog=entries, previous_memories=[], model=f"openai/{summarization_model_name}"
)

new_entry = Entry(
Expand Down
16 changes: 7 additions & 9 deletions agents-api/agents_api/models/entry/proc_mem_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def proc_mem_context_query(

Parameters:
session_id (UUID),
tool_query_embedding (list[float]),
doc_query_embedding (list[float]),
tools_confidence (float),
docs_confidence (float),
Expand All @@ -29,9 +28,8 @@ def proc_mem_context_query(
Return type:
A pandas DataFrame containing the query results.
"""
VECTOR_SIZE = 768
VECTOR_SIZE = 1024
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
session_id = str(session_id)
assert len(tool_query_embedding) == len(doc_query_embedding) == VECTOR_SIZE

tools_radius: float = 1.0 - tools_confidence
docs_radius: float = 1.0 - docs_confidence
Expand All @@ -41,14 +39,14 @@ def proc_mem_context_query(
{{
# Input table for the query
# (This is temporary to this query)
input[session_id, tool_query, doc_query] <- [[
input[session_id, doc_query] <- [[
to_uuid($session_id),
$tool_query_embedding,
# $tool_query_embedding,
$doc_query_embedding,
]]

?[session_id, tool_query, doc_query, agent_id, user_id] :=
input[session_id, tool_query, doc_query],
?[session_id, doc_query, agent_id, user_id] :=
input[session_id, doc_query],
*session_lookup{{
session_id,
agent_id,
Expand All @@ -59,7 +57,7 @@ def proc_mem_context_query(
session_id: Uuid,
agent_id: Uuid,
user_id: Uuid,
tool_query: <F32; {VECTOR_SIZE}>,
# tool_query: <F32; {VECTOR_SIZE}>,
doc_query: <F32; {VECTOR_SIZE}>,
}}
}} {{
Expand Down Expand Up @@ -130,7 +128,7 @@ def proc_mem_context_query(

# Search for tools
?[role, name, content, token_count, created_at, index] :=
*_input{{agent_id, tool_query}},
#*_input{{agent_id, tool_query}},
# ~agent_functions:embedding_space {{
# agent_id,
# name: fn_name,
Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/models/entry/test_entry_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def _():
client=client,
),
embed_docs_snippets_query(
agent_doc_id, snippet_indices=[0], embeddings=[[1.0] * 768], client=client
agent_doc_id, snippet_indices=[0], embeddings=[[1.0] * 1024], client=client
),
embed_docs_snippets_query(
user_doc_id, snippet_indices=[0], embeddings=[[1.0] * 768], client=client
user_doc_id, snippet_indices=[0], embeddings=[[1.0] * 1024], client=client
),
]

Expand All @@ -185,7 +185,7 @@ def _():
result = proc_mem_context_query(
session_id=session_id,
tool_query_embedding=[0.9] * 768,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tool_query_embedding size should be updated to 1024 to match the new VECTOR_SIZE defined in proc_mem_context.py.

Suggested change
tool_query_embedding=[0.9] * 768,
tool_query_embedding=[0.9] * 1024,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tool_query_embedding size should be updated to 1024 to match the new VECTOR_SIZE in proc_mem_context.py.

Suggested change
tool_query_embedding=[0.9] * 768,
tool_query_embedding=[0.9] * 1024,

doc_query_embedding=[0.9] * 768,
doc_query_embedding=[0.9] * 1024,
client=client,
)

Expand Down
57 changes: 31 additions & 26 deletions agents-api/agents_api/routers/agents/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
embed_docs_snippets_query,
)
from agents_api.models.tools.create_tools import create_function_query
from agents_api.models.tools.embed_tools import embed_functions_query
from agents_api.models.tools.list_tools import list_functions_by_agent_query
from agents_api.models.tools.get_tools import get_function_by_id_query
from agents_api.models.tools.delete_tools import delete_function_by_id_query
Expand Down Expand Up @@ -247,30 +246,36 @@ async def create_agent(

if request.docs:
for info in request.docs:
content = [
(c.model_dump() if isinstance(c, ContentItem) else c)
for c in (
[info.content] if isinstance(info.content, str) else info.content
)
]
create_docs_query(
owner_type="agent",
owner_id=new_agent_id,
id=uuid4(),
title=info.title,
content=info.content,
content=content,
metadata=info.metadata or {},
)

if request.tools:
functions = [t.function for t in request.tools]
embeddings = await embed(
[
function_embed_instruction
+ f"{function.name}, {function.description}, "
+ "required_params:"
+ function.parameters.model_dump_json()
for function in functions
]
)
# embeddings = await embed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The embedding functionality is commented out here and in several other places. If embedding is supposed to be integrated as per the PR description, these should be enabled or correctly implemented. Please review and ensure the embedding integration is complete.

# [
# function_embed_instruction
# + f"{function.name}, {function.description}, "
# + "required_params:"
# + function.parameters.model_dump_json()
# for function in functions
# ]
# )
create_tools_query(
new_agent_id,
functions,
embeddings,
[[0.0] * 768],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be 1024 or no? Just heads up

)

return res
Expand Down Expand Up @@ -435,20 +440,20 @@ async def create_tool(
created_at=resp["created_at"][0],
)

embeddings = await embed(
[
function_embed_instruction
+ request.function.description
+ "\nParameters: "
+ json.dumps(request.function.parameters.model_dump())
]
)

embed_functions_query(
agent_id=agent_id,
tool_ids=[tool_id],
embeddings=embeddings,
)
# embeddings = await embed(
# [
# function_embed_instruction
# + request.function.description
# + "\nParameters: "
# + json.dumps(request.function.parameters.model_dump())
# ]
# )

# embed_functions_query(
# agent_id=agent_id,
# tool_ids=[tool_id],
# embeddings=embeddings,
# )

return res

Expand Down
8 changes: 7 additions & 1 deletion agents-api/agents_api/routers/sessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from ...common.protocol.entries import Entry
from ...common.protocol.sessions import SessionData
from ...common.utils.template import render_template
from ...env import summarization_tokens_threshold
from ...env import (
summarization_tokens_threshold,
docs_embedding_service_url,
docs_embedding_model_id,
)
from ...model_registry import (
JULEP_MODELS,
get_extra_settings,
Expand Down Expand Up @@ -201,6 +205,8 @@ async def forward(
]
],
join_inputs=False,
embedding_service_url=docs_embedding_service_url,
embedding_model_name=docs_embedding_model_id,
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
)

entries: list[Entry] = []
Expand Down
2 changes: 1 addition & 1 deletion model-serving/model_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class Type(Enum):
class Tool(BaseModel):
type: Type
function: FunctionDef
id: str
id: str | None = None


class SamplingParams(SamplingParams):
Expand Down
Loading