Skip to content

Commit

Permalink
chore: use standard collections as type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Oct 16, 2024
1 parent c2a47d8 commit aa80672
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 64 deletions.
4 changes: 2 additions & 2 deletions src/langchain_google_cloud_sql_mysql/chat_message_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import List
from typing import list

import sqlalchemy
from langchain_core.chat_history import BaseChatMessageHistory
Expand Down Expand Up @@ -78,7 +78,7 @@ def _verify_schema(self) -> None:
)

@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> list[BaseMessage]: # type: ignore
"""Retrieve the messages from Cloud SQL"""
query = f"SELECT data, type FROM `{self.table_name}` WHERE session_id = :session_id ORDER BY id;"
with self.engine.connect() as conn:
Expand Down
12 changes: 6 additions & 6 deletions src/langchain_google_cloud_sql_mysql/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
from typing import TYPE_CHECKING, Optional, Sequence

import google.auth
import google.auth.transport.requests
Expand Down Expand Up @@ -77,7 +77,7 @@ def _get_iam_principal_email(
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}"
response = requests.get(url)
response.raise_for_status()
response_json: Dict = response.json()
response_json: dict = response.json()
email = response_json.get("email")
if email is None:
raise ValueError(
Expand Down Expand Up @@ -287,7 +287,7 @@ def init_chat_history_table(self, table_name: str) -> None:
def init_document_table(
self,
table_name: str,
metadata_columns: List[sqlalchemy.Column] = [],
metadata_columns: list[sqlalchemy.Column] = [],
content_column: str = "page_content",
metadata_json_column: Optional[str] = "langchain_metadata",
overwrite_existing: bool = False,
Expand All @@ -297,7 +297,7 @@ def init_document_table(
Args:
table_name (str): The MySQL database table name.
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
metadata_columns (list[sqlalchemy.Column]): A list of SQLAlchemy Columns
to create for custom metadata. Optional.
content_column (str): The column to store document content.
Deafult: `page_content`.
Expand Down Expand Up @@ -351,7 +351,7 @@ def init_vectorstore_table(
vector_size: int,
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: List[Column] = [],
metadata_columns: list[Column] = [],
metadata_json_column: str = "langchain_metadata",
id_column: str = "langchain_id",
overwrite_existing: bool = False,
Expand All @@ -367,7 +367,7 @@ def init_vectorstore_table(
Default: `page_content`.
embedding_column (str) : Name of the column to store vector embeddings.
Default: `embedding`.
metadata_columns (List[Column]): A list of Columns to create for custom
metadata_columns (list[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: `langchain_metadata`. Optional.
Expand Down
30 changes: 15 additions & 15 deletions src/langchain_google_cloud_sql_mysql/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Dict, Iterable, Iterator, List, Optional, cast
from typing import Any, Iterable, Iterator, Optional, cast

import pymysql
import sqlalchemy
Expand All @@ -28,13 +28,13 @@
def _parse_doc_from_row(
content_columns: Iterable[str],
metadata_columns: Iterable[str],
row: Dict,
row: dict,
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
) -> Document:
page_content = " ".join(
str(row[column]) for column in content_columns if column in row
)
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
# unnest metadata from langchain_metadata column
if row.get(metadata_json_column):
for k, v in row[metadata_json_column].items():
Expand All @@ -51,9 +51,9 @@ def _parse_row_from_doc(
doc: Document,
content_column: str = DEFAULT_CONTENT_COL,
metadata_json_column: str = DEFAULT_METADATA_COL,
) -> Dict:
) -> dict:
doc_metadata = doc.metadata.copy()
row: Dict[str, Any] = {content_column: doc.page_content}
row: dict[str, Any] = {content_column: doc.page_content}
for entry in doc.metadata:
if entry in column_names:
row[entry] = doc_metadata[entry]
Expand All @@ -72,8 +72,8 @@ def __init__(
engine: MySQLEngine,
table_name: str = "",
query: str = "",
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
content_columns: Optional[list[str]] = None,
metadata_columns: Optional[list[str]] = None,
metadata_json_column: Optional[str] = None,
):
"""
Expand All @@ -89,9 +89,9 @@ def __init__(
engine (MySQLEngine): MySQLEngine object to connect to the MySQL database.
table_name (str): The MySQL database table name. (OneOf: table_name, query).
query (str): The query to execute in MySQL format. (OneOf: table_name, query).
content_columns (List[str]): The columns to write into the `page_content`
content_columns (list[str]): The columns to write into the `page_content`
of the document. Optional.
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
metadata_columns (list[str]): The columns to write into the `metadata` of the document.
Optional.
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
dictionary. Default: `langchain_metadata`. Optional.
Expand All @@ -110,12 +110,12 @@ def __init__(
"entire table or 'query' to load a specific query."
)

def load(self) -> List[Document]:
def load(self) -> list[Document]:
"""
Load langchain documents from a Cloud SQL MySQL database.
Returns:
(List[langchain_core.documents.Document]): a list of Documents with metadata from
(list[langchain_core.documents.Document]): a list of Documents with metadata from
specific columns.
"""
return list(self.lazy_load())
Expand Down Expand Up @@ -231,13 +231,13 @@ def __init__(
)
self.metadata_json_column = metadata_json_column or DEFAULT_METADATA_COL

def add_documents(self, docs: List[Document]) -> None:
def add_documents(self, docs: list[Document]) -> None:
"""
Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or
stored in langchain_metadata JSON column.
Args:
docs (List[langchain_core.documents.Document]): a list of documents to be saved.
docs (list[langchain_core.documents.Document]): a list of documents to be saved.
"""
with self.engine.connect() as conn:
for doc in docs:
Expand All @@ -250,13 +250,13 @@ def add_documents(self, docs: List[Document]) -> None:
conn.execute(sqlalchemy.insert(self._table).values(row))
conn.commit()

def delete(self, docs: List[Document]) -> None:
def delete(self, docs: list[Document]) -> None:
"""
Delete all instances of a document from the DocumentSaver table by matching the entire Document
object.
Args:
docs (List[langchain_core.documents.Document]): a list of documents to be deleted.
docs (list[langchain_core.documents.Document]): a list of documents to be deleted.
"""
with self.engine.connect() as conn:
for doc in docs:
Expand Down
Loading

0 comments on commit aa80672

Please sign in to comment.