Skip to content

Commit

Permalink
Harrison/align table (langchain-ai#1081)
Browse files Browse the repository at this point in the history
Co-authored-by: Francisco Ingham <fpingham@gmail.com>
  • Loading branch information
2 people authored and dongreenberg committed Feb 17, 2023
1 parent 9860181 commit 45c2eae
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 42 deletions.
51 changes: 41 additions & 10 deletions docs/modules/chains/examples/sqlite.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "d0e27d88",
"metadata": {
"pycharm": {
Expand All @@ -43,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "72ede462",
"metadata": {
"pycharm": {
Expand Down Expand Up @@ -346,15 +346,46 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CREATE TABLE [Album]\n",
"(\n",
" [AlbumId] INTEGER NOT NULL,\n",
" [Title] NVARCHAR(160) NOT NULL,\n",
" [ArtistId] INTEGER NOT NULL,\n",
" CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]),\n",
" FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n",
")\n",
"\n",
" Table data will be described in the following format:\n",
"SELECT * FROM 'Album' LIMIT 2\n",
"AlbumId Title ArtistId\n",
"1 For Those About To Rock We Salute You 1\n",
"2 Balls to the Wall 2\n",
"\n",
" Table 'table name' has columns: {column1 name: (column1 type, [list of example values for column1]),\n",
" column2 name: (column2 type, [list of example values for column2], ...)\n",
"\n",
" These are the tables you can use, together with their column information:\n",
"CREATE TABLE [Track]\n",
"(\n",
" [TrackId] INTEGER NOT NULL,\n",
" [Name] NVARCHAR(200) NOT NULL,\n",
" [AlbumId] INTEGER,\n",
" [MediaTypeId] INTEGER NOT NULL,\n",
" [GenreId] INTEGER,\n",
" [Composer] NVARCHAR(220),\n",
" [Milliseconds] INTEGER NOT NULL,\n",
" [Bytes] INTEGER,\n",
" [UnitPrice] NUMERIC(10,2) NOT NULL,\n",
" CONSTRAINT [PK_Track] PRIMARY KEY ([TrackId]),\n",
" FOREIGN KEY ([AlbumId]) REFERENCES [Album] ([AlbumId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n",
" FOREIGN KEY ([GenreId]) REFERENCES [Genre] ([GenreId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION,\n",
" FOREIGN KEY ([MediaTypeId]) REFERENCES [MediaType] ([MediaTypeId]) \n",
"\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n",
")\n",
"\n",
" Table 'Track' has columns: {'TrackId': ['INTEGER', ['1', '2']], 'Name': ['NVARCHAR(200)', ['For Those About To Rock (We Salute You)', 'Balls to the Wall']], 'AlbumId': ['INTEGER', ['1', '2']], 'MediaTypeId': ['INTEGER', ['1', '2']], 'GenreId': ['INTEGER', ['1', '1']], 'Composer': ['NVARCHAR(220)', ['Angus Young, Malcolm Young, Brian Johnson', 'None']], 'Milliseconds': ['INTEGER', ['343719', '342562']], 'Bytes': ['INTEGER', ['11170334', '5510424']], 'UnitPrice': ['NUMERIC(10, 2)', ['0.99', '0.99']]}\n"
"SELECT * FROM 'Track' LIMIT 2\n",
"TrackId Name AlbumId MediaTypeId GenreId Composer Milliseconds Bytes UnitPrice\n",
"1 For Those About To Rock (We Salute You) 1 1 1 Angus Young, Malcolm Young, Brian Johnson 343719 11170334 0.99\n",
"2 Balls to the Wall 2 2 1 None 342562 5510424 0.99\n"
]
}
],
Expand Down Expand Up @@ -492,9 +523,9 @@
"lastKernelId": null
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "langchain",
"language": "python",
"name": "python3"
"name": "langchain"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -506,7 +537,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.8.16"
}
},
"nbformat": 4,
Expand Down
55 changes: 40 additions & 15 deletions langchain/sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import ast
from collections import defaultdict
from typing import Any, Iterable, List, Optional

