Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy errors #1313

Merged
merged 28 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ fmt: ## Format code with black and isort

.PHONY: lint
lint: ## Run linters
mypy flytekit/core || true
mypy flytekit/types || true
mypy tests/flytekit/unit/core || true
# Exclude setup.py to fix error: Duplicate module named "setup"
mypy plugins --exclude setup.py || true
mypy flytekit/core
mypy flytekit/types
# allow-empty-bodies: Allow empty body in function.
# disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked".
# Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass.
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core
pre-commit run --all-files

.PHONY: spellcheck
Expand Down
2 changes: 2 additions & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ tensorflow==2.8.1
# we put this constraint while we do not have per-environment requirements files
torch<=1.12.1
scikit-learn
types-croniter
types-protobuf
8 changes: 4 additions & 4 deletions flytekit/core/base_sql_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, Optional, Type, TypeVar
from typing import Any, Dict, Optional, Tuple, Type, TypeVar

from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.interface import Interface
Expand All @@ -22,11 +22,11 @@ def __init__(
self,
name: str,
query_template: str,
task_config: Optional[T] = None,
task_type="sql_task",
inputs: Optional[Dict[str, Type]] = None,
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
metadata: Optional[TaskMetadata] = None,
task_config: Optional[T] = None,
outputs: Dict[str, Type] = None,
outputs: Optional[Dict[str, Type]] = None,
**kwargs,
):
"""
Expand Down
42 changes: 25 additions & 17 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
import datetime
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast

from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities
from flytekit.core.context_manager import (
ExecutionParameters,
ExecutionState,
FlyteContext,
FlyteContextManager,
FlyteEntities,
)
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.local_cache import LocalTaskCache
from flytekit.core.promise import (
Expand Down Expand Up @@ -168,7 +174,7 @@ def __init__(
FlyteEntities.entities.append(self)

@property
def interface(self) -> Optional[_interface_models.TypedInterface]:
def interface(self) -> _interface_models.TypedInterface:
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
return self._interface

@property
Expand Down Expand Up @@ -232,8 +238,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
kwargs = translate_inputs_to_literals(
ctx,
incoming_values=kwargs,
flyte_interface_types=self.interface.inputs, # type: ignore
native_types=self.get_input_types(),
flyte_interface_types=self.interface.inputs,
native_types=self.get_input_types(), # type: ignore
)
input_literal_map = _literal_models.LiteralMap(literals=kwargs)

Expand All @@ -258,8 +264,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
else:
logger.info("Cache hit")
else:
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
es = cast(ExecutionState, ctx.execution_state)
b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
outputs_literals = outputs_literal_map.literals
Expand All @@ -279,8 +285,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
vals = [Promise(var, outputs_literals[var]) for var in output_names]
return create_task_output(vals, self.python_interface)

def __call__(self, *args, **kwargs):
return flyte_entity_call_handler(self, *args, **kwargs)
def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore

def compile(self, ctx: FlyteContext, *args, **kwargs):
raise Exception("not implemented")
Expand Down Expand Up @@ -361,7 +367,7 @@ def __init__(
self,
task_type: str,
name: str,
task_config: T,
task_config: Optional[T],
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
interface: Optional[Interface] = None,
environment: Optional[Dict[str, str]] = None,
disable_deck: bool = True,
Expand Down Expand Up @@ -400,25 +406,25 @@ def python_interface(self) -> Interface:
return self._python_interface

@property
def task_config(self) -> T:
def task_config(self) -> Optional[T]:
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the user-specified task config which is used for plugin-specific handling of the task.
"""
return self._task_config

def get_type_for_input_var(self, k: str, v: Any) -> Optional[Type[Any]]:
def get_type_for_input_var(self, k: str, v: Any) -> Type[Any]:
"""
Returns the python type for an input variable by name.
"""
return self._python_interface.inputs[k]

def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]:
def get_type_for_output_var(self, k: str, v: Any) -> Type[Any]:
"""
Returns the python type for the specified output variable by name.
"""
return self._python_interface.outputs[k]

