Skip to content

Commit

Permalink
ENGPLAT-131: add agent gpu plumbing
Browse files Browse the repository at this point in the history
GitOrigin-RevId: ea01cd86877c1dc6a4506c6c70a6b4650317a2a2
  • Loading branch information
drew committed Mar 29, 2022
1 parent ac8a74f commit 07be637
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
6 changes: 6 additions & 0 deletions src/gretel_client/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class Job:
container_image: str
worker_token: str
max_runtime_seconds: int
instance_type: str
log: Optional[Callable] = None
cloud_creds: Optional[List[CloudCreds]] = None
artifact_endpoint: Optional[str] = None
Expand All @@ -115,6 +116,7 @@ def from_dict(cls, source: dict, agent_config: AgentConfig) -> Job:
return cls(
uid=source["run_id"] or source["model_id"],
job_type=source["job_type"],
instance_type=source["instance_type"],
container_image=source["container_image"],
worker_token=source["worker_token"],
log=agent_config.log_factory(source["run_id"]),
Expand Down Expand Up @@ -150,6 +152,10 @@ def gretel_stage(self) -> str:
return "dev"
return "prod"

@property
def needs_gpu(self) -> bool:
return "gpu" in self.instance_type.lower()


class RateLimiter:
"""Limits the amount of jobs the agent can place."""
Expand Down
6 changes: 6 additions & 0 deletions src/gretel_client/agents/drivers/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from gretel_client.agents.drivers.driver import Driver
from gretel_client.docker import build_container, Container
from gretel_client.projects.docker import DEFAULT_GPU_CONFIG

if TYPE_CHECKING:
from gretel_client.agents.agent import AgentConfig, Job
Expand Down Expand Up @@ -41,11 +42,16 @@ def schedule(self, job: Job) -> Container:
for vol in self._agent_config.volumes:
volumes.append(vol)

device_requests = []
if job.needs_gpu:
device_requests.append(DEFAULT_GPU_CONFIG)

container_run = build_container(
image=job.container_image,
params=job.params,
env=job.env,
volumes=volumes,
device_requests=device_requests,
detach=True,
)
container_run.start()
Expand Down
20 changes: 19 additions & 1 deletion tests/gretel_client/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
from gretel_client.agents.agent import Agent, AgentConfig, Job, Poller
from gretel_client.agents.drivers.docker import Docker
from gretel_client.docker import CaCertFile
from gretel_client.projects.docker import DEFAULT_GPU_CONFIG
from gretel_client.rest.apis import JobsApi, ProjectsApi


def get_mock_job() -> dict:
def get_mock_job(instance_type: str = "cpu-standard") -> dict:
return {
"run_id": "run-id",
"job_type": "run",
"container_image": "gretelai/transforms",
"worker_token": "abcdef1243",
"instance_type": instance_type,
}


Expand Down Expand Up @@ -128,3 +130,19 @@ def test_job_with_ca_bundle(docker_client: MagicMock, get_session_config: MagicM
docker.schedule(job)

assert cert.volume in docker_client.mock_calls[0][2]["volumes"]


@patch("gretel_client.agents.agent.get_session_config")
@patch("gretel_client.agents.drivers.docker.build_container")
def test_job_needs_gpu(build_container: MagicMock, get_session_config: MagicMock):
config = AgentConfig(driver="docker")
job = Job.from_dict(get_mock_job(instance_type="gpu-standard"), config)

assert job.needs_gpu

docker = Docker.from_config(config)
docker.schedule(job)

assert build_container.call_args_list[0][1]["device_requests"] == [
DEFAULT_GPU_CONFIG
]

0 comments on commit 07be637

Please sign in to comment.