Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with jumps #150

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions spine_engine/jumpster.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, event_type, item_name, direction, **kwargs):


class Failure(BaseException):
"""A failure"""
"""Item execution stopped by failure or user interruption."""


class Output:
Expand All @@ -80,7 +80,7 @@ def __init__(self, item_finish_state):

class JumpsterThreadError(Exception):
"""An exception has occurred in one or more of the threads jumpster manages.
This error forwards the message and stack trace for all of the collected errors.
This error forwards the message and stack trace for all the collected errors.
"""

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -251,17 +251,14 @@ def execute(self):
empty_iters.append(key)
except StopIteration:
empty_iters.append(key)
# TODO: Anything about loops?
# clear and mark complete finished iterators
for key in empty_iters:
del active_iters[key]
errs = {tid: err for tid, err in errors.items() if err}
if errs:
raise JumpsterThreadError(
"During multithread execution errors occurred in threads:\n{error_list}".format(
error_list="\n".join(
["In thread {tid}: {err}".format(tid=tid, err=err.to_string()) for tid, err in errs.items()]
)
error_list="\n".join([f"In thread {tid}: {err.to_string()}" for tid, err in errs.items()])
),
thread_error_infos=list(errs.values()),
)
Expand Down Expand Up @@ -338,6 +335,7 @@ def execute_step_in_thread(step, errors):

Args:
step (Step): step to execute
errors (dict): mapping from thread id to error info
"""
event_queue = queue.Queue()
thread = threading.Thread(target=_do_execute_step_in_thread, args=(event_queue, step))
Expand Down Expand Up @@ -394,4 +392,4 @@ def _do_execute_step_in_thread(event_queue, step):
ThreadSystemErrorEvent(tid=tid, error_info=serializable_error_info_from_exc_info(sys.exc_info()))
)
except Failure:
pass
event_queue.put(JumpsterEvent(JumpsterEventType.STEP_FAILURE, step.item_name, step.direction))
10 changes: 6 additions & 4 deletions spine_engine/project_item/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,20 +589,22 @@ def from_dict(cls, connection_dict):
class Jump(ConnectionBase):
"""Represents a conditional jump between two project items."""

_DEFAULT_CONDITION = {"type": "python-script", "script": "exit(1)", "specification": ""}

def __init__(
self, source_name, source_position, destination_name, destination_position, condition={}, cmd_line_args=()
self, source_name, source_position, destination_name, destination_position, condition=None, cmd_line_args=()
):
"""
Args:
source_name (str): source project item's name
source_position (str): source anchor's position
destination_name (str): destination project item's name
destination_position (str): destination anchor's position
condition (dict): jump condition
condition (dict, optional): jump condition
cmd_line_args (Iterable of str): command line arguments
"""
super().__init__(source_name, source_position, destination_name, destination_position)
default_condition = {"type": "python-script", "script": "exit(1)", "specification": ""}
self.condition = condition if condition else default_condition
self.condition = condition if condition is not None else self._DEFAULT_CONDITION
self._resources_from_source = set()
self._resources_from_destination = set()
self.cmd_line_args = list(cmd_line_args)
Expand Down
6 changes: 2 additions & 4 deletions spine_engine/project_item/executable_item_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
# this program. If not, see <http://www.gnu.org/licenses/>.
######################################################################################################################

"""
Contains ExecutableItem, a project item's counterpart in execution as well as support utilities.

