Skip to content

Commit

Permalink
add test for changing dicts and lists (#2724)
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam-D-Lewis committed Sep 17, 2024
1 parent 671f542 commit a800a5b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 6 deletions.
13 changes: 9 additions & 4 deletions src/_nebari/stages/terraform_state/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import contextlib
import enum
import functools
import inspect
import os
import pathlib
Expand Down Expand Up @@ -266,9 +265,15 @@ def check_immutable_fields(self):
for keys, old, new in nebari_config_diff.modified():
bottom_level_schema = self.config
if len(keys) > 1:
bottom_level_schema = functools.reduce(
lambda m, k: getattr(m, k), keys[:-1], self.config
)
for key in keys[:-1]:
try:
bottom_level_schema = getattr(bottom_level_schema, key)
except AttributeError as e:
if isinstance(bottom_level_schema, dict):
# handle case where value is a dict
bottom_level_schema = bottom_level_schema[key]
else:
raise e
extra_field_schema = schema.ExtraFieldSchema(
**bottom_level_schema.model_fields[keys[-1]].json_schema_extra or {}
)
Expand Down
24 changes: 22 additions & 2 deletions tests/tests_unit/test_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_check_immutable_fields_no_changes(mock_get_state, terraform_state_stage
def test_check_immutable_fields_mutable_change(
mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy()
old_config = mock_config.model_copy(deep=True)
old_config.namespace = "old-namespace"
mock_get_state.return_value = old_config

Expand All @@ -52,7 +52,7 @@ def test_check_immutable_fields_mutable_change(
def test_check_immutable_fields_immutable_change(
mock_model_fields, mock_get_state, terraform_state_stage, mock_config
):
old_config = mock_config.model_copy()
old_config = mock_config.model_copy(deep=True)
old_config.provider = schema.ProviderEnum.gcp
mock_get_state.return_value = old_config

Expand All @@ -71,3 +71,23 @@ def test_check_immutable_fields_no_prior_state(mock_get_state, terraform_state_s

# This should not raise an exception
terraform_state_stage.check_immutable_fields()


@patch.object(TerraformStateStage, "get_nebari_config_state")
def test_check_dict_value_change(mock_get_state, terraform_state_stage, mock_config):
old_config = mock_config.model_copy(deep=True)
terraform_state_stage.config.local.node_selectors["worker"].value += "new_value"
mock_get_state.return_value = old_config

# should not throw an exception
terraform_state_stage.check_immutable_fields()


@patch.object(TerraformStateStage, "get_nebari_config_state")
def test_check_list_change(mock_get_state, terraform_state_stage, mock_config):
old_config = mock_config.model_copy(deep=True)
old_config.environments["environment-dask.yaml"].channels.append("defaults")
mock_get_state.return_value = old_config

# should not throw an exception
terraform_state_stage.check_immutable_fields()

0 comments on commit a800a5b

Please sign in to comment.