From 86e92d040e9c442d0852a11aba54be2c65634cff Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Thu, 9 Mar 2023 16:32:28 -0800 Subject: [PATCH 1/8] Device auth flow Signed-off-by: Ketan Umare --- flytekit/clients/auth/authenticator.py | 52 ++++++++++++++++++++------ 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 183c1787cd..3b69a0c261 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -22,6 +22,7 @@ class ClientConfig: authorization_endpoint: str redirect_uri: str client_id: str + device_authorization_endpoint: typing.Optional[str] = None scopes: typing.List[str] = None header_key: str = "authorization" @@ -79,11 +80,11 @@ class PKCEAuthenticator(Authenticator): """ def __init__( - self, - endpoint: str, - cfg_store: ClientConfigStore, - header_key: typing.Optional[str] = None, - verify: typing.Optional[typing.Union[bool, str]] = None, + self, + endpoint: str, + cfg_store: ClientConfigStore, + header_key: typing.Optional[str] = None, + verify: typing.Optional[typing.Union[bool, str]] = None, ): """ Initialize with default creds from KeyStore using the endpoint name @@ -158,12 +159,12 @@ class ClientCredentialsAuthenticator(Authenticator): _utf_8 = "utf-8" def __init__( - self, - endpoint: str, - client_id: str, - client_secret: str, - cfg_store: ClientConfigStore, - header_key: str = None, + self, + endpoint: str, + client_id: str, + client_secret: str, + cfg_store: ClientConfigStore, + header_key: str = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") @@ -233,3 +234,32 @@ def refresh_credentials(self): token, expires_in = self.get_token(token_endpoint, authorization_header, scopes) logging.info("Retrieved new token, expires in {}".format(expires_in)) self._creds = Credentials(token) + + +class DeviceCodeAuthenticator(Authenticator): + """ + This Authenticator implements the Device Code authorization flow useful for headless user authentication. + + Examples described + - https://developer.okta.com/docs/guides/device-authorization-grant/main/ + - https://auth0.com/docs/get-started/authentication-and-authorization-flow/device-authorization-flow#device-flow + """ + + def __init__(self, + endpoint: str, + cfg_store: ClientConfigStore, + header_key: typing.Optional[str] = None, + audience: typing.Optional[str] = None): + pass + + def _get_code(self): + pass + + def _poll(self): + pass + + def _get_token(self): + pass + + def refresh_credentials(self): + pass From 20e1778295e063ecb5a87617bc19d20150700ebd Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Mon, 13 Mar 2023 21:52:34 -0700 Subject: [PATCH 2/8] Device AuthFlow is now available in flytekit Signed-off-by: Ketan Umare --- flytekit/clients/auth/auth_client.py | 4 +- flytekit/clients/auth/authenticator.py | 126 +++++++++----------- flytekit/clients/auth/exceptions.py | 8 ++ flytekit/clients/auth/keyring.py | 1 + flytekit/clients/auth/token_client.py | 149 ++++++++++++++++++++++++ flytekit/clients/auth_helper.py | 4 + flytekit/clis/sdk_in_container/utils.py | 0 flytekit/configuration/__init__.py | 2 + 8 files changed, 220 insertions(+), 74 deletions(-) create mode 100644 flytekit/clients/auth/token_client.py create mode 100644 flytekit/clis/sdk_in_container/utils.py diff --git a/flytekit/clients/auth/auth_client.py b/flytekit/clients/auth/auth_client.py index 94afa13612..ec1fd4d3e1 100644 --- a/flytekit/clients/auth/auth_client.py +++ b/flytekit/clients/auth/auth_client.py @@ -269,9 +269,11 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials: raise ValueError('Expected "access_token" in response from oauth server') if "refresh_token" in response_body: refresh_token = response_body["refresh_token"] + if "expires_in" in response_body: + expires_in = response_body["expires_in"] access_token = response_body["access_token"] - return Credentials(access_token, refresh_token, self._endpoint) + return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in) def _request_access_token(self, auth_code) -> Credentials: if self._state != auth_code.state: diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 3b69a0c261..24ad904e2e 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -1,12 +1,16 @@ import base64 +import json import logging import subprocess +import time import typing from abc import abstractmethod from dataclasses import dataclass +from datetime import datetime, timedelta import requests +from . import token_client from .auth_client import AuthorizationClient from .exceptions import AccessTokenNotFoundError, AuthenticationError from .keyring import Credentials, KeyringStore @@ -80,11 +84,11 @@ class PKCEAuthenticator(Authenticator): """ def __init__( - self, - endpoint: str, - cfg_store: ClientConfigStore, - header_key: typing.Optional[str] = None, - verify: typing.Optional[typing.Union[bool, str]] = None, + self, + endpoint: str, + cfg_store: ClientConfigStore, + header_key: typing.Optional[str] = None, + verify: typing.Optional[typing.Union[bool, str]] = None, ): """ Initialize with default creds from KeyStore using the endpoint name @@ -156,15 +160,13 @@ class ClientCredentialsAuthenticator(Authenticator): This Authenticator uses ClientId and ClientSecret to authenticate """ - _utf_8 = "utf-8" - def __init__( - self, - endpoint: str, - client_id: str, - client_secret: str, - cfg_store: ClientConfigStore, - header_key: str = None, + self, + endpoint: str, + client_id: str, + client_secret: str, + cfg_store: ClientConfigStore, + header_key: str = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") @@ -175,48 +177,6 @@ def __init__( self._client_secret = client_secret super().__init__(endpoint, cfg.header_key or header_key) - @staticmethod - def get_token(token_endpoint: str, authorization_header: str, scopes: typing.List[str]) -> typing.Tuple[str, int]: - """ - :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration - in seconds - """ - headers = { - "Authorization": authorization_header, - "Cache-Control": "no-cache", - "Accept": "application/json", - "Content-Type": "application/x-www-form-urlencoded", - } - body = { - "grant_type": "client_credentials", - } - if scopes is not None: - body["scope"] = ",".join(scopes) - response = requests.post(token_endpoint, data=body, headers=headers) - if response.status_code != 200: - logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text)) - raise AuthenticationError("Non-200 received from IDP") - - response = response.json() - return response["access_token"], response["expires_in"] - - @staticmethod - def get_basic_authorization_header(client_id: str, client_secret: str) -> str: - """ - This function transforms the client id and the client secret into a header that conforms with http basic auth. - It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text - - :param client_id: str - :param client_secret: str - :rtype: str - """ - concated = "{}:{}".format(client_id, client_secret) - return "Basic {}".format( - base64.b64encode(concated.encode(ClientCredentialsAuthenticator._utf_8)).decode( - ClientCredentialsAuthenticator._utf_8 - ) - ) - def refresh_credentials(self): """ This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler @@ -230,8 +190,8 @@ def refresh_credentials(self): # Note that unlike the Pkce flow, the client ID does not come from Admin. logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}") - authorization_header = self.get_basic_authorization_header(self._client_id, self._client_secret) - token, expires_in = self.get_token(token_endpoint, authorization_header, scopes) + authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret) + token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header) logging.info("Retrieved new token, expires in {}".format(expires_in)) self._creds = Credentials(token) @@ -245,21 +205,41 @@ class DeviceCodeAuthenticator(Authenticator): - https://auth0.com/docs/get-started/authentication-and-authorization-flow/device-authorization-flow#device-flow """ - def __init__(self, - endpoint: str, - cfg_store: ClientConfigStore, - header_key: typing.Optional[str] = None, - audience: typing.Optional[str] = None): - pass - - def _get_code(self): - pass - - def _poll(self): - pass + def __init__( + self, + endpoint: str, + cfg_store: ClientConfigStore, + header_key: typing.Optional[str] = None, + audience: typing.Optional[str] = None, + ): + self._audience = audience + cfg = cfg_store.get_client_config() + self._client_id = cfg.client_id + self._device_auth_endpoint = cfg.device_authorization_endpoint + self._scope = cfg.scopes + self._token_endpoint = cfg.token_endpoint + if self._device_auth_endpoint is None: + raise ValueError("Device Authentication is not available on the Flyte backend / authentication server") + super().__init__( + endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint) + ) - def _get_token(self): - pass - def refresh_credentials(self): - pass + resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope) + print( + f""" +To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code} +OR copy paste the following URL: {resp.verification_uri_complete} + """ + ) + try: + # Currently the refresh token is not retreived. We may want to add support for refreshTokens so that + # access tokens can be refreshed for once authenticated machines + token, expires_in = token_client.poll_token_endpoint( + resp, self._token_endpoint, device_code=resp.device_code, client_id=self._client_id + ) + self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint) + KeyringStore.store(self._creds) + except Exception: + KeyringStore.delete(self._endpoint) + raise diff --git a/flytekit/clients/auth/exceptions.py b/flytekit/clients/auth/exceptions.py index 6e790e47a4..5086c5b6e1 100644 --- a/flytekit/clients/auth/exceptions.py +++ b/flytekit/clients/auth/exceptions.py @@ -12,3 +12,11 @@ class AuthenticationError(RuntimeError): """ pass + + +class AuthenticationPending(RuntimeError): + """ + This is raised if the token endpoint returns authentication pending + """ + + pass diff --git a/flytekit/clients/auth/keyring.py b/flytekit/clients/auth/keyring.py index c2b19c46b6..831558f4c2 100644 --- a/flytekit/clients/auth/keyring.py +++ b/flytekit/clients/auth/keyring.py @@ -15,6 +15,7 @@ class Credentials(object): access_token: str refresh_token: str = "na" for_endpoint: str = "flyte-default" + expires_in: typing.Optional[int] = None class KeyringStore: diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py new file mode 100644 index 0000000000..4d6bd95499 --- /dev/null +++ b/flytekit/clients/auth/token_client.py @@ -0,0 +1,149 @@ +import base64 +import enum +import logging +import time +import typing +from dataclasses import dataclass +from datetime import datetime, timedelta + +import requests + +from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending + +utf_8 = "utf-8" + +# Errors that Token endpoint will return +error_slow_down = "slow_down" +error_auth_pending = "authorization_pending" + + +# Grant Types +class GrantType(str, enum.Enum): + CLIENT_CREDS = "client_credentials" + DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code" + + +@dataclass +class DeviceCodeResponse: + """ + Response from device auth flow endpoint + {'device_code': 'code', + 'user_code': 'BNDJJFXL', + 'verification_uri': 'url', + 'verification_uri_complete': 'url', + 'expires_in': 600, + 'interval': 5} + """ + + device_code: str + user_code: str + verification_uri: str + verification_uri_complete: str + expires_in: int + interval: int + + @classmethod + def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse": + return cls( + device_code=j["device_code"], + user_code=j["user_code"], + verification_uri=j["verification_uri"], + verification_uri_complete=j["verification_uri_complete"], + expires_in=j["expires_in"], + interval=j["interval"], + ) + + +def get_basic_authorization_header(client_id: str, client_secret: str) -> str: + """ + This function transforms the client id and the client secret into a header that conforms with http basic auth. + It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text + + :param client_id: str + :param client_secret: str + :rtype: str + """ + concated = "{}:{}".format(client_id, client_secret) + return "Basic {}".format(base64.b64encode(concated.encode(utf_8)).decode(utf_8)) + + +def get_token( + token_endpoint: str, + scopes: typing.Optional[typing.List[str]] = None, + authorization_header: typing.Optional[str] = None, + client_id: typing.Optional[str] = None, + device_code: typing.Optional[str] = None, + grant_type: GrantType = GrantType.CLIENT_CREDS, +) -> typing.Tuple[str, int]: + """ + :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration + in seconds + """ + headers = { + "Cache-Control": "no-cache", + "Accept": "application/json", + "Content-Type": "application/x-www-form-urlencoded", + } + if authorization_header: + headers["Authorization"] = authorization_header + body = { + "grant_type": grant_type.value, + } + if client_id: + body["client_id"] = client_id + if device_code: + body["device_code"] = device_code + if scopes is not None: + body["scope"] = ",".join(scopes) + + response = requests.post(token_endpoint, data=body, headers=headers) + if not response.ok: + j = response.json() + if "error" in j: + err = j["error"] + if err == error_auth_pending or err == error_slow_down: + raise AuthenticationPending(f"Token not yet available, try again in some time {err}") + logging.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) + raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) + + response = response.json() + return response["access_token"], response["expires_in"] + + +def get_device_code( + device_auth_endpoint: str, + client_id: str, + audience: typing.Optional[str] = None, + scope: typing.Optional[typing.List[str]] = None, +) -> DeviceCodeResponse: + """ + Retrieves the device Authentication code that can be done to authenticate the request using a browser on a + separate device + """ + payload = {"client_id": client_id, "scope": scope, "audience": audience} + resp = requests.post(device_auth_endpoint, payload) + if not resp.ok: + raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}") + return DeviceCodeResponse.from_json_response(resp.json()) + + +def poll_token_endpoint( + resp: DeviceCodeResponse, token_endpoint: str, client_id: str, device_code: str +) -> typing.Tuple[str, int]: + tick = datetime.now() + interval = timedelta(seconds=resp.interval) + end_time = tick + timedelta(seconds=resp.expires_in) + while tick < end_time: + try: + access_token, expires_in = get_token( + token_endpoint, grant_type=GrantType.DEVICE_CODE, client_id=client_id, device_code=device_code + ) + print("Authentication successful!") + return access_token, expires_in + except AuthenticationPending: + ... + except Exception: + raise + print(f"Authentication Pending...") + time.sleep(interval.total_seconds()) + tick = tick + interval diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 41fc5c025f..44aee9096c 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -12,6 +12,7 @@ ClientConfigStore, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, ) from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor @@ -41,6 +42,7 @@ def get_client_config(self) -> ClientConfig: client_id=public_client_config.client_id, scopes=public_client_config.scopes, header_key=public_client_config.authorization_metadata_key or None, + device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint, ) @@ -78,6 +80,8 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth command=cfg.command, header_key=client_cfg.header_key if client_cfg else None, ) + elif cfg_auth == AuthType.DEVICEFLOW: + return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience) else: raise ValueError( f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value" diff --git a/flytekit/clis/sdk_in_container/utils.py b/flytekit/clis/sdk_in_container/utils.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index afed857a26..84132c6f9b 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -344,6 +344,7 @@ class AuthType(enum.Enum): CLIENTSECRET = "ClientSecret" PKCE = "Pkce" EXTERNALCOMMAND = "ExternalCommand" + DEVICEFLOW = "DeviceFlow" @dataclass(init=True, repr=True, eq=True, frozen=True) @@ -376,6 +377,7 @@ class PlatformConfig(object): client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) auth_mode: AuthType = AuthType.STANDARD + audience: typing.Optional[str] = None rpc_retries: int = 3 @classmethod From a2104f6702157c21a2ff91d559679cc20d49b113 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Thu, 16 Mar 2023 22:13:58 -0700 Subject: [PATCH 3/8] unit tests Signed-off-by: Ketan Umare --- flytekit/clients/auth/authenticator.py | 8 +- flytekit/clients/auth/token_client.py | 14 ++-- .../unit/clients/auth/test_authenticator.py | 20 +---- .../unit/clients/auth/test_token_client.py | 78 +++++++++++++++++++ .../flytekit/unit/clients/test_auth_helper.py | 16 +++- 5 files changed, 109 insertions(+), 27 deletions(-) create mode 100644 tests/flytekit/unit/clients/auth/test_token_client.py diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 24ad904e2e..4000fe60ec 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -219,7 +219,9 @@ def __init__( self._scope = cfg.scopes self._token_endpoint = cfg.token_endpoint if self._device_auth_endpoint is None: - raise ValueError("Device Authentication is not available on the Flyte backend / authentication server") + raise AuthenticationError( + "Device Authentication is not available on the Flyte backend / authentication server" + ) super().__init__( endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint) ) @@ -235,9 +237,7 @@ def refresh_credentials(self): try: # Currently the refresh token is not retreived. We may want to add support for refreshTokens so that # access tokens can be refreshed for once authenticated machines - token, expires_in = token_client.poll_token_endpoint( - resp, self._token_endpoint, device_code=resp.device_code, client_id=self._client_id - ) + token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id) self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint) KeyringStore.store(self._creds) except Exception: diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py index 4d6bd95499..7e9c13d080 100644 --- a/flytekit/clients/auth/token_client.py +++ b/flytekit/clients/auth/token_client.py @@ -106,8 +106,8 @@ def get_token( logging.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text)) - response = response.json() - return response["access_token"], response["expires_in"] + j = response.json() + return j["access_token"], j["expires_in"] def get_device_code( @@ -127,16 +127,17 @@ def get_device_code( return DeviceCodeResponse.from_json_response(resp.json()) -def poll_token_endpoint( - resp: DeviceCodeResponse, token_endpoint: str, client_id: str, device_code: str -) -> typing.Tuple[str, int]: +def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]: tick = datetime.now() interval = timedelta(seconds=resp.interval) end_time = tick + timedelta(seconds=resp.expires_in) while tick < end_time: try: access_token, expires_in = get_token( - token_endpoint, grant_type=GrantType.DEVICE_CODE, client_id=client_id, device_code=device_code + token_endpoint, + grant_type=GrantType.DEVICE_CODE, + client_id=client_id, + device_code=resp.device_code, ) print("Authentication successful!") return access_token, expires_in @@ -147,3 +148,4 @@ def poll_token_endpoint( print(f"Authentication Pending...") time.sleep(interval.total_seconds()) tick = tick + interval + raise AuthenticationError("Authentication failed!") diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 4c968cf0bd..ab1eb7356f 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -65,22 +65,6 @@ def test_command_authenticator(mock_subprocess: MagicMock): authn.refresh_credentials() -def test_get_basic_authorization_header(): - header = ClientCredentialsAuthenticator.get_basic_authorization_header("client_id", "abc") - assert header == "Basic Y2xpZW50X2lkOmFiYw==" - - -@patch("flytekit.clients.auth.authenticator.requests") -def test_get_token(mock_requests): - response = MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response - access, expiration = ClientCredentialsAuthenticator.get_token("https://corp.idp.net", "abc123", ["my_scope"]) - assert access == "abc" - assert expiration == 60 - - @patch("flytekit.clients.auth.authenticator.requests") def test_client_creds_authenticator(mock_requests): authn = ClientCredentialsAuthenticator( @@ -93,3 +77,7 @@ def test_client_creds_authenticator(mock_requests): mock_requests.post.return_value = response authn.refresh_credentials() assert authn._creds + + +def test_device_flow_authenticator(): + pass diff --git a/tests/flytekit/unit/clients/auth/test_token_client.py b/tests/flytekit/unit/clients/auth/test_token_client.py new file mode 100644 index 0000000000..4c2b90ed97 --- /dev/null +++ b/tests/flytekit/unit/clients/auth/test_token_client.py @@ -0,0 +1,78 @@ +import json +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + +from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth.token_client import ( + DeviceCodeResponse, + error_auth_pending, + get_basic_authorization_header, + get_device_code, + get_token, + poll_token_endpoint, +) + + +def test_get_basic_authorization_header(): + header = get_basic_authorization_header("client_id", "abc") + assert header == "Basic Y2xpZW50X2lkOmFiYw==" + + +@patch("flytekit.clients.auth.token_client.requests") +def test_get_token(mock_requests): + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + mock_requests.post.return_value = response + access, expiration = get_token("https://corp.idp.net", client_id="abc123", scopes=["my_scope"]) + assert access == "abc" + assert expiration == 60 + + +@patch("flytekit.clients.auth.token_client.requests") +def test_get_device_code(mock_requests): + response = MagicMock() + response.ok = False + mock_requests.post.return_value = response + with pytest.raises(AuthenticationError): + get_device_code("test.com", "test") + + response.ok = True + response.json.return_value = { + "device_code": "code", + "user_code": "BNDJJFXL", + "verification_uri": "url", + "verification_uri_complete": "url", + "expires_in": 600, + "interval": 5, + } + mock_requests.post.return_value = response + c = get_device_code("test.com", "test") + assert c + assert c.device_code == "code" + + +@patch("flytekit.clients.auth.token_client.requests") +def test_poll_token_endpoint(mock_requests): + response = MagicMock() + response.ok = False + response.json.return_value = {"error": error_auth_pending} + mock_requests.post.return_value = response + + r = DeviceCodeResponse( + device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=1 + ) + with pytest.raises(AuthenticationError): + poll_token_endpoint(r, "test.com", "test") + + response = MagicMock() + response.ok = True + response.json.return_value = {"access_token": "abc", "expires_in": 60} + mock_requests.post.return_value = response + r = DeviceCodeResponse( + device_code="x", user_code="y", verification_uri="v", verification_uri_complete="v1", expires_in=1, interval=0 + ) + t, e = poll_token_endpoint(r, "test.com", "test") + assert t + assert e diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index 8f14de730e..e3c87e033f 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -1,4 +1,5 @@ import os.path +import typing from unittest.mock import MagicMock, patch import pytest @@ -9,6 +10,7 @@ ClientConfigStore, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, ) from flytekit.clients.auth.exceptions import AuthenticationError @@ -31,6 +33,8 @@ OAUTH_AUTHORIZE = "https://your.domain.io/oauth2/authorize" +DEVICE_AUTH_ENDPOINT = "https://your.domain.io/..." + def get_auth_service_mock() -> MagicMock: auth_stub_mock = MagicMock() @@ -66,13 +70,14 @@ def test_remote_client_config_store(mock_auth_service: MagicMock): assert ccfg.authorization_endpoint == OAUTH_AUTHORIZE -def get_client_config() -> ClientConfigStore: +def get_client_config(**kwargs) -> ClientConfigStore: cfg_store = MagicMock() cfg_store.get_client_config.return_value = ClientConfig( token_endpoint=TOKEN_ENDPOINT, authorization_endpoint=OAUTH_AUTHORIZE, redirect_uri=REDIRECT_URI, client_id=CLIENT_ID, + **kwargs ) return cfg_store @@ -135,6 +140,15 @@ def test_get_authenticator_cmd(): assert authn._cmd == ["echo"] +def test_get_authenticator_deviceflow(): + cfg = PlatformConfig(auth_mode=AuthType.DEVICEFLOW) + with pytest.raises(AuthenticationError): + get_authenticator(cfg, get_client_config()) + + authn = get_authenticator(cfg, get_client_config(device_authorization_endpoint=DEVICE_AUTH_ENDPOINT)) + assert isinstance(authn, DeviceCodeAuthenticator) + + def test_wrap_exceptions_channel(): ch = MagicMock() out_ch = wrap_exceptions_channel(PlatformConfig(), ch) From 2269e6b0c1a1c6b0a676b3f17d42c9b6ba46162b Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Thu, 23 Mar 2023 22:25:59 -0700 Subject: [PATCH 4/8] test added Signed-off-by: Ketan Umare --- .../unit/clients/auth/test_authenticator.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index ab1eb7356f..e2e8eb115b 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -8,10 +8,12 @@ ClientConfig, ClientCredentialsAuthenticator, CommandAuthenticator, + DeviceCodeAuthenticator, PKCEAuthenticator, StaticClientConfigStore, ) from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth.token_client import DeviceCodeResponse ENDPOINT = "example.com" @@ -79,5 +81,33 @@ def test_client_creds_authenticator(mock_requests): assert authn._creds -def test_device_flow_authenticator(): - pass +@patch("flytekit.clients.auth.authenticator.KeyringStore") +@patch("flytekit.clients.auth.token_client.get_device_code") +@patch("flytekit.clients.auth.token_client.poll_token_endpoint") +def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock): + with pytest.raises(AuthenticationError): + DeviceCodeAuthenticator( + ENDPOINT, + static_cfg_store, + audience="x", + ) + + cfg_store = StaticClientConfigStore( + ClientConfig( + token_endpoint="token_endpoint", + authorization_endpoint="auth_endpoint", + redirect_uri="redirect_uri", + client_id="client", + device_authorization_endpoint="dev", + ) + ) + authn = DeviceCodeAuthenticator( + ENDPOINT, + cfg_store, + audience="x", + ) + + device_mock.return_value = DeviceCodeResponse("x", "y", "s", "m", 1000, 0) + poll_mock.return_value = ("access", 100) + authn.refresh_credentials() + assert authn._creds From 7b89ff0bdaf0d2747753d977c75277a310da0643 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Sun, 26 Mar 2023 21:57:54 -0700 Subject: [PATCH 5/8] updated Signed-off-by: Ketan Umare --- tests/flytekit/unit/clients/auth/test_authenticator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index e2e8eb115b..a3f6df9c6b 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -67,7 +67,7 @@ def test_command_authenticator(mock_subprocess: MagicMock): authn.refresh_credentials() -@patch("flytekit.clients.auth.authenticator.requests") +@patch("flytekit.clients.auth.token_client.requests") def test_client_creds_authenticator(mock_requests): authn = ClientCredentialsAuthenticator( ENDPOINT, client_id="client", client_secret="secret", cfg_store=static_cfg_store From c4c22838fd0db20220444ffd5231c97e41d91768 Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Mon, 27 Mar 2023 17:03:53 -0700 Subject: [PATCH 6/8] Fixed test Signed-off-by: Ketan Umare --- tests/flytekit/unit/clients/auth/test_authenticator.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 2a78e43bae..d4ea52edba 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -113,8 +113,6 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, poll_mock.return_value = ("access", 100) authn.refresh_credentials() assert authn._creds -======= - assert authn._scopes == expected_scopes @patch("flytekit.clients.auth.authenticator.requests") @@ -135,4 +133,3 @@ def test_client_creds_authenticator_with_custom_scopes(mock_requests): assert authn._creds assert authn._scopes == expected_scopes ->>>>>>> master From 728a81f99908e31fab16e22329f52f01a56ca33c Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Mon, 27 Mar 2023 17:32:31 -0700 Subject: [PATCH 7/8] Fixed unit test Signed-off-by: Ketan Umare --- tests/flytekit/unit/clients/auth/test_authenticator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index d4ea52edba..fdbddb2ebe 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -115,7 +115,7 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, assert authn._creds -@patch("flytekit.clients.auth.authenticator.requests") +@patch("flytekit.clients.auth.token_client.requests") def test_client_creds_authenticator_with_custom_scopes(mock_requests): expected_scopes = ["foo", "baz"] authn = ClientCredentialsAuthenticator( From 94a7bb5c7e7913853bb329f02b13d20b3976ae8f Mon Sep 17 00:00:00 2001 From: Ketan Umare Date: Mon, 27 Mar 2023 17:36:39 -0700 Subject: [PATCH 8/8] Fixed lint errors Signed-off-by: Ketan Umare --- flytekit/clients/auth/authenticator.py | 6 ------ flytekit/clients/auth/token_client.py | 2 +- tests/flytekit/unit/clients/auth/test_token_client.py | 2 +- tests/flytekit/unit/clients/test_auth_helper.py | 1 - 4 files changed, 2 insertions(+), 9 deletions(-) diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index 82842b9368..9582c901d8 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -1,14 +1,8 @@ -import base64 -import json import logging import subprocess -import time import typing from abc import abstractmethod from dataclasses import dataclass -from datetime import datetime, timedelta - -import requests from . import token_client from .auth_client import AuthorizationClient diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py index 7e9c13d080..e7e55f74a9 100644 --- a/flytekit/clients/auth/token_client.py +++ b/flytekit/clients/auth/token_client.py @@ -145,7 +145,7 @@ def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id ... except Exception: raise - print(f"Authentication Pending...") + print("Authentication Pending...") time.sleep(interval.total_seconds()) tick = tick + interval raise AuthenticationError("Authentication failed!") diff --git a/tests/flytekit/unit/clients/auth/test_token_client.py b/tests/flytekit/unit/clients/auth/test_token_client.py index 4c2b90ed97..6e56c351bc 100644 --- a/tests/flytekit/unit/clients/auth/test_token_client.py +++ b/tests/flytekit/unit/clients/auth/test_token_client.py @@ -1,5 +1,5 @@ import json -from unittest.mock import MagicMock, PropertyMock, patch +from unittest.mock import MagicMock, patch import pytest diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index e3c87e033f..3bd57918f4 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -1,5 +1,4 @@ import os.path -import typing from unittest.mock import MagicMock, patch import pytest