Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Device auth flow / Headless auth #1552

Merged
merged 10 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flytekit/clients/auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,11 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials:
raise ValueError('Expected "access_token" in response from oauth server')
if "refresh_token" in response_body:
refresh_token = response_body["refresh_token"]
if "expires_in" in response_body:
expires_in = response_body["expires_in"]
access_token = response_body["access_token"]

return Credentials(access_token, refresh_token, self._endpoint)
return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in)

def _request_access_token(self, auth_code) -> Credentials:
if self._state != auth_code.state:
Expand Down
102 changes: 56 additions & 46 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import base64
import json
import logging
import subprocess
import time
import typing
from abc import abstractmethod
from dataclasses import dataclass
from datetime import datetime, timedelta

import requests

from . import token_client
from .auth_client import AuthorizationClient
from .exceptions import AccessTokenNotFoundError, AuthenticationError
from .keyring import Credentials, KeyringStore
Expand All @@ -22,6 +26,7 @@ class ClientConfig:
authorization_endpoint: str
redirect_uri: str
client_id: str
device_authorization_endpoint: typing.Optional[str] = None
scopes: typing.List[str] = None
header_key: str = "authorization"

Expand Down Expand Up @@ -155,8 +160,6 @@ class ClientCredentialsAuthenticator(Authenticator):
This Authenticator uses ClientId and ClientSecret to authenticate
"""

_utf_8 = "utf-8"

def __init__(
self,
endpoint: str,
Expand All @@ -174,48 +177,6 @@ def __init__(
self._client_secret = client_secret
super().__init__(endpoint, cfg.header_key or header_key)

@staticmethod
def get_token(token_endpoint: str, authorization_header: str, scopes: typing.List[str]) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
in seconds
"""
headers = {
"Authorization": authorization_header,
"Cache-Control": "no-cache",
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
body = {
"grant_type": "client_credentials",
}
if scopes is not None:
body["scope"] = ",".join(scopes)
response = requests.post(token_endpoint, data=body, headers=headers)
if response.status_code != 200:
logging.error("Non-200 ({}) received from IDP: {}".format(response.status_code, response.text))
raise AuthenticationError("Non-200 received from IDP")

response = response.json()
return response["access_token"], response["expires_in"]

@staticmethod
def get_basic_authorization_header(client_id: str, client_secret: str) -> str:
"""
This function transforms the client id and the client secret into a header that conforms with http basic auth.
It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text

:param client_id: str
:param client_secret: str
:rtype: str
"""
concated = "{}:{}".format(client_id, client_secret)
return "Basic {}".format(
base64.b64encode(concated.encode(ClientCredentialsAuthenticator._utf_8)).decode(
ClientCredentialsAuthenticator._utf_8
)
)

def refresh_credentials(self):
"""
This function is used by the _handle_rpc_error() decorator, depending on the AUTH_MODE config object. This handler
Expand All @@ -229,7 +190,56 @@ def refresh_credentials(self):

# Note that unlike the Pkce flow, the client ID does not come from Admin.
logging.debug(f"Basic authorization flow with client id {self._client_id} scope {scopes}")
authorization_header = self.get_basic_authorization_header(self._client_id, self._client_secret)
token, expires_in = self.get_token(token_endpoint, authorization_header, scopes)
authorization_header = token_client.get_basic_authorization_header(self._client_id, self._client_secret)
token, expires_in = token_client.get_token(token_endpoint, scopes, authorization_header)
logging.info("Retrieved new token, expires in {}".format(expires_in))
self._creds = Credentials(token)


class DeviceCodeAuthenticator(Authenticator):
"""
This Authenticator implements the Device Code authorization flow useful for headless user authentication.

Examples described
- https://developer.okta.com/docs/guides/device-authorization-grant/main/
- https://auth0.com/docs/get-started/authentication-and-authorization-flow/device-authorization-flow#device-flow
"""

def __init__(
self,
endpoint: str,
cfg_store: ClientConfigStore,
header_key: typing.Optional[str] = None,
audience: typing.Optional[str] = None,
):
self._audience = audience
cfg = cfg_store.get_client_config()
self._client_id = cfg.client_id
self._device_auth_endpoint = cfg.device_authorization_endpoint
self._scope = cfg.scopes
self._token_endpoint = cfg.token_endpoint
if self._device_auth_endpoint is None:
raise AuthenticationError(
"Device Authentication is not available on the Flyte backend / authentication server"
)
super().__init__(
endpoint=endpoint, header_key=header_key or cfg.header_key, credentials=KeyringStore.retrieve(endpoint)
)

def refresh_credentials(self):
resp = token_client.get_device_code(self._device_auth_endpoint, self._client_id, self._audience, self._scope)
print(
f"""
To Authenticate navigate in a browser to the following URL: {resp.verification_uri} and enter code: {resp.user_code}
OR copy paste the following URL: {resp.verification_uri_complete}
"""
)
try:
# Currently the refresh token is not retreived. We may want to add support for refreshTokens so that
# access tokens can be refreshed for once authenticated machines
token, expires_in = token_client.poll_token_endpoint(resp, self._token_endpoint, client_id=self._client_id)
self._creds = Credentials(access_token=token, expires_in=expires_in, for_endpoint=self._endpoint)
KeyringStore.store(self._creds)
except Exception:
KeyringStore.delete(self._endpoint)
raise
8 changes: 8 additions & 0 deletions flytekit/clients/auth/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,11 @@ class AuthenticationError(RuntimeError):
"""

pass


class AuthenticationPending(RuntimeError):
"""
This is raised if the token endpoint returns authentication pending
"""

pass
1 change: 1 addition & 0 deletions flytekit/clients/auth/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Credentials(object):
access_token: str
refresh_token: str = "na"
for_endpoint: str = "flyte-default"
expires_in: typing.Optional[int] = None


class KeyringStore:
Expand Down
151 changes: 151 additions & 0 deletions flytekit/clients/auth/token_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import base64
import enum
import logging
import time
import typing
from dataclasses import dataclass
from datetime import datetime, timedelta

import requests

from flytekit.clients.auth.exceptions import AuthenticationError, AuthenticationPending

utf_8 = "utf-8"

# Errors that Token endpoint will return
error_slow_down = "slow_down"
error_auth_pending = "authorization_pending"


# Grant Types
class GrantType(str, enum.Enum):
CLIENT_CREDS = "client_credentials"
DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"


@dataclass
class DeviceCodeResponse:
"""
Response from device auth flow endpoint
{'device_code': 'code',
'user_code': 'BNDJJFXL',
'verification_uri': 'url',
'verification_uri_complete': 'url',
'expires_in': 600,
'interval': 5}
"""

device_code: str
user_code: str
verification_uri: str
verification_uri_complete: str
expires_in: int
interval: int

@classmethod
def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse":
return cls(
device_code=j["device_code"],
user_code=j["user_code"],
verification_uri=j["verification_uri"],
verification_uri_complete=j["verification_uri_complete"],
expires_in=j["expires_in"],
interval=j["interval"],
)


def get_basic_authorization_header(client_id: str, client_secret: str) -> str:
"""
This function transforms the client id and the client secret into a header that conforms with http basic auth.
It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text

:param client_id: str
:param client_secret: str
:rtype: str
"""
concated = "{}:{}".format(client_id, client_secret)
return "Basic {}".format(base64.b64encode(concated.encode(utf_8)).decode(utf_8))


def get_token(
token_endpoint: str,
scopes: typing.Optional[typing.List[str]] = None,
authorization_header: typing.Optional[str] = None,
client_id: typing.Optional[str] = None,
device_code: typing.Optional[str] = None,
grant_type: GrantType = GrantType.CLIENT_CREDS,
) -> typing.Tuple[str, int]:
"""
:rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration
in seconds
"""
headers = {
"Cache-Control": "no-cache",
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
if authorization_header:
headers["Authorization"] = authorization_header
body = {
"grant_type": grant_type.value,
}
if client_id:
body["client_id"] = client_id
if device_code:
body["device_code"] = device_code
if scopes is not None:
body["scope"] = ",".join(scopes)

response = requests.post(token_endpoint, data=body, headers=headers)
if not response.ok:
j = response.json()
if "error" in j:
err = j["error"]
if err == error_auth_pending or err == error_slow_down:
raise AuthenticationPending(f"Token not yet available, try again in some time {err}")
logging.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))
raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))

