Skip to content

Commit

Permalink
Add _execute_select and filter in the Query class.
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed Aug 2, 2024
1 parent 33a89e8 commit e3ca89e
Show file tree
Hide file tree
Showing 16 changed files with 345 additions and 251 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Change all the `_outputs.` to `_outputs__`
- Disable cdc on output tables.
- Remove `-` from the uuid of the component.
- Add _execute_select and filter in the Query class.



#### New Features & Functionality
Expand Down
83 changes: 45 additions & 38 deletions superduper/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import abstractmethod
from functools import wraps

from superduper import CFG
from superduper import CFG, logging
from superduper.base.document import Document, _unpack
from superduper.base.leaf import Leaf

Expand Down Expand Up @@ -70,7 +70,6 @@ def _build_hr_identifier(self):
return identifier

def __getattr__(self, item):
item = type(self).methods_mapping.get(item, item)
return type(self)(
db=self.db,
table=self.table,
Expand Down Expand Up @@ -164,15 +163,16 @@ class Query(_BaseQuery):
"""

flavours: t.ClassVar[t.Dict[str, str]] = {}
methods_mapping: t.ClassVar[t.Dict[str, str]] = {}

table: str
parts: t.Sequence[t.Union[t.Tuple, str]] = ()
identifier: str = ''

def __getitem__(self, item):
if isinstance(item, str):
return getattr(self, item)
if not isinstance(item, slice):
return super().__getitem__(item)
raise TypeError('Query index must be a string or a slice')
assert isinstance(item, slice)
parts = self.parts[item]
return type(self)(db=self.db, table=self.table, parts=parts)
Expand Down Expand Up @@ -281,6 +281,9 @@ def _get_flavour(self):
def _get_parent(self):
return self.db.databackend.get_table_or_collection(self.table)

def _execute_select(self, parent):
raise NotImplementedError

def _prepare_pre_like(self, parent):
like_args, like_kwargs = self.parts[0][1:]
like_args = list(like_args)
Expand Down Expand Up @@ -389,40 +392,37 @@ def __repr__(self):
output = output.replace(f'documents[{i}]', doc_string)
return output

def __eq__(self, other):
def _ops(self, op, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__eq__', (other,), {})],
parts=self.parts + [(op, (other,), {})],
)

def __eq__(self, other):
return self._ops('__eq__', other)

def __ne__(self, other):
return self._ops('__ne__', other)

def __lt__(self, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__lt__', (other,), {})],
)
return self._ops('__lt__', other)

def __gt__(self, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__gt__', (other,), {})],
)
return self._ops('__gt__', other)

def __le__(self, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__le__', (other,), {})],
)
return self._ops('__le__', other)

def __ge__(self, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__ge__', (other,), {})],
)
return self._ops('__ge__', other)

def isin(self, other):
"""Create an isin query.
:param other: The value to check against.
"""
return self._ops('isin', other)

def _encode_or_unpack_args(self, r, db, method='encode', parent=None):
if isinstance(r, Document):
Expand Down Expand Up @@ -456,17 +456,25 @@ def _encode_or_unpack_args(self, r, db, method='encode', parent=None):
return r

def _execute(self, parent, method='encode'):
for part in self.parts:
if isinstance(part, str):
parent = getattr(parent, part)
continue
args = self._encode_or_unpack_args(
part[1], self.db, method=method, parent=parent
)
kwargs = self._encode_or_unpack_args(
part[2], self.db, method=method, parent=parent
)
parent = getattr(parent, part[0])(*args, **kwargs)
return self._get_chain_native_query(parent, self.parts, method)

def _get_chain_native_query(self, parent, parts, method='encode'):
try:
for part in parts:
if isinstance(part, str):
parent = getattr(parent, part)
continue
args = self._encode_or_unpack_args(
part[1], self.db, method=method, parent=parent
)
kwargs = self._encode_or_unpack_args(
part[2], self.db, method=method, parent=parent
)
parent = getattr(parent, part[0])(*args, **kwargs)
except Exception as e:
logging.error(f'Error in executing query, parts: {parts}')
raise e

return parent

@abstractmethod
Expand Down Expand Up @@ -723,7 +731,6 @@ class Model(_BaseQuery):
table: str
identifier: str = ''
parts: t.Sequence[t.Union[t.Tuple, str]] = ()
methods_mapping: t.ClassVar[t.Dict[str, str]] = {}
type: t.ClassVar[str] = 'predict'

def execute(self):
Expand Down
33 changes: 12 additions & 21 deletions superduper/backends/ibis/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import typing as t
import uuid
from collections import defaultdict
Expand Down Expand Up @@ -106,6 +105,12 @@ class IbisQuery(Query):
'anti_join': r'^[^\(]+\.anti_join\(.*\)$',
}

# Use to control the behavior in the class construction method within LeafMeta
__dataclass_params__: t.ClassVar[t.Dict[str, t.Any]] = {
'eq': False,
'order': False,
}

@property
@applies_to('insert')
def documents(self):
Expand Down Expand Up @@ -184,15 +189,6 @@ def _execute_pre_like(self, parent):
result.scores = similar_scores
return result

def __eq__(self, other):
return super().__eq__(other)

def __leq__(self, other):
return super().__leq__(other)

def __geq__(self, other):
return super().__geq__(other)

def _execute_post_like(self, parent):
pre_like_parts = []
like_part = []
Expand Down Expand Up @@ -363,17 +359,12 @@ def outputs(self, *predict_ids):
:param predict_ids: The predict ids.
"""
find_args = ()
if self.parts:
find_args, _ = self.parts[0][1:]
find_args = copy.deepcopy(list(find_args))

if not find_args:
find_args = [{}]

if not find_args[1:]:
find_args.append({})

for part in self.parts:
if part[0] == 'select':
args = part[1]
assert (
self.primary_id in args
), f'Primary id: `{self.primary_id}` not in select when using outputs'
query = self
attr = getattr(query, self.primary_id)
for identifier in predict_ids:
Expand Down
Loading

0 comments on commit e3ca89e

Please sign in to comment.