Skip to content

Commit

Permalink
Fix index ordering when merging values on import
Browse files Browse the repository at this point in the history
In case of string indices, we want to append the new parameter after
the old one.
  • Loading branch information
soininen committed Oct 30, 2024
1 parent d076802 commit 152d1d1
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 35 deletions.
19 changes: 10 additions & 9 deletions spinedb_api/import_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""
from collections import defaultdict
from .helpers import _parse_metadata
from .parameter_value import fancy_type_to_type_and_rank, fix_conflict, to_database
from .parameter_value import fancy_type_to_type_and_rank, get_conflict_fixer, to_database


def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs):
Expand Down Expand Up @@ -79,8 +79,9 @@ def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs
"""
all_errors = []
num_imports = 0
conflict_fixer = get_conflict_fixer(on_conflict)
for item_type, items in get_data_for_import(
db_map, all_errors, unparse_value=unparse_value, on_conflict=on_conflict, **kwargs
db_map, all_errors, unparse_value=unparse_value, fix_value_conflict=conflict_fixer, **kwargs
):
added, updated, errors = db_map.add_update_items(item_type, *items, strict=False)
num_imports += len(added + updated)
Expand All @@ -92,7 +93,7 @@ def get_data_for_import(
db_map,
all_errors,
unparse_value=to_database,
on_conflict="merge",
fix_value_conflict=get_conflict_fixer("merge"),
entity_classes=(),
entities=(),
entity_groups=(),
Expand Down Expand Up @@ -136,7 +137,7 @@ def get_data_for_import(
db_map (DatabaseMapping): database mapping
all_errors (list of str): errors encountered during import
unparse_value (Callable): function to call when parsing parameter values
on_conflict (str): Conflict resolution strategy for :func:`~spinedb_api.parameter_value.fix_conflict`
fix_value_conflict (Callable): parameter value conflict resolution function
entity_classes (list(tuple(str,tuple,str,int)): tuples of
(name, dimension name tuple, description, display icon integer)
parameter_definitions (list(tuple(str,str,str,str)):
Expand Down Expand Up @@ -219,22 +220,22 @@ def get_data_for_import(
if parameter_values:
yield (
"parameter_value",
_get_parameter_values_for_import(db_map, parameter_values, all_errors, unparse_value, on_conflict),
_get_parameter_values_for_import(db_map, parameter_values, all_errors, unparse_value, fix_value_conflict),
)
if object_parameter_values: # Legacy
yield from get_data_for_import(
db_map,
all_errors,
unparse_value=unparse_value,
on_conflict=on_conflict,
fix_value_conflict=fix_value_conflict,
parameter_values=object_parameter_values,
)
if relationship_parameter_values: # Legacy
yield from get_data_for_import(
db_map,
all_errors,
unparse_value=unparse_value,
on_conflict=on_conflict,
fix_value_conflict=fix_value_conflict,
parameter_values=relationship_parameter_values,
)
if metadata:
Expand Down Expand Up @@ -627,7 +628,7 @@ def _get_parameter_definitions_for_import(data, unparse_value):
yield dict(zip(key, (class_name, parameter_name, value, type_, *optionals)))


def _get_parameter_values_for_import(db_map, data, all_errors, unparse_value, on_conflict):
def _get_parameter_values_for_import(db_map, data, all_errors, unparse_value, fix_conflict):
seen = set()
key = ("entity_class_name", "entity_byname", "parameter_definition_name", "alternative_name", "value", "type")
for class_name, entity_byname, parameter_name, value, *optionals in data:
Expand All @@ -648,7 +649,7 @@ def _get_parameter_values_for_import(db_map, data, all_errors, unparse_value, on
item = dict(zip(key, unique_values + (None, None)))
pv = db_map.mapped_table("parameter_value").find_item(item)
if pv:
value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]), on_conflict)
value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]))
item.update({"value": value, "type": type_})
yield item

Expand Down
49 changes: 23 additions & 26 deletions spinedb_api/parameter_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,31 +336,6 @@ def from_dict(value):
raise ParameterValueFormatError(f'"{error.args[0]}" is missing in the parameter value description') from error


def fix_conflict(new, old, on_conflict="merge"):
"""Resolves conflicts between parameter values:
:meta private:
Args:
new (:class:`ParameterValue`, float, str, bool or None): new parameter value to be written.
old (:class:`ParameterValue`, float, str, bool or None): an existing parameter value in the db.
on_conflict (str): conflict resolution strategy:
- 'merge': Merge indexes if possible, otherwise replace.
- 'replace': Replace old with new.
- 'keep': Keep old.
Returns:
:class:`ParameterValue`, float, str, bool or None: a new parameter value with conflicts resolved.
"""
funcs = {"keep": lambda new, old: old, "replace": lambda new, old: new, "merge": merge}
func = funcs.get(on_conflict)
if func is None:
raise RuntimeError(
f"Invalid conflict resolution strategy {on_conflict}, valid strategies are {', '.join(funcs)}"
)
return func(new, old)


def merge(value, other):
"""Merges the DB representation of two parameter values.
Expand All @@ -386,6 +361,28 @@ def merge_parsed(parsed_value, parsed_other):
return parsed_value.merge(parsed_other)


