diff --git a/.travis-data/test_daemon.py b/.travis-data/test_daemon.py index 302939f0ec..fd9b3917bf 100644 --- a/.travis-data/test_daemon.py +++ b/.travis-data/test_daemon.py @@ -20,9 +20,10 @@ ParameterData = DataFactory('parameter') codename = 'doubler@torquessh' -timeout_secs = 4 * 60 # 4 minutes -number_calculations = 30 # Number of calculations to submit -number_workchains = 30 # Number of workchains to submit +timeout_secs = 4 * 60 # 4 minutes +number_calculations = 30 # Number of calculations to submit +number_workchains = 30 # Number of workchains to submit + def print_daemon_log(): home = os.environ['HOME'] @@ -35,22 +36,25 @@ def print_daemon_log(): except subprocess.CalledProcessError as e: print "Note: the command failed, message: {}".format(e.message) + def jobs_have_finished(pks): finished_list = [load_node(pk).has_finished() for pk in pks] num_finished = len([_ for _ in finished_list if _]) print "{}/{} finished".format(num_finished, len(finished_list)) return not (False in finished_list) + def print_logshow(pk): print "Output of 'verdi calculation logshow {}':".format(pk) try: print subprocess.check_output( ["verdi", "calculation", "logshow", "{}".format(pk)], stderr=subprocess.STDOUT, - ) + ) except subprocess.CalledProcessError as e2: print "Note: the command failed, message: {}".format(e2.message) + def validate_calculations(expected_results): valid = True actual_dict = {} @@ -81,6 +85,7 @@ def validate_calculations(expected_results): return valid + def validate_workchains(expected_results): valid = True for pk, expected_value in expected_results.iteritems(): @@ -98,6 +103,7 @@ def validate_workchains(expected_results): return valid + def validate_cached(cached_calcs): """ Check that the calculations with created with caching are indeed cached. @@ -108,21 +114,22 @@ def validate_cached(cached_calcs): for calc in cached_calcs ) + def create_calculation(code, counter, inputval, use_cache=False): parameters = ParameterData(dict={'value': inputval}) template = ParameterData(dict={ - ## The following line adds a significant sleep time. - ## I set it to 1 second to speed up tests - ## I keep it to a non-zero value because I want - ## To test the case when AiiDA finds some calcs - ## in a queued state - #'cmdline_params': ["{}".format(counter % 3)], # Sleep time - 'cmdline_params': ["1"], - 'input_file_template': "{value}", # File just contains the value to double - 'input_file_name': 'value_to_double.txt', - 'output_file_name': 'output.txt', - 'retrieve_temporary_files': ['triple_value.tmp'] - }) + ## The following line adds a significant sleep time. + ## I set it to 1 second to speed up tests + ## I keep it to a non-zero value because I want + ## To test the case when AiiDA finds some calcs + ## in a queued state + # 'cmdline_params': ["{}".format(counter % 3)], # Sleep time + 'cmdline_params': ["1"], + 'input_file_template': "{value}", # File just contains the value to double + 'input_file_name': 'value_to_double.txt', + 'output_file_name': 'output.txt', + 'retrieve_temporary_files': ['triple_value.tmp'] + }) calc = code.new_calc() calc.set_max_wallclock_seconds(5 * 60) # 5 min calc.set_resources({"num_machines": 1}) @@ -138,10 +145,10 @@ def create_calculation(code, counter, inputval, use_cache=False): 'triple_value.tmp': str(inputval * 3) } } - print "[{}] created calculation {}, pk={}".format( - counter, calc.uuid, calc.dbnode.pk) + print "[{}] created calculation {}, pk={}".format(counter, calc.uuid, calc.pk) return calc, expected_result + def submit_calculation(code, counter, inputval): calc, expected_result = create_calculation( code=code, counter=counter, inputval=inputval @@ -150,6 +157,7 @@ def submit_calculation(code, counter, inputval): print "[{}] calculation submitted.".format(counter) return calc, expected_result + def create_cache_calc(code, counter, inputval): calc, expected_result = create_calculation( code=code, counter=counter, inputval=inputval, use_cache=True @@ -157,8 +165,8 @@ def create_cache_calc(code, counter, inputval): print "[{}] created cached calculation.".format(counter) return calc, expected_result -def main(): +def main(): # Submitting the Calculations print "Submitting {} calculations to the daemon".format(number_calculations) code = Code.get_from_string(codename) @@ -178,7 +186,6 @@ def main(): future = submit(ParentWorkChain, inp=inp) expected_results_workchains[future.pid] = index * 2 - calculation_pks = sorted(expected_results_calculations.keys()) workchains_pks = sorted(expected_results_workchains.keys()) pks = calculation_pks + workchains_pks @@ -187,14 +194,14 @@ def main(): start_time = time.time() exited_with_timeout = True while time.time() - start_time < timeout_secs: - time.sleep(15) # Wait a few seconds + time.sleep(15) # Wait a few seconds # Print some debug info, both for debugging reasons and to avoid # that the test machine is shut down because there is no output - print "#"*78 + print "#" * 78 print "####### TIME ELAPSED: {} s".format(time.time() - start_time) - print "#"*78 + print "#" * 78 print "Output of 'verdi calculation list -a':" try: print subprocess.check_output( @@ -244,8 +251,8 @@ def main(): cached_calcs.append(calc) expected_results_calculations[calc.pk] = expected_result if (validate_calculations(expected_results_calculations) - and validate_workchains(expected_results_workchains) - and validate_cached(cached_calcs)): + and validate_workchains(expected_results_workchains) + and validate_cached(cached_calcs)): print_daemon_log() print "" print "OK, all calculations have the expected parsed result" diff --git a/aiida/backends/djsite/db/subtests/djangomigrations.py b/aiida/backends/djsite/db/subtests/djangomigrations.py index 1db891ea3b..5add41fd75 100644 --- a/aiida/backends/djsite/db/subtests/djangomigrations.py +++ b/aiida/backends/djsite/db/subtests/djangomigrations.py @@ -38,7 +38,7 @@ def test_unexpected_calc_states(self): job = JobCalculation(**calc_params) job.store() # Now save the errant state - DbCalcState(dbnode=job.dbnode, state=state).save() + DbCalcState(dbnode=job._dbnode, state=state).save() time_before_fix = timezone.now() diff --git a/aiida/backends/djsite/db/subtests/generic.py b/aiida/backends/djsite/db/subtests/generic.py index bdc5c60614..9f2312e49e 100644 --- a/aiida/backends/djsite/db/subtests/generic.py +++ b/aiida/backends/djsite/db/subtests/generic.py @@ -15,7 +15,6 @@ from aiida.orm.node import Node - class TestComputer(AiidaTestCase): """ Test the Computer class. @@ -42,8 +41,6 @@ def test_deletion(self): _ = JobCalculation(**calc_params).store() - #print "Node stored with pk:", _.dbnode.pk - # This should fail, because there is at least a calculation # using this computer (the one created just above) with self.assertRaises(InvalidOperation): @@ -123,31 +120,32 @@ class TestDbExtrasDjango(AiidaTestCase): """ Test DbAttributes. """ + def test_replacement_1(self): from aiida.backends.djsite.db.models import DbExtra n1 = Node().store() n2 = Node().store() - DbExtra.set_value_for_node(n1.dbnode, "pippo", [1, 2, 'a']) - DbExtra.set_value_for_node(n1.dbnode, "pippobis", [5, 6, 'c']) - DbExtra.set_value_for_node(n2.dbnode, "pippo2", [3, 4, 'b']) + DbExtra.set_value_for_node(n1._dbnode, "pippo", [1, 2, 'a']) + DbExtra.set_value_for_node(n1._dbnode, "pippobis", [5, 6, 'c']) + DbExtra.set_value_for_node(n2._dbnode, "pippo2", [3, 4, 'b']) self.assertEquals(n1.get_extras(), {'pippo': [1, 2, 'a'], - 'pippobis': [5, 6, 'c'], - '_aiida_hash': n1.get_hash() - }) + 'pippobis': [5, 6, 'c'], + '_aiida_hash': n1.get_hash() + }) self.assertEquals(n2.get_extras(), {'pippo2': [3, 4, 'b'], '_aiida_hash': n2.get_hash() }) new_attrs = {"newval1": "v", "newval2": [1, {"c": "d", "e": 2}]} - DbExtra.reset_values_for_node(n1.dbnode, attributes=new_attrs) + DbExtra.reset_values_for_node(n1._dbnode, attributes=new_attrs) self.assertEquals(n1.get_extras(), new_attrs) self.assertEquals(n2.get_extras(), {'pippo2': [3, 4, 'b'], '_aiida_hash': n2.get_hash()}) - DbExtra.del_value_for_node(n1.dbnode, key='newval2') + DbExtra.del_value_for_node(n1._dbnode, key='newval2') del new_attrs['newval2'] self.assertEquals(n1.get_extras(), new_attrs) # Also check that other nodes were not damaged diff --git a/aiida/backends/sqlalchemy/__init__.py b/aiida/backends/sqlalchemy/__init__.py index 56749c9cc1..3910365ea7 100644 --- a/aiida/backends/sqlalchemy/__init__.py +++ b/aiida/backends/sqlalchemy/__init__.py @@ -10,7 +10,7 @@ # The next two serve as 'global' variables, set in the load_dbenv # call. They are properly reset upon forking. -engine = None +engine = None scopedsessionclass = None @@ -28,4 +28,3 @@ def get_scoped_session(): s = scopedsessionclass() return s - diff --git a/aiida/backends/sqlalchemy/models/base.py b/aiida/backends/sqlalchemy/models/base.py index 965a9dd279..48d52cd5ac 100644 --- a/aiida/backends/sqlalchemy/models/base.py +++ b/aiida/backends/sqlalchemy/models/base.py @@ -16,12 +16,11 @@ import aiida.backends.sqlalchemy from aiida.common.exceptions import InvalidOperation + # Taken from # https://github.com/mitsuhiko/flask-sqlalchemy/blob/master/flask_sqlalchemy/__init__.py#L491 - - class _QueryProperty(object): def __init__(self, query_class=orm.Query): @@ -63,8 +62,8 @@ def __iter__(self): from aiida.backends.sqlalchemy import get_scoped_session -class Model(object): +class Model(object): query = _QueryProperty() session = _SessionProperty() @@ -81,4 +80,6 @@ def delete(self, commit=True): sess.delete(self) if commit: sess.commit() + + Base = declarative_base(cls=Model, name='Model') diff --git a/aiida/backends/sqlalchemy/models/node.py b/aiida/backends/sqlalchemy/models/node.py index 82de0f729a..d8e264620d 100644 --- a/aiida/backends/sqlalchemy/models/node.py +++ b/aiida/backends/sqlalchemy/models/node.py @@ -114,7 +114,7 @@ class DbNode(Base): # User user = relationship( 'DbUser', - backref=backref('dbnodes', passive_deletes='all', cascade='merge',) + backref=backref('dbnodes', passive_deletes='all', cascade='merge', ) ) # outputs via db_dblink table @@ -127,14 +127,6 @@ class DbNode(Base): passive_deletes=True ) - @property - def outputs(self): - return self.outputs_q.all() - - @property - def inputs(self): - return self.inputs_q.all() - def __init__(self, *args, **kwargs): super(DbNode, self).__init__(*args, **kwargs) @@ -144,6 +136,13 @@ def __init__(self, *args, **kwargs): if self.extras is None: self.extras = dict() + @property + def outputs(self): + return self.outputs_q.all() + + @property + def inputs(self): + return self.inputs_q.all() # XXX repetition between django/sqlalchemy here. def get_aiida_class(self): @@ -215,8 +214,7 @@ def del_extra(self, key): @staticmethod def _set_attr(d, key, value): if '.' in key: - raise ValueError( - "We don't know how to treat key with dot in it yet") + raise ValueError("We don't know how to treat key with dot in it yet") d[key] = value @@ -273,10 +271,9 @@ def computer_name(cls): database) """ return select([DbComputer.name]).where(DbComputer.id == - cls.dbcomputer_id).label( + cls.dbcomputer_id).label( 'computer_name') - @hybrid_property def state(self): """ @@ -286,10 +283,10 @@ def state(self): return None all_states = DbCalcState.query.filter(DbCalcState.dbnode_id == self.id).all() if all_states: - #return max((st.time, st.state) for st in all_states)[1] + # return max((st.time, st.state) for st in all_states)[1] return sort_states(((dbcalcstate.state, dbcalcstate.state.value) for dbcalcstate in all_states), - use_key=True)[0] + use_key=True)[0] else: return None @@ -306,7 +303,7 @@ def state(cls): in enumerate(_sorted_datastates[::-1], start=1)} custom_sort_order = case(value=DbCalcState.state, whens=whens, - else_=100) # else: high value to put it at the bottom + else_=100) # else: high value to put it at the bottom # Add numerical state to string, to allow to sort them states_with_num = select([ @@ -329,7 +326,7 @@ def state(cls): DbCalcState.state.label('state_string'), calc_state_num.c.recent_state.label('recent_state'), custom_sort_order.label('num_state'), - ]).select_from(#DbCalcState).alias().join( + ]).select_from( # DbCalcState).alias().join( join(DbCalcState, calc_state_num, DbCalcState.dbnode_id == calc_state_num.c.dbnode_id)).alias() # Get the association between each calc and only its corresponding most-recent-state row @@ -339,10 +336,10 @@ def state(cls): ]).select_from(all_states_q).where(all_states_q.c.num_state == all_states_q.c.recent_state).alias() # Final filtering for the actual query - return select([subq.c.state]).\ + return select([subq.c.state]). \ where( - subq.c.dbnode_id == cls.id, - ).\ + subq.c.dbnode_id == cls.id, + ). \ label('laststate') @@ -388,4 +385,3 @@ def __str__(self): self.output.get_simple_name(invalid_result="Unknown node"), self.output.pk ) - diff --git a/aiida/backends/sqlalchemy/models/workflow.py b/aiida/backends/sqlalchemy/models/workflow.py index 7421bdcd5b..981c425f1e 100644 --- a/aiida/backends/sqlalchemy/models/workflow.py +++ b/aiida/backends/sqlalchemy/models/workflow.py @@ -249,7 +249,7 @@ def set_value(self, arg): raise ValueError("Cannot add an unstored node as an " "attribute of a Workflow!") sess = get_scoped_session() - self.aiida_obj = sess.merge(arg.dbnode, load=True) + self.aiida_obj = sess.merge(arg._dbnode, load=True) self.value_type = wf_data_value_types.AIIDA self.save() else: @@ -325,7 +325,7 @@ def add_calculation(self, step_calculation): raise ValueError("Cannot add a non-Calculation object to a workflow step") try: - self.calculations.append(step_calculation.dbnode) + self.calculations.append(step_calculation._dbnode) except: raise ValueError("Error adding calculation to step") diff --git a/aiida/backends/sqlalchemy/tests/session.py b/aiida/backends/sqlalchemy/tests/session.py index f889bb3f66..08b1b8ef5a 100644 --- a/aiida/backends/sqlalchemy/tests/session.py +++ b/aiida/backends/sqlalchemy/tests/session.py @@ -12,8 +12,8 @@ """ import os - from sqlalchemy.orm import sessionmaker +import unittest import aiida.backends from aiida.backends.testbase import AiidaTestCase @@ -80,7 +80,7 @@ def test_session_update_and_expiration_1(self): code = Code() code.set_remote_computer_exec((computer, '/x.x')) - session.add(code.dbnode) + session.add(code._dbnode) session.commit() self.drop_connection() @@ -145,7 +145,7 @@ def test_session_update_and_expiration_3(self): code = Code() code.set_remote_computer_exec((computer, '/x.x')) - session.add(code.dbnode) + session.add(code._dbnode) session.commit() self.drop_connection() @@ -201,8 +201,8 @@ def test_session_wfdata(self): n.store() # Keep some useful information - n_id = n.dbnode.id - old_dbnode = n.dbnode + n_id = n.id + old_dbnode = n._dbnode # Get the session sess = get_scoped_session() @@ -219,10 +219,75 @@ def test_session_wfdata(self): # Remove everything from the session sess.expunge_all() - # Add the dbnode that was firstly added to the session + # Add the dbnode that was originally added to the session sess.add(old_dbnode) # Add as attribute the node that was added after the first cleanup # of the session # At this point the following command should not fail - wf.add_attribute('a', n_reloaded) \ No newline at end of file + wf.add_attribute('a', n_reloaded) + + def test_node_access_with_sessions(self): + from aiida.utils import timezone + from aiida.orm.node import Node + import aiida.backends.sqlalchemy as sa + from sqlalchemy.orm import sessionmaker + from aiida.orm.implementation.sqlalchemy.node import DbNode + + Session = sessionmaker(bind=sa.engine) + custom_session = Session() + + node = Node().store() + master_session = node._dbnode.session + self.assertIsNot(master_session, custom_session) + + # Manually load the DbNode in a different session + dbnode_reloaded = custom_session.query(DbNode).get(node.id) + + # Now, go through one by one changing the possible attributes (of the model) + # and check that they're updated when the user reads them from the aiida node + + def check_attrs_match(name): + node_attr = getattr(node, name) + dbnode_attr = getattr(dbnode_reloaded, name) + self.assertEqual( + node_attr, dbnode_attr, + "Values of '{}' don't match ({} != {})".format(name, node_attr, dbnode_attr)) + + def do_value_checks(attr_name, original, changed): + try: + setattr(node, attr_name, original) + except AttributeError: + # This may mean that it is immutable, but we should still be able to + # change it below directly through the dbnode + pass + # Refresh the custom session and make sure they match + custom_session.refresh(dbnode_reloaded, attribute_names=[str_attr]) + check_attrs_match(attr_name) + + # Change the value in the custom session via the DbNode + setattr(dbnode_reloaded, attr_name, changed) + custom_session.commit() + + # Check that the Node 'sees' the change + check_attrs_match(str_attr) + + for str_attr in ['label', 'description']: + do_value_checks(str_attr, 'original', 'changed') + + do_value_checks('nodeversion', 1, 2) + do_value_checks('public', True, False) + + # Attributes + self.assertDictEqual(node._attributes(), dbnode_reloaded.attributes) + dbnode_reloaded.attributes['test_attrs'] = 'Boo!' + custom_session.commit() + self.assertDictEqual(node._attributes(), dbnode_reloaded.attributes) + + # Extras + self.assertDictEqual(node.get_extras(), dbnode_reloaded.extras) + dbnode_reloaded.extras['test_extras'] = 'Boo!' + custom_session.commit() + self.assertDictEqual(node._attributes(), dbnode_reloaded.attributes) + + diff --git a/aiida/backends/sqlalchemy/utils.py b/aiida/backends/sqlalchemy/utils.py index 4b0a935526..b8355217d6 100644 --- a/aiida/backends/sqlalchemy/utils.py +++ b/aiida/backends/sqlalchemy/utils.py @@ -9,16 +9,17 @@ ########################################################################### - try: import ultrajson as json from functools import partial + # double_precision = 15, to replicate what PostgreSQL numerical type is # using json_dumps = partial(json.dumps, double_precision=15) json_loads = partial(json.loads, precise_float=True) except ImportError: import json + json_dumps = json.dumps json_loads = json.loads @@ -41,6 +42,7 @@ ALEMBIC_FILENAME = "alembic.ini" ALEMBIC_REL_PATH = "migrations" + # def is_dbenv_loaded(): # """ # Return if the environment has already been loaded or not. @@ -57,6 +59,7 @@ def recreate_after_fork(engine): sa.engine.dispose() sa.scopedsessionclass = scoped_session(sessionmaker(bind=sa.engine, expire_on_commit=True)) + def reset_session(config): """ :param config: the configuration of the profile from the @@ -100,6 +103,7 @@ def _load_dbenv_noschemacheck(process=None, profile=None, connection=None): config = get_profile_config(settings.AIIDADB_PROFILE) reset_session(config) + _aiida_autouser_cache = None @@ -166,6 +170,7 @@ def f(v): return json_dumps(f(d)) + date_reg = re.compile(r'^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+(\+\d{2}:\d{2})?$') @@ -197,7 +202,6 @@ def f(d): return f(ret) - # XXX the code here isn't different from the one use in Django. We may be able # to refactor it in some way def install_tc(session): @@ -488,9 +492,9 @@ def get_db_version(rev, _): return [] with EnvironmentContext( - config, - script, - fn=get_db_version + config, + script, + fn=get_db_version ): script.run_env() return config.attributes['rev'] diff --git a/aiida/backends/tests/dataclasses.py b/aiida/backends/tests/dataclasses.py index 0cfafc0a60..bf9dfd321a 100644 --- a/aiida/backends/tests/dataclasses.py +++ b/aiida/backends/tests/dataclasses.py @@ -1843,7 +1843,7 @@ def test_reload(self): a.store() - b = StructureData(dbnode=a.dbnode) + b = load_node(uuid=a.uuid) for i in range(3): for j in range(3): @@ -2704,7 +2704,7 @@ def test_creation(self): self.assertEquals(second.shape, n.get_shape('second')) # Same checks, after reloading - n2 = ArrayData(dbnode=n.dbnode) + n2 = load_node(uuid=n.uuid) self.assertEquals(set(['first', 'second']), set(n2.arraynames())) self.assertAlmostEquals(abs(first - n2.get_array('first')).max(), 0.) self.assertAlmostEquals(abs(second - n2.get_array('second')).max(), 0.) diff --git a/aiida/backends/tests/nodes.py b/aiida/backends/tests/nodes.py index c1364873f9..1939aba5c6 100644 --- a/aiida/backends/tests/nodes.py +++ b/aiida/backends/tests/nodes.py @@ -23,10 +23,12 @@ from aiida.orm.node import Node from aiida.orm.utils import load_node + class TestNodeHashing(AiidaTestCase): """ Tests the functionality of hashing a node """ + @staticmethod def create_simple_node(a, b=0, c=0): n = Node() @@ -67,8 +69,8 @@ def test_folder_file_different(self): f2 = self.create_folderdata_with_empty_folder() assert ( - f1.folder.get_subfolder('path').get_content_list() == - f2.folder.get_subfolder('path').get_content_list() + f1.folder.get_subfolder('path').get_content_list() == + f2.folder.get_subfolder('path').get_content_list() ) assert f1.get_hash() != f2.get_hash() @@ -106,6 +108,7 @@ def test_unequal_arrays(self): (np.zeros(1001), np.zeros(1005)), (np.array([1, 2, 3]), np.array([2, 3, 4])) ] + def create_arraydata(arr): a = ArrayData() a.set_array('a', arr) @@ -130,6 +133,7 @@ def test_updatable_attributes(self): self.assertNotEquals(hash1, None) self.assertEquals(hash1, hash2) + class TestDataNode(AiidaTestCase): """ These tests check the features of Data nodes that differ from the base Node @@ -487,7 +491,6 @@ def test_get_attrs_after_storing(self): # Now I check if I can retrieve them, before the storage self.assertEquals(a.get_attrs(), target_attrs) - def test_store_object(self): """Trying to store objects should fail""" a = Node() @@ -523,7 +526,7 @@ def test_append_no_side_effects(self): self.assertEquals(a.get_attr('list'), [1, 2, 3, 4]) self.assertEquals(mylist, [1, 2, 3]) - #pylint: disable=no-self-use,unused-argument,unused-variable,no-member + # pylint: disable=no-self-use,unused-argument,unused-variable,no-member def DISABLED(self): """ This test routine is disabled for the time being; I will re-enable @@ -559,7 +562,7 @@ def test_very_deep_attributes(self): all_keys = models.DbAttribute.objects.filter( dbnode=n.dbnode).values_list( - 'key', flat=True) + 'key', flat=True) print max(len(i) for i in all_keys) @@ -1312,9 +1315,9 @@ def test_comments(self): self.assertEquals([(i['user__email'], i['content']) for i in comments], [ - (self.user_email, 'text'), - (self.user_email, 'text2'), - ]) + (self.user_email, 'text'), + (self.user_email, 'text2'), + ]) def test_code_loading_from_string(self): """ @@ -1378,22 +1381,22 @@ def test_get_subclass_from_pk(self): # Check that you can load it with a simple integer id. a2 = Node.get_subclass_from_pk(a1.id) self.assertEquals(a1.id, a2.id, "The ids of the stored and loaded node" - "should be equal (since it should be " - "the same node") + "should be equal (since it should be " + "the same node") # Check that you can load it with an id of type long. # a3 = Node.get_subclass_from_pk(long(a1.id)) a3 = Node.get_subclass_from_pk(long(a1.id)) self.assertEquals(a1.id, a3.id, "The ids of the stored and loaded node" - "should be equal (since it should be " - "the same node") + "should be equal (since it should be " + "the same node") # Check that it manages to load the node even if the id is # passed as a string. a4 = Node.get_subclass_from_pk(str(a1.id)) self.assertEquals(a1.id, a4.id, "The ids of the stored and loaded node" - "should be equal (since it should be " - "the same node") + "should be equal (since it should be " + "the same node") # Check that a ValueError exception is raised when a string that can # not be casted to integer is passed. @@ -1575,8 +1578,12 @@ def test_load_node(self): with self.assertRaises(NotExistent): load_node(spec, parent_class=ArrayData) - def test_load_plugin_safe(self): - from aiida.orm import (JobCalculation, CalculationFactory, DataFactory) + def test_load_unknown_calculation_type(self): + """ + Test that the loader will choose a common calculation ancestor for an unknown data type. + For the case where, e.g., the user doesn't have the necessary plugin. + """ + from aiida.orm import (JobCalculation, CalculationFactory) ###### for calculation calc_params = { @@ -1586,37 +1593,47 @@ def test_load_plugin_safe(self): TemplateReplacerCalc = CalculationFactory('simpleplugins.templatereplacer') testcalc = TemplateReplacerCalc(**calc_params).store() - jobcalc = JobCalculation(**calc_params).store() # compare if plugin exist - obj = testcalc.dbnode.get_aiida_class() + obj = load_node(uuid=testcalc.uuid) self.assertEqual(type(testcalc), type(obj)) - # change node type and save in database again - testcalc.dbnode.type = "calculation.job.simpleplugins_tmp.templatereplacer.TemplatereplacerCalculation." - testcalc.dbnode.save() + # Create a custom calculation type that inherits from JobCalculation + class TestCalculation(JobCalculation): + pass + + jobcalc = JobCalculation(**calc_params).store() + testcalc = TestCalculation(**calc_params).store() # changed node should return job calc as its plugin is not exist - obj = testcalc.dbnode.get_aiida_class() + obj = load_node(uuid=testcalc.uuid) self.assertEqual(type(jobcalc), type(obj)) - ####### for data + def test_load_unknown_data_type(self): + """ + Test that the loader will choose a common data ancestor for an unknown data type. + For the case where, e.g., the user doesn't have the necessary plugin. + """ + from aiida.orm import DataFactory + KpointsData = DataFactory('array.kpoints') kpoint = KpointsData().store() Data = DataFactory("Data") data = Data().store() # compare if plugin exist - obj = kpoint.dbnode.get_aiida_class() + obj = load_node(uuid=kpoint.uuid) self.assertEqual(type(kpoint), type(obj)) + class TestKpointsData(KpointsData): + pass + # change node type and save in database again - kpoint.dbnode.type = "data.array.kpoints_tmp.KpointsData." - kpoint.dbnode.save() + test_kpoint = TestKpointsData().store() # changed node should return data node as its plugin is not exist - obj = kpoint.dbnode.get_aiida_class() - self.assertEqual(type(data), type(obj)) + obj = load_node(uuid=test_kpoint.uuid) + self.assertEqual(type(kpoint), type(obj)) ###### for node n1 = Node().store() @@ -1624,7 +1641,6 @@ def test_load_plugin_safe(self): self.assertEqual(type(n1), type(obj)) - class TestSubNodesAndLinks(AiidaTestCase): def test_cachelink(self): @@ -1668,7 +1684,7 @@ def test_cachelink(self): self.assertEqual( set([(i[0], i[1].uuid) for i in endnode.get_inputs(only_in_db=True, also_labels=True) - ]), set()) + ]), set()) self.assertEqual( set([(i[0], i[1].uuid) for i in endnode.get_inputs(also_labels=True)]), @@ -1681,7 +1697,7 @@ def test_cachelink(self): self.assertEqual( set([(i[0], i[1].uuid) for i in endnode.get_inputs(only_in_db=True, also_labels=True) - ]), + ]), set([("N1", n1.uuid), ("N2", n2.uuid), ("N3", n3.uuid), ("N4", n4.uuid)])) self.assertEqual( @@ -1716,7 +1732,7 @@ def test_store_with_unstored_parents(self): self.assertEqual( set([(i[0], i[1].uuid) for i in endnode.get_inputs(only_in_db=True, also_labels=True) - ]), set([("N1", n1.uuid), ("N2", n2.uuid)])) + ]), set([("N1", n1.uuid), ("N2", n2.uuid)])) self.assertEqual( set([(i[0], i[1].uuid) for i in endnode.get_inputs(also_labels=True)]), @@ -1765,13 +1781,13 @@ def test_has_children_has_parents(self): n2.add_link_from(n1, "N1", link_type=LinkType.CREATE) self.assertTrue(n1.has_children, "It should be true since n2 is the " - "child of n1.") + "child of n1.") self.assertFalse(n1.has_parents, "It should be false since n1 doesn't " - "have any parents.") + "have any parents.") self.assertFalse(n2.has_children, "It should be false since n2 " - "doesn't have any children.") + "doesn't have any children.") self.assertTrue(n2.has_parents, "It should be true since n1 is the " - "parent of n2.") + "parent of n2.") def test_use_code(self): from aiida.orm import JobCalculation @@ -1814,7 +1830,7 @@ def test_use_code(self): self.assertEqual(calc.get_code().uuid, code.uuid) self.assertEqual(unstoredcalc.get_code().uuid, code.uuid) - #pylint: disable=unused-variable,no-member,no-self-use + # pylint: disable=unused-variable,no-member,no-self-use def test_calculation_load(self): from aiida.orm import JobCalculation @@ -2190,13 +2206,16 @@ def test_node_get_inputs_link_type_unstored(self): self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.CALL)), 1) self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.INPUT)), 1) + class AnyValue(object): """ Helper class that compares equal to everything. """ + def __eq__(self, other): return True + class TestNodeDeletion(AiidaTestCase): def _check_existence(self, uuids_check_existence, uuids_check_deleted): @@ -2226,7 +2245,6 @@ def _create_calls_n_returns_graph(self): """ from aiida.common.links import LinkType - in1, in2, wf, slave1, outp1, outp2, slave2, outp3, outp4 = [Node().store() for i in range(9)] wf.add_link_from(in1, link_type=LinkType.INPUT) slave1.add_link_from(in1, link_type=LinkType.INPUT) @@ -2255,7 +2273,7 @@ def test_deletion_simple(self): # Now I am linking the nodes in a branched network # Connecting nodes 1,2,3 to 0 - for i in range(1,4): + for i in range(1, 4): nodes[i].add_link_from(nodes[0], link_type=LinkType.INPUT) # Connecting nodes 4,5,6 to 3 for i in range(4, 7): @@ -2267,11 +2285,10 @@ def test_deletion_simple(self): for i in range(9, 10): nodes[i].add_link_from(nodes[5], link_type=LinkType.INPUT) for i in range(10, 14): - nodes[i+1].add_link_from(nodes[i], link_type=LinkType.INPUT) + nodes[i + 1].add_link_from(nodes[i], link_type=LinkType.INPUT) delete_nodes((nodes[3].pk, nodes[10].pk), force=True, verbosity=0) self._check_existence(uuids_check_existence, uuids_check_deleted) - def test_deletion_with_calls_with_returns(self): """ Checking the case where I follow calls and return links for deletion @@ -2343,7 +2360,7 @@ def test_deletion_with_returns_n_loops(self): wf.add_link_from(in2, link_type=LinkType.INPUT) in2.add_link_from(wf, link_type=LinkType.RETURN) - uuids_check_existence = (in1.uuid, ) + uuids_check_existence = (in1.uuid,) uuids_check_deleted = [n.uuid for n in (wf, in2)] delete_nodes([wf.pk], verbosity=0, force=True, follow_returns=True) diff --git a/aiida/cmdline/commands/group.py b/aiida/cmdline/commands/group.py index 88b4e24eb4..fa6c7ae50c 100644 --- a/aiida/cmdline/commands/group.py +++ b/aiida/cmdline/commands/group.py @@ -220,11 +220,8 @@ def group_show(self, *args): if parsed_args.uuid: row.append(n.uuid) row.append(n.pk) - row.append(from_type_to_pluginclassname(n.dbnode.type). - rsplit(".", 1)[1]) - - row.append(str_timedelta(now - n.ctime, short=True, - negative_to_zero=True)) + row.append(from_type_to_pluginclassname(n.type).rsplit(".", 1)[1]) + row.append(str_timedelta(now - n.ctime, short=True, negative_to_zero=True)) table.append(row) print(tabulate(table, headers=header)) diff --git a/aiida/orm/data/array/kpoints.py b/aiida/orm/data/array/kpoints.py index 402df3498b..b9a4fc3f15 100644 --- a/aiida/orm/data/array/kpoints.py +++ b/aiida/orm/data/array/kpoints.py @@ -52,7 +52,7 @@ def get_desc(self): try: return '(Path of {} kpts)'.format(len(self.get_kpoints())) except OSError: - return self.dbnode.type + return self.type @property def cell(self): diff --git a/aiida/orm/implementation/django/node.py b/aiida/orm/implementation/django/node.py index 1113a1afe7..577bea0209 100644 --- a/aiida/orm/implementation/django/node.py +++ b/aiida/orm/implementation/django/node.py @@ -22,6 +22,7 @@ from aiida.common.links import LinkType from aiida.common.utils import get_new_uuid from aiida.orm.implementation.general.node import AbstractNode, _NO_DEFAULT, _HASH_EXTRA_KEY +from aiida.orm.implementation.django.computer import Computer from aiida.orm.mixins import Sealable # from aiida.orm.implementation.django.utils import get_db_columns from aiida.orm.implementation.general.utils import get_db_columns @@ -42,7 +43,6 @@ def get_subclass_from_uuid(cls, uuid): @staticmethod def get_db_columns(): - # from aiida.backends.djsite.db.models import DbNode from aiida.backends.djsite.querybuilder_django.dummy_model import DbNode return get_db_columns(DbNode) @@ -58,6 +58,29 @@ def get_subclass_from_pk(cls, pk): pk, cls.__name__)) return node + @classmethod + def query(cls, *args, **kwargs): + from aiida.backends.djsite.db.models import DbNode + if cls._plugin_type_string: + if not cls._plugin_type_string.endswith('.'): + raise InternalError("The plugin type string does not " + "finish with a dot??") + + # If it is 'calculation.Calculation.', we want to filter + # for things that start with 'calculation.' and so on + plug_type = cls._plugin_type_string + + # Remove the implementation.django or sqla part. + if plug_type.startswith('implementation.'): + plug_type = '.'.join(plug_type.split('.')[2:]) + pre, sep, _ = plug_type[:-1].rpartition('.') + superclass_string = "".join([pre, sep]) + return DbNode.aiidaobjects.filter( + *args, type__startswith=superclass_string, **kwargs) + else: + # Base Node class, with empty string + return DbNode.aiidaobjects.filter(*args, **kwargs) + def __init__(self, **kwargs): from aiida.backends.djsite.db.models import DbNode super(Node, self).__init__() @@ -87,7 +110,7 @@ def __init__(self, **kwargs): # If this is changed, fix also the importer self._repo_folder = RepositoryFolder(section=self._section_name, - uuid=self._dbnode.uuid) + uuid=self.uuid) # NO VALIDATION ON __init__ BY DEFAULT, IT IS TOO SLOW SINCE IT OFTEN # REQUIRES MULTIPLE DB HITS @@ -122,38 +145,33 @@ def __init__(self, **kwargs): # stop self._set_with_defaults(**kwargs) - @classmethod - def query(cls, *args, **kwargs): - from aiida.backends.djsite.db.models import DbNode - if cls._plugin_type_string: - if not cls._plugin_type_string.endswith('.'): - raise InternalError("The plugin type string does not " - "finish with a dot??") + @property + def type(self): + return self._dbnode.type - # If it is 'calculation.Calculation.', we want to filter - # for things that start with 'calculation.' and so on - plug_type = cls._plugin_type_string + @property + def ctime(self): + return self._dbnode.ctime - # Remove the implementation.django or sqla part. - if plug_type.startswith('implementation.'): - plug_type = '.'.join(plug_type.split('.')[2:]) - pre, sep, _ = plug_type[:-1].rpartition('.') - superclass_string = "".join([pre, sep]) - return DbNode.aiidaobjects.filter( - *args, type__startswith=superclass_string, **kwargs) - else: - # Base Node class, with empty string - return DbNode.aiidaobjects.filter(*args, **kwargs) + @property + def mtime(self): + return self._dbnode.mtime + + def _get_db_label_field(self): + return self._dbnode.label def _update_db_label_field(self, field_value): - self.dbnode.label = field_value + self._dbnode.label = field_value if not self._to_be_stored: with transaction.atomic(): self._dbnode.save() self._increment_version_number_db() + def _get_db_description_field(self): + return self._dbnode.description + def _update_db_description_field(self, field_value): - self.dbnode.description = field_value + self._dbnode.description = field_value if not self._to_be_stored: with transaction.atomic(): self._dbnode.save() @@ -169,7 +187,7 @@ def _replace_dblink_from(self, src, label, link_type): self._add_dblink_from(src, label, link_type) def _remove_dblink_from(self, label): - DbLink.objects.filter(output=self.dbnode, label=label).delete() + DbLink.objects.filter(output=self._dbnode, label=label).delete() def _add_dblink_from(self, src, label=None, link_type=LinkType.UNSPECIFIED): from aiida.orm.querybuilder import QueryBuilder @@ -195,8 +213,8 @@ def _add_dblink_from(self, src, label=None, link_type=LinkType.UNSPECIFIED): # I am linking src->self; a loop would be created if a DbPath exists already # in the TC table from self to src if QueryBuilder().append( - Node, filters={'id':self.pk}, tag='parent').append( - Node, filters={'id':src.pk}, tag='child', descendant_of='parent').count() > 0: + Node, filters={'id': self.pk}, tag='parent').append( + Node, filters={'id': src.pk}, tag='child', descendant_of='parent').count() > 0: raise ValueError( "The link you are attempting to create would generate a loop") @@ -204,7 +222,7 @@ def _add_dblink_from(self, src, label=None, link_type=LinkType.UNSPECIFIED): autolabel_idx = 1 existing_from_autolabels = list(DbLink.objects.filter( - output=self.dbnode, + output=self._dbnode, label__startswith="link_").values_list('label', flat=True)) while "link_{}".format(autolabel_idx) in existing_from_autolabels: autolabel_idx += 1 @@ -233,7 +251,7 @@ def _do_create_link(self, src, label, link_type): # transactions are needed here for Postgresql: # https://docs.djangoproject.com/en/1.5/topics/db/transactions/#handling-exceptions-within-postgresql-transactions sid = transaction.savepoint() - DbLink.objects.create(input=src.dbnode, output=self.dbnode, + DbLink.objects.create(input=src._dbnode, output=self._dbnode, label=label, type=link_type.value) transaction.savepoint_commit(sid) except IntegrityError as e: @@ -245,25 +263,35 @@ def _do_create_link(self, src, label, link_type): def _get_db_input_links(self, link_type): from aiida.backends.djsite.db.models import DbLink - link_filter = {'output': self.dbnode} + link_filter = {'output': self._dbnode} if link_type is not None: link_filter['type'] = link_type.value return [(i.label, i.input.get_aiida_class()) for i in DbLink.objects.filter(**link_filter).distinct()] - def _get_db_output_links(self, link_type): from aiida.backends.djsite.db.models import DbLink - link_filter = {'input': self.dbnode} + link_filter = {'input': self._dbnode} if link_type is not None: link_filter['type'] = link_type.value return ((i.label, i.output.get_aiida_class()) for i in DbLink.objects.filter(**link_filter).distinct()) + def get_computer(self): + """ + Get the computer associated to the node. + + :return: the Computer object or None. + """ + if self._dbnode.dbcomputer is None: + return None + else: + return Computer(dbcomputer=self._dbnode.dbcomputer) + def _set_db_computer(self, computer): from aiida.backends.djsite.db.models import DbComputer - self.dbnode.dbcomputer = DbComputer.get_dbcomputer(computer) + self._dbnode.dbcomputer = DbComputer.get_dbcomputer(computer) def _set_db_attr(self, key, value): """ @@ -277,26 +305,26 @@ def _set_db_attr(self, key, value): """ from aiida.backends.djsite.db.models import DbAttribute - DbAttribute.set_value_for_node(self.dbnode, key, value) + DbAttribute.set_value_for_node(self._dbnode, key, value) self._increment_version_number_db() def _del_db_attr(self, key): from aiida.backends.djsite.db.models import DbAttribute - if not DbAttribute.has_key(self.dbnode, key): + if not DbAttribute.has_key(self._dbnode, key): raise AttributeError("DbAttribute {} does not exist".format( key)) - DbAttribute.del_value_for_node(self.dbnode, key) + DbAttribute.del_value_for_node(self._dbnode, key) self._increment_version_number_db() def _get_db_attr(self, key): from aiida.backends.djsite.db.models import DbAttribute return DbAttribute.get_value_for_node( - dbnode=self.dbnode, key=key) + dbnode=self._dbnode, key=key) def _set_db_extra(self, key, value, exclusive=False): from aiida.backends.djsite.db.models import DbExtra - DbExtra.set_value_for_node(self.dbnode, key, value, + DbExtra.set_value_for_node(self._dbnode, key, value, stop_if_existing=exclusive) self._increment_version_number_db() @@ -306,27 +334,27 @@ def _reset_db_extras(self, new_extras): def _get_db_extra(self, key, *args): from aiida.backends.djsite.db.models import DbExtra - return DbExtra.get_value_for_node(dbnode=self.dbnode, + return DbExtra.get_value_for_node(dbnode=self._dbnode, key=key) def _del_db_extra(self, key): from aiida.backends.djsite.db.models import DbExtra - if not DbExtra.has_key(self.dbnode, key): + if not DbExtra.has_key(self._dbnode, key): raise AttributeError("DbExtra {} does not exist".format( key)) - return DbExtra.del_value_for_node(self.dbnode, key) + return DbExtra.del_value_for_node(self._dbnode, key) self._increment_version_number_db() def _db_iterextras(self): from aiida.backends.djsite.db.models import DbExtra - extraslist = DbExtra.list_all_node_elements(self.dbnode) + extraslist = DbExtra.list_all_node_elements(self._dbnode) for e in extraslist: yield (e.key, e.getvalue()) def _db_iterattrs(self): from aiida.backends.djsite.db.models import DbAttribute - all_attrs = DbAttribute.get_all_values_for_node(self.dbnode) + all_attrs = DbAttribute.get_all_values_for_node(self._dbnode) for attr in all_attrs: yield (attr, all_attrs[attr]) @@ -336,7 +364,7 @@ def _db_attrs(self): # calling iterattrs from here, because iterattrs is slow on each call # since it has to call .getvalue(). To improve! from aiida.backends.djsite.db.models import DbAttribute - attrlist = DbAttribute.list_all_node_elements(self.dbnode) + attrlist = DbAttribute.list_all_node_elements(self._dbnode) for attr in attrlist: yield attr.key @@ -437,11 +465,11 @@ def _increment_version_number_db(self): def copy(self, **kwargs): newobject = self.__class__() - newobject.dbnode.type = self.dbnode.type # Inherit type - newobject.dbnode.label = self.dbnode.label # Inherit label + newobject._dbnode.type = self._dbnode.type # Inherit type + newobject.label = self.label # Inherit label # TODO: add to the description the fact that this was a copy? - newobject.dbnode.description = self.dbnode.description # Inherit description - newobject.dbnode.dbcomputer = self.dbnode.dbcomputer # Inherit computer + newobject.description = self.description # Inherit description + newobject._dbnode.dbcomputer = self._dbnode.dbcomputer # Inherit computer for k, v in self.iterattrs(): if k != Sealable.SEALED_KEY: @@ -454,11 +482,11 @@ def copy(self, **kwargs): @property def uuid(self): - return unicode(self.dbnode.uuid) + return unicode(self._dbnode.uuid) @property def id(self): - return self.dbnode.id + return self._dbnode.id @property def dbnode(self): @@ -494,6 +522,8 @@ def _db_store_all(self, with_transaction=True, use_cache=None): return self + def get_user(self): + return self._dbnode.user def _store_cached_input_links(self, with_transaction=True): """ @@ -594,7 +624,7 @@ def _db_store(self, with_transaction=True): self._dbnode.save() # Save its attributes 'manually' without incrementing # the version for each add. - DbAttribute.reset_values_for_node(self.dbnode, + DbAttribute.reset_values_for_node(self._dbnode, attributes=self._attrs_cache, with_transaction=False) # This should not be used anymore: I delete it to @@ -619,6 +649,6 @@ def _db_store(self, with_transaction=True): from aiida.backends.djsite.db.models import DbExtra # I store the hash without cleaning and without incrementing the nodeversion number - DbExtra.set_value_for_node(self.dbnode, _HASH_EXTRA_KEY, self.get_hash()) + DbExtra.set_value_for_node(self._dbnode, _HASH_EXTRA_KEY, self.get_hash()) return self diff --git a/aiida/orm/implementation/general/node.py b/aiida/orm/implementation/general/node.py index 20466e9d34..14c0b32600 100644 --- a/aiida/orm/implementation/general/node.py +++ b/aiida/orm/implementation/general/node.py @@ -14,6 +14,7 @@ import logging import importlib import collections + try: import pathlib except ImportError: @@ -33,6 +34,7 @@ _NO_DEFAULT = tuple() _HASH_EXTRA_KEY = '_aiida_hash' + def clean_value(value): """ Get value from input and (recursively) replace, if needed, all occurrences @@ -151,8 +153,7 @@ def __new__(cls, name, bases, attrs): def get_desc(self): """ - Returns a string with infos retrieved from a node's - properties. + Returns a string with infos retrieved from a node's properties. This method is actually overwritten by the inheriting classes :return: a description string @@ -204,20 +205,6 @@ def get_subclass_from_pk(cls, pk): """ pass - @property - def ctime(self): - """ - Return the creation time of the node. - """ - return self.dbnode.ctime - - @property - def mtime(self): - """ - Return the modification time of the node. - """ - return self.dbnode.mtime - def __int__(self): if self._to_be_stored: return None @@ -248,6 +235,20 @@ def is_stored(self): """ return not self._to_be_stored + @abstractproperty + def ctime(self): + """ + Return the creation time of the node. + """ + pass + + @abstractproperty + def mtime(self): + """ + Return the modification time of the node. + """ + pass + def __repr__(self): return '<{}: {}>'.format(self.__class__.__name__, str(self)) @@ -259,8 +260,7 @@ def __str__(self): def _init_internal_params(self): """ - Set here the default values for this class; this method - is automatically called by the init. + Set the default values for this class; this method is automatically called by the init. :note: if you inherit this function, ALWAYS remember to call super()._init_internal_params() as the first thing @@ -348,7 +348,7 @@ def _set_internal(self, arguments, allow_hidden=False): raise ValueError("Cannot set {} directly when creating " "the node or using the .set() method; " "use the specific method instead.".format( - incomp[0])) + incomp[0])) else: raise ValueError("Cannot set {} at the same time".format( " and ".join(incomp))) @@ -367,6 +367,15 @@ def _set_internal(self, arguments, allow_hidden=False): "callable!".format(k)) method(v) + @abstractproperty + def type(self): + """ + Get the type of the node. + + :return: a string. + """ + pass + @property def label(self): """ @@ -374,7 +383,7 @@ def label(self): :return: a string. """ - return self.dbnode.label + return self._get_db_label_field() @label.setter def label(self, label): @@ -385,10 +394,19 @@ def label(self, label): """ self._update_db_label_field(label) + @abstractmethod + def _get_db_label_field(self): + """ + Get the label field acting directly on the DB + + :return: a string. + """ + pass + @abstractmethod def _update_db_label_field(self, field_value): """ - Update the label field acting directly on the DB + Set the label field acting directly on the DB """ pass @@ -398,8 +416,9 @@ def description(self): Get the description of the node. :return: a string + :rtype: str """ - return self.dbnode.description + return self._get_db_description_field() @description.setter def description(self, desc): @@ -410,6 +429,13 @@ def description(self, desc): """ self._update_db_description_field(desc) + @abstractmethod + def _get_db_description_field(self): + """ + Get the description of this node, acting directly at the DB level + """ + pass + @abstractmethod def _update_db_description_field(self, field_value): """ @@ -432,13 +458,14 @@ def _validate(self): """ return True + @abstractmethod def get_user(self): """ Get the user. - :return: a Django DbUser model object + :return: a DbUser model object """ - return self.dbnode.user + pass def _has_cached_links(self): """ @@ -688,7 +715,7 @@ def get_inputs(self, node_type=None, also_labels=False, only_in_db=False, link_t input_link_type = v[1] if label in input_list_keys: raise InternalError("There exist a link with the same name '{}' both in the DB " - "and in the internal cache for node pk= {}!".format(label, self.pk)) + "and in the internal cache for node pk= {}!".format(label, self.pk)) if link_type is None or input_link_type is link_type: inputs_list.append((label, src)) @@ -753,17 +780,14 @@ def _get_db_output_links(self, link_type): """ pass + @abstractmethod def get_computer(self): """ Get the computer associated to the node. :return: the Computer object or None. """ - from aiida.orm.computer import Computer - if self.dbnode.dbcomputer is None: - return None - else: - return Computer(dbcomputer=self.dbnode.dbcomputer) + pass def set_computer(self, computer): """ @@ -1135,7 +1159,7 @@ def iterextras(self): # added (in particular, we do not even have an ID to use!) # Return without value, meaning that this is an empty generator return - yield # Needed after return to convert it to a generator + yield # Needed after return to convert it to a generator for extra in self._db_iterextras(): yield extra @@ -1432,7 +1456,7 @@ def get_abs_path(self, path=None, section=None): raise ValueError("The path in get_abs_path must be relative") return self.folder.get_subfolder( section, reset_limit=True).get_abs_path( - path, check_existence=True) + path, check_existence=True) def store_all(self, with_transaction=True, use_cache=None): """ @@ -1636,7 +1660,6 @@ def _db_store(self, with_transaction=True): """ pass - def __del__(self): """ Called only upon real object destruction from memory @@ -1646,7 +1669,6 @@ def __del__(self): if getattr(self, '_temp_folder', None) is not None: self._temp_folder.erase() - def get_hash(self, ignore_errors=True, **kwargs): """ Making a hash based on my attributes @@ -1674,7 +1696,7 @@ def _get_objects_to_hash(self): if ( (key not in self._hash_ignored_attributes) and (key not in getattr(self, '_updatable_attributes', tuple())) - ) + ) }, self.folder, computer.uuid if computer is not None else None @@ -1756,7 +1778,6 @@ def inp(self): """ return NodeInputManager(self) - @property def has_children(self): """ @@ -1766,11 +1787,10 @@ def has_children(self): from aiida.orm.querybuilder import QueryBuilder from aiida.orm import Node first_desc = QueryBuilder().append( - Node, filters={'id':self.pk}, tag='self').append( + Node, filters={'id': self.pk}, tag='self').append( Node, descendant_of='self', project='id').first() return bool(first_desc) - @property def has_parents(self): """ @@ -1780,7 +1800,7 @@ def has_parents(self): from aiida.orm.querybuilder import QueryBuilder from aiida.orm import Node first_ancestor = QueryBuilder().append( - Node, filters={'id':self.pk}, tag='self').append( + Node, filters={'id': self.pk}, tag='self').append( Node, ancestor_of='self', project='id').first() return bool(first_ancestor) diff --git a/aiida/orm/implementation/sqlalchemy/calculation/job/__init__.py b/aiida/orm/implementation/sqlalchemy/calculation/job/__init__.py index cc96b69210..ab44d46282 100644 --- a/aiida/orm/implementation/sqlalchemy/calculation/job/__init__.py +++ b/aiida/orm/implementation/sqlalchemy/calculation/job/__init__.py @@ -71,9 +71,9 @@ def _set_state(self, state): "to {}".format(old_state, state)) try: - new_state = DbCalcState(dbnode=self.dbnode, state=state).save() + new_state = DbCalcState(dbnode=self._dbnode, state=state).save() except SQLAlchemyError: - self.dbnode.session.rollback() + self._dbnode.session.rollback() raise ModificationNotAllowed("Calculation pk= {} already transited through " "the state {}".format(self.pk, state)) @@ -102,7 +102,7 @@ def get_state(self, from_attribute=False): state_to_return = calc_states.NEW else: # In the sqlalchemy model, the state - most_recent_state = self.dbnode.state + most_recent_state = self._dbnode.state if most_recent_state: state_to_return = most_recent_state.value else: diff --git a/aiida/orm/implementation/sqlalchemy/code.py b/aiida/orm/implementation/sqlalchemy/code.py index 45dcdf4ef6..ac7f446df2 100644 --- a/aiida/orm/implementation/sqlalchemy/code.py +++ b/aiida/orm/implementation/sqlalchemy/code.py @@ -10,7 +10,6 @@ import os - from aiida.backends.sqlalchemy.models.computer import DbComputer from aiida.common.exceptions import NotExistent, MultipleObjectsError, InvalidOperation @@ -19,7 +18,6 @@ from aiida.orm.implementation.sqlalchemy.computer import Computer - class Code(AbstractCode): @classmethod @@ -48,7 +46,7 @@ def set_remote_computer_exec(self, remote_computer_exec): """ if (not isinstance(remote_computer_exec, (list, tuple)) - or len(remote_computer_exec) != 2): + or len(remote_computer_exec) != 2): raise ValueError("remote_computer_exec must be a list or tuple " "of length 2, with machine and executable " "name") @@ -58,15 +56,11 @@ def set_remote_computer_exec(self, remote_computer_exec): if not os.path.isabs(remote_exec_path): raise ValueError("exec_path must be an absolute path (on the remote machine)") - remote_dbcomputer = computer - if isinstance(remote_dbcomputer, Computer): - remote_dbcomputer = remote_dbcomputer.dbcomputer - if not (isinstance(remote_dbcomputer, DbComputer)): - raise TypeError("computer must be either a Computer or DbComputer object") + if not isinstance(computer, Computer): + raise TypeError("Computer must be of type Computer, got '{}'".format(type(computer))) self._set_remote() - - self.dbnode.dbcomputer = remote_dbcomputer + self.set_computer(computer) self._set_attr('remote_exec_path', remote_exec_path) def _set_local(self): @@ -78,7 +72,7 @@ def _set_local(self): It also deletes the flags related to the local case (if any) """ self._set_attr('is_local', True) - self.dbnode.dbcomputer = None + self._dbnode.dbcomputer = None try: self._del_attr('remote_exec_path') except AttributeError: diff --git a/aiida/orm/implementation/sqlalchemy/node.py b/aiida/orm/implementation/sqlalchemy/node.py index 1c14b6712a..e48b2b2cd5 100644 --- a/aiida/orm/implementation/sqlalchemy/node.py +++ b/aiida/orm/implementation/sqlalchemy/node.py @@ -41,7 +41,6 @@ import aiida.orm.autogroup - class Node(AbstractNode): def __init__(self, **kwargs): super(Node, self).__init__() @@ -71,7 +70,7 @@ def __init__(self, **kwargs): # If this is changed, fix also the importer self._repo_folder = RepositoryFolder(section=self._section_name, - uuid=self._dbnode.uuid) + uuid=self.uuid) else: # TODO: allow to get the user from the parameters @@ -153,20 +152,81 @@ def query(cls, *args, **kwargs): raise NotImplementedError("The node query method is not supported in " "SQLAlchemy. Please use QueryBuilder.") + @property + def type(self): + # Type is immutable so no need to ensure the model is up to date + return self._dbnode.type + + @property + def ctime(self): + """ + Return the creation time of the node. + """ + self._ensure_model_uptodate(attribute_names=['ctime']) + return self._dbnode.ctime + + @property + def mtime(self): + """ + Return the modification time of the node. + """ + self._ensure_model_uptodate(attribute_names=['mtime']) + return self._dbnode.mtime + + def get_user(self): + """ + Get the user. + + :return: a Django DbUser model object + """ + self._ensure_model_uptodate(attribute_names=['user']) + return self._dbnode.user + + def get_computer(self): + """ + Get the computer associated to the node. + + :return: the Computer object or None. + """ + self._ensure_model_uptodate(attribute_names=['dbcomputer']) + if self._dbnode.dbcomputer is None: + return None + else: + return Computer(dbcomputer=self._dbnode.dbcomputer) + + def _get_db_label_field(self): + """ + Get the label of the node. + + :return: a string. + """ + self._ensure_model_uptodate(attribute_names=['label']) + return self._dbnode.label + def _update_db_label_field(self, field_value): from aiida.backends.sqlalchemy import get_scoped_session session = get_scoped_session() - self.dbnode.label = field_value + self._dbnode.label = field_value if not self._to_be_stored: session.add(self._dbnode) self._increment_version_number_db() + def _get_db_description_field(self): + """ + Get the description of the node. + + :return: a string + :rtype: str + """ + self._ensure_model_uptodate(attribute_names=['description']) + return self._dbnode.description + def _update_db_description_field(self, field_value): from aiida.backends.sqlalchemy import get_scoped_session session = get_scoped_session() - self.dbnode.description = field_value + self._dbnode.description = field_value if not self._to_be_stored: session.add(self._dbnode) self._increment_version_number_db() @@ -219,8 +279,8 @@ def _add_dblink_from(self, src, label=None, link_type=LinkType.UNSPECIFIED): # already in the TC table from self to src if link_type is LinkType.CREATE or link_type is LinkType.INPUT: if QueryBuilder().append( - Node, filters={'id':self.pk}, tag='parent').append( - Node, filters={'id':src.pk}, tag='child', descendant_of='parent').count() > 0: + Node, filters={'id': self.pk}, tag='parent').append( + Node, filters={'id': src.pk}, tag='child', descendant_of='parent').count() > 0: raise ValueError( "The link you are attempting to create would generate a loop") @@ -228,7 +288,7 @@ def _add_dblink_from(self, src, label=None, link_type=LinkType.UNSPECIFIED): autolabel_idx = 1 existing_from_autolabels = session.query(DbLink.label).filter( - DbLink.output_id == self.dbnode.id, + DbLink.output_id == self._dbnode.id, DbLink.label.like("link%") ) @@ -258,7 +318,7 @@ def _do_create_link(self, src, label, link_type): session = get_scoped_session() try: with session.begin_nested(): - link = DbLink(input_id=src.dbnode.id, output_id=self.dbnode.id, + link = DbLink(input_id=src.id, output_id=self.id, label=label, type=link_type.value) session.add(link) except SQLAlchemyError as e: @@ -267,22 +327,21 @@ def _do_create_link(self, src, label, link_type): "".format(e)) def _get_db_input_links(self, link_type): - link_filter = {'output': self.dbnode} + link_filter = {'output': self._dbnode} if link_type is not None: link_filter['type'] = link_type.value return [(i.label, i.input.get_aiida_class()) for i in DbLink.query.filter_by(**link_filter).distinct().all()] - def _get_db_output_links(self, link_type): - link_filter = {'input': self.dbnode} + link_filter = {'input': self._dbnode} if link_type is not None: link_filter['type'] = link_type.value return ((i.label, i.output.get_aiida_class()) for i in DbLink.query.filter_by(**link_filter).distinct().all()) def _set_db_computer(self, computer): - self.dbnode.dbcomputer = DbComputer.get_dbcomputer(computer) + self._dbnode.dbcomputer = DbComputer.get_dbcomputer(computer) def _set_db_attr(self, key, value): """ @@ -295,7 +354,7 @@ def _set_db_attr(self, key, value): :param value: its value """ try: - self.dbnode.set_attr(key, value) + self._dbnode.set_attr(key, value) self._increment_version_number_db() except: from aiida.backends.sqlalchemy import get_scoped_session @@ -305,7 +364,7 @@ def _set_db_attr(self, key, value): def _del_db_attr(self, key): try: - self.dbnode.del_attr(key) + self._dbnode.del_attr(key) self._increment_version_number_db() except: from aiida.backends.sqlalchemy import get_scoped_session @@ -315,17 +374,16 @@ def _del_db_attr(self, key): def _get_db_attr(self, key): try: - return get_attr(self.dbnode.attributes, key) + return get_attr(self._attributes(), key) except (KeyError, IndexError): raise AttributeError("Attribute '{}' does not exist".format(key)) def _set_db_extra(self, key, value, exclusive=False): - if exclusive: raise NotImplementedError("exclusive=True not implemented yet in SQLAlchemy backend") try: - self.dbnode.set_extra(key, value) + self._dbnode.set_extra(key, value) self._increment_version_number_db() except: from aiida.backends.sqlalchemy import get_scoped_session @@ -335,7 +393,7 @@ def _set_db_extra(self, key, value, exclusive=False): def _reset_db_extras(self, new_extras): try: - self.dbnode.reset_extras(new_extras) + self._dbnode.reset_extras(new_extras) self._increment_version_number_db() except: from aiida.backends.sqlalchemy import get_scoped_session @@ -345,14 +403,14 @@ def _reset_db_extras(self, new_extras): def _get_db_extra(self, key, default=None): try: - return get_attr(self.dbnode.extras, key) + return get_attr(self._extras(), key) except (KeyError, AttributeError): raise AttributeError("DbExtra {} does not exist".format( key)) def _del_db_extra(self, key): try: - self.dbnode.del_extra(key) + self._dbnode.del_extra(key) self._increment_version_number_db() except: from aiida.backends.sqlalchemy import get_scoped_session @@ -360,19 +418,19 @@ def _del_db_extra(self, key): session.rollback() raise - def _db_iterextras(self): - if self.dbnode.extras is None: + extras = self._extras() + if extras is None: return dict().iteritems() - return self.dbnode.extras.iteritems() + return extras.iteritems() def _db_iterattrs(self): - for k, v in self.dbnode.attributes.iteritems(): + for k, v in self._attributes().iteritems(): yield (k, v) def _db_attrs(self): - for k in self.dbnode.attributes.iterkeys(): + for k in self._attributes().iterkeys(): yield k def add_comment(self, content, user=None): @@ -417,7 +475,7 @@ def get_comments(self, pk=None): "ctime": c.ctime, "mtime": c.mtime, "content": c.content - } for c in comments ] + } for c in comments] def _get_dbcomments(self, pk=None, with_user=False): comments = DbComment.query.filter_by(dbnode=self._dbnode) @@ -473,7 +531,7 @@ def _remove_comment(self, comment_pk, user): raise def _increment_version_number_db(self): - self._dbnode.nodeversion = DbNode.nodeversion + 1 + self._dbnode.nodeversion = self.nodeversion + 1 try: self._dbnode.save() except: @@ -482,14 +540,15 @@ def _increment_version_number_db(self): session.rollback() raise - def copy(self, **kwargs): + # Make sure we have the latest version from the database + self._ensure_model_uptodate() newobject = self.__class__() - newobject.dbnode.type = self.dbnode.type # Inherit type - newobject.dbnode.label = self.dbnode.label # Inherit label + newobject._dbnode.type = self._dbnode.type # Inherit type + newobject._dbnode.label = self._dbnode.label # Inherit label # TODO: add to the description the fact that this was a copy? - newobject.dbnode.description = self.dbnode.description # Inherit description - newobject.dbnode.dbcomputer = self.dbnode.dbcomputer # Inherit computer + newobject._dbnode.description = self._dbnode.description # Inherit description + newobject._dbnode.dbcomputer = self._dbnode.dbcomputer # Inherit computer for k, v in self.iterattrs(): if k != Sealable.SEALED_KEY: @@ -502,12 +561,23 @@ def copy(self, **kwargs): @property def id(self): - return self.dbnode.id + return self._dbnode.id @property def dbnode(self): + self._ensure_model_uptodate() return self._dbnode + @property + def nodeversion(self): + self._ensure_model_uptodate(attribute_names=['nodeversion']) + return self._dbnode.nodeversion + + @property + def public(self): + self._ensure_model_uptodate(attribute_names=['public']) + return self._dbnode.public + def _db_store_all(self, with_transaction=True, use_cache=None): """ Store the node, together with all input links, if cached, and also the @@ -533,7 +603,6 @@ def _db_store_all(self, with_transaction=True, use_cache=None): return self - def _store_cached_input_links(self, with_transaction=True): """ Store all input links that are in the local cache, transferring them @@ -626,8 +695,8 @@ def _db_store(self, with_transaction=True): session.add(self._dbnode) # Save its attributes 'manually' without incrementing # the version for each add. - self.dbnode.attributes = self._attrs_cache - flag_modified(self.dbnode, "attributes") + self._dbnode.attributes = self._attrs_cache + flag_modified(self._dbnode, "attributes") # This should not be used anymore: I delete it to # possibly free memory del self._attrs_cache @@ -644,7 +713,7 @@ def _db_store(self, with_transaction=True): # aiida.backends.sqlalchemy.get_scoped_session().commit() session.commit() except SQLAlchemyError as e: - #print "Cannot store the node. Original exception: {" \ + # print "Cannot store the node. Original exception: {" \ # "}".format(e) session.rollback() raise @@ -658,10 +727,21 @@ def _db_store(self, with_transaction=True): self._repository_folder.abspath, move=True, overwrite=True) raise - self.dbnode.set_extra(_HASH_EXTRA_KEY, self.get_hash()) + self._dbnode.set_extra(_HASH_EXTRA_KEY, self.get_hash()) return self - @property def uuid(self): - return unicode(self.dbnode.uuid) + return unicode(self._dbnode.uuid) + + def _attributes(self): + self._ensure_model_uptodate(['attributes']) + return self._dbnode.attributes + + def _extras(self): + self._ensure_model_uptodate(['extras']) + return self._dbnode.extras + + def _ensure_model_uptodate(self, attribute_names=None): + if self.is_stored: + self._dbnode.session.expire(self._dbnode, attribute_names=attribute_names) diff --git a/aiida/restapi/translator/calculation/__init__.py b/aiida/restapi/translator/calculation/__init__.py index 2a4967347a..964343a6e2 100644 --- a/aiida/restapi/translator/calculation/__init__.py +++ b/aiida/restapi/translator/calculation/__init__.py @@ -63,7 +63,7 @@ def get_retrieved_inputs(node, filename=None, rtype=None): :return: the retrieved input files for job calculation """ - if node.dbnode.type.startswith("calculation.job."): + if node.type.startswith("calculation.job."): input_folder = node._raw_input_folder @@ -85,7 +85,7 @@ def get_retrieved_inputs(node, filename=None, rtype=None): response["filename"] = filename.replace("/", "_") else: - raise RestInputValidationError ("rtype is not supported") + raise RestInputValidationError("rtype is not supported") return response @@ -103,7 +103,7 @@ def get_retrieved_outputs(node, filename=None, rtype=None): :return: the retrieved output files for job calculation """ - if node.dbnode.type.startswith("calculation.job."): + if node.type.startswith("calculation.job."): retrieved_folder = node.out.retrieved response = {} diff --git a/aiida/restapi/translator/node.py b/aiida/restapi/translator/node.py index 90daaab472..e52014e7f3 100644 --- a/aiida/restapi/translator/node.py +++ b/aiida/restapi/translator/node.py @@ -419,7 +419,7 @@ def get_retrieved_inputs(self, node, filename=None, rtype=None): :returns: list of calc inputls command """ - if node.dbnode.type.startswith("calculation"): + if node.type.startswith("calculation"): from aiida.restapi.translator.calculation import CalculationTranslator return CalculationTranslator.get_retrieved_inputs(node, filename=filename, rtype=rtype) return [] @@ -435,7 +435,7 @@ def get_retrieved_outputs(self, node, filename=None, rtype=None): :returns: list of calc outputls command """ - if node.dbnode.type.startswith("calculation"): + if node.type.startswith("calculation"): from aiida.restapi.translator.calculation import CalculationTranslator return CalculationTranslator.get_retrieved_outputs(node, filename=filename, rtype=rtype) return [] @@ -514,11 +514,11 @@ def get_node_shape(ntype): mainNode = qb.first()[0] pk = mainNode.pk uuid = mainNode.uuid - nodetype = mainNode.dbnode.type + nodetype = mainNode.type display_type = nodetype.split('.')[-2] description = mainNode.get_desc() if description == '': - description = mainNode.dbnode.type.split('.')[-2] + description = mainNode.type.split('.')[-2] nodes.append({ "id": nodeCount, @@ -545,11 +545,11 @@ def get_node_shape(ntype): linktype = input['main--in']['label'] pk = node.pk uuid = node.uuid - nodetype = node.dbnode.type + nodetype = node.type display_type = nodetype.split('.')[-2] description = node.get_desc() if description == '': - description = node.dbnode.type.split('.')[-2] + description = node.type.split('.')[-2] nodes.append({ "id": nodeCount, @@ -583,11 +583,11 @@ def get_node_shape(ntype): linktype = output['main--out']['label'] pk = node.pk uuid = node.uuid - nodetype = node.dbnode.type + nodetype = node.type display_type = nodetype.split('.')[-2] description = node.get_desc() if description == '': - description = node.dbnode.type.split('.')[-2] + description = node.type.split('.')[-2] nodes.append({ "id": nodeCount,