Skip to content

Commit

Permalink
ENGPLAT-128: Agent can configure worker env and volume mounts
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 21982a9020e103286cc5e3ab6376325c0fbe8485
  • Loading branch information
drew committed Mar 26, 2022
1 parent 072fd03 commit 7045740
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 13 deletions.
21 changes: 16 additions & 5 deletions src/gretel_client/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
import threading

from dataclasses import asdict, dataclass
from typing import Callable, Dict, Generic, Iterator, Optional
from typing import Callable, Dict, Generic, Iterator, List, Optional

from backports.cached_property import cached_property

from gretel_client.agents.drivers.driver import ComputeUnit, Driver
from gretel_client.agents.drivers.registry import get_driver
from gretel_client.agents.logger import configure_logging
from gretel_client.config import configure_custom_logger, get_session_config
from gretel_client.docker import CloudCreds
from gretel_client.docker import CloudCreds, DataVolumeDef
from gretel_client.projects import get_project
from gretel_client.rest.apis import JobsApi, ProjectsApi, UsersApi

Expand Down Expand Up @@ -46,12 +46,18 @@ class AgentConfig:
jobs from all projects will be fetched.
"""

creds: Optional[CloudCreds] = None
creds: Optional[List[CloudCreds]] = None
"""Provide credentials to propagate to the worker"""

artifact_endpoint: Optional[str] = None
"""Configure an artifact endpoint for workers to store intermediate data on."""

volumes: Optional[List[DataVolumeDef]] = None
"""A list of volumes to mount into the worker container"""

env_vars: Optional[dict] = None
"""A list of environment variables to mount into the container"""

_max_runtime_seconds: Optional[int] = None
"""TODO: implement"""

Expand Down Expand Up @@ -100,8 +106,9 @@ class Job:
worker_token: str
max_runtime_seconds: int
log: Optional[Callable] = None
cloud_creds: Optional[CloudCreds] = None
cloud_creds: Optional[List[CloudCreds]] = None
artifact_endpoint: Optional[str] = None
env_vars: Optional[dict] = None

@classmethod
def from_dict(cls, source: dict, agent_config: AgentConfig) -> Job:
Expand All @@ -114,6 +121,7 @@ def from_dict(cls, source: dict, agent_config: AgentConfig) -> Job:
max_runtime_seconds=agent_config.max_runtime_seconds,
cloud_creds=agent_config.creds,
artifact_endpoint=agent_config.artifact_endpoint,
env_vars=agent_config.env_vars,
)

@property
Expand All @@ -130,7 +138,10 @@ def env(self) -> Dict[str, str]:
"GRETEL_STAGE": self.gretel_stage,
}
if self.cloud_creds:
params.update(self.cloud_creds.env)
for cred in self.cloud_creds:
params.update(cred.env)
if self.env_vars:
params.update(self.env_vars)
return params

@property
Expand Down
13 changes: 10 additions & 3 deletions src/gretel_client/agents/drivers/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,24 @@ class Docker(Driver):
VMs where a docker daemon is running.
"""

def __init__(self):
def __init__(self, agent_config: AgentConfig):
self._docker_client = docker.from_env()
self._agent_config = agent_config

@classmethod
def from_config(cls, config: AgentConfig) -> Docker:
return cls()
return cls(config)

def schedule(self, job: Job) -> Container:
volumes = []
if job.cloud_creds:
volumes.append(job.cloud_creds.volume)
for cred in job.cloud_creds:
volumes.append(cred.volume)

if self._agent_config.volumes:
for vol in self._agent_config.volumes:
volumes.append(vol)

