diff --git a/baseplate/lib/secrets.py b/baseplate/lib/secrets.py index 41b5d04c4..820fb6c3a 100644 --- a/baseplate/lib/secrets.py +++ b/baseplate/lib/secrets.py @@ -3,14 +3,15 @@ import binascii import json import logging +import os from pathlib import Path from typing import Any -from typing import Callable from typing import Dict from typing import Iterator from typing import NamedTuple from typing import Optional +from typing import Protocol from typing import Tuple from baseplate import Span @@ -119,7 +120,9 @@ def _decode_secret(path: str, encoding: str, value: str) -> bytes: raise CorruptSecretError(path, f"unknown encoding: {encoding!r}") -SecretParser = Callable[[Dict[str, Any], str], Dict[str, str]] +class SecretParser(Protocol): + def __call__(self, data: Dict[str, Any], secret_path: str = "") -> Dict[str, str]: + ... def parse_secrets_fetcher(data: Dict[str, Any], secret_path: str = "") -> Dict[str, str]: @@ -382,86 +385,75 @@ def _get_data(self) -> Tuple[Dict, float]: return self._data -class DirectorySecretsStore(SecretsStore): - """Access to secret tokens with automatic refresh when changed. +class VaultCSIEntry(NamedTuple): + mtime: float + data: Any - This local vault allows access to the secrets cached on disk given a path to - a directory. It will automatically reload the cache when it is changed. Do not - cache or store the values returned by this class's methods but rather get - them from this class each time you need them. The secrets are served from - memory so there's little performance impact to doing so and you will be + +class VaultCSISecretsStore(SecretsStore): + """Access to secret tokens using a vault CSI mount with automatic refresh. + + This store allows access to secrets stored in vault through the CSI interface. + It performs caching and will automatically reload secrets from disk. Generally + do not cache or store the values returned by this class's methods, rather get + them from this class each time you need them. The secrets are served from memory + when possible, so there's little performance impact in doing so, and you will be sure to always have the current version in the face of key rotation etc. """ + path: Path + data_symlink: Path + cache: Dict[str, VaultCSIEntry] + def __init__( self, path: str, parser: SecretParser, - timeout: Optional[int] = None, - backoff: Optional[float] = None, ): # pylint: disable=super-init-not-called + self.path = Path(path) self.parser = parser - self._filewatchers = {} - root = Path(path) - for p in root.glob("**"): - file_path = str(p.relative_to(path)) - self._filewatchers[file_path] = FileWatcher( - file_path, json.load, timeout=timeout, backoff=backoff + self.cache = {} + self.data_symlink = self.path.joinpath("..data") + if not self.path.is_dir(): + raise ValueError(f"Expected {self.path} to be a directory.") + if not self.data_symlink.is_dir(): + raise ValueError( + f"Expected {self.data_symlink} to be a directory. Verify {self.path} is the root of the Vault CSI mount." ) def get_vault_url(self) -> str: + """Deprecated and will be removed in 3.0.0""" raise NotImplementedError def get_vault_token(self) -> str: + """Deprecated and will be removed in 3.0.0""" raise NotImplementedError - def _get_file_data(self, filename: str) -> Tuple[Any, float]: + def _get_mtime(self) -> float: + """Modification time is store-wide for CSI secrets.""" + return os.path.getmtime(self.data_symlink) + + def _raw_secret(self, name: str) -> Any: try: - return self._filewatchers[filename].get_data_and_mtime() - except KeyError: - raise SecretNotFoundError(filename) - except WatchedFileNotAvailableError as exc: - raise SecretsNotAvailableError(exc) + with open(self.data_symlink.joinpath(name), "r", encoding="UTF-8") as fp: + return self.parser(json.load(fp)) + except FileNotFoundError as exc: + raise SecretNotFoundError(name) from exc def get_raw_and_mtime(self, secret_path: str) -> Tuple[Dict[str, str], float]: - """Return raw secret and modification time. - This returns the same data as :py:meth:`get_raw` as well as a UNIX - epoch timestamp indicating the last time the secrets data was updated. - This modification time can be used to know when to invalidate - downstream caching. - .. versionadded:: 1.5 - """ - data, mtime = self._get_file_data(secret_path) - return self.parser(data, secret_path), mtime - - def make_object_for_context(self, name: str, span: Span) -> "DirectorySecretsStore": - """Return an object that can be added to the context object. - This allows the secret store to be used with - :py:meth:`~baseplate.Baseplate.add_to_context`:: - secrets = SecretsStore("/var/local/secrets.json") - baseplate.add_to_context("secrets", secrets) - """ - return _CachingDirectorySecretsStore(self._filewatchers, self.parser) - - -# pylint: disable=abstract-method -class _CachingDirectorySecretsStore(DirectorySecretsStore): - """Lazily load and cache the parsed data until the server span ends.""" - - def __init__( - self, filewatchers: Dict[str, FileWatcher], parser: SecretParser - ): # pylint: disable=super-init-not-called - self._filewatchers = filewatchers - self.parser = parser - self._cached_data: Dict[str, Tuple[Dict, float]] = {} + mtime = self._get_mtime() + if cache_entry := self.cache.get(secret_path): + if cache_entry.mtime == mtime: + return cache_entry.data, mtime + secret_data = self._raw_secret(secret_path) + self.cache[secret_path] = VaultCSIEntry( + mtime=mtime, + data=secret_data, + ) + return secret_data, mtime - def _get_file_data(self, filename: str) -> Tuple[Dict, float]: - try: - result = self._cached_data[filename] - except KeyError: - result = super()._get_file_data(filename) - self._cached_data[filename] = result - return result + def make_object_for_context(self, name: str, span: Span) -> SecretsStore: + return self def secrets_store_from_config( @@ -487,7 +479,6 @@ def secrets_store_from_config( :param provider: The secrets provider, acceptable values are 'vault' and 'vault_csi'. Defaults to 'vault' """ - parser: SecretParser assert prefix.endswith(".") config_prefix = prefix[:-1] @@ -509,8 +500,7 @@ def secrets_store_from_config( backoff = None if options.provider == "vault_csi": - parser = parse_vault_csi - return DirectorySecretsStore(options.path, parser, timeout=timeout, backoff=backoff) + return VaultCSISecretsStore(options.path, parser=parse_vault_csi) return SecretsStore( options.path, timeout=timeout, backoff=backoff, parser=parse_secrets_fetcher diff --git a/tests/unit/lib/secrets/store_directory_tests.py b/tests/unit/lib/secrets/store_directory_tests.py deleted file mode 100644 index d4b5b07b0..000000000 --- a/tests/unit/lib/secrets/store_directory_tests.py +++ /dev/null @@ -1,222 +0,0 @@ -import unittest - -from baseplate.lib.secrets import CorruptSecretError -from baseplate.lib.secrets import CredentialSecret -from baseplate.lib.secrets import DirectorySecretsStore -from baseplate.lib.secrets import parse_vault_csi -from baseplate.lib.secrets import SecretNotFoundError -from baseplate.lib.secrets import secrets_store_from_config -from baseplate.testing.lib.file_watcher import FakeFileWatcher - - -class StoreDirectoryTests(unittest.TestCase): - def setUp(self): - self.fake_filewatcher_1 = FakeFileWatcher() - self.fake_filewatcher_2 = FakeFileWatcher() - self.fake_filewatcher_3 = FakeFileWatcher() - self.store = DirectorySecretsStore("/whatever", parse_vault_csi) - self.store._filewatchers["secret1"] = self.fake_filewatcher_1 - self.store._filewatchers["secret2"] = self.fake_filewatcher_2 - self.store._filewatchers["secret3"] = self.fake_filewatcher_3 - - def test_file_not_found(self): - with self.assertRaises(SecretNotFoundError): - self.store.get_raw("test") - - def test_vault_info(self): - with self.assertRaises(NotImplementedError): - self.store.get_vault_token() - - with self.assertRaises(NotImplementedError): - self.store.get_vault_url() - - def test_raw_secrets(self): - self.fake_filewatcher_1.data = { - "data": {"something": "exists"}, - } - - self.assertEqual(self.store.get_raw("secret1"), {"something": "exists"}) - - with self.assertRaises(SecretNotFoundError): - self.store.get_raw("secret0") - - def test_simple_secrets(self): - # simple test - self.fake_filewatcher_1.data = { - "data": {"type": "simple", "value": "easy"}, - } - self.assertEqual(self.store.get_simple("secret1"), b"easy") - - # test base64 - self.fake_filewatcher_2.data = { - "data": {"type": "simple", "value": "aHVudGVyMg==", "encoding": "base64"}, - } - self.assertEqual(self.store.get_simple("secret2"), b"hunter2") - - # test unknown encoding - self.fake_filewatcher_3.data = { - "data": {"type": "simple", "value": "sdlfkj", "encoding": "mystery"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_simple("secret3") - - # test not simple - self.fake_filewatcher_1.data = { - "data": {"something": "else"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_simple("secret1") - - # test no value - self.fake_filewatcher_2.data = { - "data": {"type": "simple"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_simple("secret2") - - # test bad base64 - self.fake_filewatcher_3.data = { - "data": {"type": "simple", "value": "aHVudGVyMg", "encoding": "base64"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_simple("secret3") - - def test_versioned_secrets(self): - # simple test - self.fake_filewatcher_1.data = { - "data": {"type": "versioned", "current": "easy"}, - } - simple = self.store.get_versioned("secret1") - self.assertEqual(simple.current, b"easy") - self.assertEqual(list(simple.all_versions), [b"easy"]) - - # test base64 - self.fake_filewatcher_2.data = { - "data": { - "type": "versioned", - "previous": "aHVudGVyMQ==", - "current": "aHVudGVyMg==", - "next": "aHVudGVyMw==", - "encoding": "base64", - }, - } - encoded = self.store.get_versioned("secret2") - self.assertEqual(encoded.previous, b"hunter1") - self.assertEqual(encoded.current, b"hunter2") - self.assertEqual(encoded.next, b"hunter3") - self.assertEqual(list(encoded.all_versions), [b"hunter2", b"hunter1", b"hunter3"]) - - # test unknown encoding - self.fake_filewatcher_3.data = { - "data": {"type": "versioned", "current": "sdlfkj", "encoding": "mystery"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_versioned("secret3") - - # test not versioned - self.fake_filewatcher_1.data = { - "data": {"something": "else"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_versioned("secret1") - - # test no value - self.fake_filewatcher_2.data = { - "data": {"type": "versioned"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_versioned("secret2") - - # test bad base64 - self.fake_filewatcher_3.data = { - "data": {"type": "simple", "value": "aHVudGVyMg", "encoding": "base64"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_versioned("secret3") - - def test_credential_secrets(self): - # simple test - self.fake_filewatcher_1.data = { - "data": {"type": "credential", "username": "user", "password": "password"}, - } - self.assertEqual( - self.store.get_credentials("secret1"), CredentialSecret("user", "password") - ) - - # test identiy - self.fake_filewatcher_2.data = { - "data": { - "type": "credential", - "username": "spez", - "password": "hunter2", - "encoding": "identity", - }, - } - self.assertEqual(self.store.get_credentials("secret2"), CredentialSecret("spez", "hunter2")) - - # test base64 - self.fake_filewatcher_2.data = { - "data": { - "type": "credential", - "username": "foo", - "password": "aHVudGVyMg==", - "encoding": "base64", - }, - } - with self.assertRaises(CorruptSecretError): - self.store.get_credentials("secret2") - - # test unknkown encoding - self.fake_filewatcher_3.data = { - "data": { - "type": "credential", - "username": "fizz", - "password": "buzz", - "encoding": "something", - }, - } - with self.assertRaises(CorruptSecretError): - self.store.get_credentials("secret3") - - # test not credentials - self.fake_filewatcher_1.data = { - "data": {"type": "versioned", "current": "easy"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_credentials("secret1") - - # test no values - self.fake_filewatcher_2.data = { - "data": {"type": "credential"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_credentials("secret2") - - # test no username - self.fake_filewatcher_3.data = { - "data": {"type": "credential", "password": "password"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_credentials("secret3") - - # test no password - self.fake_filewatcher_1.data = { - "data": {"type": "credential", "username": "user"}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_credentials("secret1") - - # test bad type - self.fake_filewatcher_2.data = { - "data": {"type": "credential", "username": "user", "password": 100}, - } - with self.assertRaises(CorruptSecretError): - self.store.get_credentials("secret2") - - -class StoreFromConfigTests(unittest.TestCase): - def test_make_store(self): - secrets = secrets_store_from_config( - {"secrets.path": "/tmp", "secrets.provider": "vault_csi"} - ) - self.assertIsInstance(secrets, DirectorySecretsStore) diff --git a/tests/unit/lib/secrets/vault_csi_tests.py b/tests/unit/lib/secrets/vault_csi_tests.py new file mode 100644 index 000000000..a55880022 --- /dev/null +++ b/tests/unit/lib/secrets/vault_csi_tests.py @@ -0,0 +1,267 @@ +import datetime +import json +import shutil +import string +import tempfile +import typing +import unittest + +from pathlib import Path +from unittest.mock import mock_open +from unittest.mock import patch + +import gevent +import pytest +import typing_extensions + +from baseplate.lib.secrets import SecretNotFoundError +from baseplate.lib.secrets import secrets_store_from_config +from baseplate.lib.secrets import SecretsStore +from baseplate.lib.secrets import VaultCSISecretsStore + +SecretType: typing_extensions.TypeAlias = typing.Dict[str, any] + + +def write_secrets(secrets_data_path: Path, data: typing.Dict[str, SecretType]) -> None: + """Write secrets to the current data directory.""" + for key, value in data.items(): + secret_path = secrets_data_path.joinpath(key) + secret_path.parent.mkdir(parents=True, exist_ok=True) + secret_path.write_text(json.dumps(value)) + + +def write_symlinks(data_path: Path) -> None: + csi_path = data_path.parent + # This path can be monitored for changes + # https://github.com/kubernetes-sigs/secrets-store-csi-driver/blob/c697863c35d5431ec048b440d36550eb3ceb338f/pkg/util/fileutil/atomic_writer.go#L60-L62 + data_link = Path(csi_path, "..data") + # Simulate atomic update + new_data_link = Path(csi_path, "..data-new") + new_data_link.symlink_to(data_path) + new_data_link.rename(data_link) + human_path = Path(csi_path, "secret") + if not human_path.exists(): + human_path.symlink_to(csi_path.joinpath("..data/secret")) + + +def new_fake_csi(data: typing.Dict[str, SecretType]) -> Path: + """Creates a simulated CSI directory with data and symlinks. + Note that this would already be configured before the pod starts.""" + csi_dir = Path(tempfile.mkdtemp()) + # Closely resembles but doesn't precisely match the actual CSI plugin + data_path = Path(csi_dir, f'..{datetime.datetime.today().strftime("%Y_%m_%d_%H_%M_%S.%f")}') + write_secrets(data_path, data) + write_symlinks(data_path) + return csi_dir + + +def simulate_secret_update( + csi_dir: Path, updated_data: typing.Optional[typing.Dict[str, SecretType]] = None +) -> None: + """Simulates either TTL expiry / a secret update.""" + old_data_path = csi_dir.joinpath("..data").resolve() + # Clone the data directory + new_data_path = Path(csi_dir, f'..{datetime.datetime.today().strftime("%Y_%m_%d_%H_%M_%S.%f")}') + # Update the secret + if updated_data: + write_secrets(new_data_path, updated_data) + else: + shutil.copytree(old_data_path, new_data_path) + write_symlinks(new_data_path) + shutil.rmtree(old_data_path) + + +def get_secrets_store(csi_dir: str) -> SecretsStore: + store = secrets_store_from_config({"secrets.path": csi_dir, "secrets.provider": "vault_csi"}) + assert isinstance(store, VaultCSISecretsStore) + return store + + +EXAMPLE_SECRETS_DATA = { + "secret/example-service/example-secret": { + "request_id": "8487d906-2154-0151-d07e-57f41447326a", + "lease_id": "", + "lease_duration": 2764800, + "renewable": False, + "data": {"password": "password", "type": "credential", "username": "reddit"}, + "warnings": None, + }, + "secret/example-service/nested/example-nested-secret": { + "request_id": "8487d906-2154-0151-d07e-57f41447326a", + "lease_id": "", + "lease_duration": 2764800, + "renewable": False, + "data": {"password": "password", "type": "credential", "username": "reddit"}, + "warnings": None, + }, + "secret/bare-secret": { + "request_id": "8487d906-2154-0151-d07e-57f41447326a", + "lease_id": "", + "lease_duration": 2764800, + "renewable": False, + "data": {"password": "password", "type": "credential", "username": "reddit"}, + "warnings": None, + }, + "secret/simple-secret": { + "request_id": "8487d906-2154-0151-d07e-57f41447326a", + "lease_id": "", + "lease_duration": 2764800, + "renewable": False, + "data": {"type": "simple", "value": "simply a secret"}, + "warnings": None, + }, + "secret/simple-encoded-secret": { + "request_id": "8487d906-2154-0151-d07e-57f41447326a", + "lease_id": "", + "lease_duration": 2764800, + "renewable": False, + "data": {"type": "simple", "encoding": "base64", "value": "MTMzNw=="}, + "warnings": None, + }, + "secret/versioned-secret": { + "request_id": "8487d906-2154-0151-d07e-57f41447326a", + "lease_id": "", + "lease_duration": 2764800, + "renewable": False, + "data": {"type": "versioned", "current": "current value", "previous": "previous value"}, + "warnings": None, + }, + "secret/versioned-encoded-secret": { + "request_id": "8487d906-2154-0151-d07e-57f41447326a", + "lease_id": "", + "lease_duration": 2764800, + "renewable": False, + "data": { + "type": "versioned", + "encoding": "base64", + "current": "Y3VycmVudCBlbmNvZGVkIHZhbHVl", + "previous": "cHJldmlvdXMgZW5jb2RlZCB2YWx1ZQ==", + }, + "warnings": None, + }, +} + +EXAMPLE_UPDATED_SECRETS = EXAMPLE_SECRETS_DATA.copy() +EXAMPLE_UPDATED_SECRETS.update( + { + "secret/example-service/example-secret": { + "request_id": "8487d906-2154-0151-d07e-57f41447326a", + "lease_id": "", + "lease_duration": 2764800, + "renewable": False, + "data": { + "password": "new_password", + "type": "credential", + "username": "new_reddit", + }, + "warnings": None, + }, + } +) + + +class StoreTests(unittest.TestCase): + def setUp(self): + self.csi_dir = new_fake_csi(EXAMPLE_SECRETS_DATA) + + def tearDown(self): + shutil.rmtree(self.csi_dir) + + def test_can_load_credential_secret(self): + secrets_store = get_secrets_store(str(self.csi_dir)) + data = secrets_store.get_credentials("secret/example-service/example-secret") + assert data.username == "reddit" + assert data.password == "password" + + def test_get_raw_secret(self): + secrets_store = get_secrets_store(str(self.csi_dir)) + data = secrets_store.get_raw("secret/example-service/nested/example-nested-secret") + assert data == {"password": "password", "type": "credential", "username": "reddit"} + + def test_get_simple_secret(self): + secrets_store = get_secrets_store(str(self.csi_dir)) + data = secrets_store.get_simple("secret/simple-secret") + assert data == b"simply a secret" + + def test_get_versioned_secret(self): + secrets_store = get_secrets_store(str(self.csi_dir)) + data = secrets_store.get_versioned("secret/versioned-secret") + assert data.current == b"current value" + assert data.previous == b"previous value" + + def test_simple_encoding(self): + secrets_store = get_secrets_store(str(self.csi_dir)) + data = secrets_store.get_simple("secret/simple-encoded-secret") + assert data == b"1337" + + def test_versioned_encoding(self): + secrets_store = get_secrets_store(str(self.csi_dir)) + data = secrets_store.get_versioned("secret/versioned-encoded-secret") + assert data.current == b"current encoded value" + assert data.previous == b"previous encoded value" + + def test_symlink_updated(self): + original_data_path = self.csi_dir.joinpath("..data").resolve() + secrets_store = get_secrets_store(str(self.csi_dir)) + data = secrets_store.get_credentials("secret/bare-secret") + assert data.username == "reddit" + assert data.password == "password" + simulate_secret_update(self.csi_dir) + assert original_data_path != self.csi_dir.joinpath("..data").resolve() + data = secrets_store.get_credentials("secret/bare-secret") + assert data.username == "reddit" + assert data.password == "password" + + def test_secret_updated(self): + secrets_store = get_secrets_store(str(self.csi_dir)) + data = secrets_store.get_credentials("secret/example-service/example-secret") + gevent.sleep(0.1) # prevent gevent shenanigans + + assert data.username == "reddit" + assert data.password == "password" + + printable_ascii = list(string.printable)[:60] + for i in range(0, len(printable_ascii), 6): + chars = printable_ascii[i : i + 6] + expected_username = "".join(chars[:3]) + expected_password = "".join(chars[3:]) + new_secrets = EXAMPLE_UPDATED_SECRETS.copy() + new_secrets["secret/example-service/example-secret"]["data"][ + "username" + ] = expected_username + new_secrets["secret/example-service/example-secret"]["data"][ + "password" + ] = expected_password + simulate_secret_update( + self.csi_dir, + updated_data=EXAMPLE_UPDATED_SECRETS, + ) + data = secrets_store.get_credentials("secret/example-service/example-secret") + assert data.username == expected_username, f"{data.username} != {expected_username}" + assert data.password == expected_password, f"{data.password} != {expected_password}" + gevent.sleep(0.075) # prevent gevent shenanigans + + def test_cache_works(self): + self.csi_dir.joinpath("..data").resolve() + secrets_store = get_secrets_store(str(self.csi_dir)) + with patch( + "builtins.open", + mock_open( + read_data=json.dumps(EXAMPLE_SECRETS_DATA["secret/example-service/example-secret"]) + ), + ) as mock_open_: + secrets_store.get_credentials("secret/example-service/example-secret") + secrets_store.get_credentials("secret/example-service/example-secret") + secrets_store.get_credentials("secret/example-service/example-secret") + secrets_store.get_credentials("secret/example-service/example-secret") + gevent.sleep(0.1) # prevent gevent shenanigans + assert mock_open_.call_count == 1 + + def test_invalid_secret_raises(self): + self.csi_dir.joinpath("..data").resolve() + secrets_store = get_secrets_store(str(self.csi_dir)) + with pytest.raises(SecretNotFoundError): + secrets_store.get_credentials("secret/example-service/does-not-exist") + # While cache is updating we should still fail + with pytest.raises(SecretNotFoundError): + secrets_store.get_credentials("secret/example-service/does-not-exist")