Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support: Fixd UUID and prefix outputs string #2358

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Use different colors to distinguish logs
- Change all the `_outputs.` to `_outputs__`
- Disable cdc on output tables.
- Remove `-` from the uuid of the component.


#### New Features & Functionality
Expand All @@ -31,6 +32,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Support eager mode
- Add CSN env var
- Make tests configurable against backend
- Make the prefix of the output string configurable


#### Bug Fixes
Expand Down
15 changes: 9 additions & 6 deletions superduper/backends/base/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import abstractmethod
from functools import wraps

from superduper import CFG
from superduper.base.document import Document, _unpack
from superduper.base.leaf import Leaf

Expand Down Expand Up @@ -327,7 +328,8 @@ def _wrap_document(document):
field = [
k
for k in table.schema.fields
if k not in [self.primary_id, '_fold', '_outputs']
if k not in [self.primary_id, '_fold']
and not k.startswith(CFG.output_prefix)
]
assert len(field) == 1
document = Document({field[0]: document})
Expand Down Expand Up @@ -648,12 +650,13 @@ def table_or_collection(self):


def _parse_query_part(part, documents, query, builder_cls, db=None):
key = part.split('.')
if key[0] == '_outputs':
table = f'{key[0]}.{key[1]}'
part = part.split('.')[2:]
if '.' in CFG.output_prefix and part.startswith(CFG.output_prefix):
rest_part = part[len(CFG.output_prefix) :].split('.')
table = f'{CFG.output_prefix}{rest_part[0]}'
part = rest_part[1:]

else:
table = key[0]
table = part.split('.')[0]
part = part.split('.')[1:]

current = builder_cls(table=table, parts=(), db=db)
Expand Down
7 changes: 4 additions & 3 deletions superduper/backends/ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pandas.core.frame import DataFrame
from sqlalchemy.exc import NoSuchTableError

