Skip to content

Commit

Permalink
Device auth flow / Headless auth (#1552)
Browse files Browse the repository at this point in the history
* Device auth flow

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* Device AuthFlow is now available in flytekit

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* unit tests

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* test added

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* updated

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* Fixed test

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* Fixed unit test

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

* Fixed lint errors

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>

---------

Signed-off-by: Ketan Umare <ketan.umare@gmail.com>
  • Loading branch information
kumare3 committed Mar 28, 2023
1 parent 6371d31 commit 28ce939
Show file tree
Hide file tree
Showing 11 changed files with 351 additions and 71 deletions.
4 changes: 3 additions & 1 deletion flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,11 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
refresh_token = response_body["refresh_token"]
if "expires_in" in response_body:
expires_in = response_body["expires_in"]
access_token = response_body["access_token"]

return Credentials(access_token, refresh_token, self._endpoint)
return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in)

def _request_access_token(self, auth_code) -> Credentials:
if self._state != auth_code.state:
Expand Down
102 changes: 53 additions & 49 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import base64
import logging
import subprocess
import typing
from abc import abstractmethod
from dataclasses import dataclass

import requests

from . import token_client
from .auth_client import AuthorizationClient
from .exceptions import AccessTokenNotFoundError, AuthenticationError
from .keyring import Credentials, KeyringStore
Expand All @@ -22,6 +20,7 @@ class ClientConfig:
authorization_endpoint: str
redirect_uri: str
client_id: str
device_authorization_endpoint: typing.Optional[str] = None
scopes: typing.List[str] = None
header_key: str = "authorization"

Expand Down Expand Up @@ -155,8 +154,6 @@ class ClientCredentialsAuthenticator(Authenticator):
This Authenticator uses ClientId and ClientSecret to authenticate
"""

_utf_8 = "utf-8"

def __init__(
self,
endpoint: str,
Expand All @@ -176,48 +173,6 @@ def __init__(
self._client_secret = client_secret
super().__init__(endpoint, cfg.header_key or header_key)

@staticmethod
def get_token(token_endpoint: str, authorization_header: str, scopes: typing.List[str]) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
in seconds
"""
headers = {
"Authorization": authorization_header,
"Cache-Control": "no-cache",
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
body = {
"grant_type": "client_credentials",
}
if scopes is not None:
body["scope"] = ",".join(scopes)
response = requests.post(token_endpoint, data=body, headers=headers)
if response.status_code != 200:
logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text))
raise AuthenticationError("Non-200 received from IDP")

response = response.json()
return response["access_token"], response["expires_in"]

@staticmethod
def get_basic_authorization_header(client_id: str, client_secret: str) -> str:
"""
This function transforms the client id and the client secret into a header that conforms with http basic auth.
It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text
:param client_id: str
:param client_secret: str
:rtype: str
"""
concated = "{}:{}".format(client_id, client_secret)
return "Basic {}".format(
base64.b64encode(concated.encode(ClientCredentialsAuthenticator._utf_8)).decode(
ClientCredentialsAuthenticator._utf_8
)
)

def refresh_credentials(self):
"""
This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler
Expand All @@ -231,7 +186,56 @@ def refresh_credentials(self):

# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = self.get_basic_authorization_header(self._client_id, self._client_secret)
token, expires_in = self.get_token(token_endpoint, authorization_header, scopes)
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)
token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header)
logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)


class DeviceCodeAuthenticator(Authenticator):
"""
This Authenticator implements the Device Code authorization flow useful for headless user authentication.
Examples described
- https://developer.okta.com/docs/guides/device-authorization-grant/main/
- https://auth0.com/docs/get-started/authentication-and-authorization-flow/device-authorization-flow#device-flow
"""

def __init__(
self,
endpoint: str,
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
self._client_id = cfg.client_id
self._device_auth_endpoint = cfg.device_authorization_endpoint
self._scope = cfg.scopes
self._token_endpoint = cfg.token_endpoint
if self._device_auth_endpoint is None:
raise AuthenticationError(
"Device Authentication is not available on the Flyte backend / authentication server"
)
super().__init__(
endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint)
)

def refresh_credentials(self):
resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope)
print(
f"""
To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code}
OR copy paste the following URL: {resp.verification_uri_complete}
"""
)
try:
# Currently the refresh token is not retreived. We may want to add support for refreshTokens so that
# access tokens can be refreshed for once authenticated machines
token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id)
self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint)
KeyringStore.store(self._creds)
except Exception:
KeyringStore.delete(self._endpoint)
raise
8 changes: 8 additions & 0 deletions flytekit/clients/auth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,11 @@ class AuthenticationError(RuntimeError):
"""

pass


class AuthenticationPending(RuntimeError):
"""
This is raised if the token endpoint returns authentication pending
"""

pass
1 change: 1 addition & 0 deletions flytekit/clients/auth/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Credentials(object):
access_token: str
refresh_token: str = "na"
for_endpoint: str = "flyte-default"
expires_in: typing.Optional[int] = None


