Skip to content

Commit

Permalink
Run remote Launchplan from pyflyte run (flyteorg#1785)
Browse files Browse the repository at this point in the history
* Beautified pyflyte run even for every task and workflow

- identify a task or a workflow
- task or workflow help menus show types and use rich to beautify

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* one more improvement

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* updated

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* updated command

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* Updated

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* updated formatting

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* updated

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* updated

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* bug fixed in types

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* Updated

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* lint

Signed-off-by: Kevin Su <pingsutw@apache.org>

---------

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>
Signed-off-by: Kevin Su <pingsutw@apache.org>
Co-authored-by: Kevin Su <pingsutw@apache.org>
  • Loading branch information
2 people authored and hhcs9527 committed Sep 9, 2023
1 parent 5a103d2 commit fd49c2b
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 34 deletions.
234 changes: 202 additions & 32 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import yaml
from dataclasses_json import DataClassJsonMixin
from pytimeparse import parse
from rich.progress import Progress
from typing_extensions import get_args

from flytekit import BlobType, Literal, Scalar
Expand All @@ -42,11 +43,11 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase
from flytekit.models import literals
from flytekit.models.interface import Variable
from flytekit.models.interface import Parameter, Variable
from flytekit.models.literals import Blob, BlobMetadata, LiteralCollection, LiteralMap, Primitive, Union
from flytekit.models.types import LiteralType, SimpleType
from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow
from flytekit.remote.executions import FlyteWorkflowExecution
from flytekit.remote.remote import FlyteRemote
from flytekit.tools import module_loader, script_mode
from flytekit.tools.script_mode import _find_project_root
from flytekit.tools.translator import Options
Expand Down Expand Up @@ -115,15 +116,13 @@ class PickleParamType(click.ParamType):
def convert(
self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context]
) -> typing.Any:

uri = FlyteContextManager.current_context().file_access.get_random_local_path()
with open(uri, "w+b") as outfile:
cloudpickle.dump(value, outfile)
return FileParam(filepath=str(pathlib.Path(uri).resolve()))


class DateTimeType(click.DateTime):

_NOW_FMT = "now"
_ADDITONAL_FORMATS = [_NOW_FMT]

