Skip to content

Commit

Permalink
Make the prefix of the output string configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou authored and blythed committed Jul 31, 2024
1 parent 618f436 commit ae852ec
Show file tree
Hide file tree
Showing 20 changed files with 172 additions and 64 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Use different colors to distinguish logs
- Change all the `_outputs.` to `_outputs__`
- Disable cdc on output tables.
- Remove `-` from the uuid of the component.


#### New Features & Functionality
Expand All @@ -31,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support eager mode
- Add CSN env var
- Make tests configurable against backend
- Make the prefix of the output string configurable


#### Bug Fixes
Expand Down
15 changes: 9 additions & 6 deletions superduper/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod
from functools import wraps

from superduper import CFG
from superduper.base.document import Document, _unpack
from superduper.base.leaf import Leaf

Expand Down Expand Up @@ -327,7 +328,8 @@ def _wrap_document(document):
field = [
k
for k in table.schema.fields
if k not in [self.primary_id, '_fold', '_outputs']
if k not in [self.primary_id, '_fold']
and not k.startswith(CFG.output_prefix)
]
assert len(field) == 1
document = Document({field[0]: document})
Expand Down Expand Up @@ -648,12 +650,13 @@ def table_or_collection(self):


def _parse_query_part(part, documents, query, builder_cls, db=None):
key = part.split('.')
if key[0] == '_outputs':
table = f'{key[0]}.{key[1]}'
part = part.split('.')[2:]
if '.' in CFG.output_prefix and part.startswith(CFG.output_prefix):
rest_part = part[len(CFG.output_prefix) :].split('.')
table = f'{CFG.output_prefix}{rest_part[0]}'
part = rest_part[1:]

else:
table = key[0]
table = part.split('.')[0]
part = part.split('.')[1:]

current = builder_cls(table=table, parts=(), db=db)
Expand Down
7 changes: 4 additions & 3 deletions superduper/backends/ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pandas.core.frame import DataFrame
from sqlalchemy.exc import NoSuchTableError

