From 92f45d33d2910098ea56ecce74ab081fc637da05 Mon Sep 17 00:00:00 2001 From: JieguangZhou Date: Wed, 31 Jul 2024 11:32:24 +0800 Subject: [PATCH 1/2] Use a random SHA-1 hash instead of uuid.uuid4() to create a UUID. --- superduper/base/leaf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/superduper/base/leaf.py b/superduper/base/leaf.py index 03d59303d..b516c6e64 100644 --- a/superduper/base/leaf.py +++ b/superduper/base/leaf.py @@ -128,7 +128,7 @@ class Leaf(metaclass=LeafMeta): identifier: str db: dc.InitVar[t.Optional['Datalayer']] = None - uuid: str = dc.field(default_factory=lambda: str(uuid.uuid4())) + uuid: str = dc.field(default_factory=lambda: str(uuid.uuid4()).replace('-', '')) @property def metadata(self): From 8050b45940e386fc9a16cf0d7d788540b637c4d8 Mon Sep 17 00:00:00 2001 From: JieguangZhou Date: Wed, 31 Jul 2024 14:13:50 +0800 Subject: [PATCH 2/2] Make the prefix of the output string configurable --- CHANGELOG.md | 2 + superduper/backends/base/query.py | 15 +++-- superduper/backends/ibis/data_backend.py | 7 ++- superduper/backends/ibis/query.py | 14 +++-- superduper/backends/mongodb/data_backend.py | 8 +-- superduper/backends/mongodb/query.py | 38 +++++++------ superduper/base/config.py | 2 + superduper/base/datalayer.py | 6 +- superduper/base/document.py | 1 - superduper/components/application.py | 4 +- superduper/components/listener.py | 7 +-- superduper/components/model.py | 6 +- superduper/components/table.py | 3 +- superduper/components/vector_index.py | 5 +- superduper/misc/eager.py | 6 +- superduper/vector_search/atlas.py | 9 ++- superduper/vector_search/update_tasks.py | 8 ++- .../integration/usecase/test_output_prefix.py | 38 +++++++++++++ test/utils/smoke/__init__.py | 0 test/utils/smoke/chain_listener.py | 57 +++++++++++++++++++ 20 files changed, 172 insertions(+), 64 deletions(-) create mode 100644 test/integration/usecase/test_output_prefix.py create mode 100644 test/utils/smoke/__init__.py create mode 100644 test/utils/smoke/chain_listener.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ded6e13a..0ddf953d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Use different colors to distinguish logs - Change all the `_outputs.` to `_outputs__` - Disable cdc on output tables. +- Remove `-` from the uuid of the component. #### New Features & Functionality @@ -31,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Support eager mode - Add CSN env var - Make tests configurable against backend +- Make the prefix of the output string configurable #### Bug Fixes diff --git a/superduper/backends/base/query.py b/superduper/backends/base/query.py index 9a6113fc7..961fcd531 100644 --- a/superduper/backends/base/query.py +++ b/superduper/backends/base/query.py @@ -5,6 +5,7 @@ from abc import abstractmethod from functools import wraps +from superduper import CFG from superduper.base.document import Document, _unpack from superduper.base.leaf import Leaf @@ -327,7 +328,8 @@ def _wrap_document(document): field = [ k for k in table.schema.fields - if k not in [self.primary_id, '_fold', '_outputs'] + if k not in [self.primary_id, '_fold'] + and not k.startswith(CFG.output_prefix) ] assert len(field) == 1 document = Document({field[0]: document}) @@ -648,12 +650,13 @@ def table_or_collection(self): def _parse_query_part(part, documents, query, builder_cls, db=None): - key = part.split('.') - if key[0] == '_outputs': - table = f'{key[0]}.{key[1]}' - part = part.split('.')[2:] + if '.' in CFG.output_prefix and part.startswith(CFG.output_prefix): + rest_part = part[len(CFG.output_prefix) :].split('.') + table = f'{CFG.output_prefix}{rest_part[0]}' + part = rest_part[1:] + else: - table = key[0] + table = part.split('.')[0] part = part.split('.')[1:] current = builder_cls(table=table, parts=(), db=db) diff --git a/superduper/backends/ibis/data_backend.py b/superduper/backends/ibis/data_backend.py index a2375cfdc..21c74a443 100644 --- a/superduper/backends/ibis/data_backend.py +++ b/superduper/backends/ibis/data_backend.py @@ -8,6 +8,7 @@ from pandas.core.frame import DataFrame from sqlalchemy.exc import NoSuchTableError +from superduper import CFG from superduper.backends.base.data_backend import BaseDataBackend from superduper.backends.base.metadata import MetaDataStoreProxy from superduper.backends.ibis.db_helper import get_db_helper @@ -167,10 +168,10 @@ def create_output_dest( fields = { INPUT_KEY: dtype('string'), 'id': dtype('string'), - f'_outputs__{predict_id}': output_type, + f'{CFG.output_prefix}{predict_id}': output_type, } return Table( - identifier=f'_outputs__{predict_id}', + identifier=f'{CFG.output_prefix}{predict_id}', schema=Schema(identifier=f'_schema/{predict_id}', fields=fields), ) @@ -180,7 +181,7 @@ def check_output_dest(self, predict_id) -> bool: :param predict_id: The identifier of the prediction. """ try: - self.conn.table(f'_outputs__{predict_id}') + self.conn.table(f'{CFG.output_prefix}{predict_id}') return True except (NoSuchTableError, ibis.IbisError): return False diff --git a/superduper/backends/ibis/query.py b/superduper/backends/ibis/query.py index cce96a0d0..d629cfd4d 100644 --- a/superduper/backends/ibis/query.py +++ b/superduper/backends/ibis/query.py @@ -5,7 +5,7 @@ import pandas -from superduper import Document +from superduper import CFG, Document from superduper.backends.base.query import ( Query, applies_to, @@ -83,13 +83,13 @@ def _model_update_impl( for output, source_id in zip(outputs, ids): d = { '_source': str(source_id), - f'_outputs__{predict_id}': output.x + f'{CFG.output_prefix}{predict_id}': output.x if isinstance(output, Encodable) else output, 'id': str(uuid.uuid4()), } documents.append(Document(d)) - return db[f'_outputs__{predict_id}'].insert(documents) + return db[f'{CFG.output_prefix}{predict_id}'].insert(documents) class IbisQuery(Query): @@ -355,7 +355,7 @@ def drop_outputs(self, predict_id: str): :param predict_ids: The ids of the predictions to select. """ - return self.db.databackend.conn.drop_table(f'_outputs__{predict_id}') + return self.db.databackend.conn.drop_table(f'{CFG.output_prefix}{predict_id}') @applies_to('select') def outputs(self, *predict_ids): @@ -378,7 +378,9 @@ def outputs(self, *predict_ids): attr = getattr(query, self.primary_id) for identifier in predict_ids: identifier = ( - identifier if '_outputs' in identifier else f'_outputs__{identifier}' + identifier + if identifier.startswith(CFG.output_prefix) + else f'{CFG.output_prefix}{identifier}' ) symbol_table = self.db[identifier] @@ -399,7 +401,7 @@ def select_ids_of_missing_outputs(self, predict_id: str): assert isinstance(self.db, Datalayer) - output_table = self.db[f'_outputs__{predict_id}'] + output_table = self.db[f'{CFG.output_prefix}{predict_id}'] return self.anti_join( output_table, output_table._source == getattr(self, self.primary_id), diff --git a/superduper/backends/mongodb/data_backend.py b/superduper/backends/mongodb/data_backend.py index f8cea7918..4fd6fc413 100644 --- a/superduper/backends/mongodb/data_backend.py +++ b/superduper/backends/mongodb/data_backend.py @@ -5,7 +5,7 @@ import mongomock import pymongo -from superduper import logging +from superduper import CFG, logging from superduper.backends.base.data_backend import BaseDataBackend from superduper.backends.base.metadata import MetaDataStoreProxy from superduper.backends.ibis.field_types import FieldType @@ -107,10 +107,8 @@ def build_artifact_store(self): def drop_outputs(self): """Drop all outputs.""" for collection in self.db.list_collection_names(): - if collection.startswith('_outputs__'): + if collection.startswith(CFG.output_prefix): self.db.drop_collection(collection) - else: - self.db[collection].update_many({}, {'$unset': {'_outputs': ''}}) def drop_table_or_collection(self, name: str): """Drop the table or collection. @@ -188,7 +186,7 @@ def check_output_dest(self, predict_id) -> bool: :param predict_id: identifier of the prediction """ - return self.db[f'_outputs__{predict_id}'].find_one() is not None + return self.db[f'{CFG.output_prefix}{predict_id}'].find_one() is not None @staticmethod def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None): diff --git a/superduper/backends/mongodb/query.py b/superduper/backends/mongodb/query.py index a811b4d28..af902e8cc 100644 --- a/superduper/backends/mongodb/query.py +++ b/superduper/backends/mongodb/query.py @@ -8,7 +8,7 @@ import pymongo from bson import ObjectId -from superduper import logging +from superduper import CFG, logging from superduper.backends.base.query import ( Query, applies_to, @@ -367,7 +367,9 @@ def drop_outputs(self, predict_id: str): :param predict_ids: The ids of the predictions to select. """ - return self.db.databackend.drop_table_or_collection(f'_outputs__{predict_id}') + return self.db.databackend.drop_table_or_collection( + f'{CFG.output_prefix}{predict_id}' + ) @applies_to('find') def outputs(self, *predict_ids): @@ -459,13 +461,13 @@ def select_ids_of_missing_outputs(self, predict_id: str): { '$and': [ args[0], - {f'_outputs__{predict_id}': {'$exists': 0}}, + {f'{CFG.output_prefix}{predict_id}': {'$exists': 0}}, ] }, *args[1:], ] else: - args = [{f'_outputs__{predict_id}': {'$exists': 0}}] + args = [{f'{CFG.output_prefix}{predict_id}': {'$exists': 0}}] if len(args) == 1: args.append({}) @@ -530,7 +532,7 @@ def model_update( for output, id in zip(outputs, ids): documents.append( { - f'_outputs__{predict_id}': output, + f'{CFG.output_prefix}{predict_id}': output, '_source': ObjectId(id), } ) @@ -538,7 +540,9 @@ def model_update( from superduper.base.datalayer import Datalayer assert isinstance(self.db, Datalayer) - output_query = self.db[f'_outputs__{predict_id}'].insert_many(documents) + output_query = self.db[f'{CFG.output_prefix}{predict_id}'].insert_many( + documents + ) output_query.is_output_query = True output_query.updated_key = predict_id return output_query @@ -616,18 +620,18 @@ def _execute(self, parent, method='encode'): predict_ids = sum([p[1] for p in outputs_parts], ()) pipeline = [] - filter_mapping_base, filter_mapping_outptus = self._get_filter_mapping() + filter_mapping_base, filter_mapping_outputs = self._get_filter_mapping() if filter_mapping_base: pipeline.append({"$match": filter_mapping_base}) project.update({k: 1 for k in filter_mapping_base.keys()}) - predict_ids_in_filter = list(filter_mapping_outptus.keys()) + predict_ids_in_filter = list(filter_mapping_outputs.keys()) predict_ids = list(set(predict_ids).union(predict_ids_in_filter)) # After the join, the complete outputs data can be queried as - # _outputs__{predict_id}._outputs.{predict_id} : result. + # {CFG.output_prefix}{predict_id}._outputs.{predict_id} : result. for predict_id in predict_ids: - key = f'_outputs__{predict_id}' + key = f'{CFG.output_prefix}{predict_id}' lookup = { "$lookup": { "from": key, @@ -640,9 +644,9 @@ def _execute(self, parent, method='encode'): project[key] = 1 pipeline.append(lookup) - if predict_id in filter_mapping_outptus: + if predict_id in filter_mapping_outputs: filter_key, filter_value = list( - filter_mapping_outptus[predict_id].items() + filter_mapping_outputs[predict_id].items() )[0] pipeline.append({"$match": {f'{key}.{filter_key}': filter_value}}) @@ -675,11 +679,11 @@ def _get_filter_mapping(self): filter_mapping_outputs = defaultdict(dict) for key, value in filter.items(): - if '_outputs__' not in key: + if '{CFG.output_prefix}' not in key: filter_mapping_base[key] = value continue - if key.startswith('_outputs__'): + if key.startswith('{CFG.output_prefix}'): predict_id = key.split('__')[1] filter_mapping_outputs[predict_id] = {key: value} @@ -688,7 +692,7 @@ def _get_filter_mapping(self): def _postprocess_result(self, result): """Postprocess the result of the query. - Merge the outputs/_builds/_files/_blobs from the _outputs__* keys to the result + Merge the outputs/_builds/_files/_blobs from the output keys to the result :param result: The result to postprocess. """ @@ -697,7 +701,9 @@ def _postprocess_result(self, result): merge_files = result.get('_files', {}) merge_blobs = result.get('_blobs', {}) - output_keys = {key for key in result.keys() if key.startswith('_outputs__')} + output_keys = { + key for key in result.keys() if key.startswith(CFG.output_prefix) + } for output_key in output_keys: output_data = result[output_key] output_result = output_data[output_key] diff --git a/superduper/base/config.py b/superduper/base/config.py index 76feef04e..15bfacf9e 100644 --- a/superduper/base/config.py +++ b/superduper/base/config.py @@ -290,6 +290,7 @@ class Config(BaseConfig): :param auto_schema: Whether to automatically create the schema. If True, the schema will be created if it does not exist. :param log_colorize: Whether to colorize the logs + :param output_prefix: The prefix for the output table and output field key """ envs: dc.InitVar[t.Optional[t.Dict[str, str]]] = None @@ -312,6 +313,7 @@ class Config(BaseConfig): bytes_encoding: BytesEncoding = BytesEncoding.BYTES auto_schema: bool = True + output_prefix: str = '_outputs__' def __post_init__(self, envs): if envs is not None: diff --git a/superduper/base/datalayer.py b/superduper/base/datalayer.py index 199e76dbd..922be264e 100644 --- a/superduper/base/datalayer.py +++ b/superduper/base/datalayer.py @@ -8,7 +8,7 @@ import networkx import superduper as s -from superduper import logging +from superduper import CFG, logging from superduper.backends.base.artifacts import ArtifactStore from superduper.backends.base.backends import vector_searcher_implementations from superduper.backends.base.compute import ComputeBackend @@ -338,7 +338,7 @@ def _insert( inserted_ids = insert.do_execute(self) cdc_status = s.CFG.cluster.cdc.uri is not None - is_output_table = insert.table.startswith('_outputs__') + is_output_table = insert.table.startswith(CFG.output_prefix) if refresh: if cdc_status and not is_output_table: @@ -440,7 +440,7 @@ def _update(self, update: Query, refresh: bool = True) -> UpdateResult: updated_ids = update.do_execute(self) cdc_status = s.CFG.cluster.cdc.uri is not None - is_output_table = update.table.startswith('_outputs__') + is_output_table = update.table.startswith(CFG.output_prefix) if refresh and updated_ids: if cdc_status and not is_output_table: logging.warn('CDC service is active, skipping model/listener refresh') diff --git a/superduper/base/document.py b/superduper/base/document.py index 16b31e32f..d70b08379 100644 --- a/superduper/base/document.py +++ b/superduper/base/document.py @@ -36,7 +36,6 @@ 'remote_code': Code, } _LEAF_TYPES.update(_ENCODABLES) -_OUTPUTS_KEY = '_outputs' class _Getters: diff --git a/superduper/components/application.py b/superduper/components/application.py index e0d020723..f17217bcf 100644 --- a/superduper/components/application.py +++ b/superduper/components/application.py @@ -1,6 +1,6 @@ import typing as t -from superduper import logging +from superduper import CFG, logging from .component import Component @@ -74,7 +74,7 @@ def build_from_db(cls, identifier, db: "Datalayer"): if any( [ component.type_id == "table" - and component.identifier.startswith("_outputs"), + and component.identifier.startswith(CFG.output_prefix), component.type_id == "schema" and component.identifier.startswith("_schema/"), ] diff --git a/superduper/components/listener.py b/superduper/components/listener.py index 6dd0d5164..d8c097ea2 100644 --- a/superduper/components/listener.py +++ b/superduper/components/listener.py @@ -5,7 +5,6 @@ from superduper import CFG, logging from superduper.backends.base.query import Query -from superduper.base.document import _OUTPUTS_KEY from superduper.components.model import Mapping from superduper.misc.server import request_server @@ -54,7 +53,7 @@ def mapping(self): @property def outputs(self): """Get reference to outputs of listener model.""" - return f'{_OUTPUTS_KEY}__{self.uuid}' + return f'{CFG.output_prefix}{self.uuid}' @property def outputs_key(self): @@ -128,8 +127,8 @@ def dependencies(self): all_ = list(args) + list(kwargs.values()) out = [] for x in all_: - if x.startswith('_outputs__'): - listener_id = x.split('__')[1] + if x.startswith(CFG.output_prefix): + listener_id = x[len(CFG.output_prefix) :] out.append(listener_id) return out diff --git a/superduper/components/model.py b/superduper/components/model.py index efe7969ba..7c998ee11 100644 --- a/superduper/components/model.py +++ b/superduper/components/model.py @@ -14,7 +14,7 @@ import requests import tqdm -from superduper import logging +from superduper import CFG, logging from superduper.backends.base.query import Query from superduper.backends.ibis.field_types import FieldType from superduper.backends.query_dataset import CachedQueryDataset, QueryDataset @@ -380,7 +380,7 @@ def id_key(self): for arg in self.mapping[0]: outputs.append(arg) for key, value in self.mapping[1].items(): - if key.startswith('_outputs__'): + if key.startswith(CFG.output_prefix): key = key.split('.')[1] outputs.append(f'{key}={value}') return ', '.join(outputs) @@ -622,8 +622,6 @@ def _get_ids_from_select( if not overwrite: if ids: select = select.select_using_ids(ids) - if '_outputs' in X: - X = X.split('.')[1] query = select.select_ids_of_missing_outputs(predict_id=predict_id) else: if ids: diff --git a/superduper/components/table.py b/superduper/components/table.py index 015c59b15..45c7c6b7c 100644 --- a/superduper/components/table.py +++ b/superduper/components/table.py @@ -1,5 +1,6 @@ import typing as t +from superduper import CFG from superduper.components.component import Component from superduper.components.schema import Schema, _Native @@ -51,7 +52,7 @@ def pre_create(self, db: 'Datalayer'): for e in self.schema.encoders: db.add(e) if db.databackend.in_memory: - if '_outputs' in self.identifier: + if self.identifier.startswith(CFG.output_prefix): db.databackend.in_memory_tables[ self.identifier ] = db.databackend.create_table_and_schema(self.identifier, self.schema) diff --git a/superduper/components/vector_index.py b/superduper/components/vector_index.py index 21102ce5a..e716d8f18 100644 --- a/superduper/components/vector_index.py +++ b/superduper/components/vector_index.py @@ -79,10 +79,7 @@ def get_vector( """ document = MongoStyleDict(like.unpack()) if outputs is not None: - outputs = outputs or {} - if '_outputs' not in document: - document['_outputs'] = {} - document['_outputs'].update(outputs) + document.update(outputs) assert not isinstance(self.indexing_listener, str) available_keys = list(document.keys()) diff --git a/superduper/misc/eager.py b/superduper/misc/eager.py index 19f719f8d..93f1ffc60 100644 --- a/superduper/misc/eager.py +++ b/superduper/misc/eager.py @@ -4,7 +4,7 @@ import networkx as nx -from superduper import logging +from superduper import CFG, logging from superduper.base.constant import KEY_BLOBS, KEY_BUILDS, KEY_FILES if t.TYPE_CHECKING: @@ -162,7 +162,7 @@ def key(self): if self.type == SuperDuperDataType.DATA: key = ".".join(self.ops) elif self.type == SuperDuperDataType.MODEL_OUTPUT: - prefix = f"_outputs__{self.predict_id}" + prefix = f"{CFG.output_prefix}{self.predict_id}" if self.ops: key = f"{prefix}.{'.'.join(self.ops)}" else: @@ -389,7 +389,7 @@ def _get_select(self, node: SuperDuperData): else: if len(relations) != 1: raise ValueError(_MIXED_FLATTEN_ERROR_MESSAGE) - main_table = f"_outputs.{upstream_node.predict_id}" + main_table = f"{CFG.output_prefix}{upstream_node.predict_id}" predict_ids = [] else: raise ValueError(f"Unknown node type: {upstream_node.type}") diff --git a/superduper/vector_search/atlas.py b/superduper/vector_search/atlas.py index 8ce2bd852..790f078d8 100644 --- a/superduper/vector_search/atlas.py +++ b/superduper/vector_search/atlas.py @@ -71,15 +71,17 @@ def from_component(cls, vi: 'VectorIndex'): assert isinstance( indexing_key, str ), 'Only single key is support for atlas search' - if indexing_key.startswith('_outputs'): - indexing_key = indexing_key.split('.')[1] + if indexing_key.startswith(CFG.output_prefix): + indexing_key = indexing_key[len(CFG.output_prefix) :] assert isinstance(vi.indexing_listener.model, ObjectModel) or isinstance( vi.indexing_listener.model, APIBaseModel ) assert isinstance(collection, str), 'Collection is required to be a string' indexing_model = vi.indexing_listener.model.identifier indexing_version = vi.indexing_listener.model.version - output_path = f'_outputs__{indexing_key}.{indexing_model}.{indexing_version}' + output_path = ( + f'{CFG.output_prefix}{indexing_key}.{indexing_model}.{indexing_version}' + ) return MongoAtlasVectorSearcher( identifier=vi.identifier, @@ -181,6 +183,7 @@ def _create_index(self, collection: str, output_path: str): :param output_path: Path to the output """ _, key, model, version = output_path.split('.') + # TODO: Need to fix this and test it with CFG.output_prefix if re.match(r'^_outputs\.[A-Za-z0-9_]+\.[A-Za-z0-9_]+', key): key = key.split('.')[1] diff --git a/superduper/vector_search/update_tasks.py b/superduper/vector_search/update_tasks.py index 94dbc0588..2b8038a50 100644 --- a/superduper/vector_search/update_tasks.py +++ b/superduper/vector_search/update_tasks.py @@ -1,6 +1,6 @@ import typing as t -from superduper import Document, logging +from superduper import CFG, Document, logging from superduper.backends.base.query import Query from superduper.misc.special_dicts import MongoStyleDict from superduper.vector_search.base import VectorItem @@ -49,14 +49,16 @@ def copy_vectors( docs = [doc.unpack() for doc in docs] key = vi.indexing_listener.key - if '_outputs__' in key: + if CFG.output_prefix in key: key = key.split('.')[1] vectors = [] nokeys = 0 for doc in docs: try: - vector = MongoStyleDict(doc)[f'_outputs__{vi.indexing_listener.predict_id}'] + vector = MongoStyleDict(doc)[ + f'{CFG.output_prefix}{vi.indexing_listener.predict_id}' + ] except KeyError: nokeys += 1 continue diff --git a/test/integration/usecase/test_output_prefix.py b/test/integration/usecase/test_output_prefix.py new file mode 100644 index 000000000..66ff13667 --- /dev/null +++ b/test/integration/usecase/test_output_prefix.py @@ -0,0 +1,38 @@ +import unittest.mock as mock +from test.utils.smoke.chain_listener import build_chain_listener + +from superduper import CFG + + +def test_output_prefix(db): + with mock.patch.object(CFG, "output_prefix", "sddb_outputs_"): + # Mock CFG.output_prefix + build_chain_listener(db) + listener_a = db.listeners["a"] + listener_b = db.listeners["b"] + listener_c = db.listeners["c"] + + assert listener_a.outputs == "sddb_outputs_a" + assert listener_b.outputs == "sddb_outputs_b" + assert listener_c.outputs == "sddb_outputs_c" + + expect_tables = [ + "documents", + "sddb_outputs_a", + "sddb_outputs_b", + "sddb_outputs_c", + ] + + assert set(db.databackend.list_tables_or_collections()) == set(expect_tables) + + outputs_a = list(listener_a.outputs_select.execute()) + assert len(outputs_a) == 6 + assert all("sddb_outputs_a" in x for x in outputs_a) + + outputs_b = list(listener_b.outputs_select.execute()) + assert len(outputs_b) == 6 + assert all("sddb_outputs_b" in x for x in outputs_b) + + outputs_c = list(listener_c.outputs_select.execute()) + assert len(outputs_c) == 6 + assert all("sddb_outputs_c" in x for x in outputs_c) diff --git a/test/utils/smoke/__init__.py b/test/utils/smoke/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/utils/smoke/chain_listener.py b/test/utils/smoke/chain_listener.py new file mode 100644 index 000000000..bac5339c2 --- /dev/null +++ b/test/utils/smoke/chain_listener.py @@ -0,0 +1,57 @@ +import typing as t + +from superduper import ObjectModel + +if t.TYPE_CHECKING: + from superduper.base.datalayer import Datalayer + + +def build_chain_listener(db: "Datalayer"): + db.cfg.auto_schema = True + data = [ + {"x": 1}, + {"x": 2}, + {"x": 3}, + ] + + db["documents"].insert(data).execute() + + data = list(db["documents"].select().execute(eager_mode=True))[0] + + model_a = ObjectModel(identifier="a", object=lambda x: f"{x}->a") + + model_b = ObjectModel(identifier="b", object=lambda x: f"{x}->b") + + model_c = ObjectModel(identifier="c", object=lambda x: f"{x}->c") + + listener_a = model_a.to_listener( + select=db["documents"].select(), + key="x", + uuid="a", + identifier="a", + ) + db.apply(listener_a) + + listener_b = model_b.to_listener( + select=listener_a.outputs_select, + key=listener_a.outputs, + uuid="b", + identifier="b", + ) + db.apply(listener_b) + + listener_c = model_c.to_listener( + select=listener_b.outputs_select, + key=listener_b.outputs, + uuid="c", + identifier="c", + ) + db.apply(listener_c) + + data = [ + {"x": 4}, + {"x": 5}, + {"x": 6}, + ] + + db["documents"].insert(data).execute()