diff --git a/superduperdb/backends/base/query.py b/superduperdb/backends/base/query.py index df7dc7184..ae0bdbd6d 100644 --- a/superduperdb/backends/base/query.py +++ b/superduperdb/backends/base/query.py @@ -41,6 +41,8 @@ def decorated(self, *args, **kwargs): class _BaseQuery(Leaf): def __post_init__(self, db: t.Optional['Datalayer'] = None): super().__post_init__(db) + self._is_output_query = False + self._updated_key = None if not self.identifier: self.identifier = self._build_hr_identifier() @@ -201,6 +203,26 @@ def _set_db(r, db): parts.append((part, part_args, part_kwargs)) self.parts = parts + @property + def is_output_query(self): + """Check if query is of output type.""" + return self._is_output_query + + @is_output_query.setter + def is_output_query(self, b): + """Property setter.""" + self._is_output_query = b + + @property + def updated_key(self): + """Return query updated key.""" + return self._updated_key + + @updated_key.setter + def updated_key(self, update): + """Property setter.""" + self._updated_key = update + @property def dependencies( self, @@ -211,12 +233,7 @@ def dependencies( dependencies = [] def _check_query_match(listener, query): - if ( - listener.select.table_or_collection.identifier - == query.table_or_collection.identifier - ): - return True - return False + return listener.depends(query) for listener in listeners: listener = self.db.listeners[listener] diff --git a/superduperdb/backends/ibis/query.py b/superduperdb/backends/ibis/query.py index 53d6ad4b7..274334eee 100644 --- a/superduperdb/backends/ibis/query.py +++ b/superduperdb/backends/ibis/query.py @@ -338,6 +338,9 @@ def model_update( :param flatten: Whether to flatten the outputs. :param kwargs: Additional keyword arguments. """ + self.is_output_query = True + self.updated_key = predict_id + if not flatten: return _model_update_impl( db=self.db, diff --git a/superduperdb/backends/local/compute.py b/superduperdb/backends/local/compute.py index 6eb7eb0e0..44fc69dcb 100644 --- a/superduperdb/backends/local/compute.py +++ b/superduperdb/backends/local/compute.py @@ -42,14 +42,7 @@ def broadcast(self, events: t.List, to: tuple = ()): :param events: List of events. :param to: Destination component. """ - jobs = [] - if isinstance(to, (list, tuple)): - for dep in to: - jobs.append(self.queue.publish(events, to=dep)) - else: - job = self.queue.publish(events, to=to) - jobs.append(job) - return jobs + return self.queue.publish(events, to=to) def submit( self, function: t.Callable, *args, compute_kwargs: t.Dict = {}, **kwargs diff --git a/superduperdb/backends/mongodb/query.py b/superduperdb/backends/mongodb/query.py index 368b134b5..3abb2cf3d 100644 --- a/superduperdb/backends/mongodb/query.py +++ b/superduperdb/backends/mongodb/query.py @@ -268,6 +268,8 @@ def _execute_bulk_write(self, parent): operations = self.parts[0][1][0] for query in operations: assert isinstance(query, (BulkOp)) + + query.is_output_query = self.is_output_query if not query.kwargs.get('arg_ids', None): raise ValueError( 'Please provided update/delete id in args', @@ -601,6 +603,7 @@ def model_update( ) document_embedded = kwargs.get('document_embedded', True) + if document_embedded: outputs = [Document({"_base": output}).encode() for output in outputs] bulk_operations = [] @@ -621,7 +624,9 @@ def model_update( update=update, ) ) - return self.table_or_collection.bulk_write(bulk_operations) + output_query = self.table_or_collection.bulk_write( + bulk_operations, output_query=True + ) else: documents = [] @@ -636,7 +641,10 @@ def model_update( from superduperdb.base.datalayer import Datalayer assert isinstance(self.db, Datalayer) - return self.db[f'_outputs.{predict_id}'].insert_many(documents) + output_query = self.db[f'_outputs.{predict_id}'].insert_many(documents) + output_query.is_output_query = True + output_query.updated_key = predict_id + return output_query def InsertOne(**kwargs): diff --git a/superduperdb/base/datalayer.py b/superduperdb/base/datalayer.py index 8fcd1d988..e55ec38bc 100644 --- a/superduperdb/base/datalayer.py +++ b/superduperdb/base/datalayer.py @@ -423,6 +423,8 @@ def on_event(self, query: Query, ids: t.Sequence[str], event_type: str = 'insert :param ids: IDs that further reduce the scope of computations. """ deps = query.dependencies + if not deps: + return events = [{'identifier': id, 'type': event_type} for id in ids] return self.compute.broadcast(events, to=deps) @@ -441,23 +443,16 @@ def _write(self, write: Query, refresh: bool = True) -> UpdateResult: logging.warn('CDC service is active, skipping model/listener refresh') else: jobs = [] - for u, d in zip(write_result['update'], write_result['delete']): - q = u['query'] - ids = u['ids'] - # Overwrite should be true for update operation since updates - # could be done on collections with already existing outputs - # We need overwrite ouptuts on those select and recompute predict - job_update = self.on_event( - query=q, ids=ids, event_type=Event.upsert + if updated_ids: + job = self.on_event( + query=write, ids=updated_ids, event_type=Event.update ) - jobs.append(job_update) - - q = d['query'] - ids = d['ids'] - job_update = self.on_event( - query=q, ids=ids, event_type=Event.delete + jobs.append(job) + if deleted_ids: + job = self.on_event( + query=write, ids=deleted_ids, event_type=Event.delete ) - jobs.append(job_update) + jobs.append(job) return updated_ids, deleted_ids, jobs return updated_ids, deleted_ids, None @@ -731,9 +726,11 @@ def _apply( if parent is not None: self.metadata.create_parent_child(parent, object.uuid) + deps = [] for job in jobs: - if not isinstance(job, dict): - dependencies.append(job.job_id) + if isinstance(job, Job): + deps.append(job.job_id) + dependencies = [*deps, *dependencies] # type: ignore[list-item] object.post_create(self) self._add_component_to_cache(object) diff --git a/superduperdb/components/listener.py b/superduperdb/components/listener.py index 2309ffd12..a52802b44 100644 --- a/superduperdb/components/listener.py +++ b/superduperdb/components/listener.py @@ -151,6 +151,28 @@ def predict_id(self): """Get predict ID.""" return self.uuid + def depends(self, query): + """Check if query depends on the listener.""" + if query.is_output_query: + if self.key.startswith('_outputs.'): + key = self.key.split('_outputs.')[-1] + else: + key = self.key + + if key in query.updated_key: + if ( + self.select.table_or_collection.identifier + == query.table_or_collection.identifier + ): + return True + else: + if ( + self.select.table_or_collection.identifier + == query.table_or_collection.identifier + ): + return True + return False + @override def schedule_jobs( self, diff --git a/superduperdb/jobs/queue.py b/superduperdb/jobs/queue.py index 23ad10242..287d60326 100644 --- a/superduperdb/jobs/queue.py +++ b/superduperdb/jobs/queue.py @@ -1,10 +1,11 @@ import typing as t +DependencyType = t.Union[t.Dict[str, str], t.Sequence[t.Dict[str, str]]] + class LocalSequentialQueue: """ - LocalSequentialQueue for handling publisher and consumer process - in local queue. + LocalSequentialQueue for handling publisher and consumer process. Local queue which holds listeners, vector indices as queue which consists of events to be consumed by the corresponding components. @@ -31,18 +32,26 @@ def db(self): def db(self, db): self._db = db - def publish(self, events: t.List[t.Dict], to: t.Dict[str, str]): + def publish(self, events: t.List[t.Dict], to: DependencyType): """ Publish events to local queue. :param events: list of events :param to: Component name for events to be published. """ - identifier = to['identifier'] - type_id = to['type_id'] - self._component_map.update(to) - self.queue[f'{type_id}.{identifier}'].extend(events) + def _publish(events, to): + identifier = to['identifier'] + type_id = to['type_id'] + self._component_map.update(to) + + self.queue[f'{type_id}.{identifier}'].extend(events) + + if isinstance(to, (tuple, list)): + for dep in to: + _publish(events, dep) + else: + _publish(events, to) return self.consume() def consume(self): diff --git a/test/unittest/test_quality.py b/test/unittest/test_quality.py index ca2655ce4..23a712762 100644 --- a/test/unittest/test_quality.py +++ b/test/unittest/test_quality.py @@ -19,7 +19,7 @@ ALLOWABLE_DEFECTS = { 'cast': 5, # Try to keep this down 'noqa': 5, # This should never change - 'type_ignore': 12, # This should only ever increase in obscure edge cases + 'type_ignore': 13, # This should only ever increase in obscure edge cases }