Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy errors #1313

Merged
merged 28 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ fmt: ## Format code with black and isort

.PHONY: lint
lint: ## Run linters
mypy flytekit/core || true
mypy flytekit/types || true
mypy tests/flytekit/unit/core || true
# Exclude setup.py to fix error: Duplicate module named "setup"
mypy plugins --exclude setup.py || true
mypy flytekit/core
mypy flytekit/types
# allow-empty-bodies: Allow empty body in function.
# disable-error-code="annotation-unchecked": Remove the warning "By default the bodies of untyped functions are not checked".
# Mypy raises a warning because it cannot determine the type from the dataclass, despite we specified the type in the dataclass.
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
mypy --allow-empty-bodies --disable-error-code="annotation-unchecked" tests/flytekit/unit/core
pre-commit run --all-files

.PHONY: spellcheck
Expand Down
3 changes: 3 additions & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ tensorflow==2.8.1
# we put this constraint while we do not have per-environment requirements files
torch<=1.12.1
scikit-learn
types-protobuf
types-croniter
types-mock
46 changes: 17 additions & 29 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#
# This file is autogenerated by pip-compile with Python 3.7
# This file is autogenerated by pip-compile with Python 3.9
# by the following command:
#
# make dev-requirements.txt
# pip-compile dev-requirements.in
#
-e file:.#egg=flytekit
# via
Expand All @@ -12,6 +12,8 @@ absl-py==1.3.0
# via
# tensorboard
# tensorflow
appnope==0.1.3
# via ipython
arrow==1.2.3
# via
# -c requirements.txt
Expand All @@ -32,8 +34,6 @@ binaryornot==0.4.4
# via
# -c requirements.txt
# cookiecutter
cached-property==1.5.2
# via docker-compose
cachetools==5.2.0
# via google-auth
certifi==2022.12.7
Expand Down Expand Up @@ -83,7 +83,6 @@ cryptography==38.0.4
# -c requirements.txt
# paramiko
# pyopenssl
# secretstorage
dataclasses-json==0.5.7
# via
# -c requirements.txt
Expand Down Expand Up @@ -136,6 +135,10 @@ flyteidl==1.3.1
# flytekit
gast==0.5.3
# via tensorflow
gitdb==4.0.10
# via gitpython
gitpython==3.1.30
# via flytekit
google-api-core[grpc]==2.11.0
# via
# google-cloud-bigquery
Expand Down Expand Up @@ -167,6 +170,7 @@ googleapis-common-protos==1.57.0
# via
# -c requirements.txt
# flyteidl
# flytekit
# google-api-core
# grpcio-status
grpcio==1.51.1
Expand Down Expand Up @@ -194,15 +198,9 @@ idna==3.4
importlib-metadata==5.1.0
# via
# -c requirements.txt
# click
# flytekit
# jsonschema
# keyring
# markdown
# pluggy
# pre-commit
# pytest
# virtualenv
iniconfig==1.1.1
# via pytest
ipython==7.34.0
Expand All @@ -213,11 +211,6 @@ jaraco-classes==3.2.3
# keyring
jedi==0.18.2
# via ipython
jeepney==0.8.0
# via
# -c requirements.txt
# keyring
# secretstorage
jinja2==3.1.2
# via
# -c requirements.txt
Expand Down Expand Up @@ -470,14 +463,6 @@ scikit-learn==1.0.2
# via -r dev-requirements.in
scipy==1.7.3
# via scikit-learn
secretstorage==3.3.3
# via
# -c requirements.txt
# keyring
singledispatchmethod==1.0
# via
# -c requirements.txt
# flytekit
six==1.16.0
# via
# -c requirements.txt
Expand All @@ -491,6 +476,8 @@ six==1.16.0
# python-dateutil
# tensorflow
# websocket-client
smmap==5.0.0
# via gitdb
sortedcontainers==2.4.0
# via
# -c requirements.txt
Expand Down Expand Up @@ -537,20 +524,21 @@ traitlets==5.6.0
# via
# ipython
# matplotlib-inline
typed-ast==1.5.4
# via mypy
types-croniter==1.3.2.2
# via -r dev-requirements.in
types-mock==5.0.0.2
# via -r dev-requirements.in
types-protobuf==4.21.0.3
# via -r dev-requirements.in
types-toml==0.10.8.1
# via
# -c requirements.txt
# responses
typing-extensions==4.4.0
# via
# -c requirements.txt
# arrow
# flytekit
# importlib-metadata
# mypy
# responses
# tensorflow
# torch
# typing-inspect
Expand Down
8 changes: 4 additions & 4 deletions flytekit/core/base_sql_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Dict, Optional, Type, TypeVar
from typing import Any, Dict, Optional, Tuple, Type, TypeVar

