Skip to content

Commit

Permalink
Improve yaml fault tolerance and handle check_config border cases (#3159
Browse files Browse the repository at this point in the history
)
  • Loading branch information
kellerza authored Sep 8, 2016
1 parent 267cda4 commit e8ad76c
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 59 deletions.
8 changes: 6 additions & 2 deletions homeassistant/scripts/check_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,18 @@ def mock_except(ex, domain, config): # pylint: disable=unused-variable

try:
bootstrap.from_config_file(config_path, skip_pip=True)
res['secret_cache'] = yaml.__SECRET_CACHE
return res
res['secret_cache'] = dict(yaml.__SECRET_CACHE)
except Exception as err: # pylint: disable=broad-except
print(color('red', 'Fatal error while loading config:'), str(err))
finally:
# Stop all patches
for pat in PATCHES.values():
pat.stop()
# Ensure !secrets point to the original function
yaml.yaml.SafeLoader.add_constructor('!secret', yaml._secret_yaml)
bootstrap.clear_secret_cache()

return res


def dump_dict(layer, indent_count=1, listi=False, **kwargs):
Expand Down
10 changes: 10 additions & 0 deletions homeassistant/util/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ def _ordered_dict(loader: SafeLineLoader,
line = getattr(node, '__line__', 'unknown')
if line != 'unknown' and (min_line is None or line < min_line):
min_line = line

try:
hash(key)
except TypeError:
fname = getattr(loader.stream, 'name', '')
raise yaml.MarkedYAMLError(
context="invalid key: \"{}\"".format(key),
context_mark=yaml.Mark(fname, 0, min_line, -1, None, None)
)

if key in seen:
fname = getattr(loader.stream, 'name', '')
first_mark = yaml.Mark(fname, 0, seen[key], -1, None, None)
Expand Down
11 changes: 7 additions & 4 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,27 +247,30 @@ def patch_yaml_files(files_dict, endswith=True):
"""Patch load_yaml with a dictionary of yaml files."""
# match using endswith, start search with longest string
matchlist = sorted(list(files_dict.keys()), key=len) if endswith else []
# matchlist.sort(key=len)

def mock_open_f(fname, **_):
"""Mock open() in the yaml module, used by load_yaml."""
# Return the mocked file on full match
if fname in files_dict:
_LOGGER.debug('patch_yaml_files match %s', fname)
return StringIO(files_dict[fname])
res = StringIO(files_dict[fname])
setattr(res, 'name', fname)
return res

# Match using endswith
for ends in matchlist:
if fname.endswith(ends):
_LOGGER.debug('patch_yaml_files end match %s: %s', ends, fname)
return StringIO(files_dict[ends])
res = StringIO(files_dict[ends])
setattr(res, 'name', fname)
return res

# Fallback for hass.components (i.e. services.yaml)
if 'homeassistant/components' in fname:
_LOGGER.debug('patch_yaml_files using real file: %s', fname)
return open(fname, encoding='utf-8')

# Not found
raise IOError('File not found: {}'.format(fname))
raise FileNotFoundError('File not found: {}'.format(fname))

return patch.object(yaml, 'open', mock_open_f, create=True)
9 changes: 6 additions & 3 deletions tests/scripts/test_check_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ def test_secrets(self):
self.maxDiff = None

with patch_yaml_files(files):
res = check_config.check(get_test_config_dir('secret.yaml'))
config_path = get_test_config_dir('secret.yaml')
secrets_path = get_test_config_dir('secrets.yaml')

res = check_config.check(config_path)
change_yaml_files(res)

# convert secrets OrderedDict to dict for assertequal
Expand All @@ -148,7 +151,7 @@ def test_secrets(self):
'components': {'http': {'api_password': 'abc123',
'server_port': 8123}},
'except': {},
'secret_cache': {'secrets.yaml': {'http_pw': 'abc123'}},
'secret_cache': {secrets_path: {'http_pw': 'abc123'}},
'secrets': {'http_pw': 'abc123'},
'yaml_files': ['.../secret.yaml', 'secrets.yaml']
'yaml_files': ['.../secret.yaml', '.../secrets.yaml']
}, res)
111 changes: 61 additions & 50 deletions tests/util/test_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,68 +3,77 @@
import unittest
import os
import tempfile
from unittest.mock import patch

