Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adfs authentication support #214

Merged
merged 5 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/build-installer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ jobs:
run: |
curl -L https://dev.mysql.com/get/Downloads/MySQL-8.3/mysql-${{ vars.MYSQL_VERSION }}-winx64.zip -o mysql.zip
unzip -d C:/ mysql.zip

- name: Install OpenSSL 3
run: |
curl -L https://download.firedaemon.com/FireDaemon-OpenSSL/openssl-3.3.1.zip -o openssl3.zip
unzip -d C:/ openssl3.zip
cp -r C:/openssl-3/x64/bin/libssl-3-x64.dll C:/Windows/System32/
cp -r C:/openssl-3/x64/bin/libcrypto-3-x64.dll C:/Windows/System32/

- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v2
Expand Down
225 changes: 216 additions & 9 deletions driver/adfs_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,238 @@
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "adfs_proxy.h"
#include <regex>
#include "driver.h"

#define SIGN_IN_PAGE_URL "/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=urn:amazon:webservices"

std::unordered_map<std::string, TOKEN_INFO> ADFS_PROXY::token_cache;
std::mutex ADFS_PROXY::token_cache_mutex;

ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds) : ADFS_PROXY(dbc, ds, nullptr) {};

ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
if (ds->opt_AUTH_REGION) {
this->auth_util = std::make_shared<AUTH_UTIL>((const char*)ds->opt_AUTH_REGION);
std::string host{static_cast<const char*>(ds->opt_IDP_ENDPOINT)};
host += ":" + std::to_string(ds->opt_IDP_PORT);

const int client_connect_timeout = ds->opt_CLIENT_CONNECT_TIMEOUT;
const int client_socket_timeout = ds->opt_CLIENT_SOCKET_TIMEOUT;
const bool enable_ssl = ds->opt_ENABLE_SSL;
this->saml_util = std::make_shared<ADFS_SAML_UTIL>(host, client_connect_timeout, client_socket_timeout, enable_ssl);
}

void ADFS_PROXY::clear_token_cache() {
std::unique_lock<std::mutex> lock(token_cache_mutex);
token_cache.clear();
}

ADFS_SAML_UTIL::ADFS_SAML_UTIL(const std::shared_ptr<SAML_HTTP_CLIENT>& client) { this->http_client = client; }

ADFS_SAML_UTIL::ADFS_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) {
this->http_client =
std::make_shared<SAML_HTTP_CLIENT>("https://" + host, connect_timeout, socket_timeout, enable_ssl);
}

