Skip to content

Commit

Permalink
ENGPLAT-3: Add S3 Connector
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 0e5f8ed5c5ec1dee95866830ada2b9e66d741544
  • Loading branch information
drew committed Jan 19, 2022
1 parent e6a02ae commit b9fd267
Show file tree
Hide file tree
Showing 27 changed files with 1,753 additions and 349 deletions.
83 changes: 83 additions & 0 deletions docs/rest/JobsApi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# gretel_client.rest.JobsApi

All URIs are relative to *https://api-dev.gretel.cloud*

Method | HTTP request | Description
------------- | ------------- | -------------
[**receive_one**](JobsApi.md#receive_one) | **POST** /jobs/receive_one | Get Gretel job for scheduling


# **receive_one**
> {str: (bool, date, datetime, dict, float, int, list, str, none_type)} receive_one()
Get Gretel job for scheduling

### Example

* Api Key Authentication (ApiKey):
```python
import time
import gretel_client.rest
from gretel_client.rest.api import jobs_api
from pprint import pprint
# Defining the host is optional and defaults to https://api-dev.gretel.cloud
# See configuration.py for a list of all supported configuration parameters.
configuration = gretel_client.rest.Configuration(
host = "https://api-dev.gretel.cloud"
)

# The client must configure the authentication and authorization parameters
# in accordance with the API server security policy.
# Examples for each auth method are provided below, use the example that
# satisfies your auth use case.

# Configure API key authorization: ApiKey
configuration.api_key['ApiKey'] = 'YOUR_API_KEY'

# Uncomment below to setup prefix (e.g. Bearer) for API key, if needed
# configuration.api_key_prefix['ApiKey'] = 'Bearer'

# Enter a context with an instance of the API client
with gretel_client.rest.ApiClient(configuration) as api_client:
# Create an instance of the API class
api_instance = jobs_api.JobsApi(api_client)
project_id = "project_id_example" # str | (optional)

# example passing only required values which don't have defaults set
# and optional values
try:
# Get Gretel job for scheduling
api_response = api_instance.receive_one(project_id=project_id)
pprint(api_response)
except gretel_client.rest.ApiException as e:
print("Exception when calling JobsApi->receive_one: %s\n" % e)
```


### Parameters

Name | Type | Description | Notes
------------- | ------------- | ------------- | -------------
**project_id** | **str**| | [optional]

### Return type

**{str: (bool, date, datetime, dict, float, int, list, str, none_type)}**

### Authorization

[ApiKey](../README.md#ApiKey)

### HTTP request headers

- **Content-Type**: Not defined
- **Accept**: application/json


### HTTP response details
| Status code | Description | Response headers |
|-------------|-------------|------------------|
**200** | Job to schedule | - |

[[Back to top]](#) [[Back to API list]](../README.md#documentation-for-api-endpoints) [[Back to Model list]](../README.md#documentation-for-models) [[Back to README]](../README.md)

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
backports.cached-property==1.0.0.post2
click==7.1.2
docker==4.4.1
python_dateutil>=2.8.0
Expand Down
283 changes: 283 additions & 0 deletions src/gretel_client/agents/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
"""
Classes responsible for running local Gretel worker agents.
"""
from __future__ import annotations

import logging
import threading

from dataclasses import asdict, dataclass
from typing import Callable, Dict, Generic, Iterator, List, Optional

from backports.cached_property import cached_property

from gretel_client.agents.drivers.driver import ComputeUnit, Driver
from gretel_client.agents.drivers.registry import get_driver
from gretel_client.agents.logger import configure_logging
from gretel_client.config import configure_custom_logger, get_session_config
from gretel_client.docker import CloudCreds, DataVolumeDef
from gretel_client.projects import get_project
from gretel_client.rest.apis import JobsApi, ProjectsApi, UsersApi

configure_logging()


class AgentError(Exception):
...


@dataclass
class AgentConfig:
"""Provides various configuration knobs for running a Gretel Agent."""

driver: str
"""Defines the driver used to launch containers from."""

max_workers: int = 1
"""The max number of workers the agent instance will launch."""

log_factory: Callable = lambda _: None
"""A factory function to ship worker logs to. If none is provided
log messages from the worker will be suppressed, though they still
be logged to their respective artifact endpoint.
"""

project: Optional[str] = None
"""Determines the project to pull jobs for. If none if provided, than
jobs from all projects will be fetched.
"""

creds: Optional[CloudCreds] = None
"""Provide credentials to propagate to the worker"""

artifact_endpoint: Optional[str] = None
"""Configure an artifact endpoint for workers to store intermediate data on."""

_max_runtime_seconds: Optional[int] = None
"""TODO: implement"""

def __post_init__(self):
if not self._max_runtime_seconds:
self._max_runtime_seconds = self._lookup_max_runtime()

@property
def max_runtime_seconds(self) -> int:
if not self._max_runtime_seconds:
raise AgentError("Could not fetch user config. Please restart the agent.")
return self._max_runtime_seconds

def _lookup_max_runtime(self) -> int:
user_api = get_session_config().get_api(UsersApi)
return (
user_api.users_me()
.get("data")
.get("me")
.get("service_limits")
.get("max_job_runtime")
)

@property
def as_dict(self) -> dict:
return asdict(self)

@cached_property
def project_id(self) -> str:
project = get_project(name=self.project)
return project.project_id


@dataclass
class Job:

"""Job container class.
Contains various Gretel Job properties that are used by each
driver to configure and run the respective job.
"""

uid: str
job_type: str
container_image: str
worker_token: str
max_runtime_seconds: int
log: Optional[Callable] = None
cloud_creds: Optional[CloudCreds] = None
artifact_endpoint: Optional[str] = None

@classmethod
def from_dict(cls, source: dict, agent_config: AgentConfig) -> Job:
return cls(
uid=source["run_id"],
job_type=source["job_type"],
container_image=source["container_image"],
worker_token=source["worker_token"],
log=agent_config.log_factory(source["run_id"]),
max_runtime_seconds=agent_config.max_runtime_seconds,
cloud_creds=agent_config.creds,
artifact_endpoint=agent_config.artifact_endpoint,
)

@property
def params(self) -> Dict[str, str]:
params = {"--worker-token": self.worker_token}
if self.artifact_endpoint:
params["--artifact-endpoint"] = self.artifact_endpoint
return params

@property
def env(self) -> Dict[str, str]:
if self.cloud_creds:
return self.cloud_creds.env
return {}


class RateLimiter:
"""Limits the amount of jobs the agent can place."""

def __init__(self, max_active_jobs: int, job_manager: JobManager):
self._job_manger = job_manager
self._max_active_jobs = max_active_jobs

def has_capacity(self) -> bool:
return self._job_manger.active_jobs < self._max_active_jobs


class JobManager(Generic[ComputeUnit]):
"""Responsible for tracking the status of jobs as they are
scheduled and compelted.
TODO: Add support for cleaning stuck jobs based on a config's
max runtime.
"""

def __init__(self, driver: Driver):
self._active_jobs: Dict[str, ComputeUnit] = {}
self._driver = driver
self._logger = logging.getLogger(__name__)

def add_job(self, job: Job, unit: ComputeUnit):
self._active_jobs[job.uid] = unit

def _update_active_jobs(self):
for job in list(self._active_jobs):
if not self._driver.active(self._active_jobs[job]):
self._logger.info(f"Job {job} completed")
self._driver.clean(self._active_jobs[job])
self._active_jobs.pop(job)

@property
def active_jobs(self) -> int:
self._update_active_jobs()
return len(self._active_jobs)

def shutdown(self):
self._logger.info("Attemping to shutdown job manager")
self._update_active_jobs()
for job, unit in self._active_jobs.items():
self._logger.info(f"Shutting down job {job} unit {unit}")
self._driver.shutdown(unit)


class Poller(Iterator):
"""
Provides an iterator interface for fetching polling for new
``Job``s from the API. If no jobs are available, the iterator
will block until a new ``Job`` is available.
Args:
jobs_api: Job api client instance.
rate_limiter: Uses to ensure new jobs aren't returned until the
agent has capacity.
agent_config: Agent config used to configure the ``Job``.
"""

max_wait_secs = 16

def __init__(
self, jobs_api: JobsApi, rate_limiter: RateLimiter, agent_config: AgentConfig
):
self._agent_config = agent_config
self._jobs_api = jobs_api
self._rate_limiter = rate_limiter
self._logger = logging.getLogger(__name__)
self._interupt = threading.Event()

def __iter__(self):
return self

def interupt(self):
return self._interupt.set()

def poll_endpoint(self) -> Optional[Job]:
next_job = self._jobs_api.receive_one(project_id=self._agent_config.project_id)
if next_job["data"]["job"] is not None:
return Job.from_dict(next_job["data"]["job"], self._agent_config)

def __next__(self) -> Optional[Job]:
wait_secs = 2
while True and not self._interupt.is_set():
if self._rate_limiter.has_capacity():
job = None
try:
job = self.poll_endpoint()
except Exception as ex:
self._logger.warning(
f"There was a problem calling the jobs endpoint {ex}"
)
if job:
return job
self._interupt.wait(wait_secs)
if wait_secs > Poller.max_wait_secs:
wait_secs = 2
self._logger.info("Heartbeat from poller, still here...")
else:
wait_secs += wait_secs ** 2


class Agent:
"""Starts an agent"""

def __init__(self, config: AgentConfig):
self._logger = logging.getLogger(__name__)
configure_custom_logger(self._logger)
self._config = config
self._client_config = get_session_config()
self._driver = get_driver(config)
self._jobs_manager = JobManager(self._driver)
self._rate_limiter = RateLimiter(AgentConfig.max_workers, self._jobs_manager)
self._jobs_api = self._client_config.get_api(JobsApi)
self._projects_api = self._client_config.get_api(ProjectsApi)
self._jobs = Poller(self._jobs_api, self._rate_limiter, self._config)
self._interupt = threading.Event()

def start(self, cooloff: float = 5):
"""Start the agent"""
self._logger.info("Agent started, waiting for work to arrive")
for job in self._jobs:
if not job:
if self._interupt.is_set():
return
else:
continue
self._logger.info(f"Got {job.job_type} job {job.uid}, scheduling now.")

self._jobs_manager.add_job(job, self._driver.schedule(job))
self._logger.info(f"Container for job {job.uid} scheduled")
# todo: add in read lock to jobs endpoint. this sleep is
# a stopgap until then. without this the agent is going to
# try and launch multiple containers of the same job.
self._interupt.wait(cooloff)

def interupt(self):
"""Shuts down the agent"""
self._jobs.interupt()
self._interupt.set()
self._logger.info("Server preparing to shutdown")
self._jobs_manager.shutdown()
self._logger.info("Server shutdown complete")


def get_agent(config: AgentConfig) -> Agent:
return Agent(config=config)
Loading

0 comments on commit b9fd267

Please sign in to comment.