Skip to content

Commit

Permalink
refactor: auth token helper methods (#198)
Browse files Browse the repository at this point in the history
- feat: add connection options for federated authentication
- refactor: move auth token methods to a separate helper class
  • Loading branch information
karenc-bq authored Jul 31, 2024
1 parent 8b7bedd commit c114d75
Show file tree
Hide file tree
Showing 17 changed files with 438 additions and 165 deletions.
4 changes: 4 additions & 0 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
SET(DRIVER_NAME_STATIC "awsmysqlodbc${CONNECTOR_DRIVER_TYPE_SHORT}-static")

SET(DRIVER_SRCS
adfs_proxy.cc
auth_util.cc
aws_sdk_helper.cc
base_metrics_holder.cc
catalog.cc
Expand Down Expand Up @@ -122,6 +124,8 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.def.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.def @ONLY)
CONFIGURE_FILE(${CMAKE_SOURCE_DIR}/driver/driver.rc.cmake ${CMAKE_SOURCE_DIR}/driver/driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc @ONLY)
SET(DRIVER_SRCS ${DRIVER_SRCS} driver${CONNECTOR_DRIVER_TYPE_SHORT}.def driver${CONNECTOR_DRIVER_TYPE_SHORT}.rc
adfs_proxy.h
auth_util.h
aws_sdk_helper.h
base_metrics_holder.h
catalog.h
Expand Down
58 changes: 58 additions & 0 deletions driver/adfs_proxy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "adfs_proxy.h"
#include "driver.h"

ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds) : ADFS_PROXY(dbc, ds, nullptr) {};

ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
if (ds->opt_AUTH_REGION) {
this->auth_util = std::make_shared<AUTH_UTIL>((const char*)ds->opt_AUTH_REGION);
}
else {
this->auth_util = std::make_shared<AUTH_UTIL>();
}
}

#ifdef UNIT_TEST_BUILD
ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy,
std::shared_ptr<AUTH_UTIL> auth_util) : CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
this->auth_util = auth_util;
}
#endif

ADFS_PROXY::~ADFS_PROXY() { this->auth_util.reset(); }

bool ADFS_PROXY::connect(const char* host, const char* user, const char* password, const char* database,
unsigned int port, const char* socket, unsigned long flags) {
return true;
}
69 changes: 69 additions & 0 deletions driver/adfs_proxy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#ifndef __ADFS_PROXY__
#define __ADFS_PROXY__

#include <unordered_map>
#include "auth_util.h"

class ADFS_PROXY : public CONNECTION_PROXY {
public:
ADFS_PROXY() = default;
ADFS_PROXY(DBC* dbc, DataSource* ds);
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy);
#ifdef UNIT_TEST_BUILD
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr<AUTH_UTIL> auth_util);
#endif
~ADFS_PROXY() override;
bool connect(
const char* host,
const char* user,
const char* password,
const char* database,
unsigned int port,
const char* socket,
unsigned long flags) override;

protected:
static std::unordered_map<std::string, TOKEN_INFO> token_cache;
static std::mutex token_cache_mutex;
std::shared_ptr<AUTH_UTIL> auth_util;
bool using_cached_token = false;

static void clear_token_cache();

#ifdef UNIT_TEST_BUILD
// Allows for testing private/protected methods
friend class TEST_UTILS;
#endif
};

#endif

64 changes: 64 additions & 0 deletions driver/auth_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "auth_util.h"
#include "aws_sdk_helper.h"
#include "driver.h"

namespace {
AWS_SDK_HELPER SDK_HELPER;
}

AUTH_UTIL::AUTH_UTIL(const char* region) {
++SDK_HELPER;

Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider;
Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials();

Aws::RDS::RDSClientConfiguration client_config;
if (region) {
client_config.region = region;
}

this->rds_client = std::make_shared<Aws::RDS::RDSClient>(credentials, client_config);
};

std::string AUTH_UTIL::get_auth_token(const char* host, const char* region, unsigned int port, const char* user) {
return this->rds_client->GenerateConnectAuthToken(host, region, port, user);
}

std::string AUTH_UTIL::build_cache_key(const char* host, const char* region, unsigned int port, const char* user) {
// Format should be "<region>:<host>:<port>:<user>"
return std::string(region).append(":").append(host).append(":").append(std::to_string(port)).append(":").append(user);
}

AUTH_UTIL::~AUTH_UTIL() {
this->rds_client.reset();
--SDK_HELPER;
}
78 changes: 78 additions & 0 deletions driver/auth_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#ifndef __AUTH_UTIL__
#define __AUTH_UTIL__

#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/rds/RDSClient.h>

#include "connection_proxy.h"