"""
""" Contains ExecutableItem, a project item's counterpart in execution as well as support utilities. """
from hashlib import sha1
from pathlib import Path
from ..utils.helpers import ExecutionDirection, ItemExecutionFinishState, shorten
Expand All @@ -28,6 +25,7 @@ def __init__(self, name, project_dir, logger, group_id=None):
name (str): item's name
project_dir (str): absolute path to project directory
logger (LoggerInterface): a logger
group_id (str, optional): execution group identifier
"""
self._name = name
self._project_dir = project_dir
Expand Down
30 changes: 22 additions & 8 deletions spine_engine/spine_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,14 @@
Raises:
EngineInitFailed: Raised if initialization fails
"""
super().__init__()
self._queue = mp.Queue()
if items is None:
items = {}
self._items = items
if execution_permits is None:
execution_permits = {}
self._execution_permits = execution_permits
connections = list(map(Connection.from_dict, connections)) # Deserialize connections
connections = list(map(Connection.from_dict, connections))
project_item_loader = ProjectItemLoader()
self._executable_item_classes = project_item_loader.load_executable_item_classes(items_module_name)
required_items = required_items_for_execution(
Expand Down Expand Up @@ -137,8 +136,11 @@
self._item_names = list(self._dag) # Names of permitted items and their neighbors
if jumps is None:
jumps = []
self._jumps = list(map(Jump.from_dict, jumps))
validate_jumps(self._jumps, self._dag)
else:
jumps = list(map(Jump.from_dict, jumps))
items_by_jump = _get_items_by_jump(jumps, self._dag)
self._jumps = filter_unneeded_jumps(jumps, items_by_jump, execution_permits)
validate_jumps(self._jumps, items_by_jump, self._dag)
for x in self._connections + self._jumps:
x.make_logger(self._queue)
for x in self._jumps:
Expand Down Expand Up @@ -275,10 +277,11 @@
self._thread.start()
while True:
msg = self._queue.get()
yield msg
if msg[0] == "dag_exec_finished":
break
yield msg

Check warning on line 282 in spine_engine/spine_engine.py

View check run for this annotation

Codecov / codecov/patch

spine_engine/spine_engine.py#L282

Added line #L282 was not covered by tests
self._thread.join()
yield msg

Check warning on line 284 in spine_engine/spine_engine.py

View check run for this annotation

Codecov / codecov/patch

spine_engine/spine_engine.py#L284

Added line #L284 was not covered by tests

def answer_prompt(self, prompter_id, answer):
"""Answers the prompt for the specified prompter id."""
Expand Down Expand Up @@ -760,14 +763,25 @@
raise EngineInitFailed("DAG contains unconnected items.")


def validate_jumps(jumps, dag):
def filter_unneeded_jumps(jumps, items_by_jump, execution_permits):
"""Drops jumps whose items are not going to be executed.

Args:
jumps (Iterable of Jump): jumps to filter
items_by_jump (dict): mapping from jump to list of item names
execution_permits (dict): mapping from item name to boolean telling if its is permitted to execute
"""
return [jump for jump in jumps if all(execution_permits[item] for item in items_by_jump[jump])]


def validate_jumps(jumps, items_by_jump, dag):
"""Raises an exception in case jumps are not valid.

Args:
jumps (list of Jump): jumps
items_by_jump (dict): mapping from jump to list of item names
dag (DiGraph): jumps' DAG
"""
items_by_jump = _get_items_by_jump(jumps, dag)
for jump in jumps:
validate_single_jump(jump, jumps, dag, items_by_jump)

Expand All @@ -777,7 +791,7 @@

Args:
jump (Jump): the jump to check
jumps (list of Jump): all jumps in dag
jumps (list of Jump): all jumps in DAG
dag (DiGraph): jumps' DAG
items_by_jump (dict, optional): mapping jumps to a set of items in between destination and source
"""
Expand Down
6 changes: 3 additions & 3 deletions spine_engine/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def make_connections(connections, permitted_items):
list of Connection: List of permitted Connections or an empty list if the DAG contains no connections
"""
if not connections:
return list()
return []
connections = connections_to_selected_items(connections, permitted_items)
return connections

Expand Down Expand Up @@ -433,10 +433,10 @@ def dag_edges(connections):
Returns:
dict: DAG edges. Mapping of source item (node) to a list of destination items (nodes)
"""
edges = dict()
edges = {}
for connection in connections:
source, destination = connection.source, connection.destination
edges.setdefault(source, list()).append(destination)
edges.setdefault(source, []).append(destination)
return edges


Expand Down
28 changes: 27 additions & 1 deletion tests/project_item/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
######################################################################################################################
""" Uni tests for the ``connection`` module. """
import os.path
import pathlib
from tempfile import TemporaryDirectory
import unittest
from unittest.mock import Mock
from spine_engine.project_item.connection import Connection, FilterSettings, Jump
from spine_engine.project_item.project_item_resource import database_resource
from spine_engine.project_item.project_item_resource import LabelArg, database_resource, file_resource
from spinedb_api import DatabaseMapping, import_alternatives, import_entity_classes, import_scenarios
from spinedb_api.filters.scenario_filter import SCENARIO_FILTER_TYPE

Expand Down Expand Up @@ -169,6 +170,31 @@ def test_counter_passed_to_condition(self):
jump.make_logger(Mock())
self.assertTrue(jump.is_condition_true(23))

def test_command_line_args_with_whitespace_are_not_broken_into_tokens(self):
# Curiously, this test fails when run under PyCharm's debugger.
with TemporaryDirectory() as temp_dir:
path = pathlib.Path(temp_dir) / "path with spaces" / "file name.txt"
condition = {
"type": "python-script",
"script": "\n".join(
(
"from pathlib import Path",
"import sys",
"if len(sys.argv) != 3:",
" exit(1)",
f"expected_path = Path(r'{str(path)}').resolve()",
"if Path(sys.argv[1]).resolve() != expected_path:",
" exit(1)",
"exit(0 if int(sys.argv[2]) == 23 else 1)",
)
),
}
jump = Jump("source", "bottom", "destination", "top", condition, [LabelArg("arg_label")])
resource = file_resource("provider: unit test", str(path), "arg_label")
jump.receive_resources_from_source([resource])
jump.make_logger(Mock())
self.assertTrue(jump.is_condition_true(23))

def test_dictionary(self):
jump = Jump("source", "bottom", "destination", "top", {"type": "python-script", "script": "exit(23)"})
jump_dict = jump.to_dict()
Expand Down
52 changes: 51 additions & 1 deletion tests/test_spine_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Inspired from tests for spinetoolbox.ExecutionInstance and spinetoolbox.ResourceMap,
and intended to supersede them.
"""
from functools import partial
import os.path
import sys
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -43,6 +44,11 @@ class TestSpineEngine(unittest.TestCase):
"script": "\n".join(["import sys", "loop_counter = int(sys.argv[1])", "exit(0 if loop_counter < 2 else 1)"]),
}

