diff --git a/spine_engine/jumpster.py b/spine_engine/jumpster.py index ce55e09..117437d 100644 --- a/spine_engine/jumpster.py +++ b/spine_engine/jumpster.py @@ -65,7 +65,7 @@ def __init__(self, event_type, item_name, direction, **kwargs): class Failure(BaseException): - """A failure""" + """Item execution stopped by failure or user interruption.""" class Output: @@ -80,7 +80,7 @@ def __init__(self, item_finish_state): class JumpsterThreadError(Exception): """An exception has occurred in one or more of the threads jumpster manages. - This error forwards the message and stack trace for all of the collected errors. + This error forwards the message and stack trace for all the collected errors. """ def __init__(self, *args, **kwargs): @@ -251,7 +251,6 @@ def execute(self): empty_iters.append(key) except StopIteration: empty_iters.append(key) - # TODO: Anything about loops? # clear and mark complete finished iterators for key in empty_iters: del active_iters[key] @@ -259,9 +258,7 @@ def execute(self): if errs: raise JumpsterThreadError( "During multithread execution errors occurred in threads:\n{error_list}".format( - error_list="\n".join( - ["In thread {tid}: {err}".format(tid=tid, err=err.to_string()) for tid, err in errs.items()] - ) + error_list="\n".join([f"In thread {tid}: {err.to_string()}" for tid, err in errs.items()]) ), thread_error_infos=list(errs.values()), ) @@ -338,6 +335,7 @@ def execute_step_in_thread(step, errors): Args: step (Step): step to execute + errors (dict): mapping from thread id to error info """ event_queue = queue.Queue() thread = threading.Thread(target=_do_execute_step_in_thread, args=(event_queue, step)) @@ -394,4 +392,4 @@ def _do_execute_step_in_thread(event_queue, step): ThreadSystemErrorEvent(tid=tid, error_info=serializable_error_info_from_exc_info(sys.exc_info())) ) except Failure: - pass + event_queue.put(JumpsterEvent(JumpsterEventType.STEP_FAILURE, step.item_name, step.direction)) diff --git a/spine_engine/project_item/connection.py b/spine_engine/project_item/connection.py index 64d071c..0f28bed 100644 --- a/spine_engine/project_item/connection.py +++ b/spine_engine/project_item/connection.py @@ -589,8 +589,10 @@ def from_dict(cls, connection_dict): class Jump(ConnectionBase): """Represents a conditional jump between two project items.""" + _DEFAULT_CONDITION = {"type": "python-script", "script": "exit(1)", "specification": ""} + def __init__( - self, source_name, source_position, destination_name, destination_position, condition={}, cmd_line_args=() + self, source_name, source_position, destination_name, destination_position, condition=None, cmd_line_args=() ): """ Args: @@ -598,11 +600,11 @@ def __init__( source_position (str): source anchor's position destination_name (str): destination project item's name destination_position (str): destination anchor's position - condition (dict): jump condition + condition (dict, optional): jump condition + cmd_line_args (Iterable of str): command line arguments """ super().__init__(source_name, source_position, destination_name, destination_position) - default_condition = {"type": "python-script", "script": "exit(1)", "specification": ""} - self.condition = condition if condition else default_condition + self.condition = condition if condition is not None else self._DEFAULT_CONDITION self._resources_from_source = set() self._resources_from_destination = set() self.cmd_line_args = list(cmd_line_args) diff --git a/spine_engine/project_item/executable_item_base.py b/spine_engine/project_item/executable_item_base.py index 272ff0d..878bf1a 100644 --- a/spine_engine/project_item/executable_item_base.py +++ b/spine_engine/project_item/executable_item_base.py @@ -10,10 +10,7 @@ # this program. If not, see . ###################################################################################################################### -""" -Contains ExecutableItem, a project item's counterpart in execution as well as support utilities. - -""" +""" Contains ExecutableItem, a project item's counterpart in execution as well as support utilities. """ from hashlib import sha1 from pathlib import Path from ..utils.helpers import ExecutionDirection, ItemExecutionFinishState, shorten @@ -28,6 +25,7 @@ def __init__(self, name, project_dir, logger, group_id=None): name (str): item's name project_dir (str): absolute path to project directory logger (LoggerInterface): a logger + group_id (str, optional): execution group identifier """ self._name = name self._project_dir = project_dir diff --git a/spine_engine/spine_engine.py b/spine_engine/spine_engine.py index 8fe6b61..68d5175 100644 --- a/spine_engine/spine_engine.py +++ b/spine_engine/spine_engine.py @@ -101,7 +101,6 @@ def __init__( Raises: EngineInitFailed: Raised if initialization fails """ - super().__init__() self._queue = mp.Queue() if items is None: items = {} @@ -109,7 +108,7 @@ def __init__( if execution_permits is None: execution_permits = {} self._execution_permits = execution_permits - connections = list(map(Connection.from_dict, connections)) # Deserialize connections + connections = list(map(Connection.from_dict, connections)) project_item_loader = ProjectItemLoader() self._executable_item_classes = project_item_loader.load_executable_item_classes(items_module_name) required_items = required_items_for_execution( @@ -137,8 +136,11 @@ def __init__( self._item_names = list(self._dag) # Names of permitted items and their neighbors if jumps is None: jumps = [] - self._jumps = list(map(Jump.from_dict, jumps)) - validate_jumps(self._jumps, self._dag) + else: + jumps = list(map(Jump.from_dict, jumps)) + items_by_jump = _get_items_by_jump(jumps, self._dag) + self._jumps = filter_unneeded_jumps(jumps, items_by_jump, execution_permits) + validate_jumps(self._jumps, items_by_jump, self._dag) for x in self._connections + self._jumps: x.make_logger(self._queue) for x in self._jumps: @@ -275,10 +277,11 @@ def _get_event_stream(self): self._thread.start() while True: msg = self._queue.get() - yield msg if msg[0] == "dag_exec_finished": break + yield msg self._thread.join() + yield msg def answer_prompt(self, prompter_id, answer): """Answers the prompt for the specified prompter id.""" @@ -760,14 +763,25 @@ def _validate_dag(dag): raise EngineInitFailed("DAG contains unconnected items.") -def validate_jumps(jumps, dag): +def filter_unneeded_jumps(jumps, items_by_jump, execution_permits): + """Drops jumps whose items are not going to be executed. + + Args: + jumps (Iterable of Jump): jumps to filter + items_by_jump (dict): mapping from jump to list of item names + execution_permits (dict): mapping from item name to boolean telling if its is permitted to execute + """ + return [jump for jump in jumps if all(execution_permits[item] for item in items_by_jump[jump])] + + +def validate_jumps(jumps, items_by_jump, dag): """Raises an exception in case jumps are not valid. Args: jumps (list of Jump): jumps + items_by_jump (dict): mapping from jump to list of item names dag (DiGraph): jumps' DAG """ - items_by_jump = _get_items_by_jump(jumps, dag) for jump in jumps: validate_single_jump(jump, jumps, dag, items_by_jump) @@ -777,7 +791,7 @@ def validate_single_jump(jump, jumps, dag, items_by_jump=None): Args: jump (Jump): the jump to check - jumps (list of Jump): all jumps in dag + jumps (list of Jump): all jumps in DAG dag (DiGraph): jumps' DAG items_by_jump (dict, optional): mapping jumps to a set of items in between destination and source """ diff --git a/spine_engine/utils/helpers.py b/spine_engine/utils/helpers.py index 1d95a1d..1d840de 100644 --- a/spine_engine/utils/helpers.py +++ b/spine_engine/utils/helpers.py @@ -405,7 +405,7 @@ def make_connections(connections, permitted_items): list of Connection: List of permitted Connections or an empty list if the DAG contains no connections """ if not connections: - return list() + return [] connections = connections_to_selected_items(connections, permitted_items) return connections @@ -433,10 +433,10 @@ def dag_edges(connections): Returns: dict: DAG edges. Mapping of source item (node) to a list of destination items (nodes) """ - edges = dict() + edges = {} for connection in connections: source, destination = connection.source, connection.destination - edges.setdefault(source, list()).append(destination) + edges.setdefault(source, []).append(destination) return edges diff --git a/tests/project_item/test_connection.py b/tests/project_item/test_connection.py index a98bda5..4d1f812 100644 --- a/tests/project_item/test_connection.py +++ b/tests/project_item/test_connection.py @@ -11,11 +11,12 @@ ###################################################################################################################### """ Uni tests for the ``connection`` module. """ import os.path +import pathlib from tempfile import TemporaryDirectory import unittest from unittest.mock import Mock from spine_engine.project_item.connection import Connection, FilterSettings, Jump -from spine_engine.project_item.project_item_resource import database_resource +from spine_engine.project_item.project_item_resource import LabelArg, database_resource, file_resource from spinedb_api import DatabaseMapping, import_alternatives, import_entity_classes, import_scenarios from spinedb_api.filters.scenario_filter import SCENARIO_FILTER_TYPE @@ -169,6 +170,31 @@ def test_counter_passed_to_condition(self): jump.make_logger(Mock()) self.assertTrue(jump.is_condition_true(23)) + def test_command_line_args_with_whitespace_are_not_broken_into_tokens(self): + # Curiously, this test fails when run under PyCharm's debugger. + with TemporaryDirectory() as temp_dir: + path = pathlib.Path(temp_dir) / "path with spaces" / "file name.txt" + condition = { + "type": "python-script", + "script": "\n".join( + ( + "from pathlib import Path", + "import sys", + "if len(sys.argv) != 3:", + " exit(1)", + f"expected_path = Path(r'{str(path)}').resolve()", + "if Path(sys.argv[1]).resolve() != expected_path:", + " exit(1)", + "exit(0 if int(sys.argv[2]) == 23 else 1)", + ) + ), + } + jump = Jump("source", "bottom", "destination", "top", condition, [LabelArg("arg_label")]) + resource = file_resource("provider: unit test", str(path), "arg_label") + jump.receive_resources_from_source([resource]) + jump.make_logger(Mock()) + self.assertTrue(jump.is_condition_true(23)) + def test_dictionary(self): jump = Jump("source", "bottom", "destination", "top", {"type": "python-script", "script": "exit(23)"}) jump_dict = jump.to_dict() diff --git a/tests/test_spine_engine.py b/tests/test_spine_engine.py index 6228476..44310dc 100644 --- a/tests/test_spine_engine.py +++ b/tests/test_spine_engine.py @@ -16,6 +16,7 @@ Inspired from tests for spinetoolbox.ExecutionInstance and spinetoolbox.ResourceMap, and intended to supersede them. """ +from functools import partial import os.path import sys from tempfile import TemporaryDirectory @@ -43,6 +44,11 @@ class TestSpineEngine(unittest.TestCase): "script": "\n".join(["import sys", "loop_counter = int(sys.argv[1])", "exit(0 if loop_counter < 2 else 1)"]), } + _LOOP_FOREVER = { + "type": "python-script", + "script": "\n".join(["exit(0)"]), + } + @staticmethod def _mock_item( name, resources_forward=None, resources_backward=None, execute_outcome=ItemExecutionFinishState.SUCCESS @@ -97,7 +103,7 @@ def _default_backward_url_resource(url, item_name, successor_name, scenarios=Non } return resource - def _run_engine(self, items, connections, item_instances, execution_permits=None, jumps=None): + def _create_engine(self, items, connections, item_instances, execution_permits=None, jumps=None): if execution_permits is None: execution_permits = {item_name: True for item_name in items} with patch("spine_engine.spine_engine.create_timestamp") as mock_create_timestamp: @@ -116,6 +122,10 @@ def make_item(name, direction): return item_instances[name][0] engine.make_item = make_item + return engine + + def _run_engine(self, items, connections, item_instances, execution_permits=None, jumps=None): + engine = self._create_engine(items, connections, item_instances, execution_permits, jumps) engine.run() self.assertEqual(engine.state(), SpineEngineState.COMPLETED) @@ -858,6 +868,46 @@ def test_nested_jump_with_inner_self_jump(self): expected = 2 * [[[self._default_forward_url_resource(url_fw_b, "b")], []]] self._assert_resource_args(item_c.execute.call_args_list, expected) + def test_stopping_execution_in_the_middle_of_a_loop_does_not_leave_multithread_executor_running(self): + item_a = self._mock_item("a") + item_b = self._mock_item("b") + item_instances = {"a": [item_a, item_a, item_a, item_a], "b": [item_b]} + items = { + "a": {"type": "TestItem"}, + "b": {"type": "TestItem"}, + } + connections = [c.to_dict() for c in (Connection("a", "right", "b", "left"),)] + jumps = [Jump("a", "right", "a", "right", self._LOOP_FOREVER).to_dict()] + engine = self._create_engine(items, connections, item_instances, jumps=jumps) + + def execute_item_a(loop_counter, *args, **kwargs): + if loop_counter[0] == 2: + engine.stop() + return ItemExecutionFinishState.STOPPED + loop_counter[0] += 1 + return ItemExecutionFinishState.SUCCESS + + loop_counter = [0] + item_a.execute.side_effect = partial(execute_item_a, loop_counter) + engine.run() + self.assertEqual(engine.state(), SpineEngineState.USER_STOPPED) + self.assertEqual(item_a.execute.call_count, 3) + item_b.execute.assert_not_called() + + def test_executing_loop_source_item_only_does_not_execute_the_loop(self): + item_a = self._mock_item("a") + item_b = self._mock_item("b") + item_instances = {"a": [item_a, item_a, item_a, item_a], "b": [item_b, item_b, item_b, item_b]} + items = { + "a": {"type": "TestItem"}, + "b": {"type": "TestItem"}, + } + connections = [c.to_dict() for c in (Connection("a", "right", "b", "left"),)] + jumps = [Jump("b", "right", "a", "right", self._LOOP_FOREVER).to_dict()] + self._run_engine(items, connections, item_instances, execution_permits={"a": False, "b": True}, jumps=jumps) + self.assertEqual(item_a.execute.call_count, 0) + self.assertEqual(item_b.execute.call_count, 1) + def _assert_resource_args(self, arg_packs, expected_packs): self.assertEqual(len(arg_packs), len(expected_packs)) for pack, expected_pack in zip(arg_packs, expected_packs):