Skip to content

Commit

Permalink
Try to order tables by class dimension (#2837)
Browse files Browse the repository at this point in the history
* Try to order tables by class dimension

- Rows in tables are sorted by ascending dimension count
- Switch from lexicographical to natural ordering.
  -> Now parameter2 belongs above parameter10.

* Fix helper test
  • Loading branch information
PiispaH authored Jun 12, 2024
1 parent b31813f commit 5d4cd24
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 26 deletions.
12 changes: 12 additions & 0 deletions spinetoolbox/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1825,3 +1825,15 @@ def add_action(self, text, slot, enabled=True, tooltip=None, icon=None):
action.setEnabled(enabled)
if tooltip is not None:
action.setToolTip(tooltip)


def order_key(name):
"""Splits the given string into a list of its substrings and digits
example: "David_1946_Gilmour" -> ["David_", 1946, "_Gilmour"]
If given a string that starts with a digit, a 'big' string (in comparisons) will be added to the start
of the order key that makes sure that every key starts with a str and alternates with int after that.
"""
key_list = [int(text) if text.isdigit() else text for text in re.split(r"(\d+)", name) if text]
if len(key_list) and isinstance(key_list[0], int):
key_list.insert(0, "\U0010FFFF")
return key_list
23 changes: 18 additions & 5 deletions spinetoolbox/spine_db_editor/mvcmodels/single_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

"""Single models for parameter definitions and values (as 'for a single entity')."""
from PySide6.QtCore import Qt
from spinetoolbox.helpers import DB_ITEM_SEPARATOR, plain_to_rich
from spinetoolbox.helpers import DB_ITEM_SEPARATOR, plain_to_rich, order_key
from ...mvcmodels.minimal_table_model import MinimalTableModel
from ..mvcmodels.single_and_empty_model_mixins import SplitValueAndTypeMixin, MakeEntityOnTheFlyMixin
from ...mvcmodels.shared import PARSED_ROLE, DB_MAP_ROLE
Expand Down Expand Up @@ -62,7 +62,15 @@ def __init__(self, parent, db_map, entity_class_id, committed, lazy=False):
def __lt__(self, other):
if self.entity_class_name == other.entity_class_name:
return self.db_map.codename < other.db_map.codename
return self.entity_class_name < other.entity_class_name
keys = {}
for side, model in {"left": self, "right": other}.items():
dim = len(model.dimension_id_list)
class_name = model.entity_class_name
keys[side] = (
dim,
class_name,
)
return keys["left"] < keys["right"]

@property
def item_type(self):
Expand Down Expand Up @@ -365,7 +373,7 @@ def item_type(self):

def _sort_key(self, element):
item = self.db_item_from_id(element)
return item.get("name", "")
return order_key(item.get("name", ""))

def _do_update_items_in_db(self, db_map_data):
self.db_mngr.update_parameter_definitions(db_map_data)
Expand All @@ -387,7 +395,10 @@ def item_type(self):

def _sort_key(self, element):
item = self.db_item_from_id(element)
return (item.get("entity_byname", ()), item.get("parameter_name", ""), item.get("alternative_name", ""))
byname = order_key("_".join(item.get("entity_byname", ())))
parameter_name = order_key(item.get("parameter_name", ""))
alt_name = order_key(item.get("alternative_name", ""))
return byname, parameter_name, alt_name

def _do_update_items_in_db(self, db_map_data):
self.db_mngr.update_parameter_values(db_map_data)
Expand All @@ -402,7 +413,9 @@ def item_type(self):

def _sort_key(self, element):
item = self.db_item_from_id(element)
return (item.get("entity_byname", ()), item.get("alternative_name", ""))
byname = order_key("_".join(item.get("entity_byname", ())))
alt_name = order_key(item.get("alternative_name", ""))
return byname, alt_name

@property
def _references(self):
Expand Down
142 changes: 121 additions & 21 deletions tests/spine_db_editor/widgets/test_custom_qtableview.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,18 +303,51 @@ def setUp(self):
self._temp_dir = TemporaryDirectory()
url = "sqlite:///" + os.path.join(self._temp_dir.name, "test_database.sqlite")
db_map = DatabaseMapping(url, create=True)
import_functions.import_object_classes(db_map, ("object_class",))
self._n_objects = 12
object_data = (("object_class", f"object_{n}") for n in range(self._n_objects))
import_functions.import_objects(db_map, object_data)
# 1-D entity class
self._n_entities = 12
self._n_parameters = 12
import_functions.import_entity_classes(db_map, (("object_class",),))
object_data = [("object_class", f"object_{n}") for n in range(self._n_entities)]
import_functions.import_entities(db_map, object_data)
parameter_definition_data = (("object_class", f"parameter_{n}") for n in range(self._n_parameters))
import_functions.import_object_parameters(db_map, parameter_definition_data)
parameter_value_data = (
("object_class", f"object_{object_n}", f"parameter_{parameter_n}", "a_value")
for object_n, parameter_n in itertools.product(range(self._n_objects), range(self._n_parameters))
for object_n, parameter_n in itertools.product(range(self._n_entities), range(self._n_parameters))
)
import_functions.import_object_parameter_values(db_map, parameter_value_data)
# 2-D entity class
self._n_ND_entities = 2
self._n_ND_parameters = 2
import_functions.import_entity_classes(
db_map,
(
(
"multi_d_class",
(
"object_class",
"object_class",
),
),
),
)
nd_entity_names = [
(f"object_{i}", f"object_{j}") for i, j in itertools.permutations(range(self._n_ND_entities), 2)
]
object_data = [("multi_d_class", byname) for byname in nd_entity_names]
import_functions.import_entities(db_map, object_data)
parameter_definition_data = (("multi_d_class", f"parameter_{n}") for n in range(self._n_ND_parameters))
import_functions.import_object_parameters(db_map, parameter_definition_data)
parameter_value_data = [
(
"multi_d_class",
byname,
f"parameter_{parameter_n}",
"a_value",
)
for byname, parameter_n in itertools.product(nd_entity_names, range(self._n_ND_parameters))
]
import_functions.import_parameter_values(db_map, parameter_value_data)
db_map.commit_session("Add test data.")
db_map.close()
self._common_setup(url, create=False)
Expand All @@ -327,6 +360,9 @@ def tearDown(self):
self._common_tear_down()
self._temp_dir.cleanup()

def _whole_model_rowcount(self):
return self._n_entities * self._n_parameters + self._n_ND_entities * self._n_ND_parameters + 1

def test_purging_value_data_removes_all_rows(self):
table_view = self._db_editor.ui.tableView_parameter_value
model = table_view.model()
Expand Down Expand Up @@ -355,9 +391,9 @@ def test_removing_fetched_rows_allows_still_fetching_more(self):
table_view = self._db_editor.ui.tableView_parameter_value
model = table_view.model()
self.assertEqual(model.rowCount(), self._CHUNK_SIZE + 1)
n_values = self._n_parameters * self._n_objects
n_values = self._whole_model_rowcount() - 1
self._db_mngr.remove_items({self._db_map: {"parameter_value": set(range(1, n_values, 2))}})
self.assertEqual(model.rowCount(), (self._CHUNK_SIZE) / 2 + 1)
self.assertEqual(model.rowCount(), self._CHUNK_SIZE / 2 + 1)

def test_undoing_purge(self):
table_view = self._db_editor.ui.tableView_parameter_value
Expand All @@ -366,19 +402,23 @@ def test_undoing_purge(self):
self._db_mngr.purge_items({self._db_map: ["parameter_value"]})
self.assertEqual(model.rowCount(), 1)
self._db_editor.undo_action.trigger()
while model.rowCount() != self._n_objects * self._n_parameters + 1:
while model.rowCount() != self._whole_model_rowcount():
# Fetch the entire model, because we want to validate all the data.
model.fetchMore(QModelIndex())
QApplication.processEvents()
expected = sorted(
expected = [
["object_class", f"object_{object_n}", f"parameter_{parameter_n}", "Base", "a_value", self.db_codename]
for object_n, parameter_n in itertools.product(range(self._n_entities), range(self._n_parameters))
]
nd_entity_names = [f"object_{i} ǀ object_{j}" for i, j in itertools.permutations(range(self._n_ND_entities), 2)]
expected.extend(
[
["object_class", f"object_{object_n}", f"parameter_{parameter_n}", "Base", "a_value", self.db_codename]
for object_n, parameter_n in itertools.product(range(self._n_objects), range(self._n_parameters))
],
key=lambda x: (x[1], x[2]),
["multi_d_class", entity_name, f"parameter_{parameter_n}", "Base", "a_value", self.db_codename]
for entity_name, parameter_n in itertools.product(nd_entity_names, range(self._n_ND_parameters))
]
)
expected.append([None, None, None, None, None, self.db_codename])
self.assertEqual(model.rowCount(), self._n_objects * self._n_parameters + 1)
self.assertEqual(model.rowCount(), self._whole_model_rowcount())
for row, column in itertools.product(range(model.rowCount()), range(model.columnCount())):
self.assertEqual(model.index(row, column).data(), expected[row][column])

Expand All @@ -394,20 +434,80 @@ def test_rolling_back_purge(self):
instance.exec.return_value = QMessageBox.StandardButton.Ok
self._db_editor.ui.actionRollback.trigger()
self._db_editor.rollback_session()
while model.rowCount() != self._n_objects * self._n_parameters + 1:
while model.rowCount() != self._whole_model_rowcount():
# Fetch the entire model, because we want to validate all the data.
model.fetchMore(QModelIndex())
QApplication.processEvents()
expected = sorted(
expected = [
["object_class", f"object_{object_n}", f"parameter_{parameter_n}", "Base", "a_value", self.db_codename]
for object_n, parameter_n in itertools.product(range(self._n_entities), range(self._n_parameters))
]
nd_entity_names = [f"object_{i} ǀ object_{j}" for i, j in itertools.permutations(range(self._n_ND_entities), 2)]
expected.extend(
[
["object_class", f"object_{object_n}", f"parameter_{parameter_n}", "Base", "a_value", self.db_codename]
for object_n, parameter_n in itertools.product(range(self._n_objects), range(self._n_parameters))
],
key=lambda x: (x[1], x[2]),
["multi_d_class", entity_name, f"parameter_{parameter_n}", "Base", "a_value", self.db_codename]
for entity_name, parameter_n in itertools.product(nd_entity_names, range(self._n_ND_parameters))
]
)
QApplication.processEvents()
expected.append([None, None, None, None, None, self.db_codename])
self.assertEqual(model.rowCount(), self._n_objects * self._n_parameters + 1)
self.assertEqual(model.rowCount(), self._whole_model_rowcount())
for row, column in itertools.product(range(model.rowCount()), range(model.columnCount())):
self.assertEqual(model.index(row, column).data(), expected[row][column])

def test_sorting(self):
"""Test that the parameter value table sorts in an expected order."""
url = "sqlite:///" + os.path.join(self._temp_dir.name, "test_database.sqlite")
db_map = DatabaseMapping(url)
parameter_definition_data = (
("object_class", f"0parameter_"),
("object_class", f"1parameter_"),
)
import_functions.import_object_parameters(db_map, parameter_definition_data)
parameter_value_data = (
("object_class", f"object_0", f"0parameter_", "a_value"),
("object_class", f"object_0", f"1parameter_", "a_value"),
("object_class", f"object_1", f"0parameter_", "a_value"),
("object_class", f"object_1", f"1parameter_", "a_value"),
)
import_functions.import_object_parameter_values(db_map, parameter_value_data)
db_map.commit_session("Add test data.")
db_map.close()
table_view = self._db_editor.ui.tableView_parameter_value
model = table_view.model()
self.assertEqual(model.rowCount(), self._CHUNK_SIZE + 1)
while model.rowCount() != self._whole_model_rowcount() + 4:
model.fetchMore(QModelIndex())
QApplication.processEvents()
expected = []
for object_n in range(self._n_entities):
for parameter_n in range(self._n_parameters):
expected.append(
[
"object_class",
f"object_{object_n}",
f"parameter_{parameter_n}",
"Base",
"a_value",
self.db_codename,
]
)
if object_n < 2:
expected.extend(
[
["object_class", f"object_{object_n}", f"0parameter_", "Base", "a_value", self.db_codename],
["object_class", f"object_{object_n}", f"1parameter_", "Base", "a_value", self.db_codename],
]
)
nd_entity_names = [f"object_{i} ǀ object_{j}" for i, j in itertools.permutations(range(self._n_ND_entities), 2)]
expected.extend(
[
["multi_d_class", entity_name, f"parameter_{parameter_n}", "Base", "a_value", self.db_codename]
for entity_name, parameter_n in itertools.product(nd_entity_names, range(self._n_ND_parameters))
]
)
expected.append([None, None, None, None, None, self.db_codename])
self.assertEqual(model.rowCount(), self._whole_model_rowcount() + 4)
for row, column in itertools.product(range(model.rowCount()), range(model.columnCount())):
self.assertEqual(model.index(row, column).data(), expected[row][column])

Expand Down
7 changes: 7 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
merge_dicts,
HTMLTagFilter,
home_dir,
order_key,
)


Expand Down Expand Up @@ -419,6 +420,12 @@ def test_merge_dicts_when_source_overwrites_data_in_target(self):
merge_dicts({"a": {"b": 2}}, target)
self.assertEqual(target, {"a": {"b": 2}})

def test_order_key(self):
self.assertEqual(["Humphrey_Bogart"], order_key("Humphrey_Bogart"))
self.assertEqual(["Wes_", 1969, "_Anderson"], order_key("Wes_1969_Anderson"))
self.assertEqual(["\U0010ffff", 1899, "_Alfred-", 1980, "Hitchcock"], order_key("1899_Alfred-1980Hitchcock"))
self.assertEqual([], order_key(""))


class TestHTMLTagFilter(unittest.TestCase):
def test_simple_log_line(self):
Expand Down

0 comments on commit 5d4cd24

Please sign in to comment.