diff --git a/.github/workflows/Linux.yml b/.github/workflows/Linux.yml index 2645f32..028de6c 100644 --- a/.github/workflows/Linux.yml +++ b/.github/workflows/Linux.yml @@ -105,7 +105,7 @@ jobs: - name: Setup vcpkg uses: lukka/run-vcpkg@v11 with: - vcpkgGitCommitId: 501db0f17ef6df184fcdbfbe0f87cde2313b6ab1 + vcpkgGitCommitId: 9edb1b8e590cc086563301d735cae4b6e732d2d2 # Build extension - name: Build extension diff --git a/.github/workflows/MacOS.yml b/.github/workflows/MacOS.yml index 46556ed..44040d9 100644 --- a/.github/workflows/MacOS.yml +++ b/.github/workflows/MacOS.yml @@ -54,7 +54,7 @@ jobs: - name: Setup vcpkg uses: lukka/run-vcpkg@v11 with: - vcpkgGitCommitId: 501db0f17ef6df184fcdbfbe0f87cde2313b6ab1 + vcpkgGitCommitId: 9edb1b8e590cc086563301d735cae4b6e732d2d2 - name: Build extension shell: bash diff --git a/.github/workflows/MainDistributionPipeline.yml b/.github/workflows/MainDistributionPipeline.yml index 21ecf2b..a416ee3 100644 --- a/.github/workflows/MainDistributionPipeline.yml +++ b/.github/workflows/MainDistributionPipeline.yml @@ -18,6 +18,7 @@ jobs: with: duckdb_version: v0.9.1 extension_name: azure + vcpkg_commit: 9edb1b8e590cc086563301d735cae4b6e732d2d2 # TODO: remove pinned vcpkg commit when updating duckdb version duckdb-stable-deploy: name: Deploy extension binaries diff --git a/.github/workflows/Windows.yml b/.github/workflows/Windows.yml index 1791987..82f1d82 100644 --- a/.github/workflows/Windows.yml +++ b/.github/workflows/Windows.yml @@ -33,7 +33,7 @@ jobs: - name: Setup vcpkg uses: lukka/run-vcpkg@v11 with: - vcpkgGitCommitId: 501db0f17ef6df184fcdbfbe0f87cde2313b6ab1 + vcpkgGitCommitId: 9edb1b8e590cc086563301d735cae4b6e732d2d2 - uses: actions/setup-python@v2 with: diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d8694d..e0d8eab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,12 +20,17 @@ find_package(azure-identity-cpp CONFIG REQUIRED) find_package(azure-storage-blobs-cpp CONFIG REQUIRED) # Static lib -target_link_libraries(${EXTENSION_NAME} Azure::azure-identity Azure::azure-storage-blobs) -target_include_directories(${EXTENSION_NAME} PRIVATE Azure::azure-identity Azure::azure-storage-blobs) +target_link_libraries(${EXTENSION_NAME} Azure::azure-identity + Azure::azure-storage-blobs) +target_include_directories(${EXTENSION_NAME} PRIVATE Azure::azure-identity + Azure::azure-storage-blobs) # Loadable binary -target_link_libraries(${TARGET_NAME}_loadable_extension Azure::azure-identity Azure::azure-storage-blobs) -target_include_directories(${TARGET_NAME}_loadable_extension PRIVATE Azure::azure-identity Azure::azure-storage-blobs) +target_link_libraries(${TARGET_NAME}_loadable_extension Azure::azure-identity + Azure::azure-storage-blobs) +target_include_directories( + ${TARGET_NAME}_loadable_extension PRIVATE Azure::azure-identity + Azure::azure-storage-blobs) install( TARGETS ${EXTENSION_NAME} diff --git a/src/azure_extension.cpp b/src/azure_extension.cpp index b07bef9..d9fc34f 100644 --- a/src/azure_extension.cpp +++ b/src/azure_extension.cpp @@ -1,74 +1,123 @@ #define DUCKDB_EXTENSION_MAIN #include "azure_extension.hpp" + #include "duckdb.hpp" #include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/scalar_function.hpp" #include "duckdb/common/file_opener.hpp" +#include "duckdb/common/string_util.hpp" #include "duckdb/function/scalar/string_functions.hpp" - +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/main/extension_util.hpp" +#include +#include +#include +#include +#include +#include #include - #include -#include - namespace duckdb { -BlobClientWrapper::BlobClientWrapper(AzureAuthentication auth, const string& path) { - auto container_client = Azure::Storage::Blobs::BlobContainerClient::CreateFromConnectionString(auth.connection_string, auth.container); - blob_client = make_uniq(container_client.GetBlockBlobClient(path)); +static Azure::Identity::ChainedTokenCredential::Sources +CreateCredentialChainFromSetting(const string &credential_chain) { + auto chain_list = StringUtil::Split(credential_chain, ';'); + Azure::Identity::ChainedTokenCredential::Sources result; + + for (const auto &item : chain_list) { + if (item == "cli") { + result.push_back(std::make_shared()); + } else if (item == "managed_identity") { + result.push_back(std::make_shared()); + } else if (item == "env") { + result.push_back(std::make_shared()); + } else if (item == "default") { + result.push_back(std::make_shared()); + } else if (item != "none") { + throw InvalidInputException("Unknown credential provider found: " + item); + } + } + + return result; } -BlobClientWrapper::~BlobClientWrapper() = default; -Azure::Storage::Blobs::BlobClient* BlobClientWrapper::GetClient() { - return blob_client.get(); -}; -unique_ptr AzureStorageFileSystem::CreateHandle(const string &path, uint8_t flags, FileLockType lock, - FileCompressionType compression, FileOpener *opener) { - D_ASSERT(compression == FileCompressionType::UNCOMPRESSED); - auto parsed_url = ParseUrl(path); +static AzureAuthentication ParseAzureAuthSettings(FileOpener *opener) { + AzureAuthentication auth; - string connection_string; - Value value; - if (FileOpener::TryGetCurrentSetting(opener, "azure_storage_connection_string", value)) { - connection_string = value.ToString(); + Value connection_string_val; + if (FileOpener::TryGetCurrentSetting(opener, "azure_storage_connection_string", connection_string_val)) { + auth.connection_string = connection_string_val.ToString(); } - if (connection_string.empty()) { - throw IOException("No azure_storage_connection_string found, please set using SET azure_storage_connection_string='' "); + Value account_name_val; + if (FileOpener::TryGetCurrentSetting(opener, "azure_account_name", account_name_val)) { + auth.account_name = account_name_val.ToString(); } - AzureAuthentication auth{ - connection_string, - parsed_url.container - }; + Value endpoint_val; + if (FileOpener::TryGetCurrentSetting(opener, "azure_endpoint", endpoint_val)) { + auth.endpoint = endpoint_val.ToString(); + } + + if (!auth.account_name.empty()) { + string credential_chain; + Value credential_chain_val; + if (FileOpener::TryGetCurrentSetting(opener, "azure_credential_chain", credential_chain_val)) { + auth.credential_chain = credential_chain_val.ToString(); + } + } - return make_uniq(*this, path, flags, auth, parsed_url); + return auth; } -unique_ptr AzureStorageFileSystem::OpenFile(const string &path, uint8_t flags, FileLockType lock, - FileCompressionType compression, FileOpener *opener) { - D_ASSERT(compression == FileCompressionType::UNCOMPRESSED); +static Azure::Storage::Blobs::BlobContainerClient GetContainerClient(AzureAuthentication &auth, AzureParsedUrl &url) { + if (!auth.connection_string.empty()) { + return Azure::Storage::Blobs::BlobContainerClient::CreateFromConnectionString(auth.connection_string, + url.container); + } - if (flags & FileFlags::FILE_FLAGS_WRITE) { - throw NotImplementedException("Writing to Azure containers is currently not supported"); + // Build credential chain, from last to first + Azure::Identity::ChainedTokenCredential::Sources credential_chain; + if (!auth.credential_chain.empty()) { + credential_chain = CreateCredentialChainFromSetting(auth.credential_chain); } - auto handle = CreateHandle(path, flags, lock, compression, opener); - return std::move(handle); + auto accountURL = "https://" + auth.account_name + "." + auth.endpoint; + if (!credential_chain.empty()) { + // A set of credentials providers was passed + auto chainedTokenCredential = std::make_shared(credential_chain); + Azure::Storage::Blobs::BlobServiceClient blob_service_client(accountURL, chainedTokenCredential); + return blob_service_client.GetBlobContainerClient(url.container); + } else if (!auth.account_name.empty()) { + return Azure::Storage::Blobs::BlobContainerClient(accountURL + "/" + url.container); + } else { + throw InvalidInputException( + "No valid Azure credentials found, use either the azure_connection_string or azure_account_name"); + } } -AzureStorageFileHandle::AzureStorageFileHandle(FileSystem &fs, string path_p, uint8_t flags, AzureAuthentication auth, AzureParsedUrl parsed_url) - : FileHandle(fs, std::move(path_p)), flags(flags), length(0), buffer_available(0), buffer_idx(0), file_offset(0), - buffer_start(0), buffer_end(0), blob_client(std::move(auth), parsed_url.path) { +BlobClientWrapper::BlobClientWrapper(AzureAuthentication &auth, AzureParsedUrl &url) { + auto container_client = GetContainerClient(auth, url); + blob_client = make_uniq(container_client.GetBlockBlobClient(url.path)); +} + +BlobClientWrapper::~BlobClientWrapper() = default; +Azure::Storage::Blobs::BlobClient *BlobClientWrapper::GetClient() { + return blob_client.get(); +}; + +AzureStorageFileHandle::AzureStorageFileHandle(FileSystem &fs, string path_p, uint8_t flags, AzureAuthentication &auth, + AzureParsedUrl parsed_url) + : FileHandle(fs, std::move(path_p)), flags(flags), length(0), last_modified(time_t()), buffer_available(0), + buffer_idx(0), file_offset(0), buffer_start(0), buffer_end(0), blob_client(auth, parsed_url) { try { auto client = *blob_client.GetClient(); auto res = client.GetProperties(); length = res.Value.BlobSize; } catch (Azure::Storage::StorageException &e) { - throw IOException("AzureStorageFileSystem open file " + path + " failed with " + e.ErrorCode + "Reason Phrase: " + e.ReasonPhrase); + throw IOException("AzureStorageFileSystem open file '" + path + "' failed with code'" + e.ErrorCode + + "',Reason Phrase: '" + e.ReasonPhrase + "', Message: '" + e.Message + "'"); } if (flags & FileFlags::FILE_FLAGS_READ) { @@ -76,44 +125,83 @@ AzureStorageFileHandle::AzureStorageFileHandle(FileSystem &fs, string path_p, ui } } +unique_ptr AzureStorageFileSystem::CreateHandle(const string &path, uint8_t flags, + FileLockType lock, + FileCompressionType compression, + FileOpener *opener) { + D_ASSERT(compression == FileCompressionType::UNCOMPRESSED); + + auto parsed_url = ParseUrl(path); + auto azure_auth = ParseAzureAuthSettings(opener); + + return make_uniq(*this, path, flags, azure_auth, parsed_url); +} + +unique_ptr AzureStorageFileSystem::OpenFile(const string &path, uint8_t flags, FileLockType lock, + FileCompressionType compression, FileOpener *opener) { + D_ASSERT(compression == FileCompressionType::UNCOMPRESSED); + + if (flags & FileFlags::FILE_FLAGS_WRITE) { + throw NotImplementedException("Writing to Azure containers is currently not supported"); + } + + auto handle = CreateHandle(path, flags, lock, compression, opener); + return std::move(handle); +} + int64_t AzureStorageFileSystem::GetFileSize(FileHandle &handle) { - auto &afh = (AzureStorageFileHandle &)handle; - return afh.length; + auto &afh = (AzureStorageFileHandle &)handle; + return afh.length; } time_t AzureStorageFileSystem::GetLastModifiedTime(FileHandle &handle) { - auto &afh = (AzureStorageFileHandle &)handle; - return afh.last_modified; + auto &afh = (AzureStorageFileHandle &)handle; + return afh.last_modified; } // TODO: this is currently a bit weird: it should be az:// but that shit dont work bool AzureStorageFileSystem::CanHandleFile(const string &fpath) { - return fpath.rfind("azure://", 0) == 0; + return fpath.rfind("azure://", 0) == 0; } void AzureStorageFileSystem::Seek(FileHandle &handle, idx_t location) { - auto &sfh = (AzureStorageFileHandle &)handle; - sfh.file_offset = location; + auto &sfh = (AzureStorageFileHandle &)handle; + sfh.file_offset = location; } void AzureStorageFileSystem::FileSync(FileHandle &handle) { - throw NotImplementedException("FileSync for Azure Storage files not implemented"); + throw NotImplementedException("FileSync for Azure Storage files not implemented"); } static void LoadInternal(DatabaseInstance &instance) { - auto &fs = instance.GetFileSystem(); - fs.RegisterSubSystem(make_uniq()); + // Load filesystem + auto &fs = instance.GetFileSystem(); + fs.RegisterSubSystem(make_uniq()); + // Load extension config auto &config = DBConfig::GetConfig(instance); - config.AddExtensionOption("azure_storage_connection_string", "Azure connection string, used for authenticating and configuring azure requests", LogicalType::VARCHAR); + config.AddExtensionOption("azure_storage_connection_string", + "Azure connection string, used for authenticating and configuring azure requests", + LogicalType::VARCHAR); + config.AddExtensionOption( + "azure_account_name", + "Azure account name, when set, the extension will attempt to automatically detect credentials", + LogicalType::VARCHAR); + config.AddExtensionOption("azure_credential_chain", + "Ordered list of Azure credential providers, in string format separated by ';'. E.g. " + "'cli;managed_identity;env'", + LogicalType::VARCHAR, "none"); + config.AddExtensionOption("azure_endpoint", + "Override the azure endpoint for when the Azure credential providers are used.", + LogicalType::VARCHAR, "blob.core.windows.net"); } int64_t AzureStorageFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &hfh = (AzureStorageFileHandle &)handle; - idx_t max_read = hfh.length - hfh.file_offset; - nr_bytes = MinValue(max_read, nr_bytes); - Read(handle, buffer, nr_bytes, hfh.file_offset); - return nr_bytes; + auto &hfh = (AzureStorageFileHandle &)handle; + idx_t max_read = hfh.length - hfh.file_offset; + nr_bytes = MinValue(max_read, nr_bytes); + Read(handle, buffer, nr_bytes, hfh.file_offset); + return nr_bytes; } // taken from s3fs.cpp TODO: deduplicate! @@ -146,32 +234,30 @@ vector AzureStorageFileSystem::Glob(const string &path, FileOpener *open if (opener == nullptr) { throw InternalException("Cannot do Azure storage Glob without FileOpener"); } - auto parsed_azure_url = AzureStorageFileSystem::ParseUrl(path); + auto azure_url = AzureStorageFileSystem::ParseUrl(path); + auto azure_auth = ParseAzureAuthSettings(opener); // Azure matches on prefix, not glob pattern, so we take a substring until the first wildcard - auto first_wildcard_pos = parsed_azure_url.path.find_first_of("*[\\"); + auto first_wildcard_pos = azure_url.path.find_first_of("*[\\"); if (first_wildcard_pos == string::npos) { return {path}; } - string shared_path = parsed_azure_url.path.substr(0, first_wildcard_pos); - Value value; - string connection_string; - if (FileOpener::TryGetCurrentSetting(opener, "azure_storage_connection_string", value)) { - connection_string = value.ToString(); - } - - if (connection_string.empty()) { - throw IOException("No azure_storage_connection_string found, please set using SET azure_storage_connection_string='' "); - } - - auto container_client = Azure::Storage::Blobs::BlobContainerClient::CreateFromConnectionString(connection_string, parsed_azure_url.container); + string shared_path = azure_url.path.substr(0, first_wildcard_pos); + auto container_client = GetContainerClient(azure_auth, azure_url); vector found_keys; Azure::Storage::Blobs::ListBlobsOptions options; options.Prefix = shared_path; - while(true) { - auto res = container_client.ListBlobs(options); + while (true) { + Azure::Storage::Blobs::ListBlobsPagedResponse res; + try { + res = container_client.ListBlobs(options); + } catch (Azure::Storage::StorageException &e) { + throw IOException("AzureStorageFileSystem Read to " + path + " failed with " + e.ErrorCode + + "Reason Phrase: " + e.ReasonPhrase); + } + found_keys.insert(found_keys.end(), res.Blobs.begin(), res.Blobs.end()); if (res.NextPageToken) { options.ContinuationToken = res.NextPageToken; @@ -180,14 +266,14 @@ vector AzureStorageFileSystem::Glob(const string &path, FileOpener *open } } - vector pattern_splits = StringUtil::Split(parsed_azure_url.path, "/"); + vector pattern_splits = StringUtil::Split(azure_url.path, "/"); vector result; for (const auto &key : found_keys) { vector key_splits = StringUtil::Split(key.Name, "/"); bool is_match = Match(key_splits.begin(), key_splits.end(), pattern_splits.begin(), pattern_splits.end()); if (is_match) { - auto result_full_url = "azure://" + parsed_azure_url.container + "/" + key.Name; + auto result_full_url = "azure://" + azure_url.container + "/" + key.Name; result.push_back(result_full_url); } } @@ -284,11 +370,12 @@ void AzureStorageFileSystem::ReadRange(FileHandle &handle, idx_t file_offset, ch auto res = blob_client.DownloadTo((uint8_t *)buffer_out, buffer_out_len, options); } catch (Azure::Storage::StorageException &e) { - throw IOException("AzureStorageFileSystem Read to " + afh.path + " failed with " + e.ErrorCode + "Reason Phrase: " + e.ReasonPhrase); + throw IOException("AzureStorageFileSystem Read to " + afh.path + " failed with " + e.ErrorCode + + "Reason Phrase: " + e.ReasonPhrase); } } -AzureParsedUrl AzureStorageFileSystem::ParseUrl(const string& url) { +AzureParsedUrl AzureStorageFileSystem::ParseUrl(const string &url) { string container, path; if (url.rfind("azure://", 0) != 0) { @@ -303,7 +390,7 @@ AzureParsedUrl AzureStorageFileSystem::ParseUrl(const string& url) { throw IOException("URL needs to contain a bucket name"); } - path = url.substr(slash_pos+1); + path = url.substr(slash_pos + 1); return {container, path}; } diff --git a/src/include/azure_extension.hpp b/src/include/azure_extension.hpp index 71cd4fb..4d548e9 100644 --- a/src/include/azure_extension.hpp +++ b/src/include/azure_extension.hpp @@ -3,12 +3,12 @@ #include "duckdb.hpp" namespace Azure { - namespace Storage { - namespace Blobs { - class BlobClient; - } - } +namespace Storage { +namespace Blobs { +class BlobClient; } +} // namespace Storage +} // namespace Azure namespace duckdb { @@ -19,8 +19,13 @@ class AzureExtension : public Extension { }; struct AzureAuthentication { + //! Auth method #1: setting the connection string string connection_string; - string container; + + //! Auth method #2: setting account name + defining a credential chain. + string account_name; + string credential_chain; + string endpoint; }; struct AzureParsedUrl { @@ -30,16 +35,18 @@ struct AzureParsedUrl { class BlobClientWrapper { public: - BlobClientWrapper(AzureAuthentication auth, const string& path); + BlobClientWrapper(AzureAuthentication &auth, AzureParsedUrl &url); ~BlobClientWrapper(); - Azure::Storage::Blobs::BlobClient* GetClient(); + Azure::Storage::Blobs::BlobClient *GetClient(); + protected: unique_ptr blob_client; }; class AzureStorageFileHandle : public FileHandle { public: - AzureStorageFileHandle(FileSystem &fs, string path, uint8_t flags, AzureAuthentication auth, AzureParsedUrl parsed_url); + AzureStorageFileHandle(FileSystem &fs, string path, uint8_t flags, AzureAuthentication &auth, + AzureParsedUrl parsed_url); ~AzureStorageFileHandle() override = default; public: @@ -98,7 +105,7 @@ class AzureStorageFileSystem : public FileSystem { static void Verify(); protected: - static AzureParsedUrl ParseUrl(const string& url); + static AzureParsedUrl ParseUrl(const string &url); static void ReadRange(FileHandle &handle, idx_t file_offset, char *buffer_out, idx_t buffer_out_len); virtual duckdb::unique_ptr CreateHandle(const string &path, uint8_t flags, FileLockType lock, FileCompressionType compression, diff --git a/test/sql/azure.test b/test/sql/azure.test index 6aaa107..6cc9564 100644 --- a/test/sql/azure.test +++ b/test/sql/azure.test @@ -13,7 +13,7 @@ require-env AZURE_STORAGE_CONNECTION_STRING statement error SELECT sum(l_orderkey) FROM 'azure://testing-private/l.parquet'; ---- -IO Error: No azure_storage_connection_string found, please set using SET azure_storage_connection_string='' +Invalid Input Error: No valid Azure credentials found # Set connection string from env var statement ok diff --git a/test/sql/azure_auth_local.test b/test/sql/azure_auth_local.test new file mode 100644 index 0000000..3bc5d53 --- /dev/null +++ b/test/sql/azure_auth_local.test @@ -0,0 +1,93 @@ +# name: test/sql/azure_auth_local.test +# description: test azure extension authentication +# group: [azure] + +require azure + +require parquet + +require-env DUCKDB_CLI_TEST_ENV_AVAILABLE + +# Note: this test is currently not run in CI as it requires setting up quite a bit of setup. +# for now, to run this test locally, ensure you have access to the duckdbtesting storage +# account, then login through the cli. Then running the test with DUCKDB_CLI_TEST_ENV_AVAILABLE=1 +# should give all green! +# +# TODO: We should setup a key in CI to automatically test this. Ideally that would also involve Managed identities and +# service principals + +# Set the storage account name +statement ok +set azure_account_name='duckdbtesting'; + +# Set the azure credential chain +statement ok +set azure_credential_chain = 'cli'; + +query I +SELECT count(*) FROM 'azure://testing-public/l1.parquet'; +---- +60175 + +query I +SELECT count(*) FROM 'azure://testing-public/l*.parquet'; +---- +180525 + +# With the CLI credentials, private authentication should now work +query I +SELECT count(*) FROM 'azure://testing-private/l1.parquet'; +---- +60175 + +query I +SELECT count(*) FROM 'azure://testing-private/l*.parquet'; +---- +180525 + +# No credential providers, public buckets should still work! +statement ok +set azure_credential_chain = ''; + +query I +SELECT count(*) FROM 'azure://testing-public/l1.parquet'; +---- +60175 + +query I +SELECT count(*) FROM 'azure://testing-public/l*.parquet'; +---- +180525 + +# private without credentials don't work +statement error +SELECT count(*) FROM 'azure://testing-private/l.parquet'; +---- +IO Error: AzureStorageFileSystem + +# globbing neither +statement error +SELECT count(*) FROM 'azure://testing-private/l*.parquet'; +---- +IO Error: AzureStorageFileSystem + +# Note that we can construct a chain of credential providers: +statement ok +set azure_credential_chain = 'env;cli;managed_identity'; + +# Still good! +query I +SELECT count(*) FROM 'azure://testing-private/l1.parquet'; +---- +60175 + +query I +SELECT count(*) FROM 'azure://testing-private/l*.parquet'; +---- +180525 + +statement ok +set azure_endpoint='nop.nop'; + +statement error +SELECT count(*) FROM 'azure://testing-private/l*.parquet'; \ No newline at end of file