Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for acquiring a token with a pre-signed JWT #271

Merged
merged 7 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ def __init__(
"The provided signature value did not match the expected signature value",
you may try use only the leaf cert (in PEM/str format) instead.

*Added in version 1.13.0*:
It can also be a completly pre-signed assertion that you've assembled yourself.
Simply pass a container containing only the key "client_assertion", like this:
lochiiconnectivity marked this conversation as resolved.
Show resolved Hide resolved

{
"client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..."
}

:param dict client_claims:
*Added in version 0.5.0*:
It is a dictionary of extra claims that would be signed by
Expand Down Expand Up @@ -256,28 +264,32 @@ def _build_client(self, client_credential, authority):
default_headers['x-app-ver'] = self.app_version
default_body = {"client_info": 1}
if isinstance(client_credential, dict):
assert ("private_key" in client_credential
and "thumbprint" in client_credential)
headers = {}
if 'public_certificate' in client_credential:
headers["x5c"] = extract_certs(client_credential['public_certificate'])
if not client_credential.get("passphrase"):
unencrypted_private_key = client_credential['private_key']
else:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
unencrypted_private_key = serialization.load_pem_private_key(
_str2bytes(client_credential["private_key"]),
_str2bytes(client_credential["passphrase"]),
backend=default_backend(), # It was a required param until 2020
)
assertion = JwtAssertionCreator(
unencrypted_private_key, algorithm="RS256",
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
client_assertion = assertion.create_regenerative_assertion(
audience=authority.token_endpoint, issuer=self.client_id,
additional_claims=self.client_claims or {})
assert (("private_key" in client_credential
and "thumbprint" in client_credential) or
"client_assertion" in client_credential)
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT
if 'client_assertion' in client_credential:
client_assertion = client_credential['client_assertion']
else:
headers = {}
if 'public_certificate' in client_credential:
headers["x5c"] = extract_certs(client_credential['public_certificate'])
if not client_credential.get("passphrase"):
unencrypted_private_key = client_credential['private_key']
else:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
unencrypted_private_key = serialization.load_pem_private_key(
_str2bytes(client_credential["private_key"]),
_str2bytes(client_credential["passphrase"]),
backend=default_backend(), # It was a required param until 2020
)
assertion = JwtAssertionCreator(
unencrypted_private_key, algorithm="RS256",
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
client_assertion = assertion.create_regenerative_assertion(
audience=authority.token_endpoint, issuer=self.client_id,
additional_claims=self.client_claims or {})
else:
default_body['client_secret'] = client_credential
server_configuration = {
Expand Down
134 changes: 134 additions & 0 deletions sample/vault_jwt_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
The configuration file would look like this (sans those // comments):
{
"tenant": "your_tenant_name",
// Your target tenant, DNS name
"client_id": "your_client_id",
// Target app ID in Azure AD
"scope": ["https://graph.microsoft.com/.default"],
// Specific to Client Credentials Grant i.e. acquire_token_for_client(),
// you don't specify, in the code, the individual scopes you want to access.
// Instead, you statically declared them when registering your application.
// Therefore the only possible scope is "resource/.default"
// (here "https://graph.microsoft.com/.default")
// which means "the static permissions defined in the application".
"vault_tenant": "your_vault_tenant_name",
// Your Vault tenant may be different to your target tenant
// If that's not the case, you can set this to the same
// as "tenant"
"vault_clientid": "your_vault_client_id",
// Client ID of your vault app in your vault tenant
"vault_clientsecret": "your_vault_client_secret",
// Secret for your vault app
"vault_url": "your_vault_url",
// URL of your vault app
"cert": "your_cert_name",
// Name of your certificate in your vault
"cert_thumb": "your_cert_thumbprint",
// Thumbprint of your certificate
"endpoint": "https://graph.microsoft.com/v1.0/users"
// For this resource to work, you need to visit Application Permissions
// page in portal, declare scope User.Read.All, which needs admin consent
// https://github.com/Azure-Samples/ms-identity-python-daemon/blob/master/2-Call-MsGraph-WithCertificate/README.md
}
You can then run this sample with a JSON configuration file:
python sample.py parameters.json
"""

import base64
import json
import logging
import requests
import sys
import time
import uuid
import msal

# Optional logging
# logging.basicConfig(level=logging.DEBUG) # Enable DEBUG log for entire script
# logging.getLogger("msal").setLevel(logging.INFO) # Optionally disable MSAL DEBUG logs

from azure.keyvault import KeyVaultClient, KeyVaultAuthentication
from azure.common.credentials import ServicePrincipalCredentials
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes

config = json.load(open(sys.argv[1]))

def auth_vault_callback(server, resource, scope):
credentials = ServicePrincipalCredentials(
client_id=config['vault_clientid'],
secret=config['vault_clientsecret'],
tenant=config['vault_tenant'],
resource='https://vault.azure.net'
)
token = credentials.token
return token['token_type'], token['access_token']


def make_vault_jwt():

header = {
'alg': 'RS256',
'typ': 'JWT',
'x5t': base64.b64encode(
config['cert_thumb'].decode('hex'))
}
header_b64 = base64.b64encode(json.dumps(header).encode('utf-8'))

body = {
'aud': "https://login.microsoftonline.com/%s/oauth2/token" %
config['tenant'],
'exp': (int(time.time()) + 600),
'iss': config['client_id'],
'jti': str(uuid.uuid4()),
'nbf': int(time.time()),
'sub': config['client_id']
}
body_b64 = base64.b64encode(json.dumps(body).encode('utf-8'))

full_b64 = b'.'.join([header_b64, body_b64])

client = KeyVaultClient(KeyVaultAuthentication(auth_vault_callback))
chosen_hash = hashes.SHA256()
hasher = hashes.Hash(chosen_hash, default_backend())
hasher.update(full_b64)
digest = hasher.finalize()
signed_digest = client.sign(config['vault_url'],
config['cert'], '', 'RS256',
digest).result

full_token = b'.'.join([full_b64, base64.b64encode(signed_digest)])

return full_token


authority = "https://login.microsoftonline.com/%s" % config['tenant']

app = msal.ConfidentialClientApplication(
config['client_id'], authority=authority, client_credential={"client_assertion": make_vault_jwt()}
)

# The pattern to acquire a token looks like this.
result = None

# Firstly, looks up a token from cache
# Since we are looking for token for the current app, NOT for an end user,
# notice we give account parameter as None.
result = app.acquire_token_silent(config["scope"], account=None)

if not result:
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
result = app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
lochiiconnectivity marked this conversation as resolved.
Show resolved Hide resolved
config["endpoint"],
headers={'Authorization': 'Bearer ' + result['access_token']},).json()
print("Graph API call result: %s" % json.dumps(graph_data, indent=2))
else:
print(result.get("error"))
print(result.get("error_description"))
print(result.get("correlation_id")) # You may need this when reporting a bug

10 changes: 9 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,15 @@ class TestClient(Oauth2TestCase):
@classmethod
def setUpClass(cls):
http_client = MinimalHttpClient()
if "client_certificate" in CONFIG:
if "client_assertion" in CONFIG:
cls.client = Client(
CONFIG["openid_configuration"],
CONFIG['client_id'],
http_client=http_client,
client_assertion=CONFIG["client_assertion"],
client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT,
)
elif "client_certificate" in CONFIG:
private_key_path = CONFIG["client_certificate"]["private_key_path"]
with open(os.path.join(THIS_FOLDER, private_key_path)) as f:
private_key = f.read() # Expecting PEM format
Expand Down
10 changes: 10 additions & 0 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,16 @@ def test_subject_name_issuer_authentication(self):
self.assertIn('access_token', result)
self.assertCacheWorksForApp(result, scope)

def test_pre_signed_jwt_authentication(self):
lochiiconnectivity marked this conversation as resolved.
Show resolved Hide resolved
self.skipUnlessWithConfig(["client_id", "client_assertion"])
self.app = msal.ConfidentialClientApplication(
self.config['client_id'], authority=self.config["authority"],
client_credential={"client_assertion": self.config["client_assertion"]},
http_client=MinimalHttpClient())
scope = self.config.get("scope", [])
result = self.app.acquire_token_for_client(scope)
self.assertIn('access_token', result)
self.assertCacheWorksForApp(result, scope)

@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG)
class DeviceFlowTestCase(E2eTestCase): # A leaf class so it will be run only once
Expand Down