-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GitOrigin-RevId: 0e5f8ed5c5ec1dee95866830ada2b9e66d741544
- Loading branch information
Showing
27 changed files
with
1,753 additions
and
349 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# gretel_client.rest.JobsApi | ||
|
||
All URIs are relative to *https://api-dev.gretel.cloud* | ||
|
||
Method | HTTP request | Description | ||
------------- | ------------- | ------------- | ||
[**receive_one**](JobsApi.md#receive_one) | **POST** /jobs/receive_one | Get Gretel job for scheduling | ||
|
||
|
||
# **receive_one** | ||
> {str: (bool, date, datetime, dict, float, int, list, str, none_type)} receive_one() | ||
Get Gretel job for scheduling | ||
|
||
### Example | ||
|
||
* Api Key Authentication (ApiKey): | ||
```python | ||
import time | ||
import gretel_client.rest | ||
from gretel_client.rest.api import jobs_api | ||
from pprint import pprint | ||
# Defining the host is optional and defaults to https://api-dev.gretel.cloud | ||
# See configuration.py for a list of all supported configuration parameters. | ||
configuration = gretel_client.rest.Configuration( | ||
host = "https://api-dev.gretel.cloud" | ||
) | ||
|
||
# The client must configure the authentication and authorization parameters | ||
# in accordance with the API server security policy. | ||
# Examples for each auth method are provided below, use the example that | ||
# satisfies your auth use case. | ||
|
||
# Configure API key authorization: ApiKey | ||
configuration.api_key['ApiKey'] = 'YOUR_API_KEY' | ||
|
||
# Uncomment below to setup prefix (e.g. Bearer) for API key, if needed | ||
# configuration.api_key_prefix['ApiKey'] = 'Bearer' | ||
|
||
# Enter a context with an instance of the API client | ||
with gretel_client.rest.ApiClient(configuration) as api_client: | ||
# Create an instance of the API class | ||
api_instance = jobs_api.JobsApi(api_client) | ||
project_id = "project_id_example" # str | (optional) | ||
|
||
# example passing only required values which don't have defaults set | ||
# and optional values | ||
try: | ||
# Get Gretel job for scheduling | ||
api_response = api_instance.receive_one(project_id=project_id) | ||
pprint(api_response) | ||
except gretel_client.rest.ApiException as e: | ||
print("Exception when calling JobsApi->receive_one: %s\n" % e) | ||
``` | ||
|
||
|
||
### Parameters | ||
|
||
Name | Type | Description | Notes | ||
------------- | ------------- | ------------- | ------------- | ||
**project_id** | **str**| | [optional] | ||
|
||
### Return type | ||
|
||
**{str: (bool, date, datetime, dict, float, int, list, str, none_type)}** | ||
|
||
### Authorization | ||
|
||
[ApiKey](../README.md#ApiKey) | ||
|
||
### HTTP request headers | ||
|
||
- **Content-Type**: Not defined | ||
- **Accept**: application/json | ||
|
||
|
||
### HTTP response details | ||
| Status code | Description | Response headers | | ||
|-------------|-------------|------------------| | ||
**200** | Job to schedule | - | | ||
|
||
[[Back to top]](#) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to Model list]](../README.md#documentation-for-models) [[Back to README]](../README.md) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
backports.cached-property==1.0.0.post2 | ||
click==7.1.2 | ||
docker==4.4.1 | ||
python_dateutil>=2.8.0 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,283 @@ | ||
""" | ||
Classes responsible for running local Gretel worker agents. | ||
""" | ||
from __future__ import annotations | ||
|
||
import logging | ||
import threading | ||
|
||
from dataclasses import asdict, dataclass | ||
from typing import Callable, Dict, Generic, Iterator, List, Optional | ||
|
||
from backports.cached_property import cached_property | ||
|
||
from gretel_client.agents.drivers.driver import ComputeUnit, Driver | ||
from gretel_client.agents.drivers.registry import get_driver | ||
from gretel_client.agents.logger import configure_logging | ||
from gretel_client.config import configure_custom_logger, get_session_config | ||
from gretel_client.docker import CloudCreds, DataVolumeDef | ||
from gretel_client.projects import get_project | ||
from gretel_client.rest.apis import JobsApi, ProjectsApi, UsersApi | ||
|
||
configure_logging() | ||
|
||
|
||
class AgentError(Exception): | ||
... | ||
|
||
|
||
@dataclass | ||
class AgentConfig: | ||
"""Provides various configuration knobs for running a Gretel Agent.""" | ||
|
||
driver: str | ||
"""Defines the driver used to launch containers from.""" | ||
|
||
max_workers: int = 1 | ||
"""The max number of workers the agent instance will launch.""" | ||
|
||
log_factory: Callable = lambda _: None | ||
"""A factory function to ship worker logs to. If none is provided | ||
log messages from the worker will be suppressed, though they still | ||
be logged to their respective artifact endpoint. | ||
""" | ||
|
||
project: Optional[str] = None | ||
"""Determines the project to pull jobs for. If none if provided, than | ||
jobs from all projects will be fetched. | ||
""" | ||
|
||
creds: Optional[CloudCreds] = None | ||
"""Provide credentials to propagate to the worker""" | ||
|
||
artifact_endpoint: Optional[str] = None | ||
"""Configure an artifact endpoint for workers to store intermediate data on.""" | ||
|
||
_max_runtime_seconds: Optional[int] = None | ||
"""TODO: implement""" | ||
|
||
def __post_init__(self): | ||
if not self._max_runtime_seconds: | ||
self._max_runtime_seconds = self._lookup_max_runtime() | ||
|
||
@property | ||
def max_runtime_seconds(self) -> int: | ||
if not self._max_runtime_seconds: | ||
raise AgentError("Could not fetch user config. Please restart the agent.") | ||
return self._max_runtime_seconds | ||
|
||
def _lookup_max_runtime(self) -> int: | ||
user_api = get_session_config().get_api(UsersApi) | ||
return ( | ||
user_api.users_me() | ||
.get("data") | ||
.get("me") | ||
.get("service_limits") | ||
.get("max_job_runtime") | ||
) | ||
|
||
@property | ||
def as_dict(self) -> dict: | ||
return asdict(self) | ||
|
||
@cached_property | ||
def project_id(self) -> str: | ||
project = get_project(name=self.project) | ||
return project.project_id | ||
|
||
|
||
@dataclass | ||
class Job: | ||
|
||
"""Job container class. | ||
Contains various Gretel Job properties that are used by each | ||
driver to configure and run the respective job. | ||
""" | ||
|
||
uid: str | ||
job_type: str | ||
container_image: str | ||
worker_token: str | ||
max_runtime_seconds: int | ||
log: Optional[Callable] = None | ||
cloud_creds: Optional[CloudCreds] = None | ||
artifact_endpoint: Optional[str] = None | ||
|
||
@classmethod | ||
def from_dict(cls, source: dict, agent_config: AgentConfig) -> Job: | ||
return cls( | ||
uid=source["run_id"], | ||
job_type=source["job_type"], | ||
container_image=source["container_image"], | ||
worker_token=source["worker_token"], | ||
log=agent_config.log_factory(source["run_id"]), | ||
max_runtime_seconds=agent_config.max_runtime_seconds, | ||
cloud_creds=agent_config.creds, | ||
artifact_endpoint=agent_config.artifact_endpoint, | ||
) | ||
|
||
@property | ||
def params(self) -> Dict[str, str]: | ||
params = {"--worker-token": self.worker_token} | ||
if self.artifact_endpoint: | ||
params["--artifact-endpoint"] = self.artifact_endpoint | ||
return params | ||
|
||
@property | ||
def env(self) -> Dict[str, str]: | ||
if self.cloud_creds: | ||
return self.cloud_creds.env | ||
return {} | ||
|
||
|
||
class RateLimiter: | ||
"""Limits the amount of jobs the agent can place.""" | ||
|
||
def __init__(self, max_active_jobs: int, job_manager: JobManager): | ||
self._job_manger = job_manager | ||
self._max_active_jobs = max_active_jobs | ||
|
||
def has_capacity(self) -> bool: | ||
return self._job_manger.active_jobs < self._max_active_jobs | ||
|
||
|
||
class JobManager(Generic[ComputeUnit]): | ||
"""Responsible for tracking the status of jobs as they are | ||
scheduled and compelted. | ||
TODO: Add support for cleaning stuck jobs based on a config's | ||
max runtime. | ||
""" | ||
|
||
def __init__(self, driver: Driver): | ||
self._active_jobs: Dict[str, ComputeUnit] = {} | ||
self._driver = driver | ||
self._logger = logging.getLogger(__name__) | ||
|
||
def add_job(self, job: Job, unit: ComputeUnit): | ||
self._active_jobs[job.uid] = unit | ||
|
||
def _update_active_jobs(self): | ||
for job in list(self._active_jobs): | ||
if not self._driver.active(self._active_jobs[job]): | ||
self._logger.info(f"Job {job} completed") | ||
self._driver.clean(self._active_jobs[job]) | ||
self._active_jobs.pop(job) | ||
|
||
@property | ||
def active_jobs(self) -> int: | ||
self._update_active_jobs() | ||
return len(self._active_jobs) | ||
|
||
def shutdown(self): | ||
self._logger.info("Attemping to shutdown job manager") | ||
self._update_active_jobs() | ||
for job, unit in self._active_jobs.items(): | ||
self._logger.info(f"Shutting down job {job} unit {unit}") | ||
self._driver.shutdown(unit) | ||
|
||
|
||
class Poller(Iterator): | ||
""" | ||
Provides an iterator interface for fetching polling for new | ||
``Job``s from the API. If no jobs are available, the iterator | ||
will block until a new ``Job`` is available. | ||
Args: | ||
jobs_api: Job api client instance. | ||
rate_limiter: Uses to ensure new jobs aren't returned until the | ||
agent has capacity. | ||
agent_config: Agent config used to configure the ``Job``. | ||
""" | ||
|
||
max_wait_secs = 16 | ||
|
||
def __init__( | ||
self, jobs_api: JobsApi, rate_limiter: RateLimiter, agent_config: AgentConfig | ||
): | ||
self._agent_config = agent_config | ||
self._jobs_api = jobs_api | ||
self._rate_limiter = rate_limiter | ||
self._logger = logging.getLogger(__name__) | ||
self._interupt = threading.Event() | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def interupt(self): | ||
return self._interupt.set() | ||
|
||
def poll_endpoint(self) -> Optional[Job]: | ||
next_job = self._jobs_api.receive_one(project_id=self._agent_config.project_id) | ||
if next_job["data"]["job"] is not None: | ||
return Job.from_dict(next_job["data"]["job"], self._agent_config) | ||
|
||
def __next__(self) -> Optional[Job]: | ||
wait_secs = 2 | ||
while True and not self._interupt.is_set(): | ||
if self._rate_limiter.has_capacity(): | ||
job = None | ||
try: | ||
job = self.poll_endpoint() | ||
except Exception as ex: | ||
self._logger.warning( | ||
f"There was a problem calling the jobs endpoint {ex}" | ||
) | ||
if job: | ||
return job | ||
self._interupt.wait(wait_secs) | ||
if wait_secs > Poller.max_wait_secs: | ||
wait_secs = 2 | ||
self._logger.info("Heartbeat from poller, still here...") | ||
else: | ||
wait_secs += wait_secs ** 2 | ||
|
||
|
||
class Agent: | ||
"""Starts an agent""" | ||
|
||
def __init__(self, config: AgentConfig): | ||
self._logger = logging.getLogger(__name__) | ||
configure_custom_logger(self._logger) | ||
self._config = config | ||
self._client_config = get_session_config() | ||
self._driver = get_driver(config) | ||
self._jobs_manager = JobManager(self._driver) | ||
self._rate_limiter = RateLimiter(AgentConfig.max_workers, self._jobs_manager) | ||
self._jobs_api = self._client_config.get_api(JobsApi) | ||
self._projects_api = self._client_config.get_api(ProjectsApi) | ||
self._jobs = Poller(self._jobs_api, self._rate_limiter, self._config) | ||
self._interupt = threading.Event() | ||
|
||
def start(self, cooloff: float = 5): | ||
"""Start the agent""" | ||
self._logger.info("Agent started, waiting for work to arrive") | ||
for job in self._jobs: | ||
if not job: | ||
if self._interupt.is_set(): | ||
return | ||
else: | ||
continue | ||
self._logger.info(f"Got {job.job_type} job {job.uid}, scheduling now.") | ||
|
||
self._jobs_manager.add_job(job, self._driver.schedule(job)) | ||
self._logger.info(f"Container for job {job.uid} scheduled") | ||
# todo: add in read lock to jobs endpoint. this sleep is | ||
# a stopgap until then. without this the agent is going to | ||
# try and launch multiple containers of the same job. | ||
self._interupt.wait(cooloff) | ||
|
||
def interupt(self): | ||
"""Shuts down the agent""" | ||
self._jobs.interupt() | ||
self._interupt.set() | ||
self._logger.info("Server preparing to shutdown") | ||
self._jobs_manager.shutdown() | ||
self._logger.info("Server shutdown complete") | ||
|
||
|
||
def get_agent(config: AgentConfig) -> Agent: | ||
return Agent(config=config) |
Oops, something went wrong.