Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interpolate env in registry_auth #1540

Merged
merged 1 commit into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ResourceNotExistsError,
ServerClientError,
)
from dstack._internal.core.models.common import RegistryAuth
from dstack._internal.core.models.configurations import (
AnyRunConfiguration,
ApplyConfigurationType,
Expand All @@ -33,7 +34,7 @@
)
from dstack._internal.core.models.runs import JobSubmission, JobTerminationReason, RunStatus
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.utils.interpolator import VariablesInterpolator
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack.api._public.runs import Run
from dstack.api.utils import load_profile

Expand Down Expand Up @@ -262,14 +263,29 @@ def apply_args(self, conf: BaseRunConfiguration, args: argparse.Namespace, unkno
conf.resources.disk = args.disk_spec

self.apply_env_vars(conf.env, args)

self.interpolate_env(conf)
self.interpolate_run_args(conf.setup, unknown)

def interpolate_run_args(self, value: List[str], unknown):
run_args = " ".join(unknown)
interpolator = VariablesInterpolator({"run": {"args": run_args}}, skip=["secrets"])
for i in range(len(value)):
value[i] = interpolator.interpolate(value[i])
try:
for i in range(len(value)):
value[i] = interpolator.interpolate_or_error(value[i])
except InterpolatorError as e:
raise ConfigurationError(e.args[0])

def interpolate_env(self, conf: BaseRunConfiguration):
env_dict = conf.env.as_dict()
interpolator = VariablesInterpolator({"env": env_dict}, skip=["secrets"])
try:
if conf.registry_auth is not None:
conf.registry_auth = RegistryAuth(
username=interpolator.interpolate_or_error(conf.registry_auth.username),
password=interpolator.interpolate_or_error(conf.registry_auth.password),
)
except InterpolatorError as e:
raise ConfigurationError(e.args[0])


class RunWithPortsConfigurator(BaseRunConfigurator):
Expand Down
14 changes: 12 additions & 2 deletions src/dstack/_internal/utils/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class Name:
char = first_char | set(string.digits + ".")


class InterpolatorError(ValueError):
pass


class VariablesInterpolator:
def __init__(
self, namespaces: Dict[str, Dict[str, str]], *, skip: Optional[Iterable[str]] = None
Expand Down Expand Up @@ -42,11 +46,11 @@ def interpolate(
tokens.append(s[start:opening])
closing = s.find(Pattern.closing, opening)
if closing == -1:
raise ValueError(f"No pattern closing: {s[opening:]}")
raise InterpolatorError(f"No pattern closing: {s[opening:]}")

name = s[opening + len(Pattern.opening) : closing].strip()
if not self.validate_name(name):
raise ValueError(f"Illegal reference name: {name}")
raise InterpolatorError(f"Illegal reference name: {name}")
if name.split(".")[0] in self.skip:
tokens.append(s[opening : closing + len(Pattern.closing)])
elif name in self.variables:
Expand All @@ -57,6 +61,12 @@ def interpolate(
s = "".join(tokens)
return (s, missing) if return_missing else s

def interpolate_or_error(self, s: str) -> str:
res, missing = self.interpolate(s, return_missing=True)
if len(missing) == 0:
return res
raise InterpolatorError(f"Failed to interpolate due to missing vars: {missing}")

@staticmethod
def validate_name(s: str) -> bool:
if s.count(".") != 1 or not (0 < s.index(".") < len(s) - 1):
Expand Down
21 changes: 21 additions & 0 deletions src/tests/_internal/cli/services/configurators/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dstack._internal.cli.services.configurators import get_run_configurator_class
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import RegistryAuth
from dstack._internal.core.models.configurations import (
BaseRunConfiguration,
PortMapping,
Expand Down Expand Up @@ -61,6 +62,26 @@ def test_any_port(self):
conf.ports = [PortMapping(local_port=None, container_port=8000)]
assert modified.dict() == conf.dict()

def test_interpolates_env(self):
conf = TaskConfiguration(
image="my_image",
registry_auth=RegistryAuth(
username="${{ env.REGISTRY_USERNAME }}",
password="${{ env.REGISTRY_PASSWORD }}",
),
env=Env.parse_obj(
{
"REGISTRY_USERNAME": "test_user",
"REGISTRY_PASSWORD": "test_password",
}
),
)
modified, args = apply_args(conf, [])
assert modified.registry_auth == RegistryAuth(
username="test_user",
password="test_password",
)


def apply_args(
conf: BaseRunConfiguration, args: List[str]
Expand Down
14 changes: 7 additions & 7 deletions src/tests/_internal/utils/test_interpolator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from dstack._internal.utils.interpolator import VariablesInterpolator
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator


def get_interpolator():
Expand Down Expand Up @@ -35,17 +35,17 @@ def test_missing(self):
assert ["env.name"] == missing

def test_unclosed_pattern(self):
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ secrets.password }")

def test_illegal_name(self):
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ secrets.pass-word }}")
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ .password }}")
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ password. }}")
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ secrets.password.hash }}")
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ secrets.007 }}")
Loading