-
Notifications
You must be signed in to change notification settings - Fork 0
/
vanna_test.py
32 lines (25 loc) · 973 Bytes
/
vanna_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import vanna
import pandas as pd
import mysql.connector
from vanna.remote import VannaDefault
def run_sql(sql: str) -> pd.DataFrame:
cnx = mysql.connector.connect(user='root',password='',host='localhost',database='TTS')
//链接本地数据库
cursor = cnx.cursor()
cursor.execute(sql)
result = cursor.fetchall()
columns = cursor.column_names
df = pd.DataFrame(result, columns=columns)
return df
api_key = 'ff271480a5114f0493b7706fe1010ba1'
vanna_model_name = 'tts'
//使用vanna链接LLM
vn = VannaDefault(model=vanna_model_name, api_key=api_key)
vn.run_sql = run_sql
vn.run_sql_is_set = True
//训练数据库查询命令
vn.train(question='统计不同民族数量?', sql='SELECT nation, COUNT(*) as count FROM customer GROUP BY nation ORDER BY count DESC;')
vn.ask('统计不同民族数量?')
from vanna.flask import VannaFlaskApp
app = VannaFlaskApp(vn, allow_llm_to_see_data=True)
app.run()