Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add expose inputs / outputs feature #1170

Merged
merged 17 commits into from
Feb 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 173 additions & 1 deletion aiida/backends/tests/work/work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from aiida.backends.testbase import AiidaTestCase
from aiida.common.links import LinkType
from aiida.daemon.workflowmanager import execute_steps
from aiida.orm.data.base import Int, Str, Bool
from aiida.orm.data.base import Int, Str, Bool, Float
from aiida.work.utils import ProcessStack
from aiida.workflows.wf_demo import WorkflowDemo
from aiida import work
Expand Down Expand Up @@ -772,3 +772,175 @@ def step_two(self):
x = Int(1)
y = Int(2)
work.launch.run(Wf, subspace={'one': Int(1), 'two': Int(2)})

class GrandParentExposeWorkChain(work.WorkChain):
@classmethod
def define(cls, spec):
super(GrandParentExposeWorkChain, cls).define(spec)

spec.expose_inputs(ParentExposeWorkChain, namespace='sub.sub')
spec.expose_outputs(ParentExposeWorkChain, namespace='sub.sub')

spec.outline(cls.do_run, cls.finalize)

def do_run(self):
return ToContext(child=self.submit(
ParentExposeWorkChain,
**self.exposed_inputs(ParentExposeWorkChain, namespace='sub.sub')
))

def finalize(self):
self.out_many(
self.exposed_outputs(
self.ctx.child,
ParentExposeWorkChain,
namespace='sub.sub'
)
)

class ParentExposeWorkChain(work.WorkChain):
@classmethod
def define(cls, spec):
super(ParentExposeWorkChain, cls).define(spec)

spec.expose_inputs(ChildExposeWorkChain, include=['a'])
spec.expose_inputs(
ChildExposeWorkChain,
exclude=['a'],
namespace='sub_1',
)
spec.expose_inputs(
ChildExposeWorkChain,
include=['b'],
namespace='sub_2',
)
spec.expose_inputs(
ChildExposeWorkChain,
include=['c'],
namespace='sub_2.sub_3',
)

spec.expose_outputs(ChildExposeWorkChain, include=['a'])
spec.expose_outputs(
ChildExposeWorkChain,
exclude=['a'],
namespace='sub_1'
)
spec.expose_outputs(
ChildExposeWorkChain,
include=['b'],
namespace='sub_2'
)
spec.expose_outputs(
ChildExposeWorkChain,
include=['c'],
namespace='sub_2.sub_3'
)

spec.outline(
cls.start_children,
cls.finalize
)

def start_children(self):
child_1 = self.submit(
ChildExposeWorkChain,
a=self.exposed_inputs(ChildExposeWorkChain)['a'],
**self.exposed_inputs(ChildExposeWorkChain, namespace='sub_1', agglomerate=False)
)
child_2 = self.submit(
ChildExposeWorkChain,
**self.exposed_inputs(
ChildExposeWorkChain,
namespace='sub_2.sub_3',
)
)
return ToContext(child_1=child_1, child_2=child_2)

def finalize(self):
exposed_1 = self.exposed_outputs(
self.ctx.child_1,
ChildExposeWorkChain,
namespace='sub_1',
agglomerate=False
)
self.out_many(exposed_1)
exposed_2 = self.exposed_outputs(
self.ctx.child_2,
ChildExposeWorkChain,
namespace='sub_2.sub_3'
)
self.out_many(exposed_2)

class ChildExposeWorkChain(work.WorkChain):
@classmethod
def define(cls, spec):
super(ChildExposeWorkChain, cls).define(spec)

spec.input('a', valid_type=Int)
spec.input('b', valid_type=Float)
spec.input('c', valid_type=Bool)

spec.output('a', valid_type=Float)
spec.output('b', valid_type=Float)
spec.output('c', valid_type=Bool)

spec.outline(cls.do_run)

def do_run(self):
self.out('a', self.inputs.a + self.inputs.b)
self.out('b', self.inputs.b)
self.out('c', self.inputs.c)

class TestWorkChainExpose(AiidaTestCase):
"""
Test the expose inputs / outputs functionality
"""

def setUp(self):
super(TestWorkChainExpose, self).setUp()
self.assertEquals(len(ProcessStack.stack()), 0)
self.runner = utils.create_test_runner()