container_run = build_container(
image=job.container_image,
params=job.params,
Expand Down
49 changes: 45 additions & 4 deletions src/gretel_client/cli/agent.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging

from typing import Callable
from pathlib import Path
from typing import Callable, List, Optional

import click

from gretel_client.agents.agent import AgentConfig, get_agent
from gretel_client.cli.common import pass_session, project_option, SessionContext
from gretel_client.config import get_session_config
from gretel_client.docker import AwsCredFile
from gretel_client.docker import AwsCredFile, CaCertFile, DataVolumeDef


@click.group(
Expand Down Expand Up @@ -46,6 +47,23 @@ def build_logger(job_id: str) -> Callable:
default=None,
envvar="GRETEL_ARTIFACT_ENDPOINT",
)
@click.option(
"--env",
metavar="KEY=VALUE",
help="Pass environment variables into the worker container",
multiple=True,
)
@click.option(
"--volume",
metavar="HOST:CONTAINER",
help="Mount single file into the worker container. HOST and CONTAINER must be files",
multiple=True,
)
@click.option(
"--ca-bundle",
metavar="PATH",
help="Mount custom CA into each worker container",
)
@pass_session
def start(
sc: SessionContext,
Expand All @@ -54,16 +72,39 @@ def start(
project: str = None,
aws_cred_path: str = None,
artifact_endpoint: str = None,
env: List[str] = None,
volume: List[str] = None,
ca_bundle: Optional[str] = None,
):
sc.log.info(f"Starting Gretel agent using driver {driver}")
aws_creds = AwsCredFile(cred_from_agent=aws_cred_path) if aws_cred_path else None
creds = []

if aws_cred_path:
creds.append(AwsCredFile(cred_from_agent=aws_cred_path))

if ca_bundle:
creds.append(CaCertFile(cred_from_agent=ca_bundle))

volumes = []
if volume:
for vol in volume:
host_path, target = vol.split(":", maxsplit=1)
target_path = Path(target)
volumes.append(
DataVolumeDef(str(target_path.parent), [(host_path, target_path.name)])
)

env_dict = dict(e.split("=", maxsplit=1) for e in env) if env else None

config = AgentConfig(
project=project,
max_workers=max_workers,
driver=driver,
creds=aws_creds,
creds=creds,
log_factory=build_logger,
artifact_endpoint=artifact_endpoint,
env_vars=env_dict,
volumes=volumes,
)
agent = get_agent(config)
sc.register_cleanup(lambda: agent.interupt())
Expand Down
24 changes: 24 additions & 0 deletions src/gretel_client/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,13 @@ class AuthStrategy(Enum):
class DataVolumeDef:

target_dir: str
"""Defines container directory to place host file"""

host_files: List[Tuple[str, Optional[str]]]
"""Specify what files to place into the container. The
first item in the tuple is the host file and the second
item optionally renames the host file on the container.
"""


class DataVolume:
Expand Down Expand Up @@ -477,6 +483,24 @@ def env(self) -> Dict[str, str]:
}


class CaCertFile(CloudCreds):

base_dir: str = "/etc/ssl"
credential_file = "agent_ca.crt"

@property
def volume(self) -> DataVolumeDef:
return DataVolumeDef(
self.base_dir, [(self.cred_from_agent, self.credential_file)]
)

@property
def env(self) -> Dict[str, str]:
return {
"REQUESTS_CA_BUNDLE": f"{self.base_dir}/{self.credential_file}",
}


def extract_container_path(
container: docker.models.containers.Container, container_path: str, host_path: str
):
Expand Down
22 changes: 21 additions & 1 deletion tests/gretel_client/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

import pytest

from gretel_client.agents.agent import Agent, AgentConfig, Poller
from gretel_client.agents.agent import Agent, AgentConfig, Job, Poller
from gretel_client.agents.drivers.docker import Docker
from gretel_client.docker import CaCertFile
from gretel_client.rest.apis import JobsApi, ProjectsApi


Expand Down Expand Up @@ -108,3 +110,21 @@ def test_agent_job_poller(agent_config: AgentConfig):
assert job.job_type == job_data["job_type"]
assert job.container_image == job_data["container_image"]
assert job.worker_token == job_data["worker_token"]


@patch("gretel_client.agents.agent.get_session_config")
@patch("gretel_client.agents.drivers.docker.build_container")
def test_job_with_ca_bundle(docker_client: MagicMock, get_session_config: MagicMock):
cert = CaCertFile(cred_from_agent="/bundle/path")
config = AgentConfig(
driver="docker",
creds=[cert],
)
job = Job.from_dict(get_mock_job(), config)

assert job.env["REQUESTS_CA_BUNDLE"] == "/etc/ssl/agent_ca.crt"

docker = Docker.from_config(config)
docker.schedule(job)

assert cert.volume in docker_client.mock_calls[0][2]["volumes"]

0 comments on commit 7045740

Please sign in to comment.