diff --git a/aiida/orm/implementation/sqlalchemy/groups.py b/aiida/orm/implementation/sqlalchemy/groups.py index 5302d4f0d6..ec6328c207 100644 --- a/aiida/orm/implementation/sqlalchemy/groups.py +++ b/aiida/orm/implementation/sqlalchemy/groups.py @@ -228,37 +228,56 @@ def check_node(given_node): # Commit everything as up till now we've just flushed session.commit() - def remove_nodes(self, nodes): + def remove_nodes(self, nodes, **kwargs): """Remove a node or a set of nodes from the group. :note: all the nodes *and* the group itself have to be stored. :param nodes: a list of `BackendNode` instance to be added to this group + :param kwargs: + skip_orm: When the flag is set to `True`, the SQLA ORM is skipped and SQLA is used to create a direct SQL + DELETE statement to the group-node relationship table in order to improve speed. """ + from sqlalchemy import and_ + from aiida.backends.sqlalchemy import get_scoped_session + from aiida.backends.sqlalchemy.models.base import Base from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode super().remove_nodes(nodes) # Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database dbnodes = self._dbmodel.dbnodes + skip_orm = kwargs.get('skip_orm', False) - list_nodes = [] - - for node in nodes: + def check_node(node): if not isinstance(node, SqlaNode): raise TypeError('invalid type {}, has to be {}'.format(type(node), SqlaNode)) if node.id is None: raise ValueError('At least one of the provided nodes is unstored, stopping...') - # If we don't check first, SqlA might issue a DELETE statement for an unexisting key, resulting in an error - if node.dbmodel in dbnodes: - list_nodes.append(node.dbmodel) + list_nodes = [] - for node in list_nodes: - dbnodes.remove(node) + with utils.disable_expire_on_commit(get_scoped_session()) as session: + if not skip_orm: + for node in nodes: + check_node(node) + + # Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error + if node.dbmodel in dbnodes: + list_nodes.append(node.dbmodel) + + for node in list_nodes: + dbnodes.remove(node) + else: + table = Base.metadata.tables['db_dbgroup_dbnodes'] + for node in nodes: + check_node(node) + clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id) + statement = table.delete().where(clause) + session.execute(statement) - sa.get_scoped_session().commit() + session.commit() class SqlaGroupCollection(BackendGroupCollection): diff --git a/tests/backends/aiida_sqlalchemy/test_generic.py b/tests/backends/aiida_sqlalchemy/test_generic.py index cba65a7abe..254f7f8c96 100644 --- a/tests/backends/aiida_sqlalchemy/test_generic.py +++ b/tests/backends/aiida_sqlalchemy/test_generic.py @@ -164,3 +164,33 @@ def test_group_batch_size(self): group = Group(label='test_batches_' + str(batch_size)).store() group.backend_entity.add_nodes(nodes, skip_orm=True, batch_size=batch_size) self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + + def test_remove_nodes_bulk(self): + """Test node removal.""" + backend = self.backend + + node_01 = Data().store().backend_entity + node_02 = Data().store().backend_entity + node_03 = Data().store().backend_entity + node_04 = Data().store().backend_entity + nodes = [node_01, node_02, node_03] + group = backend.groups.create(label='test_remove_nodes', user=backend.users.create('simple2@ton.com')).store() + + # Add initial nodes + group.add_nodes(nodes) + self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + + # Remove a node that is not in the group: nothing should happen + group.remove_nodes([node_04], skip_orm=True) + self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + + # Remove one Node + nodes.remove(node_03) + group.remove_nodes([node_03], skip_orm=True) + self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes)) + + # Remove a list of Nodes and check + nodes.remove(node_01) + nodes.remove(node_02) + group.remove_nodes([node_01, node_02], skip_orm=True) + self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))