Skip to content

Commit

Permalink
feat: add more configuration params for federated authentication (#212)
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq authored Sep 4, 2024
1 parent 81d33d4 commit 8bf63c4
Show file tree
Hide file tree
Showing 12 changed files with 106 additions and 66 deletions.
56 changes: 28 additions & 28 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License, version 2.0
# (GPLv2), as published by the Free Software Foundation, with the
# following additional permissions:
#
# This program is distributed with certain software that is licensed
# under separate terms, as designated in a particular file or component
# or in the license documentation. Without limiting your rights under
# the GPLv2, the authors of this program hereby grant you an additional
# permission to link the program and your derivative works with the
# separately licensed software that they have included with the program.
#
# Without limiting the foregoing grant of rights under the GPLv2 and
# additional permission as to separately licensed software, this
# program is also subject to the Universal FOSS Exception, version 1.0,
# a copy of which can be found along with its FAQ at
# http://oss.oracle.com/licenses/universal-foss-exception.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License, version 2.0, for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see
# http://www.gnu.org/licenses/gpl-2.0.html.

Language: Cpp
# BasedOnStyle: Google
Expand Down
9 changes: 6 additions & 3 deletions driver/okta_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds) : OKTA_PROXY(dbc, ds, nullptr)
OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
const std::string idp_host{static_cast<const char*>(ds->opt_IDP_ENDPOINT)};
this->saml_util = std::make_shared<OKTA_SAML_UTIL>(idp_host);
const int client_connect_timeout = ds->opt_CLIENT_CONNECT_TIMEOUT;
const int client_socket_timeout = ds->opt_CLIENT_SOCKET_TIMEOUT;
const bool enable_ssl = ds->opt_ENABLE_SSL;
this->saml_util = std::make_shared<OKTA_SAML_UTIL>(idp_host, client_connect_timeout, client_socket_timeout, enable_ssl);
}

