diff --git a/src/_nebari/stages/terraform_state/__init__.py b/src/_nebari/stages/terraform_state/__init__.py index d9afff36e..4d93775fa 100644 --- a/src/_nebari/stages/terraform_state/__init__.py +++ b/src/_nebari/stages/terraform_state/__init__.py @@ -1,6 +1,5 @@ import contextlib import enum -import functools import inspect import os import pathlib @@ -266,9 +265,15 @@ def check_immutable_fields(self): for keys, old, new in nebari_config_diff.modified(): bottom_level_schema = self.config if len(keys) > 1: - bottom_level_schema = functools.reduce( - lambda m, k: getattr(m, k), keys[:-1], self.config - ) + for key in keys[:-1]: + try: + bottom_level_schema = getattr(bottom_level_schema, key) + except AttributeError as e: + if isinstance(bottom_level_schema, dict): + # handle case where value is a dict + bottom_level_schema = bottom_level_schema[key] + else: + raise e extra_field_schema = schema.ExtraFieldSchema( **bottom_level_schema.model_fields[keys[-1]].json_schema_extra or {} ) diff --git a/tests/tests_unit/test_stages.py b/tests/tests_unit/test_stages.py index e0e254a7d..8c0facf8c 100644 --- a/tests/tests_unit/test_stages.py +++ b/tests/tests_unit/test_stages.py @@ -39,7 +39,7 @@ def test_check_immutable_fields_no_changes(mock_get_state, terraform_state_stage def test_check_immutable_fields_mutable_change( mock_get_state, terraform_state_stage, mock_config ): - old_config = mock_config.model_copy() + old_config = mock_config.model_copy(deep=True) old_config.namespace = "old-namespace" mock_get_state.return_value = old_config @@ -52,7 +52,7 @@ def test_check_immutable_fields_mutable_change( def test_check_immutable_fields_immutable_change( mock_model_fields, mock_get_state, terraform_state_stage, mock_config ): - old_config = mock_config.model_copy() + old_config = mock_config.model_copy(deep=True) old_config.provider = schema.ProviderEnum.gcp mock_get_state.return_value = old_config @@ -71,3 +71,23 @@ def test_check_immutable_fields_no_prior_state(mock_get_state, terraform_state_s # This should not raise an exception terraform_state_stage.check_immutable_fields() + + +@patch.object(TerraformStateStage, "get_nebari_config_state") +def test_check_dict_value_change(mock_get_state, terraform_state_stage, mock_config): + old_config = mock_config.model_copy(deep=True) + terraform_state_stage.config.local.node_selectors["worker"].value += "new_value" + mock_get_state.return_value = old_config + + # should not throw an exception + terraform_state_stage.check_immutable_fields() + + +@patch.object(TerraformStateStage, "get_nebari_config_state") +def test_check_list_change(mock_get_state, terraform_state_stage, mock_config): + old_config = mock_config.model_copy(deep=True) + old_config.environments["environment-dask.yaml"].channels.append("defaults") + mock_get_state.return_value = old_config + + # should not throw an exception + terraform_state_stage.check_immutable_fields()