-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Cosmos] AAD authentication async client (#23717)
* working authentication to get database account * working aad authentication for sync client with sample * readme and changelog * pylint and better comments on sample * working async aad * Delete access_cosmos_with_aad.py snuck its way into the async PR * Update _auth_policies.py * small changes * Update _cosmos_client_connection.py * removing changes made in sync * Update _auth_policy_async.py * Update _auth_policy_async.py * Update _auth_policy_async.py * added licenses to samples
- Loading branch information
Showing
23 changed files
with
392 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
175 changes: 175 additions & 0 deletions
175
sdk/cosmos/azure-cosmos/azure/cosmos/aio/_auth_policy_async.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See LICENSE.txt in the project root for | ||
# license information. | ||
# ------------------------------------------------------------------------- | ||
import asyncio | ||
import time | ||
|
||
from typing import Any, Awaitable, Optional, Dict, Union | ||
from azure.core.pipeline.policies import AsyncHTTPPolicy | ||
from azure.core.credentials import AccessToken | ||
from azure.core.pipeline import PipelineRequest, PipelineResponse | ||
from azure.cosmos import http_constants | ||
|
||
|
||
async def await_result(func, *args, **kwargs): | ||
"""If func returns an awaitable, await it.""" | ||
result = func(*args, **kwargs) | ||
if hasattr(result, "__await__"): | ||
# type ignore on await: https://github.com/python/mypy/issues/7587 | ||
return await result # type: ignore | ||
return result | ||
|
||
|
||
class _AsyncCosmosBearerTokenCredentialPolicyBase(object): | ||
"""Base class for a Bearer Token Credential Policy. | ||
:param credential: The credential. | ||
:type credential: ~azure.core.credentials.TokenCredential | ||
:param str scopes: Lets you specify the type of access needed. | ||
""" | ||
|
||
def __init__(self, credential, *scopes, **kwargs): # pylint:disable=unused-argument | ||
# type: (TokenCredential, *str, **Any) -> None | ||
super(_AsyncCosmosBearerTokenCredentialPolicyBase, self).__init__() | ||
self._scopes = scopes | ||
self._credential = credential | ||
self._token = None # type: Optional[AccessToken] | ||
self._lock = asyncio.Lock() | ||
|
||
@staticmethod | ||
def _enforce_https(request): | ||
# type: (PipelineRequest) -> None | ||
|
||
# move 'enforce_https' from options to context so it persists | ||
# across retries but isn't passed to a transport implementation | ||
option = request.context.options.pop("enforce_https", None) | ||
|
||
# True is the default setting; we needn't preserve an explicit opt in to the default behavior | ||
if option is False: | ||
request.context["enforce_https"] = option | ||
|
||
enforce_https = request.context.get("enforce_https", True) | ||
if enforce_https and not request.http_request.url.lower().startswith("https"): | ||
raise ValueError( | ||
"Bearer token authentication is not permitted for non-TLS protected (non-https) URLs." | ||
) | ||
|
||
@staticmethod | ||
def _update_headers(headers, token): | ||
# type: (Dict[str, str], str) -> None | ||
"""Updates the Authorization header with the cosmos signature and bearer token. | ||
This is the main method that differentiates this policy from core's BearerTokenCredentialPolicy and works | ||
to properly sign the authorization header for Cosmos' REST API. For more information: | ||
https://docs.microsoft.com/rest/api/cosmos-db/access-control-on-cosmosdb-resources#authorization-header | ||
:param dict headers: The HTTP Request headers | ||
:param str token: The OAuth token. | ||
""" | ||
headers[http_constants.HttpHeaders.Authorization] = "type=aad&ver=1.0&sig={}".format(token) | ||
|
||
@property | ||
def _need_new_token(self) -> bool: | ||
return not self._token or self._token.expires_on - time.time() < 300 | ||
|
||
|
||
class AsyncCosmosBearerTokenCredentialPolicy(_AsyncCosmosBearerTokenCredentialPolicyBase, AsyncHTTPPolicy): | ||
"""Adds a bearer token Authorization header to requests. | ||
:param credential: The credential. | ||
:type credential: ~azure.core.TokenCredential | ||
:param str scopes: Lets you specify the type of access needed. | ||
:raises ValueError: If https_enforce does not match with endpoint being used. | ||
""" | ||
|
||
async def on_request(self, request: "PipelineRequest") -> None: # pylint:disable=invalid-overridden-method | ||
"""Adds a bearer token Authorization header to request and sends request to next policy. | ||
:param request: The pipeline request object to be modified. | ||
:type request: ~azure.core.pipeline.PipelineRequest | ||
:raises: :class:`~azure.core.exceptions.ServiceRequestError` | ||
""" | ||
self._enforce_https(request) # pylint:disable=protected-access | ||
|
||
if self._token is None or self._need_new_token: | ||
async with self._lock: | ||
# double check because another coroutine may have acquired a token while we waited to acquire the lock | ||
if self._token is None or self._need_new_token: | ||
self._token = await self._credential.get_token(*self._scopes) | ||
self._update_headers(request.http_request.headers, self._token.token) | ||
|
||
async def authorize_request(self, request: "PipelineRequest", *scopes: str, **kwargs: "Any") -> None: | ||
"""Acquire a token from the credential and authorize the request with it. | ||
Keyword arguments are passed to the credential's get_token method. The token will be cached and used to | ||
authorize future requests. | ||
:param ~azure.core.pipeline.PipelineRequest request: the request | ||
:param str scopes: required scopes of authentication | ||
""" | ||
async with self._lock: | ||
self._token = await self._credential.get_token(*scopes, **kwargs) | ||
self._update_headers(request.http_request.headers, self._token.token) | ||
|
||
async def send(self, request: "PipelineRequest") -> "PipelineResponse": | ||
"""Authorize request with a bearer token and send it to the next policy | ||
:param request: The pipeline request object | ||
:type request: ~azure.core.pipeline.PipelineRequest | ||
""" | ||
await await_result(self.on_request, request) | ||
try: | ||
response = await self.next.send(request) | ||
await await_result(self.on_response, request, response) | ||
except Exception: # pylint:disable=broad-except | ||
handled = await await_result(self.on_exception, request) | ||
if not handled: | ||
raise | ||
else: | ||
if response.http_response.status_code == 401: | ||
self._token = None # any cached token is invalid | ||
if "WWW-Authenticate" in response.http_response.headers: | ||
request_authorized = await self.on_challenge(request, response) | ||
if request_authorized: | ||
try: | ||
response = await self.next.send(request) | ||
await await_result(self.on_response, request, response) | ||
except Exception: # pylint:disable=broad-except | ||
handled = await await_result(self.on_exception, request) | ||
if not handled: | ||
raise | ||
|
||
return response | ||
|
||
async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool: | ||
"""Authorize request according to an authentication challenge | ||
This method is called when the resource provider responds 401 with a WWW-Authenticate header. | ||
:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge | ||
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response | ||
:returns: a bool indicating whether the policy should send the request | ||
""" | ||
# pylint:disable=unused-argument,no-self-use | ||
return False | ||
|
||
def on_response(self, request: PipelineRequest, response: PipelineResponse) -> Union[None, Awaitable[None]]: | ||
"""Executed after the request comes back from the next policy. | ||
:param request: Request to be modified after returning from the policy. | ||
:type request: ~azure.core.pipeline.PipelineRequest | ||
:param response: Pipeline response object | ||
:type response: ~azure.core.pipeline.PipelineResponse | ||
""" | ||
|
||
def on_exception(self, request: PipelineRequest) -> None: | ||
"""Executed when an exception is raised while executing the next policy. | ||
This method is executed inside the exception handler. | ||
:param request: The Pipeline request object | ||
:type request: ~azure.core.pipeline.PipelineRequest | ||
""" | ||
# pylint: disable=no-self-use,unused-argument | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
116 changes: 116 additions & 0 deletions
116
sdk/cosmos/azure-cosmos/samples/access_cosmos_with_aad_async.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See LICENSE.txt in the project root for | ||
# license information. | ||
# ------------------------------------------------------------------------- | ||
from azure.cosmos.aio import CosmosClient | ||
import azure.cosmos.exceptions as exceptions | ||
from azure.cosmos.partition_key import PartitionKey | ||
from azure.identity.aio import ClientSecretCredential, DefaultAzureCredential | ||
import config | ||
import asyncio | ||
|
||
# ---------------------------------------------------------------------------------------------------------- | ||
# Prerequistes - | ||
# | ||
# 1. An Azure Cosmos account - | ||
# https://docs.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account | ||
# | ||
# 2. Microsoft Azure Cosmos | ||
# pip install azure-cosmos>=4.3.0b4 | ||
# ---------------------------------------------------------------------------------------------------------- | ||
# Sample - demonstrates how to authenticate and use your database account using AAD credentials | ||
# Read more about operations allowed for this authorization method: https://aka.ms/cosmos-native-rbac | ||
# ---------------------------------------------------------------------------------------------------------- | ||
# Note: | ||
# This sample creates a Container to your database account. | ||
# Each time a Container is created the account will be billed for 1 hour of usage based on | ||
# the provisioned throughput (RU/s) of that account. | ||
# ---------------------------------------------------------------------------------------------------------- | ||
# <configureConnectivity> | ||
HOST = config.settings["host"] | ||
MASTER_KEY = config.settings["master_key"] | ||
|
||
TENANT_ID = config.settings["tenant_id"] | ||
CLIENT_ID = config.settings["client_id"] | ||
CLIENT_SECRET = config.settings["client_secret"] | ||
|
||
DATABASE_ID = config.settings["database_id"] | ||
CONTAINER_ID = config.settings["container_id"] | ||
PARTITION_KEY = PartitionKey(path="/id") | ||
|
||
|
||
def get_test_item(num): | ||
test_item = { | ||
'id': 'Item_' + str(num), | ||
'test_object': True, | ||
'lastName': 'Smith' | ||
} | ||
return test_item | ||
|
||
|
||
async def create_sample_resources(): | ||
print("creating sample resources") | ||
async with CosmosClient(HOST, MASTER_KEY) as client: | ||
db = await client.create_database(DATABASE_ID) | ||
await db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) | ||
|
||
|
||
async def delete_sample_resources(): | ||
print("deleting sample resources") | ||
async with CosmosClient(HOST, MASTER_KEY) as client: | ||
await client.delete_database(DATABASE_ID) | ||
|
||
|
||
async def run_sample(): | ||
# Since Azure Cosmos DB data plane SDK does not cover management operations, we have to create our resources | ||
# with a master key authenticated client for this sample. | ||
await create_sample_resources() | ||
|
||
# With this done, you can use your AAD service principal id and secret to create your ClientSecretCredential. | ||
# The async ClientSecretCredentials, like the async client, also have a context manager, | ||
# and as such should be used with the `async with` keywords. | ||
async with ClientSecretCredential( | ||
tenant_id=TENANT_ID, | ||
client_id=CLIENT_ID, | ||
client_secret=CLIENT_SECRET) as aad_credentials: | ||
|
||
# Use your credentials to authenticate your client. | ||
async with CosmosClient(HOST, aad_credentials) as aad_client: | ||
print("Showed ClientSecretCredential, now showing DefaultAzureCredential") | ||
|
||
# You can also utilize DefaultAzureCredential rather than directly passing in the id's and secrets. | ||
# This is the recommended method of authentication, and uses environment variables rather than in-code strings. | ||
async with DefaultAzureCredential() as aad_credentials: | ||
|
||
# Use your credentials to authenticate your client. | ||
async with CosmosClient(HOST, aad_credentials) as aad_client: | ||
|
||
# Do any R/W data operations with your authorized AAD client. | ||
db = aad_client.get_database_client(DATABASE_ID) | ||
container = db.get_container_client(CONTAINER_ID) | ||
|
||
print("Container info: " + str(container.read())) | ||
await container.create_item(get_test_item(879)) | ||
print("Point read result: " + str(container.read_item(item='Item_0', partition_key='Item_0'))) | ||
query_results = [item async for item in | ||
container.query_items(query='select * from c', partition_key='Item_0')] | ||
assert len(query_results) == 1 | ||
print("Query result: " + str(query_results[0])) | ||
await container.delete_item(item='Item_0', partition_key='Item_0') | ||
|
||
# Attempting to do management operations will return a 403 Forbidden exception. | ||
try: | ||
await aad_client.delete_database(DATABASE_ID) | ||
except exceptions.CosmosHttpResponseError as e: | ||
assert e.status_code == 403 | ||
print("403 error assertion success") | ||
|
||
# To clean up the sample, we use a master key client again to get access to deleting containers/ databases. | ||
await delete_sample_resources() | ||
print("end of sample") | ||
|
||
|
||
if __name__ == "__main__": | ||
loop = asyncio.get_event_loop() | ||
loop.run_until_complete(run_sample()) |
5 changes: 5 additions & 0 deletions
5
sdk/cosmos/azure-cosmos/samples/access_cosmos_with_resource_token.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
sdk/cosmos/azure-cosmos/samples/access_cosmos_with_resource_token_async.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
5 changes: 5 additions & 0 deletions
5
sdk/cosmos/azure-cosmos/samples/change_feed_management_async.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.