From a75f531a94343bfd536e5da907a8b8c7b809d129 Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Sun, 12 Feb 2023 15:13:17 -0500 Subject: [PATCH] [typing/static] Fix @repository decorator typing --- .../dagster/_core/definitions/assets_job.py | 2 +- .../_core/definitions/cacheable_assets.py | 4 +-- .../decorators/repository_decorator.py | 27 +++++++++++++-- .../definitions/repository_definition.py | 0 .../repository_definition/__init__.py | 2 ++ .../repository_definition/caching_index.py | 34 +++++++++---------- .../repository_definition.py | 3 +- .../valid_definitions.py | 14 ++++++-- .../test_external_execution_plan.py | 6 +++- 9 files changed, 65 insertions(+), 27 deletions(-) delete mode 100644 python_modules/dagster/dagster/_core/definitions/repository_definition.py diff --git a/python_modules/dagster/dagster/_core/definitions/assets_job.py b/python_modules/dagster/dagster/_core/definitions/assets_job.py index 01fe7891061a6..8b444f8dc0809 100644 --- a/python_modules/dagster/dagster/_core/definitions/assets_job.py +++ b/python_modules/dagster/dagster/_core/definitions/assets_job.py @@ -38,7 +38,7 @@ ASSET_BASE_JOB_PREFIX = "__ASSET_JOB" -def is_base_asset_job_name(name) -> bool: +def is_base_asset_job_name(name: str) -> bool: return name.startswith(ASSET_BASE_JOB_PREFIX) diff --git a/python_modules/dagster/dagster/_core/definitions/cacheable_assets.py b/python_modules/dagster/dagster/_core/definitions/cacheable_assets.py index 8604d1d75b183..68688853d3e0c 100644 --- a/python_modules/dagster/dagster/_core/definitions/cacheable_assets.py +++ b/python_modules/dagster/dagster/_core/definitions/cacheable_assets.py @@ -242,7 +242,7 @@ def __init__( ) super().__init__( - unique_id=f"{wrapped._unique_id}_prefix_or_group_{self._get_hash()}", # noqa: SLF001 + unique_id=f"{wrapped.unique_id}_prefix_or_group_{self._get_hash()}", wrapped=wrapped, ) @@ -331,7 +331,7 @@ def __init__( self._resource_defs = resource_defs super().__init__( - unique_id=f"{wrapped._unique_id}_resources_{self._get_hash()}", # noqa: SLF001 + unique_id=f"{wrapped.unique_id}_resources_{self._get_hash()}", wrapped=wrapped, ) diff --git a/python_modules/dagster/dagster/_core/definitions/decorators/repository_decorator.py b/python_modules/dagster/dagster/_core/definitions/decorators/repository_decorator.py index 4754ae2625c2a..2bc6ff6c35271 100644 --- a/python_modules/dagster/dagster/_core/definitions/decorators/repository_decorator.py +++ b/python_modules/dagster/dagster/_core/definitions/decorators/repository_decorator.py @@ -28,8 +28,10 @@ VALID_REPOSITORY_DATA_DICT_KEYS, CachingRepositoryData, PendingRepositoryDefinition, + PendingRepositoryListDefinition, RepositoryData, RepositoryDefinition, + RepositoryListDefinition, ) from ..schedule_definition import ScheduleDefinition from ..sensor_definition import SensorDefinition @@ -68,8 +70,20 @@ def __init__( top_level_resources, "top_level_resources", key_type=str, value_type=ResourceDefinition ) + @overload def __call__( - self, fn: Callable[[], Sequence[Any]] + self, fn: Callable[[], Sequence[PendingRepositoryListDefinition]] + ) -> PendingRepositoryDefinition: + ... + + @overload + def __call__( + self, fn: Callable[[], Sequence[RepositoryListDefinition]] + ) -> RepositoryDefinition: + ... + + def __call__( + self, fn: Callable[[], Sequence[PendingRepositoryListDefinition]] ) -> Union[RepositoryDefinition, PendingRepositoryDefinition]: from dagster._core.definitions import AssetGroup, AssetsDefinition, SourceAsset from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition @@ -180,7 +194,16 @@ def __call__( @overload -def repository(definitions_fn: Callable[..., Sequence[Any]]) -> RepositoryDefinition: +def repository( + definitions_fn: Callable[..., Sequence[RepositoryListDefinition]] +) -> RepositoryDefinition: + ... + + +@overload +def repository( + definitions_fn: Callable[..., Sequence[PendingRepositoryListDefinition]] +) -> PendingRepositoryDefinition: ... diff --git a/python_modules/dagster/dagster/_core/definitions/repository_definition.py b/python_modules/dagster/dagster/_core/definitions/repository_definition.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/python_modules/dagster/dagster/_core/definitions/repository_definition/__init__.py b/python_modules/dagster/dagster/_core/definitions/repository_definition/__init__.py index d68c7b474727a..db847cd6699c6 100644 --- a/python_modules/dagster/dagster/_core/definitions/repository_definition/__init__.py +++ b/python_modules/dagster/dagster/_core/definitions/repository_definition/__init__.py @@ -9,4 +9,6 @@ from .valid_definitions import ( SINGLETON_REPOSITORY_NAME as SINGLETON_REPOSITORY_NAME, VALID_REPOSITORY_DATA_DICT_KEYS as VALID_REPOSITORY_DATA_DICT_KEYS, + PendingRepositoryListDefinition as PendingRepositoryListDefinition, + RepositoryListDefinition as RepositoryListDefinition, ) diff --git a/python_modules/dagster/dagster/_core/definitions/repository_definition/caching_index.py b/python_modules/dagster/dagster/_core/definitions/repository_definition/caching_index.py index 3c8d7910eadd5..c9eb7920812a8 100644 --- a/python_modules/dagster/dagster/_core/definitions/repository_definition/caching_index.py +++ b/python_modules/dagster/dagster/_core/definitions/repository_definition/caching_index.py @@ -13,20 +13,20 @@ import dagster._check as check from dagster._core.errors import DagsterInvariantViolationError -from .valid_definitions import RepositoryLevelDefinition +from .valid_definitions import T_RepositoryLevelDefinition -class CacheingDefinitionIndex(Generic[RepositoryLevelDefinition]): +class CacheingDefinitionIndex(Generic[T_RepositoryLevelDefinition]): def __init__( self, - definition_class: Type[RepositoryLevelDefinition], + definition_class: Type[T_RepositoryLevelDefinition], definition_class_name: str, definition_kind: str, definitions: Mapping[ - str, Union[RepositoryLevelDefinition, Callable[[], RepositoryLevelDefinition]] + str, Union[T_RepositoryLevelDefinition, Callable[[], T_RepositoryLevelDefinition]] ], - validation_fn: Callable[[RepositoryLevelDefinition], RepositoryLevelDefinition], - lazy_definitions_fn: Optional[Callable[[], Sequence[RepositoryLevelDefinition]]] = None, + validation_fn: Callable[[T_RepositoryLevelDefinition], T_RepositoryLevelDefinition], + lazy_definitions_fn: Optional[Callable[[], Sequence[T_RepositoryLevelDefinition]]] = None, ): """Args: definitions: A dictionary of definition names to definitions or functions that load @@ -47,27 +47,27 @@ def __init__( ), ) - self._definition_class: Type[RepositoryLevelDefinition] = definition_class + self._definition_class: Type[T_RepositoryLevelDefinition] = definition_class self._definition_class_name = definition_class_name self._definition_kind = definition_kind self._validation_fn: Callable[ - [RepositoryLevelDefinition], RepositoryLevelDefinition + [T_RepositoryLevelDefinition], T_RepositoryLevelDefinition ] = validation_fn self._definitions: Mapping[ - str, Union[RepositoryLevelDefinition, Callable[[], RepositoryLevelDefinition]] + str, Union[T_RepositoryLevelDefinition, Callable[[], T_RepositoryLevelDefinition]] ] = definitions - self._definition_cache: Dict[str, RepositoryLevelDefinition] = {} + self._definition_cache: Dict[str, T_RepositoryLevelDefinition] = {} self._definition_names: Optional[Sequence[str]] = None self._lazy_definitions_fn: Callable[ - [], Sequence[RepositoryLevelDefinition] + [], Sequence[T_RepositoryLevelDefinition] ] = lazy_definitions_fn or (lambda: []) - self._lazy_definitions: Optional[Sequence[RepositoryLevelDefinition]] = None + self._lazy_definitions: Optional[Sequence[T_RepositoryLevelDefinition]] = None - self._all_definitions: Optional[Sequence[RepositoryLevelDefinition]] = None + self._all_definitions: Optional[Sequence[T_RepositoryLevelDefinition]] = None - def _get_lazy_definitions(self) -> Sequence[RepositoryLevelDefinition]: + def _get_lazy_definitions(self) -> Sequence[T_RepositoryLevelDefinition]: if self._lazy_definitions is None: self._lazy_definitions = self._lazy_definitions_fn() for definition in self._lazy_definitions: @@ -98,7 +98,7 @@ def has_definition(self, definition_name: str) -> bool: return definition_name in self.get_definition_names() - def get_all_definitions(self) -> Sequence[RepositoryLevelDefinition]: + def get_all_definitions(self) -> Sequence[T_RepositoryLevelDefinition]: if self._all_definitions is not None: return self._all_definitions @@ -110,7 +110,7 @@ def get_all_definitions(self) -> Sequence[RepositoryLevelDefinition]: ) return self._all_definitions - def get_definition(self, definition_name: str) -> RepositoryLevelDefinition: + def get_definition(self, definition_name: str) -> T_RepositoryLevelDefinition: check.str_param(definition_name, "definition_name") if not self.has_definition(definition_name): @@ -142,7 +142,7 @@ def get_definition(self, definition_name: str) -> RepositoryLevelDefinition: return definition def _validate_and_cache_definition( - self, definition: RepositoryLevelDefinition, definition_dict_key: str + self, definition: T_RepositoryLevelDefinition, definition_dict_key: str ): check.invariant( isinstance(definition, self._definition_class), diff --git a/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_definition.py b/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_definition.py index 10fc93a10fe76..6ba3f8b6365aa 100644 --- a/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_definition.py +++ b/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_definition.py @@ -40,7 +40,8 @@ from .valid_definitions import ( SINGLETON_REPOSITORY_NAME as SINGLETON_REPOSITORY_NAME, VALID_REPOSITORY_DATA_DICT_KEYS as VALID_REPOSITORY_DATA_DICT_KEYS, - RepositoryListDefinition, + PendingRepositoryListDefinition as PendingRepositoryListDefinition, + RepositoryListDefinition as RepositoryListDefinition, ) if TYPE_CHECKING: diff --git a/python_modules/dagster/dagster/_core/definitions/repository_definition/valid_definitions.py b/python_modules/dagster/dagster/_core/definitions/repository_definition/valid_definitions.py index 339a3bb52588e..e811fa2e9237f 100644 --- a/python_modules/dagster/dagster/_core/definitions/repository_definition/valid_definitions.py +++ b/python_modules/dagster/dagster/_core/definitions/repository_definition/valid_definitions.py @@ -1,5 +1,7 @@ from typing import TYPE_CHECKING, TypeVar, Union +from typing_extensions import TypeAlias + from dagster._core.definitions.graph_definition import GraphDefinition from dagster._core.definitions.job_definition import JobDefinition from dagster._core.definitions.partition import PartitionSetDefinition @@ -11,6 +13,7 @@ if TYPE_CHECKING: from dagster._core.definitions import AssetGroup, AssetsDefinition + from dagster._core.definitions.cacheable_assets import CacheableAssetsDefinition from dagster._core.definitions.partitioned_schedule import ( UnresolvedPartitionedAssetScheduleDefinition, ) @@ -25,8 +28,8 @@ "jobs", } -RepositoryLevelDefinition = TypeVar( - "RepositoryLevelDefinition", +T_RepositoryLevelDefinition = TypeVar( + "T_RepositoryLevelDefinition", PipelineDefinition, JobDefinition, PartitionSetDefinition, @@ -34,7 +37,7 @@ SensorDefinition, ) -RepositoryListDefinition = Union[ +RepositoryListDefinition: TypeAlias = Union[ "AssetsDefinition", "AssetGroup", GraphDefinition, @@ -46,3 +49,8 @@ UnresolvedAssetJobDefinition, "UnresolvedPartitionedAssetScheduleDefinition", ] + +PendingRepositoryListDefinition: TypeAlias = Union[ + RepositoryListDefinition, + "CacheableAssetsDefinition", +] diff --git a/python_modules/dagster/dagster_tests/core_tests/test_external_execution_plan.py b/python_modules/dagster/dagster_tests/core_tests/test_external_execution_plan.py index f913b490f5c99..82aaf6461f941 100644 --- a/python_modules/dagster/dagster_tests/core_tests/test_external_execution_plan.py +++ b/python_modules/dagster/dagster_tests/core_tests/test_external_execution_plan.py @@ -1,6 +1,7 @@ import os import pickle import re +from typing import Sequence import pytest from dagster import ( @@ -25,6 +26,9 @@ from dagster._core.definitions.output import Out from dagster._core.definitions.pipeline_base import InMemoryPipeline from dagster._core.definitions.reconstruct import ReconstructablePipeline, ReconstructableRepository +from dagster._core.definitions.repository_definition.valid_definitions import ( + PendingRepositoryListDefinition, +) from dagster._core.execution.api import create_execution_plan, execute_plan from dagster._core.execution.plan.plan import ExecutionPlan from dagster._core.instance import DagsterInstance @@ -362,5 +366,5 @@ def bar(foo): @repository -def pending_repo(): +def pending_repo() -> Sequence[PendingRepositoryListDefinition]: return [bar, MyCacheableAssetsDefinition("xyz"), define_asset_job("all_asset_job")]