diff --git a/papermill/abs.py b/papermill/abs.py index 700164f5..2e86a0fa 100644 --- a/papermill/abs.py +++ b/papermill/abs.py @@ -3,6 +3,7 @@ import io from azure.storage.blob import BlobServiceClient +from azure.identity import EnvironmentCredential class AzureBlobStore(object): @@ -17,11 +18,13 @@ class AzureBlobStore(object): - write """ - def _blob_service_client(self, account_name, sas_token): + def _blob_service_client(self, account_name, sas_token=None): blob_service_client = BlobServiceClient( - "{account}.blob.core.windows.net".format(account=account_name), sas_token + account_url="{account}.blob.core.windows.net".format(account=account_name), + credential=sas_token or EnvironmentCredential(), ) + return blob_service_client @classmethod @@ -30,7 +33,7 @@ def _split_url(self, url): see: https://docs.microsoft.com/en-us/azure/storage/common/storage-dotnet-shared-access-signature-part-1 # noqa: E501 abs://myaccount.blob.core.windows.net/sascontainer/sasblob.txt?sastoken """ - match = re.match(r"abs://(.*)\.blob\.core\.windows\.net\/(.*?)\/(.*)\?(.*)$", url) + match = re.match(r"abs://(.*)\.blob\.core\.windows\.net\/(.*?)\/([^\?]*)\??(.*)$", url) if not match: raise Exception("Invalid azure blob url '{0}'".format(url)) else: diff --git a/papermill/tests/test_abs.py b/papermill/tests/test_abs.py index d5b2124c..1307e8fc 100644 --- a/papermill/tests/test_abs.py +++ b/papermill/tests/test_abs.py @@ -1,7 +1,8 @@ +import os import unittest from unittest.mock import Mock, patch - +from azure.identity import EnvironmentCredential from ..abs import AzureBlobStore @@ -33,6 +34,9 @@ def setUp(self): ) self.abs = AzureBlobStore() self.abs._blob_service_client = Mock(return_value=self._blob_service_client) + os.environ["AZURE_TENANT_ID"] = "mytenantid" + os.environ["AZURE_CLIENT_ID"] = "myclientid" + os.environ["AZURE_CLIENT_SECRET"] = "myclientsecret" def test_split_url_raises_exception_on_invalid_url(self): with self.assertRaises(Exception) as context: @@ -50,6 +54,15 @@ def test_split_url_splits_valid_url(self): self.assertEqual(params["blob"], "sasblob.txt") self.assertEqual(params["sas_token"], "sastoken") + def test_split_url_splits_valid_url_no_sas(self): + params = AzureBlobStore._split_url( + "abs://myaccount.blob.core.windows.net/container/blob.txt" + ) + self.assertEqual(params["account"], "myaccount") + self.assertEqual(params["container"], "container") + self.assertEqual(params["blob"], "blob.txt") + self.assertEqual(params["sas_token"], "") + def test_split_url_splits_valid_url_with_prefix(self): params = AzureBlobStore._split_url( "abs://myaccount.blob.core.windows.net/sascontainer/A/B/sasblob.txt?sastoken" @@ -97,3 +110,12 @@ def test_blob_service_client(self): self.assertEqual(blob.account_name, "myaccount") # Credentials gets funky with v12.0.0, so I comment this out # self.assertEqual(blob.credential, "sastoken") + + def test_blob_service_client_environment_credentials(self): + abs = AzureBlobStore() + blob = abs._blob_service_client(account_name="myaccount", sas_token="") + self.assertEqual(blob.account_name, "myaccount") + self.assertIsInstance(blob.credential, EnvironmentCredential) + self.assertEqual(blob.credential._credential._tenant_id, "mytenantid") + self.assertEqual(blob.credential._credential._client_id, "myclientid") + self.assertEqual(blob.credential._credential._client_credential, "myclientsecret") diff --git a/requirements/azure.txt b/requirements/azure.txt index c2242ad1..82f60816 100644 --- a/requirements/azure.txt +++ b/requirements/azure.txt @@ -1,3 +1,4 @@ azure-datalake-store >= 0.0.30 azure-storage-blob >= 12.1.0 requests >= 2.21.0 +azure-identity>=1.3.1 \ No newline at end of file