diff --git a/dev/dev.py b/dev/dev.py index 9b6394e..73a7a8a 100644 --- a/dev/dev.py +++ b/dev/dev.py @@ -7,15 +7,13 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from anemoi.transform.filters import filter_factory -from anemoi.transform.sources import source_factory -from anemoi.transform.workflows import workflow_factory +from anemoi.transform.filters import filter_registry +from anemoi.transform.sources import source_registry +from anemoi.transform.workflows import workflow_registry ################ -mars = source_factory( - "mars", -) +mars = source_registry.create("mars") r = dict( param=["u", "v", "t", "q"], @@ -31,14 +29,14 @@ ################ -uv_2_ddff = filter_factory("uv_2_ddff") +uv_2_ddff = filter_registry.create("uv_2_ddff") data = uv_2_ddff.forward(data) for f in data: print(f) -ddff_2_uv = filter_factory("ddff_2_uv") +ddff_2_uv = filter_registry.create("ddff_2_uv") data = ddff_2_uv.forward(data) for f in data: print(f) @@ -46,7 +44,7 @@ ################ -pipeline = workflow_factory("pipeline", filters=[mars, uv_2_ddff, ddff_2_uv]) +pipeline = workflow_registry.create("pipeline", filters=[mars, uv_2_ddff, ddff_2_uv]) for f in pipeline(r): print(f) diff --git a/src/anemoi/transform/filters/__init__.py b/src/anemoi/transform/filters/__init__.py index 8524c67..b03adbe 100644 --- a/src/anemoi/transform/filters/__init__.py +++ b/src/anemoi/transform/filters/__init__.py @@ -6,18 +6,6 @@ # nor does it submit to any jurisdiction. -from ..registry import Registry +from anemoi.utils.registry import Registry -registry = Registry(__name__) - - -def register_filter(name, maker): - registry.register(name, maker) - - -def lookup_filter(name): - return registry.lookup(name) - - -def filter_factory(name, *args, **kwargs): - return lookup_filter(name)(*args, **kwargs) +filter_registry = Registry(__name__) diff --git a/src/anemoi/transform/filters/uv_to_ddff.py b/src/anemoi/transform/filters/uv_to_ddff.py index 2f6201d..8a763d3 100644 --- a/src/anemoi/transform/filters/uv_to_ddff.py +++ b/src/anemoi/transform/filters/uv_to_ddff.py @@ -11,7 +11,7 @@ from earthkit.meteo.wind.array import polar_to_xy from earthkit.meteo.wind.array import xy_to_polar -from . import register_filter +from . import filter_registry from .base import SimpleFilter @@ -80,5 +80,5 @@ def backward_transform(self, speed, direction): yield self.new_field_from_numpy(v, template=direction, param=self.v_component) -register_filter("uv_2_ddff", WindComponents) -register_filter("ddff_2_uv", WindComponents.reversed) +filter_registry.register("uv_2_ddff", WindComponents) +filter_registry.register("ddff_2_uv", WindComponents.reversed) diff --git a/src/anemoi/transform/registry.py b/src/anemoi/transform/registry.py deleted file mode 100644 index 2e1240b..0000000 --- a/src/anemoi/transform/registry.py +++ /dev/null @@ -1,75 +0,0 @@ -# (C) Copyright 2024 Anemoi contributors. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - - -import importlib -import logging -import os -import sys - -import entrypoints - -LOG = logging.getLogger(__name__) - - -class Registry: - """A registry of factories""" - - def __init__(self, package): - - self.package = package - self.registered = {} - self.kind = package.split(".")[-1] - - def register(self, name: str, factory: callable): - self.registered[name] = factory - - def _load(self, file): - name, _ = os.path.splitext(file) - try: - importlib.import_module(f".{name}", package=self.package) - except Exception: - LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True) - - def lookup(self, name: str) -> callable: - if name in self.registered: - return self.registered[name] - - directory = sys.modules[self.package].__path__[0] - - for file in os.listdir(directory): - - if file[0] == ".": - continue - - if file == "__init__.py": - continue - - full = os.path.join(directory, file) - if os.path.isdir(full): - if os.path.exists(os.path.join(full, "__init__.py")): - self._load(file) - continue - - if file.endswith(".py"): - self._load(file) - - entrypoint_group = f"anemoi.{self.kind}" - for entry_point in entrypoints.get_group_all(entrypoint_group): - if entry_point.name == name: - if name in self.registered: - LOG.warning( - f"Overwriting builtin '{name}' from {self.package} with plugin '{entry_point.module_name}'" - ) - self.registered[name] = entry_point.load() - - if name not in self.registered: - raise ValueError(f"Cannot load '{name}' from {self.package}") - - return self.registered[name] diff --git a/src/anemoi/transform/sources/__init__.py b/src/anemoi/transform/sources/__init__.py index 35098d5..26227ff 100644 --- a/src/anemoi/transform/sources/__init__.py +++ b/src/anemoi/transform/sources/__init__.py @@ -6,18 +6,6 @@ # nor does it submit to any jurisdiction. -from ..registry import Registry +from anemoi.utils.registry import Registry -registry = Registry(__name__) - - -def register_source(name, maker): - registry.register(name, maker) - - -def lookup_source(name): - return registry.lookup(name) - - -def source_factory(name, *args, **kwargs): - return lookup_source(name)(*args, **kwargs) +source_registry = Registry(__name__) diff --git a/src/anemoi/transform/sources/mars.py b/src/anemoi/transform/sources/mars.py index 0f18d64..06fc48f 100644 --- a/src/anemoi/transform/sources/mars.py +++ b/src/anemoi/transform/sources/mars.py @@ -11,9 +11,10 @@ import earthkit.data as ekd from ..source import Source -from . import register_source +from . import source_registry +@source_registry.register("mars") class Mars(Source): """A demo source""" @@ -37,4 +38,4 @@ def forward(self, data): return Input(data) -register_source("mars", Mars) +source_registry.register("mars", Mars) diff --git a/src/anemoi/transform/transform.py b/src/anemoi/transform/transform.py index a655614..1d57add 100644 --- a/src/anemoi/transform/transform.py +++ b/src/anemoi/transform/transform.py @@ -36,9 +36,9 @@ def reversed(cls, *args, **kwargs): return ReversedTransform(cls(*args, **kwargs)) def __or__(self, other): - from .workflows import workflow_factory + from .workflows import workflow_registry - return workflow_factory("pipeline", filters=[self, other]) + return workflow_registry.create("pipeline", filters=[self, other]) class ReversedTransform(Transform): diff --git a/src/anemoi/transform/workflows/__init__.py b/src/anemoi/transform/workflows/__init__.py index 5225904..10945d7 100644 --- a/src/anemoi/transform/workflows/__init__.py +++ b/src/anemoi/transform/workflows/__init__.py @@ -5,18 +5,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -from ..registry import Registry +from anemoi.utils.registry import Registry -registry = Registry(__name__) - - -def register_workflow(name, maker): - registry.register(name, maker) - - -def lookup_workflow(name): - return registry.lookup(name) - - -def workflow_factory(name, *args, **kwargs): - return lookup_workflow(name)(*args, **kwargs) +workflow_registry = Registry(__name__) diff --git a/src/anemoi/transform/workflows/pipeline.py b/src/anemoi/transform/workflows/pipeline.py index 6b3f83f..73e29bb 100644 --- a/src/anemoi/transform/workflows/pipeline.py +++ b/src/anemoi/transform/workflows/pipeline.py @@ -9,7 +9,7 @@ from ..workflow import Workflow -from . import register_workflow +from . import workflow_registry class Pipeline(Workflow): @@ -29,4 +29,4 @@ def backward(self, data): return data -register_workflow("pipeline", Pipeline) +workflow_registry.register("pipeline", Pipeline)