from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import yaml
import homeassistant.config as config_util
from tests.common import get_test_config_dir
from homeassistant.config import YAML_CONFIG_FILE, load_yaml_config_file
from tests.common import get_test_config_dir, patch_yaml_files


class TestYaml(unittest.TestCase):
"""Test util.yaml loader."""
# pylint: disable=no-self-use,invalid-name

def test_simple_list(self):
"""Test simple list."""
conf = "config:\n - simple\n - list"
with io.StringIO(conf) as f:
doc = yaml.yaml.safe_load(f)
with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file)
assert doc['config'] == ["simple", "list"]

def test_simple_dict(self):
"""Test simple dict."""
conf = "key: value"
with io.StringIO(conf) as f:
doc = yaml.yaml.safe_load(f)
with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file)
assert doc['key'] == 'value'

def test_duplicate_key(self):
"""Test simple dict."""
conf = "key: thing1\nkey: thing2"
try:
with io.StringIO(conf) as f:
yaml.yaml.safe_load(f)
except Exception:
pass
else:
assert 0
"""Test duplicate dict keys."""
files = {YAML_CONFIG_FILE: 'key: thing1\nkey: thing2'}
with self.assertRaises(HomeAssistantError):
with patch_yaml_files(files):
load_yaml_config_file(YAML_CONFIG_FILE)

def test_unhashable_key(self):
"""Test an unhasable key."""
files = {YAML_CONFIG_FILE: 'message:\n {{ states.state }}'}
with self.assertRaises(HomeAssistantError), \
patch_yaml_files(files):
load_yaml_config_file(YAML_CONFIG_FILE)

def test_no_key(self):
"""Test item without an key."""
files = {YAML_CONFIG_FILE: 'a: a\nnokeyhere'}
with self.assertRaises(HomeAssistantError), \
patch_yaml_files(files):
yaml.load_yaml(YAML_CONFIG_FILE)

def test_enviroment_variable(self):
"""Test config file with enviroment variable."""
os.environ["PASSWORD"] = "secret_password"
conf = "password: !env_var PASSWORD"
with io.StringIO(conf) as f:
doc = yaml.yaml.safe_load(f)
with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file)
assert doc['password'] == "secret_password"
del os.environ["PASSWORD"]

def test_invalid_enviroment_variable(self):
"""Test config file with no enviroment variable sat."""
conf = "password: !env_var PASSWORD"
try:
with io.StringIO(conf) as f:
yaml.yaml.safe_load(f)
except Exception:
pass
else:
assert 0
with self.assertRaises(HomeAssistantError):
with io.StringIO(conf) as file:
yaml.yaml.safe_load(file)

def test_include_yaml(self):
"""Test include yaml."""
with tempfile.NamedTemporaryFile() as include_file:
include_file.write(b"value")
include_file.seek(0)
conf = "key: !include {}".format(include_file.name)
with io.StringIO(conf) as f:
doc = yaml.yaml.safe_load(f)
with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file)
assert doc["key"] == "value"

def test_include_dir_list(self):
Expand All @@ -79,8 +88,8 @@ def test_include_dir_list(self):
file_2.write(b"two")
file_2.close()
conf = "key: !include_dir_list {}".format(include_dir)
with io.StringIO(conf) as f:
doc = yaml.yaml.safe_load(f)
with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file)
assert sorted(doc["key"]) == sorted(["one", "two"])

def test_include_dir_named(self):
Expand All @@ -98,8 +107,8 @@ def test_include_dir_named(self):
correct = {}
correct[os.path.splitext(os.path.basename(file_1.name))[0]] = "one"
correct[os.path.splitext(os.path.basename(file_2.name))[0]] = "two"
with io.StringIO(conf) as f:
doc = yaml.yaml.safe_load(f)
with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file)
assert doc["key"] == correct

