Skip to content

Commit

Permalink
Revert "Fix mypy errors (#1313)"
Browse files Browse the repository at this point in the history
This reverts commit af49155.

Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
  • Loading branch information
eapolinario committed May 17, 2023
1 parent f24f261 commit 14fb5b5
Show file tree
Hide file tree
Showing 55 changed files with 348 additions and 404 deletions.
11 changes: 5 additions & 6 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,11 @@ fmt: ## Format code with black and isort

.PHONY: lint
lint: ## Run linters
mypy flytekit/core
mypy flytekit/types
# allow-empty-bodies: Allow empty body in function.
# disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked".
# Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass.
mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core
mypy flytekit/core || true
mypy flytekit/types || true
mypy tests/flytekit/unit/core || true
# Exclude setup.py to fix error: Duplicate module named "setup"
mypy plugins --exclude setup.py || true
pre-commit run --all-files

.PHONY: spellcheck
Expand Down
8 changes: 4 additions & 4 deletions flytekit/core/base_sql_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, Optional, Tuple, Type, TypeVar
from typing import Any, Dict, Optional, Type, TypeVar

from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.interface import Interface
Expand All @@ -22,11 +22,11 @@ def __init__(
self,
name: str,
query_template: str,
task_config: Optional[T] = None,
task_type="sql_task",
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
inputs: Optional[Dict[str, Type]] = None,
metadata: Optional[TaskMetadata] = None,
outputs: Optional[Dict[str, Type]] = None,
task_config: Optional[T] = None,
outputs: Dict[str, Type] = None,
**kwargs,
):
"""
Expand Down
48 changes: 18 additions & 30 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,10 @@
import datetime
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast
from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union

from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import (
ExecutionParameters,
ExecutionState,
FlyteContext,
FlyteContextManager,
FlyteEntities,
)
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.local_cache import LocalTaskCache
from flytekit.core.promise import (
Expand Down Expand Up @@ -162,7 +156,7 @@ def __init__(
self,
task_type: str,
name: str,
interface: _interface_models.TypedInterface,
interface: Optional[_interface_models.TypedInterface] = None,
metadata: Optional[TaskMetadata] = None,
task_type_version=0,
security_ctx: Optional[SecurityContext] = None,
Expand All @@ -180,7 +174,7 @@ def __init__(
FlyteEntities.entities.append(self)

@property
def interface(self) -> _interface_models.TypedInterface:
def interface(self) -> Optional[_interface_models.TypedInterface]:
return self._interface

@property
Expand Down Expand Up @@ -300,8 +294,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
vals = [Promise(var, outputs_literals[var]) for var in output_names]
return create_task_output(vals, self.python_interface)

def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore
def __call__(self, *args, **kwargs):
return flyte_entity_call_handler(self, *args, **kwargs)

def compile(self, ctx: FlyteContext, *args, **kwargs):
raise Exception("not implemented")
Expand Down Expand Up @@ -345,8 +339,8 @@ def sandbox_execute(
"""
Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime.
"""
es = cast(ExecutionState, ctx.execution_state)
b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox()
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
return self.dispatch_execute(ctx, input_literal_map)

Expand Down Expand Up @@ -395,7 +389,7 @@ def __init__(
self,
task_type: str,
name: str,
task_config: Optional[T],
task_config: T,
interface: Optional[Interface] = None,
environment: Optional[Dict[str, str]] = None,
disable_deck: bool = True,
Expand Down Expand Up @@ -432,13 +426,9 @@ def __init__(
)
else:
if self._python_interface.docstring.short_description:
cast(
Documentation, self._docs
).short_description = self._python_interface.docstring.short_description
self._docs.short_description = self._python_interface.docstring.short_description
if self._python_interface.docstring.long_description:
cast(Documentation, self._docs).long_description = Description(
value=self._python_interface.docstring.long_description
)
self._docs.long_description = Description(value=self._python_interface.docstring.long_description)

# TODO lets call this interface and the other as flyte_interface?
@property
Expand All @@ -449,25 +439,25 @@ def python_interface(self) -> Interface:
return self._python_interface

@property
def task_config(self) -> Optional[T]:
def task_config(self) -> T:
"""
Returns the user-specified task config which is used for plugin-specific handling of the task.
"""
return self._task_config

def get_type_for_input_var(self, k: str, v: Any) -> Type[Any]:
def get_type_for_input_var(self, k: str, v: Any) -> Optional[Type[Any]]:
"""
Returns the python type for an input variable by name.
"""
return self._python_interface.inputs[k]

def get_type_for_output_var(self, k: str, v: Any) -> Type[Any]:
def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]:
"""
Returns the python type for the specified output variable by name.
"""
return self._python_interface.outputs[k]

def get_input_types(self) -> Dict[str, type]:
def get_input_types(self) -> Optional[Dict[str, type]]:
"""
Returns the names and python types as a dictionary for the inputs of this task.
"""
Expand Down Expand Up @@ -513,9 +503,7 @@ def dispatch_execute(

# Create another execution context with the new user params, but let's keep the same working dir
with FlyteContextManager.with_context(
ctx.with_execution_state(
cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params)
)
ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params))
# type: ignore
) as exec_ctx:
# TODO We could support default values here too - but not part of the plan right now
Expand Down Expand Up @@ -608,7 +596,7 @@ def dispatch_execute(
# After the execute has been successfully completed
return outputs_literal_map

def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore
def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
"""
This is the method that will be invoked directly before executing the task method and before all the inputs
are converted. One particular case where this is useful is if the context is to be modified for the user process
Expand All @@ -626,7 +614,7 @@ def execute(self, **kwargs) -> Any:
"""
pass

def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any:
def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
"""
Post execute is called after the execution has completed, with the user_params and can be used to clean-up,
or alter the outputs to match the intended tasks outputs. If not overridden, then this function is a No-op
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/class_based_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs):
def name(self) -> str:
return "ClassStorageTaskResolver"

