Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve jupyter repr and __repr__ for Flyte models #2647

Merged
40 changes: 37 additions & 3 deletions flytekit/models/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import abc
import json
import re
import os
from contextlib import closing
from io import StringIO
from textwrap import shorten
from typing import Dict

from flyteidl.admin import common_pb2 as _common_pb2
Expand Down Expand Up @@ -40,6 +43,29 @@
pass


def _repr_idl_yaml_like(idl, indent=0) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make this the default repr

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make it the default, but I saw the current repr strips away the newlines:

literal_str = re.sub(r"\s+", " ", str(self.to_flyte_idl())).strip()

Was there a reason for making it all one line?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think so, not a good one anyways. feel free to overwrite. the reason likely was logging platforms (cloudwatch logs, splunk, grafana, etc) don't render logs with new lines well. but i don't know how much this could matter for python. that's more of a backend issue.

"""Formats an IDL into a YAML-like representation."""
if not hasattr(idl, "ListFields"):
return str(idl)

Check warning on line 49 in flytekit/models/common.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/common.py#L49

Added line #L49 was not covered by tests

with closing(StringIO()) as out:
for descriptor, field in idl.ListFields():
try:
inner_fields = field.ListFields()
# if inner_fields is empty, then we do not render the descriptor,
# because it is empty
if inner_fields:
out.write(" " * indent + descriptor.name + ":" + os.linesep)
out.write(_repr_idl_yaml_like(field, indent + 2))
except AttributeError:
# No ListFields -> Must be a scalar
str_repr = shorten(str(field).strip(), width=80)
if str_repr:
out.write(" " * indent + descriptor.name + ": " + str_repr + os.linesep)

return out.getvalue()


class FlyteIdlEntity(object, metaclass=FlyteType):
def __eq__(self, other):
return isinstance(other, FlyteIdlEntity) and other.to_flyte_idl() == self.to_flyte_idl()
Expand All @@ -60,9 +86,9 @@
"""
:rtype: Text
"""
literal_str = re.sub(r"\s+", " ", str(self.to_flyte_idl())).strip()
str_repr = _repr_idl_yaml_like(self.to_flyte_idl(), indent=2).rstrip(os.linesep)
type_str = type(self).__name__
return f"[Flyte Serialized object: Type: <{type_str}> Value: <{literal_str}>]"
return f"Flyte Serialized object ({type_str}):" + os.linesep + str_repr
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a few more tests to cover the different types of literals (collections, maps, etc)? These are useful especially in flyteremote when debugging.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more tests in 21372e4


