diff --git a/generic_config_updater/patch_sorter.py b/generic_config_updater/patch_sorter.py index 83ed4a88cb..c374301729 100644 --- a/generic_config_updater/patch_sorter.py +++ b/generic_config_updater/patch_sorter.py @@ -399,10 +399,10 @@ def _get_paths_recursive(self, config, pattern_tokens, matching_tokens, idx, com if token == "*": matching_keys = config.keys() elif token.startswith("*|"): - suffix = token[2:] + suffix = token[1:] matching_keys = [key for key in config.keys() if key.endswith(suffix)] elif token.endswith("|*"): - prefix = token[:-2] + prefix = token[:-1] matching_keys = [key for key in config.keys() if key.startswith(prefix)] elif token in config: matching_keys = [token] diff --git a/tests/generic_config_updater/patch_sorter_test.py b/tests/generic_config_updater/patch_sorter_test.py index 68a6b09a54..2ef18e1fc4 100644 --- a/tests/generic_config_updater/patch_sorter_test.py +++ b/tests/generic_config_updater/patch_sorter_test.py @@ -753,6 +753,41 @@ def test_simulate__applies_move(self): # Assert self.assertIs(self.any_diff, actual) +class TestJsonPointerFilter(unittest.TestCase): + def test_get_paths__common_prefix__exact_match_returned(self): + config = { + "BUFFER_PG": { + "Ethernet1|0": {}, + "Ethernet12|0": {}, + "Ethernet120|0": {}, # 'Ethernet12' is a common prefix with the previous line + }, + } + + filter = ps.JsonPointerFilter([["BUFFER_PG", "Ethernet12|*"]], PathAddressing()) + + expected_paths = ["/BUFFER_PG/Ethernet12|0"] + + actual_paths = list(filter.get_paths(config)) + + self.assertCountEqual(expected_paths, actual_paths) + + def test_get_paths__common_suffix__exact_match_returned(self): + config = { + "QUEUE": { + "Ethernet1|0": {}, + "Ethernet1|10": {}, + "Ethernet1|110": {}, # 10 is a common suffix with the previous line + }, + } + + filter = ps.JsonPointerFilter([["QUEUE", "*|10"]], PathAddressing()) + + expected_paths = ["/QUEUE/Ethernet1|10"] + + actual_paths = list(filter.get_paths(config)) + + self.assertCountEqual(expected_paths, actual_paths) + class TestRequiredValueIdentifier(unittest.TestCase): def test_hard_coded_required_value_data(self): identifier = ps.RequiredValueIdentifier(PathAddressing())