_MERGE_FUNCTIONS = {"keep": lambda new, old: old, "replace": lambda new, old: new, "merge": merge}


def get_conflict_fixer(on_conflict):
"""
:meta private:
Returns parameter value conflict resolution function.
Args:
on_conflict (str): resolution action name
Returns:
Callable: conflict resolution function
"""
try:
return _MERGE_FUNCTIONS[on_conflict]
except KeyError:
raise RuntimeError(

Check warning on line 381 in spinedb_api/parameter_value.py

View check run for this annotation

Codecov / codecov/patch

spinedb_api/parameter_value.py#L380-L381

Added lines #L380 - L381 were not covered by tests
f"Invalid conflict resolution strategy {on_conflict}, valid strategies are {', '.join(_MERGE_FUNCTIONS)}"
)


def _break_dictionary(data):
"""Converts {"index": value} style dictionary into (list(indexes), numpy.ndarray(values)) tuple."""
if not isinstance(data, dict):
Expand Down Expand Up @@ -1065,7 +1062,7 @@ def merge(self, other):
# Avoid sorting when indices are arbitrary strings
existing = set(self.indexes)
additional = [x for x in other.indexes if x not in existing]
new_indexes = np.concat((self.indexes, additional))
new_indexes = np.concat((additional, self.indexes))

def _merge(value, other):
return other if value is None else merge_parsed(value, other)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_import_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import_scenarios,
)
from spinedb_api.parameter_value import (
Map,
TimePattern,
TimeSeriesFixedResolution,
dump_db_value,
Expand Down Expand Up @@ -999,6 +1000,39 @@ def test_import_same_as_existing_value_on_conflict_merge_time_pattern(self):
self.assertTrue(value_item)
self.assertEqual(value_item["parsed_value"], value)

def test_import_existing_map_on_conflict_merge(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
self._assert_success(db_map.add_entity_class_item(name="Object"))
self._assert_success(db_map.add_parameter_definition_item(name="X", entity_class_name="Object"))
self._assert_success(db_map.add_entity_item(name="widget", entity_class_name="Object"))
value = Map(["T1", "T2"], [1.1, 1.2])
db_value, value_type = to_database(value)
self._assert_success(
db_map.add_parameter_value_item(
entity_class_name="Object",
entity_byname=("widget",),
parameter_definition_name="X",
alternative_name="Base",
value=db_value,
type=value_type,
)
)
extended_value = Map(["T3", "T4"], [1.3, 1.4])
self._assert_imports(
import_parameter_values(
db_map, [["Object", "widget", "X", extended_value, "Base"]], on_conflict="merge"
)
)
value_item = db_map.get_parameter_value_item(
entity_class_name="Object",
entity_byname=("widget",),
parameter_definition_name="X",
alternative_name="Base",
)
self.assertTrue(value_item)
merged_value = Map(["T1", "T2", "T3", "T4"], [1.1, 1.2, 1.3, 1.4])
self.assertEqual(value_item["parsed_value"], merged_value)

def test_import_duplicate_object_parameter_value(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
self.populate(db_map)
Expand Down

0 comments on commit 152d1d1

Please sign in to comment.