Skip to content

Commit

Permalink
Add mapping name as mapping source in importer (#423)
Browse files Browse the repository at this point in the history
  • Loading branch information
PiispaH authored Jun 27, 2024
2 parents c630744 + 137ca24 commit afbe63a
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 16 deletions.
22 changes: 20 additions & 2 deletions spinedb_api/import_mapping/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_mapped_data(
default_column_convert_fn=None,
row_convert_fns=None,
unparse_value=identity,
mapping_names=[],
):
"""
Args:
Expand All @@ -68,6 +69,7 @@ def get_mapped_data(
default_column_convert_fn (Callable, optional): default convert function for surplus columns
row_convert_fns (dict(int,function), optional): mapping from row number to convert function
unparse_value (Callable): a callable that converts values to database format
mapping_names (list, optional): list of mapping names (order corresponds to order of mappings).
Returns:
dict: Mapped data, ready for ``import_data()``
Expand All @@ -91,10 +93,11 @@ def get_mapped_data(
row_convert_fns = {}
if default_column_convert_fn is None:
default_column_convert_fn = column_convert_fns[max(column_convert_fns)] if column_convert_fns else identity
for mapping in mappings:
_ensure_mapping_name_consistency(mappings, mapping_names)
for mapping, mapping_name in zip(mappings, mapping_names):
read_state = {}
mapping = deepcopy(mapping)
mapping.polish(table_name, data_header, column_count)
mapping.polish(table_name, data_header, mapping_name, column_count)
mapping_errors = check_validity(mapping)
if mapping_errors:
errors += mapping_errors
Expand Down Expand Up @@ -417,3 +420,18 @@ def _apply_index_names(map_, index_names):
for v in map_.values:
if isinstance(v, Map):
_apply_index_names(v, index_names[1:])


def _ensure_mapping_name_consistency(mappings, mapping_names):
"""Makes sure that there are as many mapping names as actual mappings.
Args:
mappings (list(ImportMapping)): list of mappings
mapping_names (list(str)): list of mapping names
"""
n_mappings = len(mappings)
n_mapping_names = len(mapping_names)
if n_mapping_names > n_mappings:
mapping_names = mapping_names[:n_mappings]
elif n_mapping_names < n_mappings:
mapping_names += [""] * (n_mappings - n_mapping_names)
13 changes: 9 additions & 4 deletions spinedb_api/import_mapping/import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,28 +156,29 @@ def check_for_invalid_column_refs(self, header, table_name):
return msg
return ""

def polish(self, table_name, source_header, column_count=0, for_preview=False):
def polish(self, table_name, source_header, mapping_name, column_count=0, for_preview=False):
"""Polishes the mapping before an import operation.
'Expands' transient ``position`` and ``value`` attributes into their final value.
Args:
table_name (str)
source_header (list(str))
mapping_name (str)
column_count (int, optional)
for_preview (bool, optional)
"""
self._polish_for_import(table_name, source_header, column_count)
self._polish_for_import(table_name, source_header, mapping_name, column_count)
if for_preview:
self._polish_for_preview(source_header)

def _polish_for_import(self, table_name, source_header, column_count, pivoted=None):
def _polish_for_import(self, table_name, source_header, mapping_name, column_count, pivoted=None):
# FIXME: Polish skip columns
if pivoted is None:
pivoted = self.is_pivoted()
if pivoted and self.parent and self.is_effective_leaf():
return
if self.child is not None:
self.child._polish_for_import(table_name, source_header, column_count, pivoted)
self.child._polish_for_import(table_name, source_header, mapping_name, column_count, pivoted)
if isinstance(self.position, str):
# Column mapping with string position, we need to find the index in the header
try:
Expand All @@ -190,6 +191,10 @@ def _polish_for_import(self, table_name, source_header, column_count, pivoted=No
# Table name mapping, we set the fixed value to the table name
self.value = table_name
return
if self.position == Position.mapping_name:
# Mapping name mapping, we set the fixed value to the mapping name
self.value = mapping_name
return
if self.position == Position.header:
if self.value is None:
# Row mapping from header, we handle this one separately
Expand Down
1 change: 1 addition & 0 deletions spinedb_api/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Position(Enum):
hidden = "hidden"
table_name = "table_name"
header = "header"
mapping_name = "mapping_name"


def is_pivoted(position):
Expand Down
5 changes: 4 additions & 1 deletion spinedb_api/spine_io/importers/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ def get_mapped_data(
table_max_rows = self._resolve_max_rows(options, max_rows)
data_source, header = self.get_data_iterator(table, options, table_max_rows)
mappings = []
mapping_names = []
for named_mapping_spec in named_mapping_specs:
_, mapping = parse_named_mapping_spec(named_mapping_spec)
name, mapping = parse_named_mapping_spec(named_mapping_spec)
mappings.append(mapping)
mapping_names.append(name)
try:
data, t_errors = get_mapped_data(
data_source,
Expand All @@ -152,6 +154,7 @@ def get_mapped_data(
default_column_convert_fn,
row_convert_fns,
unparse_value,
mapping_names,
)
except (ConnectorError, ParameterValueFormatError, InvalidMappingComponent) as error:
errors.append(str(error))
Expand Down
42 changes: 42 additions & 0 deletions tests/import_mapping/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,48 @@ def test_import_durations(self):
},
)

def test_import_with_one_mapping_name_for_two_mappings(self):
data_source = iter([["other_name"]])
mappings = [
[
{"map_type": "Alternative", "position": "mapping_name"},
],
[
{"map_type": "Alternative", "position": 0},
],
]
convert_function_specs = {0: "string"}
convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()}
mapped_data, errors = get_mapped_data(
data_source, mappings, column_convert_fns=convert_functions, mapping_names=["some_name"]
)
self.assertEqual(errors, [])
self.assertEqual(
mapped_data,
{"alternatives": {"some_name", "other_name"}},
)

def test_import_with_mapping_name_with_too_many_mapping_names(self):
data_source = iter([["other_name"]])
mappings = [
[
{"map_type": "Alternative", "position": "mapping_name"},
],
[
{"map_type": "Alternative", "position": 0},
],
]
convert_function_specs = {0: "string"}
convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()}
mapped_data, errors = get_mapped_data(
data_source, mappings, column_convert_fns=convert_functions, mapping_names=["some_name", "other", "null"]
)
self.assertEqual(errors, [])
self.assertEqual(
mapped_data,
{"alternatives": {"some_name", "other_name"}},
)


if __name__ == "__main__":
unittest.main()
27 changes: 18 additions & 9 deletions tests/import_mapping/test_import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,39 +121,39 @@ def test_polish_null_mapping(self):
mapping = ImportMapping(Position.hidden, value=None)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.hidden)
self.assertIsNone(mapping.value)

def test_polish_column_mapping(self):
mapping = ImportMapping("B", value=None)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, 1)
self.assertIsNone(mapping.value)

def test_polish_column_header_mapping(self):
mapping = ImportMapping(Position.header, value=2)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.header)
self.assertEqual(mapping.value, "C")

def test_polish_column_header_mapping_str(self):
mapping = ImportMapping(Position.header, value="2")
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.header)
self.assertEqual(mapping.value, "C")

def test_polish_column_header_mapping_duplicates(self):
mapping = ImportMapping(Position.header, value=3)
table_name = "tablename"
header = ["A", "B", "C", "A"]
mapping.polish(table_name, header, for_preview=True)
mapping.polish(table_name, header, "", for_preview=True)
self.assertEqual(mapping.position, Position.header)
self.assertEqual(mapping.value, 3)

Expand All @@ -162,31 +162,40 @@ def test_polish_column_header_mapping_invalid_header(self):
table_name = "tablename"
header = ["A", "B", "C"]
with self.assertRaises(InvalidMapping):
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")

def test_polish_column_header_mapping_invalid_index(self):
mapping = ImportMapping(Position.header, value=4)
table_name = "tablename"
header = ["A", "B", "C"]
with self.assertRaises(InvalidMapping):
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")

def test_polish_table_name_mapping(self):
mapping = ImportMapping(Position.table_name)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.table_name)
self.assertEqual(mapping.value, "tablename")

def test_polish_row_header_mapping(self):
mapping = ImportMapping(Position.header, value=None)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.header)
self.assertIsNone(mapping.value)

def test_polish_mapping_name_mapping(self):
mapping = ImportMapping(Position.mapping_name)
table_name = "tablename"
mapping_name = "some_mapping_name"
header = ["A", "B", "C"]
mapping.polish(table_name, header, mapping_name)
self.assertEqual(mapping.position, Position.mapping_name)
self.assertEqual(mapping.value, "some_mapping_name")


class TestImportMappingIO(unittest.TestCase):
def test_object_class_mapping(self):
Expand Down

0 comments on commit afbe63a

Please sign in to comment.