j = response.json()
return j["access_token"], j["expires_in"]


def get_device_code(
device_auth_endpoint: str,
client_id: str,
audience: typing.Optional[str] = None,
scope: typing.Optional[typing.List[str]] = None,
) -> DeviceCodeResponse:
"""
Retrieves the device Authentication code that can be done to authenticate the request using a browser on a
separate device
"""
payload = {"client_id": client_id, "scope": scope, "audience": audience}
resp = requests.post(device_auth_endpoint, payload)
if not resp.ok:
raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}")
return DeviceCodeResponse.from_json_response(resp.json())


def poll_token_endpoint(resp: DeviceCodeResponse, token_endpoint: str, client_id: str) -> typing.Tuple[str, int]:
tick = datetime.now()
interval = timedelta(seconds=resp.interval)
end_time = tick + timedelta(seconds=resp.expires_in)
while tick < end_time:
try:
access_token, expires_in = get_token(
token_endpoint,
grant_type=GrantType.DEVICE_CODE,
client_id=client_id,
device_code=resp.device_code,
)
print("Authentication successful!")
return access_token, expires_in
except AuthenticationPending:
...
except Exception:
raise
print(f"Authentication Pending...")
time.sleep(interval.total_seconds())
tick = tick + interval
raise AuthenticationError("Authentication failed!")
4 changes: 4 additions & 0 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ClientConfigStore,
ClientCredentialsAuthenticator,
CommandAuthenticator,
DeviceCodeAuthenticator,
PKCEAuthenticator,
)
from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor
Expand Down Expand Up @@ -41,6 +42,7 @@ def get_client_config(self) -> ClientConfig:
client_id=public_client_config.client_id,
scopes=public_client_config.scopes,
header_key=public_client_config.authorization_metadata_key or None,
device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint,
)


