Skip to content

Commit

Permalink
Merge pull request #245 from ag2ai/neo4j_fix
Browse files Browse the repository at this point in the history
Neo4j fix
  • Loading branch information
AgentGenie authored Dec 20, 2024
2 parents c16b9b3 + ea78a01 commit a7a50fe
Show file tree
Hide file tree
Showing 6 changed files with 465 additions and 113 deletions.
160 changes: 109 additions & 51 deletions autogen/agentchat/contrib/graph_rag/neo4j_graph_query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from llama_index.core import PropertyGraphIndex, SimpleDirectoryReader
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.indices.property_graph import SchemaLLMPathExtractor
from llama_index.core.indices.property_graph import (
DynamicLLMPathExtractor,
SchemaLLMPathExtractor,
)
from llama_index.core.indices.property_graph.transformations.schema_llm import Triple
from llama_index.core.llms import LLM
from llama_index.embeddings.openai import OpenAIEmbedding
Expand All @@ -19,14 +22,21 @@

class Neo4jGraphQueryEngine(GraphQueryEngine):
"""
This class serves as a wrapper for a Neo4j database-backed PropertyGraphIndex query engine,
facilitating the creation, updating, and querying of graphs.
This class serves as a wrapper for a property graph query engine backed by LlamaIndex and Neo4j,
facilitating the creating, connecting, updating, and querying of LlamaIndex property graphs.
It builds a PropertyGraph Index from input documents,
storing and retrieving data from a property graph in the Neo4j database.
It builds a property graph Index from input documents,
storing and retrieving data from the property graph in the Neo4j database.
Using SchemaLLMPathExtractor, it defines schemas with entities, relationships, and other properties based on the input,
which are added into the preprty graph.
It extracts triplets, i.e., [entity] -> [relationship] -> [entity] sets,
from the input documents using llamIndex extractors.
Users can provide custom entities, relationships, and schema to guide the extraction process.
If strict is True, the engine will extract triplets following the schema
of allowed relationships for each entity specified in the schema.
It also leverages LlamaIndex's chat engine which has a conversation history internally to provide context-aware responses.
For usage, please refer to example notebook/agentchat_graph_rag_neo4j.ipynb
"""
Expand All @@ -42,8 +52,8 @@ def __init__(
embedding: BaseEmbedding = OpenAIEmbedding(model_name="text-embedding-3-small"),
entities: Optional[TypeAlias] = None,
relations: Optional[TypeAlias] = None,
validation_schema: Optional[Union[Dict[str, str], List[Triple]]] = None,
strict: Optional[bool] = True,
schema: Optional[Union[Dict[str, str], List[Triple]]] = None,
strict: Optional[bool] = False,
):
"""
Initialize a Neo4j Property graph.
Expand All @@ -56,12 +66,12 @@ def __init__(
database (str): Neo4j database name.
username (str): Neo4j username.
password (str): Neo4j password.
llm (LLM): Language model to use for extracting tripletss.
llm (LLM): Language model to use for extracting triplets.
embedding (BaseEmbedding): Embedding model to use constructing index and query
entities (Optional[TypeAlias]): Custom possible entities to include in the graph.
relations (Optional[TypeAlias]): Custom poissble relations to include in the graph.
validation_schema (Optional[Union[Dict[str, str], List[Triple]]): Custom schema to validate the extracted triplets
strict (Optional[bool]): If false, allows for values outside of the schema, useful for using the schema as a suggestion.
entities (Optional[TypeAlias]): Custom suggested entities to include in the graph.
relations (Optional[TypeAlias]): Custom suggested relations to include in the graph.
schema (Optional[Union[Dict[str, str], List[Triple]]): Custom schema to specify allowed relationships for each entity.
strict (Optional[bool]): If false, allows for values outside of the input schema.
"""
self.host = host
self.port = port
Expand All @@ -72,19 +82,15 @@ def __init__(
self.embedding = embedding
self.entities = entities
self.relations = relations
self.validation_schema = validation_schema
self.schema = schema
self.strict = strict

def init_db(self, input_doc: List[Document] | None = None):
"""
Build the knowledge graph with input documents.
"""
self.input_files = []
for doc in input_doc:
if os.path.exists(doc.path_or_url):
self.input_files.append(doc.path_or_url)
else:
raise ValueError(f"Document file not found: {doc.path_or_url}")

self.documents = self._load_doc(input_doc)

self.graph_store = Neo4jPropertyGraphStore(
username=self.username,
Expand All @@ -96,19 +102,8 @@ def init_db(self, input_doc: List[Document] | None = None):
# delete all entities and relationships in case a graph pre-exists
self._clear()

self.documents = SimpleDirectoryReader(input_files=self.input_files).load_data()

# Extract paths following a strict schema of allowed entities, relationships, and which entities can be connected to which relationships.
# To add more extractors, please refer to https://docs.llamaindex.ai/en/latest/module_guides/indexing/lpg_index_guide/#construction
self.kg_extractors = [
SchemaLLMPathExtractor(
llm=self.llm,
possible_entities=self.entities,
possible_relations=self.relations,
kg_validation_schema=self.validation_schema,
strict=self.strict,
)
]
# Create knowledge graph extractors.
self.kg_extractors = self._create_kg_extractors()

self.index = PropertyGraphIndex.from_documents(
self.documents,
Expand All @@ -118,6 +113,26 @@ def init_db(self, input_doc: List[Document] | None = None):
show_progress=True,
)

def connect_db(self):
"""
Connect to an existing knowledge graph database.
"""
self.graph_store = Neo4jPropertyGraphStore(
username=self.username,
password=self.password,
url=self.host + ":" + str(self.port),
database=self.database,
)

self.kg_extractors = self._create_kg_extractors()

self.index = PropertyGraphIndex.from_existing(
property_graph_store=self.graph_store,
kg_extractors=self.kg_extractors,
embed_model=self.embedding,
show_progress=True,
)

def add_records(self, new_records: List) -> bool:
"""
Add new records to the knowledge graph. Must be local files.
Expand All @@ -129,7 +144,7 @@ def add_records(self, new_records: List) -> bool:
bool: True if successful, False otherwise.
"""
if self.graph_store is None:
raise ValueError("Knowledge graph is not initialized. Please call init_db first.")
raise ValueError("Knowledge graph is not initialized. Please call init_db or connect_db first.")

try:
"""
Expand All @@ -149,7 +164,11 @@ def add_records(self, new_records: List) -> bool:

def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryResult:
"""
Query the knowledge graph with a question.
Query the property graph with a question using LlamaIndex chat engine.
We use the condense_plus_context chat mode
which condenses the conversation history and the user query into a standalone question,
and then build a context for the standadlone question
from the property graph to generate a response.
Args:
question: a human input question.
Expand All @@ -158,23 +177,15 @@ def query(self, question: str, n_results: int = 1, **kwargs) -> GraphStoreQueryR
Returns:
A GrapStoreQueryResult object containing the answer and related triplets.
"""
if self.graph_store is None:
raise ValueError("Knowledge graph is not created.")
if not hasattr(self, "index"):
raise ValueError("Property graph index is not created.")

# query the graph to get the answer
query_engine = self.index.as_query_engine(include_text=True)
response = str(query_engine.query(question))
# Initialize chat engine if not already initialized
if not hasattr(self, "chat_engine"):
self.chat_engine = self.index.as_chat_engine(chat_mode="condense_plus_context", llm=self.llm)

# retrieve source triplets that are semantically related to the question
retriever = self.index.as_retriever(include_text=False)
nodes = retriever.retrieve(question)
triplets = []
for node in nodes:
entities = [sub.split("(")[0].strip() for sub in node.text.split("->")]
triplet = " -> ".join(entities)
triplets.append(triplet)

return GraphStoreQueryResult(answer=response, results=triplets)
response = self.chat_engine.chat(question)
return GraphStoreQueryResult(answer=str(response))

def _clear(self) -> None:
"""
Expand All @@ -183,3 +194,50 @@ def _clear(self) -> None:
"""
with self.graph_store._driver.session() as session:
session.run("MATCH (n) DETACH DELETE n;")

def _load_doc(self, input_doc: List[Document]) -> List[Document]:
"""
Load documents from the input files.
"""
input_files = []
for doc in input_doc:
if os.path.exists(doc.path_or_url):
input_files.append(doc.path_or_url)
else:
raise ValueError(f"Document file not found: {doc.path_or_url}")

return SimpleDirectoryReader(input_files=input_files).load_data()

def _create_kg_extractors(self):
"""
If strict is True,
extract paths following a strict schema of allowed relationships for each entity.
If strict is False,
auto-create relationships and schema that fit the graph
# To add more extractors, please refer to https://docs.llamaindex.ai/en/latest/module_guides/indexing/lpg_index_guide/#construction
"""

#
kg_extractors = [
SchemaLLMPathExtractor(
llm=self.llm,
possible_entities=self.entities,
possible_relations=self.relations,
kg_validation_schema=self.schema,
strict=self.strict,
),
]

# DynamicLLMPathExtractor will auto-create relationships and schema that fit the graph
if not self.strict:
kg_extractors.append(
DynamicLLMPathExtractor(
llm=self.llm,
allowed_entity_types=self.entities,
allowed_relation_types=self.relations,
)
)

return kg_extractors
370 changes: 316 additions & 54 deletions notebook/agentchat_graph_rag_neo4j.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions notebook/neo4j_property_graph_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions notebook/neo4j_property_graph_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions notebook/neo4j_property_graph_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
36 changes: 34 additions & 2 deletions test/agentchat/contrib/graph_rag/test_neo4j_graph_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def neo4j_query_engine():
]

# define which entities can have which relations
validation_schema = {
schema = {
"EMPLOYEE": ["FOLLOWS", "APPLIES_TO", "ASSIGNED_TO", "ENTITLED_TO", "REPORTS_TO"],
"EMPLOYER": ["PROVIDES", "DEFINED_AS", "MANAGES", "REQUIRES"],
"POLICY": ["APPLIES_TO", "DEFINED_AS", "REQUIRES"],
Expand All @@ -69,7 +69,7 @@ def neo4j_query_engine():
database="neo4j", # Change if you want to store the graphh in your custom database
entities=entities, # possible entities
relations=relations, # possible relations
validation_schema=validation_schema, # schema to validate the extracted triplets
schema=schema,
strict=True, # enofrce the extracted triplets to be in the schema
)

Expand All @@ -78,6 +78,23 @@ def neo4j_query_engine():
return query_engine


# Test fixture to test auto-generation without given schema
@pytest.fixture(scope="module")
def neo4j_query_engine_auto():
"""
Test the engine with auto-generated property graph
"""
query_engine = Neo4jGraphQueryEngine(
username="neo4j",
password="password",
host="bolt://172.17.0.3",
port=7687,
database="neo4j",
)
query_engine.connect_db() # Connect to the existing graph
return query_engine


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip or skip_openai,
reason=reason,
Expand Down Expand Up @@ -117,3 +134,18 @@ def test_neo4j_add_records(neo4j_query_engine):
print(query_result.answer)

assert query_result.answer.find("Keanu Reeves") >= 0


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"] or skip or skip_openai,
reason=reason,
)
def test_neo4j_auto(neo4j_query_engine_auto):
"""
Test querying with auto-generated property graph
"""
question = "Which company is the employer?"
query_result: GraphStoreQueryResult = neo4j_query_engine_auto.query(question=question)

print(query_result.answer)
assert query_result.answer.find("BUZZ") >= 0

0 comments on commit a7a50fe

Please sign in to comment.