def get_input_types(self) -> Optional[Dict[str, type]]:
def get_input_types(self) -> Dict[str, type]:
"""
Returns the names and python types as a dictionary for the inputs of this task.
"""
Expand Down Expand Up @@ -464,7 +470,9 @@ def dispatch_execute(

# Create another execution context with the new user params, but let's keep the same working dir
with FlyteContextManager.with_context(
ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params))
ctx.with_execution_state(
cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params)
)
# type: ignore
) as exec_ctx:
# TODO We could support default values here too - but not part of the plan right now
Expand Down Expand Up @@ -545,7 +553,7 @@ def dispatch_execute(
# After the execute has been successfully completed
return outputs_literal_map

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore
"""
This is the method that will be invoked directly before executing the task method and before all the inputs
are converted. One particular case where this is useful is if the context is to be modified for the user process
Expand All @@ -563,7 +571,7 @@ def execute(self, **kwargs) -> Any:
"""
pass

def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any:
"""
Post execute is called after the execution has completed, with the user_params and can be used to clean-up,
or alter the outputs to match the intended tasks outputs. If not overridden, then this function is a No-op
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/class_based_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs):
def name(self) -> str:
return "ClassStorageTaskResolver"

def get_all_tasks(self) -> List[PythonAutoContainerTask]:
def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type:ignore
return self.mapping

def add(self, t: PythonAutoContainerTask):
Expand All @@ -33,7 +33,7 @@ def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
idx = int(loader_args[0])
return self.mapping[idx]

def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]:
def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: # type: ignore
"""
This is responsible for turning an instance of a task into args that the load_task function can reconstitute.
"""
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP
return self._compute_outputs(n)
return self._condition

def if_(self, expr: bool) -> Case:
def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case:
return self._condition._if(expr)

def compute_output_vars(self) -> typing.Optional[typing.List[str]]:
Expand Down Expand Up @@ -360,7 +360,7 @@ def create_branch_node_promise_var(node_id: str, var: str) -> str:
return f"{node_id}.{var}"


def merge_promises(*args: Promise) -> typing.List[Promise]:
def merge_promises(*args: Optional[Promise]) -> typing.List[Promise]:
node_vars: typing.Set[typing.Tuple[str, str]] = set()
merged_promises: typing.List[Promise] = []
for p in args:
Expand Down Expand Up @@ -414,7 +414,7 @@ def transform_to_boolexpr(


def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]:
expr, promises = transform_to_boolexpr(c.expr)
expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr))
n = c.output_promise.ref.node # type: ignore
return _core_wf.IfBlock(condition=expr, then_node=n), promises

Expand Down
14 changes: 7 additions & 7 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional, Tuple, Type

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
Expand Down Expand Up @@ -35,16 +35,16 @@ def __init__(
name: str,
image: str,
command: List[str],
inputs: Optional[Dict[str, Type]] = None,
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
metadata: Optional[TaskMetadata] = None,
arguments: List[str] = None,
outputs: Dict[str, Type] = None,
arguments: Optional[List[str]] = None,
outputs: Optional[Dict[str, Type]] = None,
requests: Optional[Resources] = None,
limits: Optional[Resources] = None,
input_data_dir: str = None,
output_data_dir: str = None,
input_data_dir: Optional[str] = None,
output_data_dir: Optional[str] = None,
metadata_format: MetadataFormat = MetadataFormat.JSON,
io_strategy: IOStrategy = None,
io_strategy: Optional[IOStrategy] = None,
secret_requests: Optional[List[Secret]] = None,
**kwargs,
):
Expand Down
20 changes: 10 additions & 10 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
flyte_context_Var: ContextVar[typing.List[FlyteContext]] = ContextVar("", default=[])

if typing.TYPE_CHECKING:
from flytekit.core.base_task import TaskResolverMixin
from flytekit.core.base_task import Task, TaskResolverMixin


# Identifier fields use placeholders for registration-time substitution.
Expand Down Expand Up @@ -108,7 +108,7 @@ def add_attr(self, key: str, v: typing.Any) -> ExecutionParameters.Builder:

def build(self) -> ExecutionParameters:
if not isinstance(self.working_dir, utils.AutoDeletingTempDir):
pathlib.Path(self.working_dir).mkdir(parents=True, exist_ok=True)
pathlib.Path(typing.cast(str, self.working_dir)).mkdir(parents=True, exist_ok=True)
return ExecutionParameters(
execution_date=self.execution_date,
stats=self.stats,
Expand All @@ -123,14 +123,14 @@ def build(self) -> ExecutionParameters:
)

@staticmethod
def new_builder(current: ExecutionParameters = None) -> Builder:
def new_builder(current: Optional[ExecutionParameters] = None) -> Builder:
return ExecutionParameters.Builder(current=current)

def with_task_sandbox(self) -> Builder:
prefix = self.working_directory
if isinstance(self.working_directory, utils.AutoDeletingTempDir):
prefix = self.working_directory.name
task_sandbox_dir = tempfile.mkdtemp(prefix=prefix)
task_sandbox_dir = tempfile.mkdtemp(prefix=prefix) # type: ignore
p = pathlib.Path(task_sandbox_dir)
cp_dir = p.joinpath("__cp")
cp_dir.mkdir(exist_ok=True)
Expand Down Expand Up @@ -287,7 +287,7 @@ def get(self, key: str) -> typing.Any:
"""
Returns task specific context if present else raise an error. The returned context will match the key
"""
return self.__getattr__(attr_name=key)
return self.__getattr__(attr_name=key) # type: ignore


class SecretsManager(object):
Expand Down Expand Up @@ -467,14 +467,14 @@ class Mode(Enum):
LOCAL_TASK_EXECUTION = 3

mode: Optional[ExecutionState.Mode]
working_dir: os.PathLike
working_dir: Union[os.PathLike, str]
engine_dir: Optional[Union[os.PathLike, str]]
branch_eval_mode: Optional[BranchEvalMode]
user_space_params: Optional[ExecutionParameters]

def __init__(
self,
working_dir: os.PathLike,
working_dir: Union[os.PathLike, str],
mode: Optional[ExecutionState.Mode] = None,
engine_dir: Optional[Union[os.PathLike, str]] = None,
branch_eval_mode: Optional[BranchEvalMode] = None,
Expand Down Expand Up @@ -607,7 +607,7 @@ def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> Exec
return ExecutionState(working_dir=working_dir, user_space_params=self.user_space_params)

@staticmethod
def current_context() -> Optional[FlyteContext]:
def current_context() -> FlyteContext:
"""
This method exists only to maintain backwards compatibility. Please use
``FlyteContextManager.current_context()`` instead.
Expand Down Expand Up @@ -639,7 +639,7 @@ def get_deck(self) -> typing.Union[str, "IPython.core.display.HTML"]: # type:ig
"""
from flytekit.deck.deck import _get_deck

return _get_deck(self.execution_state.user_space_params)
return _get_deck(typing.cast(ExecutionState, self.execution_state).user_space_params)

@dataclass
class Builder(object):
Expand Down Expand Up @@ -852,7 +852,7 @@ class FlyteEntities(object):
registration process
"""

entities = []
entities: List["LaunchPlan" | Task | "WorkflowBase"] = [] # type: ignore


FlyteContextManager.initialize()
6 changes: 3 additions & 3 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class DataPersistence(object):
Base abstract type for all DataPersistence operations. This can be extended using the flytekitplugins architecture
"""

def __init__(self, name: str, default_prefix: typing.Optional[str] = None, **kwargs):
def __init__(self, name: str = "", default_prefix: typing.Optional[str] = None, **kwargs):
self._name = name
self._default_prefix = default_prefix

Expand Down Expand Up @@ -142,7 +142,7 @@ def register_plugin(cls, protocol: str, plugin: typing.Type[DataPersistence], fo
cls._PLUGINS[protocol] = plugin

@staticmethod
def get_protocol(url: str):
def get_protocol(url: str) -> str:
# copy from fsspec https://github.com/fsspec/filesystem_spec/blob/fe09da6942ad043622212927df7442c104fe7932/fsspec/utils.py#L387-L391
parts = re.split(r"(\:\:|\://)", url, 1)
if len(parts) > 1:
Expand Down Expand Up @@ -458,7 +458,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart=False):
f"Original exception: {str(ex)}"
)

def put_data(self, local_path: Union[str, os.PathLike], remote_path: str, is_multipart=False):
def put_data(self, local_path: str, remote_path: str, is_multipart=False):
"""
The implication here is that we're always going to put data to the remote location, so we .remote to ensure
we don't use the true local proxy if the remote path is a file://
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class Docstring(object):
def __init__(self, docstring: str = None, callable_: Callable = None):
def __init__(self, docstring: Optional[str] = None, callable_: Optional[Callable] = None):
if docstring is not None:
self._parsed_docstring = parse(docstring)
else:
Expand Down
5 changes: 3 additions & 2 deletions flytekit/core/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
)
else:
# We don't know how to find the python interface here, approve() sets it below, See the code.
self._python_interface = None
self._python_interface = None # type: ignore

@property
def name(self) -> str:
Expand Down Expand Up @@ -105,7 +105,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
return p

# Assume this is an approval operation since that's the only remaining option.
msg = f"Pausing execution for {self.name}, literal value is:\n{self._upstream_item.val}\nContinue?"
msg = f"Pausing execution for {self.name}, literal value is:\n{typing.cast(Promise, self._upstream_item).val}\nContinue?"
proceed = click.confirm(msg, default=True)
if proceed:
# We need to return a promise here, and a promise is what should've been passed in by the call in approve()
Expand Down Expand Up @@ -164,6 +164,7 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st
raise ValueError("You can't use approval on a task that doesn't return anything.")

ctx = FlyteContextManager.current_context()
upstream_item = typing.cast(Promise, upstream_item)
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
if not upstream_item.ref.node.flyte_entity.python_interface:
raise ValueError(
Expand Down
Loading