Skip to content

Commit

Permalink
Lazy load modules (#1590)
Browse files Browse the repository at this point in the history
* lazy load module

Signed-off-by: Kevin Su <pingsutw@apache.org>

* lazy load module

Signed-off-by: Kevin Su <pingsutw@apache.org>

* import

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* keep structured dataset in flytekit init

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* fixed tess

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* fixed tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* fixed tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* move import pandas to __init__

Signed-off-by: Kevin Su <pingsutw@apache.org>

* use lazy import loader instead

Signed-off-by: Kevin Su <pingsutw@apache.org>

* Fixed tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* Fixed tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* wip

Signed-off-by: Kevin Su <pingsutw@apache.org>

* fix tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* regular import

Signed-off-by: Kevin Su <pingsutw@apache.org>

* fixed tests

Signed-off-by: Kevin Su <pingsutw@apache.org>

* test

Signed-off-by: Kevin Su <pingsutw@apache.org>

* lint

Signed-off-by: Kevin Su <pingsutw@apache.org>

* nit

Signed-off-by: Kevin Su <pingsutw@apache.org>

* lint

Signed-off-by: Kevin Su <pingsutw@apache.org>

---------

Signed-off-by: Kevin Su <pingsutw@apache.org>
  • Loading branch information
pingsutw authored and eapolinario committed May 16, 2023
1 parent d778232 commit 5a66411
Show file tree
Hide file tree
Showing 23 changed files with 259 additions and 142 deletions.
5 changes: 3 additions & 2 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@

from rich import traceback

from flytekit.lazy_import.lazy_module import lazy_module

if sys.version_info < (3, 10):
from importlib_metadata import entry_points
else:
Expand Down Expand Up @@ -225,7 +227,6 @@
from flytekit.core.workflow import ImperativeWorkflow as Workflow
from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow
from flytekit.deck import Deck
from flytekit.extras import pytorch, sklearn, tensorflow
from flytekit.image_spec import ImageSpec
from flytekit.loggers import logger
from flytekit.models.common import Annotations, AuthRole, Labels
Expand All @@ -234,7 +235,7 @@
from flytekit.models.documentation import Description, Documentation, SourceCode
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType
from flytekit.types import directory, file, numpy, schema
from flytekit.types import directory, file
from flytekit.types.structured.structured_dataset import (
StructuredDataset,
StructuredDatasetFormat,
Expand Down
3 changes: 2 additions & 1 deletion flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@

import yaml
from dataclasses_json import dataclass_json
from docker_image import reference

from flytekit.configuration import internal as _internal
from flytekit.configuration.default_images import DefaultImages
Expand Down Expand Up @@ -208,6 +207,8 @@ def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image
:param Text tag: e.g. somedocker.com/myimage:someversion123
:rtype: Text
"""
from docker_image import reference

if pathlib.Path(tag).is_file():
with open(tag, "r") as f:
image_spec_dict = yaml.safe_load(f)
Expand Down
3 changes: 2 additions & 1 deletion flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from flytekit.core.tracker import TrackedInstance
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError
from flytekit.core.utils import timeit
from flytekit.deck.deck import Deck
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import interface as _interface_models
Expand Down Expand Up @@ -578,6 +577,8 @@ def dispatch_execute(
raise TypeError(msg) from e

if self._disable_deck is False:
from flytekit.deck.deck import Deck

INPUT = "input"
OUTPUT = "output"

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/container_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
metadata_format: MetadataFormat = MetadataFormat.JSON,
io_strategy: IOStrategy = None,
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
**kwargs,
):
Expand Down
10 changes: 5 additions & 5 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from enum import Enum
from typing import Generator, List, Optional, Union

from flytekit.clients import friendly as friendly_client # noqa
from flytekit.configuration import Config, SecretsConfig, SerializationSettings
from flytekit.core import mock_stats, utils
from flytekit.core.checkpointer import Checkpoint, SyncCheckpoint
Expand All @@ -39,7 +38,8 @@
from flytekit.models.core import identifier as _identifier

if typing.TYPE_CHECKING:
from flytekit.deck.deck import Deck
from flytekit import Deck
from flytekit.clients import friendly as friendly_client # noqa

# TODO: resolve circular import from flytekit.core.python_auto_container import TaskResolverMixin

Expand Down Expand Up @@ -262,7 +262,7 @@ def decks(self) -> typing.List:

@property
def default_deck(self) -> Deck:
from flytekit.deck.deck import Deck
from flytekit import Deck

return Deck("default")

Expand Down Expand Up @@ -551,7 +551,7 @@ class FlyteContext(object):

file_access: FileAccessProvider
level: int = 0
flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None
flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None
compilation_state: Optional[CompilationState] = None
execution_state: Optional[ExecutionState] = None
serialization_settings: Optional[SerializationSettings] = None
Expand Down Expand Up @@ -660,7 +660,7 @@ class Builder(object):
level: int = 0
compilation_state: Optional[CompilationState] = None
execution_state: Optional[ExecutionState] = None
flyte_client: Optional[friendly_client.SynchronousFlyteClient] = None
flyte_client: Optional["friendly_client.SynchronousFlyteClient"] = None
serialization_settings: Optional[SerializationSettings] = None
in_a_condition: bool = False

Expand Down
4 changes: 3 additions & 1 deletion flytekit/core/local_cache.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Optional

import joblib
from diskcache import Cache

from flytekit import lazy_module
from flytekit.models.literals import Literal, LiteralCollection, LiteralMap

joblib = lazy_module("joblib")

# Location on the filesystem where serialized objects will be stored
# TODO: read from config
CACHE_LOCATION = "~/.flyte/local-cache"
Expand Down
15 changes: 10 additions & 5 deletions flytekit/core/pod_template.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from dataclasses import dataclass
from typing import Dict, Optional

from kubernetes.client.models import V1PodSpec
from typing import TYPE_CHECKING, Dict, Optional

from flytekit.exceptions import user as _user_exceptions

if TYPE_CHECKING:
from kubernetes.client import V1PodSpec

PRIMARY_CONTAINER_DEFAULT_NAME = "primary"


@dataclass
@dataclass(init=True, repr=True, eq=True, frozen=False)
class PodTemplate(object):
"""Custom PodTemplate specification for a Task."""

pod_spec: V1PodSpec = V1PodSpec(containers=[])
pod_spec: Optional["V1PodSpec"] = None
primary_container_name: str = PRIMARY_CONTAINER_DEFAULT_NAME
labels: Optional[Dict[str, str]] = None
annotations: Optional[Dict[str, str]] = None

def __post_init__(self):
if self.pod_spec is None:
from kubernetes.client import V1PodSpec

self.pod_spec = V1PodSpec(containers=[])
if not self.primary_container_name:
raise _user_exceptions.FlyteValidationException("A primary container name cannot be undefined")
2 changes: 1 addition & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def task(
task_resolver: Optional[TaskResolverMixin] = None,
docs: Optional[Documentation] = None,
disable_deck: bool = True,
pod_template: Optional[PodTemplate] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
) -> Union[Callable, PythonFunctionTask]:
"""
Expand Down
44 changes: 40 additions & 4 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from google.protobuf.json_format import ParseDict as _ParseDict
from google.protobuf.struct_pb2 import Struct
from marshmallow_enum import EnumField, LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.annotation import FlyteAnnotation
Expand All @@ -31,6 +30,7 @@
from flytekit.core.type_helpers import load_type_from_tag
from flytekit.core.utils import timeit
from flytekit.exceptions import user as user_exceptions
from flytekit.lazy_import.lazy_module import is_imported
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import types as _type_models
Expand Down Expand Up @@ -329,6 +329,8 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
# https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228
if isinstance(v, EnumField):
v.load_by = LoadDumpOptions.name
from marshmallow_jsonschema import JSONSchema

schema = JSONSchema().dump(s)
except Exception as e:
# https://github.com/lovasoa/marshmallow_dataclass/issues/13
Expand Down Expand Up @@ -376,7 +378,7 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T:
# In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict,
# so here we convert it back to the Structured Dataset.
from flytekit import StructuredDataset
from flytekit.types.structured import StructuredDataset

if python_type == StructuredDataset and type(python_val) == dict:
return StructuredDataset(**python_val)
Expand Down Expand Up @@ -659,7 +661,8 @@ class TypeEngine(typing.Generic[T]):

_REGISTRY: typing.Dict[type, TypeTransformer[T]] = {}
_RESTRICTED_TYPES: typing.List[type] = []
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer()
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
has_lazy_import = False

@classmethod
def register(
Expand Down Expand Up @@ -717,7 +720,7 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:
Step 4:
if v is of type data class, use the dataclass transformer
"""

cls.lazy_import_transformers()
# Step 1
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]
Expand Down Expand Up @@ -759,6 +762,39 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]:

raise ValueError(f"Type {python_type} not supported currently in Flytekit. Please register a new transformer")

@classmethod
def lazy_import_transformers(cls):
"""
Only load the transformers if needed.
"""
if cls.has_lazy_import:
return
cls.has_lazy_import = True
from flytekit.types.structured import (
register_arrow_handlers,
register_bigquery_handlers,
register_pandas_handlers,
)

if is_imported("tensorflow"):
from flytekit.extras import tensorflow # noqa: F401
if is_imported("torch"):
from flytekit.extras import pytorch # noqa: F401
if is_imported("sklearn"):
from flytekit.extras import sklearn # noqa: F401
if is_imported("pandas"):
try:
from flytekit.types import schema # noqa: F401
except ValueError:
logger.debug("Transformer for pandas is already registered.")
register_pandas_handlers()
if is_imported("pyarrow"):
register_arrow_handlers()
if is_imported("google.cloud.bigquery"):
register_bigquery_handlers()
if is_imported("numpy"):
from flytekit.types import numpy # noqa: F401

@classmethod
def to_literal_type(cls, python_type: Type) -> LiteralType:
"""
Expand Down
30 changes: 19 additions & 11 deletions flytekit/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from functools import wraps
from hashlib import sha224 as _sha224
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast

from flyteidl.core import tasks_pb2 as _core_task
from kubernetes.client import ApiClient
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

from flytekit.core.pod_template import PodTemplate
from flytekit.loggers import logger
from flytekit.models import task as _task_model

if TYPE_CHECKING:
from flytekit.models import task as task_models


def _dnsify(value: str) -> str:
Expand Down Expand Up @@ -58,8 +58,8 @@ def _dnsify(value: str) -> str:
def _get_container_definition(
image: str,
command: List[str],
args: List[str],
data_loading_config: Optional[_task_models.DataLoadingConfig] = None,
args: Optional[List[str]] = None,
data_loading_config: Optional["task_models.DataLoadingConfig"] = None,
storage_request: Optional[str] = None,
ephemeral_storage_request: Optional[str] = None,
cpu_request: Optional[str] = None,
Expand All @@ -71,7 +71,7 @@ def _get_container_definition(
gpu_limit: Optional[str] = None,
memory_limit: Optional[str] = None,
environment: Optional[Dict[str, str]] = None,
) -> _task_models.Container:
) -> "task_models.Container":
storage_limit = storage_limit
storage_request = storage_request
ephemeral_storage_limit = ephemeral_storage_limit
Expand All @@ -83,6 +83,9 @@ def _get_container_definition(
memory_limit = memory_limit
memory_request = memory_request

from flytekit.models import task as task_models

# TODO: Use convert_resources_to_resource_model instead of manually fixing the resources.
requests = []
if storage_request:
requests.append(
Expand Down Expand Up @@ -133,12 +136,17 @@ def _get_container_definition(
)


def _sanitize_resource_name(resource: _task_model.Resources.ResourceEntry) -> str:
def _sanitize_resource_name(resource: "task_models.Resources.ResourceEntry") -> str:
return _core_task.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")


def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_model.Container) -> Dict[str, Any]:
containers = cast(PodTemplate, pod_template).pod_spec.containers
def _serialize_pod_spec(pod_template: "PodTemplate", primary_container: "task_models.Container") -> Dict[str, Any]:
from kubernetes.client import ApiClient, V1PodSpec
from kubernetes.client.models import V1Container, V1EnvVar, V1ResourceRequirements

if pod_template.pod_spec is None:
return {}
containers = cast(V1PodSpec, pod_template.pod_spec).containers
primary_exists = False

for container in containers:
Expand Down Expand Up @@ -173,7 +181,7 @@ def _serialize_pod_spec(pod_template: PodTemplate, primary_container: _task_mode
container.env or []
)
final_containers.append(container)
cast(PodTemplate, pod_template).pod_spec.containers = final_containers
cast(V1PodSpec, pod_template.pod_spec).containers = final_containers

return ApiClient().sanitize_for_serialization(cast(PodTemplate, pod_template).pod_spec)

Expand Down
Loading

0 comments on commit 5a66411

Please sign in to comment.