diff --git a/aiida/backends/tests/work/workChain.py b/aiida/backends/tests/work/workChain.py index 2f662f8ff8..9ab8eb3539 100644 --- a/aiida/backends/tests/work/workChain.py +++ b/aiida/backends/tests/work/workChain.py @@ -14,6 +14,7 @@ from aiida.backends.testbase import AiidaTestCase from plum.engine.ticking import TickingEngine +from plum.persistence.bundle import Bundle import plum.process_monitor from aiida.orm.calculation.work import WorkCalculation from aiida.work.workchain import WorkChain, \ @@ -61,7 +62,7 @@ def __init__(self): [self.s1.__name__, self.s2.__name__, self.s3.__name__, self.s4.__name__, self.s5.__name__, self.s6.__name__, self.isA.__name__, self.isB.__name__, self.ltN.__name__] - } + } def s1(self): self._set_finished(inspect.stack()[0][3]) @@ -120,6 +121,33 @@ def test_dict(self): c['new_attr'] +class IfTest(WorkChain): + @classmethod + def define(cls, spec): + super(IfTest, cls).define(spec) + spec.outline( + if_(cls.condition)( + cls.step1, + cls.step2 + ) + ) + + def on_create(self, pid, inputs, saved_state): + super(IfTest, self).on_create(pid, inputs, saved_state) + if saved_state is None: + self.ctx.s1 = False + self.ctx.s2 = False + + def condition(self): + return True + + def step1(self): + self.ctx.s1 = True + + def step2(self): + self.ctx.s2 = True + + class TestWorkchain(AiidaTestCase): def setUp(self): super(TestWorkchain, self).setUp() @@ -321,6 +349,29 @@ def run(self): run(MainWorkChain) + def test_if_block_persistence(self): + """ This test was created to capture issue #902 """ + wc = IfTest.new_instance() + + while not wc.ctx.s1 and not wc.has_finished(): + wc.tick() + self.assertTrue(wc.ctx.s1) + self.assertFalse(wc.ctx.s2) + + # Now bundle the thing + b = Bundle() + wc.save_instance_state(b) + # Abort the current one + wc.stop() + wc.destroy(execute=True) + + # Load from saved tate + wc = IfTest.create_from(b) + self.assertTrue(wc.ctx.s1) + self.assertFalse(wc.ctx.s2) + + wc.run_until_complete() + def _run_with_checkpoints(self, wf_class, inputs=None): finished_steps = {} @@ -336,7 +387,6 @@ def _run_with_checkpoints(self, wf_class, inputs=None): class TestWorkchainWithOldWorkflows(AiidaTestCase): - def setUp(self): super(TestWorkchainWithOldWorkflows, self).setUp() import logging @@ -409,10 +459,12 @@ def test_get_proc_outputs(self): self.assertEquals(outputs['a'], a) self.assertEquals(outputs['b'], b) + class TestWorkChainAbort(AiidaTestCase): """ Test the functionality to abort a workchain """ + class AbortableWorkChain(WorkChain): @classmethod def define(cls, spec): @@ -490,11 +542,13 @@ def test_simple_kill_through_process(self): self.assertEquals(future.process.calc.has_aborted(), True) engine.shutdown() + class TestWorkChainAbortChildren(AiidaTestCase): """ Test the functionality to abort a workchain and verify that children are also aborted appropriately """ + class SubWorkChain(WorkChain): @classmethod def define(cls, spec): @@ -575,4 +629,4 @@ def test_simple_kill_through_node(self): self.assertEquals(future.process.calc.has_finished_ok(), False) self.assertEquals(future.process.calc.has_failed(), False) self.assertEquals(future.process.calc.has_aborted(), True) - engine.shutdown() \ No newline at end of file + engine.shutdown() diff --git a/aiida/work/workchain.py b/aiida/work/workchain.py index f35f630adf..e5a2564ad9 100644 --- a/aiida/work/workchain.py +++ b/aiida/work/workchain.py @@ -342,6 +342,7 @@ def abort(self, msg=None, timeout=None): self._aborted = True self.stop() + def ToContext(**kwargs): """ Utility function that returns a list of UpdateContext Interstep instances @@ -366,6 +367,7 @@ class _InterstepFactory(object): Factory to create the appropriate Interstep instance based on the class string that was written to the bundle """ + def create(self, bundle): class_string = bundle[Bundle.CLASS] if class_string == get_class_string(ToContext): @@ -567,22 +569,21 @@ class Stepper(Stepper): def __init__(self, workflow, if_spec): super(_If.Stepper, self).__init__(workflow) self._if_spec = if_spec - self._pos = 0 + self._pos = -1 self._current_stepper = None def step(self): if self._current_stepper is None: - stepper = self._get_next_stepper() - # If we can't get a stepper then no conditions match, return - if stepper is None: - return True, None - self._current_stepper = stepper + self._create_stepper() + + # If we can't get a stepper then no conditions match, return + if self._current_stepper is None: + return True, None finished, retval = self._current_stepper.step() if finished: self._current_stepper = None - else: - self._pos += 1 + self._pos = -1 return finished, retval @@ -596,15 +597,24 @@ def save_position(self, out_position): def load_position(self, bundle): self._pos = bundle[self._POSITION] if self._STEPPER_POS in bundle: - self._current_stepper = self._get_next_stepper() + self._create_stepper() self._current_stepper.load_position(bundle[self._STEPPER_POS]) + else: + self._current_stepper = None - def _get_next_stepper(self): - # Check the conditions until we find that that is true - for conditional in self._if_spec.conditionals[self._pos:]: - if conditional.is_true(self._workflow): - return conditional.body.create_stepper(self._workflow) - return None + def _create_stepper(self): + if self._pos == -1: + self._current_stepper = None + # Check the conditions until we find one that is true + for idx, condition in enumerate(self._if_spec.conditionals): + if condition.is_true(self._workflow): + stepper = condition.body.create_stepper(self._workflow) + self._pos = idx + self._current_stepper = stepper + return + else: + branch = self._if_spec.conditionals[self._pos] + self._current_stepper = branch.body.create_stepper(self._workflow) def __init__(self, condition): super(_If, self).__init__()