Skip to content

Commit

Permalink
Interpolate env in registry_auth (#1540)
Browse files Browse the repository at this point in the history
  • Loading branch information
r4victor authored Aug 12, 2024
1 parent 89350b1 commit 4976572
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 13 deletions.
24 changes: 20 additions & 4 deletions src/dstack/_internal/cli/services/configurators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ResourceNotExistsError,
ServerClientError,
)
from dstack._internal.core.models.common import RegistryAuth
from dstack._internal.core.models.configurations import (
AnyRunConfiguration,
ApplyConfigurationType,
Expand All @@ -33,7 +34,7 @@
)
from dstack._internal.core.models.runs import JobSubmission, JobTerminationReason, RunStatus
from dstack._internal.core.services.configs import ConfigManager
from dstack._internal.utils.interpolator import VariablesInterpolator
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator
from dstack.api._public.runs import Run
from dstack.api.utils import load_profile

Expand Down Expand Up @@ -262,14 +263,29 @@ def apply_args(self, conf: BaseRunConfiguration, args: argparse.Namespace, unkno
conf.resources.disk = args.disk_spec

self.apply_env_vars(conf.env, args)

self.interpolate_env(conf)
self.interpolate_run_args(conf.setup, unknown)

def interpolate_run_args(self, value: List[str], unknown):
run_args = " ".join(unknown)
interpolator = VariablesInterpolator({"run": {"args": run_args}}, skip=["secrets"])
for i in range(len(value)):
value[i] = interpolator.interpolate(value[i])
try:
for i in range(len(value)):
value[i] = interpolator.interpolate_or_error(value[i])
except InterpolatorError as e:
raise ConfigurationError(e.args[0])

def interpolate_env(self, conf: BaseRunConfiguration):
env_dict = conf.env.as_dict()
interpolator = VariablesInterpolator({"env": env_dict}, skip=["secrets"])
try:
if conf.registry_auth is not None:
conf.registry_auth = RegistryAuth(
username=interpolator.interpolate_or_error(conf.registry_auth.username),
password=interpolator.interpolate_or_error(conf.registry_auth.password),
)
except InterpolatorError as e:
raise ConfigurationError(e.args[0])


class RunWithPortsConfigurator(BaseRunConfigurator):
Expand Down
14 changes: 12 additions & 2 deletions src/dstack/_internal/utils/interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class Name:
char = first_char | set(string.digits + ".")


class InterpolatorError(ValueError):
pass


class VariablesInterpolator:
def __init__(
self, namespaces: Dict[str, Dict[str, str]], *, skip: Optional[Iterable[str]] = None
Expand Down Expand Up @@ -42,11 +46,11 @@ def interpolate(
tokens.append(s[start:opening])
closing = s.find(Pattern.closing, opening)
if closing == -1:
raise ValueError(f"No pattern closing: {s[opening:]}")
raise InterpolatorError(f"No pattern closing: {s[opening:]}")

name = s[opening + len(Pattern.opening) : closing].strip()
if not self.validate_name(name):
raise ValueError(f"Illegal reference name: {name}")
raise InterpolatorError(f"Illegal reference name: {name}")
if name.split(".")[0] in self.skip:
tokens.append(s[opening : closing + len(Pattern.closing)])
elif name in self.variables:
Expand All @@ -57,6 +61,12 @@ def interpolate(
s = "".join(tokens)
return (s, missing) if return_missing else s

def interpolate_or_error(self, s: str) -> str:
res, missing = self.interpolate(s, return_missing=True)
if len(missing) == 0:
return res
raise InterpolatorError(f"Failed to interpolate due to missing vars: {missing}")

@staticmethod
def validate_name(s: str) -> bool:
if s.count(".") != 1 or not (0 < s.index(".") < len(s) - 1):
Expand Down
21 changes: 21 additions & 0 deletions src/tests/_internal/cli/services/configurators/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dstack._internal.cli.services.configurators import get_run_configurator_class
from dstack._internal.core.errors import ConfigurationError
from dstack._internal.core.models.common import RegistryAuth
from dstack._internal.core.models.configurations import (
BaseRunConfiguration,
PortMapping,
Expand Down Expand Up @@ -61,6 +62,26 @@ def test_any_port(self):
conf.ports = [PortMapping(local_port=None, container_port=8000)]
assert modified.dict() == conf.dict()

def test_interpolates_env(self):
conf = TaskConfiguration(
image="my_image",
registry_auth=RegistryAuth(
username="${{ env.REGISTRY_USERNAME }}",
password="${{ env.REGISTRY_PASSWORD }}",
),
env=Env.parse_obj(
{
"REGISTRY_USERNAME": "test_user",
"REGISTRY_PASSWORD": "test_password",
}
),
)
modified, args = apply_args(conf, [])
assert modified.registry_auth == RegistryAuth(
username="test_user",
password="test_password",
)


def apply_args(
conf: BaseRunConfiguration, args: List[str]
Expand Down
14 changes: 7 additions & 7 deletions src/tests/_internal/utils/test_interpolator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from dstack._internal.utils.interpolator import VariablesInterpolator
from dstack._internal.utils.interpolator import InterpolatorError, VariablesInterpolator


def get_interpolator():
Expand Down Expand Up @@ -35,17 +35,17 @@ def test_missing(self):
assert ["env.name"] == missing

def test_unclosed_pattern(self):
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ secrets.password }")

def test_illegal_name(self):
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ secrets.pass-word }}")
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ .password }}")
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ password. }}")
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ secrets.password.hash }}")
with pytest.raises(ValueError):
with pytest.raises(InterpolatorError):
get_interpolator().interpolate("${{ secrets.007 }}")

0 comments on commit 4976572

Please sign in to comment.