Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Issue-3617] Enables FlyteFiles, FlyteDirectors, and StructuredDatasets inputs in papermill plugin #1612

Merged
merged 8 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
record_outputs
"""

from .task import NotebookTask, record_outputs
from .task import NotebookTask, read_flytedirectory, read_flytefile, read_structureddataset, record_outputs
65 changes: 61 additions & 4 deletions plugins/flytekit-papermill/flytekitplugins/papermill/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,28 @@
import logging
import os
import sys
import tempfile
import typing
from typing import Any

import nbformat
import papermill as pm
from flyteidl.core.literals_pb2 import Literal as _pb2_Literal
from flyteidl.core.literals_pb2 import LiteralMap as _pb2_LiteralMap
from google.protobuf import text_format as _text_format
from nbconvert import HTMLExporter

from flytekit import FlyteContext, PythonInstanceTask
from flytekit import FlyteContext, PythonInstanceTask, StructuredDataset
from flytekit.configuration import SerializationSettings
from flytekit.core import utils
from flytekit.core.context_manager import ExecutionParameters
from flytekit.deck.deck import Deck
from flytekit.extend import Interface, TaskPlugins, TypeEngine
from flytekit.loggers import logger
from flytekit.models import task as task_models
from flytekit.models.literals import LiteralMap
from flytekit.types.file import HTMLPage, PythonNotebook
from flytekit.models.literals import Literal, LiteralMap
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, HTMLPage, PythonNotebook

T = typing.TypeVar("T")

Expand Down Expand Up @@ -255,6 +259,10 @@ def execute(self, **kwargs) -> Any:
singleton
"""
logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.")
for k, v in kwargs.items():
if isinstance(v, (FlyteFile, FlyteDirectory, StructuredDataset)):
peridotml marked this conversation as resolved.
Show resolved Hide resolved
kwargs[k] = save_literal_to_file(v)

# Execute Notebook via Papermill.
pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs, log_output=self._stream_logs) # type: ignore

Expand All @@ -265,6 +273,7 @@ def execute(self, **kwargs) -> Any:
if outputs:
m = outputs.literals
output_list = []

for k, type_v in self.python_interface.outputs.items():
if k == self._IMPLICIT_OP_NOTEBOOK:
output_list.append(self.output_notebook_path)
Expand All @@ -274,7 +283,7 @@ def execute(self, **kwargs) -> Any:
v = TypeEngine.to_python_value(ctx=FlyteContext.current_context(), lv=m[k], expected_python_type=type_v)
output_list.append(v)
else:
raise RuntimeError(f"Expected output {k} of type {v} not found in the notebook outputs")
raise RuntimeError(f"Expected output {k} of type {type_v} not found in the notebook outputs")
peridotml marked this conversation as resolved.
Show resolved Hide resolved

return tuple(output_list)

Expand Down Expand Up @@ -307,3 +316,51 @@ def record_outputs(**kwargs) -> str:
lit = TypeEngine.to_literal(ctx, python_type=type(v), python_val=v, expected=expected)
m[k] = lit
return LiteralMap(literals=m).to_flyte_idl()


def save_literal_to_file(input: Any) -> str:
"""
Serializes an input
"""
peridotml marked this conversation as resolved.
Show resolved Hide resolved
ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(type(input))
lit = TypeEngine.to_literal(ctx, python_type=type(input), python_val=input, expected=expected)

tmp_file = tempfile.mktemp(suffix="bin")
utils.write_proto_to_file(lit.to_flyte_idl(), tmp_file)
return tmp_file


def read_input(path: str, dtype: T) -> T:
"""
Reads a Flyte literal from a file
peridotml marked this conversation as resolved.
Show resolved Hide resolved
"""
if type(path) == dtype:
peridotml marked this conversation as resolved.
Show resolved Hide resolved
return path

proto = utils.load_proto_from_file(_pb2_Literal, path)
lit = Literal.from_flyte_idl(proto)
ctx = FlyteContext.current_context()
python_value = TypeEngine.to_python_value(ctx, lit, dtype)
return python_value


