Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mapping name as mapping source in importer #423

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions spinedb_api/import_mapping/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_mapped_data(
default_column_convert_fn=None,
row_convert_fns=None,
unparse_value=identity,
mapping_names=[],
):
"""
Args:
Expand All @@ -68,6 +69,7 @@ def get_mapped_data(
default_column_convert_fn (Callable, optional): default convert function for surplus columns
row_convert_fns (dict(int,function), optional): mapping from row number to convert function
unparse_value (Callable): a callable that converts values to database format
mapping_names (list, optional): list of mapping names (order corresponds to order of mappings).

Returns:
dict: Mapped data, ready for ``import_data()``
Expand All @@ -91,10 +93,11 @@ def get_mapped_data(
row_convert_fns = {}
if default_column_convert_fn is None:
default_column_convert_fn = column_convert_fns[max(column_convert_fns)] if column_convert_fns else identity
for mapping in mappings:
_ensure_mapping_name_consistency(mappings, mapping_names)
for mapping, mapping_name in zip(mappings, mapping_names):
read_state = {}
mapping = deepcopy(mapping)
mapping.polish(table_name, data_header, column_count)
mapping.polish(table_name, data_header, mapping_name, column_count)
mapping_errors = check_validity(mapping)
if mapping_errors:
errors += mapping_errors
Expand Down Expand Up @@ -417,3 +420,18 @@ def _apply_index_names(map_, index_names):
for v in map_.values:
if isinstance(v, Map):
_apply_index_names(v, index_names[1:])


def _ensure_mapping_name_consistency(mappings, mapping_names):
"""Makes sure that there are as many mapping names as actual mappings.

Args:
mappings (list(ImportMapping)): list of mappings
mapping_names (list(str)): list of mapping names
"""
n_mappings = len(mappings)
n_mapping_names = len(mapping_names)
if n_mapping_names > n_mappings:
mapping_names = mapping_names[:n_mappings]
elif n_mapping_names < n_mappings:
mapping_names += [""] * (n_mappings - n_mapping_names)
13 changes: 9 additions & 4 deletions spinedb_api/import_mapping/import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,28 +156,29 @@ def check_for_invalid_column_refs(self, header, table_name):
return msg
return ""

def polish(self, table_name, source_header, column_count=0, for_preview=False):
def polish(self, table_name, source_header, mapping_name, column_count=0, for_preview=False):
"""Polishes the mapping before an import operation.
'Expands' transient ``position`` and ``value`` attributes into their final value.

