-
Notifications
You must be signed in to change notification settings - Fork 0
/
calculate_embeddings.py
64 lines (50 loc) · 2.42 KB
/
calculate_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from neo4j import GraphDatabase, Query
from openai.embeddings_utils import get_embedding
import openai
import os
"""
LoadEmbedding: call OpenAI embedding API to generate embeddings for each property of node in Neo4j
Version: 1.1
"""
EMBEDDING_MODEL = "text-embedding-ada-002"
NEO4J_URL = os.environ['NEO4J_URI']
NEO4J_USER = os.environ['NEO4J_USER']
NEO4J_PASSWORD = os.environ['NEO4J_PASS']
OPENAI_KEY = os.environ['OPENAI_KEY']
class LoadEmbedding:
def __init__(self, uri, user, password):
self.driver = GraphDatabase.driver(uri, auth=(user, password))
openai.api_key = OPENAI_KEY
def close(self):
self.driver.close()
def load_embedding_to_node_property(self, node_label, node_property):
self.driver.verify_connectivity()
with self.driver.session(database="neo4j") as session:
result = session.run(f"""
MATCH (a:{node_label})
WHERE a.{node_property} IS NOT NULL
RETURN id(a) AS id, a.{node_property} AS node_property
""")
# call OpenAI embedding API to generate embeddings for each property of node
# for each node, update the embedding property
count = 0
for record in result.data():
id = record["id"]
text = record["node_property"]
# Below, instead of using the text as the input for embedding, we add label and property name in
# front of it
embedding = get_embedding(f"{node_label} {node_property} - {text}", EMBEDDING_MODEL)
# key property of Embedding node differentiates different embeddings
cypher = "CREATE (e:Embedding) SET e.key=$key, e.value=$embedding"
cypher = cypher + " WITH e MATCH (n) WHERE id(n) = $id CREATE (n) -[:HAS_EMBEDDING]-> (e)"
session.run(cypher, key=node_property, embedding=embedding, id=id)
count = count + 1
print("Processed " + str(count) + " " + node_label + " nodes for property @" + node_property + ".")
return count
if __name__ == "__main__":
loader = LoadEmbedding(NEO4J_URL, NEO4J_USER, NEO4J_PASSWORD)
loader.load_embedding_to_node_property("Control", "`Control Name`")
loader.load_embedding_to_node_property("Area", "Area")
loader.load_embedding_to_node_property("Requirement", "`Verification Requirement`")
print("done")
loader.close()