Skip to content

Commit

Permalink
👌 IMPROVE: Ensure QueryBuilder is passed Backend (#5186)
Browse files Browse the repository at this point in the history
This PR ensures core code always calls `QueryBuilder` with a specific
`Backend`, as opposed to assuming the loaded `Backend`. This will allow
for muliple backends to be used at the same time (for example export 
archives), for features including graph traversal and visualisation.
  • Loading branch information
chrisjsewell authored Oct 22, 2021
1 parent d585d14 commit 2f2bdc3
Show file tree
Hide file tree
Showing 18 changed files with 101 additions and 71 deletions.
8 changes: 6 additions & 2 deletions aiida/cmdline/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
import logging
import os
import sys
from typing import TYPE_CHECKING

from tabulate import tabulate

from . import echo

if TYPE_CHECKING:
from aiida.orm import WorkChainNode

__all__ = ('is_verbose',)


Expand Down Expand Up @@ -306,7 +310,7 @@ def get_process_function_report(node):
return '\n'.join(report)


def get_workchain_report(node, levelname, indent_size=4, max_depth=None):
def get_workchain_report(node: 'WorkChainNode', levelname, indent_size=4, max_depth=None):
"""
Return a multi line string representation of the log messages and output of a given workchain
Expand All @@ -333,7 +337,7 @@ def get_subtree(uuid, level=0):
Get a nested tree of work calculation nodes and their nesting level starting from this uuid.
The result is a list of uuid of these nodes.
"""
builder = orm.QueryBuilder()
builder = orm.QueryBuilder(backend=node.backend)
builder.append(cls=orm.WorkChainNode, filters={'uuid': uuid}, tag='workcalculation')
builder.append(
cls=orm.WorkChainNode,
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/django/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def delete_many(self, filters):
raise exceptions.ValidationError('filters must not be empty')

# Apply filter and delete found entities
builder = QueryBuilder().append(Comment, filters=filters, project='id').all()
builder = QueryBuilder(backend=self.backend).append(Comment, filters=filters, project='id').all()
entities_to_delete = [_[0] for _ in builder]
for entity in entities_to_delete:
self.delete(entity)
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/django/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def delete_many(self, filters):
raise exceptions.ValidationError('filters must not be empty')

# Apply filter and delete found entities
builder = QueryBuilder().append(Log, filters=filters, project='id')
builder = QueryBuilder(backend=self.backend).append(Log, filters=filters, project='id')
entities_to_delete = builder.all(flat=True)
for entity in entities_to_delete:
self.delete(entity)
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/sqlalchemy/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def delete_many(self, filters):
raise exceptions.ValidationError('filters must not be empty')

# Apply filter and delete found entities
builder = QueryBuilder().append(Comment, filters=filters, project='id')
builder = QueryBuilder(backend=self.backend).append(Comment, filters=filters, project='id')
entities_to_delete = builder.all(flat=True)
for entity in entities_to_delete:
self.delete(entity)
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/sqlalchemy/logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def delete_many(self, filters):
raise exceptions.ValidationError('filter must not be empty')

# Apply filter and delete found entities
builder = QueryBuilder().append(Log, filters=filters, project='id')
builder = QueryBuilder(backend=self.backend).append(Log, filters=filters, project='id')
entities_to_delete = builder.all(flat=True)
for entity in entities_to_delete:
self.delete(entity)
Expand Down
4 changes: 2 additions & 2 deletions aiida/orm/nodes/data/array/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,7 +1803,7 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un
MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI = Template("""pl.savefig("$fname", format="$format", dpi=$dpi)""")


def get_bands_and_parents_structure(args):
def get_bands_and_parents_structure(args, backend=None):
"""Search for bands and return bands and the closest structure that is a parent of the instance.
:returns:
Expand All @@ -1817,7 +1817,7 @@ def get_bands_and_parents_structure(args):
from aiida import orm
from aiida.common import timezone

q_build = orm.QueryBuilder()
q_build = orm.QueryBuilder(backend=backend)
if args.all_users is False:
q_build.append(orm.User, tag='creator', filters={'email': orm.User.objects.get_default().email})
else:
Expand Down
4 changes: 2 additions & 2 deletions aiida/orm/nodes/data/cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,15 @@ def read_cif(fileobj, index=-1, **kwargs):
return struct_list[index]

@classmethod
def from_md5(cls, md5):
def from_md5(cls, md5, backend=None):
"""
Return a list of all CIF files that match a given MD5 hash.
.. note:: the hash has to be stored in a ``_md5`` attribute,
otherwise the CIF file will not be found.
"""
from aiida.orm.querybuilder import QueryBuilder
builder = QueryBuilder()
builder = QueryBuilder(backend=backend)
builder.append(cls, filters={'attributes.md5': {'==': md5}})
return builder.all(flat=True)

Expand Down
8 changes: 4 additions & 4 deletions aiida/orm/nodes/data/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_description(self):
return f'{self.description}'

@classmethod
def get_code_helper(cls, label, machinename=None):
def get_code_helper(cls, label, machinename=None, backend=None):
"""
:param label: the code label identifying the code to load
:param machinename: the machine name where code is setup
Expand All @@ -164,7 +164,7 @@ def get_code_helper(cls, label, machinename=None):
from aiida.orm.computers import Computer
from aiida.orm.querybuilder import QueryBuilder

query = QueryBuilder()
query = QueryBuilder(backend=backend)
query.append(cls, filters={'label': label}, project='*', tag='code')
if machinename:
query.append(Computer, filters={'label': machinename}, with_node='code')
Expand Down Expand Up @@ -249,7 +249,7 @@ def get_from_string(cls, code_string):
raise MultipleObjectsError(f'{code_string} could not be uniquely resolved')

@classmethod
def list_for_plugin(cls, plugin, labels=True):
def list_for_plugin(cls, plugin, labels=True, backend=None):
"""
Return a list of valid code strings for a given plugin.
Expand All @@ -260,7 +260,7 @@ def list_for_plugin(cls, plugin, labels=True):
otherwise a list of integers with the code PKs.
"""
from aiida.orm.querybuilder import QueryBuilder
query = QueryBuilder()
query = QueryBuilder(backend=backend)
query.append(cls, filters={'attributes.input_plugin': {'==': plugin}})
valid_codes = query.all(flat=True)

Expand Down
14 changes: 7 additions & 7 deletions aiida/orm/nodes/data/upf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_pseudos_from_structure(structure, family_name):
return pseudo_list


def upload_upf_family(folder, group_label, group_description, stop_if_existing=True):
def upload_upf_family(folder, group_label, group_description, stop_if_existing=True, backend=None):
"""Upload a set of UPF files in a given group.
:param folder: a path containing all UPF files to be added.
Expand Down Expand Up @@ -120,7 +120,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T

for filename in filenames:
md5sum = md5_file(filename)
builder = orm.QueryBuilder()
builder = orm.QueryBuilder(backend=backend)
builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}})
existing_upf = builder.first()

Expand Down Expand Up @@ -321,7 +321,7 @@ def store(self, *args, **kwargs): # pylint: disable=signature-differs
return super().store(*args, **kwargs)

@classmethod
def from_md5(cls, md5):
def from_md5(cls, md5, backend=None):
"""Return a list of all `UpfData` that match the given md5 hash.
.. note:: assumes hash of stored `UpfData` nodes is stored in the `md5` attribute
Expand All @@ -330,7 +330,7 @@ def from_md5(cls, md5):
:return: list of existing `UpfData` nodes that have the same md5 hash
"""
from aiida.orm.querybuilder import QueryBuilder
builder = QueryBuilder()
builder = QueryBuilder(backend=backend)
builder.append(cls, filters={'attributes.md5': {'==': md5}})
return builder.all(flat=True)

Expand Down Expand Up @@ -366,7 +366,7 @@ def get_upf_family_names(self):
"""Get the list of all upf family names to which the pseudo belongs."""
from aiida.orm import QueryBuilder, UpfFamily

query = QueryBuilder()
query = QueryBuilder(backend=self.backend)
query.append(UpfFamily, tag='group', project='label')
query.append(UpfData, filters={'id': {'==': self.id}}, with_group='group')
return query.all(flat=True)
Expand Down Expand Up @@ -448,7 +448,7 @@ def get_upf_group(cls, group_label):
return UpfFamily.get(label=group_label)

@classmethod
def get_upf_groups(cls, filter_elements=None, user=None):
def get_upf_groups(cls, filter_elements=None, user=None, backend=None):
"""Return all names of groups of type UpfFamily, possibly with some filters.
:param filter_elements: A string or a list of strings.
Expand All @@ -460,7 +460,7 @@ def get_upf_groups(cls, filter_elements=None, user=None):
"""
from aiida.orm import QueryBuilder, UpfFamily, User

builder = QueryBuilder()
builder = QueryBuilder(backend=backend)
builder.append(UpfFamily, tag='group', project='*')

if user:
Expand Down
8 changes: 4 additions & 4 deletions aiida/orm/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,11 +456,11 @@ def validate_incoming(self, source: 'Node', link_type: LinkType, link_label: str
"""
from aiida.orm.utils.links import validate_link

validate_link(source, self, link_type, link_label)
validate_link(source, self, link_type, link_label, backend=self.backend)

# Check if the proposed link would introduce a cycle in the graph following ancestor/descendant rules
if link_type in [LinkType.CREATE, LinkType.INPUT_CALC, LinkType.INPUT_WORK]:
builder = QueryBuilder().append(
builder = QueryBuilder(backend=self.backend).append(
Node, filters={'id': self.pk}, tag='parent').append(
Node, filters={'id': source.pk}, tag='child', with_ancestors='parent') # yapf:disable
if builder.count() > 0:
Expand Down Expand Up @@ -537,7 +537,7 @@ def get_stored_link_triples(
if link_label_filter:
edge_filters['label'] = {'like': link_label_filter}

builder = QueryBuilder()
builder = QueryBuilder(backend=self.backend)
builder.append(Node, filters=node_filters, tag='main')

node_project = ['uuid'] if only_uuid else ['*']
Expand Down Expand Up @@ -894,7 +894,7 @@ def _iter_all_same_nodes(self, allow_before_store=False) -> Iterator['Node']:
if not node_hash or not self._cachable:
return iter(())

builder = QueryBuilder()
builder = QueryBuilder(backend=self.backend)
builder.append(self.__class__, filters={'extras._aiida_hash': node_hash}, project='*', subclassing=False)
nodes_identical = (n[0] for n in builder.iterall())

Expand Down
11 changes: 8 additions & 3 deletions aiida/orm/querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def __init__(
:param distinct: Whether to return de-duplicated rows
"""
backend = backend or get_manager().get_backend()
self._impl: BackendQueryBuilder = backend.query()
self._backend = backend or get_manager().get_backend()
self._impl: BackendQueryBuilder = self._backend.query()

# SERIALISABLE ATTRIBUTES
# A list storing the path being traversed by the query
Expand Down Expand Up @@ -189,6 +189,11 @@ def __init__(
if order_by:
self.order_by(order_by)

@property
def backend(self) -> 'Backend':
"""Return the backend used by the QueryBuilder."""
return self._backend

def as_dict(self, copy: bool = True) -> QueryDictType:
"""Convert to a JSON serialisable dictionary representation of the query."""
data: QueryDictType = {
Expand Down Expand Up @@ -225,7 +230,7 @@ def __str__(self) -> str:

def __deepcopy__(self, memo) -> 'QueryBuilder':
"""Create deep copy of the instance."""
return type(self)(**self.as_dict()) # type: ignore
return type(self)(backend=self.backend, **self.as_dict()) # type: ignore

def get_used_tags(self, vertices: bool = True, edges: bool = True) -> List[str]:
"""Returns a list of all the vertices that are being used.
Expand Down
8 changes: 4 additions & 4 deletions aiida/orm/utils/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LinkQuadruple = namedtuple('LinkQuadruple', ['source_id', 'target_id', 'link_type', 'link_label'])


def link_triple_exists(source, target, link_type, link_label):
def link_triple_exists(source, target, link_type, link_label, backend=None):
"""Return whether a link with the given type and label exists between the given source and target node.
:param source: node from which the link is outgoing
Expand All @@ -42,15 +42,15 @@ def link_triple_exists(source, target, link_type, link_label):

# Here we have two stored nodes, so we need to check if the same link already exists in the database.
# Finding just a single match is sufficient so we can use the `limit` clause for efficiency
builder = QueryBuilder()
builder = QueryBuilder(backend=backend)
builder.append(Node, filters={'id': source.id}, project=['id'])
builder.append(Node, filters={'id': target.id}, edge_filters={'type': link_type.value, 'label': link_label})
builder.limit(1)

return builder.count() != 0


def validate_link(source, target, link_type, link_label):
def validate_link(source, target, link_type, link_label, backend=None):
"""
Validate adding a link of the given type and label from a given node to ourself.
Expand Down Expand Up @@ -153,7 +153,7 @@ def validate_link(source, target, link_type, link_label):
if outdegree == 'unique_triple' or indegree == 'unique_triple':
# For a `unique_triple` degree we just have to check if an identical triple already exist, either in the cache
# or stored, in which case, the new proposed link is a duplicate and thus illegal
duplicate_link_triple = link_triple_exists(source, target, link_type, link_label)
duplicate_link_triple = link_triple_exists(source, target, link_type, link_label, backend)

# If the outdegree is `unique` there cannot already be any other outgoing link of that type
if outdegree == 'unique' and source.get_outgoing(link_type=link_type, only_uuid=True).all():
Expand Down
6 changes: 3 additions & 3 deletions aiida/orm/utils/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ def clean_remote(transport, path):
pass


def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computers=None, user=None):
def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computers=None, user=None, backend=None):
"""
Return a mapping of computer uuids to a list of remote paths, for a given set of calcjobs. The set of
calcjobs will be determined by a query with filters based on the pks, past_days, older_than,
computers and user arguments.
:param pks: onlu include calcjobs with a pk in this list
:param pks: only include calcjobs with a pk in this list
:param past_days: only include calcjobs created since past_days
:param older_than: only include calcjobs older than
:param computers: only include calcjobs that were ran on these computers
Expand Down Expand Up @@ -74,7 +74,7 @@ def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computer
if pks:
filters_calc['id'] = {'in': pks}

query = orm.QueryBuilder()
query = orm.QueryBuilder(backend=backend)
query.append(CalcJobNode, tag='calc', project=['attributes.remote_workdir'], filters=filters_calc)
query.append(orm.Computer, with_node='calc', tag='computer', project=['*'], filters=filters_computer)
query.append(orm.User, with_node='calc', filters={'email': user.email})
Expand Down
11 changes: 0 additions & 11 deletions aiida/tools/graph/age_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,6 @@ def aiida_cls(self):
"""Class of nodes contained in the entity set (node or group)"""
return self._aiida_cls

def get_entities(self):
"""Iterator that returns the AiiDA entities"""
for entity, in orm.QueryBuilder().append(
self._aiida_cls, project='*', filters={
self._identifier: {
'in': self.keyset
}
}
).iterall():
yield entity


class DirectedEdgeSet(AbstractSetContainer):
"""Extension of AbstractSetContainer
Expand Down
7 changes: 4 additions & 3 deletions aiida/tools/graph/age_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from abc import ABCMeta, abstractmethod
from collections import defaultdict
from copy import deepcopy

import numpy as np

Expand Down Expand Up @@ -65,7 +66,7 @@ class QueryRule(Operation, metaclass=ABCMeta):
found in the last iteration of the query (ReplaceRule).
"""

def __init__(self, querybuilder, max_iterations=1, track_edges=False):
def __init__(self, querybuilder: orm.QueryBuilder, max_iterations=1, track_edges=False):
"""Initialization method
:param querybuilder: an instance of the QueryBuilder class from which to take the
Expand Down Expand Up @@ -107,7 +108,7 @@ def get_spec_from_path(query_dict, idx):
for pathspec in query_dict['path']:
if not pathspec['entity_type']:
pathspec['entity_type'] = 'node.Node.'
self._qbtemplate = orm.QueryBuilder(**query_dict)
self._qbtemplate = deepcopy(querybuilder)
query_dict = self._qbtemplate.as_dict()
self._first_tag = query_dict['path'][0]['tag']
self._last_tag = query_dict['path'][-1]['tag']
Expand Down Expand Up @@ -163,7 +164,7 @@ def _init_run(self, operational_set):

# Copying qbtemplate so there's no problem if it is used again in a later run:
query_dict = self._qbtemplate.as_dict()
self._querybuilder = orm.QueryBuilder.from_dict(query_dict)
self._querybuilder = deepcopy(self._qbtemplate)

self._entity_to_identifier = operational_set[self._entity_to].identifier

Expand Down
Loading

0 comments on commit 2f2bdc3

Please sign in to comment.