From d5db43f9437d7e4e59455905e24df4d2e7b958d4 Mon Sep 17 00:00:00 2001 From: Sean Mackesey Date: Thu, 23 Mar 2023 13:19:22 -0400 Subject: [PATCH] [refactor] Remove `frozen{list,dict,tags}` classes (#12293) ### Summary & Motivation Internal companion PR: https://github.com/dagster-io/internal/pull/5239 Relevant discussion: https://github.com/dagster-io/internal/discussions/4859#discussioncomment-4950032 Removes the `frozenlist`, `frozendict` and `frozentags` classes. Reasons: - They don't play nicely with static type-checking and are a source of type-checker diagnostic noise - Static type-checking with Sequence/Mapping sort of solves the same problem (this is kind of like removing runtime check calls from internal methods) - They are used inconsistently in the codebase (most places where we have the intention of creating an immutable list/dict they are not used). Where they are used, it is not at all obvious why at first glance. - They generally complicate the code The main purpose the `frozen*` classes were serving in our code was to make a few select `NamedTuple` classes hashable. If a `NamedTuple` contains mutable collection members, its default hash function will fail. Replacing those mutable collections with immutable ones lets the hash succeed, which in turns lets the `NamedTuple` be cached via `lru_cache` or used as a dict key. The set of classes that need to be made hashable for `lru_cache` purposes are: - `CustomPointer` - `AssetsDefinitionCacheableData` - `RepositoryLoadData` - `ReconstructableRepository` These are part of a tree of objects rooted on `ReconstructablePipeline`, which uses `lru_cache` in its `get_definition` method. The above classes can be made hashable in a more legible and type-annotation-friendly way by defining a `__hash__` method. This PR does that-- wherever we have a `NamedTuple` that needed to be hashed, `frozen*` instantiations in the constructor were removed and a `__hash__` method was added. All of the `__hash__` methods are the same, they just call `dagster._utils.hash_named_tuple` and cache the result. This function just converts mutable collections to tuples prior to hashing, which allows hashing to succeed as tuples are immutable. Aside from cases where frozen classes were used to achieve `NamedTuple` hashability, the other uses were: - random? uses in some places where whoever wrote the code thought the structure should be immutable. I removed these cases for the reasons above (most of the time we want structures to be immutable and in most of the codebase we signal this with `Sequence`/`Mapping`. - A few other cases where dicts need to be hashable to sort them or use them in a cache. For these, I provided simple substitute solutions that achieve the same effect-- e.g. instead of using `frozendict`, converting a dict to a tuple of 2-tuples before hashing. --- - Resolves #3008 - Resolves #3641 ### How I Tested These Changes Existing test suite. --- .../dagster_graphql/client/client.py | 2 +- .../dagster/dagster/_check/__init__.py | 66 +++------- .../dagster/dagster/_config/post_process.py | 12 +- .../dagster/dagster/_config/validate.py | 10 +- .../dagster/dagster/_core/code_pointer.py | 19 +-- .../_core/definitions/cacheable_assets.py | 80 ++++++------ .../dagster/_core/definitions/composition.py | 15 ++- .../dagster/_core/definitions/data_time.py | 17 +-- .../dagster/_core/definitions/dependency.py | 19 +-- .../_core/definitions/metadata/table.py | 7 +- .../_core/definitions/node_definition.py | 9 +- .../dagster/_core/definitions/partition.py | 5 +- .../_core/definitions/pipeline_definition.py | 3 +- .../dagster/_core/definitions/reconstruct.py | 15 ++- .../repository_definition.py | 14 ++- .../dagster/_core/definitions/utils.py | 11 +- .../dagster/_core/instance/__init__.py | 13 +- .../dagster/dagster/_core/origin.py | 13 +- .../dagster/_core/snap/dep_snapshot.py | 6 +- python_modules/dagster/dagster/_core/utils.py | 21 ++-- .../dagster/dagster/_grpc/server.py | 39 +++--- python_modules/dagster/dagster/_grpc/types.py | 3 +- .../dagster/dagster/_utils/__init__.py | 119 +++++------------- .../dagster_tests/core_tests/test_utils.py | 31 ++++- .../execution_tests/test_metadata.py | 33 +++-- .../general_tests/check_tests/test_check.py | 7 -- .../utils_tests/test_frozendict.py | 18 --- .../utils_tests/test_frozenlist.py | 23 ---- .../dagster_celery_k8s/launcher.py | 3 +- .../dagster-dask/dagster_dask/executor.py | 8 +- .../dagster_k8s/container_context.py | 12 +- .../dagster-k8s/dagster_k8s/executor.py | 3 +- .../libraries/dagster-k8s/dagster_k8s/job.py | 21 ++-- .../dagster-k8s/dagster_k8s/models.py | 5 +- .../unit_tests/test_container_context.py | 9 +- 35 files changed, 306 insertions(+), 385 deletions(-) delete mode 100644 python_modules/dagster/dagster_tests/general_tests/utils_tests/test_frozendict.py delete mode 100644 python_modules/dagster/dagster_tests/general_tests/utils_tests/test_frozenlist.py diff --git a/python_modules/dagster-graphql/dagster_graphql/client/client.py b/python_modules/dagster-graphql/dagster_graphql/client/client.py index 9bcf623033e9c..b42b1c0d97bf0 100644 --- a/python_modules/dagster-graphql/dagster_graphql/client/client.py +++ b/python_modules/dagster-graphql/dagster_graphql/client/client.py @@ -122,7 +122,7 @@ def _core_submit_execution( run_config: Optional[Mapping[str, Any]] = None, mode: Optional[str] = None, preset: Optional[str] = None, - tags: Optional[Dict[str, Any]] = None, + tags: Optional[Mapping[str, str]] = None, solid_selection: Optional[List[str]] = None, is_using_job_op_graph_apis: Optional[bool] = False, ): diff --git a/python_modules/dagster/dagster/_check/__init__.py b/python_modules/dagster/dagster/_check/__init__.py index da4d16892b737..3f80cd1586ee3 100644 --- a/python_modules/dagster/dagster/_check/__init__.py +++ b/python_modules/dagster/dagster/_check/__init__.py @@ -245,11 +245,9 @@ def dict_param( """Ensures argument obj is a native Python dictionary, raises an exception if not, and otherwise returns obj. """ - from dagster._utils import frozendict - - if not isinstance(obj, (frozendict, dict)): + if not isinstance(obj, dict): raise _param_type_mismatch_exception( - obj, (frozendict, dict), param_name, additional_message=additional_message + obj, dict, param_name, additional_message=additional_message ) if not (key_type or value_type): @@ -268,12 +266,8 @@ def opt_dict_param( """Ensures argument obj is either a dictionary or None; if the latter, instantiates an empty dictionary. """ - from dagster._utils import frozendict - - if obj is not None and not isinstance(obj, (frozendict, dict)): - raise _param_type_mismatch_exception( - obj, (frozendict, dict), param_name, additional_message - ) + if obj is not None and not isinstance(obj, dict): + raise _param_type_mismatch_exception(obj, dict, param_name, additional_message) if not obj: return {} @@ -312,12 +306,8 @@ def opt_nullable_dict_param( additional_message: Optional[str] = None, ) -> Optional[Dict]: """Ensures argument obj is either a dictionary or None.""" - from dagster._utils import frozendict - - if obj is not None and not isinstance(obj, (frozendict, dict)): - raise _param_type_mismatch_exception( - obj, (frozendict, dict), param_name, additional_message - ) + if obj is not None and not isinstance(obj, dict): + raise _param_type_mismatch_exception(obj, dict, param_name, additional_message) if not obj: return None if obj is None else {} @@ -361,8 +351,6 @@ def dict_elem( value_type: Optional[TypeOrTupleOfTypes] = None, additional_message: Optional[str] = None, ) -> Dict: - from dagster._utils import frozendict - dict_param(obj, "obj") str_param(key, "key") @@ -370,8 +358,8 @@ def dict_elem( raise CheckError(f"{key} not present in dictionary {obj}") value = obj[key] - if not isinstance(value, (frozendict, dict)): - raise _element_check_error(key, value, obj, (frozendict, dict), additional_message) + if not isinstance(value, dict): + raise _element_check_error(key, value, obj, dict, additional_message) else: return _check_mapping_entries(value, key_type, value_type, mapping_type=dict) @@ -383,8 +371,6 @@ def opt_dict_elem( value_type: Optional[TypeOrTupleOfTypes] = None, additional_message: Optional[str] = None, ) -> Dict: - from dagster._utils import frozendict - dict_param(obj, "obj") str_param(key, "key") @@ -392,7 +378,7 @@ def opt_dict_elem( if value is None: return {} - elif not isinstance(value, (frozendict, dict)): + elif not isinstance(value, dict): raise _element_check_error(key, value, obj, dict, additional_message) else: return _check_mapping_entries(value, key_type, value_type, mapping_type=dict) @@ -405,8 +391,6 @@ def opt_nullable_dict_elem( value_type: Optional[TypeOrTupleOfTypes] = None, additional_message: Optional[str] = None, ) -> Optional[Dict]: - from dagster._utils import frozendict - dict_param(obj, "obj") str_param(key, "key") @@ -414,7 +398,7 @@ def opt_nullable_dict_elem( if value is None: return None - elif not isinstance(value, (frozendict, dict)): + elif not isinstance(value, dict): raise _element_check_error(key, value, obj, dict, additional_message) else: return _check_mapping_entries(value, key_type, value_type, mapping_type=dict) @@ -446,10 +430,8 @@ def is_dict( value_type: Optional[TypeOrTupleOfTypes] = None, additional_message: Optional[str] = None, ) -> Dict: - from dagster._utils import frozendict - - if not isinstance(obj, (frozendict, dict)): - raise _type_mismatch_error(obj, (frozendict, dict), additional_message) + if not isinstance(obj, dict): + raise _type_mismatch_error(obj, dict, additional_message) if not (key_type or value_type): return obj @@ -768,12 +750,8 @@ def list_param( of_type: Optional[TypeOrTupleOfTypes] = None, additional_message: Optional[str] = None, ) -> List[Any]: - from dagster._utils import frozenlist - - if not isinstance(obj, (frozenlist, list)): - raise _param_type_mismatch_exception( - obj, (frozenlist, list), param_name, additional_message - ) + if not isinstance(obj, list): + raise _param_type_mismatch_exception(obj, list, param_name, additional_message) if not of_type: return obj @@ -793,12 +771,8 @@ def opt_list_param( If the of_type argument is provided, also ensures that list items conform to the type specified by of_type. """ - from dagster._utils import frozenlist - - if obj is not None and not isinstance(obj, (frozenlist, list)): - raise _param_type_mismatch_exception( - obj, (frozenlist, list), param_name, additional_message - ) + if obj is not None and not isinstance(obj, list): + raise _param_type_mismatch_exception(obj, list, param_name, additional_message) if not obj: return [] @@ -840,12 +814,8 @@ def opt_nullable_list_param( If the of_type argument is provided, also ensures that list items conform to the type specified by of_type. """ - from dagster._utils import frozenlist - - if obj is not None and not isinstance(obj, (frozenlist, list)): - raise _param_type_mismatch_exception( - obj, (frozenlist, list), param_name, additional_message - ) + if obj is not None and not isinstance(obj, list): + raise _param_type_mismatch_exception(obj, list, param_name, additional_message) if not obj: return None if obj is None else [] diff --git a/python_modules/dagster/dagster/_config/post_process.py b/python_modules/dagster/dagster/_config/post_process.py index e8d21ee5dc7e8..adb1c44cac70d 100644 --- a/python_modules/dagster/dagster/_config/post_process.py +++ b/python_modules/dagster/dagster/_config/post_process.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Mapping, Optional, cast import dagster._check as check -from dagster._utils import ensure_single_item, frozendict, frozenlist +from dagster._utils import ensure_single_item from dagster._utils.error import serializable_error_info_from_exc_info from .config_type import ConfigType, ConfigTypeKind @@ -122,7 +122,7 @@ def _recurse_in_to_selector( else incoming_field_value, ) if field_evr.success: - return EvaluateValueResult.for_value(frozendict({field_name: field_evr.value})) + return EvaluateValueResult.for_value({field_name: field_evr.value}) return field_evr @@ -183,7 +183,7 @@ def _recurse_in_to_shape( return EvaluateValueResult.for_errors(errors) return EvaluateValueResult.for_value( - frozendict({key: result.value for key, result in processed_fields.items()}) + {key: result.value for key, result in processed_fields.items()} ) @@ -210,7 +210,7 @@ def _recurse_in_to_array(context: TraversalContext, config_value: Any) -> Evalua if errors: return EvaluateValueResult.for_errors(errors) - return EvaluateValueResult.for_value(frozenlist([result.value for result in results])) + return EvaluateValueResult.for_value([result.value for result in results]) def _recurse_in_to_map(context: TraversalContext, config_value: Any) -> EvaluateValueResult[Any]: @@ -243,6 +243,4 @@ def _recurse_in_to_map(context: TraversalContext, config_value: Any) -> Evaluate if errors: return EvaluateValueResult.for_errors(errors) - return EvaluateValueResult.for_value( - frozendict({key: result.value for key, result in results.items()}) - ) + return EvaluateValueResult.for_value({key: result.value for key, result in results.items()}) diff --git a/python_modules/dagster/dagster/_config/validate.py b/python_modules/dagster/dagster/_config/validate.py index 779af934da644..c9c3415e53915 100644 --- a/python_modules/dagster/dagster/_config/validate.py +++ b/python_modules/dagster/dagster/_config/validate.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, TypeVar, cast import dagster._check as check -from dagster._utils import ensure_single_item, frozendict +from dagster._utils import ensure_single_item from .config_type import ConfigScalarKind, ConfigType, ConfigTypeKind from .errors import ( @@ -210,7 +210,7 @@ def validate_selector_config( if child_evaluate_value_result.success: return EvaluateValueResult.for_value( # type: ignore - frozendict({field_name: child_evaluate_value_result.value}) + {field_name: child_evaluate_value_result.value} ) else: return child_evaluate_value_result @@ -289,7 +289,7 @@ def _validate_shape_config( if errors: return EvaluateValueResult.for_errors(errors) else: - return EvaluateValueResult.for_value(frozendict(config_value)) # type: ignore + return EvaluateValueResult.for_value(config_value) # type: ignore def validate_permissive_shape_config( @@ -304,7 +304,7 @@ def validate_permissive_shape_config( def validate_map_config( context: ValidationContext, config_value: object -) -> EvaluateValueResult[Mapping[str, object]]: +) -> EvaluateValueResult[Mapping[object, object]]: check.inst_param(context, "context", ValidationContext) check.invariant(context.config_type_snap.kind == ConfigTypeKind.MAP) check.not_none_param(config_value, "config_value") @@ -325,7 +325,7 @@ def validate_map_config( if not result.success: errors += cast(List, result.errors) - return EvaluateValueResult(not bool(errors), frozendict(config_value), errors) + return EvaluateValueResult(not bool(errors), config_value, errors) def validate_shape_config( diff --git a/python_modules/dagster/dagster/_core/code_pointer.py b/python_modules/dagster/dagster/_core/code_pointer.py index e42c5e66ddffc..9b83a86d20b0f 100644 --- a/python_modules/dagster/dagster/_core/code_pointer.py +++ b/python_modules/dagster/dagster/_core/code_pointer.py @@ -10,7 +10,7 @@ from dagster._core.errors import DagsterImportError, DagsterInvariantViolationError from dagster._serdes import whitelist_for_serdes from dagster._seven import get_import_error_message, import_module_from_path -from dagster._utils import alter_sys_path, frozenlist +from dagster._utils import alter_sys_path, hash_collection class CodePointer(ABC): @@ -296,13 +296,6 @@ def __new__( ), ) - # These are frozenlists, rather than lists, so that they can be hashed and the pointer - # stored in the lru_cache on the repository and pipeline get_definition methods - reconstructable_args = frozenlist(reconstructable_args) - reconstructable_kwargs = frozenlist( - [frozenlist(reconstructable_kwarg) for reconstructable_kwarg in reconstructable_kwargs] - ) - return super(CustomPointer, cls).__new__( cls, reconstructor_pointer, @@ -321,3 +314,13 @@ def describe(self) -> str: return "reconstructable using {module}.{fn_name}".format( module=self.reconstructor_pointer.module, fn_name=self.reconstructor_pointer.fn_name ) + + # Allow this to be hashed for use in `lru_cache`. This is needed because: + # - `ReconstructablePipeline` uses `lru_cache` + # - `ReconstructablePipeline` has a `ReconstructableRepository` attribute + # - `ReconstructableRepository` has a `CodePointer` attribute + # - `CustomCodePointer` has collection attributes that are unhashable by default + def __hash__(self) -> int: + if not hasattr(self, "_hash"): + self._hash = hash_collection(self) + return self._hash diff --git a/python_modules/dagster/dagster/_core/definitions/cacheable_assets.py b/python_modules/dagster/dagster/_core/definitions/cacheable_assets.py index 68688853d3e0c..954cd34f9d037 100644 --- a/python_modules/dagster/dagster/_core/definitions/cacheable_assets.py +++ b/python_modules/dagster/dagster/_core/definitions/cacheable_assets.py @@ -14,7 +14,7 @@ from dagster._core.definitions.resource_definition import ResourceDefinition from dagster._core.definitions.resource_requirement import ResourceAddable from dagster._serdes import whitelist_for_serdes -from dagster._utils import frozendict, frozenlist, make_readonly_value +from dagster._utils import hash_collection @whitelist_for_serdes @@ -45,36 +45,11 @@ def __new__( internal_asset_deps: Optional[Mapping[str, AbstractSet[AssetKey]]] = None, group_name: Optional[str] = None, metadata_by_output_name: Optional[Mapping[str, MetadataUserInput]] = None, - key_prefix: Optional[CoercibleToAssetKeyPrefix] = None, + key_prefix: Optional[Sequence[str]] = None, can_subset: bool = False, extra_metadata: Optional[Mapping[Any, Any]] = None, freshness_policies_by_output_name: Optional[Mapping[str, FreshnessPolicy]] = None, ): - keys_by_input_name = check.opt_nullable_mapping_param( - keys_by_input_name, "keys_by_input_name", key_type=str, value_type=AssetKey - ) - - keys_by_output_name = check.opt_nullable_mapping_param( - keys_by_output_name, "keys_by_output_name", key_type=str, value_type=AssetKey - ) - - internal_asset_deps = check.opt_nullable_mapping_param( - internal_asset_deps, "internal_asset_deps", key_type=str, value_type=(set, frozenset) - ) - - metadata_by_output_name = check.opt_nullable_mapping_param( - metadata_by_output_name, "metadata_by_output_name", key_type=str, value_type=dict - ) - - freshness_policies_by_output_name = check.opt_nullable_mapping_param( - freshness_policies_by_output_name, - "freshness_policies_by_output_name", - key_type=str, - value_type=FreshnessPolicy, - ) - - key_prefix = check.opt_inst_param(key_prefix, "key_prefix", (str, list)) - extra_metadata = check.opt_nullable_mapping_param(extra_metadata, "extra_metadata") try: # check that the value is JSON serializable @@ -84,25 +59,46 @@ def __new__( return super().__new__( cls, - keys_by_input_name=frozendict(keys_by_input_name) if keys_by_input_name else None, - keys_by_output_name=frozendict(keys_by_output_name) if keys_by_output_name else None, - internal_asset_deps=frozendict( - {k: frozenset(v) for k, v in internal_asset_deps.items()} - ) - if internal_asset_deps - else None, + keys_by_input_name=check.opt_nullable_mapping_param( + keys_by_input_name, "keys_by_input_name", key_type=str, value_type=AssetKey + ), + keys_by_output_name=check.opt_nullable_mapping_param( + keys_by_output_name, "keys_by_output_name", key_type=str, value_type=AssetKey + ), + internal_asset_deps=check.opt_nullable_mapping_param( + internal_asset_deps, + "internal_asset_deps", + key_type=str, + value_type=(set, frozenset), + ), group_name=check.opt_str_param(group_name, "group_name"), - metadata_by_output_name=make_readonly_value(metadata_by_output_name) - if metadata_by_output_name - else None, - key_prefix=frozenlist(key_prefix) if key_prefix else None, + metadata_by_output_name=check.opt_nullable_mapping_param( + metadata_by_output_name, "metadata_by_output_name", key_type=str + ), + key_prefix=[key_prefix] + if isinstance(key_prefix, str) + else check.opt_list_param(key_prefix, "key_prefix", of_type=str), can_subset=check.opt_bool_param(can_subset, "can_subset", default=False), - extra_metadata=make_readonly_value(extra_metadata) if extra_metadata else None, - freshness_policies_by_output_name=frozendict(freshness_policies_by_output_name) - if freshness_policies_by_output_name - else None, + extra_metadata=extra_metadata, + freshness_policies_by_output_name=check.opt_nullable_mapping_param( + freshness_policies_by_output_name, + "freshness_policies_by_output_name", + key_type=str, + value_type=FreshnessPolicy, + ), ) + # Allow this to be hashed for use in `lru_cache`. This is needed because: + # - `ReconstructablePipeline` uses `lru_cache` + # - `ReconstructablePipeline` has a `ReconstructableRepository` attribute + # - `ReconstructableRepository` has a `RepositoryLoadData` attribute + # - `RepositoryLoadData` has a `Mapping` attribute containing `AssetsDefinitionCacheableData` + # - `AssetsDefinitionCacheableData` has collection attributes that are unhashable by default + def __hash__(self) -> int: + if not hasattr(self, "_hash"): + self._hash = hash_collection(self) + return self._hash + class CacheableAssetsDefinition(ResourceAddable, ABC): def __init__(self, unique_id: str): diff --git a/python_modules/dagster/dagster/_core/definitions/composition.py b/python_modules/dagster/dagster/_core/definitions/composition.py index 866c3e5218568..4d9c8e97c1514 100644 --- a/python_modules/dagster/dagster/_core/definitions/composition.py +++ b/python_modules/dagster/dagster/_core/definitions/composition.py @@ -26,7 +26,6 @@ DagsterInvalidInvocationError, DagsterInvariantViolationError, ) -from dagster._utils import frozentags from .config import ConfigMapping from .dependency import ( @@ -231,7 +230,7 @@ def observe_invocation( given_alias: Optional[str], node_def: NodeDefinition, input_bindings: Mapping[str, InputSource], - tags: Optional[frozentags], + tags: Optional[Mapping[str, str]], hook_defs: Optional[AbstractSet[HookDefinition]], retry_policy: Optional[RetryPolicy], ) -> str: @@ -396,7 +395,7 @@ def the_graph(): node_def: NodeDefinition given_alias: Optional[str] - tags: Optional[frozentags] + tags: Optional[Mapping[str, str]] hook_defs: AbstractSet[HookDefinition] retry_policy: Optional[RetryPolicy] @@ -404,13 +403,13 @@ def __init__( self, node_def: NodeDefinition, given_alias: Optional[str], - tags: Optional[frozentags], + tags: Optional[Mapping[str, str]], hook_defs: Optional[AbstractSet[HookDefinition]], retry_policy: Optional[RetryPolicy], ): self.node_def = check.inst_param(node_def, "node_def", NodeDefinition) self.given_alias = check.opt_str_param(given_alias, "given_alias") - self.tags = check.opt_inst_param(tags, "tags", frozentags) + self.tags = check.opt_mapping_param(tags, "tags", key_type=str, value_type=str) self.hook_defs = check.opt_set_param(hook_defs, "hook_defs", HookDefinition) self.retry_policy = check.opt_inst_param(retry_policy, "retry_policy", RetryPolicy) @@ -665,7 +664,7 @@ def tag(self, tags: Optional[Mapping[str, str]]) -> "PendingNodeInvocation": return PendingNodeInvocation( node_def=self.node_def, given_alias=self.given_alias, - tags=frozentags(tags) if self.tags is None else self.tags.updated_with(tags), + tags={**(self.tags or {}), **tags}, hook_defs=self.hook_defs, retry_policy=self.retry_policy, ) @@ -725,7 +724,7 @@ def to_job( description=description, resource_defs=resource_defs, config=config, - tags=tags if not self.tags else self.tags.updated_with(tags), + tags={**(self.tags or {}), **tags}, logger_defs=logger_defs, executor_def=executor_def, hooks=job_hooks, @@ -783,7 +782,7 @@ class InvokedNode(NamedTuple): node_name: str node_def: NodeDefinition input_bindings: Mapping[str, InputSource] - tags: Optional[frozentags] + tags: Optional[Mapping[str, str]] hook_defs: Optional[AbstractSet[HookDefinition]] retry_policy: Optional[RetryPolicy] diff --git a/python_modules/dagster/dagster/_core/definitions/data_time.py b/python_modules/dagster/dagster/_core/definitions/data_time.py index 8f72743f7501a..1d0b842d494c7 100644 --- a/python_modules/dagster/dagster/_core/definitions/data_time.py +++ b/python_modules/dagster/dagster/_core/definitions/data_time.py @@ -14,7 +14,7 @@ from dagster._core.event_api import EventLogRecord from dagster._core.instance import DagsterInstance from dagster._core.storage.pipeline_run import FINISHED_STATUSES, DagsterRunStatus, RunsFilter -from dagster._utils import frozendict +from dagster._utils import make_hashable from dagster._utils.cached_method import cached_method from dagster._utils.caching_instance_queryer import CachingInstanceQueryer @@ -168,8 +168,9 @@ def _calculate_used_data_unpartitioned( asset_key: AssetKey, record_id: int, record_timestamp: Optional[float], - record_tags: Mapping[str, str], + record_tags: Tuple[Tuple[str, str]], # for hashability ) -> Mapping[AssetKey, Tuple[Optional[int], Optional[float]]]: + record_tags_dict = dict(record_tags) if record_id is None: return {key: (None, None) for key in asset_graph.get_non_source_roots(asset_key)} @@ -184,10 +185,10 @@ def _calculate_used_data_unpartitioned( continue input_event_pointer_tag = get_input_event_pointer_tag(parent_key) - if input_event_pointer_tag in record_tags: + if input_event_pointer_tag in record_tags_dict: # get the upstream materialization event which was consumed when producing this # materialization event - pointer_tag = record_tags[input_event_pointer_tag] + pointer_tag = record_tags_dict[input_event_pointer_tag] if pointer_tag and pointer_tag != "NULL": input_record_id = int(pointer_tag) parent_record = self._instance_queryer.get_latest_materialization_record( @@ -209,7 +210,7 @@ def _calculate_used_data_unpartitioned( asset_key=parent_key, record_id=parent_record.storage_id if parent_record else None, record_timestamp=parent_record.event_log_entry.timestamp if parent_record else None, - record_tags=frozendict( + record_tags=make_hashable( ( parent_record.asset_materialization.tags if parent_record and parent_record.asset_materialization @@ -233,7 +234,7 @@ def _calculate_used_data( asset_key: AssetKey, record_id: Optional[int], record_timestamp: Optional[float], - record_tags: Mapping[str, str], + record_tags: Tuple[Tuple[str, str]], # for hashability ) -> Mapping[AssetKey, Tuple[Optional[int], Optional[float]]]: if record_id is None: return {key: (None, None) for key in asset_graph.get_non_source_roots(asset_key)} @@ -279,7 +280,7 @@ def get_used_data_times_for_record( asset_key=record.asset_key, record_id=record.storage_id, record_timestamp=record.event_log_entry.timestamp, - record_tags=frozendict(record.asset_materialization.tags or {}), + record_tags=make_hashable(record.asset_materialization.tags or {}), ) return { @@ -330,7 +331,7 @@ def _get_in_progress_data_times_for_key_in_run( asset_key=asset_key, record_id=latest_record.storage_id if latest_record else None, record_timestamp=latest_record.event_log_entry.timestamp if latest_record else None, - record_tags=frozendict( + record_tags=make_hashable( ( latest_record.asset_materialization.tags if latest_record and latest_record.asset_materialization diff --git a/python_modules/dagster/dagster/_core/definitions/dependency.py b/python_modules/dagster/dagster/_core/definitions/dependency.py index fe810ef3babd4..9f718e7dc6c3b 100644 --- a/python_modules/dagster/dagster/_core/definitions/dependency.py +++ b/python_modules/dagster/dagster/_core/definitions/dependency.py @@ -30,7 +30,7 @@ from dagster._serdes.serdes import ( whitelist_for_serdes, ) -from dagster._utils import frozentags +from dagster._utils import hash_collection from .hook_definition import HookDefinition from .input import FanInInputPointer, InputDefinition, InputMapping, InputPointer @@ -98,13 +98,17 @@ def __new__( cls, name=check.str_param(name, "name"), alias=check.opt_str_param(alias, "alias"), - tags=frozentags(check.opt_mapping_param(tags, "tags", value_type=str, key_type=str)), - hook_defs=frozenset( - check.opt_set_param(hook_defs, "hook_defs", of_type=HookDefinition) - ), + tags=check.opt_mapping_param(tags, "tags", value_type=str, key_type=str), + hook_defs=check.opt_set_param(hook_defs, "hook_defs", of_type=HookDefinition), retry_policy=check.opt_inst_param(retry_policy, "retry_policy", RetryPolicy), ) + # Needs to be hashable because this class is used as a key in dependencies dicts + def __hash__(self) -> int: + if not hasattr(self, "_hash"): + self._hash = hash_collection(self) + return self._hash + class Node(ABC): """Node invocation within a graph. Identified by its name inside the graph.""" @@ -185,9 +189,8 @@ def output_dict(self) -> Mapping[str, OutputDefinition]: return self.definition.output_dict @property - def tags(self) -> frozentags: - # Type-ignore temporarily pending assessment of right data structure for `tags` - return self.definition.tags.updated_with(self._additional_tags) # type: ignore + def tags(self) -> Mapping[str, str]: + return {**self.definition.tags, **self._additional_tags} def container_maps_input(self, input_name: str) -> bool: return ( diff --git a/python_modules/dagster/dagster/_core/definitions/metadata/table.py b/python_modules/dagster/dagster/_core/definitions/metadata/table.py index 832e0964f9673..cf2f60b78b8e9 100644 --- a/python_modules/dagster/dagster/_core/definitions/metadata/table.py +++ b/python_modules/dagster/dagster/_core/definitions/metadata/table.py @@ -5,7 +5,6 @@ from dagster._serdes.serdes import ( whitelist_for_serdes, ) -from dagster._utils import frozenlist # ######################## # ##### TABLE RECORD @@ -109,7 +108,7 @@ def __new__( ): return super(TableSchema, cls).__new__( cls, - columns=frozenlist(check.sequence_param(columns, "columns", of_type=TableColumn)), + columns=check.sequence_param(columns, "columns", of_type=TableColumn), constraints=check.opt_inst_param( constraints, "constraints", TableConstraints, default=_DEFAULT_TABLE_CONSTRAINTS ), @@ -157,7 +156,7 @@ def __new__( ): return super(TableConstraints, cls).__new__( cls, - other=frozenlist(check.sequence_param(other, "other", of_type=str)), + other=check.sequence_param(other, "other", of_type=str), ) @@ -255,7 +254,7 @@ def __new__( cls, nullable=check.bool_param(nullable, "nullable"), unique=check.bool_param(unique, "unique"), - other=frozenlist(check.opt_sequence_param(other, "other")), + other=check.opt_sequence_param(other, "other"), ) diff --git a/python_modules/dagster/dagster/_core/definitions/node_definition.py b/python_modules/dagster/dagster/_core/definitions/node_definition.py index 3b118571a2597..d12eb0e17c123 100644 --- a/python_modules/dagster/dagster/_core/definitions/node_definition.py +++ b/python_modules/dagster/dagster/_core/definitions/node_definition.py @@ -13,7 +13,6 @@ import dagster._check as check from dagster._core.definitions.configurable import NamedConfigurableDefinition from dagster._core.definitions.policy import RetryPolicy -from dagster._utils import frozendict, frozenlist from .hook_definition import HookDefinition from .utils import check_valid_name, validate_tags @@ -54,11 +53,11 @@ def __init__( self._name = check_valid_name(name) self._description = check.opt_str_param(description, "description") self._tags = validate_tags(tags) - self._input_defs = frozenlist(input_defs) - self._input_dict = frozendict({input_def.name: input_def for input_def in input_defs}) + self._input_defs = input_defs + self._input_dict = {input_def.name: input_def for input_def in input_defs} check.invariant(len(self._input_defs) == len(self._input_dict), "Duplicate input def names") - self._output_defs = frozenlist(output_defs) - self._output_dict = frozendict({output_def.name: output_def for output_def in output_defs}) + self._output_defs = output_defs + self._output_dict = {output_def.name: output_def for output_def in output_defs} check.invariant( len(self._output_defs) == len(self._output_dict), "Duplicate output def names" ) diff --git a/python_modules/dagster/dagster/_core/definitions/partition.py b/python_modules/dagster/dagster/_core/definitions/partition.py index 3cf28cf12eadd..ef329914332b5 100644 --- a/python_modules/dagster/dagster/_core/definitions/partition.py +++ b/python_modules/dagster/dagster/_core/definitions/partition.py @@ -38,7 +38,6 @@ from dagster._core.storage.tags import PARTITION_NAME_TAG from dagster._serdes import whitelist_for_serdes from dagster._seven.compat.pendulum import PendulumDateTime, to_timezone -from dagster._utils import frozenlist from dagster._utils.backcompat import deprecation_warning, experimental_arg_warning from dagster._utils.cached_method import cached_method from dagster._utils.merger import merge_dicts @@ -961,9 +960,7 @@ def _execution_fn(context): return selected_partitions = ( - selector_result - if isinstance(selector_result, (frozenlist, list)) - else [selector_result] + selector_result if isinstance(selector_result, list) else [selector_result] ) check.is_list(selected_partitions, of_type=Partition) diff --git a/python_modules/dagster/dagster/_core/definitions/pipeline_definition.py b/python_modules/dagster/dagster/_core/definitions/pipeline_definition.py index b116ca70914dc..baefcab328f07 100644 --- a/python_modules/dagster/dagster/_core/definitions/pipeline_definition.py +++ b/python_modules/dagster/dagster/_core/definitions/pipeline_definition.py @@ -27,7 +27,6 @@ from dagster._core.storage.tags import MEMOIZED_RUN_TAG from dagster._core.types.dagster_type import DagsterType from dagster._core.utils import str_format_set -from dagster._utils import frozentags from dagster._utils.backcompat import experimental_class_warning from dagster._utils.merger import merge_dicts @@ -357,7 +356,7 @@ def describe_target(self) -> str: @property def tags(self) -> Mapping[str, str]: - return frozentags(**merge_dicts(self._graph_def.tags, self._tags)) + return merge_dicts(self._graph_def.tags, self._tags) @property def metadata(self) -> Sequence[MetadataEntry]: diff --git a/python_modules/dagster/dagster/_core/definitions/reconstruct.py b/python_modules/dagster/dagster/_core/definitions/reconstruct.py index b13cda3a20256..adfa17650ef7d 100644 --- a/python_modules/dagster/dagster/_core/definitions/reconstruct.py +++ b/python_modules/dagster/dagster/_core/definitions/reconstruct.py @@ -40,7 +40,7 @@ ) from dagster._core.selector import parse_solid_selection from dagster._serdes import pack_value, unpack_value, whitelist_for_serdes -from dagster._utils import frozenlist, make_readonly_value +from dagster._utils import hash_collection from .events import AssetKey from .pipeline_base import IPipeline @@ -94,12 +94,12 @@ def __new__( container_image=check.opt_str_param(container_image, "container_image"), executable_path=check.opt_str_param(executable_path, "executable_path"), entry_point=( - frozenlist(check.sequence_param(entry_point, "entry_point", of_type=str)) + check.sequence_param(entry_point, "entry_point", of_type=str) if entry_point is not None else DEFAULT_DAGSTER_ENTRY_POINT ), container_context=( - make_readonly_value(check.mapping_param(container_context, "container_context")) + check.mapping_param(container_context, "container_context") if container_context is not None else None ), @@ -163,6 +163,15 @@ def get_python_origin(self) -> RepositoryPythonOrigin: def get_python_origin_id(self) -> str: return self.get_python_origin().get_id() + # Allow this to be hashed for use in `lru_cache`. This is needed because: + # - `ReconstructablePipeline` uses `lru_cache` + # - `ReconstructablePipeline` has a `ReconstructableRepository` attribute + # - `ReconstructableRepository` has `Sequence` attributes that are unhashable by default + def __hash__(self) -> int: + if not hasattr(self, "_hash"): + self._hash = hash_collection(self) + return self._hash + @whitelist_for_serdes class ReconstructablePipeline( diff --git a/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_definition.py b/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_definition.py index 6ba3f8b6365aa..ee9d2056d1a0f 100644 --- a/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_definition.py +++ b/python_modules/dagster/dagster/_core/definitions/repository_definition/repository_definition.py @@ -34,7 +34,7 @@ from dagster._core.instance import DagsterInstance from dagster._core.selector import parse_solid_selection from dagster._serdes import whitelist_for_serdes -from dagster._utils import make_readonly_value +from dagster._utils import hash_collection from .repository_data import CachingRepositoryData, RepositoryData from .valid_definitions import ( @@ -62,7 +62,7 @@ class RepositoryLoadData( def __new__(cls, cached_data_by_key: Mapping[str, Sequence[AssetsDefinitionCacheableData]]): return super(RepositoryLoadData, cls).__new__( cls, - cached_data_by_key=make_readonly_value( + cached_data_by_key=( check.mapping_param( cached_data_by_key, "cached_data_by_key", @@ -72,6 +72,16 @@ def __new__(cls, cached_data_by_key: Mapping[str, Sequence[AssetsDefinitionCache ), ) + # Allow this to be hashed for use in `lru_cache`. This is needed because: + # - `ReconstructablePipeline` uses `lru_cache` + # - `ReconstructablePipeline` has a `ReconstructableRepository` attribute + # - `ReconstructableRepository` has a `RepositoryLoadData` attribute + # - `RepositoryLoadData` has collection attributes that are unhashable by default + def __hash__(self) -> int: + if not hasattr(self, "_hash"): + self._hash = hash_collection(self) + return self._hash + class RepositoryDefinition: """Define a repository that contains a group of definitions. diff --git a/python_modules/dagster/dagster/_core/definitions/utils.py b/python_modules/dagster/dagster/_core/definitions/utils.py index c087817044d14..b935b3411e31a 100644 --- a/python_modules/dagster/dagster/_core/definitions/utils.py +++ b/python_modules/dagster/dagster/_core/definitions/utils.py @@ -10,7 +10,6 @@ import dagster._seven as seven from dagster._core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError from dagster._core.storage.tags import check_reserved_tags -from dagster._utils import frozentags from dagster._utils.yaml_utils import merge_yaml_strings, merge_yamls DEFAULT_OUTPUT = "result" @@ -92,8 +91,10 @@ def struct_to_string(name: str, **kwargs: object) -> str: return "{name}({props_str})".format(name=name, props_str=props_str) -def validate_tags(tags: Optional[Mapping[str, Any]], allow_reserved_tags=True) -> frozentags: - valid_tags = {} +def validate_tags( + tags: Optional[Mapping[str, Any]], allow_reserved_tags: bool = True +) -> Mapping[str, str]: + valid_tags: Dict[str, str] = {} for key, value in check.opt_mapping_param(tags, "tags", key_type=str).items(): if not isinstance(value, str): valid = False @@ -118,14 +119,14 @@ def validate_tags(tags: Optional[Mapping[str, Any]], allow_reserved_tags=True) - ) ) - valid_tags[key] = str_val + valid_tags[key] = str_val # type: ignore # (possible none) else: valid_tags[key] = value if not allow_reserved_tags: check_reserved_tags(valid_tags) - return frozentags(valid_tags) + return valid_tags def validate_group_name(group_name: Optional[str]) -> str: diff --git a/python_modules/dagster/dagster/_core/instance/__init__.py b/python_modules/dagster/dagster/_core/instance/__init__.py index cbf6cebf8d7af..d75b353cbda79 100644 --- a/python_modules/dagster/dagster/_core/instance/__init__.py +++ b/python_modules/dagster/dagster/_core/instance/__init__.py @@ -71,7 +71,7 @@ from dagster._core.utils import str_format_list from dagster._serdes import ConfigurableClass from dagster._seven import get_current_datetime_in_utc -from dagster._utils import PrintFn, frozentags, traced +from dagster._utils import PrintFn, traced from dagster._utils.backcompat import deprecation_warning, experimental_functionality_warning from dagster._utils.error import serializable_error_info_from_exc_info from dagster._utils.merger import merge_dicts @@ -1043,7 +1043,7 @@ def _construct_run_with_snapshots( solids_to_execute: Optional[AbstractSet[str]], step_keys_to_execute: Optional[Sequence[str]], status: Optional[DagsterRunStatus], - tags: frozentags, + tags: Mapping[str, str], root_run_id: Optional[str], parent_run_id: Optional[str], pipeline_snapshot: Optional[PipelineSnapshot], @@ -1057,7 +1057,10 @@ def _construct_run_with_snapshots( # https://github.com/dagster-io/dagster/issues/2403 if tags and IS_AIRFLOW_INGEST_PIPELINE_STR in tags: if AIRFLOW_EXECUTION_DATE_STR not in tags: - tags[AIRFLOW_EXECUTION_DATE_STR] = get_current_datetime_in_utc().isoformat() + tags = { + **tags, + AIRFLOW_EXECUTION_DATE_STR: get_current_datetime_in_utc().isoformat(), + } check.invariant( not (not pipeline_snapshot and execution_plan_snapshot), @@ -1360,7 +1363,7 @@ def create_run( solids_to_execute=solids_to_execute, step_keys_to_execute=step_keys_to_execute, status=status, - tags=dict(validated_tags), # type: ignore + tags=validated_tags, root_run_id=root_run_id, parent_run_id=parent_run_id, pipeline_snapshot=pipeline_snapshot, @@ -1479,7 +1482,7 @@ def register_managed_run( mode: Optional[str], solids_to_execute: Optional[AbstractSet[str]], step_keys_to_execute: Optional[Sequence[str]], - tags: frozentags, + tags: Mapping[str, str], root_run_id: Optional[str], parent_run_id: Optional[str], pipeline_snapshot: Optional[PipelineSnapshot], diff --git a/python_modules/dagster/dagster/_core/origin.py b/python_modules/dagster/dagster/_core/origin.py index d13b41310d22c..802b4ad98a8c7 100644 --- a/python_modules/dagster/dagster/_core/origin.py +++ b/python_modules/dagster/dagster/_core/origin.py @@ -1,15 +1,16 @@ -from typing import Any, List, Mapping, NamedTuple, Optional, Sequence +from typing import Any, Mapping, NamedTuple, Optional, Sequence + +from typing_extensions import Final import dagster._check as check from dagster._core.code_pointer import CodePointer from dagster._serdes import create_snapshot_id, whitelist_for_serdes -from dagster._utils import frozenlist -DEFAULT_DAGSTER_ENTRY_POINT = frozenlist(["dagster"]) +DEFAULT_DAGSTER_ENTRY_POINT: Final = ["dagster"] -def get_python_environment_entry_point(executable_path: str) -> List[str]: - return frozenlist([executable_path, "-m", "dagster"]) +def get_python_environment_entry_point(executable_path: str) -> Sequence[str]: + return [executable_path, "-m", "dagster"] @whitelist_for_serdes @@ -52,7 +53,7 @@ def __new__( check.inst_param(code_pointer, "code_pointer", CodePointer), check.opt_str_param(container_image, "container_image"), ( - frozenlist(check.sequence_param(entry_point, "entry_point", of_type=str)) + check.sequence_param(entry_point, "entry_point", of_type=str) if entry_point is not None else None ), diff --git a/python_modules/dagster/dagster/_core/snap/dep_snapshot.py b/python_modules/dagster/dagster/_core/snap/dep_snapshot.py index cdd08e042536c..5d4fb08e32243 100644 --- a/python_modules/dagster/dagster/_core/snap/dep_snapshot.py +++ b/python_modules/dagster/dagster/_core/snap/dep_snapshot.py @@ -217,7 +217,7 @@ class SolidInvocationSnap( [ ("solid_name", str), ("solid_def_name", str), - ("tags", Mapping[object, object]), + ("tags", Mapping[str, str]), ("input_dep_snaps", Sequence[InputDependencySnap]), ("is_dynamic_mapped", bool), ], @@ -227,7 +227,7 @@ def __new__( cls, solid_name: str, solid_def_name: str, - tags: Mapping[object, object], + tags: Mapping[str, str], input_dep_snaps: Sequence[InputDependencySnap], is_dynamic_mapped: bool = False, ): @@ -235,7 +235,7 @@ def __new__( cls, solid_name=check.str_param(solid_name, "solid_name"), solid_def_name=check.str_param(solid_def_name, "solid_def_name"), - tags=check.mapping_param(tags, "tags"), + tags=check.mapping_param(tags, "tags", key_type=str, value_type=str), input_dep_snaps=check.sequence_param( input_dep_snaps, "input_dep_snaps", of_type=InputDependencySnap ), diff --git a/python_modules/dagster/dagster/_core/utils.py b/python_modules/dagster/dagster/_core/utils.py index a141e64695318..026d34bc8470f 100644 --- a/python_modules/dagster/dagster/_core/utils.py +++ b/python_modules/dagster/dagster/_core/utils.py @@ -7,17 +7,20 @@ from typing import AbstractSet, Any, Iterable, Mapping, Sequence, Tuple, TypeVar, Union, cast import toposort as toposort_ +from typing_extensions import Final import dagster._check as check -from dagster._utils import frozendict, library_version_from_core_version, parse_package_version +from dagster._utils import library_version_from_core_version, parse_package_version BACKFILL_TAG_LENGTH = 8 -PYTHON_LOGGING_LEVELS_MAPPING = frozendict( - OrderedDict({"CRITICAL": 50, "ERROR": 40, "WARNING": 30, "INFO": 20, "DEBUG": 10}) +PYTHON_LOGGING_LEVELS_MAPPING: Final[Mapping[str, int]] = OrderedDict( + {"CRITICAL": 50, "ERROR": 40, "WARNING": 30, "INFO": 20, "DEBUG": 10} ) -PYTHON_LOGGING_LEVELS_ALIASES = frozendict(OrderedDict({"FATAL": "CRITICAL", "WARN": "WARNING"})) +PYTHON_LOGGING_LEVELS_ALIASES: Final[Mapping[str, str]] = OrderedDict( + {"FATAL": "CRITICAL", "WARN": "WARNING"} +) PYTHON_LOGGING_LEVELS_NAMES = frozenset( [ @@ -35,18 +38,18 @@ def coerce_valid_log_level(log_level: Union[str, int]) -> int: """Convert a log level into an integer for consumption by the low-level Python logging API.""" if isinstance(log_level, int): return log_level - check.str_param(log_level, "log_level") + str_log_level = check.str_param(log_level, "log_level") check.invariant( - log_level.lower() in PYTHON_LOGGING_LEVELS_NAMES, + str_log_level.lower() in PYTHON_LOGGING_LEVELS_NAMES, "Bad value for log level {level}: permissible values are {levels}.".format( - level=log_level, + level=str_log_level, levels=", ".join( ["'{}'".format(level_name.upper()) for level_name in PYTHON_LOGGING_LEVELS_NAMES] ), ), ) - log_level = PYTHON_LOGGING_LEVELS_ALIASES.get(log_level.upper(), log_level.upper()) - return PYTHON_LOGGING_LEVELS_MAPPING[log_level] + str_log_level = PYTHON_LOGGING_LEVELS_ALIASES.get(log_level.upper(), log_level.upper()) + return PYTHON_LOGGING_LEVELS_MAPPING[str_log_level] def toposort(data: Mapping[T, AbstractSet[T]]) -> Sequence[Sequence[T]]: diff --git a/python_modules/dagster/dagster/_grpc/server.py b/python_modules/dagster/dagster/_grpc/server.py index f45a5b67f3de0..807aeae475176 100644 --- a/python_modules/dagster/dagster/_grpc/server.py +++ b/python_modules/dagster/dagster/_grpc/server.py @@ -39,7 +39,6 @@ from dagster._serdes.ipc import IPCErrorMessage, ipc_write_stream, open_ipc_subprocess from dagster._utils import ( find_free_port, - frozenlist, get_run_crash_explanation, safe_tempfile_path_unmanaged, ) @@ -237,7 +236,7 @@ def __init__( self._serializable_load_error = None self._entry_point = ( - frozenlist(check.sequence_param(entry_point, "entry_point", of_type=str)) + check.sequence_param(entry_point, "entry_point", of_type=str) if entry_point is not None else DEFAULT_DAGSTER_ENTRY_POINT ) @@ -1092,24 +1091,24 @@ def open_server_process( executable_path = loadable_target_origin.executable_path if loadable_target_origin else None - subprocess_args = ( - get_python_environment_entry_point(executable_path or sys.executable) - + ["api", "grpc"] - + ["--lazy-load-user-code"] - + (["--port", str(port)] if port else []) - + (["--socket", socket] if socket else []) - + (["-n", str(max_workers)] if max_workers else []) - + (["--heartbeat"] if heartbeat else []) - + (["--heartbeat-timeout", str(heartbeat_timeout)] if heartbeat_timeout else []) - + (["--fixed-server-id", fixed_server_id] if fixed_server_id else []) - + (["--override-system-timezone", mocked_system_timezone] if mocked_system_timezone else []) - + (["--log-level", log_level]) - # only use the Python environment if it has been explicitly set in the workspace - + (["--use-python-environment-entry-point"] if executable_path else []) - + (["--inject-env-vars-from-instance"]) - + (["--instance-ref", serialize_value(instance_ref)]) - + (["--location-name", location_name] if location_name else []) - ) + subprocess_args = [ + *get_python_environment_entry_point(executable_path or sys.executable), + *["api", "grpc"], + *["--lazy-load-user-code"], + *(["--port", str(port)] if port else []), + *(["--socket", socket] if socket else []), + *(["-n", str(max_workers)] if max_workers else []), + *(["--heartbeat"] if heartbeat else []), + *(["--heartbeat-timeout", str(heartbeat_timeout)] if heartbeat_timeout else []), + *(["--fixed-server-id", fixed_server_id] if fixed_server_id else []), + *(["--override-system-timezone", mocked_system_timezone] if mocked_system_timezone else []), + *(["--log-level", log_level]), + # only use the Python environment if it has been explicitly set in the workspace, + *(["--use-python-environment-entry-point"] if executable_path else []), + *(["--inject-env-vars-from-instance"]), + *(["--instance-ref", serialize_value(instance_ref)]), + *(["--location-name", location_name] if location_name else []), + ] if loadable_target_origin: subprocess_args += loadable_target_origin.get_cli_args() diff --git a/python_modules/dagster/dagster/_grpc/types.py b/python_modules/dagster/dagster/_grpc/types.py index eae7b043934f0..4bf945c4b0cc1 100644 --- a/python_modules/dagster/dagster/_grpc/types.py +++ b/python_modules/dagster/dagster/_grpc/types.py @@ -15,7 +15,6 @@ from dagster._core.instance.ref import InstanceRef from dagster._core.origin import PipelinePythonOrigin, get_python_environment_entry_point from dagster._serdes import serialize_value, whitelist_for_serdes -from dagster._utils import frozenlist from dagster._utils.error import SerializableErrorInfo @@ -317,7 +316,7 @@ def __new__( value_type=CodePointer, ), entry_point=( - frozenlist(check.sequence_param(entry_point, "entry_point", of_type=str)) + check.sequence_param(entry_point, "entry_point", of_type=str) if entry_point is not None else None ), diff --git a/python_modules/dagster/dagster/_utils/__init__.py b/python_modules/dagster/dagster/_utils/__init__.py index 6bab21dc9a91f..5af75349987f3 100644 --- a/python_modules/dagster/dagster/_utils/__init__.py +++ b/python_modules/dagster/dagster/_utils/__init__.py @@ -20,18 +20,21 @@ from signal import Signals from typing import ( TYPE_CHECKING, + AbstractSet, Any, Callable, ContextManager, Dict, Generator, Generic, + Hashable, Iterator, List, Mapping, NamedTuple, Optional, Sequence, + Set, Tuple, Type, TypeVar, @@ -211,91 +214,55 @@ def mkdir_p(path: str) -> str: raise -# TODO: Make frozendict generic for type annotations -# https://github.com/dagster-io/dagster/issues/3641 -class frozendict(dict): - def __readonly__(self, *args, **kwargs): - raise RuntimeError("Cannot modify ReadOnlyDict") +def hash_collection( + collection: Union[ + Mapping[Hashable, Any], Sequence[Any], AbstractSet[Any], Tuple[Any, ...], NamedTuple + ] +) -> int: + """Hash a mutable collection or immutable collection containing mutable elements. - # https://docs.python.org/3/library/pickle.html#object.__reduce__ - # - # For a dict, the default behavior for pickle is to iteratively call __setitem__ (see 5th item - # in __reduce__ tuple). Since we want to disable __setitem__ and still inherit dict, we - # override this behavior by defining __reduce__. We return the 3rd item in the tuple, which is - # passed to __setstate__, allowing us to restore the frozendict. - - def __reduce__(self): - return (frozendict, (), dict(self)) - - def __setstate__(self, state): - self.__init__(state) - - __setitem__ = __readonly__ - __delitem__ = __readonly__ - pop = __readonly__ - popitem = __readonly__ - clear = __readonly__ - update = __readonly__ - setdefault = __readonly__ # type: ignore[assignment] - del __readonly__ - - def __hash__(self): - return hash(tuple(sorted(self.items()))) - - -class frozenlist(list): - def __readonly__(self, *args, **kwargs): - raise RuntimeError("Cannot modify ReadOnlyList") - - # https://docs.python.org/3/library/pickle.html#object.__reduce__ - # - # Like frozendict, implement __reduce__ and __setstate__ to handle pickling. - # Otherwise, __setstate__ will be called to restore the frozenlist, causing - # a RuntimeError because frozenlist is not mutable. - - def __reduce__(self): - return (frozenlist, (), list(self)) + This is useful for hashing Dagster-specific NamedTuples that contain mutable lists or dicts. + The default NamedTuple __hash__ function assumes the contents of the NamedTuple are themselves + hashable, and will throw an error if they are not. This can occur when trying to e.g. compute a + cache key for the tuple for use with `lru_cache`. - def __setstate__(self, state): - self.__init__(state) + This alternative implementation will recursively process collection elements to convert basic + lists and dicts to tuples prior to hashing. It is recommended to cache the result: - __setitem__ = __readonly__ - __delitem__ = __readonly__ - append = __readonly__ - clear = __readonly__ - extend = __readonly__ - insert = __readonly__ - pop = __readonly__ - remove = __readonly__ - reverse = __readonly__ - sort = __readonly__ # type: ignore[assignment] + Example: + .. code-block:: python - def __hash__(self): - return hash(tuple(self)) + def __hash__(self): + if not hasattr(self, '_hash'): + self._hash = hash_named_tuple(self) + return self._hash + """ + assert isinstance( + collection, (list, dict, set, tuple) + ), f"Cannot hash collection of type {type(collection)}" + return hash(make_hashable(collection)) @overload -def make_readonly_value(value: List[T]) -> Sequence[T]: +def make_hashable(value: Union[List[Any], Set[Any]]) -> Tuple[Any, ...]: ... @overload -def make_readonly_value(value: Dict[T, U]) -> Mapping[T, U]: +def make_hashable(value: Dict[Any, Any]) -> Tuple[Tuple[Any, Any]]: ... @overload -def make_readonly_value(value: T) -> T: +def make_hashable(value: Any) -> Any: ... -def make_readonly_value(value: Any) -> Any: - if isinstance(value, list): - return frozenlist(list(map(make_readonly_value, value))) - elif isinstance(value, dict): - return frozendict({key: make_readonly_value(value) for key, value in value.items()}) - elif isinstance(value, set): - return frozenset(map(make_readonly_value, value)) +def make_hashable(value: Any) -> Any: + if isinstance(value, dict): + return tuple(sorted((key, make_hashable(value)) for key, value in value.items())) + elif isinstance(value, (list, tuple, set)): + return tuple([make_hashable(x) for x in value]) else: return value @@ -481,24 +448,6 @@ def datetime_as_float(dt: datetime.datetime) -> float: return float((dt - EPOCH).total_seconds()) -# hashable frozen string to string dict -class frozentags(frozendict, Mapping[str, str]): - def __init__(self, *args, **kwargs): - super(frozentags, self).__init__(*args, **kwargs) - check.dict_param(self, "self", key_type=str, value_type=str) - - def __hash__(self): - return hash(tuple(sorted(self.items()))) - - def updated_with(self, new_tags): - check.dict_param(new_tags, "new_tags", key_type=str, value_type=str) - updated = dict(self) - for key, value in new_tags.items(): - updated[key] = value - - return frozentags(updated) - - T_GeneratedContext = TypeVar("T_GeneratedContext") diff --git a/python_modules/dagster/dagster_tests/core_tests/test_utils.py b/python_modules/dagster/dagster_tests/core_tests/test_utils.py index 7415b2f79d916..c552ee2f208e0 100644 --- a/python_modules/dagster/dagster_tests/core_tests/test_utils.py +++ b/python_modules/dagster/dagster_tests/core_tests/test_utils.py @@ -1,11 +1,12 @@ import warnings +from typing import Dict, List, NamedTuple import dagster.version import pytest from dagster._core.libraries import DagsterLibraryRegistry from dagster._core.test_utils import environ from dagster._core.utils import check_dagster_package_version, parse_env_var -from dagster._utils import library_version_from_core_version +from dagster._utils import hash_collection, library_version_from_core_version def test_parse_env_var_no_equals(): @@ -67,3 +68,31 @@ def test_library_version_from_core_version(): def test_library_registry(): assert DagsterLibraryRegistry.get() == {"dagster": dagster.version.__version__} + + +def test_hash_collection(): + # lists have different hashes depending on order + assert hash_collection([1, 2, 3]) == hash_collection([1, 2, 3]) + assert hash_collection([1, 2, 3]) != hash_collection([2, 1, 3]) + + # dicts have same hash regardless of order + assert hash_collection({"a": 1, "b": 2}) == hash_collection({"b": 2, "a": 1}) + + assert hash_collection(set(range(10))) == hash_collection(set(range(10))) + + with pytest.raises(AssertionError): + hash_collection(object()) + + class Foo(NamedTuple): + a: List[int] + b: Dict[str, int] + c: str + + with pytest.raises(Exception): + hash(Foo([1, 2, 3], {"a": 1}, "alpha")) + + class Bar(Foo): + def __hash__(self): + return hash_collection(self) + + assert hash(Bar([1, 2, 3], {"a": 1}, "alpha")) == hash(Bar([1, 2, 3], {"a": 1}, "alpha")) diff --git a/python_modules/dagster/dagster_tests/execution_tests/test_metadata.py b/python_modules/dagster/dagster_tests/execution_tests/test_metadata.py index 3641a058973de..678b21ff34ab8 100644 --- a/python_modules/dagster/dagster_tests/execution_tests/test_metadata.py +++ b/python_modules/dagster/dagster_tests/execution_tests/test_metadata.py @@ -33,7 +33,6 @@ from dagster._core.execution.results import OpExecutionResult, PipelineExecutionResult from dagster._legacy import execute_pipeline, pipeline from dagster._serdes.serdes import deserialize_value, serialize_value -from dagster._utils import frozendict def solid_events_for_type( @@ -229,23 +228,21 @@ def test_table_metadata_value_schema_inference(): ] -bad_values = frozendict( - { - "table_schema": {"columns": False, "constraints": False}, - "table_column": { - "name": False, - "type": False, - "description": False, - "constraints": False, - }, - "table_constraints": {"other": False}, - "table_column_constraints": { - "nullable": "foo", - "unique": "foo", - "other": False, - }, - } -) +bad_values = { + "table_schema": {"columns": False, "constraints": False}, + "table_column": { + "name": False, + "type": False, + "description": False, + "constraints": False, + }, + "table_constraints": {"other": False}, + "table_column_constraints": { + "nullable": "foo", + "unique": "foo", + "other": False, + }, +} def test_table_column_keys(): diff --git a/python_modules/dagster/dagster_tests/general_tests/check_tests/test_check.py b/python_modules/dagster/dagster_tests/general_tests/check_tests/test_check.py index e9490895c596c..b35a7acc3e8cb 100644 --- a/python_modules/dagster/dagster_tests/general_tests/check_tests/test_check.py +++ b/python_modules/dagster/dagster_tests/general_tests/check_tests/test_check.py @@ -12,7 +12,6 @@ NotImplementedCheckError, ParameterCheckError, ) -from dagster._utils import frozendict, frozenlist @contextmanager @@ -216,7 +215,6 @@ class AlsoWrong: DICT_TEST_CASES = [ (dict(obj={}), True), - (dict(obj=frozendict()), True), (dict(obj={"a": 2}), True), (dict(obj=None), False), (dict(obj=0), False), @@ -322,7 +320,6 @@ class AlsoWrong: def test_opt_dict_param(): assert check.opt_dict_param(None, "opt_dict_param") == {} assert check.opt_dict_param({}, "opt_dict_param") == {} - assert check.opt_dict_param(frozendict(), "opt_dict_param") == {} ddict = {"a": 2} assert check.opt_dict_param(ddict, "opt_dict_param") == ddict @@ -345,7 +342,6 @@ def test_opt_dict_param(): def test_opt_nullable_dict_param(): assert check.opt_nullable_dict_param(None, "opt_nullable_dict_param") is None assert check.opt_nullable_dict_param({}, "opt_nullable_dict_param") == {} - assert check.opt_nullable_dict_param(frozendict(), "opt_nullable_dict_param") == {} ddict = {"a": 2} assert check.opt_nullable_dict_param(ddict, "opt_nullable_dict_param") == ddict @@ -776,7 +772,6 @@ class Baaz: def test_list_param(): assert check.list_param([], "list_param") == [] - assert check.list_param(frozenlist(), "list_param") == [] assert check.list_param(["foo"], "list_param", of_type=str) == ["foo"] @@ -812,7 +807,6 @@ def test_opt_list_param(): assert check.opt_list_param(None, "list_param") == [] assert check.opt_list_param(None, "list_param", of_type=str) == [] assert check.opt_list_param([], "list_param") == [] - assert check.opt_list_param(frozenlist(), "list_param") == [] obj_list = [1] assert check.list_param(obj_list, "list_param") == obj_list assert check.opt_list_param(["foo"], "list_param", of_type=str) == ["foo"] @@ -852,7 +846,6 @@ class Bar: def test_opt_nullable_list_param(): assert check.opt_nullable_list_param(None, "list_param") is None assert check.opt_nullable_list_param([], "list_param") == [] - assert check.opt_nullable_list_param(frozenlist(), "list_param") == [] obj_list = [1] assert check.opt_nullable_list_param(obj_list, "list_param") == obj_list diff --git a/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_frozendict.py b/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_frozendict.py deleted file mode 100644 index a5772d90f407d..0000000000000 --- a/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_frozendict.py +++ /dev/null @@ -1,18 +0,0 @@ -import pickle - -import pytest -from dagster._utils import frozendict - - -def test_frozendict(): - d = frozendict({"foo": "bar"}) - with pytest.raises(RuntimeError): - d["zip"] = "zowie" - - -def test_pickle_frozendict(): - orig_dict = [{"foo": "bar"}] - data = pickle.dumps(orig_dict) - loaded_dict = pickle.loads(data) - - assert orig_dict == loaded_dict diff --git a/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_frozenlist.py b/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_frozenlist.py deleted file mode 100644 index b3119e4d669f8..0000000000000 --- a/python_modules/dagster/dagster_tests/general_tests/utils_tests/test_frozenlist.py +++ /dev/null @@ -1,23 +0,0 @@ -import pickle - -import pytest -from dagster._utils import frozenlist - - -def test_pickle_frozenlist(): - orig_list = frozenlist([1, "a", {}]) - data = pickle.dumps(orig_list) - loaded_list = pickle.loads(data) - - assert orig_list == loaded_list - - -def test_hash_frozen_list(): - assert hash(frozenlist([])) - assert hash(frozenlist(["foo", "bar"])) - - with pytest.raises(TypeError, match="unhashable type"): - hash(frozenlist([[]])) - - with pytest.raises(TypeError, match="unhashable type"): - hash(frozenlist([{}])) diff --git a/python_modules/libraries/dagster-celery-k8s/dagster_celery_k8s/launcher.py b/python_modules/libraries/dagster-celery-k8s/dagster_celery_k8s/launcher.py index 7ca1fdcbac06f..ca7f8ba6b3f54 100644 --- a/python_modules/libraries/dagster-celery-k8s/dagster_celery_k8s/launcher.py +++ b/python_modules/libraries/dagster-celery-k8s/dagster_celery_k8s/launcher.py @@ -16,7 +16,6 @@ from dagster._core.storage.pipeline_run import DagsterRun, DagsterRunStatus from dagster._core.storage.tags import DOCKER_IMAGE_TAG from dagster._serdes import ConfigurableClass, ConfigurableClassData -from dagster._utils import frozentags from dagster._utils.error import serializable_error_info_from_exc_info from dagster._utils.merger import merge_dicts from dagster_k8s.client import DagsterKubernetesClient @@ -200,7 +199,7 @@ def launch_run(self, context: LaunchRunContext) -> None: {DOCKER_IMAGE_TAG: job_config.job_image}, ) - user_defined_k8s_config = get_user_defined_k8s_config(frozentags(run.tags)) + user_defined_k8s_config = get_user_defined_k8s_config(run.tags) from dagster._cli.api import ExecuteRunArgs diff --git a/python_modules/libraries/dagster-dask/dagster_dask/executor.py b/python_modules/libraries/dagster-dask/dagster_dask/executor.py index 8508f1852b30a..68061576aa7bf 100644 --- a/python_modules/libraries/dagster-dask/dagster_dask/executor.py +++ b/python_modules/libraries/dagster-dask/dagster_dask/executor.py @@ -1,3 +1,5 @@ +from typing import Mapping + import dask import dask.distributed from dagster import ( @@ -18,7 +20,7 @@ from dagster._core.execution.plan.plan import ExecutionPlan from dagster._core.execution.retries import RetryMode from dagster._core.instance import DagsterInstance -from dagster._utils import frozentags, iterate_with_context +from dagster._utils import iterate_with_context # Dask resource requirements are specified under this key DASK_RESOURCE_REQUIREMENTS_KEY = "dagster-dask/resource_requirements" @@ -143,8 +145,8 @@ def query_on_dask_worker( ) -def get_dask_resource_requirements(tags): - check.inst_param(tags, "tags", frozentags) +def get_dask_resource_requirements(tags: Mapping[str, str]): + check.mapping_param(tags, "tags", key_type=str, value_type=str) req_str = tags.get(DASK_RESOURCE_REQUIREMENTS_KEY) if req_str is not None: return _seven.json.loads(req_str) diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/container_context.py b/python_modules/libraries/dagster-k8s/dagster_k8s/container_context.py index 27bb90cac60a4..5c41e4b1eac2f 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/container_context.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/container_context.py @@ -8,7 +8,7 @@ from dagster._core.errors import DagsterInvalidConfigError from dagster._core.storage.pipeline_run import DagsterRun from dagster._core.utils import parse_env_var -from dagster._utils import frozentags, make_readonly_value +from dagster._utils import hash_collection if TYPE_CHECKING: from . import K8sRunLauncher @@ -18,7 +18,13 @@ def _dedupe_list(values): - return sorted(list(set([make_readonly_value(value) for value in values])), key=hash) + new_list = [] + for value in values: + if value not in new_list: + new_list.append(value) + return sorted( + new_list, key=lambda x: hash_collection(x) if isinstance(x, (list, dict)) else hash(x) + ) class K8sContainerContext( @@ -185,7 +191,7 @@ def create_for_run( ) if include_run_tags: - user_defined_k8s_config = get_user_defined_k8s_config(frozentags(pipeline_run.tags)) + user_defined_k8s_config = get_user_defined_k8s_config(pipeline_run.tags) context = context.merge( K8sContainerContext(run_k8s_config=user_defined_k8s_config.to_dict()) diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py index 4697894bc70dd..9087061dec557 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py @@ -21,7 +21,6 @@ StepHandler, StepHandlerContext, ) -from dagster._utils import frozentags from dagster._utils.merger import merge_dicts from dagster_k8s.launcher import K8sRunLauncher @@ -188,7 +187,7 @@ def _get_container_context( context = context.merge(self._executor_container_context) user_defined_k8s_config = get_user_defined_k8s_config( - frozentags(step_handler_context.step_tags[step_key]) + step_handler_context.step_tags[step_key] ) return context.merge(K8sContainerContext(run_k8s_config=user_defined_k8s_config.to_dict())) diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/job.py b/python_modules/libraries/dagster-k8s/dagster_k8s/job.py index ec6cb45f82c4c..cf20d7d275266 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/job.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/job.py @@ -20,7 +20,6 @@ from dagster._core.errors import DagsterInvalidConfigError from dagster._core.utils import parse_env_var from dagster._serdes import whitelist_for_serdes -from dagster._utils import frozentags from dagster._utils.merger import merge_dicts from .models import k8s_model_from_dict, k8s_snake_case_dict @@ -149,8 +148,8 @@ def from_dict(cls, config_dict): ) -def get_k8s_resource_requirements(tags): - check.inst_param(tags, "tags", frozentags) +def get_k8s_resource_requirements(tags: Mapping[str, str]): + check.mapping_param(tags, "tags", key_type=str, value_type=str) check.invariant(K8S_RESOURCE_REQUIREMENTS_KEY in tags) resource_requirements = json.loads(tags[K8S_RESOURCE_REQUIREMENTS_KEY]) @@ -166,8 +165,8 @@ def get_k8s_resource_requirements(tags): return result.value -def get_user_defined_k8s_config(tags): - check.inst_param(tags, "tags", frozentags) +def get_user_defined_k8s_config(tags: Mapping[str, str]): + check.mapping_param(tags, "tags", key_type=str, value_type=str) if not any(key in tags for key in [K8S_RESOURCE_REQUIREMENTS_KEY, USER_DEFINED_K8S_CONFIG_KEY]): return UserDefinedDagsterK8sConfig() @@ -187,7 +186,7 @@ def get_user_defined_k8s_config(tags): user_defined_k8s_config = result.value - container_config = user_defined_k8s_config.get("container_config", {}) + container_config = user_defined_k8s_config.get("container_config", {}) # type: ignore # Backcompat for resource requirements key if K8S_RESOURCE_REQUIREMENTS_KEY in tags: @@ -198,11 +197,11 @@ def get_user_defined_k8s_config(tags): return UserDefinedDagsterK8sConfig( container_config=container_config, - pod_template_spec_metadata=user_defined_k8s_config.get("pod_template_spec_metadata"), - pod_spec_config=user_defined_k8s_config.get("pod_spec_config"), - job_config=user_defined_k8s_config.get("job_config"), - job_metadata=user_defined_k8s_config.get("job_metadata"), - job_spec_config=user_defined_k8s_config.get("job_spec_config"), + pod_template_spec_metadata=user_defined_k8s_config.get("pod_template_spec_metadata"), # type: ignore + pod_spec_config=user_defined_k8s_config.get("pod_spec_config"), # type: ignore + job_config=user_defined_k8s_config.get("job_config"), # type: ignore + job_metadata=user_defined_k8s_config.get("job_metadata"), # type: ignore + job_spec_config=user_defined_k8s_config.get("job_spec_config"), # type: ignore ) diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/models.py b/python_modules/libraries/dagster-k8s/dagster_k8s/models.py index 442e805b66ace..0862e2a42cdf5 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/models.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/models.py @@ -4,7 +4,6 @@ import dagster._check as check import kubernetes -from dagster._utils import frozendict from dateutil.parser import parse from kubernetes.client import ApiClient @@ -59,7 +58,7 @@ def _k8s_parse_value(data, classname, attr_name): elif klass == datetime.datetime: return parse(data) else: - if not isinstance(data, (frozendict, dict)): + if not isinstance(data, dict): raise Exception( f"Attribute {attr_name} of type {klass.__name__} must be a dict, received" f" {data} instead" @@ -87,7 +86,7 @@ def _k8s_snake_case_value(val, attr_type, attr_name): ): return val else: - if not isinstance(val, (frozendict, dict)): + if not isinstance(val, dict): raise Exception( f"Attribute {attr_name} of type {klass.__name__} must be a dict, received" f" {val} instead" diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_container_context.py b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_container_context.py index 802a84872cd26..52e3b5f90f4f8 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_container_context.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_container_context.py @@ -1,6 +1,6 @@ import pytest from dagster._core.errors import DagsterInvalidConfigError -from dagster._utils import make_readonly_value +from dagster._utils import hash_collection from dagster_k8s.container_context import K8sContainerContext @@ -189,9 +189,10 @@ def test_invalid_config(): def _check_same_sorted(list1, list2): - assert sorted( - [make_readonly_value(val) for val in list1], key=lambda val: val.__hash__() - ) == sorted([make_readonly_value(val) for val in list2], key=lambda val: val.__hash__()) + key_fn = lambda x: hash_collection(x) if isinstance(x, (list, dict)) else hash(x) + sorted1 = sorted(list1, key=key_fn) + sorted2 = sorted(list2, key=key_fn) + assert sorted1 == sorted2 def test_camel_case_volumes(container_context_camel_case_volumes, container_context):