diff --git a/CHANGELOG.md b/CHANGELOG.md index c4e7b620..4d90eea0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ - Added - Datajoint python CLI ([#940](https://github.com/datajoint/datajoint-python/issues/940)) PR [#1095](https://github.com/datajoint/datajoint-python/pull/1095) - Added - Ability to set hidden attributes on a table - PR [#1091](https://github.com/datajoint/datajoint-python/pull/1091) - Added - Ability to specify a list of keys to popuate - PR [#989](https://github.com/datajoint/datajoint-python/pull/989) +- Fixed - fixed topological sort [#1057](https://github.com/datajoint/datajoint-python/issues/1057)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184) +- Fixed - .parts() not always returning parts [#1103](https://github.com/datajoint/datajoint-python/issues/1103)- PR [#1184](https://github.com/datajoint/datajoint-python/pull/1184) ### 0.14.2 -- Aug 19, 2024 - Added - Migrate nosetests to pytest - PR [#1142](https://github.com/datajoint/datajoint-python/pull/1142) diff --git a/datajoint/dependencies.py b/datajoint/dependencies.py index d9c425d4..5a34dc15 100644 --- a/datajoint/dependencies.py +++ b/datajoint/dependencies.py @@ -5,28 +5,64 @@ from .errors import DataJointError -def unite_master_parts(lst): +def extract_master(part_table): """ - re-order a list of table names so that part tables immediately follow their master tables without breaking - the topological order. - Without this correction, a simple topological sort may insert other descendants between master and parts. - The input list must be topologically sorted. - :example: - unite_master_parts( - ['`s`.`a`', '`s`.`a__q`', '`s`.`b`', '`s`.`c`', '`s`.`c__q`', '`s`.`b__q`', '`s`.`d`', '`s`.`a__r`']) -> - ['`s`.`a`', '`s`.`a__q`', '`s`.`a__r`', '`s`.`b`', '`s`.`b__q`', '`s`.`c`', '`s`.`c__q`', '`s`.`d`'] + given a part table name, return master part. None if not a part table """ - for i in range(2, len(lst)): - name = lst[i] - match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", name) - if match: # name is a part table - master = match.group("master") - for j in range(i - 1, -1, -1): - if lst[j] == master + "`" or lst[j].startswith(master + "__"): - # move from the ith position to the (j+1)th position - lst[j + 1 : i + 1] = [name] + lst[j + 1 : i] - break - return lst + match = re.match(r"(?P`\w+`.`#?\w+)__\w+`", part_table) + return match["master"] + "`" if match else None + + +def topo_sort(graph): + """ + topological sort of a dependency graph that keeps part tables together with their masters + :return: list of table names in topological order + """ + + graph = nx.DiGraph(graph) # make a copy + + # collapse alias nodes + alias_nodes = [node for node in graph if node.isdigit()] + for node in alias_nodes: + try: + direct_edge = ( + next(x for x in graph.in_edges(node))[0], + next(x for x in graph.out_edges(node))[1], + ) + except StopIteration: + pass # a disconnected alias node + else: + graph.add_edge(*direct_edge) + graph.remove_nodes_from(alias_nodes) + + # Add parts' dependencies to their masters' dependencies + # to ensure correct topological ordering of the masters. + for part in graph: + # find the part's master + if (master := extract_master(part)) in graph: + for edge in graph.in_edges(part): + parent = edge[0] + if master not in (parent, extract_master(parent)): + # if parent is neither master nor part of master + graph.add_edge(parent, master) + sorted_nodes = list(nx.topological_sort(graph)) + + # bring parts up to their masters + pos = len(sorted_nodes) - 1 + placed = set() + while pos > 1: + part = sorted_nodes[pos] + if (master := extract_master(part)) not in graph or part in placed: + pos -= 1 + else: + placed.add(part) + j = sorted_nodes.index(master) + if pos > j + 1: + # move the part to its master + del sorted_nodes[pos] + sorted_nodes.insert(j + 1, part) + + return sorted_nodes class Dependencies(nx.DiGraph): @@ -131,6 +167,10 @@ def load(self, force=True): raise DataJointError("DataJoint can only work with acyclic dependencies") self._loaded = True + def topo_sort(self): + """:return: list of tables names in topological order""" + return topo_sort(self) + def parents(self, table_name, primary=None): """ :param table_name: `schema`.`table` @@ -167,10 +207,8 @@ def descendants(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph(nx.algorithms.dag.descendants(self, full_table_name)) - return unite_master_parts( - [full_table_name] + list(nx.algorithms.dag.topological_sort(nodes)) - ) + nodes = self.subgraph(nx.descendants(self, full_table_name)) + return [full_table_name] + nodes.topo_sort() def ancestors(self, full_table_name): """ @@ -178,11 +216,5 @@ def ancestors(self, full_table_name): :return: all dependent tables sorted in topological order. Self is included. """ self.load(force=False) - nodes = self.subgraph(nx.algorithms.dag.ancestors(self, full_table_name)) - return list( - reversed( - unite_master_parts( - list(nx.algorithms.dag.topological_sort(nodes)) + [full_table_name] - ) - ) - ) + nodes = self.subgraph(nx.ancestors(self, full_table_name)) + return reversed(nodes.topo_sort() + [full_table_name]) diff --git a/datajoint/diagram.py b/datajoint/diagram.py index 7f47f746..aeced065 100644 --- a/datajoint/diagram.py +++ b/datajoint/diagram.py @@ -1,12 +1,11 @@ import networkx as nx -import re import functools import io import logging import inspect from .table import Table -from .dependencies import unite_master_parts -from .user_tables import Manual, Imported, Computed, Lookup, Part +from .dependencies import topo_sort +from .user_tables import Manual, Imported, Computed, Lookup, Part, _get_tier, _AliasNode from .errors import DataJointError from .table import lookup_class_name @@ -27,29 +26,6 @@ logger = logging.getLogger(__name__.split(".")[0]) -user_table_classes = (Manual, Lookup, Computed, Imported, Part) - - -class _AliasNode: - """ - special class to indicate aliased foreign keys - """ - - pass - - -def _get_tier(table_name): - if not table_name.startswith("`"): - return _AliasNode - else: - try: - return next( - tier - for tier in user_table_classes - if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2]) - ) - except StopIteration: - return None if not diagram_active: @@ -59,8 +35,7 @@ class Diagram: Entity relationship diagram, currently disabled due to the lack of required packages: matplotlib and pygraphviz. To enable Diagram feature, please install both matplotlib and pygraphviz. For instructions on how to install - these two packages, refer to http://docs.datajoint.io/setup/Install-and-connect.html#python and - http://tutorials.datajoint.io/setting-up/datajoint-python.html + these two packages, refer to https://datajoint.com/docs/core/datajoint-python/0.14/client/install/ """ def __init__(self, *args, **kwargs): @@ -72,19 +47,22 @@ def __init__(self, *args, **kwargs): class Diagram(nx.DiGraph): """ - Entity relationship diagram. + Schema diagram showing tables and foreign keys between in the form of a directed + acyclic graph (DAG). The diagram is derived from the connection.dependencies object. Usage: >>> diag = Diagram(source) - source can be a base table object, a base table class, a schema, or a module that has a schema. + source can be a table object, a table class, a schema, or a module that has a schema. >>> diag.draw() draws the diagram using pyplot diag1 + diag2 - combines the two diagrams. + diag1 - diag2 - difference between diagrams + diag1 * diag2 - intersection of diagrams diag + n - expands n levels of successors diag - n - expands n levels of predecessors Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table @@ -94,6 +72,7 @@ class Diagram(nx.DiGraph): """ def __init__(self, source, context=None): + if isinstance(source, Diagram): # copy constructor self.nodes_to_show = set(source.nodes_to_show) @@ -154,7 +133,7 @@ def from_sequence(cls, sequence): def add_parts(self): """ - Adds to the diagram the part tables of tables already included in the diagram + Adds to the diagram the part tables of all master tables already in the diagram :return: """ @@ -179,16 +158,6 @@ def is_part(part, master): ) return self - def topological_sort(self): - """:return: list of nodes in topological order""" - return unite_master_parts( - list( - nx.algorithms.dag.topological_sort( - nx.DiGraph(self).subgraph(self.nodes_to_show) - ) - ) - ) - def __add__(self, arg): """ :param arg: either another Diagram or a positive integer. @@ -256,6 +225,10 @@ def __mul__(self, arg): self.nodes_to_show.intersection_update(arg.nodes_to_show) return self + def topo_sort(self): + """return nodes in lexicographical topological order""" + return topo_sort(self) + def _make_graph(self): """ Make the self.graph - a graph object ready for drawing diff --git a/datajoint/schemas.py b/datajoint/schemas.py index 62f45fa6..c3894ba2 100644 --- a/datajoint/schemas.py +++ b/datajoint/schemas.py @@ -2,17 +2,16 @@ import logging import inspect import re -import itertools import collections +import itertools from .connection import conn -from .diagram import Diagram, _get_tier from .settings import config from .errors import DataJointError, AccessError from .jobs import JobTable from .external import ExternalMapping from .heading import Heading from .utils import user_choice, to_camel_case -from .user_tables import Part, Computed, Imported, Manual, Lookup +from .user_tables import Part, Computed, Imported, Manual, Lookup, _get_tier from .table import lookup_class_name, Log, FreeTable import types @@ -413,6 +412,7 @@ def save(self, python_filename=None): :return: a string containing the body of a complete Python module defining this schema. """ + self.connection.dependencies.load() self._assert_exists() module_count = itertools.count() # add virtual modules for referenced modules with names vmod0, vmod1, ... @@ -451,10 +451,8 @@ def replace(s): ).replace("\n", "\n " + indent), ) - diagram = Diagram(self) - body = "\n\n".join( - make_class_definition(table) for table in diagram.topological_sort() - ) + tables = self.connection.dependencies.topo_sort() + body = "\n\n".join(make_class_definition(table) for table in tables) python_code = "\n\n".join( ( '"""This module was auto-generated by datajoint from an existing schema"""', @@ -480,11 +478,12 @@ def list_tables(self): :return: A list of table names from the database schema. """ + self.connection.dependencies.load() return [ t for d, t in ( full_t.replace("`", "").split(".") - for full_t in Diagram(self).topological_sort() + for full_t in self.connection.dependencies.topo_sort() ) if d == self.database ] @@ -533,7 +532,6 @@ def __init__( def list_schemas(connection=None): """ - :param connection: a dj.Connection object :return: list of all accessible schemas on the server """ diff --git a/datajoint/table.py b/datajoint/table.py index 1ad4177a..db9eaffa 100644 --- a/datajoint/table.py +++ b/datajoint/table.py @@ -196,7 +196,6 @@ def parents(self, primary=None, as_objects=False, foreign_key_info=False): def children(self, primary=None, as_objects=False, foreign_key_info=False): """ - :param primary: if None, then all children are returned. If True, then only foreign keys composed of primary key attributes are considered. If False, return foreign keys including at least one secondary attribute. @@ -218,7 +217,6 @@ def children(self, primary=None, as_objects=False, foreign_key_info=False): def descendants(self, as_objects=False): """ - :param as_objects: False - a list of table names; True - a list of table objects. :return: list of tables descendants in topological order. """ @@ -230,7 +228,6 @@ def descendants(self, as_objects=False): def ancestors(self, as_objects=False): """ - :param as_objects: False - a list of table names; True - a list of table objects. :return: list of tables ancestors in topological order. """ @@ -246,6 +243,7 @@ def parts(self, as_objects=False): :param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects. """ + self.connection.dependencies.load(force=False) nodes = [ node for node in self.connection.dependencies.nodes @@ -427,7 +425,8 @@ def insert( self.connection.query(query) return - field_list = [] # collects the field list from first row (passed by reference) + # collects the field list from first row (passed by reference) + field_list = [] rows = list( self.__make_row_to_insert(row, field_list, ignore_extra_fields) for row in rows @@ -520,7 +519,8 @@ def cascade(table): delete_count = table.delete_quick(get_count=True) except IntegrityError as error: match = foreign_key_error_regexp.match(error.args[0]).groupdict() - if "`.`" not in match["child"]: # if schema name missing, use table + # if schema name missing, use table + if "`.`" not in match["child"]: match["child"] = "{}.{}".format( table.full_table_name.split(".")[0], match["child"] ) @@ -964,7 +964,8 @@ def lookup_class_name(name, context, depth=3): while nodes: node = nodes.pop(0) for member_name, member in node["context"].items(): - if not member_name.startswith("_"): # skip IPython's implicit variables + # skip IPython's implicit variables + if not member_name.startswith("_"): if inspect.isclass(member) and issubclass(member, Table): if member.full_table_name == name: # found it! return ".".join([node["context_name"], member_name]).lstrip(".") diff --git a/datajoint/user_tables.py b/datajoint/user_tables.py index bcb6a027..0a784560 100644 --- a/datajoint/user_tables.py +++ b/datajoint/user_tables.py @@ -2,6 +2,7 @@ Hosts the table tiers, user tables should be derived from. """ +import re from .table import Table from .autopopulate import AutoPopulate from .utils import from_camel_case, ClassProperty @@ -242,3 +243,29 @@ def drop(self, force=False): def alter(self, prompt=True, context=None): # without context, use declaration context which maps master keyword to master table super().alter(prompt=prompt, context=context or self.declaration_context) + + +user_table_classes = (Manual, Lookup, Computed, Imported, Part) + + +class _AliasNode: + """ + special class to indicate aliased foreign keys + """ + + pass + + +def _get_tier(table_name): + """given the table name, return""" + if not table_name.startswith("`"): + return _AliasNode + else: + try: + return next( + tier + for tier in user_table_classes + if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2]) + ) + except StopIteration: + return None diff --git a/tests/test_blob.py b/tests/test_blob.py index 12039f7f..6c5a6f5a 100644 --- a/tests/test_blob.py +++ b/tests/test_blob.py @@ -185,7 +185,7 @@ def test_insert_longblob(schema_any): query_mym_blob = {"id": 1, "data": np.array([1, 2, 3])} Longblob.insert1(query_mym_blob) - assert (Longblob & "id=1").fetch1()["data"].all() == query_mym_blob["data"].all() + assert_array_equal((Longblob & "id=1").fetch1()["data"], query_mym_blob["data"]) (Longblob & "id=1").delete() @@ -218,7 +218,8 @@ def test_insert_longblob_32bit(schema_any, enable_feature_32bit_dims): ), } assert fetched["id"] == expected["id"] - assert np.array_equal(fetched["data"], expected["data"]) + for name in expected["data"][0][0].dtype.names: + assert_array_equal(expected["data"][0][0][name], fetched["data"][0][0][name]) (Longblob & "id=1").delete() @@ -248,4 +249,5 @@ def test_datetime_serialization_speed(): ) print(f"python time {baseline_exe_time}") - assert optimized_exe_time * 900 < baseline_exe_time + # The time savings were much greater (x1000) but use x10 for testing + assert optimized_exe_time * 10 < baseline_exe_time diff --git a/tests/test_cli.py b/tests/test_cli.py index 3f0fd00c..29fedf22 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -3,7 +3,6 @@ """ import json -import ast import subprocess import pytest import datajoint as dj diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index 987acc6c..5a4acd7d 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -1,51 +1,5 @@ from datajoint import errors from pytest import raises -from datajoint.dependencies import unite_master_parts - - -def test_unite_master_parts(): - assert unite_master_parts( - [ - "`s`.`a`", - "`s`.`a__q`", - "`s`.`b`", - "`s`.`c`", - "`s`.`c__q`", - "`s`.`b__q`", - "`s`.`d`", - "`s`.`a__r`", - ] - ) == [ - "`s`.`a`", - "`s`.`a__q`", - "`s`.`a__r`", - "`s`.`b`", - "`s`.`b__q`", - "`s`.`c`", - "`s`.`c__q`", - "`s`.`d`", - ] - assert unite_master_parts( - [ - "`lab`.`#equipment`", - "`cells`.`cell_analysis_method`", - "`cells`.`cell_analysis_method_task_type`", - "`cells`.`cell_analysis_method_users`", - "`cells`.`favorite_selection`", - "`cells`.`cell_analysis_method__cell_selection_params`", - "`lab`.`#equipment__config`", - "`cells`.`cell_analysis_method__field_detect_params`", - ] - ) == [ - "`lab`.`#equipment`", - "`lab`.`#equipment__config`", - "`cells`.`cell_analysis_method`", - "`cells`.`cell_analysis_method__cell_selection_params`", - "`cells`.`cell_analysis_method__field_detect_params`", - "`cells`.`cell_analysis_method_task_type`", - "`cells`.`cell_analysis_method_users`", - "`cells`.`favorite_selection`", - ] def test_nullable_dependency(thing_tables):