def verbose_string(self):
"""
Expand All @@ -73,6 +99,14 @@
def serialize_to_string(self) -> str:
return self.to_flyte_idl().SerializeToString()

def _repr_html_(self) -> str:
"""HTML repr for object."""
# `_repr_html_` is used by Jupyter to render objects
type_str = type(self).__name__
idl = self.to_flyte_idl()
str_repr = _repr_idl_yaml_like(idl).rstrip(os.linesep)
return f"<h4>{type_str}</h4><pre>{str_repr}</pre>"

Check warning on line 108 in flytekit/models/common.py

View check run for this annotation

Codecov / codecov/patch

flytekit/models/common.py#L105-L108

Added lines #L105 - L108 were not covered by tests

@property
def is_empty(self):
return len(self.to_flyte_idl().SerializeToString()) == 0
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def wf(i: int, j: int):
# which is incorrect
with pytest.raises(
FlyteAssertion,
match=r"Missing input `i` type `\[Flyte Serialized object: Type: <LiteralType> Value: <simple: INTEGER>\]`",
match=r"Missing input `i` type `Flyte Serialized object \(LiteralType\):",
):
create_and_link_node_from_remote(ctx, lp)

Expand Down
3 changes: 1 addition & 2 deletions tests/flytekit/unit/core/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,8 +1767,7 @@ def wf2(a: typing.Union[int, str]) -> typing.Union[int, str]:
match=re.escape(
"Error encountered while executing 'wf2':\n"
f" Failed to convert inputs of task '{prefix}tests.flytekit.unit.core.test_type_hints.t2':\n"
' Cannot convert from [Flyte Serialized object: Type: <Literal> Value: <scalar { union { value { scalar { primitive { string_value: "2" } } } '
'type { simple: STRING structure { tag: "str" } } } }>] to typing.Union[float, dict] (using tag str)'
r' Cannot convert from Flyte Serialized object (Literal):'
),
):
assert wf2(a="2") == "2"
Expand Down
114 changes: 111 additions & 3 deletions tests/flytekit/unit/models/test_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import datetime
from datetime import timezone, timedelta
import textwrap

from flytekit.models import common as _common
from flytekit.models.core import execution as _execution

from flytekit.models.execution import ExecutionClosure

from flytekit.models.execution import LiteralMapBlob
from flytekit.models.literals import LiteralMap, Scalar, Primitive, Literal, RetryStrategy
from flytekit.models.core.execution import WorkflowExecutionPhase
from flytekit.models.task import TaskMetadata, RuntimeMetadata
from flytekit.models.project import Project


def test_notification_email():
obj = _common.EmailNotification(["a", "b", "c"])
Expand Down Expand Up @@ -106,7 +118,103 @@ def test_auth_role_empty():


def test_short_string_raw_output_data_config():
""""""
obj = _common.RawOutputDataConfig("s3://bucket")
assert "Flyte Serialized object: Type: <RawOutputDataConfig> Value" in obj.short_string()
assert "Flyte Serialized object: Type: <RawOutputDataConfig> Value" in repr(obj)
assert "Flyte Serialized object (RawOutputDataConfig):" in obj.short_string()
assert "Flyte Serialized object (RawOutputDataConfig):" in repr(obj)


def test_html_repr_data_config():
obj = _common.RawOutputDataConfig("s3://bucket")

out = obj._repr_html_()
assert "output_location_prefix: s3://bucket" in out
assert "<h4>RawOutputDataConfig</h4>" in out


def test_short_string_entities_ExecutionClosure():
_OUTPUT_MAP = LiteralMap(
{"b": Literal(scalar=Scalar(primitive=Primitive(integer=2)))}
)

test_datetime = datetime.datetime(year=2022, month=1, day=1, tzinfo=timezone.utc)
test_timedelta = datetime.timedelta(seconds=10)
test_outputs = LiteralMapBlob(values=_OUTPUT_MAP, uri="http://foo/")

obj = ExecutionClosure(
phase=WorkflowExecutionPhase.SUCCEEDED,
started_at=test_datetime,
duration=test_timedelta,
outputs=test_outputs,
created_at=None,
updated_at=test_datetime,
)
expected_result = textwrap.dedent("""\
Flyte Serialized object (ExecutionClosure):
outputs:
uri: http://foo/
phase: 4
started_at:
seconds: 1640995200
duration:
seconds: 10
updated_at:
seconds: 1640995200""")

assert repr(obj) == expected_result
assert obj.short_string() == expected_result


def test_short_string_entities_Primitive():
obj = Primitive(integer=1)
expected_result = textwrap.dedent("""\
Flyte Serialized object (Primitive):
integer: 1""")

assert repr(obj) == expected_result
assert obj.short_string() == expected_result


def test_short_string_entities_TaskMetadata():
obj = TaskMetadata(
True,
RuntimeMetadata(RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
RetryStrategy(3),
True,
"0.1.1b0",
"This is deprecated!",
True,
"A",
(),
)

expected_result = textwrap.dedent("""\
Flyte Serialized object (TaskMetadata):
discoverable: True
runtime:
type: 1
version: 1.0.0
flavor: python
timeout:
seconds: 86400
retries:
retries: 3
discovery_version: 0.1.1b0
deprecated_error_message: This is deprecated!
interruptible: True
cache_serializable: True
pod_template_name: A""")
assert repr(obj) == expected_result
assert obj.short_string() == expected_result


def test_short_string_entities_Project():
obj = Project("project_id", "project_name", "project_description")
expected_result = textwrap.dedent("""\
Flyte Serialized object (Project):
id: project_id
name: project_name
description: project_description""")

assert repr(obj) == expected_result
assert obj.short_string() == expected_result
Loading