From 8bf63c435af91256017f4a22a89b5da0a80457f5 Mon Sep 17 00:00:00 2001 From: Karen <64801825+karenc-bq@users.noreply.github.com> Date: Wed, 4 Sep 2024 19:27:19 +0000 Subject: [PATCH] feat: add more configuration params for federated authentication (#212) --- .clang-format | 56 ++++++++++++++-------------- driver/okta_proxy.cc | 9 +++-- driver/okta_proxy.h | 2 +- driver/saml_http_client.cc | 16 ++++++-- driver/saml_http_client.h | 6 ++- setupgui/callbacks.cc | 3 ++ setupgui/windows/odbcdialogparams.rc | 53 ++++++++++++++------------ setupgui/windows/resource.h | 3 ++ unit_testing/mock_objects.h | 2 +- unit_testing/okta_proxy_test.cc | 2 +- util/installer.cc | 8 +++- util/installer.h | 12 ++++-- 12 files changed, 106 insertions(+), 66 deletions(-) diff --git a/.clang-format b/.clang-format index 67a0c54a5..63982a654 100644 --- a/.clang-format +++ b/.clang-format @@ -1,31 +1,31 @@ -// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -// -// This program is free software; you can redistribute it and/or modify -// it under the terms of the GNU General Public License, version 2.0 -// (GPLv2), as published by the Free Software Foundation, with the -// following additional permissions: -// -// This program is distributed with certain software that is licensed -// under separate terms, as designated in a particular file or component -// or in the license documentation. Without limiting your rights under -// the GPLv2, the authors of this program hereby grant you an additional -// permission to link the program and your derivative works with the -// separately licensed software that they have included with the program. -// -// Without limiting the foregoing grant of rights under the GPLv2 and -// additional permission as to separately licensed software, this -// program is also subject to the Universal FOSS Exception, version 1.0, -// a copy of which can be found along with its FAQ at -// http://oss.oracle.com/licenses/universal-foss-exception. -// -// This program is distributed in the hope that it will be useful, but -// WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. -// See the GNU General Public License, version 2.0, for more details. -// -// You should have received a copy of the GNU General Public License -// along with this program. If not, see -// http://www.gnu.org/licenses/gpl-2.0.html. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License, version 2.0 +# (GPLv2), as published by the Free Software Foundation, with the +# following additional permissions: +# +# This program is distributed with certain software that is licensed +# under separate terms, as designated in a particular file or component +# or in the license documentation. Without limiting your rights under +# the GPLv2, the authors of this program hereby grant you an additional +# permission to link the program and your derivative works with the +# separately licensed software that they have included with the program. +# +# Without limiting the foregoing grant of rights under the GPLv2 and +# additional permission as to separately licensed software, this +# program is also subject to the Universal FOSS Exception, version 1.0, +# a copy of which can be found along with its FAQ at +# http://oss.oracle.com/licenses/universal-foss-exception. +# +# This program is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +# See the GNU General Public License, version 2.0, for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see +# http://www.gnu.org/licenses/gpl-2.0.html. Language: Cpp # BasedOnStyle: Google diff --git a/driver/okta_proxy.cc b/driver/okta_proxy.cc index a804732bf..822d9791c 100644 --- a/driver/okta_proxy.cc +++ b/driver/okta_proxy.cc @@ -43,7 +43,10 @@ OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds) : OKTA_PROXY(dbc, ds, nullptr) OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) { this->next_proxy = next_proxy; const std::string idp_host{static_cast(ds->opt_IDP_ENDPOINT)}; - this->saml_util = std::make_shared(idp_host); + const int client_connect_timeout = ds->opt_CLIENT_CONNECT_TIMEOUT; + const int client_socket_timeout = ds->opt_CLIENT_SOCKET_TIMEOUT; + const bool enable_ssl = ds->opt_ENABLE_SSL; + this->saml_util = std::make_shared(idp_host, client_connect_timeout, client_socket_timeout, enable_ssl); } bool OKTA_PROXY::connect(const char* host, const char* user, const char* password, const char* database, @@ -122,8 +125,8 @@ void OKTA_PROXY::clear_token_cache() { OKTA_SAML_UTIL::OKTA_SAML_UTIL(const std::shared_ptr& client) { this->http_client = client; } -OKTA_SAML_UTIL::OKTA_SAML_UTIL(std::string host) { - this->http_client = std::make_shared("https://" + host); +OKTA_SAML_UTIL::OKTA_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) { + this->http_client = std::make_shared("https://" + host, connect_timeout, socket_timeout, enable_ssl); } std::string OKTA_SAML_UTIL::get_saml_url(DataSource* ds) { diff --git a/driver/okta_proxy.h b/driver/okta_proxy.h index 936f3ea23..15d9097f4 100644 --- a/driver/okta_proxy.h +++ b/driver/okta_proxy.h @@ -43,7 +43,7 @@ const std::regex SAML_RESPONSE_PATTERN(R"#(name=\"SAMLResponse\".+value=\"(.+)\" class OKTA_SAML_UTIL : public SAML_UTIL { public: OKTA_SAML_UTIL(const std::shared_ptr& client); - OKTA_SAML_UTIL(std::string host); + OKTA_SAML_UTIL(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl); std::string get_saml_assertion(DataSource* ds) override; std::string get_session_token(DataSource* ds) const; static std::string get_saml_url(DataSource* ds); diff --git a/driver/saml_http_client.cc b/driver/saml_http_client.cc index cdb5b4240..6dbe0eb10 100644 --- a/driver/saml_http_client.cc +++ b/driver/saml_http_client.cc @@ -30,10 +30,20 @@ #include "saml_http_client.h" #include -SAML_HTTP_CLIENT::SAML_HTTP_CLIENT(std::string host) : host{std::move(host)} {} +SAML_HTTP_CLIENT::SAML_HTTP_CLIENT(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) + : host{std::move(host)}, connect_timeout(connect_timeout), socket_timeout(socket_timeout), enable_ssl(enable_ssl) {} -nlohmann::json SAML_HTTP_CLIENT::post(const std::string& path, const nlohmann::json& value) { +httplib::Client SAML_HTTP_CLIENT::get_client() const { httplib::Client client(host); + client.enable_server_certificate_verification(enable_ssl); + client.set_connection_timeout(connect_timeout); + client.set_read_timeout(socket_timeout); + client.set_write_timeout(socket_timeout); + return client; +} + +nlohmann::json SAML_HTTP_CLIENT::post(const std::string& path, const nlohmann::json& value) { + httplib::Client client = this->get_client(); if (auto res = client.Post(path.c_str(), value.dump(), "application/json")) { if (res->status == httplib::StatusCode::OK_200) { nlohmann::json json_object = nlohmann::json::parse(res->body); @@ -46,7 +56,7 @@ nlohmann::json SAML_HTTP_CLIENT::post(const std::string& path, const nlohmann::j } nlohmann::json SAML_HTTP_CLIENT::get(const std::string& path) { - httplib::Client client(host); + httplib::Client client = this->get_client(); client.set_follow_location(true); if (auto res = client.Get(path.c_str())) { if (res->status == httplib::StatusCode::OK_200) { diff --git a/driver/saml_http_client.h b/driver/saml_http_client.h index e65cac636..6c68eac49 100644 --- a/driver/saml_http_client.h +++ b/driver/saml_http_client.h @@ -47,13 +47,17 @@ class SAML_HTTP_EXCEPTION: public std::exception { class SAML_HTTP_CLIENT { public: - SAML_HTTP_CLIENT(std::string host); + SAML_HTTP_CLIENT(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl); ~SAML_HTTP_CLIENT() = default; virtual nlohmann::json post(const std::string& path, const nlohmann::json& value); virtual nlohmann::json get(const std::string& path); private: const std::string host; + const int connect_timeout; + const int socket_timeout; + const bool enable_ssl; + httplib::Client get_client() const; }; #endif diff --git a/setupgui/callbacks.cc b/setupgui/callbacks.cc index bb47dd877..1204a450a 100644 --- a/setupgui/callbacks.cc +++ b/setupgui/callbacks.cc @@ -494,6 +494,9 @@ void syncTabs(HWND hwnd, DataSource *params) SET_STRING_TAB(FED_AUTH_TAB, AUTH_HOST); SET_UNSIGNED_TAB(FED_AUTH_TAB, AUTH_PORT); SET_UNSIGNED_TAB(FED_AUTH_TAB, AUTH_EXPIRATION); + SET_UNSIGNED_TAB(FED_AUTH_TAB, CLIENT_CONNECT_TIMEOUT); + SET_UNSIGNED_TAB(FED_AUTH_TAB, CLIENT_SOCKET_TIMEOUT); + SET_BOOL_TAB(FED_AUTH_TAB, ENABLE_SSL); /* 5 - Failover */ SET_BOOL_TAB(FAILOVER_TAB, ENABLE_CLUSTER_FAILOVER); diff --git a/setupgui/windows/odbcdialogparams.rc b/setupgui/windows/odbcdialogparams.rc index fd9967d83..2a5709e85 100644 --- a/setupgui/windows/odbcdialogparams.rc +++ b/setupgui/windows/odbcdialogparams.rc @@ -221,30 +221,35 @@ IDD_TAB4 DIALOGEX 0, 0, 307, 245 STYLE DS_SETFONT | DS_FIXEDSYS | WS_CHILD FONT 8, "MS Shell Dlg", 400, 0, 0x1 BEGIN - LTEXT "Federated Authentication Mode:", IDC_STATIC, 207, 7, 80, 18 - COMBOBOX IDC_EDIT_FED_AUTH_MODE, 207, 27, 79, 12, CBS_DROPDOWN | CBS_AUTOHSCROLL | CBS_SORT | WS_VSCROLL | WS_TABSTOP - RTEXT "IDP Username:", IDC_STATIC, 4, 6, 58, 18 - EDITTEXT IDC_EDIT_IDP_USERNAME, 65, 6, 136, 12, ES_AUTOHSCROLL - RTEXT "IDP Password:", IDC_STATIC, 4, 27, 58, 18 - EDITTEXT IDC_EDIT_IDP_PASSWORD, 65, 27, 136, 12, ES_PASSWORD | ES_AUTOHSCROLL - RTEXT "IDP Endpoint:", IDC_STATIC, 4, 47, 58, 18 - EDITTEXT IDC_EDIT_IDP_ENDPOINT, 65, 46, 136, 12, ES_AUTOHSCROLL - RTEXT "App ID:", IDC_STATIC, 4, 67, 58, 18 - EDITTEXT IDC_EDIT_APP_ID, 65, 67, 136, 12, ES_AUTOHSCROLL - RTEXT "IAM Role ARN:", IDC_STATIC, 3, 88, 58, 18 - EDITTEXT IDC_EDIT_IAM_ROLE_ARN, 65, 87, 136, 12, ES_AUTOHSCROLL - RTEXT "IAM IDP ARN:", IDC_STATIC, 3, 108, 58, 18 - EDITTEXT IDC_EDIT_IAM_IDP_ARN, 65, 107, 136, 12, ES_AUTOHSCROLL - LTEXT "IDP Port:", IDC_STATIC, 207, 125, 36, 10 - EDITTEXT IDC_EDIT_IDP_PORT, 243, 124, 51, 12, ES_AUTOHSCROLL | ES_NUMBER - RTEXT "AWS Region:", IDC_STATIC, 3, 126, 58, 18 - EDITTEXT IDC_EDIT_AUTH_REGION, 65, 125, 136, 12, ES_AUTOHSCROLL - RTEXT "IAM Host:", IDC_STATIC, 3, 145, 58, 18 - EDITTEXT IDC_EDIT_AUTH_HOST, 65, 144, 136, 12, ES_AUTOHSCROLL - LTEXT "IAM Port:", IDC_STATIC, 207, 144, 36, 18 - EDITTEXT IDC_EDIT_AUTH_PORT, 244, 143, 51, 12, ES_AUTOHSCROLL | ES_NUMBER - RTEXT "IAM Expire Time:", IDC_STATIC, 3, 163, 58, 18 - EDITTEXT IDC_EDIT_AUTH_EXPIRATION, 65, 162, 136, 12, ES_AUTOHSCROLL | ES_NUMBER + LTEXT "Federated Authentication Mode:",IDC_STATIC,207,7,80,18 + COMBOBOX IDC_EDIT_FED_AUTH_MODE,207,27,79,12,CBS_DROPDOWN | CBS_AUTOHSCROLL | CBS_SORT | WS_VSCROLL | WS_TABSTOP + RTEXT "IDP Username:",IDC_STATIC,4,6,58,18 + EDITTEXT IDC_EDIT_IDP_USERNAME,65,6,136,12,ES_AUTOHSCROLL + RTEXT "IDP Password:",IDC_STATIC,4,27,58,18 + EDITTEXT IDC_EDIT_IDP_PASSWORD,65,27,136,12,ES_PASSWORD | ES_AUTOHSCROLL + RTEXT "IDP Endpoint:",IDC_STATIC,4,47,58,18 + EDITTEXT IDC_EDIT_IDP_ENDPOINT,65,46,136,12,ES_AUTOHSCROLL + RTEXT "App ID:",IDC_STATIC,4,67,58,18 + EDITTEXT IDC_EDIT_APP_ID,65,67,136,12,ES_AUTOHSCROLL + RTEXT "IAM Role ARN:",IDC_STATIC,3,88,58,18 + EDITTEXT IDC_EDIT_IAM_ROLE_ARN,65,87,136,12,ES_AUTOHSCROLL + RTEXT "IAM IDP ARN:",IDC_STATIC,3,108,58,18 + EDITTEXT IDC_EDIT_IAM_IDP_ARN,65,107,136,12,ES_AUTOHSCROLL + LTEXT "IDP Port:",IDC_STATIC,207,125,36,10 + EDITTEXT IDC_EDIT_IDP_PORT,243,124,51,12,ES_AUTOHSCROLL | ES_NUMBER + RTEXT "AWS Region:",IDC_STATIC,3,126,58,18 + EDITTEXT IDC_EDIT_AUTH_REGION,65,125,136,12,ES_AUTOHSCROLL + RTEXT "IAM Host:",IDC_STATIC,3,145,58,18 + EDITTEXT IDC_EDIT_AUTH_HOST,65,144,136,12,ES_AUTOHSCROLL + LTEXT "IAM Port:",IDC_STATIC,207,144,36,18 + EDITTEXT IDC_EDIT_AUTH_PORT,244,143,51,12,ES_AUTOHSCROLL | ES_NUMBER + RTEXT "IAM Expire Time:",IDC_STATIC,3,163,58,18 + EDITTEXT IDC_EDIT_AUTH_EXPIRATION,65,162,136,12,ES_AUTOHSCROLL | ES_NUMBER + LTEXT "Client Connect Timeout:",IDC_STATIC,207,47,86,10 + EDITTEXT IDC_EDIT_CLIENT_CONNECT_TIMEOUT,207,59,51,12,ES_AUTOHSCROLL | ES_NUMBER + LTEXT "Client Socket Timeout:",IDC_STATIC,207,76,75,10 + EDITTEXT IDC_EDIT_CLIENT_SOCKET_TIMEOUT,207,88,51,12,ES_AUTOHSCROLL | ES_NUMBER + CONTROL "&Enable SSL",IDC_CHECK_ENABLE_SSL,"Button",BS_AUTOCHECKBOX | WS_TABSTOP,207,108,47,10 END IDD_TAB5 DIALOGEX 0, 0, 209, 281 diff --git a/setupgui/windows/resource.h b/setupgui/windows/resource.h index 686f240bb..524976faf 100644 --- a/setupgui/windows/resource.h +++ b/setupgui/windows/resource.h @@ -192,6 +192,9 @@ #define IDC_EDIT_IAM_IDP_ARN 11025 #define IDC_EDIT_FED_AUTH_MODE 11026 #define IDC_EDIT_IDP_PORT 11027 +#define IDC_EDIT_CLIENT_CONNECT_TIMEOUT 11028 +#define IDC_EDIT_CLIENT_SOCKET_TIMEOUT 11029 +#define IDC_CHECK_ENABLE_SSL 11030 #define MYSQL_ADMIN_PORT 33062 #define IDC_STATIC -1 diff --git a/unit_testing/mock_objects.h b/unit_testing/mock_objects.h index a35514d4f..2185c2856 100644 --- a/unit_testing/mock_objects.h +++ b/unit_testing/mock_objects.h @@ -231,7 +231,7 @@ class MOCK_AUTH_UTIL : public AUTH_UTIL { class MOCK_SAML_HTTP_CLIENT : public SAML_HTTP_CLIENT { public: - MOCK_SAML_HTTP_CLIENT(std::string host) : SAML_HTTP_CLIENT(host) {}; + MOCK_SAML_HTTP_CLIENT(std::string host, int connect_timeout, int socket_timeout, bool enable_ssl) : SAML_HTTP_CLIENT(host, connect_timeout, socket_timeout, enable_ssl) {}; MOCK_METHOD(nlohmann::json, post, (const std::string&, const nlohmann::json&)); MOCK_METHOD(nlohmann::json, get, (const std::string&)); }; diff --git a/unit_testing/okta_proxy_test.cc b/unit_testing/okta_proxy_test.cc index b13291489..0f1051eb5 100644 --- a/unit_testing/okta_proxy_test.cc +++ b/unit_testing/okta_proxy_test.cc @@ -97,7 +97,7 @@ class OktaProxyTest : public testing::Test { ds->opt_AUTH_PORT = TEST_PORT; ds->opt_AUTH_EXPIRATION = TEST_EXPIRATION; - mock_saml_http_client = std::make_shared(TEST_ENDPOINT); + mock_saml_http_client = std::make_shared(TEST_ENDPOINT, 10, 10, true); mock_auth_util = std::make_shared(); } diff --git a/util/installer.cc b/util/installer.cc index 8e2c38bd7..d8c8c551b 100644 --- a/util/installer.cc +++ b/util/installer.cc @@ -253,6 +253,9 @@ static SQLWCHAR W_IAM_ROLE_ARN[] = { 'I', 'A', 'M', '_', 'R', 'O', 'L', 'E', '_' static SQLWCHAR W_IAM_IDP_ARN[] = { 'I', 'A', 'M', '_', 'I', 'D', 'P', '_', 'A', 'R', 'N', 0 }; static SQLWCHAR W_APP_ID[] = { 'A', 'P', 'P', '_', 'I', 'D', 0 }; static SQLWCHAR W_IDP_PORT[] = { 'I', 'D', 'P', '_', 'P', 'O', 'R', 'T', 0 }; +static SQLWCHAR W_CLIENT_CONNECT_TIMEOUT[] = {'C', 'L', 'I', 'E', 'N', 'T', '_', 'C', 'O', 'N', 'N', 'E', 'C', 'T', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', 0}; +static SQLWCHAR W_CLIENT_SOCKET_TIMEOUT[] = {'C', 'L', 'I', 'E', 'N', 'T', '_', 'S', 'O', 'C', 'K', 'E', 'T', '_', 'T', 'I', 'M', 'E', 'O', 'U', 'T', 0}; +static SQLWCHAR W_ENABLE_SSL[] = {'E', 'N', 'A', 'B', 'L', 'E', '_', 'S', 'S', 'L', 0}; /* Failover */ static SQLWCHAR W_ENABLE_CLUSTER_FAILOVER[] = { 'E', 'N', 'A', 'B', 'L', 'E', '_', 'C', 'L', 'U', 'S', 'T', 'E', 'R', '_', 'F', 'A', 'I', 'L', 'O', 'V', 'E', 'R', 0 }; @@ -321,6 +324,7 @@ SQLWCHAR *dsnparams[]= {W_DSN, W_DRIVER, W_DESCRIPTION, W_SERVER, W_AUTH_MODE, W_AUTH_REGION, W_AUTH_HOST, W_AUTH_PORT, W_AUTH_EXPIRATION, W_AUTH_SECRET_ID, /* FED Auth*/ W_IDP_USERNAME, W_IDP_PASSWORD, W_IDP_ENDPOINT, W_IDP_PORT, W_APP_ID, W_IAM_ROLE_ARN, W_IAM_IDP_ARN, + W_CLIENT_CONNECT_TIMEOUT, W_CLIENT_SOCKET_TIMEOUT, W_ENABLE_SSL, /* Failover */ W_ENABLE_CLUSTER_FAILOVER, W_FAILOVER_MODE, W_GATHER_PERF_METRICS, W_GATHER_PERF_METRICS_PER_INSTANCE, @@ -1051,9 +1055,11 @@ void DataSource::reset() { this->opt_MONITOR_DISPOSAL_TIME.set_default(MONITOR_DISPOSAL_TIME_MS); this->opt_FAILURE_DETECTION_TIMEOUT.set_default(FAILURE_DETECTION_TIMEOUT_SECS); - this->opt_IDP_PORT.set_default(-1); this->opt_AUTH_PORT.set_default(opt_PORT); this->opt_AUTH_EXPIRATION.set_default(900); // 15 minutes + this->opt_CLIENT_CONNECT_TIMEOUT.set_default(60); + this->opt_CLIENT_SOCKET_TIMEOUT.set_default(60); + this->opt_ENABLE_SSL.set_default(true); } SQLWSTRING DataSource::to_kvpair(SQLWCHAR delim) { diff --git a/util/installer.h b/util/installer.h index 795e82350..22cd711da 100644 --- a/util/installer.h +++ b/util/installer.h @@ -325,12 +325,17 @@ unsigned int get_network_timeout(unsigned int seconds); X(APP_ID) #define FED_AUTH_INT_OPTIONS_LIST(X) \ - X(IDP_PORT) + X(IDP_PORT) \ + X(CLIENT_SOCKET_TIMEOUT) \ + X(CLIENT_CONNECT_TIMEOUT) + +#define FED_AUTH_BOOL_OPTIONS_LIST(X) \ + X(ENABLE_SSL) #define FAILOVER_BOOL_OPTIONS_LIST(X) \ X(ENABLE_CLUSTER_FAILOVER) \ X(GATHER_PERF_METRICS) \ - X(GATHER_PERF_METRICS_PER_INSTANCE) + X(GATHER_PERF_METRICS_PER_INSTANCE) \ #define FAILOVER_STR_OPTIONS_LIST(X) \ X(HOST_PATTERN) \ @@ -396,7 +401,8 @@ unsigned int get_network_timeout(unsigned int seconds); X(ENABLE_LOCAL_INFILE) X(ENABLE_DNS_SRV) \ X(MULTI_HOST) \ FAILOVER_BOOL_OPTIONS_LIST(X) \ - MONITORING_BOOL_OPTIONS_LIST(X) + MONITORING_BOOL_OPTIONS_LIST(X) \ + FED_AUTH_BOOL_OPTIONS_LIST(X) #define FULL_OPTIONS_LIST(X) \ STR_OPTIONS_LIST(X) INT_OPTIONS_LIST(X) BOOL_OPTIONS_LIST(X)