_LOOP_FOREVER = {
"type": "python-script",
"script": "\n".join(["exit(0)"]),
}

@staticmethod
def _mock_item(
name, resources_forward=None, resources_backward=None, execute_outcome=ItemExecutionFinishState.SUCCESS
Expand Down Expand Up @@ -97,7 +103,7 @@ def _default_backward_url_resource(url, item_name, successor_name, scenarios=Non
}
return resource

def _run_engine(self, items, connections, item_instances, execution_permits=None, jumps=None):
def _create_engine(self, items, connections, item_instances, execution_permits=None, jumps=None):
if execution_permits is None:
execution_permits = {item_name: True for item_name in items}
with patch("spine_engine.spine_engine.create_timestamp") as mock_create_timestamp:
Expand All @@ -116,6 +122,10 @@ def make_item(name, direction):
return item_instances[name][0]

engine.make_item = make_item
return engine

def _run_engine(self, items, connections, item_instances, execution_permits=None, jumps=None):
engine = self._create_engine(items, connections, item_instances, execution_permits, jumps)
engine.run()
self.assertEqual(engine.state(), SpineEngineState.COMPLETED)

Expand Down Expand Up @@ -858,6 +868,46 @@ def test_nested_jump_with_inner_self_jump(self):
expected = 2 * [[[self._default_forward_url_resource(url_fw_b, "b")], []]]
self._assert_resource_args(item_c.execute.call_args_list, expected)

def test_stopping_execution_in_the_middle_of_a_loop_does_not_leave_multithread_executor_running(self):
item_a = self._mock_item("a")
item_b = self._mock_item("b")
item_instances = {"a": [item_a, item_a, item_a, item_a], "b": [item_b]}
items = {
"a": {"type": "TestItem"},
"b": {"type": "TestItem"},
}
connections = [c.to_dict() for c in (Connection("a", "right", "b", "left"),)]
jumps = [Jump("a", "right", "a", "right", self._LOOP_FOREVER).to_dict()]
engine = self._create_engine(items, connections, item_instances, jumps=jumps)

def execute_item_a(loop_counter, *args, **kwargs):
if loop_counter[0] == 2:
engine.stop()
return ItemExecutionFinishState.STOPPED
loop_counter[0] += 1
return ItemExecutionFinishState.SUCCESS

loop_counter = [0]
item_a.execute.side_effect = partial(execute_item_a, loop_counter)
engine.run()
self.assertEqual(engine.state(), SpineEngineState.USER_STOPPED)
self.assertEqual(item_a.execute.call_count, 3)
item_b.execute.assert_not_called()

def test_executing_loop_source_item_only_does_not_execute_the_loop(self):
item_a = self._mock_item("a")
item_b = self._mock_item("b")
item_instances = {"a": [item_a, item_a, item_a, item_a], "b": [item_b, item_b, item_b, item_b]}
items = {
"a": {"type": "TestItem"},
"b": {"type": "TestItem"},
}
connections = [c.to_dict() for c in (Connection("a", "right", "b", "left"),)]
jumps = [Jump("b", "right", "a", "right", self._LOOP_FOREVER).to_dict()]
self._run_engine(items, connections, item_instances, execution_permits={"a": False, "b": True}, jumps=jumps)
self.assertEqual(item_a.execute.call_count, 0)
self.assertEqual(item_b.execute.call_count, 1)

def _assert_resource_args(self, arg_packs, expected_packs):
self.assertEqual(len(arg_packs), len(expected_packs))
for pack, expected_pack in zip(arg_packs, expected_packs):
Expand Down