Skip to content

Commit

Permalink
Fix query dependencies property
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Jul 9, 2024
1 parent 4d2e11f commit 4949f21
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 41 deletions.
29 changes: 23 additions & 6 deletions superduperdb/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def decorated(self, *args, **kwargs):
class _BaseQuery(Leaf):
def __post_init__(self, db: t.Optional['Datalayer'] = None):
super().__post_init__(db)
self._is_output_query = False
self._updated_key = None
if not self.identifier:
self.identifier = self._build_hr_identifier()

Expand Down Expand Up @@ -201,6 +203,26 @@ def _set_db(r, db):
parts.append((part, part_args, part_kwargs))
self.parts = parts

@property
def is_output_query(self):
"""Check if query is of output type."""
return self._is_output_query

@is_output_query.setter
def is_output_query(self, b):
"""Property setter."""
self._is_output_query = b

@property
def updated_key(self):
"""Return query updated key."""
return self._updated_key

@updated_key.setter
def updated_key(self, update):
"""Property setter."""
self._updated_key = update

@property
def dependencies(
self,
Expand All @@ -211,12 +233,7 @@ def dependencies(
dependencies = []

def _check_query_match(listener, query):
if (
listener.select.table_or_collection.identifier
== query.table_or_collection.identifier
):
return True
return False
return listener.depends(query)

for listener in listeners:
listener = self.db.listeners[listener]
Expand Down
3 changes: 3 additions & 0 deletions superduperdb/backends/ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ def model_update(
:param flatten: Whether to flatten the outputs.
:param kwargs: Additional keyword arguments.
"""
self.is_output_query = True
self.updated_key = predict_id

if not flatten:
return _model_update_impl(
db=self.db,
Expand Down
9 changes: 1 addition & 8 deletions superduperdb/backends/local/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,7 @@ def broadcast(self, events: t.List, to: tuple = ()):
:param events: List of events.
:param to: Destination component.
"""
jobs = []
if isinstance(to, (list, tuple)):
for dep in to:
jobs.append(self.queue.publish(events, to=dep))
else:
job = self.queue.publish(events, to=to)
jobs.append(job)
return jobs
return self.queue.publish(events, to=to)

def submit(
self, function: t.Callable, *args, compute_kwargs: t.Dict = {}, **kwargs
Expand Down
12 changes: 10 additions & 2 deletions superduperdb/backends/mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def _execute_bulk_write(self, parent):
operations = self.parts[0][1][0]
for query in operations:
assert isinstance(query, (BulkOp))

query.is_output_query = self.is_output_query
if not query.kwargs.get('arg_ids', None):
raise ValueError(
'Please provided update/delete id in args',
Expand Down Expand Up @@ -601,6 +603,7 @@ def model_update(
)

document_embedded = kwargs.get('document_embedded', True)

if document_embedded:
outputs = [Document({"_base": output}).encode() for output in outputs]
bulk_operations = []
Expand All @@ -621,7 +624,9 @@ def model_update(
update=update,
)
)
return self.table_or_collection.bulk_write(bulk_operations)
output_query = self.table_or_collection.bulk_write(
bulk_operations, output_query=True
)

else:
documents = []
Expand All @@ -636,7 +641,10 @@ def model_update(
from superduperdb.base.datalayer import Datalayer

assert isinstance(self.db, Datalayer)
return self.db[f'_outputs.{predict_id}'].insert_many(documents)
output_query = self.db[f'_outputs.{predict_id}'].insert_many(documents)
output_query.is_output_query = True
output_query.updated_key = predict_id
return output_query


def InsertOne(**kwargs):
Expand Down
31 changes: 14 additions & 17 deletions superduperdb/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ def on_event(self, query: Query, ids: t.Sequence[str], event_type: str = 'insert
:param ids: IDs that further reduce the scope of computations.
"""
deps = query.dependencies
if not deps:
return
events = [{'identifier': id, 'type': event_type} for id in ids]
return self.compute.broadcast(events, to=deps)

Expand All @@ -441,23 +443,16 @@ def _write(self, write: Query, refresh: bool = True) -> UpdateResult:
logging.warn('CDC service is active, skipping model/listener refresh')
else:
jobs = []
for u, d in zip(write_result['update'], write_result['delete']):
q = u['query']
ids = u['ids']
# Overwrite should be true for update operation since updates
# could be done on collections with already existing outputs
# We need overwrite ouptuts on those select and recompute predict
job_update = self.on_event(
query=q, ids=ids, event_type=Event.upsert
if updated_ids:
job = self.on_event(
query=write, ids=updated_ids, event_type=Event.update
)
jobs.append(job_update)

q = d['query']
ids = d['ids']
job_update = self.on_event(
query=q, ids=ids, event_type=Event.delete
jobs.append(job)
if deleted_ids:
job = self.on_event(
query=write, ids=deleted_ids, event_type=Event.delete
)
jobs.append(job_update)
jobs.append(job)

return updated_ids, deleted_ids, jobs
return updated_ids, deleted_ids, None
Expand Down Expand Up @@ -731,9 +726,11 @@ def _apply(
if parent is not None:
self.metadata.create_parent_child(parent, object.uuid)

deps = []
for job in jobs:
if not isinstance(job, dict):
dependencies.append(job.job_id)
if isinstance(job, Job):
deps.append(job.job_id)
dependencies = [*deps, *dependencies] # type: ignore[list-item]

object.post_create(self)
self._add_component_to_cache(object)
Expand Down
22 changes: 22 additions & 0 deletions superduperdb/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ def predict_id(self):
"""Get predict ID."""
return self.uuid

def depends(self, query):
"""Check if query depends on the listener."""
if query.is_output_query:
if self.key.startswith('_outputs.'):
key = self.key.split('_outputs.')[-1]
else:
key = self.key

if key in query.updated_key:
if (
self.select.table_or_collection.identifier
== query.table_or_collection.identifier
):
return True
else:
if (
self.select.table_or_collection.identifier
== query.table_or_collection.identifier
):
return True
return False

@override
def schedule_jobs(
self,
Expand Down
23 changes: 16 additions & 7 deletions superduperdb/jobs/queue.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import typing as t

DependencyType = t.Union[t.Dict[str, str], t.Sequence[t.Dict[str, str]]]


class LocalSequentialQueue:
"""
LocalSequentialQueue for handling publisher and consumer process
in local queue.
LocalSequentialQueue for handling publisher and consumer process.
Local queue which holds listeners, vector indices as queue which
consists of events to be consumed by the corresponding components.
Expand All @@ -31,18 +32,26 @@ def db(self):
def db(self, db):
self._db = db

def publish(self, events: t.List[t.Dict], to: t.Dict[str, str]):
def publish(self, events: t.List[t.Dict], to: DependencyType):
"""
Publish events to local queue.
:param events: list of events
:param to: Component name for events to be published.
"""
identifier = to['identifier']
type_id = to['type_id']
self._component_map.update(to)

self.queue[f'{type_id}.{identifier}'].extend(events)
def _publish(events, to):
identifier = to['identifier']
type_id = to['type_id']
self._component_map.update(to)

self.queue[f'{type_id}.{identifier}'].extend(events)

if isinstance(to, (tuple, list)):
for dep in to:
_publish(events, dep)
else:
_publish(events, to)
return self.consume()

def consume(self):
Expand Down
2 changes: 1 addition & 1 deletion test/unittest/test_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ALLOWABLE_DEFECTS = {
'cast': 5, # Try to keep this down
'noqa': 5, # This should never change
'type_ignore': 12, # This should only ever increase in obscure edge cases
'type_ignore': 13, # This should only ever increase in obscure edge cases
}


Expand Down

0 comments on commit 4949f21

Please sign in to comment.