Skip to content

Commit

Permalink
Support GCP secrets (#1571)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor authored Mar 31, 2023
1 parent 6c52297 commit 9658b02
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 23 deletions.
27 changes: 14 additions & 13 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,13 @@ def __getattr__(self, item: str) -> _GroupSecrets:
"""
return self._GroupSecrets(item, self)

def get(self, group: str, key: str) -> str:
def get(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
"""
Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError
"""
self.check_group_key(group, key)
env_var = self.get_secrets_env_var(group, key)
fpath = self.get_secrets_file(group, key)
self.check_group_key(group)
env_var = self.get_secrets_env_var(group, key, group_version)
fpath = self.get_secrets_file(group, key, group_version)
v = os.environ.get(env_var)
if v is not None:
return v
Expand All @@ -346,26 +346,27 @@ def get(self, group: str, key: str) -> str:
f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}"
)

def get_secrets_env_var(self, group: str, key: str) -> str:
def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
"""
Returns a string that matches the ENV Variable to look for the secrets
"""
self.check_group_key(group, key)
return f"{self._env_prefix}{group.upper()}_{key.upper()}"
self.check_group_key(group)
l = [k.upper() for k in filter(None, (group, group_version, key))]
return f"{self._env_prefix}{'_'.join(l)}"

def get_secrets_file(self, group: str, key: str) -> str:
def get_secrets_file(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str:
"""
Returns a path that matches the file to look for the secrets
"""
self.check_group_key(group, key)
return os.path.join(self._base_dir, group.lower(), f"{self._file_prefix}{key.lower()}")
self.check_group_key(group)
l = [k.lower() for k in filter(None, (group, group_version, key))]
l[-1] = f"{self._file_prefix}{l[-1]}"
return os.path.join(self._base_dir, *l)

@staticmethod
def check_group_key(group: str, key: str):
def check_group_key(group: str):
if group is None or group == "":
raise ValueError("secrets group is a mandatory field.")
if key is None or key == "":
raise ValueError("secrets key is a mandatory field.")


@dataclass(frozen=True)
Expand Down
6 changes: 2 additions & 4 deletions flytekit/models/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ class MountType(Enum):
"""

group: str
key: str
key: Optional[str] = None
group_version: Optional[str] = None
mount_requirement: MountType = MountType.ANY

def __post_init__(self):
if self.group is None:
raise ValueError("Group is a required parameter")
if self.key is None:
raise ValueError("Key is also a required parameter")

def to_flyte_idl(self) -> _sec.Secret:
return _sec.Secret(
Expand All @@ -59,7 +57,7 @@ def from_flyte_idl(cls, pb2_object: _sec.Secret) -> "Secret":
return cls(
group=pb2_object.group,
group_version=pb2_object.group_version if pb2_object.group_version else None,
key=pb2_object.key,
key=pb2_object.key if pb2_object.key else None,
mount_requirement=Secret.MountType(pb2_object.mount_requirement),
)

Expand Down
15 changes: 9 additions & 6 deletions tests/flytekit/unit/core/test_context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,17 @@ def test_secrets_manager_default():

def test_secrets_manager_get_envvar():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_env_var("test", "")
with pytest.raises(ValueError):
sec.get_secrets_env_var("", "x")
cfg = SecretsConfig.auto()
assert sec.get_secrets_env_var("group", "test") == f"{cfg.env_prefix}GROUP_TEST"
assert sec.get_secrets_env_var("group", "test", "v1") == f"{cfg.env_prefix}GROUP_V1_TEST"
assert sec.get_secrets_env_var("group", group_version="v1") == f"{cfg.env_prefix}GROUP_V1"
assert sec.get_secrets_env_var("group") == f"{cfg.env_prefix}GROUP"


def test_secrets_manager_get_file():
sec = SecretsManager()
with pytest.raises(ValueError):
sec.get_secrets_file("test", "")
with pytest.raises(ValueError):
sec.get_secrets_file("", "x")
cfg = SecretsConfig.auto()
Expand All @@ -135,6 +134,12 @@ def test_secrets_manager_get_file():
"group",
f"{cfg.file_prefix}test",
)
assert sec.get_secrets_file("group", "test", "v1") == os.path.join(
cfg.default_dir,
"group",
"v1",
f"{cfg.file_prefix}test",
)


def test_secrets_manager_file(tmpdir: py.path.local):
Expand All @@ -145,8 +150,6 @@ def test_secrets_manager_file(tmpdir: py.path.local):
with open(f, "w+") as w:
w.write("my-password")

with pytest.raises(ValueError):
sec.get("test", "")
with pytest.raises(ValueError):
sec.get("", "x")
# Group dir not exists
Expand Down
13 changes: 13 additions & 0 deletions tests/flytekit/unit/models/core/test_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from flytekit.models.security import Secret


def test_secret():
obj = Secret("grp", "key")
obj2 = Secret.from_flyte_idl(obj.to_flyte_idl())
assert obj2.key == "key"
assert obj2.group_version is None

obj = Secret("grp", group_version="v1")
obj2 = Secret.from_flyte_idl(obj.to_flyte_idl())
assert obj2.key is None
assert obj2.group_version == "v1"

0 comments on commit 9658b02

Please sign in to comment.