diff --git a/spinetoolbox/mvcmodels/compound_table_model.py b/spinetoolbox/mvcmodels/compound_table_model.py index cd00d794a..d1133db22 100644 --- a/spinetoolbox/mvcmodels/compound_table_model.py +++ b/spinetoolbox/mvcmodels/compound_table_model.py @@ -13,6 +13,7 @@ """Models that vertically concatenate two or more table models.""" import bisect from PySide6.QtCore import QModelIndex, Qt, QTimer, Slot +from ..helpers import rows_to_row_count_tuples from ..mvcmodels.minimal_table_model import MinimalTableModel @@ -20,8 +21,7 @@ class CompoundTableModel(MinimalTableModel): """A model that concatenates several sub table models vertically.""" def __init__(self, parent=None, header=None): - """Initializes model. - + """ Args: parent (QObject, optional): the parent object header (list of str, optional): header labels @@ -125,7 +125,7 @@ def _append_row_map(self, row_map): row_map (list): tuples (model, row number) """ for model_row_tup in row_map: - self._inv_row_map[model_row_tup] = self.rowCount() + self._inv_row_map[model_row_tup] = len(self._row_map) self._row_map.append(model_row_tup) def _row_map_iterator_for_model(self, model): @@ -331,20 +331,21 @@ def _handle_single_model_about_to_be_reset(self, model): row_map = self._row_map_for_model(model) if not row_map: return - try: - first = self._inv_row_map[row_map[0]] - except KeyError: - # Sometimes the submodel may get reset before it has been added to the inverted row map. - # In this case there are no rows to remove, so we can bail out here. - return - last = first + len(row_map) - 1 - tail_row_map = self._row_map[last + 1 :] - self.beginRemoveRows(QModelIndex(), first, last) - for key in self._row_map[first:]: - del self._inv_row_map[key] - self._row_map[first:] = [] - self._append_row_map(tail_row_map) - self.endRemoveRows() + removed_rows = [] + for mapped_row in row_map: + try: + removed_rows.append(self._inv_row_map[mapped_row]) + except KeyError: + pass + for first, count in sorted(rows_to_row_count_tuples(removed_rows), reverse=True): + last = first + count - 1 + tail_row_map = self._row_map[last + 1 :] + self.beginRemoveRows(QModelIndex(), first, last) + for key in self._row_map[first:]: + del self._inv_row_map[key] + del self._row_map[first:] + self._append_row_map(tail_row_map) + self.endRemoveRows() def _handle_single_model_reset(self, model): """Runs when given model is reset.""" diff --git a/spinetoolbox/spine_db_editor/mvcmodels/compound_models.py b/spinetoolbox/spine_db_editor/mvcmodels/compound_models.py index 735944a1a..3eee1bcc6 100644 --- a/spinetoolbox/spine_db_editor/mvcmodels/compound_models.py +++ b/spinetoolbox/spine_db_editor/mvcmodels/compound_models.py @@ -192,7 +192,7 @@ def _auto_filter_accepts_model(self, model): for db_map, entity_class_id in values: if model.db_map == db_map and (entity_class_id is None or model.entity_class_id == entity_class_id): break - else: # nobreak + else: return False return True @@ -391,7 +391,6 @@ def handle_items_removed(self, db_map_data): Args: db_map_data (dict): list of removed dict-items keyed by DatabaseMapping """ - self.layoutAboutToBeChanged.emit() for db_map, items in db_map_data.items(): if db_map not in self.db_maps: continue @@ -403,23 +402,77 @@ def handle_items_removed(self, db_map_data): removed_ids = {x["id"] for x in items_per_class.get(model.entity_class_id, {})} if not removed_ids: continue - removed_rows = [] + removed_invisible_rows = set() + removed_visible_rows = [] for row in range(model.rowCount()): id_ = model._main_data[row] if id_ in removed_ids: - removed_rows.append(row) removed_ids.remove(id_) - if not removed_ids: - break - for row, count in sorted(rows_to_row_count_tuples(removed_rows), reverse=True): - del model._main_data[row : row + count] + if (model, row) in self._inv_row_map: + removed_visible_rows.append(row) + else: + removed_invisible_rows.add(row) + removed_compound_rows = [self._inv_row_map[(model, row)] for row in removed_visible_rows] + if removed_invisible_rows: + new_kept_rows = self._delete_rows_from_single_model(model, removed_invisible_rows) + self._update_single_model_rows_in_row_map(model, new_kept_rows) + for first_compound_row, count in sorted(rows_to_row_count_tuples(removed_compound_rows), reverse=True): + self.beginRemoveRows(QModelIndex(), first_compound_row, first_compound_row + count - 1) + removed_model_rows = { + self._row_map[r][1] for r in range(first_compound_row, first_compound_row + count) + } + new_kept_rows = self._delete_rows_from_single_model(model, removed_model_rows) + for row in removed_model_rows: + del self._inv_row_map[(model, row)] + self._update_single_model_rows_in_row_map(model, new_kept_rows) + del self._row_map[first_compound_row : first_compound_row + count] + for row, mapped_row in enumerate(self._row_map[first_compound_row:]): + self._inv_row_map[mapped_row] = row + first_compound_row + self.endRemoveRows() if model.rowCount() == 0: emptied_single_model_indexes.append(model_index) for model_index in reversed(emptied_single_model_indexes): model = self.sub_models.pop(model_index) model.deleteLater() - self._do_refresh() - self.layoutChanged.emit() + + def _delete_rows_from_single_model(self, model, rows_to_remove): + """Removes rows from given single model and computes a map from original rows to retained rows. + + Args: + model (SingleModelBase): single model to delete data from + rows_to_remove (set of int): row index that should be removed + + Returns: + dict: mapping from original row index to post-removal row index + """ + new_kept_rows = {} + sorted_deleted_rows = [] + for row in range(model.rowCount()): + if row in rows_to_remove: + sorted_deleted_rows.append(row) + else: + new_kept_rows[row] = row - len(sorted_deleted_rows) + for row in reversed(sorted_deleted_rows): + del model._main_data[row] + return new_kept_rows + + def _update_single_model_rows_in_row_map(self, model, new_rows): + """Rewrites single model rows in row map. + + Args: + model (SingleModelBase): single model whose rows to update + new_rows (dict): mapping from old row index to updated index + """ + new_inv_row_map = {} + for row, new_row in new_rows.items(): + try: + compound_row = self._inv_row_map.pop((model, row)) + except KeyError: + continue + self._row_map[compound_row] = (model, new_row) + new_inv_row_map[(model, new_row)] = compound_row + for mapped_row, compound_row in new_inv_row_map.items(): + self._inv_row_map[mapped_row] = compound_row def db_item(self, index): sub_index = self.map_to_sub(index) diff --git a/spinetoolbox/spine_db_editor/widgets/custom_menus.py b/spinetoolbox/spine_db_editor/widgets/custom_menus.py index 5a3669c7c..90cd7e3d2 100644 --- a/spinetoolbox/spine_db_editor/widgets/custom_menus.py +++ b/spinetoolbox/spine_db_editor/widgets/custom_menus.py @@ -253,7 +253,7 @@ def __init__(self, parent): "Clear", self.clear_recents, enabled=self.has_recents(), - icon=QIcon(":icons/trash-alt.svg"), + icon=QIcon(":icons/menu_icons/trash-alt.svg"), ) def has_recents(self): diff --git a/spinetoolbox/widgets/custom_menus.py b/spinetoolbox/widgets/custom_menus.py index 9bbe7c8ad..ab6836c15 100644 --- a/spinetoolbox/widgets/custom_menus.py +++ b/spinetoolbox/widgets/custom_menus.py @@ -105,7 +105,7 @@ def __init__(self, parent): "Clear", lambda checked=False: self.call_clear_recents(checked), enabled=self.has_recents(), - icon=QIcon(":icons/trash-alt.svg"), + icon=QIcon(":icons/menu_icons/trash-alt.svg"), ) def has_recents(self): diff --git a/tests/spine_db_editor/mvcmodels/test_compound_models.py b/tests/spine_db_editor/mvcmodels/test_compound_models.py index 9e06845c5..19c4d02c8 100644 --- a/tests/spine_db_editor/mvcmodels/test_compound_models.py +++ b/tests/spine_db_editor/mvcmodels/test_compound_models.py @@ -11,6 +11,7 @@ ###################################################################################################################### """Unit tests for the models in ``compound_models`` module.""" +from itertools import product import unittest from spinedb_api import Array, to_database from spinetoolbox.spine_db_editor.mvcmodels.compound_models import ( @@ -73,8 +74,10 @@ def test_model_updates_when_entity_class_is_removed(self): model = CompoundParameterDefinitionModel(self._db_editor, self._db_mngr, self._db_map) model.init_model() fetch_model(model) - model.set_filter_class_ids({self._db_map: {entity_class_2["id"]}}) self.assertEqual(model.rowCount(), 4) + model.set_filter_class_ids({self._db_map: {entity_class_2["id"]}}) + model.refresh() + self.assertEqual(model.rowCount(), 3) self._db_mngr.remove_items({self._db_map: {"entity_class": [entity_class_2["id"]]}}) self.assertEqual(model.rowCount(), 1) @@ -189,6 +192,541 @@ def test_index_name_returns_sane_label(self): index = model.index(0, 3) self.assertEqual(model.index_name(index), "TestCompoundParameterValueModel_db - x - Base - mysterious cube") + def test_removing_first_of_two_rows(self): + self.assert_success(self._db_map.add_entity_class_item(name="Object")) + self.assert_success(self._db_map.add_parameter_definition_item(name="X", entity_class_name="Object")) + self.assert_success(self._db_map.add_alternative_item(name="not-Base")) + self.assert_success(self._db_map.add_entity_item(name="curious sphere", entity_class_name="Object")) + value, value_type = to_database(2.3) + value_in_base = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-2.3) + value_not_in_base = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="not-Base", + value=value, + type=value_type, + ) + ) + self._db_map.commit_session("Add data") + model = CompoundParameterValueModel(self._db_editor, self._db_mngr, self._db_map) + model.init_model() + fetch_model(model) + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + ["Object", "curious sphere", "X", "not-Base", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + self.assertEqual(model.columnCount(), 6) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + value_in_base.remove() + value_not_in_base.remove() + expected = [ + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + value_not_in_base.restore() + value_in_base.restore() + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + ["Object", "curious sphere", "X", "not-Base", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + + def test_removing_second_of_two_uncommitted_rows(self): + self.assert_success(self._db_map.add_entity_class_item(name="Object")) + self.assert_success(self._db_map.add_parameter_definition_item(name="X", entity_class_name="Object")) + self.assert_success(self._db_map.add_alternative_item(name="not-Base")) + self.assert_success(self._db_map.add_entity_item(name="curious sphere", entity_class_name="Object")) + value, value_type = to_database(2.3) + value_in_base = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-2.3) + value_not_in_base = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="not-Base", + value=value, + type=value_type, + ) + ) + model = CompoundParameterValueModel(self._db_editor, self._db_mngr, self._db_map) + model.init_model() + fetch_model(model) + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + ["Object", "curious sphere", "X", "not-Base", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + self.assertEqual(model.columnCount(), 6) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + value_not_in_base.remove() + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + + def test_restoring_removed_item_keeps_empty_row_last(self): + self.assert_success(self._db_map.add_entity_class_item(name="Object")) + self.assert_success(self._db_map.add_parameter_definition_item(name="X", entity_class_name="Object")) + self.assert_success(self._db_map.add_alternative_item(name="not-Base")) + self.assert_success(self._db_map.add_entity_item(name="curious sphere", entity_class_name="Object")) + value, value_type = to_database(2.3) + value_in_base = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-2.3) + value_not_in_base = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="not-Base", + value=value, + type=value_type, + ) + ) + self._db_map.commit_session("Add data") + model = CompoundParameterValueModel(self._db_editor, self._db_mngr, self._db_map) + model.init_model() + fetch_model(model) + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + ["Object", "curious sphere", "X", "not-Base", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + self.assertEqual(model.columnCount(), 6) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + value_in_base.remove() + expected = [ + ["Object", "curious sphere", "X", "not-Base", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + value_not_in_base.remove() + expected = [ + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + self.assertEqual(model.single_models, []) + + def test_removing_value_from_another_alternative_that_is_selected_for_filtering_works(self): + self.assert_success(self._db_map.add_entity_class_item(name="Object")) + self.assert_success(self._db_map.add_parameter_definition_item(name="X", entity_class_name="Object")) + not_base_alternative = self.assert_success(self._db_map.add_alternative_item(name="not-Base")) + self.assert_success(self._db_map.add_entity_item(name="curious sphere", entity_class_name="Object")) + value, value_type = to_database(2.3) + value_in_base = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-2.3) + self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="not-Base", + value=value, + type=value_type, + ) + ) + self._db_map.commit_session("Add data") + model = CompoundParameterValueModel(self._db_editor, self._db_mngr, self._db_map) + model.init_model() + fetch_model(model) + self.assertEqual(model.rowCount(), 3) + self.assertEqual(model.columnCount(), 6) + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + ["Object", "curious sphere", "X", "not-Base", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + model.set_filter_alternative_ids({self._db_map: {not_base_alternative["id"]}}) + model.refresh() + self.assertEqual(model.rowCount(), 2) + expected = [ + ["Object", "curious sphere", "X", "not-Base", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + value_in_base.remove() + self.assertEqual(model.rowCount(), 2) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + + def test_restoring_removed_value_from_another_alternative_that_is_selected_for_filtering_works(self): + self.assert_success(self._db_map.add_entity_class_item(name="Object")) + self.assert_success(self._db_map.add_parameter_definition_item(name="X", entity_class_name="Object")) + not_base_alternative = self.assert_success(self._db_map.add_alternative_item(name="not-Base")) + self.assert_success(self._db_map.add_entity_item(name="curious sphere", entity_class_name="Object")) + value, value_type = to_database(2.3) + value_in_base = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-2.3) + self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="not-Base", + value=value, + type=value_type, + ) + ) + self._db_map.commit_session("Add test data") + model = CompoundParameterValueModel(self._db_editor, self._db_mngr, self._db_map) + model.init_model() + fetch_model(model) + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + ["Object", "curious sphere", "X", "not-Base", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + self.assertEqual(model.columnCount(), 6) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + model.set_filter_alternative_ids({self._db_map: {not_base_alternative["id"]}}) + model.refresh() + self.assertEqual(model.rowCount(), 2) + expected = [ + ["Object", "curious sphere", "X", "not-Base", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + value_in_base.remove() + self.assertEqual(model.rowCount(), 2) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + value_in_base.restore() + self.assertEqual(model.rowCount(), 2) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + + def test_remove_every_other_row(self): + self.assert_success(self._db_map.add_entity_class_item(name="Object")) + self.assert_success(self._db_map.add_parameter_definition_item(name="X", entity_class_name="Object")) + self.assert_success(self._db_map.add_alternative_item(name="ctrl")) + self.assert_success(self._db_map.add_alternative_item(name="alt")) + self.assert_success(self._db_map.add_alternative_item(name="del")) + self.assert_success(self._db_map.add_entity_item(name="curious sphere", entity_class_name="Object")) + value, value_type = to_database(2.3) + self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-2.3) + self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="ctrl", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(23.0) + alt_value = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="alt", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-23.0) + del_value = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="del", + value=value, + type=value_type, + ) + ) + self._db_map.commit_session("Add test data") + model = CompoundParameterValueModel(self._db_editor, self._db_mngr, self._db_map) + model.init_model() + fetch_model(model) + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + ["Object", "curious sphere", "X", "alt", "23.0", self.db_codename], + ["Object", "curious sphere", "X", "ctrl", "-2.3", self.db_codename], + ["Object", "curious sphere", "X", "del", "-23.0", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + self.assertEqual(model.columnCount(), 6) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + self._db_map.remove_items("parameter_value", alt_value["id"], del_value["id"]) + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + ["Object", "curious sphere", "X", "ctrl", "-2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + + def test_remove_item_from_another_entity_class_than_selected(self): + object_class = self.assert_success(self._db_map.add_entity_class_item(name="Object")) + self.assert_success(self._db_map.add_parameter_definition_item(name="X", entity_class_name="Object")) + self.assert_success(self._db_map.add_entity_item(name="curious sphere", entity_class_name="Object")) + value, value_type = to_database(2.3) + self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + self.assert_success(self._db_map.add_entity_class_item(name="Immaterial")) + self.assert_success(self._db_map.add_parameter_definition_item(name="Y", entity_class_name="Immaterial")) + self.assert_success(self._db_map.add_parameter_definition_item(name="Z", entity_class_name="Immaterial")) + self.assert_success(self._db_map.add_entity_item(name="ghost", entity_class_name="Immaterial")) + value, value_type = to_database(-2.3) + self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Immaterial", + entity_byname=("ghost",), + parameter_definition_name="Y", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(23.0) + z_value = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Immaterial", + entity_byname=("ghost",), + parameter_definition_name="Z", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + self._db_map.commit_session("Add test data") + model = CompoundParameterValueModel(self._db_editor, self._db_mngr, self._db_map) + model.init_model() + fetch_model(model) + expected = [ + ["Immaterial", "ghost", "Y", "Base", "-2.3", self.db_codename], + ["Immaterial", "ghost", "Z", "Base", "23.0", self.db_codename], + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + self.assertEqual(model.columnCount(), 6) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + model.set_filter_class_ids({self._db_map: {object_class["id"]}}) + model.refresh() + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + self.assertEqual(model.columnCount(), 6) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + z_value.remove() + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + + def test_remove_visible_and_hidden_items(self): + alternative = self.assert_success(self._db_map.add_alternative_item(name="alt")) + self.assert_success(self._db_map.add_entity_class_item(name="Object")) + self.assert_success(self._db_map.add_parameter_definition_item(name="X", entity_class_name="Object")) + self.assert_success(self._db_map.add_entity_item(name="mystic cube", entity_class_name="Object")) + self.assert_success(self._db_map.add_entity_item(name="curious sphere", entity_class_name="Object")) + value, value_type = to_database(2.3) + spherical_value_in_base = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-2.3) + spherical_value_in_alt = self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("curious sphere",), + parameter_definition_name="X", + alternative_name="alt", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(23.0) + self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("mystic cube",), + parameter_definition_name="X", + alternative_name="Base", + value=value, + type=value_type, + ) + ) + value, value_type = to_database(-23.0) + self.assert_success( + self._db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("mystic cube",), + parameter_definition_name="X", + alternative_name="alt", + value=value, + type=value_type, + ) + ) + self._db_map.commit_session("Add test data") + model = CompoundParameterValueModel(self._db_editor, self._db_mngr, self._db_map) + model.init_model() + fetch_model(model) + expected = [ + ["Object", "curious sphere", "X", "Base", "2.3", self.db_codename], + ["Object", "curious sphere", "X", "alt", "-2.3", self.db_codename], + ["Object", "mystic cube", "X", "Base", "23.0", self.db_codename], + ["Object", "mystic cube", "X", "alt", "-23.0", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + self.assertEqual(model.columnCount(), 6) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + model.set_filter_alternative_ids({self._db_map: {alternative["id"]}}) + model.refresh() + expected = [ + ["Object", "curious sphere", "X", "alt", "-2.3", self.db_codename], + ["Object", "mystic cube", "X", "alt", "-23.0", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + self.assertEqual(model.columnCount(), 6) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + spherical_value_in_base.remove() + spherical_value_in_alt.remove() + expected = [ + ["Object", "mystic cube", "X", "alt", "-23.0", self.db_codename], + [None, None, None, None, None, None], + ] + self.assertEqual(model.rowCount(), len(expected)) + for row, column in product(range(model.rowCount()), range(model.columnCount())): + with self.subTest(row=row, column=column): + self.assertEqual(model.index(row, column).data(), expected[row][column]) + if __name__ == "__main__": unittest.main() diff --git a/tests/spine_db_editor/widgets/helpers.py b/tests/spine_db_editor/widgets/helpers.py index 5f8455df0..a75e10f30 100644 --- a/tests/spine_db_editor/widgets/helpers.py +++ b/tests/spine_db_editor/widgets/helpers.py @@ -42,6 +42,7 @@ def create_and_store_editor(instance, parent, option, target_index): view.edit(index) if self._cell_editor is None: # Native editor widget is being used, fall back to setting value directly in model. + view.closeEditor() view.model().setData(index, value) return if isinstance(self._cell_editor, SearchBarEditor): @@ -71,6 +72,8 @@ def create_and_store_editor(instance, parent, option, target_index): view.edit(index) def reset(self): + if self._cell_editor is not None: + self._cell_editor.deleteLater() self._cell_editor = None diff --git a/tests/test_SpineToolboxProject.py b/tests/test_SpineToolboxProject.py index 3384896e3..d4e638155 100644 --- a/tests/test_SpineToolboxProject.py +++ b/tests/test_SpineToolboxProject.py @@ -20,7 +20,7 @@ import networkx as nx from PySide6.QtCore import QVariantAnimation from PySide6.QtGui import QColor -from PySide6.QtWidgets import QApplication, QMessageBox +from PySide6.QtWidgets import QApplication, QGraphicsRectItem, QMessageBox from spine_engine.project_item.executable_item_base import ExecutableItemBase from spine_engine.project_item.project_item_specification import ProjectItemSpecification from spine_engine.spine_engine import ItemExecutionFinishState @@ -778,6 +778,9 @@ def set_rank(self, rank): def set_icon(self, icon): return + def get_icon(self): + return QGraphicsRectItem(0.0, 0.0, 23.0, 23.0) + class _MockItemFactoryForLocalDataTests(ProjectItemFactory): @staticmethod diff --git a/tests/widgets/test_custom_combobox.py b/tests/widgets/test_custom_combobox.py index a3dd29b49..b59e3e4e6 100644 --- a/tests/widgets/test_custom_combobox.py +++ b/tests/widgets/test_custom_combobox.py @@ -13,7 +13,7 @@ """Unit tests for the classes in ``custom_combobox`` module. OpenProjectDialogComboBox is tested in test_open_project_dialog module.""" import unittest -from PySide6.QtGui import QPaintEvent +from PySide6.QtGui import QColor, QImage, QPaintEvent from PySide6.QtWidgets import QWidget from spinetoolbox.widgets.custom_combobox import CustomQComboBox, ElidedCombobox from tests.mock_helpers import TestCaseWithQApplication @@ -21,14 +21,15 @@ class TestCustomComboBoxes(TestCaseWithQApplication): def test_custom_combobox(self): - parent = QWidget() - cb = CustomQComboBox(parent) + cb = CustomQComboBox(None) cb.addItems(["a", "b", "c"]) self.assertEqual("a", cb.itemText(0)) - parent.deleteLater() + cb.deleteLater() def test_elided_combobox(self): - parent = QWidget() - cb = ElidedCombobox(parent) + cb = ElidedCombobox(None) + image = QImage(cb.size(), QImage.Format.Format_RGB32) + image.fill(QColor("white")) + cb.paintEngine = image.paintEngine cb.paintEvent(QPaintEvent(cb.rect())) - parent.deleteLater() + cb.deleteLater()