def read_flytefile(path: str) -> T:
"""
Use this method to read a FlyteFile literal from a file.
peridotml marked this conversation as resolved.
Show resolved Hide resolved
"""
return read_input(path=path, dtype=FlyteFile)


def read_flytedirectory(path: str) -> T:
"""
Use this method to read a FlyteDirectory literal from a file.
peridotml marked this conversation as resolved.
Show resolved Hide resolved
"""
return read_input(path=path, dtype=FlyteDirectory)


def read_structureddataset(path: str) -> T:
"""
Use this method to read a StructuredDataset literal from a file.
peridotml marked this conversation as resolved.
Show resolved Hide resolved
"""
return read_input(path=path, dtype=StructuredDataset)
42 changes: 40 additions & 2 deletions plugins/flytekit-papermill/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import datetime
import os
import tempfile

import pandas as pd
from flytekitplugins.papermill import NotebookTask
from flytekitplugins.pod import Pod
from kubernetes.client import V1Container, V1PodSpec

import flytekit
from flytekit import kwtypes
from flytekit import StructuredDataset, kwtypes, task
from flytekit.configuration import Image, ImageConfig
from flytekit.types.file import PythonNotebook
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, PythonNotebook

from .testdata.datatype import X

Expand Down Expand Up @@ -134,3 +137,38 @@ def test_notebook_pod_task():
nb.get_command(serialization_settings)
== nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"]
)


def test_flyte_types():
@task
def create_file() -> FlyteFile:
tmp_file = tempfile.mktemp()
with open(tmp_file, "w") as f:
f.write("abc")
return FlyteFile(path=tmp_file)

@task
def create_dir() -> FlyteDirectory:
tmp_dir = tempfile.mkdtemp()
with open(os.path.join(tmp_dir, "file.txt"), "w") as f:
f.write("abc")
return FlyteDirectory(path=tmp_dir)

@task
def create_sd() -> StructuredDataset:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
return StructuredDataset(dataframe=df)

ff = create_file()
fd = create_dir()
sd = create_sd()

nb_name = "nb-types"
nb_types = NotebookTask(
name="test",
notebook_path=_get_nb_path(nb_name, abs=False),
inputs=kwtypes(ff=FlyteFile, fd=FlyteDirectory, sd=StructuredDataset),
outputs=kwtypes(success=bool),
)
success, out, render = nb_types.execute(ff=ff, fd=fd, sd=sd)
assert success is True, "Notebook execution failed"
7 changes: 3 additions & 4 deletions plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
"outputs": [],
"source": [
"from flytekitplugins.papermill import record_outputs\n",
"\n",
"record_outputs(square=out)"
]
},
Expand All @@ -49,7 +48,7 @@
"metadata": {
"celltoolbar": "Tags",
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -63,9 +62,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}
100 changes: 100 additions & 0 deletions plugins/flytekit-papermill/tests/testdata/nb-types.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"ff = None\n",
"fd = None\n",
"sd = None"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"from flytekitplugins.papermill import (\n",
" read_flytefile, read_flytedirectory, read_structureddataset,\n",
" record_outputs\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ff = read_flytefile(ff)\n",
"fd = read_flytedirectory(fd)\n",
"sd = read_structureddataset(sd)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# read file\n",
"with open(ff.download(), 'r') as f:\n",
" text = f.read()\n",
" assert text == \"abc\", \"Text does not match\"\n",
"\n",
"# check file inside directory\n",
"with open(os.path.join(fd.download(),\"file.txt\"), 'r') as f:\n",
" text = f.read()\n",
" assert text == \"abc\", \"Text does not match\"\n",
"\n",
"# check dataset\n",
"df = sd.open(pd.DataFrame).all()\n",
"expected = pd.DataFrame({\"a\": [1, 2], \"b\": [3, 4]})\n",
"assert df.equals(expected), \"Dataframes do not match\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"outputs"
]
},
"outputs": [],
"source": [
"record_outputs(success=True)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}