def test_include_dir_merge_list(self):
Expand All @@ -114,8 +123,8 @@ def test_include_dir_merge_list(self):
file_2.write(b"- two\n- three")
file_2.close()
conf = "key: !include_dir_merge_list {}".format(include_dir)
with io.StringIO(conf) as f:
doc = yaml.yaml.safe_load(f)
with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file)
assert sorted(doc["key"]) == sorted(["one", "two", "three"])

def test_include_dir_merge_named(self):
Expand All @@ -130,23 +139,25 @@ def test_include_dir_merge_named(self):
file_2.write(b"key2: two\nkey3: three")
file_2.close()
conf = "key: !include_dir_merge_named {}".format(include_dir)
with io.StringIO(conf) as f:
doc = yaml.yaml.safe_load(f)
with io.StringIO(conf) as file:
doc = yaml.yaml.safe_load(file)
assert doc["key"] == {
"key1": "one",
"key2": "two",
"key3": "three"
}

FILES = {}


def load_yaml(fname, string):
"""Write a string to file and return the parsed yaml."""
with open(fname, 'w') as file:
file.write(string)
return config_util.load_yaml_config_file(fname)
FILES[fname] = string
with patch_yaml_files(FILES):
return load_yaml_config_file(fname)


class FakeKeyring():
class FakeKeyring(): # pylint: disable=too-few-public-methods
"""Fake a keyring class."""

def __init__(self, secrets_dict):
Expand All @@ -162,20 +173,16 @@ def get_password(self, domain, name):

class TestSecrets(unittest.TestCase):
"""Test the secrets parameter in the yaml utility."""
# pylint: disable=protected-access,invalid-name

def setUp(self): # pylint: disable=invalid-name
"""Create & load secrets file."""
config_dir = get_test_config_dir()
yaml.clear_secret_cache()
self._yaml_path = os.path.join(config_dir,
config_util.YAML_CONFIG_FILE)
self._yaml_path = os.path.join(config_dir, YAML_CONFIG_FILE)
self._secret_path = os.path.join(config_dir, yaml._SECRET_YAML)
self._sub_folder_path = os.path.join(config_dir, 'subFolder')
if not os.path.exists(self._sub_folder_path):
os.makedirs(self._sub_folder_path)
self._unrelated_path = os.path.join(config_dir, 'unrelated')
if not os.path.exists(self._unrelated_path):
os.makedirs(self._unrelated_path)

load_yaml(self._secret_path,
'http_pw: pwhttp\n'
Expand All @@ -194,12 +201,7 @@ def setUp(self): # pylint: disable=invalid-name
def tearDown(self): # pylint: disable=invalid-name
"""Clean up secrets."""
yaml.clear_secret_cache()
for path in [self._yaml_path, self._secret_path,
os.path.join(self._sub_folder_path, 'sub.yaml'),
os.path.join(self._sub_folder_path, yaml._SECRET_YAML),
os.path.join(self._unrelated_path, yaml._SECRET_YAML)]:
if os.path.isfile(path):
os.remove(path)
FILES.clear()

def test_secrets_from_yaml(self):
"""Did secrets load ok."""
Expand Down Expand Up @@ -263,3 +265,12 @@ def test_secrets_logger_removed(self):
"""Ensure logger: debug was removed."""
with self.assertRaises(yaml.HomeAssistantError):
load_yaml(self._yaml_path, 'api_password: !secret logger')

@patch('homeassistant.util.yaml._LOGGER.error')
def test_bad_logger_value(self, mock_error):
"""Ensure logger: debug was removed."""
yaml.clear_secret_cache()
load_yaml(self._secret_path, 'logger: info\npw: abc')
load_yaml(self._yaml_path, 'api_password: !secret pw')
assert mock_error.call_count == 1, \
"Expected an error about logger: value"

0 comments on commit e8ad76c

Please sign in to comment.