std::string ADFS_SAML_UTIL::get_saml_assertion(DataSource* ds) {
nlohmann::json res;
try {
res = this->http_client->get(std::string(SIGN_IN_PAGE_URL));
} catch (SAML_HTTP_EXCEPTION& e) {
const std::string error =
"Failed to get sign-in page from ADFS: " + e.error_message() + ". Please verify your IDP endpoint.";
throw SAML_HTTP_EXCEPTION(error);
}

const auto body = std::string(res);
std::smatch m;
if (!std::regex_search(body, m, ADFS_REGEX::FORM_ACTION_PATTERN)) {
return std::string();
}
else {
this->auth_util = std::make_shared<AUTH_UTIL>();
std::string form_action = unescape_html_entity(m.str(1));
const std::string params = get_parameters_from_html(ds, body);
const std::string content = get_form_action_body(form_action, params);
if (std::regex_search(content, m, ADFS_REGEX::SAML_RESPONSE_PATTERN)) {
return m.str(1);
}
return std::string();
}

std::string ADFS_SAML_UTIL::unescape_html_entity(const std::string& html) {
std::string retval("");
int i = 0;
int length = html.length();
while (i < length) {
char c = html[i];
if (c != '&') {
retval.append(1, c);
i++;
continue;
}

if (html.substr(i, 4) == "&lt;") {
retval.append(1, '<');
i += 4;
} else if (html.substr(i, 4) == "&gt;") {
retval.append(1, '>');
i += 4;
} else if (html.substr(i, 5) == "&amp;") {
retval.append(1, '&');
i += 5;
} else if (html.substr(i, 6) == "&apos;") {
retval.append(1, '\'');
i += 6;
} else if (html.substr(i, 6) == "&quot;") {
retval.append(1, '"');
i += 6;
} else {
retval.append(1, c);
++i;
}
}
return retval;
}

std::vector<std::string> ADFS_SAML_UTIL::get_input_tags_from_html(const std::string& body) {
std::unordered_set<std::string> hashSet;
std::vector<std::string> retval;

std::smatch matches;
std::regex pattern(ADFS_REGEX::INPUT_TAG_PATTERN);
std::string source = body;
while (std::regex_search(source, matches, pattern)) {
std::string tag = matches.str(0);
std::string tagName = get_value_by_key(tag, std::string("name"));
std::transform(tagName.begin(), tagName.end(), tagName.begin(), [](unsigned char c) { return std::tolower(c); });
if (!tagName.empty() && hashSet.find(tagName) == hashSet.end()) {
hashSet.insert(tagName);
retval.push_back(tag);
}

source = matches.suffix().str();
}

return retval;
}

std::string ADFS_SAML_UTIL::get_value_by_key(const std::string& input, const std::string& key) {
std::string pattern("(");
pattern += key;
pattern += ")\\s*=\\s*\"(.*?)\"";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: std::regex pattern on line 129 above is using a const for the pattern. Does it make sense to define a const for this pattern as well?


std::smatch matches;
if (std::regex_search(input, matches, std::regex(pattern))) {
MYLOG_TRACE(init_log_file(), 0, "get_value_by_key");
return unescape_html_entity(matches.str(2));
}
return "";
}

std::string ADFS_SAML_UTIL::get_parameters_from_html(DataSource* ds, const std::string& body) {
std::map<std::string, std::string> parameters;
for (auto& inputTag : get_input_tags_from_html(body)) {
std::string name = get_value_by_key(inputTag, std::string("name"));
std::string value = get_value_by_key(inputTag, std::string("value"));
std::string nameLower = name;
std::transform(nameLower.begin(), nameLower.end(), nameLower.begin(),
[](unsigned char c) { return std::tolower(c); });

const std::string username = static_cast<const char*>(ds->opt_IDP_USERNAME);
const std::string password = static_cast<const char*>(ds->opt_IDP_PASSWORD);

if (nameLower.find("username") != std::string::npos) {
parameters.insert(std::pair<std::string, std::string>(name, username));
} else if ((nameLower.find("authmethod") != std::string::npos) && !value.empty()) {
parameters.insert(std::pair<std::string, std::string>(name, value));
} else if (nameLower.find("password") != std::string::npos) {
parameters.insert(std::pair<std::string, std::string>(name, password));
} else if (!name.empty()) {
parameters.insert(std::pair<std::string, std::string>(name, value));
}
}

// Convert parameters to a & delimited string, e.g. username=u&password=p
const std::string delimiter = "&";
const std::string result =
std::accumulate(parameters.begin(), parameters.end(), std::string(),
[delimiter](const std::string& s, const std::pair<const std::string, std::string>& p) {
return s + (s.empty() ? std::string() : delimiter) + p.first + "=" + p.second;
});

return result;
}

std::string ADFS_SAML_UTIL::get_form_action_body(const std::string& url, const std::string& params) {
nlohmann::json res;
try {
res = this->http_client->post(url, params, "application/x-www-form-urlencoded");
} catch (SAML_HTTP_EXCEPTION& e) {
const std::string error =
"Failed to get SAML Assertion from ADFS : " + e.error_message() + ". Please verify your ADFS credentials.";
throw SAML_HTTP_EXCEPTION(error);
}
return res.empty() ? "" : res;
}

#ifdef UNIT_TEST_BUILD
ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy,
std::shared_ptr<AUTH_UTIL> auth_util) : CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
this->auth_util = auth_util;
ADFS_PROXY::ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr<AUTH_UTIL> auth_util,
const std::shared_ptr<SAML_HTTP_CLIENT>& client)
: CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
this->auth_util = auth_util;
this->saml_util = std::make_shared<ADFS_SAML_UTIL>(client);
}
#endif

ADFS_PROXY::~ADFS_PROXY() { this->auth_util.reset(); }

bool ADFS_PROXY::connect(const char* host, const char* user, const char* password, const char* database,
unsigned int port, const char* socket, unsigned long flags) {
return true;
auto func = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port,
socket, flags);
const char* region =
ds->opt_FED_AUTH_REGION ? static_cast<const char*>(ds->opt_FED_AUTH_REGION) : Aws::Region::US_EAST_1;
std::string assertion;
try {
assertion = this->saml_util->get_saml_assertion(ds);
} catch (SAML_HTTP_EXCEPTION& e) {
this->set_custom_error_message(e.error_message().c_str());
return false;
}

auto idp_host = static_cast<const char*>(ds->opt_IDP_ENDPOINT);
auto iam_role_arn = static_cast<const char*>(ds->opt_IAM_ROLE_ARN);
auto idp_arn = static_cast<const char*>(ds->opt_IAM_IDP_ARN);
const Aws::Auth::AWSCredentials credentials =
this->saml_util->get_aws_credentials(idp_host, region, iam_role_arn, idp_arn, assertion);
this->auth_util = std::make_shared<AUTH_UTIL>(region, credentials);

