From 1deda11929b75a2bf1e74ad404359021d4530832 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 3 Apr 2018 15:47:58 +0200 Subject: [PATCH 1/9] Cherry-picked backport of PyCifRW update to v4.2.1 This was merged into develop with PR #1073 but needs to be back ported into release v0.11.4 --- aiida/backends/tests/dataclasses.py | 90 +++++++++++++++++++++------- aiida/backends/tests/restapi.py | 5 +- aiida/backends/tests/tcodexporter.py | 12 ++-- aiida/common/utils.py | 19 ++++++ aiida/orm/data/array/trajectory.py | 4 +- aiida/orm/data/cif.py | 36 ++++++++--- aiida/tools/dbexporters/tcod.py | 4 +- setup_requirements.py | 4 +- 8 files changed, 128 insertions(+), 46 deletions(-) diff --git a/aiida/backends/tests/dataclasses.py b/aiida/backends/tests/dataclasses.py index 7a96816016..0dfcb00498 100644 --- a/aiida/backends/tests/dataclasses.py +++ b/aiida/backends/tests/dataclasses.py @@ -14,7 +14,7 @@ from aiida.common.exceptions import ModificationNotAllowed from aiida.backends.testbase import AiidaTestCase import unittest - +from aiida.common.utils import HiddenPrints def has_seekpath(): @@ -147,6 +147,7 @@ class TestCifData(AiidaTestCase): from aiida.orm.data.structure import has_ase, has_pymatgen, has_spglib, \ get_pymatgen_version from distutils.version import StrictVersion + valid_sample_cif_str = ''' data_test @@ -447,7 +448,8 @@ def test_pycifrw_from_datablocks(self): '_publ_section_title': 'Test CIF' } ] - lines = pycifrw_from_cif(datablocks).WriteOut().split('\n') + with HiddenPrints(): + lines = pycifrw_from_cif(datablocks).WriteOut().split('\n') non_comments = [] for line in lines: if not re.search('^#', line): @@ -471,7 +473,8 @@ def test_pycifrw_from_datablocks(self): ''')) loops = {'_atom_site': ['_atom_site_label', '_atom_site_occupancy']} - lines = pycifrw_from_cif(datablocks, loops).WriteOut().split('\n') + with HiddenPrints(): + lines = pycifrw_from_cif(datablocks, loops).WriteOut().split('\n') non_comments = [] for line in lines: if not re.search('^#', line): @@ -489,6 +492,49 @@ def test_pycifrw_from_datablocks(self): _publ_section_title 'Test CIF' ''')) + @unittest.skipIf(not has_pycifrw(), "Unable to import PyCifRW") + def test_pycifrw_syntax(self): + """ + Tests CifData.pycifrw_from_cif() - check syntax pb in PyCifRW 3.6 + """ + from aiida.orm.data.cif import pycifrw_from_cif + import re + + datablocks = [ + { + '_tag': '[value]', + } + ] + with HiddenPrints(): + lines = pycifrw_from_cif(datablocks).WriteOut().split('\n') + non_comments = [] + for line in lines: + if not re.search('^#', line): + non_comments.append(line) + self.assertEquals(simplify("\n".join(non_comments)), + simplify(''' +data_0 +_tag '[value]' +''')) + + @unittest.skipIf(not has_pycifrw(), "Unable to import PyCifRW") + def test_cif_with_long_line(self): + """ + Tests CifData - check that long lines (longer than 2048 characters) + are supported. + Should not raise any error. + """ + import tempfile + from aiida.orm.data.cif import CifData + + with tempfile.NamedTemporaryFile() as f: + f.write(''' +data_0 +_tag {} + '''.format('a'*5000)) + f.flush() + _ = CifData(file=f.name) + @unittest.skipIf(not has_ase(), "Unable to import ase") @unittest.skipIf(not has_pycifrw(), "Unable to import PyCifRW") def test_cif_roundtrip(self): @@ -1194,6 +1240,7 @@ class TestStructureData(AiidaTestCase): Tests the creation of StructureData objects (cell and pbc). """ from aiida.orm.data.structure import has_ase, has_spglib + from aiida.orm.data.cif import has_pycifrw def test_cell_ok_and_atoms(self): """ @@ -1603,11 +1650,14 @@ def test_get_formula(self): mode="count_compact"), 'BaTiO3') + @unittest.skipIf(not has_ase(), "Unable to import ase") + @unittest.skipIf(not has_pycifrw(), "Unable to import PyCifRW") def test_get_cif(self): """ Tests the conversion to CifData """ from aiida.orm.data.structure import StructureData + import re a = StructureData(cell=((2., 0., 0.), (0., 2., 0.), (0., 0., 2.))) @@ -1615,25 +1665,14 @@ def test_get_cif(self): a.append_atom(position=(0.5, 0.5, 0.5), symbols=['Ba']) a.append_atom(position=(1., 1., 1.), symbols=['Ti']) - try: - c = a._get_cif() - # Exception thrown if ase can't be found - except ImportError: - return - self.assertEquals(simplify(c._prepare_cif()[0]), - simplify("""#\#CIF1.1 -########################################################################## -# Crystallographic Information Format file -# Produced by PyCifRW module -# -# This is a CIF file. CIF has been adopted by the International -# Union of Crystallography as the standard for data archiving and -# transmission. -# -# For information on this file format, follow the CIF links at -# http://www.iucr.org -########################################################################## - + c = a._get_cif() + lines = c._prepare_cif()[0].split('\n') + non_comments = [] + for line in lines: + if not re.search('^#', line): + non_comments.append(line) + self.assertEquals(simplify("\n".join(non_comments)), + simplify(""" data_0 loop_ _atom_site_label @@ -2723,6 +2762,7 @@ def test_export_to_file(self): import os import tempfile from aiida.orm.data.array.trajectory import TrajectoryData + from aiida.orm.data.cif import has_pycifrw n = TrajectoryData() @@ -2775,7 +2815,11 @@ def test_export_to_file(self): os.close(handle) os.remove(filename) - for format in ['cif', 'xsf']: + if has_pycifrw(): + formats_to_test = ['cif', 'xsf'] + else: + formats_to_test = ['xsf'] + for format in formats_to_test: files_created = [] # In case there is an exception try: files_created = n.export(filename, fileformat=format) diff --git a/aiida/backends/tests/restapi.py b/aiida/backends/tests/restapi.py index 7ecaac1819..ae7cebeec7 100644 --- a/aiida/backends/tests/restapi.py +++ b/aiida/backends/tests/restapi.py @@ -771,14 +771,15 @@ def test_structure_visualization(self): """ Get the list of give calculation inputs """ + from aiida.backends.tests.dataclasses import simplify node_uuid = self.get_dummy_data()["structuredata"][0]["uuid"] url = self.get_url_prefix() + '/structures/' + str( node_uuid) + '/content/visualization?visformat=cif' with self.app.test_client() as client: rv = client.get(url) response = json.loads(rv.data) - expected_visdata = """#\\#CIF1.1\n##########################################################################\n# Crystallographic Information Format file \n# Produced by PyCifRW module\n# \n# This is a CIF file. CIF has been adopted by the International\n# Union of Crystallography as the standard for data archiving and \n# transmission.\n#\n# For information on this file format, follow the CIF links at\n# http://www.iucr.org\n##########################################################################\n\ndata_0\nloop_\n _atom_site_label\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_type_symbol\n Ba1 0.0 0.0 0.0 Ba\n \n_cell_angle_alpha 90.0\n_cell_angle_beta 90.0\n_cell_angle_gamma 90.0\n_cell_length_a 2.0\n_cell_length_b 2.0\n_cell_length_c 2.0\nloop_\n _symmetry_equiv_pos_as_xyz\n 'x, y, z'\n \n_symmetry_int_tables_number 1\n_symmetry_space_group_name_H-M 'P 1'\n""" - self.assertEquals(response["data"]["visualization"]["str_viz_info"]["data"],expected_visdata) + expected_visdata = """\n##########################################################################\n# Crystallographic Information Format file \n# Produced by PyCifRW module\n# \n# This is a CIF file. CIF has been adopted by the International\n# Union of Crystallography as the standard for data archiving and \n# transmission.\n#\n# For information on this file format, follow the CIF links at\n# http://www.iucr.org\n##########################################################################\n\ndata_0\nloop_\n _atom_site_label\n _atom_site_fract_x\n _atom_site_fract_y\n _atom_site_fract_z\n _atom_site_type_symbol\n Ba1 0.0 0.0 0.0 Ba\n \n_cell_angle_alpha 90.0\n_cell_angle_beta 90.0\n_cell_angle_gamma 90.0\n_cell_length_a 2.0\n_cell_length_b 2.0\n_cell_length_c 2.0\nloop_\n _symmetry_equiv_pos_as_xyz\n 'x, y, z'\n \n_symmetry_int_tables_number 1\n_symmetry_space_group_name_H-M 'P 1'\n""" + self.assertEquals(simplify(response["data"]["visualization"]["str_viz_info"]["data"]),simplify(expected_visdata)) self.assertEquals(response["data"]["visualization"]["str_viz_info"]["format"],"cif") self.assertEquals(response["data"]["visualization"]["dimensionality"], {u'dim': 3, u'value': 8.0, u'label': u'volume'}) diff --git a/aiida/backends/tests/tcodexporter.py b/aiida/backends/tests/tcodexporter.py index 11f12f70ea..6efd914196 100644 --- a/aiida/backends/tests/tcodexporter.py +++ b/aiida/backends/tests/tcodexporter.py @@ -464,24 +464,24 @@ def test_export_trajectory(self): '_cell_length_b', '_cell_length_c', '_chemical_formula_sum', - '_symmetry_Int_Tables_number', '_symmetry_equiv_pos_as_xyz', - '_symmetry_space_group_name_H-M', - '_symmetry_space_group_name_Hall' + '_symmetry_int_tables_number', + '_symmetry_space_group_name_h-m', + '_symmetry_space_group_name_hall' ] tcod_file_tags = [ '_tcod_content_encoding_id', '_tcod_content_encoding_layer_id', '_tcod_content_encoding_layer_type', - '_tcod_file_URI', '_tcod_file_content_encoding', '_tcod_file_contents', '_tcod_file_id', '_tcod_file_md5sum', '_tcod_file_name', '_tcod_file_role', - '_tcod_file_sha1sum' + '_tcod_file_sha1sum', + '_tcod_file_uri', ] # Not stored and not to be stored: @@ -500,7 +500,7 @@ def test_export_trajectory(self): v = export_values(td, trajectory_index=1, store=True) self.assertEqual(sorted(v['0'].keys()), expected_tags + tcod_file_tags) - + # Both stored and expected to be stored: td = TrajectoryData(structurelist=structurelist) td.store() diff --git a/aiida/common/utils.py b/aiida/common/utils.py index 77320819c9..a2bf7b2441 100644 --- a/aiida/common/utils.py +++ b/aiida/common/utils.py @@ -1265,3 +1265,22 @@ def get_mode_string(mode): else: perm.append("-") return "".join(perm) + + +class HiddenPrints: + """ + Class to prevent any print to the std output. + Usage: + + with HiddenPrints(): + print("I won't print this") + """ + + def __enter__(self): + from os import devnull + self._original_stdout = sys.stdout + sys.stdout = open(devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout = self._original_stdout + diff --git a/aiida/orm/data/array/trajectory.py b/aiida/orm/data/array/trajectory.py index 43f2537944..178b5ce75a 100644 --- a/aiida/orm/data/array/trajectory.py +++ b/aiida/orm/data/array/trajectory.py @@ -476,6 +476,7 @@ def _prepare_cif(self, trajectory_index=None, main_file_name=""): import CifFile from aiida.orm.data.cif \ import ase_loops, cif_from_ase, pycifrw_from_cif + from aiida.common.utils import HiddenPrints cif = "" indices = range(self.numsteps) @@ -485,7 +486,8 @@ def _prepare_cif(self, trajectory_index=None, main_file_name=""): structure = self.get_step_structure(idx) ciffile = pycifrw_from_cif(cif_from_ase(structure.get_ase()), ase_loops) - cif = cif + ciffile.WriteOut() + with HiddenPrints(): + cif = cif + ciffile.WriteOut() return cif.encode('utf-8'), {} def _prepare_tcod(self, main_file_name="", **kwargs): diff --git a/aiida/orm/data/cif.py b/aiida/orm/data/cif.py index ff89318fa5..8f087e9bdf 100644 --- a/aiida/orm/data/cif.py +++ b/aiida/orm/data/cif.py @@ -11,6 +11,7 @@ # pylint: disable=invalid-name,too-many-locals,too-many-statements from aiida.orm.data.singlefile import SinglefileData from aiida.orm.calculation.inline import optional_inline +from aiida.common.utils import HiddenPrints ase_loops = { '_atom_site': [ @@ -52,6 +53,7 @@ def has_pycifrw(): # pylint: disable=unused-variable try: import CifFile + from CifFile import CifBlock except ImportError: return False return True @@ -216,11 +218,18 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): :return: CifFile """ import CifFile + from CifFile import CifBlock if loops is None: loops = dict() cif = CifFile.CifFile() + try: + cif.set_grammar("1.1") + except AttributeError: + # if no grammar can be set, we assume it's 1.1 (widespread standard) + pass + if names and len(names) < len(datablocks): raise ValueError("Not enough names supplied for " "datablocks: {} (names) < " @@ -229,8 +238,8 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): name = str(i) if names: name = names[i] - cif.NewBlock(name) - datablock = cif[name] + datablock = CifBlock() + cif[name] = datablock for loopname in loops.keys(): loopdata = ([[]], [[]]) row_size = None @@ -252,6 +261,9 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): datablock.AddCifItem(loopdata) for tag in sorted(values.keys()): datablock[tag] = values[tag] + # create automatically a loop for non-scalar values + if isinstance(values[tag],(tuple,list)) and tag not in loops.keys(): + datablock.CreateLoop([tag]) return cif @@ -286,7 +298,8 @@ def refine_inline(node): refined_atoms, symmetry = ase_refine_cell(original_atoms) cif = CifData(ase=refined_atoms) - cif.values.dictionary[name] = cif.values.dictionary.pop(str(0)) + if name != str(0): + cif.values.rename(str(0),name) # Remove all existing symmetry tags before overwriting: for tag in symmetry_tags: @@ -472,7 +485,8 @@ def set_ase(self, aseatoms): import tempfile cif = cif_from_ase(aseatoms) with tempfile.NamedTemporaryFile() as f: - f.write(pycifrw_from_cif(cif, loops=ase_loops).WriteOut()) + with HiddenPrints(): + f.write(pycifrw_from_cif(cif, loops=ase_loops).WriteOut()) f.flush() self.set_file(f.name) @@ -490,11 +504,14 @@ def values(self): if self._values is None: try: import CifFile + from CifFile import CifBlock except ImportError as e: - raise ImportError( - str(e) + '. You need to install the PyCifRW package.') - self._values = CifFile.ReadCif( - self.get_file_abs_path(), scantype=self.get_attr('scan_type')) + raise ImportError(str(e) + '. You need to install the PyCifRW package.') + c = CifFile.ReadCif(self.get_file_abs_path()) + # change all StarBlocks into CifBlocks + for k,v in c.items(): + c.dictionary[k] = CifBlock(v) + self._values = c return self._values def set_values(self, values): @@ -509,7 +526,8 @@ def set_values(self, values): """ import tempfile with tempfile.NamedTemporaryFile() as f: - f.write(values.WriteOut()) + with HiddenPrints(): + f.write(values.WriteOut()) f.flush() self.set_file(f.name) diff --git a/aiida/tools/dbexporters/tcod.py b/aiida/tools/dbexporters/tcod.py index ccaab1820d..7b81675a0d 100644 --- a/aiida/tools/dbexporters/tcod.py +++ b/aiida/tools/dbexporters/tcod.py @@ -879,8 +879,8 @@ def add_metadata_inline(what, node=None, parameters=None, args=None): for tag in node.values[dataname].keys(): datablock[tag] = node.values[dataname][tag] datablocks.append(datablock) - for loop in node.values[dataname].loops: - loops[loop.keys()[0]] = loop.keys() + for loop in node.values[dataname].loops.values(): + loops[loop[0]] = loop # Unpacking the kwargs from ParameterData kwargs = {} diff --git a/setup_requirements.py b/setup_requirements.py index 6796a8dbf5..ed037469b6 100644 --- a/setup_requirements.py +++ b/setup_requirements.py @@ -113,11 +113,9 @@ 'pymatgen==4.5.3', # support for NWChem I/O 'ase==3.12.0', # support for crystal structure manipulation 'PyMySQL==0.7.9', # required by ICSD tools - 'PyCifRW==3.6.2.1', + 'PyCifRW==4.2.1', 'seekpath==1.8.0', 'qe-tools==1.0', - # support for the AiiDA CifData class. Update to version 4 does - # break tests ], # Requirements for jupyter notebook 'notebook': [ From 164ef7f032254ccc7646ce06249b983b187f72e7 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 29 Mar 2018 14:51:58 +0200 Subject: [PATCH 2/9] Serialize the context of a WorkChain before persisting The user is free to populate the context with Node instances, which means that if the WorkChain needs to be persisted that the nodes need to be serialized. Since the nodes are not necessarily stored upon calling save_instance_state, we also store them if they were not yet stored. We also use this opportunity to replace the ad hoc serialization in the pickle persister and use the new more complete serializer and deserializer. --- aiida/orm/utils.py | 38 ++++++++++++++++- aiida/utils/serialize.py | 87 +++++++++++++++++++++++++++++++++++++++ aiida/work/persistence.py | 57 ++++++------------------- aiida/work/workchain.py | 9 ++-- 4 files changed, 143 insertions(+), 48 deletions(-) create mode 100644 aiida/utils/serialize.py diff --git a/aiida/orm/utils.py b/aiida/orm/utils.py index abb397260e..071b251500 100644 --- a/aiida/orm/utils.py +++ b/aiida/orm/utils.py @@ -11,7 +11,7 @@ from aiida.common.pluginloader import BaseFactory from aiida.common.utils import abstractclassmethod -__all__ = ['CalculationFactory', 'DataFactory', 'WorkflowFactory', 'load_node', 'load_workflow'] +__all__ = ['CalculationFactory', 'DataFactory', 'WorkflowFactory', 'load_group', 'load_node', 'load_workflow'] def CalculationFactory(module, from_abstract=False): @@ -138,6 +138,42 @@ def create_node_id_qb(node_id=None, pk=None, uuid=None, return qb +def load_group(group_id=None, pk=None, uuid=None, query_with_dashes=True): + """ + Load a group by its pk or uuid + + :param group_id: pk (integer) or uuid (string) of a group + :param pk: pk of a group + :param uuid: uuid of a group, or the beginning of the uuid + :param bool query_with_dashes: allow to query for a uuid with dashes (default=True) + :returns: the requested group if existing and unique + :raise InputValidationError: if none or more than one of the arguments are supplied + :raise TypeError: if the wrong types are provided + :raise NotExistent: if no matching Node is found. + :raise MultipleObjectsError: if more than one Node was found + """ + from aiida.orm import Group + + kwargs = { + 'node_id': group_id, + 'pk': pk, + 'uuid': uuid, + 'parent_class': Group, + 'query_with_dashes': query_with_dashes + } + + qb = create_node_id_qb(**kwargs) + qb.add_projection('node', '*') + qb.limit(2) + + try: + return qb.one()[0] + except MultipleObjectsError: + raise MultipleObjectsError('More than one group found. Provide longer starting pattern for uuid.') + except NotExistent: + raise NotExistent('No group was found') + + def load_node(node_id=None, pk=None, uuid=None, parent_class=None, query_with_dashes=True): """ Returns an AiiDA node given its PK or UUID. diff --git a/aiida/utils/serialize.py b/aiida/utils/serialize.py new file mode 100644 index 0000000000..59894637e7 --- /dev/null +++ b/aiida/utils/serialize.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +import collections +from ast import literal_eval +from aiida.common.extendeddicts import AttributeDict +from aiida.orm import Group, Node, load_group, load_node + + +_PREFIX_KEY_TUPLE = 'tuple():' +_PREFIX_VALUE_NODE = 'aiida_node:' +_PREFIX_VALUE_GROUP = 'aiida_group:' + + +def encode_key(key): + """ + Helper function for the serialize_data function which may need to serialize a + dictionary that uses tuples as keys. This function will encode the tuple into + a string such that it is JSON serializable + + :param key: the key to encode + :return: the encoded key + """ + if isinstance(key, tuple): + return '{}{}'.format(_PREFIX_KEY_TUPLE, key) + else: + return key + + +def decode_key(key): + """ + Helper function for the deserialize_data function which can undo the key encoding + of tuple keys done by the encode_key function + + :param key: the key to decode + :return: the decoded key + """ + if key.startswith(_PREFIX_KEY_TUPLE): + return literal_eval(key[len(_PREFIX_KEY_TUPLE):]) + else: + return key + + +def serialize_data(data): + """ + Serialize a value or collection that may potentially contain AiiDA nodes, which + will be serialized to their UUID. Keys encountered in any mappings, such as a dictionary, + will also be encoded if necessary. An example is where tuples are used as keys in the + pseudo potential input dictionaries. These operations will ensure that the returned data is + JSON serializable. + + :param data: a single value or collection + :return: the serialized data with the same internal structure + """ + if isinstance(data, Node): + return '{}{}'.format(_PREFIX_VALUE_NODE, data.uuid) + elif isinstance(data, Group): + return '{}{}'.format(_PREFIX_VALUE_GROUP, data.uuid) + elif isinstance(data, AttributeDict): + return AttributeDict({encode_key(key): serialize_data(value) for key, value in data.iteritems()}) + elif isinstance(data, collections.Mapping): + return {encode_key(key): serialize_data(value) for key, value in data.iteritems()} + elif isinstance(data, collections.Sequence) and not isinstance(data, (str, unicode)): + return [serialize_data(value) for value in data] + else: + return data + + +def deserialize_data(data): + """ + Deserialize a single value or a collection that may contain serialized AiiDA nodes. This is + essentially the inverse operation of serialize_data which will reload node instances from + the serialized UUID data. Encoded tuples that are used as dictionary keys will be decoded. + + :param data: serialized data + :return: the deserialized data with keys decoded and node instances loaded from UUID's + """ + if isinstance(data, AttributeDict): + return AttributeDict({decode_key(key): deserialize_data(value) for key, value in data.iteritems()}) + elif isinstance(data, collections.Mapping): + return {decode_key(key): deserialize_data(value) for key, value in data.iteritems()} + elif isinstance(data, collections.Sequence) and not isinstance(data, (str, unicode)): + return [deserialize_data(value) for value in data] + elif isinstance(data, (str, unicode)) and data.startswith(_PREFIX_VALUE_NODE): + return load_node(uuid=data[len(_PREFIX_VALUE_NODE):]) + elif isinstance(data, (str, unicode)) and data.startswith(_PREFIX_VALUE_GROUP): + return load_group(uuid=data[len(_PREFIX_VALUE_GROUP):]) + else: + return data diff --git a/aiida/work/persistence.py b/aiida/work/persistence.py index b9b40af2d9..ac0840172f 100644 --- a/aiida/work/persistence.py +++ b/aiida/work/persistence.py @@ -16,6 +16,7 @@ import plum.persistence.pickle_persistence from plum.process import Process from aiida.common.lang import override +from aiida.utils.serialize import serialize_data, deserialize_data from aiida.work.defaults import class_loader import glob @@ -397,9 +398,13 @@ def _load_checkpoint(self, pid): def load_checkpoint_from_file_object(self, file_object): cp = pickle.load(file_object) - inputs = cp[Process.BundleKeys.INPUTS.value] + inputs = cp[Process.BundleKeys.INPUTS_RAW.value] if inputs: - cp[Process.BundleKeys.INPUTS.value] = self._load_nodes_from(inputs) + cp[Process.BundleKeys.INPUTS_RAW.value] = deserialize_data(inputs) + + inputs = cp[Process.BundleKeys.INPUTS_PARSED.value] + if inputs: + cp[Process.BundleKeys.INPUTS_PARSED.value] = deserialize_data(inputs) cp.set_class_loader(class_loader) return cp @@ -412,51 +417,15 @@ def get_checkpoint_state(self, pid): def create_bundle(self, process): bundle = Bundle() process.save_instance_state(bundle) - inputs = bundle[Process.BundleKeys.INPUTS.value] + inputs = bundle[Process.BundleKeys.INPUTS_RAW.value] if inputs: - bundle[Process.BundleKeys.INPUTS.value] = self._convert_to_ids(inputs) - - return bundle - - def _convert_to_ids(self, nodes): - from aiida.orm import Node - - input_ids = {} - for label, node in nodes.iteritems(): - if node is None: - continue - elif isinstance(node, Node): - if node.is_stored: - input_ids[label] = node.pk - else: - # Try using the UUID, but there's probably no chance of - # being abel to recover the node from this if not stored - # (for the time being) - input_ids[label] = node.uuid - elif isinstance(node, collections.Mapping): - input_ids[label] = self._convert_to_ids(node) - - return input_ids - - def _load_nodes_from(self, pks_mapping): - """ - Take a dictionary of of {label: pk} or nested dictionary i.e. - {label: {label: pk}} and convert to the equivalent dictionary but - with nodes instead of the ids. + bundle[Process.BundleKeys.INPUTS_RAW.value] = serialize_data(inputs) - :param pks_mapping: The dictionary of node pks. - :return: A dictionary with the loaded nodes. - :rtype: dict - """ - from aiida.orm import load_node + inputs = bundle[Process.BundleKeys.INPUTS_PARSED.value] + if inputs: + bundle[Process.BundleKeys.INPUTS_PARSED.value] = serialize_data(inputs) - nodes = {} - for label, pk in pks_mapping.iteritems(): - if isinstance(pk, collections.Mapping): - nodes[label] = self._load_nodes_from(pk) - else: - nodes[label] = load_node(pk=pk) - return nodes + return bundle def _clear(self, fileobj): """ diff --git a/aiida/work/workchain.py b/aiida/work/workchain.py index f79faedc91..1c66948836 100644 --- a/aiida/work/workchain.py +++ b/aiida/work/workchain.py @@ -18,7 +18,8 @@ from aiida.common.lang import override from aiida.common.utils import get_class_string, get_object_string, \ get_object_from_string -from aiida.orm import load_node, load_workflow +from aiida.orm import load_node, load_workflow, Node +from aiida.utils.serialize import serialize_data, deserialize_data from plum.wait_ons import Checkpoint, WaitOnAll, WaitOnProcess from plum.wait import WaitOn from plum.persistence.bundle import Bundle @@ -124,7 +125,9 @@ def setdefault(self, key, default=None): def save_instance_state(self, out_state): for k, v in self._content.iteritems(): - out_state[k] = v + if isinstance(v, Node) and not v.is_stored: + v.store() + out_state[k] = serialize_data(v) def __init__(self): super(WorkChain, self).__init__() @@ -283,7 +286,7 @@ def on_create(self, pid, inputs, saved_state): self._context = self.Context() else: # Recreate the context - self._context = self.Context(saved_state[self._CONTEXT]) + self._context = self.Context(deserialize_data(saved_state[self._CONTEXT])) # Recreate the stepper if self._STEPPER_STATE in saved_state: From 6b41461de14bd6163aad44aba1a46a7be7ff1ca8 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Thu, 29 Mar 2018 17:36:46 +0200 Subject: [PATCH 3/9] Update the version of plumpy This release of plumpy fixes an issue where the inputs of a process are recreated each time when loaded from a persisted state, meaning that inputs that were not explicitly specified but taken from the spec's default will be recreated. For default nodes in aiida-core this means that the nodes are duplicated multiple times --- docs/requirements_for_rtd.txt | 2 +- setup_requirements.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/requirements_for_rtd.txt b/docs/requirements_for_rtd.txt index 6574246e78..cbea28c31a 100644 --- a/docs/requirements_for_rtd.txt +++ b/docs/requirements_for_rtd.txt @@ -23,7 +23,7 @@ paramiko==2.4.0 passlib==1.7.1 pathlib2==2.3.0 pip==9.0.1 -plumpy==0.7.10 +plumpy==0.7.12 portalocker==1.1.0 psutil==5.4.0 pycrypto==2.6.1 diff --git a/setup_requirements.py b/setup_requirements.py index ed037469b6..76472133e6 100644 --- a/setup_requirements.py +++ b/setup_requirements.py @@ -41,7 +41,7 @@ 'psutil==5.4.0', 'meld3==1.0.0', 'numpy==1.12.0', - 'plumpy==0.7.10', + 'plumpy==0.7.12', 'portalocker==1.1.0', 'SQLAlchemy==1.0.19', # upgrade to SQLalchemy 1.1.5 does break tests, see #465 'SQLAlchemy-Utils==0.33.0', From 3c1c2a0e122e20c9003e5a8af3cb29b1052c7479 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Fri, 23 Feb 2018 13:57:11 +0100 Subject: [PATCH 4/9] Sort output of verdi code list by code pk to have consistent results The verdi work list test that compared the output with and without the -a flag was failing sometimes, simply because the order was wrong but the exact string was compared. Sorting by code id will fix the order and this breaking test --- aiida/cmdline/commands/code.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aiida/cmdline/commands/code.py b/aiida/cmdline/commands/code.py index faa2eef07f..2982afad03 100644 --- a/aiida/cmdline/commands/code.py +++ b/aiida/cmdline/commands/code.py @@ -720,6 +720,7 @@ def code_list(self, *args): qb.append(Computer, computer_of="code", project=["name"], filters=qb_computer_filters) + qb.order_by({Code: {'id': 'asc'}}) self.print_list_res(qb, show_owner) # If there is no filter on computers @@ -737,6 +738,7 @@ def code_list(self, *args): filters=qb_user_filters) qb.append(Computer, computer_of="code", project=["name"]) + qb.order_by({Code: {'id': 'asc'}}) self.print_list_res(qb, show_owner) # Now print all the local codes. To get the local codes we ask @@ -757,6 +759,7 @@ def code_list(self, *args): qb.append(User, creator_of="code", project=["email"], filters=qb_user_filters) + qb.order_by({Code: {'id': 'asc'}}) self.print_list_res(qb, show_owner) @staticmethod From edbfcd2d370a725fd7a8f2556b783eff3eec5e8d Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 3 Apr 2018 18:10:44 +0200 Subject: [PATCH 5/9] Save and load the parsed inputs from the persisted state This is an issue in plumpy that was also fixed there. The parsed inputs of a Process, which is returned by calling self.inputs, were being rebuilt from the raw inputs everytime the process was loaded from a persisted state. This meant that inputs that were not explicitly passed by the user and were populated with the defaults specified by the port, were being recreated upon reloading the instance. However, they should have been the ones that were created when the process was created the first time around. Therefore we persist the parsed inputs to the saved state and reload them instead of recreating them with `create_input_args` --- aiida/utils/serialize.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aiida/utils/serialize.py b/aiida/utils/serialize.py index 59894637e7..2431819828 100644 --- a/aiida/utils/serialize.py +++ b/aiida/utils/serialize.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import collections from ast import literal_eval +from plumpy.util import AttributesFrozendict from aiida.common.extendeddicts import AttributeDict from aiida.orm import Group, Node, load_group, load_node @@ -56,6 +57,8 @@ def serialize_data(data): return '{}{}'.format(_PREFIX_VALUE_GROUP, data.uuid) elif isinstance(data, AttributeDict): return AttributeDict({encode_key(key): serialize_data(value) for key, value in data.iteritems()}) + elif isinstance(data, AttributesFrozendict): + return AttributesFrozendict({encode_key(key): serialize_data(value) for key, value in data.iteritems()}) elif isinstance(data, collections.Mapping): return {encode_key(key): serialize_data(value) for key, value in data.iteritems()} elif isinstance(data, collections.Sequence) and not isinstance(data, (str, unicode)): @@ -75,6 +78,8 @@ def deserialize_data(data): """ if isinstance(data, AttributeDict): return AttributeDict({decode_key(key): deserialize_data(value) for key, value in data.iteritems()}) + elif isinstance(data, AttributesFrozendict): + return AttributesFrozendict({decode_key(key): deserialize_data(value) for key, value in data.iteritems()}) elif isinstance(data, collections.Mapping): return {decode_key(key): deserialize_data(value) for key, value in data.iteritems()} elif isinstance(data, collections.Sequence) and not isinstance(data, (str, unicode)): From f6662a71fbb162809cae36001415edd86de4b678 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 3 Apr 2018 18:11:40 +0200 Subject: [PATCH 6/9] Run the pre-release helper scripts --- aiida/utils/serialize.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/aiida/utils/serialize.py b/aiida/utils/serialize.py index 2431819828..652f50f3c7 100644 --- a/aiida/utils/serialize.py +++ b/aiida/utils/serialize.py @@ -1,4 +1,12 @@ # -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### import collections from ast import literal_eval from plumpy.util import AttributesFrozendict From 47607c8268e49dbe1cd05dd202290d7f83c11fdb Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 3 Apr 2018 18:12:44 +0200 Subject: [PATCH 7/9] Update the CHANGELOG.md --- CHANGELOG.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3bee5a2d8e..8ca38a9405 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +## v0.11.4 + +### Improvements +- PyCifRW upgraded to 4.2.1 [[#1073]](https://github.com/aiidateam/aiida_core/pull/1073) + +### Critical bug fixes +- Persist and load parsed workchain inputs and do not recreate to avoid creating duplicates for default inputs [[#1362]](https://github.com/aiidateam/aiida_core/pull/1362) +- Serialize `WorkChain` context before persisting [[#1354]](https://github.com/aiidateam/aiida_core/pull/1354) + + ## v0.11.3 ### Improvements From 152e13e399f0f295e3e713d054724d842f917274 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 3 Apr 2018 18:22:10 +0200 Subject: [PATCH 8/9] Run pre-commit hooks --- .prospector.yaml | 3 +++ aiida/orm/data/cif.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/.prospector.yaml b/.prospector.yaml index 25e0290850..07d7db638c 100644 --- a/.prospector.yaml +++ b/.prospector.yaml @@ -9,3 +9,6 @@ pylint: pyflakes: run: false + +mccabe: + run: false diff --git a/aiida/orm/data/cif.py b/aiida/orm/data/cif.py index 8f087e9bdf..3d24b1489d 100644 --- a/aiida/orm/data/cif.py +++ b/aiida/orm/data/cif.py @@ -208,6 +208,7 @@ def cif_from_ase(ase, full_occupancies=False, add_fake_biso=False): return datablocks +# pylint: disable=too-many-branches def pycifrw_from_cif(datablocks, loops=None, names=None): """ Constructs PyCifRW's CifFile from an array of CIF datablocks. @@ -229,7 +230,7 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): except AttributeError: # if no grammar can be set, we assume it's 1.1 (widespread standard) pass - + if names and len(names) < len(datablocks): raise ValueError("Not enough names supplied for " "datablocks: {} (names) < " @@ -262,7 +263,8 @@ def pycifrw_from_cif(datablocks, loops=None, names=None): for tag in sorted(values.keys()): datablock[tag] = values[tag] # create automatically a loop for non-scalar values - if isinstance(values[tag],(tuple,list)) and tag not in loops.keys(): + if isinstance(values[tag], + (tuple, list)) and tag not in loops.keys(): datablock.CreateLoop([tag]) return cif @@ -299,7 +301,7 @@ def refine_inline(node): cif = CifData(ase=refined_atoms) if name != str(0): - cif.values.rename(str(0),name) + cif.values.rename(str(0), name) # Remove all existing symmetry tags before overwriting: for tag in symmetry_tags: @@ -506,10 +508,11 @@ def values(self): import CifFile from CifFile import CifBlock except ImportError as e: - raise ImportError(str(e) + '. You need to install the PyCifRW package.') + raise ImportError( + str(e) + '. You need to install the PyCifRW package.') c = CifFile.ReadCif(self.get_file_abs_path()) # change all StarBlocks into CifBlocks - for k,v in c.items(): + for k, v in c.items(): c.dictionary[k] = CifBlock(v) self._values = c return self._values From fe86f8cac2bbbde59841b39677230c70afb62295 Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Tue, 3 Apr 2018 18:22:46 +0200 Subject: [PATCH 9/9] Update version number to v0.11.4 --- aiida/__init__.py | 2 +- aiida/utils/serialize.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aiida/__init__.py b/aiida/__init__.py index 73a599b3a4..732da5b3bc 100644 --- a/aiida/__init__.py +++ b/aiida/__init__.py @@ -13,7 +13,7 @@ __copyright__ = u"Copyright (c), This file is part of the AiiDA platform. For further information please visit http://www.aiida.net/. All rights reserved." __license__ = "MIT license, see LICENSE.txt file." -__version__ = "0.11.3" +__version__ = "0.11.4" __authors__ = "The AiiDA team." __paper__ = """G. Pizzi, A. Cepellotti, R. Sabatini, N. Marzari, and B. Kozinsky, "AiiDA: automated interactive infrastructure and database for computational science", Comp. Mat. Sci 111, 218-230 (2016); http://dx.doi.org/10.1016/j.commatsci.2015.09.013 - http://www.aiida.net.""" __paper_short__ = """G. Pizzi et al., Comp. Mat. Sci 111, 218 (2016).""" diff --git a/aiida/utils/serialize.py b/aiida/utils/serialize.py index 652f50f3c7..d02983a394 100644 --- a/aiida/utils/serialize.py +++ b/aiida/utils/serialize.py @@ -9,7 +9,7 @@ ########################################################################### import collections from ast import literal_eval -from plumpy.util import AttributesFrozendict +from plum.util import AttributesFrozendict from aiida.common.extendeddicts import AttributeDict from aiida.orm import Group, Node, load_group, load_node