bool OKTA_PROXY::connect(const char* host, const char* user, const char* password, const char* database,
Expand Down Expand Up @@ -122,8 +125,8 @@ void OKTA_PROXY::clear_token_cache() {

OKTA_SAML_UTIL::OKTA_SAML_UTIL(const std::shared_ptr<SAML_HTTP_CLIENT>& client) { this->http_client = client; }

OKTA_SAML_UTIL::OKTA_SAML_UTIL(std::string host) {
this->http_client = std::make_shared<SAML_HTTP_CLIENT>("https://" + host);
OKTA_SAML_UTIL::OKTA_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) {
this->http_client = std::make_shared<SAML_HTTP_CLIENT>("https://" + host, connect_timeout, socket_timeout, enable_ssl);
}

std::string OKTA_SAML_UTIL::get_saml_url(DataSource* ds) {
Expand Down
2 changes: 1 addition & 1 deletion driver/okta_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ const std::regex SAML_RESPONSE_PATTERN(R"#(name=\"SAMLResponse\".+value=\"(.+)\"
class OKTA_SAML_UTIL : public SAML_UTIL {
public:
OKTA_SAML_UTIL(const std::shared_ptr<SAML_HTTP_CLIENT>& client);
OKTA_SAML_UTIL(std::string host);
OKTA_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl);
std::string get_saml_assertion(DataSource* ds) override;
std::string get_session_token(DataSource* ds) const;
static std::string get_saml_url(DataSource* ds);
Expand Down
16 changes: 13 additions & 3 deletions driver/saml_http_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,20 @@
#include "saml_http_client.h"
#include <utility>

SAML_HTTP_CLIENT::SAML_HTTP_CLIENT(std::string host) : host{std::move(host)} {}
SAML_HTTP_CLIENT::SAML_HTTP_CLIENT(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl)
: host{std::move(host)}, connect_timeout(connect_timeout), socket_timeout(socket_timeout), enable_ssl(enable_ssl) {}

nlohmann::json SAML_HTTP_CLIENT::post(const std::string& path, const nlohmann::json& value) {
httplib::Client SAML_HTTP_CLIENT::get_client() const {
httplib::Client client(host);
client.enable_server_certificate_verification(enable_ssl);
client.set_connection_timeout(connect_timeout);
client.set_read_timeout(socket_timeout);
client.set_write_timeout(socket_timeout);
return client;
}

nlohmann::json SAML_HTTP_CLIENT::post(const std::string& path, const nlohmann::json& value) {
httplib::Client client = this->get_client();
if (auto res = client.Post(path.c_str(), value.dump(), "application/json")) {
if (res->status == httplib::StatusCode::OK_200) {
nlohmann::json json_object = nlohmann::json::parse(res->body);
Expand All @@ -46,7 +56,7 @@ nlohmann::json SAML_HTTP_CLIENT::post(const std::string& path, const nlohmann::j
}

nlohmann::json SAML_HTTP_CLIENT::get(const std::string& path) {
httplib::Client client(host);
httplib::Client client = this->get_client();
client.set_follow_location(true);
if (auto res = client.Get(path.c_str())) {
if (res->status == httplib::StatusCode::OK_200) {
Expand Down
6 changes: 5 additions & 1 deletion driver/saml_http_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ class SAML_HTTP_EXCEPTION: public std::exception {

class SAML_HTTP_CLIENT {
public:
SAML_HTTP_CLIENT(std::string host);
SAML_HTTP_CLIENT(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl);
~SAML_HTTP_CLIENT() = default;
virtual nlohmann::json post(const std::string& path, const nlohmann::json& value);
virtual nlohmann::json get(const std::string& path);

private:
const std::string host;
const int connect_timeout;
const int socket_timeout;
const bool enable_ssl;
httplib::Client get_client() const;
};

#endif
3 changes: 3 additions & 0 deletions setupgui/callbacks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,9 @@ void syncTabs(HWND hwnd, DataSource *params)
SET_STRING_TAB(FED_AUTH_TAB, AUTH_HOST);
SET_UNSIGNED_TAB(FED_AUTH_TAB, AUTH_PORT);
SET_UNSIGNED_TAB(FED_AUTH_TAB, AUTH_EXPIRATION);
SET_UNSIGNED_TAB(FED_AUTH_TAB, CLIENT_CONNECT_TIMEOUT);
SET_UNSIGNED_TAB(FED_AUTH_TAB, CLIENT_SOCKET_TIMEOUT);
SET_BOOL_TAB(FED_AUTH_TAB, ENABLE_SSL);

/* 5 - Failover */
SET_BOOL_TAB(FAILOVER_TAB, ENABLE_CLUSTER_FAILOVER);
Expand Down
53 changes: 29 additions & 24 deletions setupgui/windows/odbcdialogparams.rc
Original file line number Diff line number Diff line change
Expand Up @@ -221,30 +221,35 @@ IDD_TAB4 DIALOGEX 0, 0, 307, 245
STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD
FONT 8, "MS Shell Dlg", 400, 0, 0x1
BEGIN
LTEXT "Federated Authentication Mode:", IDC_STATIC, 207, 7, 80, 18
COMBOBOX IDC_EDIT_FED_AUTH_MODE, 207, 27, 79, 12, CBS_DROPDOWN | CBS_AUTOHSCROLL | CBS_SORT | WS_VSCROLL | WS_TABSTOP
RTEXT "IDP Username:", IDC_STATIC, 4, 6, 58, 18
EDITTEXT IDC_EDIT_IDP_USERNAME, 65, 6, 136, 12, ES_AUTOHSCROLL
RTEXT "IDP Password:", IDC_STATIC, 4, 27, 58, 18
EDITTEXT IDC_EDIT_IDP_PASSWORD, 65, 27, 136, 12, ES_PASSWORD | ES_AUTOHSCROLL
RTEXT "IDP Endpoint:", IDC_STATIC, 4, 47, 58, 18
EDITTEXT IDC_EDIT_IDP_ENDPOINT, 65, 46, 136, 12, ES_AUTOHSCROLL
RTEXT "App ID:", IDC_STATIC, 4, 67, 58, 18
EDITTEXT IDC_EDIT_APP_ID, 65, 67, 136, 12, ES_AUTOHSCROLL
RTEXT "IAM Role ARN:", IDC_STATIC, 3, 88, 58, 18
EDITTEXT IDC_EDIT_IAM_ROLE_ARN, 65, 87, 136, 12, ES_AUTOHSCROLL
RTEXT "IAM IDP ARN:", IDC_STATIC, 3, 108, 58, 18
EDITTEXT IDC_EDIT_IAM_IDP_ARN, 65, 107, 136, 12, ES_AUTOHSCROLL
LTEXT "IDP Port:", IDC_STATIC, 207, 125, 36, 10
EDITTEXT IDC_EDIT_IDP_PORT, 243, 124, 51, 12, ES_AUTOHSCROLL | ES_NUMBER
RTEXT "AWS Region:", IDC_STATIC, 3, 126, 58, 18
EDITTEXT IDC_EDIT_AUTH_REGION, 65, 125, 136, 12, ES_AUTOHSCROLL
RTEXT "IAM Host:", IDC_STATIC, 3, 145, 58, 18
EDITTEXT IDC_EDIT_AUTH_HOST, 65, 144, 136, 12, ES_AUTOHSCROLL
LTEXT "IAM Port:", IDC_STATIC, 207, 144, 36, 18
EDITTEXT IDC_EDIT_AUTH_PORT, 244, 143, 51, 12, ES_AUTOHSCROLL | ES_NUMBER
RTEXT "IAM Expire Time:", IDC_STATIC, 3, 163, 58, 18
EDITTEXT IDC_EDIT_AUTH_EXPIRATION, 65, 162, 136, 12, ES_AUTOHSCROLL | ES_NUMBER
LTEXT "Federated Authentication Mode:",IDC_STATIC,207,7,80,18
COMBOBOX IDC_EDIT_FED_AUTH_MODE,207,27,79,12,CBS_DROPDOWN | CBS_AUTOHSCROLL | CBS_SORT | WS_VSCROLL | WS_TABSTOP
RTEXT "IDP Username:",IDC_STATIC,4,6,58,18
EDITTEXT IDC_EDIT_IDP_USERNAME,65,6,136,12,ES_AUTOHSCROLL
RTEXT "IDP Password:",IDC_STATIC,4,27,58,18
EDITTEXT IDC_EDIT_IDP_PASSWORD,65,27,136,12,ES_PASSWORD | ES_AUTOHSCROLL
RTEXT "IDP Endpoint:",IDC_STATIC,4,47,58,18
EDITTEXT IDC_EDIT_IDP_ENDPOINT,65,46,136,12,ES_AUTOHSCROLL
RTEXT "App ID:",IDC_STATIC,4,67,58,18
EDITTEXT IDC_EDIT_APP_ID,65,67,136,12,ES_AUTOHSCROLL
RTEXT "IAM Role ARN:",IDC_STATIC,3,88,58,18
EDITTEXT IDC_EDIT_IAM_ROLE_ARN,65,87,136,12,ES_AUTOHSCROLL
RTEXT "IAM IDP ARN:",IDC_STATIC,3,108,58,18
EDITTEXT IDC_EDIT_IAM_IDP_ARN,65,107,136,12,ES_AUTOHSCROLL
LTEXT "IDP Port:",IDC_STATIC,207,125,36,10
EDITTEXT IDC_EDIT_IDP_PORT,243,124,51,12,ES_AUTOHSCROLL | ES_NUMBER
RTEXT "AWS Region:",IDC_STATIC,3,126,58,18
EDITTEXT IDC_EDIT_AUTH_REGION,65,125,136,12,ES_AUTOHSCROLL
RTEXT "IAM Host:",IDC_STATIC,3,145,58,18
EDITTEXT IDC_EDIT_AUTH_HOST,65,144,136,12,ES_AUTOHSCROLL
LTEXT "IAM Port:",IDC_STATIC,207,144,36,18
EDITTEXT IDC_EDIT_AUTH_PORT,244,143,51,12,ES_AUTOHSCROLL | ES_NUMBER
RTEXT "IAM Expire Time:",IDC_STATIC,3,163,58,18
EDITTEXT IDC_EDIT_AUTH_EXPIRATION,65,162,136,12,ES_AUTOHSCROLL | ES_NUMBER
LTEXT "Client Connect Timeout:",IDC_STATIC,207,47,86,10
EDITTEXT IDC_EDIT_CLIENT_CONNECT_TIMEOUT,207,59,51,12,ES_AUTOHSCROLL | ES_NUMBER
LTEXT "Client Socket Timeout:",IDC_STATIC,207,76,75,10
EDITTEXT IDC_EDIT_CLIENT_SOCKET_TIMEOUT,207,88,51,12,ES_AUTOHSCROLL | ES_NUMBER
CONTROL "&Enable SSL",IDC_CHECK_ENABLE_SSL,"Button",BS_AUTOCHECKBOX | WS_TABSTOP,207,108,47,10
END

IDD_TAB5 DIALOGEX 0, 0, 209, 281
Expand Down
3 changes: 3 additions & 0 deletions setupgui/windows/resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@
#define IDC_EDIT_IAM_IDP_ARN 11025
#define IDC_EDIT_FED_AUTH_MODE 11026
#define IDC_EDIT_IDP_PORT 11027
#define IDC_EDIT_CLIENT_CONNECT_TIMEOUT 11028
#define IDC_EDIT_CLIENT_SOCKET_TIMEOUT 11029
#define IDC_CHECK_ENABLE_SSL 11030
#define MYSQL_ADMIN_PORT 33062
#define IDC_STATIC -1

Expand Down
2 changes: 1 addition & 1 deletion unit_testing/mock_objects.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class MOCK_AUTH_UTIL : public AUTH_UTIL {

class MOCK_SAML_HTTP_CLIENT : public SAML_HTTP_CLIENT {
public:
MOCK_SAML_HTTP_CLIENT(std::string host) : SAML_HTTP_CLIENT(host) {};
MOCK_SAML_HTTP_CLIENT(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) : SAML_HTTP_CLIENT(host, connect_timeout, socket_timeout, enable_ssl) {};
MOCK_METHOD(nlohmann::json, post, (const std::string&, const nlohmann::json&));
MOCK_METHOD(nlohmann::json, get, (const std::string&));
};
Expand Down
2 changes: 1 addition & 1 deletion unit_testing/okta_proxy_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class OktaProxyTest : public testing::Test {
ds->opt_AUTH_PORT = TEST_PORT;
ds->opt_AUTH_EXPIRATION = TEST_EXPIRATION;

mock_saml_http_client = std::make_shared<MOCK_SAML_HTTP_CLIENT>(TEST_ENDPOINT);
mock_saml_http_client = std::make_shared<MOCK_SAML_HTTP_CLIENT>(TEST_ENDPOINT, 10, 10, true);
mock_auth_util = std::make_shared<MOCK_AUTH_UTIL>();
}

Expand Down
8 changes: 7 additions & 1 deletion util/installer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ static SQLWCHAR W_IAM_ROLE_ARN[] = { 'I', 'A', 'M', '_', 'R', 'O', 'L', 'E', '_'
static SQLWCHAR W_IAM_IDP_ARN[] = { 'I', 'A', 'M', '_', 'I', 'D', 'P', '_', 'A', 'R', 'N', 0 };
static SQLWCHAR W_APP_ID[] = { 'A', 'P', 'P', '_', 'I', 'D', 0 };
static SQLWCHAR W_IDP_PORT[] = { 'I', 'D', 'P', '_', 'P', 'O', 'R', 'T', 0 };
static SQLWCHAR W_CLIENT_CONNECT_TIMEOUT[] = {'C', 'L', 'I', 'E', 'N', 'T', '_', 'C', 'O', 'N', 'N', 'E', 'C', 'T', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', 0};
static SQLWCHAR W_CLIENT_SOCKET_TIMEOUT[] = {'C', 'L', 'I', 'E', 'N', 'T', '_', 'S', 'O', 'C', 'K', 'E', 'T', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', 0};
static SQLWCHAR W_ENABLE_SSL[] = {'E', 'N', 'A', 'B', 'L', 'E', '_', 'S', 'S', 'L', 0};

/* Failover */
static SQLWCHAR W_ENABLE_CLUSTER_FAILOVER[] = { 'E', 'N', 'A', 'B', 'L', 'E', '_', 'C', 'L', 'U', 'S', 'T', 'E', 'R', '_', 'F', 'A', 'I', 'L', 'O', 'V', 'E', 'R', 0 };
Expand Down Expand Up @@ -321,6 +324,7 @@ SQLWCHAR *dsnparams[]= {W_DSN, W_DRIVER, W_DESCRIPTION, W_SERVER,
W_AUTH_MODE, W_AUTH_REGION, W_AUTH_HOST, W_AUTH_PORT, W_AUTH_EXPIRATION, W_AUTH_SECRET_ID,
/* FED Auth*/
W_IDP_USERNAME, W_IDP_PASSWORD, W_IDP_ENDPOINT, W_IDP_PORT, W_APP_ID, W_IAM_ROLE_ARN, W_IAM_IDP_ARN,
W_CLIENT_CONNECT_TIMEOUT, W_CLIENT_SOCKET_TIMEOUT, W_ENABLE_SSL,
/* Failover */
W_ENABLE_CLUSTER_FAILOVER, W_FAILOVER_MODE,
W_GATHER_PERF_METRICS, W_GATHER_PERF_METRICS_PER_INSTANCE,
Expand Down Expand Up @@ -1051,9 +1055,11 @@ void DataSource::reset() {
this->opt_MONITOR_DISPOSAL_TIME.set_default(MONITOR_DISPOSAL_TIME_MS);
this->opt_FAILURE_DETECTION_TIMEOUT.set_default(FAILURE_DETECTION_TIMEOUT_SECS);

this->opt_IDP_PORT.set_default(-1);
this->opt_AUTH_PORT.set_default(opt_PORT);
this->opt_AUTH_EXPIRATION.set_default(900); // 15 minutes
this->opt_CLIENT_CONNECT_TIMEOUT.set_default(60);
this->opt_CLIENT_SOCKET_TIMEOUT.set_default(60);
this->opt_ENABLE_SSL.set_default(true);
}

SQLWSTRING DataSource::to_kvpair(SQLWCHAR delim) {
Expand Down
12 changes: 9 additions & 3 deletions util/installer.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,17 @@ unsigned int get_network_timeout(unsigned int seconds);
X(APP_ID)

#define FED_AUTH_INT_OPTIONS_LIST(X) \
X(IDP_PORT)
X(IDP_PORT) \
X(CLIENT_SOCKET_TIMEOUT) \
X(CLIENT_CONNECT_TIMEOUT)

#define FED_AUTH_BOOL_OPTIONS_LIST(X) \
X(ENABLE_SSL)

#define FAILOVER_BOOL_OPTIONS_LIST(X) \
X(ENABLE_CLUSTER_FAILOVER) \
X(GATHER_PERF_METRICS) \
X(GATHER_PERF_METRICS_PER_INSTANCE)
X(GATHER_PERF_METRICS_PER_INSTANCE) \

#define FAILOVER_STR_OPTIONS_LIST(X) \
X(HOST_PATTERN) \
Expand Down Expand Up @@ -396,7 +401,8 @@ unsigned int get_network_timeout(unsigned int seconds);
X(ENABLE_LOCAL_INFILE) X(ENABLE_DNS_SRV) \
X(MULTI_HOST) \
FAILOVER_BOOL_OPTIONS_LIST(X) \
MONITORING_BOOL_OPTIONS_LIST(X)
MONITORING_BOOL_OPTIONS_LIST(X) \
FED_AUTH_BOOL_OPTIONS_LIST(X)

#define FULL_OPTIONS_LIST(X) \
STR_OPTIONS_LIST(X) INT_OPTIONS_LIST(X) BOOL_OPTIONS_LIST(X)
Expand Down

0 comments on commit 8bf63c4

Please sign in to comment.