diff --git a/aiida/backends/djsite/db/models.py b/aiida/backends/djsite/db/models.py index 315bbd31d3..8070891df8 100644 --- a/aiida/backends/djsite/db/models.py +++ b/aiida/backends/djsite/db/models.py @@ -185,8 +185,7 @@ def get_aiida_class(self): """ from aiida.orm.node import Node from aiida.common.old_pluginloader import from_type_to_pluginclassname - from aiida.common.pluginloader import load_plugin - from aiida.common import aiidalogger + from aiida.common.pluginloader import load_plugin_safe try: pluginclassname = from_type_to_pluginclassname(self.type) @@ -194,12 +193,7 @@ def get_aiida_class(self): raise DbContentError("The type name of node with pk= {} is " "not valid: '{}'".format(self.pk, self.type)) - try: - PluginClass = load_plugin(Node, 'aiida.orm', pluginclassname) - except MissingPluginError: - aiidalogger.error("Unable to find plugin for type '{}' (node= {}), " - "will use base Node class".format(self.type, self.pk)) - PluginClass = Node + PluginClass = load_plugin_safe(Node, 'aiida.orm', pluginclassname, self.type, self.pk) return PluginClass(dbnode=self) diff --git a/aiida/backends/sqlalchemy/models/node.py b/aiida/backends/sqlalchemy/models/node.py index 2101801563..82de0f729a 100644 --- a/aiida/backends/sqlalchemy/models/node.py +++ b/aiida/backends/sqlalchemy/models/node.py @@ -28,7 +28,6 @@ from aiida.backends.sqlalchemy.models.utils import uuid_func from aiida.common import aiidalogger -from aiida.common.pluginloader import load_plugin from aiida.common.exceptions import DbContentError, MissingPluginError from aiida.common.datastructures import calc_states, _sorted_datastates, sort_states @@ -154,6 +153,7 @@ def get_aiida_class(self): """ from aiida.common.old_pluginloader import from_type_to_pluginclassname from aiida.orm.node import Node + from aiida.common.pluginloader import load_plugin_safe try: pluginclassname = from_type_to_pluginclassname(self.type) @@ -161,12 +161,7 @@ def get_aiida_class(self): raise DbContentError("The type name of node with pk= {} is " "not valid: '{}'".format(self.pk, self.type)) - try: - PluginClass = load_plugin(Node, 'aiida.orm', pluginclassname) - except MissingPluginError: - aiidalogger.error("Unable to find plugin for type '{}' (node= {}), " - "will use base Node class".format(self.type, self.pk)) - PluginClass = Node + PluginClass = load_plugin_safe(Node, 'aiida.orm', pluginclassname, self.type, self.pk) return PluginClass(dbnode=self) diff --git a/aiida/backends/tests/nodes.py b/aiida/backends/tests/nodes.py index 99a94c0802..368343d1a0 100644 --- a/aiida/backends/tests/nodes.py +++ b/aiida/backends/tests/nodes.py @@ -1410,6 +1410,55 @@ def test_load_node(self): with self.assertRaises(NotExistent): load_node(spec, parent_class=ArrayData) + def test_load_plugin_safe(self): + from aiida.orm import (JobCalculation, CalculationFactory, DataFactory) + + ###### for calculation + calc_params = { + 'computer': self.computer, + 'resources': {'num_machines': 1, 'num_mpiprocs_per_machine': 1} + } + + TemplateReplacerCalc = CalculationFactory('simpleplugins.templatereplacer') + testcalc = TemplateReplacerCalc(**calc_params).store() + jobcalc = JobCalculation(**calc_params).store() + + # compare if plugin exist + obj = testcalc.dbnode.get_aiida_class() + self.assertEqual(type(testcalc), type(obj)) + + # change node type and save in database again + testcalc.dbnode.type = "calculation.job.simpleplugins_tmp.templatereplacer.TemplatereplacerCalculation." + testcalc.dbnode.save() + + # changed node should return job calc as its plugin is not exist + obj = testcalc.dbnode.get_aiida_class() + self.assertEqual(type(jobcalc), type(obj)) + + ####### for data + KpointsData = DataFactory('array.kpoints') + kpoint = KpointsData().store() + Data = DataFactory("Data") + data = Data().store() + + # compare if plugin exist + obj = kpoint.dbnode.get_aiida_class() + self.assertEqual(type(kpoint), type(obj)) + + # change node type and save in database again + kpoint.dbnode.type = "data.array.kpoints_tmp.KpointsData." + kpoint.dbnode.save() + + # changed node should return data node as its plugin is not exist + obj = kpoint.dbnode.get_aiida_class() + self.assertEqual(type(data), type(obj)) + + ###### for node + n1 = Node().store() + obj = n1.dbnode.get_aiida_class() + self.assertEqual(type(n1), type(obj)) + + class TestSubNodesAndLinks(AiidaTestCase): diff --git a/aiida/common/pluginloader.py b/aiida/common/pluginloader.py index a85c15ddf4..fbbf5a5ef2 100644 --- a/aiida/common/pluginloader.py +++ b/aiida/common/pluginloader.py @@ -125,6 +125,62 @@ def get_plugin(category, name): return plugin +def load_plugin_safe(base_class, plugins_module, plugin_type, node_type, node_pk): + """ + It is a wrapper of load_plugin function to return closely related node class + if plugin is not available. By default it returns base Node class and does not + raise exception. + + params: Look at the docstring of aiida.common.old_pluginloader.load_plugin for more Info + + :param: node_type: type of the node + :param node_pk: node pk + + :return: The plugin class + """ + from aiida.common import aiidalogger + + try: + PluginClass = load_plugin(base_class, plugins_module, plugin_type) + except MissingPluginError: + node_parts = plugin_type.partition(".") + base_node_type = node_parts[0] + + ## data node: temporarily returning base data node. + # In future its better to check the closest available plugin and return it. + # For example if type is "aiida.orm.data.array.kpoints_tmp.KpointsData" + # it should return array data node and not base data node + if base_node_type == "data": + PluginClass = load_plugin(base_class, plugins_module, 'data.Data') + + ## code node + elif base_node_type == "code": + PluginClass = load_plugin(base_class, plugins_module, 'code.Code') + + ## calculation node: for calculation currently we are hardcoding cases + elif base_node_type == "calculation": + sub_node_parts = node_parts[2].partition(".") + sub_node_type = sub_node_parts[0] + if sub_node_type == "job": + PluginClass = load_plugin(base_class, plugins_module, 'calculation.job.JobCalculation') + elif sub_node_type == "inline": + PluginClass = load_plugin(base_class, plugins_module, 'calculation.inline.InlineCalculation') + elif sub_node_type == "work": + PluginClass = load_plugin(base_class, plugins_module, 'calculation.work.WorkCalculation') + else: + PluginClass = load_plugin(base_class, plugins_module, 'calculation.Calculation') + + ## for base node + elif base_node_type == "node": + PluginClass = base_class + + ## default case + else: + aiidalogger.error("Unable to find plugin for type '{}' (node= {}), " + "will use base Node class".format(node_type, node_pk)) + PluginClass = base_class + + return PluginClass + def load_plugin(base_class, plugins_module, plugin_type): """