def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type:ignore
def get_all_tasks(self) -> List[PythonAutoContainerTask]:
return self.mapping

def add(self, t: PythonAutoContainerTask):
Expand All @@ -33,7 +33,7 @@ def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
idx = int(loader_args[0])
return self.mapping[idx]

def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: # type: ignore
def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]:
"""
This is responsible for turning an instance of a task into args that the load_task function can reconstitute.
"""
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP
return self._compute_outputs(n)
return self._condition

def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case:
def if_(self, expr: bool) -> Case:
return self._condition._if(expr)

def compute_output_vars(self) -> typing.Optional[typing.List[str]]:
Expand Down Expand Up @@ -360,7 +360,7 @@ def create_branch_node_promise_var(node_id: str, var: str) -> str:
return f"{node_id}.{var}"


def merge_promises(*args: Optional[Promise]) -> typing.List[Promise]:
def merge_promises(*args: Promise) -> typing.List[Promise]:
node_vars: typing.Set[typing.Tuple[str, str]] = set()
merged_promises: typing.List[Promise] = []
for p in args:
Expand Down Expand Up @@ -414,7 +414,7 @@ def transform_to_boolexpr(


def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]:
expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr))
expr, promises = transform_to_boolexpr(c.expr)
n = c.output_promise.ref.node # type: ignore
return _core_wf.IfBlock(condition=expr, then_node=n), promises

Expand Down
14 changes: 7 additions & 7 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Type

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
Expand Down Expand Up @@ -38,16 +38,16 @@ def __init__(
name: str,
image: str,
command: List[str],
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
inputs: Optional[Dict[str, Type]] = None,
metadata: Optional[TaskMetadata] = None,
arguments: Optional[List[str]] = None,
outputs: Optional[Dict[str, Type]] = None,
arguments: List[str] = None,
outputs: Dict[str, Type] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
input_data_dir: Optional[str] = None,
output_data_dir: Optional[str] = None,
input_data_dir: str = None,
output_data_dir: str = None,
metadata_format: MetadataFormat = MetadataFormat.JSON,
io_strategy: Optional[IOStrategy] = None,
io_strategy: IOStrategy = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
Expand Down
20 changes: 10 additions & 10 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
flyte_context_Var: ContextVar[typing.List[FlyteContext]] = ContextVar("", default=[])

if typing.TYPE_CHECKING:
from flytekit.core.base_task import Task, TaskResolverMixin
from flytekit.core.base_task import TaskResolverMixin


# Identifier fields use placeholders for registration-time substitution.
Expand Down Expand Up @@ -108,7 +108,7 @@ def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder:

def build(self) -> ExecutionParameters:
if not isinstance(self.working_dir, utils.AutoDeletingTempDir):
pathlib.Path(typing.cast(str, self.working_dir)).mkdir(parents=True, exist_ok=True)
pathlib.Path(self.working_dir).mkdir(parents=True, exist_ok=True)
return ExecutionParameters(
execution_date=self.execution_date,
stats=self.stats,
Expand All @@ -123,14 +123,14 @@ def build(self) -> ExecutionParameters:
)

@staticmethod
def new_builder(current: Optional[ExecutionParameters] = None) -> Builder:
def new_builder(current: ExecutionParameters = None) -> Builder:
return ExecutionParameters.Builder(current=current)

def with_task_sandbox(self) -> Builder:
prefix = self.working_directory
if isinstance(self.working_directory, utils.AutoDeletingTempDir):
prefix = self.working_directory.name
task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) # type: ignore
task_sandbox_dir = tempfile.mkdtemp(prefix=prefix)
p = pathlib.Path(task_sandbox_dir)
cp_dir = p.joinpath("__cp")
cp_dir.mkdir(exist_ok=True)
Expand Down Expand Up @@ -299,7 +299,7 @@ def get(self, key: str) -> typing.Any:
"""
Returns task specific context if present else raise an error. The returned context will match the key
"""
return self.__getattr__(attr_name=key) # type: ignore
return self.__getattr__(attr_name=key)


