diff --git a/CHANGELOG.md b/CHANGELOG.md index 49d966050..1af659426 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,8 +25,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Move all plugins superduperdb/ext/* to /plugins - Optimize the logic for file saving and retrieval in the artifact_store. - - #### New Features & Functionality - Modify the field name output to _outputs.predict_id in the model results of Ibis. @@ -41,7 +39,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add testing utils for plugins - Add `cache` field in Component - #### Bug Fixes - Vector-search vector-loading bug fixed diff --git a/superduper/backends/local/__init__.py b/superduper/backends/local/__init__.py index 6298af57a..752d6d061 100644 --- a/superduper/backends/local/__init__.py +++ b/superduper/backends/local/__init__.py @@ -1,3 +1,4 @@ from .artifacts import FileSystemArtifactStore as ArtifactStore +from .compute import LocalComputeBackend as ComputeBackend -__all__ = ["ArtifactStore"] +__all__ = ["ArtifactStore", "ComputeBackend"] diff --git a/superduper/backends/local/artifacts.py b/superduper/backends/local/artifacts.py index 0e165e815..2530bb857 100644 --- a/superduper/backends/local/artifacts.py +++ b/superduper/backends/local/artifacts.py @@ -25,6 +25,8 @@ def __init__( name: t.Optional[str] = None, flavour: t.Optional[str] = None, ): + if conn.startswith('filesystem://'): + conn = conn.split('filesystem://')[-1] super().__init__(conn, name, flavour) if not os.path.exists(self.conn): logging.info('Creating artifact store directory') diff --git a/superduper/backends/local/compute.py b/superduper/backends/local/compute.py index 96048a858..fb27ba507 100644 --- a/superduper/backends/local/compute.py +++ b/superduper/backends/local/compute.py @@ -12,16 +12,19 @@ class LocalComputeBackend(ComputeBackend): :param uri: Optional uri param. :param queue: Optional pluggable queue. + :param kwargs: Optional kwargs. """ def __init__( self, uri: t.Optional[str] = None, queue: BaseQueuePublisher = LocalQueuePublisher(), + kwargs: t.Dict = {}, ): self.__outputs: t.Dict = {} self.uri = uri self.queue = queue + self.kwargs = kwargs @property def remote(self) -> bool: diff --git a/superduper/base/build.py b/superduper/base/build.py index 5e79399ce..ea92e5d76 100644 --- a/superduper/base/build.py +++ b/superduper/base/build.py @@ -17,19 +17,33 @@ class _Loader: not_supported: t.Tuple = () @classmethod - def create(cls, uri): - """Helper method to create metadata backend.""" + def match(cls, uri): + """Check if the uri matches the pattern.""" + plugin, flavour = None, None for pattern in cls.patterns: if re.match(pattern, uri) is not None: - plugin, flavour = cls.patterns[pattern] - if cls.not_supported and (plugin, flavour) in cls.not_supported: - raise ValueError( - f"{plugin} with flavour {flavour} not supported " - "to create metadata store." - ) - impl = getattr(load_plugin(plugin), cls.impl) - return impl(uri, flavour=flavour) - raise ValueError(f"{cls.__name__} No support for uri: {uri}") + selection = cls.patterns[pattern] + if isinstance(selection, tuple): + plugin, flavour = selection + else: + assert isinstance(selection, str) + plugin = selection + break + if plugin is None: + raise ValueError(f"{cls.__name__} No support for uri: {uri}") + return plugin, flavour + + @classmethod + def create(cls, uri): + """Helper method to create metadata backend.""" + plugin, flavour = cls.match(uri) + if cls.not_supported and (plugin, flavour) in cls.not_supported: + raise ValueError( + f"{plugin} with flavour {flavour} not supported " + "to create metadata store." + ) + impl = getattr(load_plugin(plugin), cls.impl) + return impl(uri, flavour=flavour) class _MetaDataLoader(_Loader): @@ -92,9 +106,8 @@ def _build_compute(cfg): :param cfg: SuperDuper config. """ - from superduper.backends.local.compute import LocalComputeBackend - - return LocalComputeBackend() + backend = getattr(load_plugin(cfg.cluster.compute.backend), 'ComputeBackend') + return backend(uri=cfg.cluster.compute.uri, **cfg.cluster.compute.kwargs) def build_datalayer(cfg=None, **kwargs) -> Datalayer: diff --git a/superduper/base/config.py b/superduper/base/config.py index af7c0f2c5..3a6b3c7a1 100644 --- a/superduper/base/config.py +++ b/superduper/base/config.py @@ -181,12 +181,12 @@ class Compute(BaseConfig): """Describes the configuration for distributed computing. :param uri: The URI for the compute service. - :param compute_kwargs: The keyword arguments to pass to the compute service. + :param kwargs: The keyword arguments to pass to the compute service. :param backend: Compute backend. """ uri: t.Optional[str] = None - compute_kwargs: t.Dict = dc.field(default_factory=dict) + kwargs: t.Dict = dc.field(default_factory=dict) backend: str = 'local' diff --git a/superduper/components/model.py b/superduper/components/model.py index 6752e616e..2b5e89d9d 100644 --- a/superduper/components/model.py +++ b/superduper/components/model.py @@ -504,7 +504,7 @@ def __post_init__(self, db, artifacts): super().__post_init__(db, artifacts) from superduper import CFG - compute_kwargs = CFG.cluster.compute.compute_kwargs + compute_kwargs = CFG.cluster.compute.kwargs self.compute_kwargs = self.compute_kwargs or compute_kwargs self._is_initialized = False if not self.identifier: diff --git a/superduper/components/template.py b/superduper/components/template.py index 330dc3b3c..d7818eb30 100644 --- a/superduper/components/template.py +++ b/superduper/components/template.py @@ -190,6 +190,7 @@ def form_template(self): for k, v in self.template.items() if k not in {'_blobs', 'identifier', '_path'} }, + '_path': self.template['_path'], } def execute(self, **kwargs): diff --git a/superduper/jobs/job.py b/superduper/jobs/job.py index 05a464521..957f32807 100644 --- a/superduper/jobs/job.py +++ b/superduper/jobs/job.py @@ -187,7 +187,7 @@ def __init__( db: t.Optional['Datalayer'] = None, component: 'Component' = None, ): - self.compute_kwargs = compute_kwargs or CFG.cluster.compute.compute_kwargs + self.compute_kwargs = compute_kwargs or CFG.cluster.compute.kwargs super().__init__(args=args, kwargs=kwargs, db=db, identifier=identifier) diff --git a/test/unittest/backends/local/test_artifact_store.py b/test/unittest/backends/local/test_artifact_store.py index 1c5477271..656374ef0 100644 --- a/test/unittest/backends/local/test_artifact_store.py +++ b/test/unittest/backends/local/test_artifact_store.py @@ -7,7 +7,7 @@ @pytest.fixture def artifact_store(tmpdir): - artifact_store = FileSystemArtifactStore(tmpdir) + artifact_store = FileSystemArtifactStore(str(tmpdir)) yield artifact_store artifact_store.drop(True)