Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix recreation of stepper in workchain 'if' logical block #904

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions aiida/backends/tests/work/workChain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from aiida.backends.testbase import AiidaTestCase
from plum.engine.ticking import TickingEngine
from plum.persistence.bundle import Bundle
import plum.process_monitor
from aiida.orm.calculation.work import WorkCalculation
from aiida.work.workchain import WorkChain, \
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self):
[self.s1.__name__, self.s2.__name__, self.s3.__name__,
self.s4.__name__, self.s5.__name__, self.s6.__name__,
self.isA.__name__, self.isB.__name__, self.ltN.__name__]
}
}

def s1(self):
self._set_finished(inspect.stack()[0][3])
Expand Down Expand Up @@ -120,6 +121,33 @@ def test_dict(self):
c['new_attr']


class IfTest(WorkChain):
@classmethod
def define(cls, spec):
super(IfTest, cls).define(spec)
spec.outline(
if_(cls.condition)(
cls.step1,
cls.step2
)
)

def on_create(self, pid, inputs, saved_state):
super(IfTest, self).on_create(pid, inputs, saved_state)
if saved_state is None:
self.ctx.s1 = False
self.ctx.s2 = False

def condition(self):
return True

def step1(self):
self.ctx.s1 = True

def step2(self):
self.ctx.s2 = True


class TestWorkchain(AiidaTestCase):
def setUp(self):
super(TestWorkchain, self).setUp()
Expand Down Expand Up @@ -321,6 +349,29 @@ def run(self):

run(MainWorkChain)

def test_if_block_persistence(self):
""" This test was created to capture issue #902 """
wc = IfTest.new_instance()

while not wc.ctx.s1 and not wc.has_finished():
wc.tick()
self.assertTrue(wc.ctx.s1)
self.assertFalse(wc.ctx.s2)

# Now bundle the thing
b = Bundle()
wc.save_instance_state(b)
# Abort the current one
wc.stop()
wc.destroy(execute=True)

# Load from saved tate
wc = IfTest.create_from(b)
self.assertTrue(wc.ctx.s1)
self.assertFalse(wc.ctx.s2)

wc.run_until_complete()

def _run_with_checkpoints(self, wf_class, inputs=None):
finished_steps = {}

Expand All @@ -336,7 +387,6 @@ def _run_with_checkpoints(self, wf_class, inputs=None):


class TestWorkchainWithOldWorkflows(AiidaTestCase):

def setUp(self):
super(TestWorkchainWithOldWorkflows, self).setUp()
import logging
Expand Down Expand Up @@ -409,10 +459,12 @@ def test_get_proc_outputs(self):
self.assertEquals(outputs['a'], a)
self.assertEquals(outputs['b'], b)


class TestWorkChainAbort(AiidaTestCase):
"""
Test the functionality to abort a workchain
"""

class AbortableWorkChain(WorkChain):
@classmethod
def define(cls, spec):
Expand Down Expand Up @@ -490,11 +542,13 @@ def test_simple_kill_through_process(self):
self.assertEquals(future.process.calc.has_aborted(), True)
engine.shutdown()


class TestWorkChainAbortChildren(AiidaTestCase):
"""
Test the functionality to abort a workchain and verify that children
are also aborted appropriately
"""

class SubWorkChain(WorkChain):
@classmethod
def define(cls, spec):
Expand Down Expand Up @@ -575,4 +629,4 @@ def test_simple_kill_through_node(self):
self.assertEquals(future.process.calc.has_finished_ok(), False)
self.assertEquals(future.process.calc.has_failed(), False)
self.assertEquals(future.process.calc.has_aborted(), True)
engine.shutdown()
engine.shutdown()
40 changes: 25 additions & 15 deletions aiida/work/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def abort(self, msg=None, timeout=None):
self._aborted = True
self.stop()


def ToContext(**kwargs):
"""
Utility function that returns a list of UpdateContext Interstep instances
Expand All @@ -366,6 +367,7 @@ class _InterstepFactory(object):
Factory to create the appropriate Interstep instance based
on the class string that was written to the bundle
"""

def create(self, bundle):
class_string = bundle[Bundle.CLASS]
if class_string == get_class_string(ToContext):
Expand Down Expand Up @@ -567,22 +569,21 @@ class Stepper(Stepper):
def __init__(self, workflow, if_spec):
super(_If.Stepper, self).__init__(workflow)
self._if_spec = if_spec
self._pos = 0
self._pos = -1
self._current_stepper = None

def step(self):
if self._current_stepper is None:
stepper = self._get_next_stepper()
# If we can't get a stepper then no conditions match, return
if stepper is None:
return True, None
self._current_stepper = stepper
self._create_stepper()

# If we can't get a stepper then no conditions match, return
if self._current_stepper is None:
return True, None

finished, retval = self._current_stepper.step()
if finished:
self._current_stepper = None
else:
self._pos += 1
self._pos = -1

return finished, retval

Expand All @@ -596,15 +597,24 @@ def save_position(self, out_position):
def load_position(self, bundle):
self._pos = bundle[self._POSITION]
if self._STEPPER_POS in bundle:
self._current_stepper = self._get_next_stepper()
self._create_stepper()
self._current_stepper.load_position(bundle[self._STEPPER_POS])
else:
self._current_stepper = None

def _get_next_stepper(self):
# Check the conditions until we find that that is true
for conditional in self._if_spec.conditionals[self._pos:]:
if conditional.is_true(self._workflow):
return conditional.body.create_stepper(self._workflow)
return None
def _create_stepper(self):
if self._pos == -1:
self._current_stepper = None
# Check the conditions until we find one that is true
for idx, condition in enumerate(self._if_spec.conditionals):
if condition.is_true(self._workflow):
stepper = condition.body.create_stepper(self._workflow)
self._pos = idx
self._current_stepper = stepper
return
else:
branch = self._if_spec.conditionals[self._pos]
self._current_stepper = branch.body.create_stepper(self._workflow)

def __init__(self, condition):
super(_If, self).__init__()
Expand Down