Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
samansmink committed Oct 13, 2023
1 parent 830f0a6 commit fb68a82
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 58 deletions.
13 changes: 9 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@ find_package(azure-identity-cpp CONFIG REQUIRED)
find_package(azure-storage-blobs-cpp CONFIG REQUIRED)

# Static lib
target_link_libraries(${EXTENSION_NAME} Azure::azure-identity Azure::azure-storage-blobs)
target_include_directories(${EXTENSION_NAME} PRIVATE Azure::azure-identity Azure::azure-storage-blobs)
target_link_libraries(${EXTENSION_NAME} Azure::azure-identity
Azure::azure-storage-blobs)
target_include_directories(${EXTENSION_NAME} PRIVATE Azure::azure-identity
Azure::azure-storage-blobs)

# Loadable binary
target_link_libraries(${TARGET_NAME}_loadable_extension Azure::azure-identity Azure::azure-storage-blobs)
target_include_directories(${TARGET_NAME}_loadable_extension PRIVATE Azure::azure-identity Azure::azure-storage-blobs)
target_link_libraries(${TARGET_NAME}_loadable_extension Azure::azure-identity
Azure::azure-storage-blobs)
target_include_directories(
${TARGET_NAME}_loadable_extension PRIVATE Azure::azure-identity
Azure::azure-storage-blobs)