Args:
table_name (str)
source_header (list(str))
mapping_name (str)
column_count (int, optional)
for_preview (bool, optional)
"""
self._polish_for_import(table_name, source_header, column_count)
self._polish_for_import(table_name, source_header, mapping_name, column_count)
if for_preview:
self._polish_for_preview(source_header)

def _polish_for_import(self, table_name, source_header, column_count, pivoted=None):
def _polish_for_import(self, table_name, source_header, mapping_name, column_count, pivoted=None):
# FIXME: Polish skip columns
if pivoted is None:
pivoted = self.is_pivoted()
if pivoted and self.parent and self.is_effective_leaf():
return
if self.child is not None:
self.child._polish_for_import(table_name, source_header, column_count, pivoted)
self.child._polish_for_import(table_name, source_header, mapping_name, column_count, pivoted)
if isinstance(self.position, str):
# Column mapping with string position, we need to find the index in the header
try:
Expand All @@ -190,6 +191,10 @@ def _polish_for_import(self, table_name, source_header, column_count, pivoted=No
# Table name mapping, we set the fixed value to the table name
self.value = table_name
return
if self.position == Position.mapping_name:
# Mapping name mapping, we set the fixed value to the mapping name
self.value = mapping_name
return
if self.position == Position.header:
if self.value is None:
# Row mapping from header, we handle this one separately
Expand Down
1 change: 1 addition & 0 deletions spinedb_api/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Position(Enum):
hidden = "hidden"
table_name = "table_name"
header = "header"
mapping_name = "mapping_name"


def is_pivoted(position):
Expand Down
5 changes: 4 additions & 1 deletion spinedb_api/spine_io/importers/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ def get_mapped_data(
table_max_rows = self._resolve_max_rows(options, max_rows)
data_source, header = self.get_data_iterator(table, options, table_max_rows)
mappings = []
mapping_names = []
for named_mapping_spec in named_mapping_specs:
_, mapping = parse_named_mapping_spec(named_mapping_spec)
name, mapping = parse_named_mapping_spec(named_mapping_spec)
mappings.append(mapping)
mapping_names.append(name)
try:
data, t_errors = get_mapped_data(
data_source,
Expand All @@ -152,6 +154,7 @@ def get_mapped_data(
default_column_convert_fn,
row_convert_fns,
unparse_value,
mapping_names,
)
except (ConnectorError, ParameterValueFormatError, InvalidMappingComponent) as error:
errors.append(str(error))
Expand Down
42 changes: 42 additions & 0 deletions tests/import_mapping/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,48 @@ def test_import_durations(self):
},
)

def test_import_with_one_mapping_name_for_two_mappings(self):
data_source = iter([["other_name"]])
mappings = [
[
{"map_type": "Alternative", "position": "mapping_name"},
],
[
{"map_type": "Alternative", "position": 0},
],
]
convert_function_specs = {0: "string"}
convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()}
mapped_data, errors = get_mapped_data(
data_source, mappings, column_convert_fns=convert_functions, mapping_names=["some_name"]
)
self.assertEqual(errors, [])
self.assertEqual(
mapped_data,
{"alternatives": {"some_name", "other_name"}},
)

def test_import_with_mapping_name_with_too_many_mapping_names(self):
data_source = iter([["other_name"]])
mappings = [
[
{"map_type": "Alternative", "position": "mapping_name"},
],
[
{"map_type": "Alternative", "position": 0},
],
]
convert_function_specs = {0: "string"}
convert_functions = {column: value_to_convert_spec(spec) for column, spec in convert_function_specs.items()}
mapped_data, errors = get_mapped_data(
data_source, mappings, column_convert_fns=convert_functions, mapping_names=["some_name", "other", "null"]
)
self.assertEqual(errors, [])
self.assertEqual(
mapped_data,
{"alternatives": {"some_name", "other_name"}},
)


if __name__ == "__main__":
unittest.main()
27 changes: 18 additions & 9 deletions tests/import_mapping/test_import_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,39 +121,39 @@ def test_polish_null_mapping(self):
mapping = ImportMapping(Position.hidden, value=None)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.hidden)
self.assertIsNone(mapping.value)

def test_polish_column_mapping(self):
mapping = ImportMapping("B", value=None)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, 1)
self.assertIsNone(mapping.value)

def test_polish_column_header_mapping(self):
mapping = ImportMapping(Position.header, value=2)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.header)
self.assertEqual(mapping.value, "C")

def test_polish_column_header_mapping_str(self):
mapping = ImportMapping(Position.header, value="2")
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.header)
self.assertEqual(mapping.value, "C")

def test_polish_column_header_mapping_duplicates(self):
mapping = ImportMapping(Position.header, value=3)
table_name = "tablename"
header = ["A", "B", "C", "A"]
mapping.polish(table_name, header, for_preview=True)
mapping.polish(table_name, header, "", for_preview=True)
self.assertEqual(mapping.position, Position.header)
self.assertEqual(mapping.value, 3)

Expand All @@ -162,31 +162,40 @@ def test_polish_column_header_mapping_invalid_header(self):
table_name = "tablename"
header = ["A", "B", "C"]
with self.assertRaises(InvalidMapping):
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")

def test_polish_column_header_mapping_invalid_index(self):
mapping = ImportMapping(Position.header, value=4)
table_name = "tablename"
header = ["A", "B", "C"]
with self.assertRaises(InvalidMapping):
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")

def test_polish_table_name_mapping(self):
mapping = ImportMapping(Position.table_name)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.table_name)
self.assertEqual(mapping.value, "tablename")

def test_polish_row_header_mapping(self):
mapping = ImportMapping(Position.header, value=None)
table_name = "tablename"
header = ["A", "B", "C"]
mapping.polish(table_name, header)
mapping.polish(table_name, header, "")
self.assertEqual(mapping.position, Position.header)
self.assertIsNone(mapping.value)

def test_polish_mapping_name_mapping(self):
mapping = ImportMapping(Position.mapping_name)
table_name = "tablename"
mapping_name = "some_mapping_name"
header = ["A", "B", "C"]
mapping.polish(table_name, header, mapping_name)
self.assertEqual(mapping.position, Position.mapping_name)
self.assertEqual(mapping.value, "some_mapping_name")


class TestImportMappingIO(unittest.TestCase):
def test_object_class_mapping(self):
Expand Down