From 16de8afff9a68f688eeec53c53764bbed0e410bc Mon Sep 17 00:00:00 2001 From: Malte Isberner <2822367+misberner@users.noreply.github.com> Date: Thu, 18 Jan 2024 11:16:32 +0100 Subject: [PATCH] Include ClusterID in request headers from pipeline controller and agent GitOrigin-RevId: a784cacbca4a89ce57738288935417b8dfc9737f --- src/gretel_client/agents/agent.py | 17 +++++++++++++++-- src/gretel_client/config.py | 23 +++++++++++++++++++---- tests/gretel_client/test_agent.py | 12 ++++++++---- 3 files changed, 42 insertions(+), 10 deletions(-) diff --git a/src/gretel_client/agents/agent.py b/src/gretel_client/agents/agent.py index 8c5bd77d..6396e664 100644 --- a/src/gretel_client/agents/agent.py +++ b/src/gretel_client/agents/agent.py @@ -33,6 +33,9 @@ configure_logging() +_CLUSTERID_HEADER_KEY = "X-Gretel-Clusterid" + + class AgentError(Exception): ... @@ -520,8 +523,16 @@ def __init__(self, config: AgentConfig): self._driver = get_driver(config) self._jobs_manager = JobManager(self._driver) self._rate_limiter = RateLimiter(config.max_workers, self._jobs_manager, config) - self._jobs_api = self._client_config.get_api(JobsApi) - self._projects_api = self._client_config.get_api(ProjectsApi) + + default_headers = None + if config.cluster_guid: + default_headers = {_CLUSTERID_HEADER_KEY: config.cluster_guid} + self._jobs_api = self._client_config.get_api( + JobsApi, default_headers=default_headers + ) + self._projects_api = self._client_config.get_api( + ProjectsApi, default_headers=default_headers + ) self._jobs = Poller( self._jobs_api, self._rate_limiter, @@ -573,6 +584,8 @@ def _update_job_status(self, job: Job) -> None: worker_json = base64.standard_b64decode(job.worker_token).decode("ascii") worker_key = json.loads(worker_json)["model_key"] headers = {"Authorization": worker_key} + if self._config.cluster_guid: + headers[_CLUSTERID_HEADER_KEY] = self._config.cluster_guid url = f"{job.gretel_endpoint}/projects/models" params = {"uid": job.uid, "type": job.job_type} data = {"uid": job.uid, "status": "pending"} diff --git a/src/gretel_client/config.py b/src/gretel_client/config.py index 82a6ed6d..5328f742 100644 --- a/src/gretel_client/config.py +++ b/src/gretel_client/config.py @@ -7,7 +7,7 @@ from enum import Enum from getpass import getpass from pathlib import Path -from typing import Optional, Type, TypeVar, Union +from typing import Dict, Optional, Type, TypeVar, Union import certifi @@ -291,6 +291,8 @@ def _get_api_client_generic( config_cls: Type[ConfigT], max_retry_attempts: int = 3, backoff_factor: float = 1, + *, + default_headers: Optional[Dict[str, str]] = None, ) -> ClientT: # disable log warnings when the retry kicks in logging.getLogger("urllib3.connectionpool").setLevel(logging.ERROR) @@ -321,7 +323,10 @@ def _get_api_client_generic( max_retry_attempts=max_retry_attempts, backoff_factor=backoff_factor, ) - return client_cls(configuration, **client_kwargs) + client = client_cls(configuration, **client_kwargs) + if default_headers: + client.default_headers.update(default_headers) + return client def _get_api_client(self, *args, **kwargs) -> ApiClient: return self._get_api_client_generic(ApiClient, Configuration, *args, **kwargs) @@ -346,6 +351,8 @@ def get_api( api_interface: Type[T], max_retry_attempts: int = 5, backoff_factor: float = 1, + *, + default_headers: Optional[Dict[str, str]] = None, ) -> T: """Instantiates and configures an api client for a given component interface. @@ -358,16 +365,24 @@ def get_api( attempts. A base factor of 2 will applied to this value to determine the time between attempts. """ - return api_interface(self._get_api_client(max_retry_attempts, backoff_factor)) + return api_interface( + self._get_api_client( + max_retry_attempts, backoff_factor, default_headers=default_headers + ) + ) def get_v1_api( self, api_interface: Type[T], max_retry_attempts: int = 5, backoff_factor: float = 1, + *, + default_headers: Optional[Dict[str, str]] = None, ) -> T: return api_interface( - self._get_v1_api_client(max_retry_attempts, backoff_factor) + self._get_v1_api_client( + max_retry_attempts, backoff_factor, default_headers=default_headers + ) ) def _check_project(self, project_name: str = None) -> Optional[str]: diff --git a/tests/gretel_client/test_agent.py b/tests/gretel_client/test_agent.py index 16751ad6..a6b5da38 100644 --- a/tests/gretel_client/test_agent.py +++ b/tests/gretel_client/test_agent.py @@ -86,10 +86,14 @@ def test_agent_server_does_start( jobs_api = MagicMock() project_api = MagicMock() - get_session_config.return_value.get_api.side_effect = { - JobsApi: jobs_api, - ProjectsApi: project_api, - }.get + def get_api(api, *args, **kwargs): + if api == JobsApi: + return jobs_api + if api == ProjectsApi: + return project_api + assert False, "unexpected API requested" + + get_session_config.return_value.get_api.side_effect = get_api jobs_api.receive_one.side_effect = [ {"data": {"job": get_mock_job()}},