from sqlalchemy import create_engine, inspect
Expand Down Expand Up @@ -30,7 +29,7 @@ def __init__(
schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 0,
sample_rows_in_table_info: int = 3,
):
"""Create engine from database URI."""
self._engine = engine
Expand Down Expand Up @@ -80,9 +79,12 @@ def table_info(self) -> str:
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated by Rajkumar et al, 2022 (https://arxiv.org/abs/2204.00498).
demonstrated in the paper.
"""
all_table_names = self.get_table_names()
if table_names is not None:
Expand All @@ -93,33 +95,51 @@ def get_table_info(self, table_names: Optional[List[str]] = None) -> str:

tables = []
for table_name in all_table_names:
columns = defaultdict(list)
columns = []
create_table = self.run(
(
"SELECT sql FROM sqlite_master WHERE "
f"type='table' AND name='{table_name}'"
),
fetch="one",
)

for column in self._inspector.get_columns(table_name, schema=self._schema):
columns[f"{column['name']}"].append(str(column["type"]))
columns.append(column["name"])

if self._sample_rows_in_table_info:
sample_rows = self.run(
select_star = (
f"SELECT * FROM '{table_name}' LIMIT "
f"{self._sample_rows_in_table_info}"
)

sample_rows = self.run(select_star)

sample_rows_ls = ast.literal_eval(sample_rows)
sample_rows_ls = list(
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_ls)
)

for e, col in enumerate(columns):
columns[col].append(
[row[e] for row in sample_rows_ls] # type: ignore
)
columns_str = " ".join(columns)
sample_rows_str = "\n".join([" ".join(row) for row in sample_rows_ls])

tables.append(
create_table
+ "\n\n"
+ select_star
+ "\n"
+ columns_str
+ "\n"
+ sample_rows_str
)

table_str = f"Table '{table_name}' has columns: " + str(dict(columns))
tables.append(table_str)
else:
tables.append(create_table)

final_str = _TEMPLATE_PREFIX + "\n".join(tables)
final_str = "\n\n\n".join(tables)
return final_str

def run(self, command: str) -> str:
def run(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
Expand All @@ -130,6 +150,11 @@ def run(self, command: str) -> str:
connection.exec_driver_sql(f"SET search_path TO {self._schema}")
cursor = connection.exec_driver_sql(command)
if cursor.returns_rows:
result = cursor.fetchall()
if fetch == "all":
result = cursor.fetchall()
elif fetch == "one":
result = cursor.fetchone()[0]
else:
raise ValueError("Fetch parameter must be either 'one' or 'all'")
return str(result)
return ""
57 changes: 46 additions & 11 deletions tests/unit_tests/test_sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,28 @@ def test_table_info() -> None:
metadata_obj.create_all(engine)
db = SQLDatabase(engine)
output = db.table_info
output = output[len(_TEMPLATE_PREFIX) :]
expected_output = (
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR(16)']}",
"Table 'company' has columns: {'company_id': ['INTEGER'], 'company_location': ['VARCHAR']}",
expected_output = """
CREATE TABLE user (
user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
)
assert sorted(output.split("\n")) == sorted(expected_output)
SELECT * FROM 'user' LIMIT 3
user_id user_name
CREATE TABLE company (
company_id INTEGER NOT NULL,
company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 3
company_id company_location
"""

assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))


def test_table_info_w_sample_rows() -> None:
Expand All @@ -51,12 +67,31 @@ def test_table_info_w_sample_rows() -> None:
db = SQLDatabase(engine, sample_rows_in_table_info=2)

output = db.table_info
output = output[len(_TEMPLATE_PREFIX) :]
expected_output = (
"Table 'user' has columns: {'user_id': ['INTEGER', ['13', '14']], 'user_name': ['VARCHAR(16)', ['Harrison', 'Chase']]}",
"Table 'company' has columns: {'company_id': ['INTEGER', []], 'company_location': ['VARCHAR', []]}",
)
assert sorted(output.split("\n")) == sorted(expected_output)

expected_output = """
CREATE TABLE company (
company_id INTEGER NOT NULL,
company_location VARCHAR NOT NULL,
PRIMARY KEY (company_id)
)
SELECT * FROM 'company' LIMIT 2
company_id company_location
CREATE TABLE user (
user_id INTEGER NOT NULL,
user_name VARCHAR(16) NOT NULL,
PRIMARY KEY (user_id)
)
SELECT * FROM 'user' LIMIT 2
user_id user_name
13 Harrison
14 Chase
"""

assert sorted(output.split()) == sorted(expected_output.split())


def test_sql_database_run() -> None:
Expand Down
16 changes: 10 additions & 6 deletions tests/unit_tests/test_sql_database_schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
"""Test SQL database wrapper with schema support.
Using DuckDB as SQLite does not support schemas.
Expand All @@ -16,7 +17,7 @@
schema,
)

from langchain.sql_database import _TEMPLATE_PREFIX, SQLDatabase
from langchain.sql_database import SQLDatabase

metadata_obj = MetaData()

Expand Down Expand Up @@ -46,11 +47,14 @@ def test_table_info() -> None:
metadata_obj.create_all(engine)
db = SQLDatabase(engine, schema="schema_a")
output = db.table_info
output = output[len(_TEMPLATE_PREFIX) :]
expected_output = (
"Table 'user' has columns: {'user_id': ['INTEGER'], 'user_name': ['VARCHAR']}"
)
assert output == expected_output
expected_output = """
CREATE TABLE schema_a."user"(user_id INTEGER, user_name VARCHAR NOT NULL, PRIMARY KEY(user_id));
SELECT * FROM 'user' LIMIT 3
user_id user_name
"""

assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))


def test_sql_database_run() -> None:
Expand Down

0 comments on commit 45c2eae

Please sign in to comment.