Skip to content

Commit

Permalink
Handle partial functions in parallel.Threads (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
grusin-db authored Apr 3, 2024
1 parent ea62287 commit d2ceef7
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies = [
"isort>=2.5.0",
"mypy",
"types-PyYAML",
"types-requests",
"types-requests"
]

python="3.10"
Expand Down
23 changes: 21 additions & 2 deletions src/databricks/labs/blueprint/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,34 @@ def _progress_report(self, _):
logger.info(msg)

@staticmethod
def _wrap_result(func, name):
def _get_result_function_signature(func, name):
if not isinstance(func, functools.partial):
return name

# try to build up signature, this should never fail
try:
args = []
args.extend(repr(x) for x in func.args)
args.extend(f"{k}={v!r}" for (k, v) in func.keywords.items())
args_str = ", ".join(args)
if args_str:
return f"{name}({args_str})"
return name
# but if it would ever fail, better return generic serialized name, than messing up traceback even more...
except Exception: # pylint: disable=broad-exception-caught
return str(func)

@classmethod
def _wrap_result(cls, func, name):
"""This method emulates GoLang's error return style"""

@functools.wraps(func)
def inner(*args, **kwargs):
try:
return func(*args, **kwargs), None
except Exception as err: # pylint: disable=broad-exception-caught
logger.error(f"{name} task failed: {err!s}", exc_info=err)
signature = cls._get_result_function_signature(func, name)
logger.error(f"{signature} task failed: {err!s}", exc_info=err)
return None, err

return inner
32 changes: 32 additions & 0 deletions tests/unit/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from functools import partial

from databricks.sdk.core import DatabricksError

Expand Down Expand Up @@ -117,3 +118,34 @@ def works():
assert [True, True, True, True] == results
assert 0 == len(errors)
assert ["Finished 'testing' tasks: 100% results available (4/4)"] == _predictable_messages(caplog)


def test_odd_partial_failed(caplog):
caplog.set_level(logging.INFO)

def fails_on_odd(n=1, dummy=None):
if isinstance(n, str):
raise RuntimeError("strings are not supported!")

if n % 2:
msg = "failed"
raise DatabricksError(msg)

tasks = [
partial(fails_on_odd, n=1),
partial(fails_on_odd, 1, dummy="6"),
partial(fails_on_odd),
partial(fails_on_odd, n="aaa"),
]

results, errors = Threads.gather("testing", tasks)

assert [] == results
assert 4 == len(errors)
assert [
"All 'testing' tasks failed!!!",
"testing task failed: failed",
"testing(1, dummy='6') task failed: failed",
"testing(n='aaa') task failed: strings are not supported!",
"testing(n=1) task failed: failed",
] == _predictable_messages(caplog)

0 comments on commit d2ceef7

Please sign in to comment.