Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/template initialization 2 #2532

Merged
merged 2 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci_code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ jobs:
key: ${{ matrix.os }}_python-${{ matrix.python-version }}_${{ hashFiles('pyproject.toml', '*/pyproject.toml') }}
#---------------------------------------------------

- name: Install SuperDuperDB Project
- name: Install superduper-framework
run: |
# Install core and testsuite dependencies on the cached python environment.
python -m pip install '.[test]'
Expand Down
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Before you create a Pull Request, remember to update the Changelog with your changes.**

## Changes Since Last Release
## Changes Since Last Release

#### Changed defaults / behaviours

Expand Down
5 changes: 2 additions & 3 deletions plugins/mongodb/superduper_mongodb/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ def __init__(
else:
self.conn = conn
self.name = name
self.db = self.conn[self.name]
self.filesystem = gridfs.GridFS(self.db)
self.filesystem = gridfs.GridFS(self.conn[self.name])

def url(self):
"""Return the URL of the database."""
Expand All @@ -55,7 +54,7 @@ def drop(self, force: bool = False):
default=False,
):
logging.warn('Aborting...')
return self.db.client.drop_database(self.db.name)
return self.conn.drop_database(self.name)

def _exists(self, file_id):
return self.filesystem.find_one({'file_id': file_id}) is not None
Expand Down
1 change: 1 addition & 0 deletions superduper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@
'requires_packages',
'pickle_serializer',
'dill_serializer',
'templates',
)
6 changes: 5 additions & 1 deletion superduper/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,15 @@ def apply(

logging.info('JOBS EVENTS:')
steps = {j.job_id: str(i) for i, j in enumerate(unique_job_events)}

def uniquify(x):
return sorted(list(set(x)))

for i, j in enumerate(unique_job_events):
if j.dependencies:
logging.info(
f'[{i}]: {j.huuid}: {j.method} ~ '
f'[{",".join([steps[d] for d in j.dependencies])}]'
f'[{",".join(uniquify([steps[d] for d in j.dependencies]))}]'
)
else:
logging.info(f'[{i}]: {j.huuid}: {j.method}')
Expand Down
7 changes: 5 additions & 2 deletions superduper/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def _get_leaf_from_cache(k, builds, getters, db: t.Optional['Datalayer'] = None)
builds[k] = leaf
to_del = []
keys = list(builds.keys())

for other in keys:
import re

Expand All @@ -568,7 +569,7 @@ def _get_leaf_from_cache(k, builds, getters, db: t.Optional['Datalayer'] = None)
for match in matches:
got = _get_leaf_from_cache(match, builds, getters, db=db)
other = other.replace(f'?({match})', got)
builds[other] = leaf
builds[other] = builds[old_other]
to_del.append(old_other)

for other in to_del:
Expand All @@ -577,8 +578,10 @@ def _get_leaf_from_cache(k, builds, getters, db: t.Optional['Datalayer'] = None)
if isinstance(leaf, Leaf):
if not leaf.db:
leaf.db = db

if attribute is not None:
return getattr(leaf, attribute)

return leaf


Expand Down Expand Up @@ -608,7 +611,7 @@ def _deep_flat_decode(r, builds, getters: _Getters, db: t.Optional['Datalayer']
if isinstance(r, dict):
literals = r.get('_literals', [])
return {
k: (
_deep_flat_decode(k, builds, getters=getters, db=db): (
_deep_flat_decode(v, builds, getters=getters, db=db)
if k not in literals
else v
Expand Down
1 change: 1 addition & 0 deletions superduper/base/leaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def variables(self) -> t.List[str]:

return list(set(_find_variables(self.encode())))

# TODO this is buggy - defaults don't work
@property
def defaults(self):
"""Get the default parameter values."""
Expand Down
83 changes: 52 additions & 31 deletions superduper/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import typing as t

from superduper import Component, logging, superduper
from superduper.components.template import Template
Expand All @@ -23,15 +24,13 @@ def start(
remote_port: int = 8000,
host: str = 'localhost',
headless: bool = False,
templates: bool = True,
):
"""Start the rest server and user interface.

:param port: Port to run the server on.
:param remote_port: Port to connect to remotely
:param host: Host to connect to remotely
:param headless: Toggle to ``True`` to suppress the browser.
:param templates: Toggle to ``False`` to suppress initializing templates.
"""
from superduper.rest.base import SuperDuperApp
from superduper.rest.build import build_frontend, build_rest_app
Expand All @@ -46,34 +45,27 @@ def start(
else:
logging.warn('Frontend pointing to remote server!')

build_frontend(app, port=remote_port, host=host)

if headless:
app.start()
else:
import threading
import time
import webbrowser
if not headless:
build_frontend(app, port=remote_port, host=host)

server_thread = threading.Thread(target=lambda: app.start())
server_thread.start()
logging.info('Waiting for server to start')
app.start()

time.sleep(3)

if templates:
from superduper import templates
@command(help='Initialize a template in the system')
def bootstrap(templates: t.List[str] | None = None):
"""Initialize a template in the system.

db = app.app.state.pool
existing = db.show('template')
prebuilt = templates.ls()
for t in prebuilt:
if t not in existing:
logging.info(f'Applying template \'{t}\'')
db.apply(getattr(templates, t), force=True)
:param templates: List of templates to initialize.
"""
from superduper import templates as inbuilt

if not headless:
webbrowser.open(f'http://localhost:{port}')
if templates is None:
templates = inbuilt.ls()
templates = ['rag']
db = superduper()
for tem in templates:
tem = getattr(inbuilt, tem)
db.apply(tem, force=True)


@command(help='Apply a template or application to a `superduper` deployment')
Expand All @@ -89,6 +81,24 @@ def ls():
print(r)


@command(help='Show available components')
def show(
type_id: str | None = None,
identifier: str | None = None,
version: int | None = None,
):
"""Apply a serialized component.

:param name: Path or name of the template/ component.
:param values: JSON string of values to apply to the template.
"""
db = superduper()
to_show = db.show(type_id=type_id, identifier=identifier, version=version)
import json

print(json.dumps(to_show, indent=2))


@command(help='`superduper` deployment')
def drop(data: bool = False, force: bool = False):
"""Apply a serialized component.
Expand All @@ -106,17 +116,28 @@ def _build_from_template(t):
loaded = json.loads(variables)
return t(**loaded)

db = superduper()

if os.path.exists(name):
try:
with open(name + '/component.json', 'r') as f:
info = json.load(f)
if info['type_id'] == 'template':
t = Template.read(name)
c = _build_from_template(t)
except Exception as e:
if 'Expecting' in str(e):
c = Component.read(name)
else:
c = Component.read(name)
else:
from superduper import templates
existing = db.show('template')
if name not in existing:
from superduper import templates

try:
t = getattr(templates, name)
except AttributeError:
raise Exception(f'No pre-built template found of that name: {name}')
else:
t = db.load('template', name)

t = getattr(templates, name)
c = _build_from_template(t)

try:
Expand Down
3 changes: 0 additions & 3 deletions superduper/components/cdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class CDC(Component):
type_id: t.ClassVar[str] = 'cdc'
cdc_table: str

def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts)

def declare_component(self, cluster):
"""Declare the component to the cluster.

Expand Down
2 changes: 1 addition & 1 deletion superduper/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def export(
path: t.Optional[str] = None,
format: str = "json",
zip: bool = False,
defaults: bool = False,
defaults: bool = True,
metadata: bool = False,
hr: bool = False,
component: str = 'component',
Expand Down
1 change: 1 addition & 0 deletions superduper/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def _pre_create(self, db: Datalayer, startup_cache: t.Dict = {}):
"""Pre-create hook."""
if self.select is None:
return

if not db.cfg.auto_schema:
db.startup_cache[self.outputs] = None
return
Expand Down
20 changes: 20 additions & 0 deletions superduper/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def model(
model_update_kwargs: t.Optional[t.Dict] = None,
output_schema: t.Optional[Schema] = None,
num_workers: int = 0,
example: t.Any | None = None,
signature: Signature = '*args,**kwargs',
):
"""Decorator to wrap a function with `ObjectModel`.

Expand All @@ -55,6 +57,8 @@ def model(
:param model_update_kwargs: Dictionary to define update kwargs.
:param output_schema: Schema for the model outputs.
:param num_workers: Number of workers to use for parallel processing
:param example: Example to auto-determine the schema/ datatype.
:param signature: Signature for the model.
"""
if item is not None and (inspect.isclass(item) or callable(item)):
if inspect.isclass(item):
Expand All @@ -64,6 +68,12 @@ def object_model_factory(*args, **kwargs):
return ObjectModel(
object=object_,
identifier=identifier or object_.__class__.__name__,
datatype=datatype,
model_update_kwargs=model_update_kwargs or {},
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)

return object_model_factory
Expand All @@ -72,6 +82,12 @@ def object_model_factory(*args, **kwargs):
return ObjectModel(
identifier=item.__name__,
object=item,
datatype=datatype,
model_update_kwargs=model_update_kwargs or {},
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)
else:

Expand All @@ -87,6 +103,8 @@ def object_model_factory(*args, **kwargs):
model_update_kwargs=model_update_kwargs or {},
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)

return object_model_factory
Expand All @@ -99,6 +117,8 @@ def object_model_factory(*args, **kwargs):
model_update_kwargs=model_update_kwargs or {},
output_schema=output_schema,
num_workers=num_workers,
example=example,
signature=signature,
)

return decorated_function
Expand Down
4 changes: 2 additions & 2 deletions superduper/components/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class _BaseTemplate(Component):

def __post_init__(self, db, artifacts, substitutions):
if isinstance(self.template, Leaf):
self.template = self.template.encode(defaults=False, metadata=False)
self.template = self.template.encode(defaults=True, metadata=False)
self.template = SuperDuperFlatEncode(self.template)
if substitutions is not None:
self.template = QueryUpdateDocument(self.template).to_template(
Expand Down Expand Up @@ -120,7 +120,7 @@ def export(
path: t.Optional[str] = None,
format: str = 'json',
zip: bool = False,
defaults: bool = False,
defaults: bool = True,
metadata: bool = False,
):
"""
Expand Down
16 changes: 13 additions & 3 deletions superduper/rest/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def db_show(
version: t.Optional[int] = None,
application: t.Optional[str] = None,
):
if type_id == 'template' and identifier is None:
out = app.db.show('template')
from superduper import templates

out += [x for x in templates.ls() if x not in out]
return out
if application is not None:
r = app.db.metadata.get_component('application', application)
return r['namespace']
Expand All @@ -91,16 +97,20 @@ def db_show_template(identifier: str, type_id: str = 'template'):
def db_metadata_show_jobs(type_id: str, identifier: t.Optional[str] = None):
return [
r['job_id']
for r in app.db.metadata.show_jobs(
type_id=type_id, component_identifier=identifier
)
for r in app.db.metadata.show_jobs(type_id=type_id, identifier=identifier)
if 'job_id' in r
]

@app.add('/db/execute', method='post')
def db_execute(
query: t.Dict,
):
if query['query'].startswith('db.show'):
output = eval(f'app.{query["query"]}')
logging.info('db.show results:')
logging.info(output)
return [{'_base': output}], []

if '_path' not in query:
plugin = app.db.databackend.type.__module__.split('.')[0]
query['_path'] = f'{plugin}.query.parse_query'
Expand Down
Loading
Loading