Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _execute_select and filter in the Query class. #2363

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Change all the `_outputs.` to `_outputs__`
- Disable cdc on output tables.
- Remove `-` from the uuid of the component.
- Add _execute_select and filter in the Query class.



#### New Features & Functionality
Expand Down
83 changes: 45 additions & 38 deletions superduper/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from abc import abstractmethod
from functools import wraps

from superduper import CFG
from superduper import CFG, logging
from superduper.base.document import Document, _unpack
from superduper.base.leaf import Leaf

Expand Down Expand Up @@ -70,7 +70,6 @@ def _build_hr_identifier(self):
return identifier

def __getattr__(self, item):
item = type(self).methods_mapping.get(item, item)
return type(self)(
db=self.db,
table=self.table,
Expand Down Expand Up @@ -164,15 +163,16 @@ class Query(_BaseQuery):
"""

flavours: t.ClassVar[t.Dict[str, str]] = {}
methods_mapping: t.ClassVar[t.Dict[str, str]] = {}

table: str
parts: t.Sequence[t.Union[t.Tuple, str]] = ()
identifier: str = ''

def __getitem__(self, item):
if isinstance(item, str):
kartik4949 marked this conversation as resolved.
Show resolved Hide resolved
return getattr(self, item)
if not isinstance(item, slice):
return super().__getitem__(item)
raise TypeError('Query index must be a string or a slice')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parent class does not define the getitem method.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

assert isinstance(item, slice)
parts = self.parts[item]
return type(self)(db=self.db, table=self.table, parts=parts)
Expand Down Expand Up @@ -281,6 +281,9 @@ def _get_flavour(self):
def _get_parent(self):
return self.db.databackend.get_table_or_collection(self.table)

def _execute_select(self, parent):
raise NotImplementedError

def _prepare_pre_like(self, parent):
like_args, like_kwargs = self.parts[0][1:]
like_args = list(like_args)
Expand Down Expand Up @@ -389,40 +392,37 @@ def __repr__(self):
output = output.replace(f'documents[{i}]', doc_string)
return output

def __eq__(self, other):
def _ops(self, op, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__eq__', (other,), {})],
parts=self.parts + [(op, (other,), {})],
)

def __eq__(self, other):
return self._ops('__eq__', other)

def __ne__(self, other):
return self._ops('__ne__', other)

def __lt__(self, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__lt__', (other,), {})],
)
return self._ops('__lt__', other)

def __gt__(self, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__gt__', (other,), {})],
)
return self._ops('__gt__', other)

def __le__(self, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__le__', (other,), {})],
)
return self._ops('__le__', other)

def __ge__(self, other):
return type(self)(
db=self.db,
table=self.table,
parts=self.parts + [('__ge__', (other,), {})],
)
return self._ops('__ge__', other)

def isin(self, other):
"""Create an isin query.
:param other: The value to check against.
"""
return self._ops('isin', other)

def _encode_or_unpack_args(self, r, db, method='encode', parent=None):
if isinstance(r, Document):
Expand Down Expand Up @@ -456,17 +456,25 @@ def _encode_or_unpack_args(self, r, db, method='encode', parent=None):
return r

def _execute(self, parent, method='encode'):
for part in self.parts:
if isinstance(part, str):
parent = getattr(parent, part)
continue
args = self._encode_or_unpack_args(
part[1], self.db, method=method, parent=parent
)
kwargs = self._encode_or_unpack_args(
part[2], self.db, method=method, parent=parent
)
parent = getattr(parent, part[0])(*args, **kwargs)
return self._get_chain_native_query(parent, self.parts, method)

def _get_chain_native_query(self, parent, parts, method='encode'):
try:
for part in parts:
if isinstance(part, str):
parent = getattr(parent, part)
continue
args = self._encode_or_unpack_args(
part[1], self.db, method=method, parent=parent
)
kwargs = self._encode_or_unpack_args(
part[2], self.db, method=method, parent=parent
)
parent = getattr(parent, part[0])(*args, **kwargs)
except Exception as e:
logging.error(f'Error in executing query, parts: {parts}')
raise e

Comment on lines +459 to +477
Copy link
Collaborator Author

@jieguangzhou jieguangzhou Aug 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separating this method allows for easier future expansion of the query, enabling the query to flexibly construct the parts list without being tied to the expression.

In the future, all Query classes can be merged into one class, with Query only recording the expression and being independent the type of the database.

We can use Databackend to execute the Query

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, yes this is a possible approach.

return parent

@abstractmethod
Expand Down Expand Up @@ -723,7 +731,6 @@ class Model(_BaseQuery):
table: str
identifier: str = ''
parts: t.Sequence[t.Union[t.Tuple, str]] = ()
methods_mapping: t.ClassVar[t.Dict[str, str]] = {}
type: t.ClassVar[str] = 'predict'

def execute(self):
Expand Down
33 changes: 12 additions & 21 deletions superduper/backends/ibis/query.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import typing as t
import uuid
from collections import defaultdict
Expand Down Expand Up @@ -106,6 +105,12 @@ class IbisQuery(Query):
'anti_join': r'^[^\(]+\.anti_join\(.*\)$',
}

# Use to control the behavior in the class construction method within LeafMeta
__dataclass_params__: t.ClassVar[t.Dict[str, t.Any]] = {
'eq': False,
'order': False,
}

@property
@applies_to('insert')
def documents(self):
Expand Down Expand Up @@ -184,15 +189,6 @@ def _execute_pre_like(self, parent):
result.scores = similar_scores
return result

def __eq__(self, other):
return super().__eq__(other)

def __leq__(self, other):
return super().__leq__(other)

def __geq__(self, other):
return super().__geq__(other)

def _execute_post_like(self, parent):
pre_like_parts = []
like_part = []
Expand Down Expand Up @@ -363,17 +359,12 @@ def outputs(self, *predict_ids):
:param predict_ids: The predict ids.
"""
find_args = ()
if self.parts:
find_args, _ = self.parts[0][1:]
find_args = copy.deepcopy(list(find_args))

if not find_args:
find_args = [{}]

if not find_args[1:]:
find_args.append({})

for part in self.parts:
if part[0] == 'select':
args = part[1]
assert (
self.primary_id in args
), f'Primary id: `{self.primary_id}` not in select when using outputs'
query = self
attr = getattr(query, self.primary_id)
for identifier in predict_ids:
Expand Down
Loading
Loading