Skip to content

Commit

Permalink
chore: add missing return type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Oct 16, 2024
1 parent 1da3e1d commit c2a47d8
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
10 changes: 7 additions & 3 deletions src/langchain_google_cloud_sql_mysql/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence

import google.auth
import google.auth.transport.requests
import requests
import sqlalchemy
from sqlalchemy.engine.row import RowMapping

from google.cloud.sql.connector import Connector, RefreshStrategy

from .version import __version__
Expand Down Expand Up @@ -235,15 +237,17 @@ def _execute_outside_tx(self, query: str, params: Optional[dict] = None) -> None
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
conn.execute(sqlalchemy.text(query), params)

def _fetch(self, query: str, params: Optional[dict] = None):
def _fetch(self, query: str, params: Optional[dict] = None) -> Sequence[RowMapping]:
"""Fetch results from a SQL query."""
with self.engine.connect() as conn:
result = conn.execute(sqlalchemy.text(query), params)
result_map = result.mappings()
result_fetch = result_map.fetchall()
return result_fetch

def _fetch_rows(self, query: str, params: Optional[dict] = None):
def _fetch_rows(
self, query: str, params: Optional[dict] = None
) -> Sequence[RowMapping]:
"""Fetch results from a SQL query as rows."""
with self.engine.connect() as conn:
result = conn.execute(sqlalchemy.text(query), params)
Expand Down
12 changes: 7 additions & 5 deletions src/langchain_google_cloud_sql_mysql/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def delete(
self.engine._execute(query)
return True

def apply_vector_index(self, vector_index: VectorIndex):
def apply_vector_index(self, vector_index: VectorIndex) -> None:
# Construct the default index name
if not vector_index.name:
vector_index.name = f"{self.table_name}_{DEFAULT_INDEX_NAME_SUFFIX}"
Expand All @@ -243,7 +243,7 @@ def apply_vector_index(self, vector_index: VectorIndex):
# After applying an index to the table, set the query option search type to be ANN
self.query_options.search_type = SearchType.ANN

def alter_vector_index(self, vector_index: VectorIndex):
def alter_vector_index(self, vector_index: VectorIndex) -> None:
existing_index_name = self._get_vector_index_name()
if not existing_index_name:
raise ValueError("No existing vector index found.")
Expand All @@ -258,7 +258,9 @@ def alter_vector_index(self, vector_index: VectorIndex):
)
self.__exec_apply_vector_index(query_template, vector_index)

def __exec_apply_vector_index(self, query_template: str, vector_index: VectorIndex):
def __exec_apply_vector_index(
self, query_template: str, vector_index: VectorIndex
) -> None:
index_options = []
if vector_index.index_type:
index_options.append(f"index_type={vector_index.index_type.value}")
Expand All @@ -275,15 +277,15 @@ def __exec_apply_vector_index(self, query_template: str, vector_index: VectorInd
stmt = query_template.format(index_options_query)
self.engine._execute_outside_tx(stmt)

def _get_vector_index_name(self):
def _get_vector_index_name(self) -> Optional[str]:
query = f"SELECT index_name FROM mysql.vector_indexes WHERE table_name='{self.db_name}.{self.table_name}';"
result = self.engine._fetch(query)
if result:
return result[0]["index_name"]
else:
return None

def drop_vector_index(self):
def drop_vector_index(self) -> Optional[str]:
existing_index_name = self._get_vector_index_name()
if existing_index_name:
self.engine._execute_outside_tx(
Expand Down

0 comments on commit c2a47d8

Please sign in to comment.