constexpr auto DEFAULT_TOKEN_EXPIRATION_SEC = 15 * 60;

class TOKEN_INFO {
public:
TOKEN_INFO() {};
TOKEN_INFO(std::string token) : TOKEN_INFO(token, DEFAULT_TOKEN_EXPIRATION_SEC) {};
TOKEN_INFO(std::string token, unsigned int seconds_until_expiration) {
this->token = token;
this->expiration_time = std::chrono::system_clock::now() + std::chrono::seconds(seconds_until_expiration);
}

bool is_expired() {
std::chrono::system_clock::time_point current_time = std::chrono::system_clock::now();
return current_time > this->expiration_time;
}

std::string token;

private:
std::chrono::system_clock::time_point expiration_time;
};

class AUTH_UTIL {
public:
AUTH_UTIL() {};
AUTH_UTIL(const char* region);
~AUTH_UTIL();

virtual std::string get_auth_token(const char* host, const char* region, unsigned int port, const char* user);
static std::string build_cache_key(const char* host, const char* region, unsigned int port, const char* user);

private:
std::shared_ptr<Aws::RDS::RDSClient> rds_client;

#ifdef UNIT_TEST_BUILD
// Allows for testing private/protected methods
friend class TEST_UTILS;
#endif
};

#endif
49 changes: 13 additions & 36 deletions driver/iam_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,49 +29,36 @@

#include <functional>

#include "aws_sdk_helper.h"
#include "driver.h"
#include "iam_proxy.h"

namespace {
AWS_SDK_HELPER SDK_HELPER;
}

std::unordered_map<std::string, TOKEN_INFO> IAM_PROXY::token_cache;
std::mutex IAM_PROXY::token_cache_mutex;

IAM_PROXY::IAM_PROXY(DBC* dbc, DataSource* ds) : IAM_PROXY(dbc, ds, nullptr) {};

IAM_PROXY::IAM_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) {
++SDK_HELPER;

this->next_proxy = next_proxy;

Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider;
Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials();

Aws::RDS::RDSClientConfiguration client_config;
if (ds->opt_AUTH_REGION) {
client_config.region = (const char*) ds->opt_AUTH_REGION;
}
this->next_proxy = next_proxy;
if (ds->opt_AUTH_REGION) {
this->auth_util = std::make_shared<AUTH_UTIL>((const char*)ds->opt_AUTH_REGION);
} else {
this->auth_util = std::make_shared<AUTH_UTIL>();
}
}

this->token_generator = std::make_shared<TOKEN_GENERATOR>(credentials, client_config);
IAM_PROXY::~IAM_PROXY() {
this->auth_util.reset();
}

#ifdef UNIT_TEST_BUILD
IAM_PROXY::IAM_PROXY(DBC *dbc, DataSource *ds, CONNECTION_PROXY *next_proxy,
std::shared_ptr<TOKEN_GENERATOR> token_generator) : CONNECTION_PROXY(dbc, ds) {
std::shared_ptr<AUTH_UTIL> auth_util) : CONNECTION_PROXY(dbc, ds) {

this->next_proxy = next_proxy;
this->token_generator = token_generator;
this->auth_util = auth_util;
}
#endif

IAM_PROXY::~IAM_PROXY() {
this->token_generator.reset();
--SDK_HELPER;
}

bool IAM_PROXY::connect(const char* host, const char* user, const char* password,
const char* database, unsigned int port, const char* socket, unsigned long flags) {

Expand Down Expand Up @@ -101,7 +88,7 @@ std::string IAM_PROXY::get_auth_token(
}

std::string auth_token;
std::string cache_key = build_cache_key(host, region, port, user);
std::string cache_key = this->auth_util->build_cache_key(host, region, port, user);
using_cached_token = false;

{
Expand All @@ -125,24 +112,14 @@ std::string IAM_PROXY::get_auth_token(
}

// Generate new token
auth_token = token_generator->generate_auth_token(host, region, port, user);
auth_token = this->auth_util->get_auth_token(host, region, port, user);

token_cache[cache_key] = TOKEN_INFO(auth_token, time_until_expiration);
}

return auth_token;
}

std::string IAM_PROXY::build_cache_key(
const char* host, const char* region, unsigned int port, const char* user) {

// Format should be "<region>:<host>:<port>:<user>"
return std::string(region)
.append(":").append(host)
.append(":").append(std::to_string(port))
.append(":").append(user);
}

void IAM_PROXY::clear_token_cache() {
std::unique_lock<std::mutex> lock(token_cache_mutex);
token_cache.clear();
Expand Down
Loading

0 comments on commit c114d75

Please sign in to comment.