Skip to content

Commit

Permalink
Merge pull request #1 from myscale/master
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
mpskex authored Jun 13, 2023
2 parents 5b6bbf4 + 1bba1ee commit a79ca91
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 1 deletion.
2 changes: 2 additions & 0 deletions langchain/chains/query_constructor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class Comparator(str, Enum):
GTE = "gte"
LT = "lt"
LTE = "lte"
CONTAIN = 'contain'
LIKE = "like"


class FilterDirective(Expr, ABC):
Expand Down
7 changes: 7 additions & 0 deletions langchain/chains/query_constructor/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Optional, Sequence, Union
import datetime

try:
import lark
Expand Down Expand Up @@ -34,12 +35,14 @@ def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore
?value: SIGNED_INT -> int
| SIGNED_FLOAT -> float
| TIMESTAMP -> timestamp
| list
| string
| ("false" | "False" | "FALSE") -> false
| ("true" | "True" | "TRUE") -> true
args: expr ("," expr)*
TIMESTAMP.2: /["'](\d{4}-[01]\d-[0-3]\d)["']/
string: /'[^']*'/ | ESCAPED_STRING
list: "[" [args] "]"
Expand Down Expand Up @@ -119,6 +122,10 @@ def int(self, item: Any) -> int:

def float(self, item: Any) -> float:
return float(item)

def timestamp(self, item: Any):
item = item.replace("'", '"')
return datetime.datetime.strptime(item, '"%Y-%m-%d"').date()

def string(self, item: Any) -> str:
# Remove escaped quotes
Expand Down
3 changes: 3 additions & 0 deletions langchain/chains/query_constructor/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
Expand Down Expand Up @@ -179,6 +180,8 @@
Make sure that you only use the comparators and logical operators listed above and \
no others.
Make sure that filters only refer to attributes that exist in the data source.
Make sure that filters only use the attributed names with its function names if there are functions applied on them.
Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values.
Make sure that filters take into account the descriptions of attributes and only make \
comparisons that are feasible given the type of data being stored.
Make sure that filters are only used as needed. If there are no filters that should be \
Expand Down
6 changes: 5 additions & 1 deletion langchain/retrievers/self_query/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from langchain.retrievers.self_query.pinecone import PineconeTranslator
from langchain.retrievers.self_query.qdrant import QdrantTranslator
from langchain.retrievers.self_query.weaviate import WeaviateTranslator
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores import Chroma, Pinecone, Qdrant, VectorStore, Weaviate
from langchain.vectorstores import Chroma, Pinecone, Qdrant, VectorStore, Weaviate, MyScale


def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
Expand All @@ -24,6 +25,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
Chroma: ChromaTranslator,
Weaviate: WeaviateTranslator,
Qdrant: QdrantTranslator,
MyScale: MyScaleTranslator,
}
if vectorstore_cls not in BUILTIN_TRANSLATORS:
raise ValueError(
Expand All @@ -32,6 +34,8 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
)
if isinstance(vectorstore, Qdrant):
return QdrantTranslator(metadata_key=vectorstore.metadata_payload_key)
elif isinstance(vectorstore, MyScale):
return MyScaleTranslator(metadata_key=vectorstore.metadata_column)
return BUILTIN_TRANSLATORS[vectorstore_cls]()


Expand Down
95 changes: 95 additions & 0 deletions langchain/retrievers/self_query/myscale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import re
import datetime
from typing import Dict, Tuple, Union
from langchain.chains.query_constructor.ir import (
Comparator,
Comparison,
Operation,
Operator,
StructuredQuery,
Visitor,
)

def DEFAULT_COMPOSER(op):
def f(*args):
args = map(str, args)
return f' {op} '.join(args)
return f


def FUNCTION_COMPOSER(op):
def f(*args):
args = map(str, args)
return f"{op}({','.join(args)})"
return f


class MyScaleTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters."""

allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
"""Subset of allowed logical operators."""

allowed_comparators = [Comparator.EQ,
Comparator.GT,
Comparator.GTE,
Comparator.LT,
Comparator.LTE,
Comparator.CONTAIN,
Comparator.LIKE]

map_dict = {Operator.AND: DEFAULT_COMPOSER("AND"),
Operator.OR: DEFAULT_COMPOSER("OR"),
Operator.NOT: DEFAULT_COMPOSER("NOT"),
Comparator.EQ: DEFAULT_COMPOSER('='),
Comparator.GT: DEFAULT_COMPOSER('>'),
Comparator.GTE: DEFAULT_COMPOSER('>='),
Comparator.LT: DEFAULT_COMPOSER('<='),
Comparator.LTE: DEFAULT_COMPOSER('<'),
Comparator.CONTAIN: FUNCTION_COMPOSER('has'),
Comparator.LIKE: DEFAULT_COMPOSER("ILIKE"),
}

def __init__(self, metadata_key: str = 'metadata') -> None:
super().__init__()
self.metadata_key = metadata_key

def visit_operation(self, operation: Operation) -> Dict:
args = [arg.accept(self) for arg in operation.arguments]
func = operation.operator
self._validate_func(func)
return self.map_dict[func](*args)

def visit_comparison(self, comparison: Comparison) -> Dict:
regex = '\((.*?)\)'
matched = re.search('\(\w+\)', comparison.attribute)

# If arbitrary function is applied to an attribute
if matched:
attr = re.sub(regex, f'({self.metadata_key}.{matched.group(0)[1:-1]})', comparison.attribute)
else:
attr = f'{self.metadata_key}.{comparison.attribute}'
value = comparison.value
comp = comparison.comparator

value = f"'{value}'" if type(value) is str else value

# convert timestamp for datetime objects
if type(value) is datetime.date:
attr = f"parseDateTime32BestEffort({attr})"
value = f"parseDateTime32BestEffort('{value.strftime('%Y-%m-%d')}')"

# string pattern match
if comp is Comparator.LIKE:
value = f"'%{value[1:-1]}%'"
return self.map_dict[comp](attr, value)

def visit_structured_query(
self, structured_query: StructuredQuery
) -> Tuple[str, dict]:
print(structured_query)
if structured_query.filter is None:
kwargs = {}
else:
kwargs = {"where_str": structured_query.filter.accept(self)}
return structured_query.query, kwargs

0 comments on commit a79ca91

Please sign in to comment.