Skip to content

Commit

Permalink
Fix issues with jumps (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen authored Sep 4, 2024
2 parents 4fa66a3 + 6ba35fc commit 965a60e
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 28 deletions.
12 changes: 5 additions & 7 deletions spine_engine/jumpster.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, event_type, item_name, direction, **kwargs):


class Failure(BaseException):
"""A failure"""
"""Item execution stopped by failure or user interruption."""


class Output:
Expand All @@ -80,7 +80,7 @@ def __init__(self, item_finish_state):

class JumpsterThreadError(Exception):
"""An exception has occurred in one or more of the threads jumpster manages.
This error forwards the message and stack trace for all of the collected errors.
This error forwards the message and stack trace for all the collected errors.
"""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -251,17 +251,14 @@ def execute(self):
empty_iters.append(key)
except StopIteration:
empty_iters.append(key)
# TODO: Anything about loops?
# clear and mark complete finished iterators
for key in empty_iters:
del active_iters[key]
errs = {tid: err for tid, err in errors.items() if err}
if errs:
raise JumpsterThreadError(
"During multithread execution errors occurred in threads:\n{error_list}".format(
error_list="\n".join(
["In thread {tid}: {err}".format(tid=tid, err=err.to_string()) for tid, err in errs.items()]
)
error_list="\n".join([f"In thread {tid}: {err.to_string()}" for tid, err in errs.items()])
),
thread_error_infos=list(errs.values()),
)
Expand Down Expand Up @@ -338,6 +335,7 @@ def execute_step_in_thread(step, errors):
Args:
step (Step): step to execute
errors (dict): mapping from thread id to error info
"""
event_queue = queue.Queue()
thread = threading.Thread(target=_do_execute_step_in_thread, args=(event_queue, step))
Expand Down Expand Up @@ -394,4 +392,4 @@ def _do_execute_step_in_thread(event_queue, step):
ThreadSystemErrorEvent(tid=tid, error_info=serializable_error_info_from_exc_info(sys.exc_info()))
)
except Failure:
pass
event_queue.put(JumpsterEvent(JumpsterEventType.STEP_FAILURE, step.item_name, step.direction))
10 changes: 6 additions & 4 deletions spine_engine/project_item/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,20 +589,22 @@ def from_dict(cls, connection_dict):
class Jump(ConnectionBase):
"""Represents a conditional jump between two project items."""

_DEFAULT_CONDITION = {"type": "python-script", "script": "exit(1)", "specification": ""}

def __init__(
self, source_name, source_position, destination_name, destination_position, condition={}, cmd_line_args=()
self, source_name, source_position, destination_name, destination_position, condition=None, cmd_line_args=()
):
"""
Args:
source_name (str): source project item's name
source_position (str): source anchor's position
destination_name (str): destination project item's name
destination_position (str): destination anchor's position
condition (dict): jump condition
condition (dict, optional): jump condition
cmd_line_args (Iterable of str): command line arguments
"""
super().__init__(source_name, source_position, destination_name, destination_position)
default_condition = {"type": "python-script", "script": "exit(1)", "specification": ""}
self.condition = condition if condition else default_condition
self.condition = condition if condition is not None else self._DEFAULT_CONDITION
self._resources_from_source = set()
self._resources_from_destination = set()
self.cmd_line_args = list(cmd_line_args)
Expand Down
6 changes: 2 additions & 4 deletions spine_engine/project_item/executable_item_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################

"""
Contains ExecutableItem, a project item's counterpart in execution as well as support utilities.
"""
""" Contains ExecutableItem, a project item's counterpart in execution as well as support utilities. """
from hashlib import sha1
from pathlib import Path
from ..utils.helpers import ExecutionDirection, ItemExecutionFinishState, shorten
Expand All @@ -28,6 +25,7 @@ def __init__(self, name, project_dir, logger, group_id=None):
name (str): item's name
project_dir (str): absolute path to project directory
logger (LoggerInterface): a logger
group_id (str, optional): execution group identifier
"""
self._name = name
self._project_dir = project_dir
Expand Down
30 changes: 22 additions & 8 deletions spine_engine/spine_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,14 @@ def __init__(
Raises:
EngineInitFailed: Raised if initialization fails
"""
super().__init__()
self._queue = mp.Queue()
if items is None:
items = {}
self._items = items
if execution_permits is None:
execution_permits = {}
self._execution_permits = execution_permits
connections = list(map(Connection.from_dict, connections)) # Deserialize connections
connections = list(map(Connection.from_dict, connections))
project_item_loader = ProjectItemLoader()
self._executable_item_classes = project_item_loader.load_executable_item_classes(items_module_name)
required_items = required_items_for_execution(
Expand Down Expand Up @@ -137,8 +136,11 @@ def __init__(
self._item_names = list(self._dag) # Names of permitted items and their neighbors
if jumps is None:
jumps = []
self._jumps = list(map(Jump.from_dict, jumps))
validate_jumps(self._jumps, self._dag)
else:
jumps = list(map(Jump.from_dict, jumps))
items_by_jump = _get_items_by_jump(jumps, self._dag)
self._jumps = filter_unneeded_jumps(jumps, items_by_jump, execution_permits)
validate_jumps(self._jumps, items_by_jump, self._dag)
for x in self._connections + self._jumps:
x.make_logger(self._queue)
for x in self._jumps:
Expand Down Expand Up @@ -275,10 +277,11 @@ def _get_event_stream(self):
self._thread.start()
while True:
msg = self._queue.get()
yield msg
if msg[0] == "dag_exec_finished":
break
yield msg
self._thread.join()
yield msg

def answer_prompt(self, prompter_id, answer):
"""Answers the prompt for the specified prompter id."""
Expand Down Expand Up @@ -760,14 +763,25 @@ def _validate_dag(dag):
raise EngineInitFailed("DAG contains unconnected items.")


def validate_jumps(jumps, dag):
def filter_unneeded_jumps(jumps, items_by_jump, execution_permits):
"""Drops jumps whose items are not going to be executed.
Args:
jumps (Iterable of Jump): jumps to filter
items_by_jump (dict): mapping from jump to list of item names
execution_permits (dict): mapping from item name to boolean telling if its is permitted to execute
"""
return [jump for jump in jumps if all(execution_permits[item] for item in items_by_jump[jump])]


def validate_jumps(jumps, items_by_jump, dag):
"""Raises an exception in case jumps are not valid.
Args:
jumps (list of Jump): jumps
items_by_jump (dict): mapping from jump to list of item names
dag (DiGraph): jumps' DAG
"""
items_by_jump = _get_items_by_jump(jumps, dag)
for jump in jumps:
validate_single_jump(jump, jumps, dag, items_by_jump)

Expand All @@ -777,7 +791,7 @@ def validate_single_jump(jump, jumps, dag, items_by_jump=None):
Args:
jump (Jump): the jump to check
jumps (list of Jump): all jumps in dag
jumps (list of Jump): all jumps in DAG
dag (DiGraph): jumps' DAG
items_by_jump (dict, optional): mapping jumps to a set of items in between destination and source
"""
Expand Down
6 changes: 3 additions & 3 deletions spine_engine/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def make_connections(connections, permitted_items):
list of Connection: List of permitted Connections or an empty list if the DAG contains no connections
"""
if not connections:
return list()
return []
connections = connections_to_selected_items(connections, permitted_items)
return connections

Expand Down Expand Up @@ -433,10 +433,10 @@ def dag_edges(connections):
Returns:
dict: DAG edges. Mapping of source item (node) to a list of destination items (nodes)
"""
edges = dict()
edges = {}
for connection in connections:
source, destination = connection.source, connection.destination
edges.setdefault(source, list()).append(destination)
edges.setdefault(source, []).append(destination)
return edges


Expand Down
28 changes: 27 additions & 1 deletion tests/project_item/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
######################################################################################################################
""" Uni tests for the ``connection`` module. """
import os.path
import pathlib
from tempfile import TemporaryDirectory
import unittest
from unittest.mock import Mock
from spine_engine.project_item.connection import Connection, FilterSettings, Jump
from spine_engine.project_item.project_item_resource import database_resource
from spine_engine.project_item.project_item_resource import LabelArg, database_resource, file_resource
from spinedb_api import DatabaseMapping, import_alternatives, import_entity_classes, import_scenarios
from spinedb_api.filters.scenario_filter import SCENARIO_FILTER_TYPE

Expand Down Expand Up @@ -169,6 +170,31 @@ def test_counter_passed_to_condition(self):
jump.make_logger(Mock())
self.assertTrue(jump.is_condition_true(23))

def test_command_line_args_with_whitespace_are_not_broken_into_tokens(self):
# Curiously, this test fails when run under PyCharm's debugger.
with TemporaryDirectory() as temp_dir:
path = pathlib.Path(temp_dir) / "path with spaces" / "file name.txt"
condition = {
"type": "python-script",
"script": "\n".join(
(
"from pathlib import Path",
"import sys",
"if len(sys.argv) != 3:",
" exit(1)",
f"expected_path = Path(r'{str(path)}').resolve()",
"if Path(sys.argv[1]).resolve() != expected_path:",
" exit(1)",
"exit(0 if int(sys.argv[2]) == 23 else 1)",
)
),
}
jump = Jump("source", "bottom", "destination", "top", condition, [LabelArg("arg_label")])
resource = file_resource("provider: unit test", str(path), "arg_label")
jump.receive_resources_from_source([resource])
jump.make_logger(Mock())
self.assertTrue(jump.is_condition_true(23))

def test_dictionary(self):
jump = Jump("source", "bottom", "destination", "top", {"type": "python-script", "script": "exit(23)"})
jump_dict = jump.to_dict()
Expand Down
52 changes: 51 additions & 1 deletion tests/test_spine_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Inspired from tests for spinetoolbox.ExecutionInstance and spinetoolbox.ResourceMap,
and intended to supersede them.
"""
from functools import partial
import os.path
import sys
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -43,6 +44,11 @@ class TestSpineEngine(unittest.TestCase):
"script": "\n".join(["import sys", "loop_counter = int(sys.argv[1])", "exit(0 if loop_counter < 2 else 1)"]),
}

_LOOP_FOREVER = {
"type": "python-script",
"script": "\n".join(["exit(0)"]),
}

@staticmethod
def _mock_item(
name, resources_forward=None, resources_backward=None, execute_outcome=ItemExecutionFinishState.SUCCESS
Expand Down Expand Up @@ -97,7 +103,7 @@ def _default_backward_url_resource(url, item_name, successor_name, scenarios=Non
}
return resource

def _run_engine(self, items, connections, item_instances, execution_permits=None, jumps=None):
def _create_engine(self, items, connections, item_instances, execution_permits=None, jumps=None):
if execution_permits is None:
execution_permits = {item_name: True for item_name in items}
with patch("spine_engine.spine_engine.create_timestamp") as mock_create_timestamp:
Expand All @@ -116,6 +122,10 @@ def make_item(name, direction):
return item_instances[name][0]

engine.make_item = make_item
return engine

def _run_engine(self, items, connections, item_instances, execution_permits=None, jumps=None):
engine = self._create_engine(items, connections, item_instances, execution_permits, jumps)
engine.run()
self.assertEqual(engine.state(), SpineEngineState.COMPLETED)

Expand Down Expand Up @@ -858,6 +868,46 @@ def test_nested_jump_with_inner_self_jump(self):
expected = 2 * [[[self._default_forward_url_resource(url_fw_b, "b")], []]]
self._assert_resource_args(item_c.execute.call_args_list, expected)

def test_stopping_execution_in_the_middle_of_a_loop_does_not_leave_multithread_executor_running(self):
item_a = self._mock_item("a")
item_b = self._mock_item("b")
item_instances = {"a": [item_a, item_a, item_a, item_a], "b": [item_b]}
items = {
"a": {"type": "TestItem"},
"b": {"type": "TestItem"},
}
connections = [c.to_dict() for c in (Connection("a", "right", "b", "left"),)]
jumps = [Jump("a", "right", "a", "right", self._LOOP_FOREVER).to_dict()]
engine = self._create_engine(items, connections, item_instances, jumps=jumps)

def execute_item_a(loop_counter, *args, **kwargs):
if loop_counter[0] == 2:
engine.stop()
return ItemExecutionFinishState.STOPPED
loop_counter[0] += 1
return ItemExecutionFinishState.SUCCESS

loop_counter = [0]
item_a.execute.side_effect = partial(execute_item_a, loop_counter)
engine.run()
self.assertEqual(engine.state(), SpineEngineState.USER_STOPPED)
self.assertEqual(item_a.execute.call_count, 3)
item_b.execute.assert_not_called()

def test_executing_loop_source_item_only_does_not_execute_the_loop(self):
item_a = self._mock_item("a")
item_b = self._mock_item("b")
item_instances = {"a": [item_a, item_a, item_a, item_a], "b": [item_b, item_b, item_b, item_b]}
items = {
"a": {"type": "TestItem"},
"b": {"type": "TestItem"},
}
connections = [c.to_dict() for c in (Connection("a", "right", "b", "left"),)]
jumps = [Jump("b", "right", "a", "right", self._LOOP_FOREVER).to_dict()]
self._run_engine(items, connections, item_instances, execution_permits={"a": False, "b": True}, jumps=jumps)
self.assertEqual(item_a.execute.call_count, 0)
self.assertEqual(item_b.execute.call_count, 1)

def _assert_resource_args(self, arg_packs, expected_packs):
self.assertEqual(len(arg_packs), len(expected_packs))
for pack, expected_pack in zip(arg_packs, expected_packs):
Expand Down

0 comments on commit 965a60e

Please sign in to comment.