from flytekit.core.base_task import PythonTask, TaskMetadata
from flytekit.core.interface import Interface
Expand All @@ -22,11 +22,11 @@ def __init__(
self,
name: str,
query_template: str,
task_config: Optional[T] = None,
task_type="sql_task",
inputs: Optional[Dict[str, Type]] = None,
inputs: Optional[Dict[str, Tuple[Type, Any]]] = None,
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
metadata: Optional[TaskMetadata] = None,
task_config: Optional[T] = None,
outputs: Dict[str, Type] = None,
outputs: Optional[Dict[str, Type]] = None,
**kwargs,
):
"""
Expand Down
52 changes: 32 additions & 20 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,16 @@
import datetime
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast

from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContext, FlyteContextManager, FlyteEntities
from flytekit.core.context_manager import (
ExecutionParameters,
ExecutionState,
FlyteContext,
FlyteContextManager,
FlyteEntities,
)
from flytekit.core.interface import Interface, transform_interface_to_typed_interface
from flytekit.core.local_cache import LocalTaskCache
from flytekit.core.promise import (
Expand Down Expand Up @@ -153,7 +159,7 @@ def __init__(
self,
task_type: str,
name: str,
interface: Optional[_interface_models.TypedInterface] = None,
interface: _interface_models.TypedInterface,
metadata: Optional[TaskMetadata] = None,
task_type_version=0,
security_ctx: Optional[SecurityContext] = None,
Expand All @@ -171,7 +177,7 @@ def __init__(
FlyteEntities.entities.append(self)

@property
def interface(self) -> Optional[_interface_models.TypedInterface]:
def interface(self) -> _interface_models.TypedInterface:
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
return self._interface

@property
Expand Down Expand Up @@ -239,8 +245,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
kwargs = translate_inputs_to_literals(
ctx,
incoming_values=kwargs,
flyte_interface_types=self.interface.inputs, # type: ignore
native_types=self.get_input_types(),
flyte_interface_types=self.interface.inputs,
native_types=self.get_input_types(), # type: ignore
)
input_literal_map = _literal_models.LiteralMap(literals=kwargs)

Expand All @@ -265,8 +271,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
else:
logger.info("Cache hit")
else:
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
es = cast(ExecutionState, ctx.execution_state)
b = cast(ExecutionParameters, es.user_space_params).with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
outputs_literals = outputs_literal_map.literals
Expand All @@ -286,8 +292,8 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
vals = [Promise(var, outputs_literals[var]) for var in output_names]
return create_task_output(vals, self.python_interface)

def __call__(self, *args, **kwargs):
return flyte_entity_call_handler(self, *args, **kwargs)
def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]:
return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore

def compile(self, ctx: FlyteContext, *args, **kwargs):
raise Exception("not implemented")
Expand Down Expand Up @@ -368,7 +374,7 @@ def __init__(
self,
task_type: str,
name: str,
task_config: T,
task_config: Optional[T],
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
interface: Optional[Interface] = None,
environment: Optional[Dict[str, str]] = None,
disable_deck: bool = True,
Expand Down Expand Up @@ -405,9 +411,13 @@ def __init__(
)
else:
if self._python_interface.docstring.short_description:
self._docs.short_description = self._python_interface.docstring.short_description
cast(
Documentation, self._docs
).short_description = self._python_interface.docstring.short_description
if self._python_interface.docstring.long_description:
self._docs.long_description = Description(value=self._python_interface.docstring.long_description)
cast(Documentation, self._docs).long_description = Description(
value=self._python_interface.docstring.long_description
)

# TODO lets call this interface and the other as flyte_interface?
@property
Expand All @@ -418,25 +428,25 @@ def python_interface(self) -> Interface:
return self._python_interface

@property
def task_config(self) -> T:
def task_config(self) -> Optional[T]:
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns the user-specified task config which is used for plugin-specific handling of the task.
"""
return self._task_config