class KeyringStore:
Expand Down
151 changes: 151 additions & 0 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import base64
import enum
import logging
import time
import typing
from dataclasses import dataclass
from datetime import datetime, timedelta

import requests

from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending

utf_8 = "utf-8"

# Errors that Token endpoint will return
error_slow_down = "slow_down"
error_auth_pending = "authorization_pending"


# Grant Types
class GrantType(str, enum.Enum):
CLIENT_CREDS = "client_credentials"
DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"


@dataclass
class DeviceCodeResponse:
"""
Response from device auth flow endpoint
{'device_code': 'code',
'user_code': 'BNDJJFXL',
'verification_uri': 'url',
'verification_uri_complete': 'url',
'expires_in': 600,
'interval': 5}
"""

device_code: str
user_code: str
verification_uri: str
verification_uri_complete: str
expires_in: int
interval: int

@classmethod
def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse":
return cls(
device_code=j["device_code"],
user_code=j["user_code"],
verification_uri=j["verification_uri"],
verification_uri_complete=j["verification_uri_complete"],
expires_in=j["expires_in"],
interval=j["interval"],
)


def get_basic_authorization_header(client_id: str, client_secret: str) -> str:
"""
This function transforms the client id and the client secret into a header that conforms with http basic auth.
It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text
:param client_id: str
:param client_secret: str
:rtype: str
"""
concated = "{}:{}".format(client_id, client_secret)
return "Basic {}".format(base64.b64encode(concated.encode(utf_8)).decode(utf_8))


def get_token(
token_endpoint: str,
scopes: typing.Optional[typing.List[str]] = None,
authorization_header: typing.Optional[str] = None,
client_id: typing.Optional[str] = None,
device_code: typing.Optional[str] = None,
grant_type: GrantType = GrantType.CLIENT_CREDS,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
in seconds
"""
headers = {
"Cache-Control": "no-cache",
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
if authorization_header:
headers["Authorization"] = authorization_header
body = {
"grant_type": grant_type.value,
}
if client_id:
body["client_id"] = client_id
if device_code:
body["device_code"] = device_code
if scopes is not None:
body["scope"] = ",".join(scopes)

response = requests.post(token_endpoint, data=body, headers=headers)
if not response.ok:
j = response.json()
if "error" in j:
err = j["error"]
if err == error_auth_pending or err == error_slow_down:
raise AuthenticationPending(f"Token not yet available, try again in some time {err}")
logging.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))
raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))

j = response.json()
return j["access_token"], j["expires_in"]


def get_device_code(
device_auth_endpoint: str,
client_id: str,
audience: typing.Optional[str] = None,
scope: typing.Optional[typing.List[str]] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
separate device
"""
payload = {"client_id": client_id, "scope": scope, "audience": audience}
resp = requests.post(device_auth_endpoint, payload)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())


def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]:
tick = datetime.now()
interval = timedelta(seconds=resp.interval)
end_time = tick + timedelta(seconds=resp.expires_in)
while tick < end_time:
try:
access_token, expires_in = get_token(
token_endpoint,
grant_type=GrantType.DEVICE_CODE,
client_id=client_id,
device_code=resp.device_code,
)
print("Authentication successful!")
return access_token, expires_in
except AuthenticationPending:
...
except Exception:
raise
print("Authentication Pending...")
time.sleep(interval.total_seconds())
tick = tick + interval
raise AuthenticationError("Authentication failed!")
4 changes: 4 additions & 0 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ClientConfigStore,
ClientCredentialsAuthenticator,
CommandAuthenticator,
DeviceCodeAuthenticator,
PKCEAuthenticator,
)
from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor
Expand Down Expand Up @@ -41,6 +42,7 @@ def get_client_config(self) -> ClientConfig:
client_id=public_client_config.client_id,
scopes=public_client_config.scopes,
header_key=public_client_config.authorization_metadata_key or None,
device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint,
)


Expand Down Expand Up @@ -79,6 +81,8 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
command=cfg.command,
header_key=client_cfg.header_key if client_cfg else None,
)
elif cfg_auth == AuthType.DEVICEFLOW:
return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience)
else:
raise ValueError(
f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value"
Expand Down
Empty file.
2 changes: 2 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ class AuthType(enum.Enum):
CLIENTSECRET = "ClientSecret"
PKCE = "Pkce"
EXTERNALCOMMAND = "ExternalCommand"
DEVICEFLOW = "DeviceFlow"


@dataclass(init=True, repr=True, eq=True, frozen=True)
Expand Down Expand Up @@ -376,6 +377,7 @@ class PlatformConfig(object):
client_credentials_secret: typing.Optional[str] = None
scopes: List[str] = field(default_factory=list)
auth_mode: AuthType = AuthType.STANDARD
audience: typing.Optional[str] = None
rpc_retries: int = 3

@classmethod
Expand Down
Loading

0 comments on commit 28ce939

Please sign in to comment.