Expand Down Expand Up @@ -78,6 +80,8 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth
command=cfg.command,
header_key=client_cfg.header_key if client_cfg else None,
)
elif cfg_auth == AuthType.DEVICEFLOW:
return DeviceCodeAuthenticator(endpoint=cfg.endpoint, cfg_store=cfg_store, audience=cfg.audience)
else:
raise ValueError(
f"Invalid auth mode [{cfg_auth}] specified." f"Please update the creds config to use a valid value"
Expand Down
Empty file.
2 changes: 2 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ class AuthType(enum.Enum):
CLIENTSECRET = "ClientSecret"
PKCE = "Pkce"
EXTERNALCOMMAND = "ExternalCommand"
DEVICEFLOW = "DeviceFlow"


@dataclass(init=True, repr=True, eq=True, frozen=True)
Expand Down Expand Up @@ -376,6 +377,7 @@ class PlatformConfig(object):
client_credentials_secret: typing.Optional[str] = None
scopes: List[str] = field(default_factory=list)
auth_mode: AuthType = AuthType.STANDARD
audience: typing.Optional[str] = None
rpc_retries: int = 3

@classmethod
Expand Down
20 changes: 4 additions & 16 deletions tests/flytekit/unit/clients/auth/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,6 @@ def test_command_authenticator(mock_subprocess: MagicMock):
authn.refresh_credentials()


def test_get_basic_authorization_header():
header = ClientCredentialsAuthenticator.get_basic_authorization_header("client_id", "abc")
assert header == "Basic Y2xpZW50X2lkOmFiYw=="


@patch("flytekit.clients.auth.authenticator.requests")
def test_get_token(mock_requests):
response = MagicMock()
response.status_code = 200
response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""")
mock_requests.post.return_value = response
access, expiration = ClientCredentialsAuthenticator.get_token("https://corp.idp.net", "abc123", ["my_scope"])
assert access == "abc"
assert expiration == 60


@patch("flytekit.clients.auth.authenticator.requests")
def test_client_creds_authenticator(mock_requests):
authn = ClientCredentialsAuthenticator(
Expand All @@ -93,3 +77,7 @@ def test_client_creds_authenticator(mock_requests):
mock_requests.post.return_value = response
authn.refresh_credentials()
assert authn._creds


def test_device_flow_authenticator():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can delete

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added test

pass
Loading