from superduper import CFG
from superduper.backends.base.data_backend import BaseDataBackend
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.backends.ibis.db_helper import get_db_helper
Expand Down Expand Up @@ -167,10 +168,10 @@ def create_output_dest(
fields = {
INPUT_KEY: dtype('string'),
'id': dtype('string'),
f'_outputs__{predict_id}': output_type,
f'{CFG.output_prefix}{predict_id}': output_type,
}
return Table(
identifier=f'_outputs__{predict_id}',
identifier=f'{CFG.output_prefix}{predict_id}',
schema=Schema(identifier=f'_schema/{predict_id}', fields=fields),
)

Expand All @@ -180,7 +181,7 @@ def check_output_dest(self, predict_id) -> bool:
:param predict_id: The identifier of the prediction.
"""
try:
self.conn.table(f'_outputs__{predict_id}')
self.conn.table(f'{CFG.output_prefix}{predict_id}')
return True
except (NoSuchTableError, ibis.IbisError):
return False
Expand Down
14 changes: 8 additions & 6 deletions superduper/backends/ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas

from superduper import Document
from superduper import CFG, Document
from superduper.backends.base.query import (
Query,
applies_to,
Expand Down Expand Up @@ -83,13 +83,13 @@ def _model_update_impl(
for output, source_id in zip(outputs, ids):
d = {
'_source': str(source_id),
f'_outputs__{predict_id}': output.x
f'{CFG.output_prefix}{predict_id}': output.x
if isinstance(output, Encodable)
else output,
'id': str(uuid.uuid4()),
}
documents.append(Document(d))
return db[f'_outputs__{predict_id}'].insert(documents)
return db[f'{CFG.output_prefix}{predict_id}'].insert(documents)


class IbisQuery(Query):
Expand Down Expand Up @@ -355,7 +355,7 @@ def drop_outputs(self, predict_id: str):
:param predict_ids: The ids of the predictions to select.
"""
return self.db.databackend.conn.drop_table(f'_outputs__{predict_id}')
return self.db.databackend.conn.drop_table(f'{CFG.output_prefix}{predict_id}')

@applies_to('select')
def outputs(self, *predict_ids):
Expand All @@ -378,7 +378,9 @@ def outputs(self, *predict_ids):
attr = getattr(query, self.primary_id)
for identifier in predict_ids:
identifier = (
identifier if '_outputs' in identifier else f'_outputs__{identifier}'
identifier
if identifier.startswith(CFG.output_prefix)
else f'{CFG.output_prefix}{identifier}'
)
symbol_table = self.db[identifier]

Expand All @@ -399,7 +401,7 @@ def select_ids_of_missing_outputs(self, predict_id: str):

assert isinstance(self.db, Datalayer)

output_table = self.db[f'_outputs__{predict_id}']
output_table = self.db[f'{CFG.output_prefix}{predict_id}']
return self.anti_join(
output_table,
output_table._source == getattr(self, self.primary_id),
Expand Down
8 changes: 3 additions & 5 deletions superduper/backends/mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import mongomock
import pymongo

from superduper import logging
from superduper import CFG, logging
from superduper.backends.base.data_backend import BaseDataBackend
from superduper.backends.base.metadata import MetaDataStoreProxy
from superduper.backends.ibis.field_types import FieldType
Expand Down Expand Up @@ -107,10 +107,8 @@ def build_artifact_store(self):
def drop_outputs(self):
"""Drop all outputs."""
for collection in self.db.list_collection_names():
if collection.startswith('_outputs__'):
if collection.startswith(CFG.output_prefix):
self.db.drop_collection(collection)
else:
self.db[collection].update_many({}, {'$unset': {'_outputs': ''}})

def drop_table_or_collection(self, name: str):
"""Drop the table or collection.
Expand Down Expand Up @@ -188,7 +186,7 @@ def check_output_dest(self, predict_id) -> bool:
:param predict_id: identifier of the prediction
"""
return self.db[f'_outputs__{predict_id}'].find_one() is not None
return self.db[f'{CFG.output_prefix}{predict_id}'].find_one() is not None

@staticmethod
def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None):
Expand Down
38 changes: 22 additions & 16 deletions superduper/backends/mongodb/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pymongo
from bson import ObjectId

from superduper import logging
from superduper import CFG, logging
from superduper.backends.base.query import (
Query,
applies_to,
Expand Down Expand Up @@ -367,7 +367,9 @@ def drop_outputs(self, predict_id: str):
:param predict_ids: The ids of the predictions to select.
"""
return self.db.databackend.drop_table_or_collection(f'_outputs__{predict_id}')
return self.db.databackend.drop_table_or_collection(
f'{CFG.output_prefix}{predict_id}'
)

@applies_to('find')
def outputs(self, *predict_ids):
Expand Down Expand Up @@ -459,13 +461,13 @@ def select_ids_of_missing_outputs(self, predict_id: str):
{
'$and': [
args[0],
{f'_outputs__{predict_id}': {'$exists': 0}},
{f'{CFG.output_prefix}{predict_id}': {'$exists': 0}},
]
},
*args[1:],
]
else:
args = [{f'_outputs__{predict_id}': {'$exists': 0}}]
args = [{f'{CFG.output_prefix}{predict_id}': {'$exists': 0}}]

if len(args) == 1:
args.append({})
Expand Down Expand Up @@ -530,15 +532,17 @@ def model_update(
for output, id in zip(outputs, ids):
documents.append(
{
f'_outputs__{predict_id}': output,
f'{CFG.output_prefix}{predict_id}': output,
'_source': ObjectId(id),
}
)

from superduper.base.datalayer import Datalayer

assert isinstance(self.db, Datalayer)
output_query = self.db[f'_outputs__{predict_id}'].insert_many(documents)
output_query = self.db[f'{CFG.output_prefix}{predict_id}'].insert_many(
documents
)
output_query.is_output_query = True
output_query.updated_key = predict_id
return output_query
Expand Down Expand Up @@ -616,18 +620,18 @@ def _execute(self, parent, method='encode'):
predict_ids = sum([p[1] for p in outputs_parts], ())

pipeline = []
filter_mapping_base, filter_mapping_outptus = self._get_filter_mapping()
filter_mapping_base, filter_mapping_outputs = self._get_filter_mapping()
if filter_mapping_base:
pipeline.append({"$match": filter_mapping_base})
project.update({k: 1 for k in filter_mapping_base.keys()})

predict_ids_in_filter = list(filter_mapping_outptus.keys())
predict_ids_in_filter = list(filter_mapping_outputs.keys())

predict_ids = list(set(predict_ids).union(predict_ids_in_filter))
# After the join, the complete outputs data can be queried as
# _outputs__{predict_id}._outputs.{predict_id} : result.
# {CFG.output_prefix}{predict_id}._outputs.{predict_id} : result.
for predict_id in predict_ids:
key = f'_outputs__{predict_id}'
key = f'{CFG.output_prefix}{predict_id}'
lookup = {
"$lookup": {
"from": key,
Expand All @@ -640,9 +644,9 @@ def _execute(self, parent, method='encode'):
project[key] = 1
pipeline.append(lookup)

if predict_id in filter_mapping_outptus:
if predict_id in filter_mapping_outputs:
filter_key, filter_value = list(
filter_mapping_outptus[predict_id].items()
filter_mapping_outputs[predict_id].items()
)[0]
pipeline.append({"$match": {f'{key}.{filter_key}': filter_value}})

Expand Down Expand Up @@ -675,11 +679,11 @@ def _get_filter_mapping(self):
filter_mapping_outputs = defaultdict(dict)

for key, value in filter.items():
if '_outputs__' not in key:
if '{CFG.output_prefix}' not in key:
filter_mapping_base[key] = value
continue

if key.startswith('_outputs__'):
if key.startswith('{CFG.output_prefix}'):
predict_id = key.split('__')[1]
filter_mapping_outputs[predict_id] = {key: value}

Expand All @@ -688,7 +692,7 @@ def _get_filter_mapping(self):
def _postprocess_result(self, result):
"""Postprocess the result of the query.
Merge the outputs/_builds/_files/_blobs from the _outputs__* keys to the result
Merge the outputs/_builds/_files/_blobs from the output keys to the result
:param result: The result to postprocess.
"""
Expand All @@ -697,7 +701,9 @@ def _postprocess_result(self, result):
merge_files = result.get('_files', {})
merge_blobs = result.get('_blobs', {})

output_keys = {key for key in result.keys() if key.startswith('_outputs__')}
output_keys = {
key for key in result.keys() if key.startswith(CFG.output_prefix)
}
for output_key in output_keys:
output_data = result[output_key]
output_result = output_data[output_key]
Expand Down
2 changes: 2 additions & 0 deletions superduper/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ class Config(BaseConfig):
:param auto_schema: Whether to automatically create the schema.
If True, the schema will be created if it does not exist.
:param log_colorize: Whether to colorize the logs
:param output_prefix: The prefix for the output table and output field key
"""

envs: dc.InitVar[t.Optional[t.Dict[str, str]]] = None
Expand All @@ -312,6 +313,7 @@ class Config(BaseConfig):

bytes_encoding: BytesEncoding = BytesEncoding.BYTES
auto_schema: bool = True
output_prefix: str = '_outputs__'

def __post_init__(self, envs):
if envs is not None:
Expand Down
6 changes: 3 additions & 3 deletions superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import networkx

import superduper as s
from superduper import logging
from superduper import CFG, logging
from superduper.backends.base.artifacts import ArtifactStore
from superduper.backends.base.backends import vector_searcher_implementations
from superduper.backends.base.compute import ComputeBackend
Expand Down Expand Up @@ -338,7 +338,7 @@ def _insert(
inserted_ids = insert.do_execute(self)

cdc_status = s.CFG.cluster.cdc.uri is not None
is_output_table = insert.table.startswith('_outputs__')
is_output_table = insert.table.startswith(CFG.output_prefix)

if refresh:
if cdc_status and not is_output_table:
Expand Down Expand Up @@ -440,7 +440,7 @@ def _update(self, update: Query, refresh: bool = True) -> UpdateResult:
updated_ids = update.do_execute(self)

cdc_status = s.CFG.cluster.cdc.uri is not None
is_output_table = update.table.startswith('_outputs__')
is_output_table = update.table.startswith(CFG.output_prefix)
if refresh and updated_ids:
if cdc_status and not is_output_table:
logging.warn('CDC service is active, skipping model/listener refresh')
Expand Down
1 change: 0 additions & 1 deletion superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
'remote_code': Code,
}
_LEAF_TYPES.update(_ENCODABLES)
_OUTPUTS_KEY = '_outputs'


class _Getters:
Expand Down
2 changes: 1 addition & 1 deletion superduper/base/leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class Leaf(metaclass=LeafMeta):

identifier: str
db: dc.InitVar[t.Optional['Datalayer']] = None
uuid: str = dc.field(default_factory=lambda: str(uuid.uuid4()))
uuid: str = dc.field(default_factory=lambda: str(uuid.uuid4()).replace('-', ''))

@property
def metadata(self):
Expand Down
4 changes: 2 additions & 2 deletions superduper/components/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing as t

from superduper import logging
from superduper import CFG, logging

from .component import Component

Expand Down Expand Up @@ -74,7 +74,7 @@ def build_from_db(cls, identifier, db: "Datalayer"):
if any(
[
component.type_id == "table"
and component.identifier.startswith("_outputs"),
and component.identifier.startswith(CFG.output_prefix),
component.type_id == "schema"
and component.identifier.startswith("_schema/"),
]
Expand Down
7 changes: 3 additions & 4 deletions superduper/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from superduper import CFG, logging
from superduper.backends.base.query import Query
from superduper.base.document import _OUTPUTS_KEY
from superduper.components.model import Mapping
from superduper.misc.server import request_server

Expand Down Expand Up @@ -54,7 +53,7 @@ def mapping(self):
@property
def outputs(self):
"""Get reference to outputs of listener model."""
return f'{_OUTPUTS_KEY}__{self.uuid}'
return f'{CFG.output_prefix}{self.uuid}'

@property
def outputs_key(self):
Expand Down Expand Up @@ -128,8 +127,8 @@ def dependencies(self):
all_ = list(args) + list(kwargs.values())
out = []
for x in all_:
if x.startswith('_outputs__'):
listener_id = x.split('__')[1]
if x.startswith(CFG.output_prefix):
listener_id = x[len(CFG.output_prefix) :]
out.append(listener_id)
return out

Expand Down
Loading
Loading