Skip to content

Commit

Permalink
[refactor] Removes frozentags class
Browse files Browse the repository at this point in the history
  • Loading branch information
smackesey committed Feb 12, 2023
1 parent bc9f2fb commit d676cea
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _core_submit_execution(
run_config: Optional[Mapping[str, Any]] = None,
mode: Optional[str] = None,
preset: Optional[str] = None,
tags: Optional[Dict[str, Any]] = None,
tags: Optional[Mapping[str, object]] = None,
solid_selection: Optional[List[str]] = None,
is_using_job_op_graph_apis: Optional[bool] = False,
):
Expand Down
15 changes: 7 additions & 8 deletions python_modules/dagster/dagster/_core/definitions/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
DagsterInvalidInvocationError,
DagsterInvariantViolationError,
)
from dagster._utils import frozentags

from .config import ConfigMapping
from .dependency import (
Expand Down Expand Up @@ -238,7 +237,7 @@ def observe_invocation(
given_alias: Optional[str],
node_def: NodeDefinition,
input_bindings: Mapping[str, InputSource],
tags: Optional[frozentags],
tags: Optional[Mapping[str, str]],
hook_defs: Optional[AbstractSet[HookDefinition]],
retry_policy: Optional[RetryPolicy],
) -> str:
Expand Down Expand Up @@ -403,21 +402,21 @@ def the_graph():

node_def: NodeDefinition
given_alias: Optional[str]
tags: Optional[frozentags]
tags: Optional[Mapping[str, str]]
hook_defs: AbstractSet[HookDefinition]
retry_policy: Optional[RetryPolicy]

def __init__(
self,
node_def: NodeDefinition,
given_alias: Optional[str],
tags: Optional[frozentags],
tags: Optional[Mapping[str, str]],
hook_defs: Optional[AbstractSet[HookDefinition]],
retry_policy: Optional[RetryPolicy],
):
self.node_def = check.inst_param(node_def, "node_def", NodeDefinition)
self.given_alias = check.opt_str_param(given_alias, "given_alias")
self.tags = check.opt_inst_param(tags, "tags", frozentags)
self.tags = check.opt_mapping_param(tags, "tags", key_type=str, value_type=str)
self.hook_defs = check.opt_set_param(hook_defs, "hook_defs", HookDefinition)
self.retry_policy = check.opt_inst_param(retry_policy, "retry_policy", RetryPolicy)

Expand Down Expand Up @@ -665,7 +664,7 @@ def tag(self, tags: Optional[Mapping[str, str]]) -> "PendingNodeInvocation":
return PendingNodeInvocation(
node_def=self.node_def,
given_alias=self.given_alias,
tags=frozentags(tags) if self.tags is None else self.tags.updated_with(tags),
tags={**(self.tags or {}), **tags},
hook_defs=self.hook_defs,
retry_policy=self.retry_policy,
)
Expand Down Expand Up @@ -725,7 +724,7 @@ def to_job(
description=description,
resource_defs=resource_defs,
config=config,
tags=tags if not self.tags else self.tags.updated_with(tags),
tags={**(self.tags or {}), **tags},
logger_defs=logger_defs,
executor_def=executor_def,
hooks=job_hooks,
Expand Down Expand Up @@ -783,7 +782,7 @@ class InvokedNode(NamedTuple):
node_name: str
node_def: NodeDefinition
input_bindings: Mapping[str, InputSource]
tags: Optional[frozentags]
tags: Optional[Mapping[str, str]]
hook_defs: Optional[AbstractSet[HookDefinition]]
retry_policy: Optional[RetryPolicy]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
register_serdes_tuple_fallbacks,
whitelist_for_serdes,
)
from dagster._utils import frozentags

from .hook_definition import HookDefinition
from .input import FanInInputPointer, InputDefinition, InputMapping, InputPointer
Expand Down Expand Up @@ -101,7 +100,7 @@ def __new__(
cls,
name=check.str_param(name, "name"),
alias=check.opt_str_param(alias, "alias"),
tags=frozentags(check.opt_mapping_param(tags, "tags", value_type=str, key_type=str)),
tags=check.opt_mapping_param(tags, "tags", value_type=str, key_type=str),
hook_defs=frozenset(
check.opt_set_param(hook_defs, "hook_defs", of_type=HookDefinition)
),
Expand Down Expand Up @@ -190,7 +189,7 @@ def output_dict(self) -> Mapping[str, OutputDefinition]:
return self.definition.output_dict

@property
def tags(self) -> frozentags:
def tags(self) -> Mapping[str, str]:
# Type-ignore temporarily pending assessment of right data structure for `tags`
return self.definition.tags.updated_with(self._additional_tags) # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
)
from dagster._core.storage.tags import MEMOIZED_RUN_TAG
from dagster._core.utils import str_format_set
from dagster._utils import frozentags
from dagster._utils.backcompat import experimental_class_warning
from dagster._utils.merger import merge_dicts

Expand Down Expand Up @@ -329,7 +328,7 @@ def describe_target(self) -> str:

@property
def tags(self) -> Mapping[str, str]:
return frozentags(**merge_dicts(self._graph_def.tags, self._tags))
return merge_dicts(self._graph_def.tags, self._tags)

@property
def metadata(self) -> Sequence[Union[MetadataEntry, PartitionMetadataEntry]]:
Expand Down
47 changes: 25 additions & 22 deletions python_modules/dagster/dagster/_core/definitions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import dagster._seven as seven
from dagster._core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError
from dagster._core.storage.tags import check_reserved_tags
from dagster._utils import frozentags
from dagster._utils.yaml_utils import merge_yaml_strings, merge_yamls

DEFAULT_OUTPUT = "result"
Expand Down Expand Up @@ -92,40 +91,44 @@ def struct_to_string(name: str, **kwargs: object) -> str:
return "{name}({props_str})".format(name=name, props_str=props_str)


def validate_tags(tags: Optional[Mapping[str, Any]], allow_reserved_tags=True) -> frozentags:
valid_tags = {}
def validate_tags(
tags: Optional[Mapping[str, object]], allow_reserved_tags=True
) -> Mapping[str, str]:
valid_tags: Dict[str, str] = {}
for key, value in check.opt_mapping_param(tags, "tags", key_type=str).items():
if not isinstance(value, str):
valid = False
err_reason = 'Could not JSON encode value "{}"'.format(value)
str_val = None
try:
str_val = seven.json.dumps(value)
err_reason = (
'JSON encoding "{json}" of value "{val}" is not equivalent to original value'
.format(json=str_val, val=value)
)

valid = seven.json.loads(str_val) == value
except Exception:
pass

if not valid:
raise DagsterInvalidDefinitionError(
'Invalid value for tag "{key}", {err_reason}. Tag values must be strings '
"or meet the constraint that json.loads(json.dumps(value)) == value.".format(
key=key, err_reason=err_reason
if not valid:
raise DagsterInvalidDefinitionError(
_get_tags_error_msg(
key,
(
f'JSON encoding "{str_val}" of value "{value}" is not equivalent to'
" original value"
),
)
)
valid_tags[key] = str_val
except TypeError: # thrown for unencodable json
raise DagsterInvalidDefinitionError(
_get_tags_error_msg(key, f'Could not JSON encode value "{value}"')
)

valid_tags[key] = str_val
else:
valid_tags[key] = value

if not allow_reserved_tags:
check_reserved_tags(valid_tags)

return frozentags(valid_tags)
return valid_tags


def _get_tags_error_msg(key: str, error_reason: str) -> str:
return (
f'Invalid value for tag "{key}", {error_reason}. Tag values must be strings or meet the'
" constraint that json.loads(json.dumps(value)) == value."
)


def validate_group_name(group_name: Optional[str]) -> str:
Expand Down
4 changes: 2 additions & 2 deletions python_modules/dagster/dagster/_core/snap/dep_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ class SolidInvocationSnap(
[
("solid_name", str),
("solid_def_name", str),
("tags", Mapping[object, object]),
("tags", Mapping[str, str]),
("input_dep_snaps", Sequence[InputDependencySnap]),
("is_dynamic_mapped", bool),
],
Expand All @@ -228,7 +228,7 @@ def __new__(
cls,
solid_name: str,
solid_def_name: str,
tags: Mapping[object, object],
tags: Mapping[str, str],
input_dep_snaps: Sequence[InputDependencySnap],
is_dynamic_mapped: bool = False,
):
Expand Down
18 changes: 0 additions & 18 deletions python_modules/dagster/dagster/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,24 +478,6 @@ def datetime_as_float(dt):
return float((dt - EPOCH).total_seconds())


# hashable frozen string to string dict
class frozentags(frozendict, Mapping[str, str]):
def __init__(self, *args, **kwargs):
super(frozentags, self).__init__(*args, **kwargs)
check.dict_param(self, "self", key_type=str, value_type=str)

def __hash__(self):
return hash(tuple(sorted(self.items())))

def updated_with(self, new_tags):
check.dict_param(new_tags, "new_tags", key_type=str, value_type=str)
updated = dict(self)
for key, value in new_tags.items():
updated[key] = value

return frozentags(updated)


T_GeneratedContext = TypeVar("T_GeneratedContext")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from dagster._core.storage.pipeline_run import DagsterRun, DagsterRunStatus
from dagster._core.storage.tags import DOCKER_IMAGE_TAG
from dagster._serdes import ConfigurableClass, ConfigurableClassData
from dagster._utils import frozentags
from dagster._utils.error import serializable_error_info_from_exc_info
from dagster._utils.merger import merge_dicts
from dagster_k8s.client import DagsterKubernetesClient
Expand Down Expand Up @@ -198,7 +197,7 @@ def launch_run(self, context: LaunchRunContext) -> None:
{DOCKER_IMAGE_TAG: job_config.job_image},
)

user_defined_k8s_config = get_user_defined_k8s_config(frozentags(run.tags))
user_defined_k8s_config = get_user_defined_k8s_config(run.tags)

from dagster._cli.api import ExecuteRunArgs

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Mapping

import dask
import dask.distributed
from dagster import (
Expand All @@ -18,7 +20,7 @@
from dagster._core.execution.plan.plan import ExecutionPlan
from dagster._core.execution.retries import RetryMode
from dagster._core.instance import DagsterInstance
from dagster._utils import frozentags, iterate_with_context
from dagster._utils import iterate_with_context

# Dask resource requirements are specified under this key
DASK_RESOURCE_REQUIREMENTS_KEY = "dagster-dask/resource_requirements"
Expand Down Expand Up @@ -143,8 +145,8 @@ def query_on_dask_worker(
)


def get_dask_resource_requirements(tags):
check.inst_param(tags, "tags", frozentags)
def get_dask_resource_requirements(tags: Mapping[str, str]):
check.mapping_param(tags, "tags", key_type=str, value_type=str)
req_str = tags.get(DASK_RESOURCE_REQUIREMENTS_KEY)
if req_str is not None:
return _seven.json.loads(req_str)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dagster._core.errors import DagsterInvalidConfigError
from dagster._core.storage.pipeline_run import DagsterRun
from dagster._core.utils import parse_env_var
from dagster._utils import frozentags, make_readonly_value
from dagster._utils import make_readonly_value

if TYPE_CHECKING:
from . import K8sRunLauncher
Expand Down Expand Up @@ -175,7 +175,7 @@ def create_for_run(
K8sContainerContext.create_from_config(run_container_context)
)

user_defined_k8s_config = get_user_defined_k8s_config(frozentags(pipeline_run.tags))
user_defined_k8s_config = get_user_defined_k8s_config(pipeline_run.tags)

context = context.merge(
K8sContainerContext(run_k8s_config=user_defined_k8s_config.to_dict())
Expand Down
3 changes: 1 addition & 2 deletions python_modules/libraries/dagster-k8s/dagster_k8s/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
StepHandler,
StepHandlerContext,
)
from dagster._utils import frozentags
from dagster._utils.merger import merge_dicts

from dagster_k8s.launcher import K8sRunLauncher
Expand Down Expand Up @@ -224,7 +223,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[Dags
raise Exception("No image included in either executor config or the job")

user_defined_k8s_config = get_user_defined_k8s_config(
frozentags(step_handler_context.step_tags[step_key])
step_handler_context.step_tags[step_key]
)

job = construct_dagster_k8s_job(
Expand Down
21 changes: 10 additions & 11 deletions python_modules/libraries/dagster-k8s/dagster_k8s/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from dagster._core.errors import DagsterInvalidConfigError
from dagster._core.utils import parse_env_var
from dagster._serdes import whitelist_for_serdes
from dagster._utils import frozentags
from dagster._utils.merger import merge_dicts

from .models import k8s_model_from_dict, k8s_snake_case_dict
Expand Down Expand Up @@ -149,8 +148,8 @@ def from_dict(cls, config_dict):
)


def get_k8s_resource_requirements(tags):
check.inst_param(tags, "tags", frozentags)
def get_k8s_resource_requirements(tags: Mapping[str, str]):
check.mapping_param(tags, "tags", key_type=str, value_type=str)
check.invariant(K8S_RESOURCE_REQUIREMENTS_KEY in tags)

resource_requirements = json.loads(tags[K8S_RESOURCE_REQUIREMENTS_KEY])
Expand All @@ -166,8 +165,8 @@ def get_k8s_resource_requirements(tags):
return result.value


def get_user_defined_k8s_config(tags):
check.inst_param(tags, "tags", frozentags)
def get_user_defined_k8s_config(tags: Mapping[str, str]):
check.mapping_param(tags, "tags", key_type=str, value_type=str)

if not any(key in tags for key in [K8S_RESOURCE_REQUIREMENTS_KEY, USER_DEFINED_K8S_CONFIG_KEY]):
return UserDefinedDagsterK8sConfig()
Expand All @@ -187,7 +186,7 @@ def get_user_defined_k8s_config(tags):

user_defined_k8s_config = result.value

container_config = user_defined_k8s_config.get("container_config", {})
container_config = user_defined_k8s_config.get("container_config", {}) # type: ignore

# Backcompat for resource requirements key
if K8S_RESOURCE_REQUIREMENTS_KEY in tags:
Expand All @@ -198,11 +197,11 @@ def get_user_defined_k8s_config(tags):

return UserDefinedDagsterK8sConfig(
container_config=container_config,
pod_template_spec_metadata=user_defined_k8s_config.get("pod_template_spec_metadata"),
pod_spec_config=user_defined_k8s_config.get("pod_spec_config"),
job_config=user_defined_k8s_config.get("job_config"),
job_metadata=user_defined_k8s_config.get("job_metadata"),
job_spec_config=user_defined_k8s_config.get("job_spec_config"),
pod_template_spec_metadata=user_defined_k8s_config.get("pod_template_spec_metadata"), # type: ignore
pod_spec_config=user_defined_k8s_config.get("pod_spec_config"), # type: ignore
job_config=user_defined_k8s_config.get("job_config"), # type: ignore
job_metadata=user_defined_k8s_config.get("job_metadata"), # type: ignore
job_spec_config=user_defined_k8s_config.get("job_spec_config"), # type: ignore
)


Expand Down

0 comments on commit d676cea

Please sign in to comment.