Skip to content

Commit

Permalink
Fix index ordering when merging values on import (#463)
Browse files Browse the repository at this point in the history
  • Loading branch information
soininen authored Oct 30, 2024
2 parents d076802 + 152d1d1 commit 4677999
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 35 deletions.
19 changes: 10 additions & 9 deletions spinedb_api/import_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""
from collections import defaultdict
from .helpers import _parse_metadata
from .parameter_value import fancy_type_to_type_and_rank, fix_conflict, to_database
from .parameter_value import fancy_type_to_type_and_rank, get_conflict_fixer, to_database


def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs):
Expand Down Expand Up @@ -79,8 +79,9 @@ def import_data(db_map, unparse_value=to_database, on_conflict="merge", **kwargs
"""
all_errors = []
num_imports = 0
conflict_fixer = get_conflict_fixer(on_conflict)
for item_type, items in get_data_for_import(
db_map, all_errors, unparse_value=unparse_value, on_conflict=on_conflict, **kwargs
db_map, all_errors, unparse_value=unparse_value, fix_value_conflict=conflict_fixer, **kwargs
):
added, updated, errors = db_map.add_update_items(item_type, *items, strict=False)
num_imports += len(added + updated)
Expand All @@ -92,7 +93,7 @@ def get_data_for_import(
db_map,
all_errors,
unparse_value=to_database,
on_conflict="merge",
fix_value_conflict=get_conflict_fixer("merge"),
entity_classes=(),
entities=(),
entity_groups=(),
Expand Down Expand Up @@ -136,7 +137,7 @@ def get_data_for_import(
db_map (DatabaseMapping): database mapping
all_errors (list of str): errors encountered during import
unparse_value (Callable): function to call when parsing parameter values
on_conflict (str): Conflict resolution strategy for :func:`~spinedb_api.parameter_value.fix_conflict`
fix_value_conflict (Callable): parameter value conflict resolution function
entity_classes (list(tuple(str,tuple,str,int)): tuples of
(name, dimension name tuple, description, display icon integer)
parameter_definitions (list(tuple(str,str,str,str)):
Expand Down Expand Up @@ -219,22 +220,22 @@ def get_data_for_import(
if parameter_values:
yield (
"parameter_value",
_get_parameter_values_for_import(db_map, parameter_values, all_errors, unparse_value, on_conflict),
_get_parameter_values_for_import(db_map, parameter_values, all_errors, unparse_value, fix_value_conflict),
)
if object_parameter_values: # Legacy
yield from get_data_for_import(
db_map,
all_errors,
unparse_value=unparse_value,
on_conflict=on_conflict,
fix_value_conflict=fix_value_conflict,
parameter_values=object_parameter_values,
)
if relationship_parameter_values: # Legacy
yield from get_data_for_import(
db_map,
all_errors,
unparse_value=unparse_value,
on_conflict=on_conflict,
fix_value_conflict=fix_value_conflict,
parameter_values=relationship_parameter_values,
)
if metadata:
Expand Down Expand Up @@ -627,7 +628,7 @@ def _get_parameter_definitions_for_import(data, unparse_value):
yield dict(zip(key, (class_name, parameter_name, value, type_, *optionals)))


def _get_parameter_values_for_import(db_map, data, all_errors, unparse_value, on_conflict):
def _get_parameter_values_for_import(db_map, data, all_errors, unparse_value, fix_conflict):
seen = set()
key = ("entity_class_name", "entity_byname", "parameter_definition_name", "alternative_name", "value", "type")
for class_name, entity_byname, parameter_name, value, *optionals in data:
Expand All @@ -648,7 +649,7 @@ def _get_parameter_values_for_import(db_map, data, all_errors, unparse_value, on
item = dict(zip(key, unique_values + (None, None)))
pv = db_map.mapped_table("parameter_value").find_item(item)
if pv:
value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]), on_conflict)
value, type_ = fix_conflict((value, type_), (pv["value"], pv["type"]))
item.update({"value": value, "type": type_})
yield item

Expand Down
49 changes: 23 additions & 26 deletions spinedb_api/parameter_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,31 +336,6 @@ def from_dict(value):
raise ParameterValueFormatError(f'"{error.args[0]}" is missing in the parameter value description') from error


def fix_conflict(new, old, on_conflict="merge"):
"""Resolves conflicts between parameter values:
:meta private:
Args:
new (:class:`ParameterValue`, float, str, bool or None): new parameter value to be written.
old (:class:`ParameterValue`, float, str, bool or None): an existing parameter value in the db.
on_conflict (str): conflict resolution strategy:
- 'merge': Merge indexes if possible, otherwise replace.
- 'replace': Replace old with new.
- 'keep': Keep old.
Returns:
:class:`ParameterValue`, float, str, bool or None: a new parameter value with conflicts resolved.
"""
funcs = {"keep": lambda new, old: old, "replace": lambda new, old: new, "merge": merge}
func = funcs.get(on_conflict)
if func is None:
raise RuntimeError(
f"Invalid conflict resolution strategy {on_conflict}, valid strategies are {', '.join(funcs)}"
)
return func(new, old)


def merge(value, other):
"""Merges the DB representation of two parameter values.
Expand All @@ -386,6 +361,28 @@ def merge_parsed(parsed_value, parsed_other):
return parsed_value.merge(parsed_other)


_MERGE_FUNCTIONS = {"keep": lambda new, old: old, "replace": lambda new, old: new, "merge": merge}


def get_conflict_fixer(on_conflict):
"""
:meta private:
Returns parameter value conflict resolution function.
Args:
on_conflict (str): resolution action name
Returns:
Callable: conflict resolution function
"""
try:
return _MERGE_FUNCTIONS[on_conflict]
except KeyError:
raise RuntimeError(
f"Invalid conflict resolution strategy {on_conflict}, valid strategies are {', '.join(_MERGE_FUNCTIONS)}"
)


def _break_dictionary(data):
"""Converts {"index": value} style dictionary into (list(indexes), numpy.ndarray(values)) tuple."""
if not isinstance(data, dict):
Expand Down Expand Up @@ -1065,7 +1062,7 @@ def merge(self, other):
# Avoid sorting when indices are arbitrary strings
existing = set(self.indexes)
additional = [x for x in other.indexes if x not in existing]
new_indexes = np.concat((self.indexes, additional))
new_indexes = np.concat((additional, self.indexes))

def _merge(value, other):
return other if value is None else merge_parsed(value, other)
Expand Down
34 changes: 34 additions & 0 deletions tests/test_import_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import_scenarios,
)
from spinedb_api.parameter_value import (
Map,
TimePattern,
TimeSeriesFixedResolution,
dump_db_value,
Expand Down Expand Up @@ -999,6 +1000,39 @@ def test_import_same_as_existing_value_on_conflict_merge_time_pattern(self):
self.assertTrue(value_item)
self.assertEqual(value_item["parsed_value"], value)

def test_import_existing_map_on_conflict_merge(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
self._assert_success(db_map.add_entity_class_item(name="Object"))
self._assert_success(db_map.add_parameter_definition_item(name="X", entity_class_name="Object"))
self._assert_success(db_map.add_entity_item(name="widget", entity_class_name="Object"))
value = Map(["T1", "T2"], [1.1, 1.2])
db_value, value_type = to_database(value)
self._assert_success(
db_map.add_parameter_value_item(
entity_class_name="Object",
entity_byname=("widget",),
parameter_definition_name="X",
alternative_name="Base",
value=db_value,
type=value_type,
)
)
extended_value = Map(["T3", "T4"], [1.3, 1.4])
self._assert_imports(
import_parameter_values(
db_map, [["Object", "widget", "X", extended_value, "Base"]], on_conflict="merge"
)
)
value_item = db_map.get_parameter_value_item(
entity_class_name="Object",
entity_byname=("widget",),
parameter_definition_name="X",
alternative_name="Base",
)
self.assertTrue(value_item)
merged_value = Map(["T1", "T2", "T3", "T4"], [1.1, 1.2, 1.3, 1.4])
self.assertEqual(value_item["parsed_value"], merged_value)

def test_import_duplicate_object_parameter_value(self):
with DatabaseMapping("sqlite://", create=True) as db_map:
self.populate(db_map)
Expand Down

0 comments on commit 4677999

Please sign in to comment.