def tearDown(self):
super(TestWorkChainExpose, self).tearDown()
work.set_runner(None)
self.runner.close()
self.runner = None
self.assertEquals(len(ProcessStack.stack()), 0)

def test_expose(self):
res = work.launch.run(
ParentExposeWorkChain,
a=Int(1),
sub_1={'b': Float(2.3), 'c': Bool(True)},
sub_2={'b': Float(1.2), 'sub_3': {'c': Bool(False)}},
)
self.assertEquals(
res,
{
'a': Float(2.2),
'sub_1.b': Float(2.3), 'sub_1.c': Bool(True),
'sub_2.b': Float(1.2), 'sub_2.sub_3.c': Bool(False)
}
)

def test_nested_expose(self):
res = work.launch.run(
GrandParentExposeWorkChain,
sub=dict(
sub=dict(
a=Int(1),
sub_1={'b': Float(2.3), 'c': Bool(True)},
sub_2={'b': Float(1.2), 'sub_3': {'c': Bool(False)}},
)
)
)
self.assertEquals(
res,
{
'sub.sub.a': Float(2.2),
'sub.sub.sub_1.b': Float(2.3), 'sub.sub.sub_1.c': Bool(True),
'sub.sub.sub_2.b': Float(1.2), 'sub.sub.sub_2.sub_3.c': Bool(False)
}
)
89 changes: 87 additions & 2 deletions aiida/work/process_spec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
from collections import defaultdict

import plumpy
import voluptuous

Expand All @@ -15,7 +17,7 @@ def __call__(self, value):
"""
Call this to validate the value against the schema.

:param value: a regular dictionary or a ParameterData instance
:param value: a regular dictionary or a ParameterData instance
:return: tuple (success, msg). success is True if the value is valid
and False otherwise, in which case msg will contain information about
the validation failure.
Expand Down Expand Up @@ -50,9 +52,92 @@ def _get_template(self, dict):


class ProcessSpec(plumpy.ProcessSpec):
"""
Contains the inputs, outputs and outline of a process.
"""

INPUT_PORT_TYPE = InputPort
PORT_NAMESPACE_TYPE = PortNamespace

def __init__(self):
super(ProcessSpec, self).__init__()
super(ProcessSpec, self).__init__()
self._exposed_inputs = defaultdict(lambda: defaultdict(list))
self._exposed_outputs = defaultdict(lambda: defaultdict(list))

def expose_inputs(self, process_class, namespace=None, exclude=(), include=None):
"""
This method allows one to automatically add the inputs from another
Process to this ProcessSpec. The optional namespace argument can be
used to group the exposed inputs in a separated PortNamespace

:param process_class: the Process class whose inputs to expose
:param namespace: a namespace in which to place the exposed inputs
:param exclude: list or tuple of input keys to exclude from being exposed
"""
self._expose_ports(
process_class=process_class,
source=process_class.spec().inputs,
destination=self.inputs,
expose_memory=self._exposed_inputs,
namespace=namespace,
exclude=exclude,
include=include
)

def expose_outputs(self, process_class, namespace=None, exclude=(), include=None):
"""
This method allows one to automatically add the ouputs from another
Process to this ProcessSpec. The optional namespace argument can be
used to group the exposed outputs in a separated PortNamespace.

:param process_class: the Process class whose inputs to expose
:param namespace: a namespace in which to place the exposed inputs
:param exclude: list or tuple of input keys to exclude from being exposed
"""
self._expose_ports(
process_class=process_class,
source=process_class.spec().outputs,
destination=self.outputs,
expose_memory=self._exposed_outputs,
namespace=namespace,
exclude=exclude,
include=include
)

def _expose_ports(
self,
process_class,
source,
destination,
expose_memory,
namespace,
exclude,
include
):
if namespace:
port_namespace = destination.create_port_namespace(namespace)
else:
port_namespace = destination
exposed_list = expose_memory[namespace][process_class]

for name, port in self._filter_names(
source.iteritems(),
exclude=exclude,
include=include
):
port_namespace[name] = port
exposed_list.append(name)

@staticmethod
def _filter_names(items, exclude, include):
if exclude and include is not None:
raise ValueError('exclude and include are mutually exclusive')

for name, port in items:
if include is not None:
if name not in include:
continue
else:
if name in exclude:
continue
yield name, port
Loading