-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
112 lines (83 loc) · 3.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_experimental.sql.base import SQLDatabaseChain
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_community.utilities import SQLDatabase
from langchain.chains import LLMChain
from langchain.memory import ConversationBufferMemory
import configparser
import os
import retrying
def read_properties_file(file_path):
# Check if the file exists
if not os.path.exists(file_path):
raise FileNotFoundError(f"The file '{file_path}' does not exist.")
# Initialize the configparser
config = configparser.ConfigParser()
# Read the properties file
config.read(file_path)
# Access values
db_path = config['DEFAULT']['db_path']
gemini_api_key = config['DEFAULT']['gemini_api_key']
return db_path, gemini_api_key
def get_property():
# Path to the properties file
file_path = 'config.properties'
try:
db_path, gemini_api_key = read_properties_file(file_path)
print("Database path:", db_path)
print("Gemini API Key", gemini_api_key)
return db_path, gemini_api_key
except FileNotFoundError as e:
print(e)
raise e
def get_llm(gemini_api_key):
"""
Creata an instance of og Google Gemini Pro
returns:
- llm: An instance of Google Gemini Pro
"""
# Create llm
llm = ChatGoogleGenerativeAI(model="gemini-pro", google_api_key=gemini_api_key,
convert_system_message_to_human=True, temperature=0.0)
return llm
def db_connection(db_path):
db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
print(db.dialect)
print(db.get_usable_table_names())
resp = db.run("SELECT * FROM Employees LIMIT 10;")
print(resp)
return db
def create_conversational_chain():
try:
db, gemini_api_key = get_property()
# Get the instance of LLM
llm = get_llm(gemini_api_key)
# Get the DB connection
db = db_connection(db)
sql_prompt_template = """
Only use the following tables:
{table_info}
Question: {input}
Given an input question, first create a syntactically correct
{dialect} query to run.
Relevant pieces of previous conversation:
{history}
(You do not need to use these pieces of information if not relevant)
Dont include ```, ```sql and \n in the output.
"""
prompt = PromptTemplate(
input_variables=["input", "table_info", "dialect", "history"],
template=sql_prompt_template,
)
memory = ConversationBufferMemory(memory_key="history")
db_chain = SQLDatabaseChain.from_llm(
llm, db, memory=memory, prompt=prompt, return_direct=True, verbose=True
)
output_parser = StrOutputParser()
chain = llm | output_parser
except Exception as e:
raise e
return db_chain, chain