Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support checkpointing in local mode from cached tasks #1457

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
# The cache returns None iff the key does not exist in the cache
if outputs_literal_map is None:
logger.info("Cache miss, task will be executed now")
outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)
# TODO: need `native_inputs`
LocalTaskCache.set(self.name, self.metadata.cache_version, input_literal_map, outputs_literal_map)
logger.info(
Expand All @@ -268,10 +268,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
else:
logger.info("Cache hit")
else:
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
outputs_literal_map = self.dispatch_execute(ctx, input_literal_map)
# This code should mirror the call to `sandbox_execute` in the above cache case.
# Code is simpler with duplication and less metaprogramming, but introduces regressions
# if one is changed and not the other.
outputs_literal_map = self.sandbox_execute(ctx, input_literal_map)
outputs_literals = outputs_literal_map.literals

# TODO maybe this is the part that should be done for local execution, we pass the outputs to some special
Expand Down Expand Up @@ -326,6 +326,19 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
"""
return None

def sandbox_execute(
self,
ctx: FlyteContext,
input_literal_map: _literal_models.LiteralMap,
) -> _literal_models.LiteralMap:
"""
Call dispatch_execute, in the context of a local sandbox execution. Not invoked during runtime.
"""
es = ctx.execution_state
b = es.user_space_params.with_task_sandbox()
ctx = ctx.current_context().with_execution_state(es.with_params(user_space_params=b.build())).build()
return self.dispatch_execute(ctx, input_literal_map)

@abstractmethod
def dispatch_execute(
self,
Expand Down
19 changes: 19 additions & 0 deletions tests/flytekit/unit/core/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import flytekit
from flytekit.core.checkpointer import SyncCheckpoint
from flytekit.core.local_cache import LocalTaskCache


def test_sync_checkpoint_write(tmpdir):
Expand Down Expand Up @@ -123,5 +124,23 @@ def t1(n: int) -> int:
return n + 1


@flytekit.task(cache=True, cache_version="v0")
def t2(n: int) -> int:
ctx = flytekit.current_context()
cp = ctx.checkpoint
cp.write(bytes(n + 1))
return n + 1


@pytest.fixture(scope="function", autouse=True)
def setup():
LocalTaskCache.initialize()
LocalTaskCache.clear()


def test_checkpoint_task():
assert t1(n=5) == 6


def test_checkpoint_cached_task():
assert t2(n=5) == 6