From 9592111d1511198ad8470ec1ab3bd65b7ebf89b6 Mon Sep 17 00:00:00 2001 From: Antti Soininen Date: Wed, 30 Oct 2024 13:38:14 +0200 Subject: [PATCH] Fix index ordering when merging values on import In case of string indices, we want to append the new parameter after the old one. --- spinedb_api/import_functions.py | 19 +++++++------ spinedb_api/parameter_value.py | 49 ++++++++++++++++----------------- tests/test_import_functions.py | 34 +++++++++++++++++++++++ 3 files changed, 67 insertions(+), 35 deletions(-) diff --git a/spinedb_api/import_functions.py b/spinedb_api/import_functions.py index 9af97e80..03dda1ec 100644 --- a/spinedb_api/import_functions.py +++ b/spinedb_api/import_functions.py @@ -17,7 +17,7 @@ """ from collections import defaultdict from .helpers import _parse_metadata -from .parameter_value import fancy_type_to_type_and_rank, fix_conflict, to_database +from .parameter_value import fancy_type_to_type_and_rank, get_conflict_fixer, to_database def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs): @@ -79,8 +79,9 @@ def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs """ all_errors = [] num_imports = 0 + conflict_fixer = get_conflict_fixer(on_conflict) for item_type, items in get_data_for_import( - db_map, all_errors, unparse_value=unparse_value, on_conflict=on_conflict, **kwargs + db_map, all_errors, unparse_value=unparse_value, fix_value_conflict=conflict_fixer, **kwargs ): added, updated, errors = db_map.add_update_items(item_type, *items, strict=False) num_imports += len(added + updated) @@ -92,7 +93,7 @@ def get_data_for_import( db_map, all_errors, unparse_value=to_database, - on_conflict="merge", + fix_value_conflict=get_conflict_fixer("merge"), entity_classes=(), entities=(), entity_groups=(), @@ -136,7 +137,7 @@ def get_data_for_import( db_map (DatabaseMapping): database mapping all_errors (list of str): errors encountered during import unparse_value (Callable): function to call when parsing parameter values - on_conflict (str): Conflict resolution strategy for :func:`~spinedb_api.parameter_value.fix_conflict` + fix_value_conflict (str): Conflict resolution strategy for :func:`~spinedb_api.parameter_value.fix_conflict` entity_classes (list(tuple(str,tuple,str,int)): tuples of (name, dimension name tuple, description, display icon integer) parameter_definitions (list(tuple(str,str,str,str)): @@ -219,14 +220,14 @@ def get_data_for_import( if parameter_values: yield ( "parameter_value", - _get_parameter_values_for_import(db_map, parameter_values, all_errors, unparse_value, on_conflict), + _get_parameter_values_for_import(db_map, parameter_values, all_errors, unparse_value, fix_value_conflict), ) if object_parameter_values: # Legacy yield from get_data_for_import( db_map, all_errors, unparse_value=unparse_value, - on_conflict=on_conflict, + fix_value_conflict=fix_value_conflict, parameter_values=object_parameter_values, ) if relationship_parameter_values: # Legacy @@ -234,7 +235,7 @@ def get_data_for_import( db_map, all_errors, unparse_value=unparse_value, - on_conflict=on_conflict, + fix_value_conflict=fix_value_conflict, parameter_values=relationship_parameter_values, ) if metadata: @@ -627,7 +628,7 @@ def _get_parameter_definitions_for_import(data, unparse_value): yield dict(zip(key, (class_name, parameter_name, value, type_, *optionals))) -def _get_parameter_values_for_import(db_map, data, all_errors, unparse_value, on_conflict): +def _get_parameter_values_for_import(db_map, data, all_errors, unparse_value, fix_conflict): seen = set() key = ("entity_class_name", "entity_byname", "parameter_definition_name", "alternative_name", "value", "type") for class_name, entity_byname, parameter_name, value, *optionals in data: @@ -648,7 +649,7 @@ def _get_parameter_values_for_import(db_map, data, all_errors, unparse_value, on item = dict(zip(key, unique_values + (None, None))) pv = db_map.mapped_table("parameter_value").find_item(item) if pv: - value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]), on_conflict) + value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"])) item.update({"value": value, "type": type_}) yield item diff --git a/spinedb_api/parameter_value.py b/spinedb_api/parameter_value.py index f6c72995..3ef031d7 100644 --- a/spinedb_api/parameter_value.py +++ b/spinedb_api/parameter_value.py @@ -336,31 +336,6 @@ def from_dict(value): raise ParameterValueFormatError(f'"{error.args[0]}" is missing in the parameter value description') from error -def fix_conflict(new, old, on_conflict="merge"): - """Resolves conflicts between parameter values: - - :meta private: - - Args: - new (:class:`ParameterValue`, float, str, bool or None): new parameter value to be written. - old (:class:`ParameterValue`, float, str, bool or None): an existing parameter value in the db. - on_conflict (str): conflict resolution strategy: - - 'merge': Merge indexes if possible, otherwise replace. - - 'replace': Replace old with new. - - 'keep': Keep old. - - Returns: - :class:`ParameterValue`, float, str, bool or None: a new parameter value with conflicts resolved. - """ - funcs = {"keep": lambda new, old: old, "replace": lambda new, old: new, "merge": merge} - func = funcs.get(on_conflict) - if func is None: - raise RuntimeError( - f"Invalid conflict resolution strategy {on_conflict}, valid strategies are {', '.join(funcs)}" - ) - return func(new, old) - - def merge(value, other): """Merges the DB representation of two parameter values. @@ -386,6 +361,28 @@ def merge_parsed(parsed_value, parsed_other): return parsed_value.merge(parsed_other) +_MERGE_FUNCTIONS = {"keep": lambda new, old: old, "replace": lambda new, old: new, "merge": merge} + + +def get_conflict_fixer(on_conflict): + """ + :meta private: + Returns parameter value conflict resolution function. + + Args: + on_conflict (str): resolution action name + + Returns: + Callable: conflict resolution function + """ + try: + return _MERGE_FUNCTIONS[on_conflict] + except KeyError: + raise RuntimeError( + f"Invalid conflict resolution strategy {on_conflict}, valid strategies are {', '.join(_MERGE_FUNCTIONS)}" + ) + + def _break_dictionary(data): """Converts {"index": value} style dictionary into (list(indexes), numpy.ndarray(values)) tuple.""" if not isinstance(data, dict): @@ -1065,7 +1062,7 @@ def merge(self, other): # Avoid sorting when indices are arbitrary strings existing = set(self.indexes) additional = [x for x in other.indexes if x not in existing] - new_indexes = np.concat((self.indexes, additional)) + new_indexes = np.concat((additional, self.indexes)) def _merge(value, other): return other if value is None else merge_parsed(value, other) diff --git a/tests/test_import_functions.py b/tests/test_import_functions.py index f4296507..74252b0a 100644 --- a/tests/test_import_functions.py +++ b/tests/test_import_functions.py @@ -40,6 +40,7 @@ import_scenarios, ) from spinedb_api.parameter_value import ( + Map, TimePattern, TimeSeriesFixedResolution, dump_db_value, @@ -999,6 +1000,39 @@ def test_import_same_as_existing_value_on_conflict_merge_time_pattern(self): self.assertTrue(value_item) self.assertEqual(value_item["parsed_value"], value) + def test_import_existing_map_on_conflict_merge(self): + with DatabaseMapping("sqlite://", create=True) as db_map: + self._assert_success(db_map.add_entity_class_item(name="Object")) + self._assert_success(db_map.add_parameter_definition_item(name="X", entity_class_name="Object")) + self._assert_success(db_map.add_entity_item(name="widget", entity_class_name="Object")) + value = Map(["T1", "T2"], [1.1, 1.2]) + db_value, value_type = to_database(value) + self._assert_success( + db_map.add_parameter_value_item( + entity_class_name="Object", + entity_byname=("widget",), + parameter_definition_name="X", + alternative_name="Base", + value=db_value, + type=value_type, + ) + ) + extended_value = Map(["T3", "T4"], [1.3, 1.4]) + self._assert_imports( + import_parameter_values( + db_map, [["Object", "widget", "X", extended_value, "Base"]], on_conflict="merge" + ) + ) + value_item = db_map.get_parameter_value_item( + entity_class_name="Object", + entity_byname=("widget",), + parameter_definition_name="X", + alternative_name="Base", + ) + self.assertTrue(value_item) + merged_value = Map(["T1", "T2", "T3", "T4"], [1.1, 1.2, 1.3, 1.4]) + self.assertEqual(value_item["parsed_value"], merged_value) + def test_import_duplicate_object_parameter_value(self): with DatabaseMapping("sqlite://", create=True) as db_map: self.populate(db_map)