Expand Down Expand Up @@ -458,6 +457,7 @@ def to_click_option(
python_type: typing.Type,
default_val: typing.Any,
get_upload_url_fn: typing.Callable,
required: bool,
) -> click.Option:
"""
This handles converting workflow input types to supported click parameters with callbacks to initialize
Expand All @@ -470,21 +470,24 @@ def to_click_option(
if literal_converter.is_bool() and not default_val:
default_val = False

description_extra = ""
if literal_var.type.simple == SimpleType.STRUCT:
if default_val:
if type(default_val) == dict or type(default_val) == list:
default_val = json.dumps(default_val)
else:
default_val = cast(DataClassJsonMixin, default_val).to_json()
if literal_var.type.metadata:
description_extra = f": {json.dumps(literal_var.type.metadata)}"

return click.Option(
param_decls=[f"--{input_name}"],
type=literal_converter.click_type,
is_flag=literal_converter.is_bool(),
default=default_val,
show_default=True,
required=default_val is None,
help=literal_var.description,
required=required,
help=literal_var.description + description_extra,
callback=literal_converter.convert,
)

Expand Down Expand Up @@ -592,6 +595,13 @@ def get_workflow_command_base_params() -> typing.List[click.Option]:
type=str,
help="Tags to set for the execution",
),
click.Option(
param_decls=["--limit", "limit"],
required=False,
type=int,
default=10,
help="Use this to limit number of launch plans retreived from the backend, if `from-server` option is used",
),
]


Expand Down Expand Up @@ -662,12 +672,59 @@ def get_entities_in_file(filename: pathlib.Path, should_delete: bool) -> Entitie
return Entities(workflows, tasks)


def run_remote(
ctx: click.Context,
remote: FlyteRemote,
entity: typing.Union[FlyteWorkflow, FlyteTask, FlyteLaunchPlan],
project: str,
domain: str,
inputs: typing.Dict[str, typing.Any],
run_level_params: typing.Dict[str, typing.Any],
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
):
"""
Helper method that executes the given remote FlyteLaunchplan, FlyteWorkflow or FlyteTask
"""
options = None
service_account = run_level_params.get("service_account")
if service_account:
# options are only passed for the execution. This is to prevent errors when registering a duplicate workflow
# It is assumed that the users expectations is to override the service account only for the execution
options = Options.default_from(k8s_service_account=service_account)

execution = remote.execute(
entity,
inputs=inputs,
project=project,
domain=domain,
name=run_level_params.get("name"),
wait=run_level_params.get("wait_execution"),
options=options,
type_hints=type_hints,
overwrite_cache=run_level_params.get("overwrite_cache"),
envs=run_level_params.get("envs"),
tags=run_level_params.get("tag"),
)

console_url = remote.generate_console_url(execution)
click.secho(f"Go to {console_url} to see execution in the console.")

if run_level_params.get("dump_snippet"):
dump_flyte_remote_snippet(execution, project, domain)

if ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME):
os.remove(ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME))


def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]):
"""
Returns a function that is used to implement WorkflowCommand and execute a flyte workflow.
"""

def _run(*args, **kwargs):
"""
Click command function that is used to execute a flyte workflow from the given entity in the file.
"""
# By the time we get to this function, all the loading has already happened

run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY]
Expand Down Expand Up @@ -703,37 +760,137 @@ def _run(*args, **kwargs):
copy_all=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_COPY_ALL),
)

options = None
service_account = run_level_params.get("service_account")
if service_account:
# options are only passed for the execution. This is to prevent errors when registering a duplicate workflow
# It is assumed that the users expectations is to override the service account only for the execution
options = Options.default_from(k8s_service_account=service_account)

execution = remote.execute(
run_remote(
ctx,
remote,
remote_entity,
inputs=inputs,
project=project,
domain=domain,
name=run_level_params.get("name"),
wait=run_level_params.get("wait_execution"),
options=options,
project,
domain,
inputs,
run_level_params,
type_hints=entity.python_interface.inputs,
overwrite_cache=run_level_params.get("overwrite_cache"),
envs=run_level_params.get("envs"),
tags=run_level_params.get("tag"),
)

console_url = remote.generate_console_url(execution)
click.secho(f"Go to {console_url} to see execution in the console.")
return _run

if run_level_params.get("dump_snippet"):
dump_flyte_remote_snippet(execution, project, domain)

if ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME):
os.remove(ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME))
class DynamicLaunchPlanCommand(click.RichCommand):
"""
This is a dynamic command that is created for each launch plan. This is used to execute a launch plan.
It will fetch the launch plan from remote and create parameters from all the inputs of the launch plan.
"""

return _run
def __init__(self, name: str, h: str, lp_name: str, **kwargs):
super().__init__(name=name, help=h, **kwargs)
self._lp_name = lp_name
self._lp = None

def _fetch_launch_plan(self, ctx: click.Context) -> FlyteLaunchPlan:
if self._lp:
return self._lp
project = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT)
domain = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_DOMAIN)
r = get_and_save_remote_with_click_context(ctx, project, domain)
self._lp = r.fetch_launch_plan(project, domain, self._lp_name)
return self._lp

def _get_params(
self,
ctx: click.Context,
inputs: typing.Dict[str, Variable],
native_inputs: typing.Dict[str, type],
fixed: typing.Dict[str, Literal],
defaults: typing.Dict[str, Parameter],
) -> typing.List["click.Parameter"]:
params = []
project = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT)
domain = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_DOMAIN)
r = get_and_save_remote_with_click_context(ctx, project, domain)
get_upload_url_fn = functools.partial(r.client.get_upload_signed_url, project=project, domain=domain)
flyte_ctx = context_manager.FlyteContextManager.current_context()
for name, var in inputs.items():
if fixed and name in fixed:
continue
required = True
if defaults and name in defaults:
required = False
params.append(
to_click_option(ctx, flyte_ctx, name, var, native_inputs[name], None, get_upload_url_fn, required)
)
return params

def get_params(self, ctx: click.Context) -> typing.List["click.Parameter"]:
if not self.params:
self.params = []
lp = self._fetch_launch_plan(ctx)
if lp.interface:
if lp.interface.inputs:
types = TypeEngine.guess_python_types(lp.interface.inputs)
self.params = self._get_params(
ctx, lp.interface.inputs, types, lp.fixed_inputs.literals, lp.default_inputs.parameters
)

return super().get_params(ctx)

def invoke(self, ctx: click.Context) -> typing.Any:
"""
Default or None values should be ignored. Only values that are provided by the user should be passed to the
remote execution.
"""
project = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT)
domain = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_DOMAIN)
r = get_and_save_remote_with_click_context(ctx, project, domain)
lp = self._fetch_launch_plan(ctx)
run_remote(
ctx,
r,
lp,
project,
domain,
ctx.params,
ctx.obj[RUN_LEVEL_PARAMS_KEY],
type_hints=lp.python_interface.inputs if lp.python_interface else None,
)


class RemoteLaunchPlanGroup(click.RichGroup):
"""
click multicommand that retrieves launchplans from a remote flyte instance and executes them.
"""

COMMAND_NAME = "remote-launchplan"

def __init__(self):
super().__init__(
name="from-server",
help="Retrieve launchplans from a remote flyte instance and execute them.",
params=[
click.Option(
["--limit"], help="Limit the number of launchplans to retrieve.", default=10, show_default=True
)
],
)
self._lps = []

def list_commands(self, ctx):
if self._lps:
return self._lps
if ctx.obj is None:
return self._lps
project = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT)
domain = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_DOMAIN)
l = ctx.obj[RUN_LEVEL_PARAMS_KEY].get("limit")
r = get_and_save_remote_with_click_context(ctx, project, domain)
progress = Progress(transient=True)
task = progress.add_task(f"[cyan]Gathering [{l}] remote LaunchPlans...", total=None)
with progress:
progress.start_task(task)
lps = r.client.list_launch_plan_ids_paginated(project=project, domain=domain, limit=l)
self._lps = [l.name for l in lps[0]]
return self._lps

def get_command(self, ctx, name):
return DynamicLaunchPlanCommand(name=name, h="Execute a launchplan from remote.", lp_name=name)


class WorkflowCommand(click.RichGroup):
Expand All @@ -756,6 +913,8 @@ def __init__(self, filename: str, *args, **kwargs):
self._entities = None

def list_commands(self, ctx):
if self._entities:
return self._entities.all()
entities = get_entities_in_file(self._filename, self._should_delete)
self._entities = entities
return entities.all()
Expand Down Expand Up @@ -806,8 +965,11 @@ def get_command(self, ctx, exe_entity):
for input_name, input_type_val in entity.python_interface.inputs_with_defaults.items():
literal_var = entity.interface.inputs.get(input_name)
python_type, default_val = input_type_val
required = default_val is None
params.append(
to_click_option(ctx, flyte_ctx, input_name, literal_var, python_type, default_val, get_upload_url_fn)
to_click_option(
ctx, flyte_ctx, input_name, literal_var, python_type, default_val, get_upload_url_fn, required
)
)

entity_type = "Workflow" if is_workflow else "Task"
Expand All @@ -831,13 +993,21 @@ class RunCommand(click.RichGroup):
def __init__(self, *args, **kwargs):
params = get_workflow_command_base_params()
super().__init__(*args, params=params, **kwargs)
self._files = []

def list_commands(self, ctx):
return [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"]
if self._files:
return self._files
self._files = [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"] + [
RemoteLaunchPlanGroup.COMMAND_NAME
]
return self._files

def get_command(self, ctx, filename):
if ctx.obj:
ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params
if filename == RemoteLaunchPlanGroup.COMMAND_NAME:
return RemoteLaunchPlanGroup()
return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}")


Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,11 @@ def to_literal(
def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
return expected_python_type(lv.scalar.primitive.string_value) # type: ignore

def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
if literal_type.enum_type:
return enum.Enum("DynamicEnum", {f"{i}": i for i in literal_type.enum_type.values}) # type: ignore
raise ValueError(f"Enum transformer cannot reverse {literal_type}")


def convert_json_schema_to_python_class(schema: Dict[str, Any], schema_name: str) -> Type[Any]:
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def build_image(self, img):
ImageBuildEngine.register("test", TestImageSpecBuilder())

@task
def a():
def tk():
...

mock_click_ctx = mock.MagicMock()
Expand Down Expand Up @@ -354,7 +354,7 @@ def check_image(*args, **kwargs):

mock_remote.register_script.side_effect = check_image

run_command(mock_click_ctx, a)()
run_command(mock_click_ctx, tk)()


def test_file_param():
Expand Down
3 changes: 3 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,9 @@ def test_enum_type():
assert t.enum_type.values
assert t.enum_type.values == [c.value for c in Color]

g = TypeEngine.guess_python_type(t)
assert [e.value for e in g] == [e.value for e in Color]

ctx = FlyteContextManager.current_context()
lv = TypeEngine.to_literal(ctx, Color.RED, Color, TypeEngine.to_literal_type(Color))
assert lv
Expand Down

0 comments on commit fd49c2b

Please sign in to comment.