def get_type_for_input_var(self, k: str, v: Any) -> Optional[Type[Any]]:
def get_type_for_input_var(self, k: str, v: Any) -> Type[Any]:
"""
Returns the python type for an input variable by name.
"""
return self._python_interface.inputs[k]

def get_type_for_output_var(self, k: str, v: Any) -> Optional[Type[Any]]:
def get_type_for_output_var(self, k: str, v: Any) -> Type[Any]:
"""
Returns the python type for the specified output variable by name.
"""
return self._python_interface.outputs[k]

def get_input_types(self) -> Optional[Dict[str, type]]:
def get_input_types(self) -> Dict[str, type]:
"""
Returns the names and python types as a dictionary for the inputs of this task.
"""
Expand Down Expand Up @@ -482,7 +492,9 @@ def dispatch_execute(

# Create another execution context with the new user params, but let's keep the same working dir
with FlyteContextManager.with_context(
ctx.with_execution_state(ctx.execution_state.with_params(user_space_params=new_user_params))
ctx.with_execution_state(
cast(ExecutionState, ctx.execution_state).with_params(user_space_params=new_user_params)
)
# type: ignore
) as exec_ctx:
# TODO We could support default values here too - but not part of the plan right now
Expand Down Expand Up @@ -563,7 +575,7 @@ def dispatch_execute(
# After the execute has been successfully completed
return outputs_literal_map

def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore
"""
This is the method that will be invoked directly before executing the task method and before all the inputs
are converted. One particular case where this is useful is if the context is to be modified for the user process
Expand All @@ -581,7 +593,7 @@ def execute(self, **kwargs) -> Any:
"""
pass

def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any:
"""
Post execute is called after the execution has completed, with the user_params and can be used to clean-up,
or alter the outputs to match the intended tasks outputs. If not overridden, then this function is a No-op
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/class_based_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, *args, **kwargs):
def name(self) -> str:
return "ClassStorageTaskResolver"

def get_all_tasks(self) -> List[PythonAutoContainerTask]:
def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type:ignore
return self.mapping

def add(self, t: PythonAutoContainerTask):
Expand All @@ -33,7 +33,7 @@ def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask:
idx = int(loader_args[0])
return self.mapping[idx]

def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]:
def loader_args(self, settings: SerializationSettings, t: PythonAutoContainerTask) -> List[str]: # type: ignore
"""
This is responsible for turning an instance of a task into args that the load_task function can reconstitute.
"""
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def end_branch(self) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidP
return self._compute_outputs(n)
return self._condition

def if_(self, expr: bool) -> Case:
def if_(self, expr: Union[ComparisonExpression, ConjunctionExpression]) -> Case:
return self._condition._if(expr)

def compute_output_vars(self) -> typing.Optional[typing.List[str]]:
Expand Down Expand Up @@ -360,7 +360,7 @@ def create_branch_node_promise_var(node_id: str, var: str) -> str:
return f"{node_id}.{var}"


def merge_promises(*args: Promise) -> typing.List[Promise]:
def merge_promises(*args: Optional[Promise]) -> typing.List[Promise]:
node_vars: typing.Set[typing.Tuple[str, str]] = set()
merged_promises: typing.List[Promise] = []
for p in args:
Expand Down Expand Up @@ -414,7 +414,7 @@ def transform_to_boolexpr(


def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]:
expr, promises = transform_to_boolexpr(c.expr)
expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr))
n = c.output_promise.ref.node # type: ignore
return _core_wf.IfBlock(condition=expr, then_node=n), promises

Expand Down
Loading