Skip to content

Commit

Permalink
Merge branch 'sklearn-extra' of https://github.com/flyteorg/flytekit
Browse files Browse the repository at this point in the history
…into sklearn-extra
  • Loading branch information
cosmicBboy committed Jan 25, 2023
2 parents 58d4ad5 + cb383d3 commit 728b9ab
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 13 deletions.
4 changes: 4 additions & 0 deletions docs/source/clients.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.. automodule:: flytekit.clients
:no-members:
:no-inherited-members:
:no-special-members:
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ Expected output:
flytekit
configuration
remote
clients
testing
extend
deck
Expand Down
6 changes: 6 additions & 0 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,12 @@ def _dispatch_execute(
logger.info(f"Engine folder written successfully to the output prefix {output_prefix}")
logger.debug("Finished _dispatch_execute")

if os.environ.get("FLYTE_FAIL_ON_ERROR", "").lower() == "true" and _constants.ERROR_FILE_NAME in output_file_dict:
# This env is set by the flytepropeller
# AWS batch job get the status from the exit code, so once we catch the error,
# we should return the error code here
exit(1)


def get_one_of(*args) -> str:
"""
Expand Down
19 changes: 19 additions & 0 deletions flytekit/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
=====================
Clients
=====================
.. currentmodule:: flytekit.clients
This module provides lower level access to a Flyte backend.
.. _clients_module:
.. autosummary::
:template: custom.rst
:toctree: generated/
:nosignatures:
~friendly.SynchronousFlyteClient
~raw.RawSynchronousFlyteClient
"""
3 changes: 2 additions & 1 deletion flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,8 @@ def create_and_link_node_from_remote(
extra_inputs = used_inputs ^ set(kwargs.keys())
if len(extra_inputs) > 0:
raise _user_exceptions.FlyteAssertion(
"Too many inputs were specified for the interface. Extra inputs were: {}".format(extra_inputs)
f"Too many inputs for [{entity.name}] Expected inputs: {typed_interface.inputs.keys()} "
f"- extra inputs: {extra_inputs}"
)

# Detect upstream nodes
Expand Down
23 changes: 14 additions & 9 deletions flytekit/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@

# By default, the root flytekit logger to debug so everything is logged, but enable fine-tuning
logger = logging.getLogger("flytekit")
# Root logger control
flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT"
if os.getenv(flytekit_root_env_var) is not None:
logger.setLevel(int(os.getenv(flytekit_root_env_var)))
else:
logger.setLevel(logging.DEBUG)

# Stop propagation so that configuration is isolated to this file (so that it doesn't matter what the
# global Python root logger is set to).
Expand All @@ -40,22 +34,33 @@

# create console handler
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)

# Root logger control
# Don't want to import the configuration library since that will cause all sorts of circular imports, let's
# just use the environment variable if it's defined. Decide in the future when we implement better controls
# if we should control with the channel or with the logger level.
# The handler log level controls whether log statements will actually print to the screen
flytekit_root_env_var = f"{LOGGING_ENV_VAR}_ROOT"
level_from_env = os.getenv(LOGGING_ENV_VAR)
if level_from_env is not None:
ch.setLevel(int(level_from_env))
root_level_from_env = os.getenv(flytekit_root_env_var)
if root_level_from_env is not None:
logger.setLevel(int(root_level_from_env))
elif level_from_env is not None:
logger.setLevel(int(level_from_env))
else:
ch.setLevel(logging.WARNING)
logger.setLevel(logging.WARNING)

for log_name, child_logger in child_loggers.items():
env_var = f"{LOGGING_ENV_VAR}_{log_name.upper()}"
level_from_env = os.getenv(env_var)
if level_from_env is not None:
child_logger.setLevel(int(level_from_env))
else:
if child_logger is user_space_logger:
child_logger.setLevel(logging.INFO)
else:
child_logger.setLevel(logging.WARNING)

# create formatter
formatter = jsonlogger.JsonFormatter(fmt="%(asctime)s %(name)s %(levelname)s %(message)s")
Expand Down
7 changes: 6 additions & 1 deletion flytekit/remote/lazy_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ def entity(self) -> T:
"""
with self._mutex:
if self._entity is None:
self._entity = self._getter()
try:
self._entity = self._getter()
except AttributeError as e:
raise RuntimeError(
f"Error downloading the entity {self._name}, (check original exception...)"
) from e
return self._entity

def __getattr__(self, item: str) -> typing.Any:
Expand Down
4 changes: 2 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
NotificationList,
WorkflowExecutionGetDataResponse,
)
from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteWorkflow
from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow
from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution
from flytekit.remote.interface import TypedInterface
from flytekit.remote.lazy_entity import LazyEntity
Expand Down Expand Up @@ -1460,7 +1460,7 @@ def sync_execution(
upstream_nodes=[],
bindings=[],
metadata=NodeMetadata(name=""),
flyte_task=flyte_entity,
task_node=FlyteTaskNode(flyte_entity),
)
}
if len(task_node_exec) >= 1
Expand Down
1 change: 1 addition & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def test_fetch_execute_task_convert_dict(flyteclient, flyte_workflows_register):
flyte_task = remote.fetch_task(name="workflows.basic.dict_str_wf.convert_to_string", version=f"v{VERSION}")
d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"}
execution = remote.execute(flyte_task, {"d": d}, wait=True)
remote.sync_execution(execution, sync_nodes=True)
assert json.loads(execution.outputs["o0"]) == {"key1": "value1", "key2": "value2"}


Expand Down
32 changes: 32 additions & 0 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import OrderedDict

import mock
import pytest
from flyteidl.core.errors_pb2 import ErrorDocument

from flytekit.bin.entrypoint import _dispatch_execute, normalize_inputs, setup_execution
Expand Down Expand Up @@ -110,6 +111,37 @@ def verify_output(*args, **kwargs):
assert mock_write_to_file.call_count == 1


@mock.patch.dict(os.environ, {"FLYTE_FAIL_ON_ERROR": "True"})
@mock.patch("flytekit.core.utils.load_proto_from_file")
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.get_data")
@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data")
@mock.patch("flytekit.core.utils.write_proto_to_file")
def test_dispatch_execute_return_error_code(mock_write_to_file, mock_upload_dir, mock_get_data, mock_load_proto):
mock_get_data.return_value = True
mock_upload_dir.return_value = True

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION)
)
) as ctx:
python_task = mock.MagicMock()
python_task.dispatch_execute.side_effect = Exception("random")

empty_literal_map = _literal_models.LiteralMap({}).to_flyte_idl()
mock_load_proto.return_value = empty_literal_map

def verify_output(*args, **kwargs):
assert isinstance(args[0], ErrorDocument)

mock_write_to_file.side_effect = verify_output

with pytest.raises(SystemExit) as cm:
_dispatch_execute(ctx, python_task, "inputs path", "outputs prefix")
pytest.assertEqual(cm.value.code, 1)


# This function collects outputs instead of writing them to a file.
# See flytekit.core.utils.write_proto_to_file for the original
def get_output_collector(results: OrderedDict):
Expand Down
13 changes: 13 additions & 0 deletions tests/flytekit/unit/remote/test_lazy_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,16 @@ def _getter():
e.compile(ctx)
assert e._entity is not None
assert e.entity == dummy_task


def test_lazy_loading_exception():
def _getter():
raise AttributeError("Error")

e = LazyEntity("x", _getter)
assert e.name == "x"
assert e._entity is None
with pytest.raises(RuntimeError) as exc:
assert e.blah

assert isinstance(exc.value.__cause__, AttributeError)

0 comments on commit 728b9ab

Please sign in to comment.