class SecretsManager(object):
Expand Down Expand Up @@ -480,14 +480,14 @@ class Mode(Enum):
LOCAL_TASK_EXECUTION = 3

mode: Optional[ExecutionState.Mode]
working_dir: Union[os.PathLike, str]
working_dir: os.PathLike
engine_dir: Optional[Union[os.PathLike, str]]
branch_eval_mode: Optional[BranchEvalMode]
user_space_params: Optional[ExecutionParameters]

def __init__(
self,
working_dir: Union[os.PathLike, str],
working_dir: os.PathLike,
mode: Optional[ExecutionState.Mode] = None,
engine_dir: Optional[Union[os.PathLike, str]] = None,
branch_eval_mode: Optional[BranchEvalMode] = None,
Expand Down Expand Up @@ -620,7 +620,7 @@ def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> Exec
return ExecutionState(working_dir=working_dir, user_space_params=self.user_space_params)

@staticmethod
def current_context() -> FlyteContext:
def current_context() -> Optional[FlyteContext]:
"""
This method exists only to maintain backwards compatibility. Please use
``FlyteContextManager.current_context()`` instead.
Expand Down Expand Up @@ -652,7 +652,7 @@ def get_deck(self) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ig
"""
from flytekit.deck.deck import _get_deck

return _get_deck(typing.cast(ExecutionState, self.execution_state).user_space_params)
return _get_deck(self.execution_state.user_space_params)

@dataclass
class Builder(object):
Expand Down Expand Up @@ -865,7 +865,7 @@ class FlyteEntities(object):
registration process
"""

entities: List[Union["LaunchPlan", Task, "WorkflowBase"]] = [] # type: ignore
entities = []


FlyteContextManager.initialize()
2 changes: 1 addition & 1 deletion flytekit/core/docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Docstring(object):
def __init__(self, docstring: Optional[str] = None, callable_: Optional[Callable] = None):
def __init__(self, docstring: str = None, callable_: Callable = None):
if docstring is not None:
self._parsed_docstring = parse(docstring)
else:
Expand Down
5 changes: 2 additions & 3 deletions flytekit/core/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
)
else:
# We don't know how to find the python interface here, approve() sets it below, See the code.
self._python_interface = None # type: ignore
self._python_interface = None

@property
def name(self) -> str:
Expand Down Expand Up @@ -105,7 +105,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
return p

# Assume this is an approval operation since that's the only remaining option.
msg = f"Pausing execution for {self.name}, literal value is:\n{typing.cast(Promise, self._upstream_item).val}\nContinue?"
msg = f"Pausing execution for {self.name}, literal value is:\n{self._upstream_item.val}\nContinue?"
proceed = click.confirm(msg, default=True)
if proceed:
# We need to return a promise here, and a promise is what should've been passed in by the call in approve()
Expand Down Expand Up @@ -167,7 +167,6 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st
raise ValueError("You can't use approval on a task that doesn't return anything.")

ctx = FlyteContextManager.current_context()
upstream_item = typing.cast(Promise, upstream_item)
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
if not upstream_item.ref.node.flyte_entity.python_interface:
raise ValueError(
Expand Down
Loading

0 comments on commit 14fb5b5

Please sign in to comment.