Skip to content

Commit

Permalink
Pagination for get by group_id (#218)
Browse files Browse the repository at this point in the history
* add pagination to subgraphs

* update pagination

* update LiteralString import

* cleanup

* cleanup

* update embedding dims
  • Loading branch information
prasmussen15 authored Dec 2, 2024
1 parent 397291d commit 0fbe5c0
Show file tree
Hide file tree
Showing 12 changed files with 123 additions and 35 deletions.
3 changes: 1 addition & 2 deletions examples/podcast/transcript_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import re
from datetime import datetime, timedelta, timezone
from typing import List

from pydantic import BaseModel

Expand Down Expand Up @@ -36,7 +35,7 @@ def parse_timestamp(timestamp: str) -> timedelta:
return timedelta() # Return 0 duration if parsing fails


def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[ParsedMessage]:
def parse_conversation_file(file_path: str, speakers: list[Speaker]) -> list[ParsedMessage]:
with open(file_path) as file:
content = file.read()

Expand Down
3 changes: 1 addition & 2 deletions graphiti_core/cross_encoder/bge_reranker_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""

import asyncio
from typing import List, Tuple

from sentence_transformers import CrossEncoder

Expand All @@ -26,7 +25,7 @@ class BGERerankerClient(CrossEncoderClient):
def __init__(self):
self.model = CrossEncoder('BAAI/bge-reranker-v2-m3')

async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
if not passages:
return []

Expand Down
7 changes: 3 additions & 4 deletions graphiti_core/cross_encoder/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""

from abc import ABC, abstractmethod
from typing import List, Tuple


class CrossEncoderClient(ABC):
Expand All @@ -26,16 +25,16 @@ class CrossEncoderClient(ABC):
"""

@abstractmethod
async def rank(self, query: str, passages: List[str]) -> List[Tuple[str, float]]:
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
"""
Rank the given passages based on their relevance to the query.
Args:
query (str): The query string.
passages (List[str]): A list of passages to rank.
passages (list[str]): A list of passages to rank.
Returns:
List[Tuple[str, float]]: A list of tuples containing the passage and its score,
List[tuple[str, float]]: A list of tuples containing the passage and its score,
sorted in descending order of relevance.
"""
pass
56 changes: 51 additions & 5 deletions graphiti_core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@

from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from typing_extensions import LiteralString

from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE, parse_db_date
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT, parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_SAVE,
ENTITY_EDGE_SAVE,
Expand All @@ -50,7 +51,7 @@ async def save(self, driver: AsyncDriver): ...
async def delete(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MATCH (n)-[e {uuid: $uuid}]->(m)
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
Expand Down Expand Up @@ -137,19 +138,34 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
return edges

@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''

records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
ORDER BY e.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
Expand Down Expand Up @@ -274,11 +290,22 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
return edges

@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''

records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
e.uuid AS uuid,
n.uuid AS source_node_uuid,
Expand All @@ -292,8 +319,12 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at
ORDER BY e.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
Expand Down Expand Up @@ -365,19 +396,34 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
return edges

@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''

records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
e.uuid As uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at
ORDER BY e.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
Expand Down
6 changes: 3 additions & 3 deletions graphiti_core/embedder/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@
"""

from abc import ABC, abstractmethod
from typing import Iterable, List, Literal
from collections.abc import Iterable

from pydantic import BaseModel, Field

EMBEDDING_DIM = 1024


class EmbedderConfig(BaseModel):
embedding_dim: Literal[1024] = Field(default=EMBEDDING_DIM, frozen=True)
embedding_dim: int = Field(default=EMBEDDING_DIM, frozen=True)


class EmbedderClient(ABC):
@abstractmethod
async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
pass
4 changes: 2 additions & 2 deletions graphiti_core/embedder/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from typing import Iterable, List
from collections.abc import Iterable

from openai import AsyncOpenAI
from openai.types import EmbeddingModel
Expand Down Expand Up @@ -42,7 +42,7 @@ def __init__(self, config: OpenAIEmbedderConfig | None = None):
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)

async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
result = await self.client.embeddings.create(
input=input_data, model=self.config.embedding_model
Expand Down
6 changes: 3 additions & 3 deletions graphiti_core/embedder/voyage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
"""

from typing import Iterable, List
from collections.abc import Iterable

import voyageai # type: ignore
from pydantic import Field
Expand All @@ -41,11 +41,11 @@ def __init__(self, config: VoyageAIEmbedderConfig | None = None):
self.client = voyageai.AsyncClient(api_key=config.api_key)

async def create(
self, input_data: str | List[str] | Iterable[int] | Iterable[Iterable[int]]
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
if isinstance(input_data, str):
input_list = [input_data]
elif isinstance(input_data, List):
elif isinstance(input_data, list):
input_list = [str(i) for i in input_data if i]
else:
input_list = [str(i) for i in input_data if i is not None]
Expand Down
1 change: 1 addition & 0 deletions graphiti_core/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DEFAULT_DATABASE = os.getenv('DEFAULT_DATABASE', None)
USE_PARALLEL_RUNTIME = bool(os.getenv('USE_PARALLEL_RUNTIME', False))
MAX_REFLEXION_ITERATIONS = 2
DEFAULT_PAGE_LIMIT = 20


def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None:
Expand Down
54 changes: 50 additions & 4 deletions graphiti_core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@

from neo4j import AsyncDriver
from pydantic import BaseModel, Field
from typing_extensions import LiteralString

from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import DEFAULT_DATABASE
from graphiti_core.helpers import DEFAULT_DATABASE, DEFAULT_PAGE_LIMIT
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_SAVE,
ENTITY_NODE_SAVE,
Expand Down Expand Up @@ -207,10 +208,21 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
return episodes

@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND e.created_at < $created_at' if created_at else ''

records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic) WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN DISTINCT
e.content AS content,
e.created_at AS created_at,
Expand All @@ -220,8 +232,12 @@ async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source
ORDER BY e.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
Expand Down Expand Up @@ -308,19 +324,34 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
return nodes

@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''

records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity) WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
ORDER BY n.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
Expand Down Expand Up @@ -407,19 +438,34 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]):
return communities

@classmethod
async def get_by_group_ids(cls, driver: AsyncDriver, group_ids: list[str]):
async def get_by_group_ids(
cls,
driver: AsyncDriver,
group_ids: list[str],
limit: int = DEFAULT_PAGE_LIMIT,
created_at: datetime | None = None,
):
cursor_query: LiteralString = 'AND n.created_at < $created_at' if created_at else ''

records, _, _ = await driver.execute_query(
"""
MATCH (n:Community) WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
n.uuid As uuid,
n.name AS name,
n.name_embedding AS name_embedding,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary
ORDER BY n.uuid DESC
LIMIT $limit
""",
group_ids=group_ids,
created_at=created_at,
limit=limit,
database_=DEFAULT_DATABASE,
routing_='r',
)
Expand Down
2 changes: 1 addition & 1 deletion graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

logger = logging.getLogger(__name__)

RELEVANT_SCHEMA_LIMIT = 3
RELEVANT_SCHEMA_LIMIT = 10
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
Expand Down
Loading

0 comments on commit 0fbe5c0

Please sign in to comment.