const char* auth_host = ds->opt_FED_AUTH_HOST ? static_cast<const char*>(ds->opt_FED_AUTH_HOST)
: static_cast<const char*>(ds->opt_SERVER);
const int auth_port = ds->opt_FED_AUTH_PORT;

std::string auth_token;
bool using_cached_token;
std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token(
token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION);

bool connect_result = func(auth_token.c_str());
if (!connect_result) {
if (using_cached_token) {
// Retry func with a fresh token
std::tie(auth_token, using_cached_token) = this->auth_util->get_auth_token(
token_cache, token_cache_mutex, auth_host, region, auth_port, ds->opt_UID, ds->opt_AUTH_EXPIRATION, true);
if (func(auth_token.c_str())) {
return true;
}
}

if (credentials.IsEmpty()) {
this->set_custom_error_message(
"Unable to generate temporary AWS credentials from the SAML assertion. Please ensure the ADFS identity "
"provider is correctly configured with AWS.");
}
}

return connect_result;
}
69 changes: 45 additions & 24 deletions driver/adfs_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,40 +30,61 @@
#ifndef __ADFS_PROXY__
#define __ADFS_PROXY__

#include <regex>
#include <unordered_map>
#include "auth_util.h"
#include "saml_http_client.h"
#include "saml_util.h"

namespace ADFS_REGEX {
const std::regex FORM_ACTION_PATTERN(R"#(<form.*?action=\"([^\"]+)\")#", std::regex_constants::icase);
const std::regex SAML_RESPONSE_PATTERN("\"SAMLResponse\"\\W+value=\"(.*?)\"(\\s*/>)", std::regex_constants::icase);
const std::regex URL_PATTERN(R"#(^(https)://[-a-zA-Z0-9+&@#/%?=~_!:,.']*[-a-zA-Z0-9+&@#/%=~_'])#",
std::regex_constants::icase);
const std::regex INPUT_TAG_PATTERN(R"#(<input id=(.*))#", std::regex_constants::icase);
} // namespace ADFS_REGEX

class ADFS_SAML_UTIL : public SAML_UTIL {
public:
ADFS_SAML_UTIL(const std::shared_ptr<SAML_HTTP_CLIENT>& client);
ADFS_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl);
std::string get_saml_assertion(DataSource* ds) override;
std::shared_ptr<SAML_HTTP_CLIENT> http_client;

private:
static std::string unescape_html_entity(const std::string& html);
std::vector<std::string> get_input_tags_from_html(const std::string& body);
std::string get_value_by_key(const std::string& input, const std::string& key);
std::string get_parameters_from_html(DataSource* ds, const std::string& body);
std::string get_form_action_body(const std::string& url, const std::string& params);
};

class ADFS_PROXY : public CONNECTION_PROXY {
public:
ADFS_PROXY() = default;
ADFS_PROXY(DBC* dbc, DataSource* ds);
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy);
public:
ADFS_PROXY() = default;
ADFS_PROXY(DBC* dbc, DataSource* ds);
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy);
#ifdef UNIT_TEST_BUILD
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr<AUTH_UTIL> auth_util);
ADFS_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy, std::shared_ptr<AUTH_UTIL> auth_util,
const std::shared_ptr<SAML_HTTP_CLIENT>& client);
#endif
~ADFS_PROXY() override;
bool connect(
const char* host,
const char* user,
const char* password,
const char* database,
unsigned int port,
const char* socket,
unsigned long flags) override;

protected:
static std::unordered_map<std::string, TOKEN_INFO> token_cache;
static std::mutex token_cache_mutex;
std::shared_ptr<AUTH_UTIL> auth_util;
bool using_cached_token = false;
~ADFS_PROXY() override;
bool connect(const char* host, const char* user, const char* password, const char* database, unsigned int port,
const char* socket, unsigned long flags) override;

protected:
static std::unordered_map<std::string, TOKEN_INFO> token_cache;
static std::mutex token_cache_mutex;
std::shared_ptr<AUTH_UTIL> auth_util;
std::shared_ptr<ADFS_SAML_UTIL> saml_util;
bool using_cached_token = false;

static void clear_token_cache();
static void clear_token_cache();

#ifdef UNIT_TEST_BUILD
// Allows for testing private/protected methods
friend class TEST_UTILS;
// Allows for testing private/protected methods
friend class TEST_UTILS;
#endif
};

#endif

Loading
Loading