from superduper import CFG
from superduper.backends.base.data_backend import BaseDataBackend
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.backends.ibis.db_helper import get_db_helper
Expand Down Expand Up @@ -167,10 +168,10 @@ def create_output_dest(
fields = {
INPUT_KEY: dtype('string'),
'id': dtype('string'),
f'_outputs__{predict_id}': output_type,
f'{CFG.output_prefix}{predict_id}': output_type,
}
return Table(
identifier=f'_outputs__{predict_id}',
identifier=f'{CFG.output_prefix}{predict_id}',
schema=Schema(identifier=f'_schema/{predict_id}', fields=fields),
)

Expand All @@ -180,7 +181,7 @@ def check_output_dest(self, predict_id) -> bool:
:param predict_id: The identifier of the prediction.
"""
try:
self.conn.table(f'_outputs__{predict_id}')
self.conn.table(f'{CFG.output_prefix}{predict_id}')
return True
except (NoSuchTableError, ibis.IbisError):
return False
Expand Down
14 changes: 8 additions & 6 deletions superduper/backends/ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas

from superduper import Document
from superduper import CFG, Document
from superduper.backends.base.query import (
Query,
applies_to,
Expand Down Expand Up @@ -83,13 +83,13 @@ def _model_update_impl(
for output, source_id in zip(outputs, ids):
d = {
'_source': str(source_id),
f'_outputs__{predict_id}': output.x
f'{CFG.output_prefix}{predict_id}': output.x
if isinstance(output, Encodable)
else output,
'id': str(uuid.uuid4()),
}
documents.append(Document(d))
return db[f'_outputs__{predict_id}'].insert(documents)
return db[f'{CFG.output_prefix}{predict_id}'].insert(documents)


class IbisQuery(Query):
Expand Down Expand Up @@ -355,7 +355,7 @@ def drop_outputs(self, predict_id: str):
:param predict_ids: The ids of the predictions to select.
"""
return self.db.databackend.conn.drop_table(f'_outputs__{predict_id}')
return self.db.databackend.conn.drop_table(f'{CFG.output_prefix}{predict_id}')

@applies_to('select')
def outputs(self, *predict_ids):
Expand All @@ -378,7 +378,9 @@ def outputs(self, *predict_ids):
attr = getattr(query, self.primary_id)
for identifier in predict_ids:
identifier = (
identifier if '_outputs' in identifier else f'_outputs__{identifier}'
identifier
if identifier.startswith(CFG.output_prefix)
else f'{CFG.output_prefix}{identifier}'
)
symbol_table = self.db[identifier]

Expand All @@ -399,7 +401,7 @@ def select_ids_of_missing_outputs(self, predict_id: str):

assert isinstance(self.db, Datalayer)

output_table = self.db[f'_outputs__{predict_id}']
output_table = self.db[f'{CFG.output_prefix}{predict_id}']
return self.anti_join(
output_table,
output_table._source == getattr(self, self.primary_id),
Expand Down
8 changes: 3 additions & 5 deletions superduper/backends/mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mongomock
import pymongo

from superduper import logging
from superduper import CFG, logging
from superduper.backends.base.data_backend import BaseDataBackend
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.backends.ibis.field_types import FieldType
Expand Down Expand Up @@ -107,10 +107,8 @@ def build_artifact_store(self):
def drop_outputs(self):
"""Drop all outputs."""
for collection in self.db.list_collection_names():
if collection.startswith('_outputs__'):
if collection.startswith(CFG.output_prefix):
self.db.drop_collection(collection)
else:
self.db[collection].update_many({}, {'$unset': {'_outputs': ''}})

def drop_table_or_collection(self, name: str):
"""Drop the table or collection.
Expand Down Expand Up @@ -188,7 +186,7 @@ def check_output_dest(self, predict_id) -> bool:
:param predict_id: identifier of the prediction
"""
return self.db[f'_outputs__{predict_id}'].find_one() is not None
return self.db[f'{CFG.output_prefix}{predict_id}'].find_one() is not None

@staticmethod
def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None):
Expand Down
38 changes: 22 additions & 16 deletions superduper/backends/mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pymongo
from bson import ObjectId

from superduper import logging
from superduper import CFG, logging
from superduper.backends.base.query import (
Query,
applies_to,
Expand Down Expand Up @@ -367,7 +367,9 @@ def drop_outputs(self, predict_id: str):
:param predict_ids: The ids of the predictions to select.
"""
return self.db.databackend.drop_table_or_collection(f'_outputs__{predict_id}')
return self.db.databackend.drop_table_or_collection(
f'{CFG.output_prefix}{predict_id}'
)

@applies_to('find')
def outputs(self, *predict_ids):
Expand Down Expand Up @@ -459,13 +461,13 @@ def select_ids_of_missing_outputs(self, predict_id: str):
{
'$and': [
args[0],
{f'_outputs__{predict_id}': {'$exists': 0}},
{f'{CFG.output_prefix}{predict_id}': {'$exists': 0}},
]
},
*args[1:],
]
else:
args = [{f'_outputs__{predict_id}': {'$exists': 0}}]
args = [{f'{CFG.output_prefix}{predict_id}': {'$exists': 0}}]

if len(args) == 1:
args.append({})
Expand Down Expand Up @@ -530,15 +532,17 @@ def model_update(
for output, id in zip(outputs, ids):
documents.append(
{
f'_outputs__{predict_id}': output,
f'{CFG.output_prefix}{predict_id}': output,
'_source': ObjectId(id),
}
)

from superduper.base.datalayer import Datalayer

assert isinstance(self.db, Datalayer)
output_query = self.db[f'_outputs__{predict_id}'].insert_many(documents)
output_query = self.db[f'{CFG.output_prefix}{predict_id}'].insert_many(
documents
)
output_query.is_output_query = True
output_query.updated_key = predict_id
return output_query
Expand Down Expand Up @@ -616,18 +620,18 @@ def _execute(self, parent, method='encode'):
predict_ids = sum([p[1] for p in outputs_parts], ())

pipeline = []
filter_mapping_base, filter_mapping_outptus = self._get_filter_mapping()
filter_mapping_base, filter_mapping_outputs = self._get_filter_mapping()
if filter_mapping_base:
pipeline.append({"$match": filter_mapping_base})
project.update({k: 1 for k in filter_mapping_base.keys()})

predict_ids_in_filter = list(filter_mapping_outptus.keys())
predict_ids_in_filter = list(filter_mapping_outputs.keys())

predict_ids = list(set(predict_ids).union(predict_ids_in_filter))
# After the join, the complete outputs data can be queried as
# _outputs__{predict_id}._outputs.{predict_id} : result.
# {CFG.output_prefix}{predict_id}._outputs.{predict_id} : result.
for predict_id in predict_ids:
key = f'_outputs__{predict_id}'
key = f'{CFG.output_prefix}{predict_id}'
lookup = {
"$lookup": {
"from": key,
Expand All @@ -640,9 +644,9 @@ def _execute(self, parent, method='encode'):
project[key] = 1
pipeline.append(lookup)

if predict_id in filter_mapping_outptus:
if predict_id in filter_mapping_outputs:
filter_key, filter_value = list(
filter_mapping_outptus[predict_id].items()
filter_mapping_outputs[predict_id].items()
)[0]
pipeline.append({"$match": {f'{key}.{filter_key}': filter_value}})

Expand Down Expand Up @@ -675,11 +679,11 @@ def _get_filter_mapping(self):
filter_mapping_outputs = defaultdict(dict)

for key, value in filter.items():
if '_outputs__' not in key:
if '{CFG.output_prefix}' not in key:
filter_mapping_base[key] = value
continue

if key.startswith('_outputs__'):
if key.startswith('{CFG.output_prefix}'):
predict_id = key.split('__')[1]
filter_mapping_outputs[predict_id] = {key: value}

Expand All @@ -688,7 +692,7 @@ def _get_filter_mapping(self):
def _postprocess_result(self, result):
"""Postprocess the result of the query.
Merge the outputs/_builds/_files/_blobs from the _outputs__* keys to the result
Merge the outputs/_builds/_files/_blobs from the output keys to the result
:param result: The result to postprocess.
"""
Expand All @@ -697,7 +701,9 @@ def _postprocess_result(self, result):
merge_files = result.get('_files', {})
merge_blobs = result.get('_blobs', {})

output_keys = {key for key in result.keys() if key.startswith('_outputs__')}
output_keys = {
key for key in result.keys() if key.startswith(CFG.output_prefix)
}
for output_key in output_keys:
output_data = result[output_key]
output_result = output_data[output_key]
Expand Down
2 changes: 2 additions & 0 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ class Config(BaseConfig):
:param auto_schema: Whether to automatically create the schema.
If True, the schema will be created if it does not exist.
:param log_colorize: Whether to colorize the logs
:param output_prefix: The prefix for the output table and output field key
"""

envs: dc.InitVar[t.Optional[t.Dict[str, str]]] = None
Expand All @@ -312,6 +313,7 @@ class Config(BaseConfig):

bytes_encoding: BytesEncoding = BytesEncoding.BYTES
auto_schema: bool = True
output_prefix: str = '_outputs__'

def __post_init__(self, envs):
if envs is not None:
Expand Down
6 changes: 3 additions & 3 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import networkx

import superduper as s
from superduper import logging
from superduper import CFG, logging
from superduper.backends.base.artifacts import ArtifactStore
from superduper.backends.base.backends import vector_searcher_implementations
from superduper.backends.base.compute import ComputeBackend
Expand Down Expand Up @@ -338,7 +338,7 @@ def _insert(
inserted_ids = insert.do_execute(self)

cdc_status = s.CFG.cluster.cdc.uri is not None
is_output_table = insert.table.startswith('_outputs__')
is_output_table = insert.table.startswith(CFG.output_prefix)

if refresh:
if cdc_status and not is_output_table:
Expand Down Expand Up @@ -440,7 +440,7 @@ def _update(self, update: Query, refresh: bool = True) -> UpdateResult:
updated_ids = update.do_execute(self)

cdc_status = s.CFG.cluster.cdc.uri is not None
is_output_table = update.table.startswith('_outputs__')
is_output_table = update.table.startswith(CFG.output_prefix)
if refresh and updated_ids:
if cdc_status and not is_output_table:
logging.warn('CDC service is active, skipping model/listener refresh')
Expand Down
1 change: 0 additions & 1 deletion superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
'remote_code': Code,
}
_LEAF_TYPES.update(_ENCODABLES)
_OUTPUTS_KEY = '_outputs'


class _Getters:
Expand Down
4 changes: 2 additions & 2 deletions superduper/components/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing as t

from superduper import logging
from superduper import CFG, logging

from .component import Component

Expand Down Expand Up @@ -74,7 +74,7 @@ def build_from_db(cls, identifier, db: "Datalayer"):
if any(
[
component.type_id == "table"
and component.identifier.startswith("_outputs"),
and component.identifier.startswith(CFG.output_prefix),
component.type_id == "schema"
and component.identifier.startswith("_schema/"),
]
Expand Down
7 changes: 3 additions & 4 deletions superduper/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from superduper import CFG, logging
from superduper.backends.base.query import Query
from superduper.base.document import _OUTPUTS_KEY
from superduper.components.model import Mapping
from superduper.misc.server import request_server

Expand Down Expand Up @@ -54,7 +53,7 @@ def mapping(self):
@property
def outputs(self):
"""Get reference to outputs of listener model."""
return f'{_OUTPUTS_KEY}__{self.uuid}'
return f'{CFG.output_prefix}{self.uuid}'

@property
def outputs_key(self):
Expand Down Expand Up @@ -128,8 +127,8 @@ def dependencies(self):
all_ = list(args) + list(kwargs.values())
out = []
for x in all_:
if x.startswith('_outputs__'):
listener_id = x.split('__')[1]
if x.startswith(CFG.output_prefix):
listener_id = x[len(CFG.output_prefix) :]
out.append(listener_id)
return out

Expand Down
Loading

0 comments on commit ae852ec

Please sign in to comment.