Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vault CSI support v2 #828

Merged
merged 17 commits into from
Mar 14, 2024
Merged
116 changes: 53 additions & 63 deletions baseplate/lib/secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import binascii
import json
import logging
import os

from pathlib import Path
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import NamedTuple
from typing import Optional
from typing import Protocol
from typing import Tuple

from baseplate import Span
Expand Down Expand Up @@ -119,7 +120,9 @@ def _decode_secret(path: str, encoding: str, value: str) -> bytes:
raise CorruptSecretError(path, f"unknown encoding: {encoding!r}")


SecretParser = Callable[[Dict[str, Any], str], Dict[str, str]]
Copy link
Contributor Author

@pnovotnak pnovotnak Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Second parameter has a default which can't be expressed with Callable

class SecretParser(Protocol):
def __call__(self, data: Dict[str, Any], secret_path: str = "") -> Dict[str, str]:
...


def parse_secrets_fetcher(data: Dict[str, Any], secret_path: str = "") -> Dict[str, str]:
Expand Down Expand Up @@ -382,86 +385,75 @@ def _get_data(self) -> Tuple[Dict, float]:
return self._data


class DirectorySecretsStore(SecretsStore):
"""Access to secret tokens with automatic refresh when changed.
class VaultCSIEntry(NamedTuple):
mtime: float
data: Any

This local vault allows access to the secrets cached on disk given a path to
a directory. It will automatically reload the cache when it is changed. Do not
cache or store the values returned by this class's methods but rather get
them from this class each time you need them. The secrets are served from
memory so there's little performance impact to doing so and you will be

class VaultCSISecretsStore(SecretsStore):
"""Access to secret tokens using a vault CSI mount with automatic refresh.

This store allows access to secrets stored in vault through the CSI interface.
It performs caching and will automatically reload secrets from disk. Generally
do not cache or store the values returned by this class's methods, rather get
them from this class each time you need them. The secrets are served from memory
when possible, so there's little performance impact in doing so, and you will be
sure to always have the current version in the face of key rotation etc.
"""

path: Path
data_symlink: Path
cache: Dict[str, VaultCSIEntry]

def __init__(
self,
path: str,
parser: SecretParser,
timeout: Optional[int] = None,
backoff: Optional[float] = None,
): # pylint: disable=super-init-not-called
self.path = Path(path)
self.parser = parser
self._filewatchers = {}
root = Path(path)
for p in root.glob("**"):
file_path = str(p.relative_to(path))
self._filewatchers[file_path] = FileWatcher(
file_path, json.load, timeout=timeout, backoff=backoff
self.cache = {}
self.data_symlink = self.path.joinpath("..data")
if not self.path.is_dir():
raise ValueError(f"Expected {self.path} to be a directory.")
if not self.data_symlink.is_dir():
raise ValueError(
f"Expected {self.data_symlink} to be a directory. Verify {self.path} is the root of the Vault CSI mount."
)

def get_vault_url(self) -> str:
"""Deprecated and will be removed in 3.0.0"""
raise NotImplementedError

def get_vault_token(self) -> str:
"""Deprecated and will be removed in 3.0.0"""
raise NotImplementedError

def _get_file_data(self, filename: str) -> Tuple[Any, float]:
def _get_mtime(self) -> float:
"""Modification time is store-wide for CSI secrets."""
return os.path.getmtime(self.data_symlink)
pnovotnak marked this conversation as resolved.
Show resolved Hide resolved

def _raw_secret(self, name: str) -> Any:
try:
return self._filewatchers[filename].get_data_and_mtime()
except KeyError:
raise SecretNotFoundError(filename)
except WatchedFileNotAvailableError as exc:
raise SecretsNotAvailableError(exc)
with open(self.data_symlink.joinpath(name), "r", encoding="UTF-8") as fp:
return self.parser(json.load(fp))
except FileNotFoundError as exc:
raise SecretNotFoundError(name) from exc

def get_raw_and_mtime(self, secret_path: str) -> Tuple[Dict[str, str], float]:
"""Return raw secret and modification time.
This returns the same data as :py:meth:`get_raw` as well as a UNIX
epoch timestamp indicating the last time the secrets data was updated.
This modification time can be used to know when to invalidate
downstream caching.
.. versionadded:: 1.5
"""
data, mtime = self._get_file_data(secret_path)
return self.parser(data, secret_path), mtime

def make_object_for_context(self, name: str, span: Span) -> "DirectorySecretsStore":
"""Return an object that can be added to the context object.
This allows the secret store to be used with
:py:meth:`~baseplate.Baseplate.add_to_context`::
secrets = SecretsStore("/var/local/secrets.json")
baseplate.add_to_context("secrets", secrets)
"""
return _CachingDirectorySecretsStore(self._filewatchers, self.parser)


# pylint: disable=abstract-method
class _CachingDirectorySecretsStore(DirectorySecretsStore):
"""Lazily load and cache the parsed data until the server span ends."""

def __init__(
self, filewatchers: Dict[str, FileWatcher], parser: SecretParser
): # pylint: disable=super-init-not-called
self._filewatchers = filewatchers
self.parser = parser
self._cached_data: Dict[str, Tuple[Dict, float]] = {}
mtime = self._get_mtime()
if cache_entry := self.cache.get(secret_path):
if cache_entry.mtime == mtime:
return cache_entry.data, mtime
secret_data = self._raw_secret(secret_path)
self.cache[secret_path] = VaultCSIEntry(
mtime=mtime,
data=secret_data,
)
return secret_data, mtime

def _get_file_data(self, filename: str) -> Tuple[Dict, float]:
try:
result = self._cached_data[filename]
except KeyError:
result = super()._get_file_data(filename)
self._cached_data[filename] = result
return result
def make_object_for_context(self, name: str, span: Span) -> SecretsStore:
return self


def secrets_store_from_config(
Expand All @@ -487,7 +479,6 @@ def secrets_store_from_config(
:param provider: The secrets provider, acceptable values are 'vault' and 'vault_csi'. Defaults to 'vault'

"""
parser: SecretParser
assert prefix.endswith(".")
config_prefix = prefix[:-1]

Expand All @@ -509,8 +500,7 @@ def secrets_store_from_config(
backoff = None

if options.provider == "vault_csi":
parser = parse_vault_csi
return DirectorySecretsStore(options.path, parser, timeout=timeout, backoff=backoff)
return VaultCSISecretsStore(options.path, parser=parse_vault_csi)

return SecretsStore(
options.path, timeout=timeout, backoff=backoff, parser=parse_secrets_fetcher
Expand Down
Loading
Loading