install(
TARGETS ${EXTENSION_NAME}
Expand Down
109 changes: 64 additions & 45 deletions src/azure_extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@

namespace duckdb {

static Azure::Identity::ChainedTokenCredential::Sources CreateCredentialChainFromSetting(const string& credential_chain) {
static Azure::Identity::ChainedTokenCredential::Sources
CreateCredentialChainFromSetting(const string &credential_chain) {
auto chain_list = StringUtil::Split(credential_chain, ';');
Azure::Identity::ChainedTokenCredential::Sources result;

for (const auto& item : chain_list) {
for (const auto &item : chain_list) {
if (item == "cli") {
result.push_back(std::make_shared<Azure::Identity::AzureCliCredential>());
} else if (item == "managed_identity") {
Expand All @@ -41,7 +42,7 @@ static Azure::Identity::ChainedTokenCredential::Sources CreateCredentialChainFro
return result;
}

static AzureAuthentication ParseAzureAuthSettings(FileOpener* opener) {
static AzureAuthentication ParseAzureAuthSettings(FileOpener *opener) {
AzureAuthentication auth;

Value connection_string_val;
Expand Down Expand Up @@ -70,9 +71,10 @@ static AzureAuthentication ParseAzureAuthSettings(FileOpener* opener) {
return auth;
}

static Azure::Storage::Blobs::BlobContainerClient GetContainerClient(AzureAuthentication& auth, AzureParsedUrl& url) {
static Azure::Storage::Blobs::BlobContainerClient GetContainerClient(AzureAuthentication &auth, AzureParsedUrl &url) {
if (!auth.connection_string.empty()) {
return Azure::Storage::Blobs::BlobContainerClient::CreateFromConnectionString(auth.connection_string, url.container);
return Azure::Storage::Blobs::BlobContainerClient::CreateFromConnectionString(auth.connection_string,
url.container);
}

// Build credential chain, from last to first
Expand All @@ -87,42 +89,47 @@ static Azure::Storage::Blobs::BlobContainerClient GetContainerClient(AzureAuthen
auto chainedTokenCredential = std::make_shared<Azure::Identity::ChainedTokenCredential>(credential_chain);
Azure::Storage::Blobs::BlobServiceClient blob_service_client(accountURL, chainedTokenCredential);
return blob_service_client.GetBlobContainerClient(url.container);
} else if (!auth.account_name.empty()){
} else if (!auth.account_name.empty()) {
return Azure::Storage::Blobs::BlobContainerClient(accountURL + "/" + url.container);
} else {
throw InvalidInputException("No valid Azure credentials found, use either the azure_connection_string or azure_account_name");
throw InvalidInputException(
"No valid Azure credentials found, use either the azure_connection_string or azure_account_name");
}
}

BlobClientWrapper::BlobClientWrapper(AzureAuthentication& auth, AzureParsedUrl& url) {
BlobClientWrapper::BlobClientWrapper(AzureAuthentication &auth, AzureParsedUrl &url) {
auto container_client = GetContainerClient(auth, url);
blob_client = make_uniq<Azure::Storage::Blobs::BlockBlobClient>(container_client.GetBlockBlobClient(url.path));
}

BlobClientWrapper::~BlobClientWrapper() = default;
Azure::Storage::Blobs::BlobClient* BlobClientWrapper::GetClient() {
return blob_client.get();
Azure::Storage::Blobs::BlobClient *BlobClientWrapper::GetClient() {
return blob_client.get();
};

AzureStorageFileHandle::AzureStorageFileHandle(FileSystem &fs, string path_p, uint8_t flags, AzureAuthentication& auth, AzureParsedUrl parsed_url)
: FileHandle(fs, std::move(path_p)), flags(flags), length(0), last_modified(time_t()), buffer_available(0), buffer_idx(0), file_offset(0),
buffer_start(0), buffer_end(0), blob_client(auth, parsed_url) {
AzureStorageFileHandle::AzureStorageFileHandle(FileSystem &fs, string path_p, uint8_t flags, AzureAuthentication &auth,
AzureParsedUrl parsed_url)
: FileHandle(fs, std::move(path_p)), flags(flags), length(0), last_modified(time_t()), buffer_available(0),
buffer_idx(0), file_offset(0), buffer_start(0), buffer_end(0), blob_client(auth, parsed_url) {
try {
auto client = *blob_client.GetClient();
auto res = client.GetProperties();
length = res.Value.BlobSize;
} catch (Azure::Storage::StorageException &e) {
throw IOException("AzureStorageFileSystem open file '" + path + "' failed with code'" + e.ErrorCode + "',Reason Phrase: '" + e.ReasonPhrase + "', Message: '" + e.Message + "'");
throw IOException("AzureStorageFileSystem open file '" + path + "' failed with code'" + e.ErrorCode +
"',Reason Phrase: '" + e.ReasonPhrase + "', Message: '" + e.Message + "'");
}

if (flags & FileFlags::FILE_FLAGS_READ) {
read_buffer = duckdb::unique_ptr<data_t[]>(new data_t[READ_BUFFER_LEN]);
}
}

unique_ptr<AzureStorageFileHandle> AzureStorageFileSystem::CreateHandle(const string &path, uint8_t flags, FileLockType lock,
FileCompressionType compression, FileOpener *opener) {
D_ASSERT(compression == FileCompressionType::UNCOMPRESSED);
unique_ptr<AzureStorageFileHandle> AzureStorageFileSystem::CreateHandle(const string &path, uint8_t flags,
FileLockType lock,
FileCompressionType compression,
FileOpener *opener) {
D_ASSERT(compression == FileCompressionType::UNCOMPRESSED);

auto parsed_url = ParseUrl(path);
auto azure_auth = ParseAzureAuthSettings(opener);
Expand All @@ -131,60 +138,70 @@ unique_ptr<AzureStorageFileHandle> AzureStorageFileSystem::CreateHandle(const st
}

unique_ptr<FileHandle> AzureStorageFileSystem::OpenFile(const string &path, uint8_t flags, FileLockType lock,
FileCompressionType compression, FileOpener *opener) {
D_ASSERT(compression == FileCompressionType::UNCOMPRESSED);
FileCompressionType compression, FileOpener *opener) {
D_ASSERT(compression == FileCompressionType::UNCOMPRESSED);

if (flags & FileFlags::FILE_FLAGS_WRITE) {
throw NotImplementedException("Writing to Azure containers is currently not supported");
}

auto handle = CreateHandle(path, flags, lock, compression, opener);
return std::move(handle);
auto handle = CreateHandle(path, flags, lock, compression, opener);
return std::move(handle);
}

int64_t AzureStorageFileSystem::GetFileSize(FileHandle &handle) {
auto &afh = (AzureStorageFileHandle &)handle;
return afh.length;
auto &afh = (AzureStorageFileHandle &)handle;
return afh.length;
}

time_t AzureStorageFileSystem::GetLastModifiedTime(FileHandle &handle) {
auto &afh = (AzureStorageFileHandle &)handle;
return afh.last_modified;
auto &afh = (AzureStorageFileHandle &)handle;
return afh.last_modified;
}

// TODO: this is currently a bit weird: it should be az:// but that shit dont work
bool AzureStorageFileSystem::CanHandleFile(const string &fpath) {
return fpath.rfind("azure://", 0) == 0;
return fpath.rfind("azure://", 0) == 0;
}

void AzureStorageFileSystem::Seek(FileHandle &handle, idx_t location) {
auto &sfh = (AzureStorageFileHandle &)handle;
sfh.file_offset = location;
auto &sfh = (AzureStorageFileHandle &)handle;
sfh.file_offset = location;
}

void AzureStorageFileSystem::FileSync(FileHandle &handle) {
throw NotImplementedException("FileSync for Azure Storage files not implemented");
throw NotImplementedException("FileSync for Azure Storage files not implemented");
}

static void LoadInternal(DatabaseInstance &instance) {
// Load filesystem
auto &fs = instance.GetFileSystem();
fs.RegisterSubSystem(make_uniq<AzureStorageFileSystem>());
auto &fs = instance.GetFileSystem();
fs.RegisterSubSystem(make_uniq<AzureStorageFileSystem>());

// Load extension config
auto &config = DBConfig::GetConfig(instance);
config.AddExtensionOption("azure_storage_connection_string", "Azure connection string, used for authenticating and configuring azure requests", LogicalType::VARCHAR);
config.AddExtensionOption("azure_account_name", "Azure account name, when set, the extension will attempt to automatically detect credentials", LogicalType::VARCHAR);
config.AddExtensionOption("azure_credential_chain", "Ordered list of Azure credential providers, in string format separated by ';'. E.g. 'cli;managed_identity;env'", LogicalType::VARCHAR, "none");
config.AddExtensionOption("azure_endpoint", "Override the azure endpoint for when the Azure credential providers are used.", LogicalType::VARCHAR, "blob.core.windows.net");
config.AddExtensionOption("azure_storage_connection_string",
"Azure connection string, used for authenticating and configuring azure requests",
LogicalType::VARCHAR);
config.AddExtensionOption(
"azure_account_name",
"Azure account name, when set, the extension will attempt to automatically detect credentials",
LogicalType::VARCHAR);
config.AddExtensionOption("azure_credential_chain",
"Ordered list of Azure credential providers, in string format separated by ';'. E.g. "
"'cli;managed_identity;env'",
LogicalType::VARCHAR, "none");
config.AddExtensionOption("azure_endpoint",
"Override the azure endpoint for when the Azure credential providers are used.",
LogicalType::VARCHAR, "blob.core.windows.net");
}

int64_t AzureStorageFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) {
auto &hfh = (AzureStorageFileHandle &)handle;
idx_t max_read = hfh.length - hfh.file_offset;
nr_bytes = MinValue<idx_t>(max_read, nr_bytes);
Read(handle, buffer, nr_bytes, hfh.file_offset);
return nr_bytes;
auto &hfh = (AzureStorageFileHandle &)handle;
idx_t max_read = hfh.length - hfh.file_offset;
nr_bytes = MinValue<idx_t>(max_read, nr_bytes);
Read(handle, buffer, nr_bytes, hfh.file_offset);
return nr_bytes;
}

// taken from s3fs.cpp TODO: deduplicate!
Expand Down Expand Up @@ -232,12 +249,13 @@ vector<string> AzureStorageFileSystem::Glob(const string &path, FileOpener *open
vector<Azure::Storage::Blobs::Models::BlobItem> found_keys;
Azure::Storage::Blobs::ListBlobsOptions options;
options.Prefix = shared_path;
while(true) {
while (true) {
Azure::Storage::Blobs::ListBlobsPagedResponse res;
try {
res = container_client.ListBlobs(options);
} catch (Azure::Storage::StorageException &e) {
throw IOException("AzureStorageFileSystem Read to " + path + " failed with " + e.ErrorCode + "Reason Phrase: " + e.ReasonPhrase);
throw IOException("AzureStorageFileSystem Read to " + path + " failed with " + e.ErrorCode +
"Reason Phrase: " + e.ReasonPhrase);
}

found_keys.insert(found_keys.end(), res.Blobs.begin(), res.Blobs.end());
Expand Down Expand Up @@ -352,11 +370,12 @@ void AzureStorageFileSystem::ReadRange(FileHandle &handle, idx_t file_offset, ch
auto res = blob_client.DownloadTo((uint8_t *)buffer_out, buffer_out_len, options);

} catch (Azure::Storage::StorageException &e) {
throw IOException("AzureStorageFileSystem Read to " + afh.path + " failed with " + e.ErrorCode + "Reason Phrase: " + e.ReasonPhrase);
throw IOException("AzureStorageFileSystem Read to " + afh.path + " failed with " + e.ErrorCode +
"Reason Phrase: " + e.ReasonPhrase);
}
}

AzureParsedUrl AzureStorageFileSystem::ParseUrl(const string& url) {
AzureParsedUrl AzureStorageFileSystem::ParseUrl(const string &url) {
string container, path;

if (url.rfind("azure://", 0) != 0) {
Expand All @@ -371,7 +390,7 @@ AzureParsedUrl AzureStorageFileSystem::ParseUrl(const string& url) {
throw IOException("URL needs to contain a bucket name");
}

path = url.substr(slash_pos+1);
path = url.substr(slash_pos + 1);
return {container, path};
}

Expand Down
20 changes: 11 additions & 9 deletions src/include/azure_extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
#include "duckdb.hpp"

namespace Azure {
namespace Storage {
namespace Blobs {
class BlobClient;
}
}
namespace Storage {
namespace Blobs {
class BlobClient;
}
} // namespace Storage
} // namespace Azure

namespace duckdb {

Expand All @@ -35,16 +35,18 @@ struct AzureParsedUrl {

class BlobClientWrapper {
public:
BlobClientWrapper(AzureAuthentication& auth, AzureParsedUrl& url);
BlobClientWrapper(AzureAuthentication &auth, AzureParsedUrl &url);
~BlobClientWrapper();
Azure::Storage::Blobs::BlobClient* GetClient();
Azure::Storage::Blobs::BlobClient *GetClient();

protected:
unique_ptr<Azure::Storage::Blobs::BlobClient> blob_client;
};

class AzureStorageFileHandle : public FileHandle {
public:
AzureStorageFileHandle(FileSystem &fs, string path, uint8_t flags, AzureAuthentication& auth, AzureParsedUrl parsed_url);
AzureStorageFileHandle(FileSystem &fs, string path, uint8_t flags, AzureAuthentication &auth,
AzureParsedUrl parsed_url);
~AzureStorageFileHandle() override = default;

public:
Expand Down Expand Up @@ -103,7 +105,7 @@ class AzureStorageFileSystem : public FileSystem {
static void Verify();

protected:
static AzureParsedUrl ParseUrl(const string& url);
static AzureParsedUrl ParseUrl(const string &url);
static void ReadRange(FileHandle &handle, idx_t file_offset, char *buffer_out, idx_t buffer_out_len);
virtual duckdb::unique_ptr<AzureStorageFileHandle> CreateHandle(const string &path, uint8_t flags,
FileLockType lock, FileCompressionType compression,
Expand Down

0 comments on commit fb68a82

Please sign in to comment.