From c2ba2c78cf48766b5501a3bdf67da457910c6b60 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Wed, 18 Nov 2020 15:22:49 -0800 Subject: [PATCH 01/31] [1] Initial commit to support client header authentication in C++ --- cpp/src/arrow/flight/CMakeLists.txt | 1 + cpp/src/arrow/flight/client.cc | 41 +++++- cpp/src/arrow/flight/client.h | 8 ++ .../flight/client_header_auth_middleware.cc | 120 ++++++++++++++++++ .../flight/client_header_auth_middleware.h | 73 +++++++++++ 5 files changed, 240 insertions(+), 3 deletions(-) create mode 100644 cpp/src/arrow/flight/client_header_auth_middleware.cc create mode 100644 cpp/src/arrow/flight/client_header_auth_middleware.h diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index edd4fdf1c3391..e7758d123c80c 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -123,6 +123,7 @@ set(ARROW_FLIGHT_SRCS serialization_internal.cc server.cc server_auth.cc + client_header_auth_middleware.cc types.cc) add_arrow_lib(arrow_flight diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 821da214a59e9..c8700e11d2bfc 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -51,6 +51,7 @@ #include "arrow/flight/client_auth.h" #include "arrow/flight/client_middleware.h" +#include "arrow/flight/client_header_auth_middleware.h" #include "arrow/flight/internal.h" #include "arrow/flight/middleware.h" #include "arrow/flight/middleware_internal.h" @@ -104,6 +105,9 @@ struct ClientRpc { std::chrono::system_clock::now() + options.timeout); context.set_deadline(deadline); } + for (auto metadata: options.metadata) { + context.AddMetadata(metadata.first, metadata.second); + } } /// \brief Add an auth token via an auth handler @@ -328,7 +332,7 @@ class GrpcClientInterceptorAdapterFactory : public grpc::experimental::ClientInterceptorFactoryInterface { public: GrpcClientInterceptorAdapterFactory( - std::vector> middleware) + std::vector>& middleware) : middleware_(middleware) {} grpc::experimental::Interceptor* CreateClientInterceptor( @@ -371,7 +375,7 @@ class GrpcClientInterceptorAdapterFactory } private: - std::vector> middleware_; + std::vector>& middleware_; }; class GrpcClientAuthSender : public ClientAuthSender { @@ -963,8 +967,9 @@ class FlightClient::FlightClientImpl { std::vector> interceptors; + middleware = std::move(options.middleware); interceptors.emplace_back( - new GrpcClientInterceptorAdapterFactory(std::move(options.middleware))); + new GrpcClientInterceptorAdapterFactory(middleware)); stub_ = pb::FlightService::NewStub( grpc::experimental::CreateCustomChannelWithInterceptors( @@ -993,6 +998,29 @@ class FlightClient::FlightClientImpl { return Status::OK(); } + Status AuthenticateBasicToken(std::string username, std::string password, std::pair* bearer_token) { + // Add bearer token factory to middleware so it can intercept the bearer token. + middleware.push_back(std::make_shared(bearer_token)); + ClientRpc rpc({}); + AddBasicAuthHeaders(&rpc.context, username, password); + std::shared_ptr> + stream = stub_->Handshake(&rpc.context); + + GrpcClientAuthSender outgoing{stream}; + GrpcClientAuthReader incoming{stream}; + // Explicitly close our side of the connection + bool finished_writes = stream->WritesDone(); + middleware.pop_back(); + RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); + if (!finished_writes) { + return MakeFlightError(FlightStatusCode::Internal, + "Could not finish writing before closing"); + } + return Status::OK(); + } + + + Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, std::unique_ptr* listing) { pb::Criteria pb_criteria; @@ -1174,6 +1202,7 @@ class FlightClient::FlightClientImpl { GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig> noop_auth_check_; #endif + std::vector> middleware; int64_t write_size_limit_bytes_; }; @@ -1197,6 +1226,12 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, return impl_->Authenticate(options, std::move(auth_handler)); } +Status FlightClient::AuthenticateBasicToken( + std::string username, std::string password, + std::pair* bearer_token) { + return impl_->AuthenticateBasicToken(username, password, bearer_token); +} + Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action, std::unique_ptr* results) { return impl_->DoAction(options, action, results); diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 935e8fb92ba43..0bcdbc4fcf98e 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -65,6 +65,8 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions { /// \brief IPC writer options, if applicable for the call. ipc::IpcWriteOptions write_options; + + std::vector> metadata; }; /// \brief Indicate that the client attempted to write a message @@ -191,6 +193,12 @@ class ARROW_FLIGHT_EXPORT FlightClient { Status Authenticate(const FlightCallOptions& options, std::unique_ptr auth_handler); + /// \brief Authenticate to the server using the given handler. + /// \param[in] options Per-RPC options + /// \param[in] auth_handler The authentication mechanism to use + /// \return Status OK if the client authenticated successfully + Status AuthenticateBasicToken(std::string username, std::string password, std::pair* bearer_token); + /// \brief Perform the indicated action, returning an iterator to the stream /// of results, if any /// \param[in] options Per-RPC options diff --git a/cpp/src/arrow/flight/client_header_auth_middleware.cc b/cpp/src/arrow/flight/client_header_auth_middleware.cc new file mode 100644 index 0000000000000..e51a684a284e0 --- /dev/null +++ b/cpp/src/arrow/flight/client_header_auth_middleware.cc @@ -0,0 +1,120 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Interfaces for defining middleware for Flight clients. Currently +// experimental. + +#include "client_header_auth_middleware.h" +#include "client_middleware.h" +#include "client_auth.h" +#include "client.h" + +namespace arrow { +namespace flight { + + std::string base64_encode(const std::string& input); + + ClientBearerTokenMiddleware::ClientBearerTokenMiddleware(std::pair* bearer_token_) + : bearer_token(bearer_token_) { } + + void ClientBearerTokenMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { } + + void ClientBearerTokenMiddleware::ReceivedHeaders(const CallHeaders& incoming_headers) { + // Grab the auth token if one exists. + auto bearer_iter = incoming_headers.find(AUTH_HEADER); + if (bearer_iter == incoming_headers.end()) { + return; + } + + // Check if the value of the auth token starts with the bearer prefix, latch the token. + std::string bearer_val = bearer_iter->second.to_string(); + if (bearer_val.size() > BEARER_PREFIX.size()) { + bool hasPrefix = std::equal(bearer_val.begin(), bearer_val.begin() + BEARER_PREFIX.size(), BEARER_PREFIX.begin(), + [] (const char& char1, const char& char2) { + return (std::toupper(char1) == std::toupper(char2)); + } + ); + if (hasPrefix) { + *bearer_token = std::make_pair(AUTH_HEADER, bearer_val); + } + } + } + + void ClientBearerTokenMiddleware::CallCompleted(const Status& status) { } + + void ClientBearerTokenFactory::StartCall(const CallInfo& info, std::unique_ptr* middleware) { + *middleware = std::unique_ptr(new ClientBearerTokenMiddleware(bearer_token)); + } + + void ClientBearerTokenFactory::Reset() { + *bearer_token = std::make_pair("", ""); + } + + template + std::string string_format(const std::string& format, const Args... args) { + // Check size requirement for new string and increment by 1 for null terminator. + size_t size = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; + if(size <= 0){ + throw std::runtime_error("Error during string formatting. Format: '" + format + "'."); + } + + // Create buffer for new string and write string in. + std::unique_ptr buf(new char[size]); + std::snprintf(buf.get(), size, format.c_str(), args...); + + // Convert to std::string, subtracting size by 1 to trim null terminator. + return std::string(buf.get(), buf.get() + size - 1); + } + + void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password) { + const std::string formatted_credentials = string_format("%s:%s", username.c_str(), password.c_str()); + context->AddMetadata(AUTH_HEADER, BASIC_PREFIX + base64_encode(formatted_credentials)); + } + + std::string base64_encode(const std::string& input) { + static const std::string base64_chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + auto get_encoded_length = [] (const std::string& in) { + return 4 * ((in.size() + 2) / 3); + }; + auto get_overwrite_count = [] (const std::string& in) { + const std::string::size_type remainder = in.length() % 3; + return (remainder > 0) ? (3 - (remainder % 3)) : 0; + }; + + // Generate string with required length for encoding. + std::string encoded; + encoded.reserve(get_encoded_length(input)); + + // Loop through input writing base64 characters to string. + for (int i = 0; i < input.length();) { + uint32_t octet_1 = i < input.length() ? (unsigned char)input[i++] : 0; + uint32_t octet_2 = i < input.length() ? (unsigned char)input[i++] : 0; + uint32_t octet_3 = i < input.length() ? (unsigned char)input[i++] : 0; + uint32_t octriple = (octet_1 << 0x10) + (octet_2 << 0x08) + octet_3; + for (int j = 3; j >= 0; j--) { + encoded.push_back(base64_chars[(octriple >> j * 6) & 0x3F]); + } + } + + // Round up to nearest multiple of 3 and replace characters at end based on rounding. + int overwrite_count = get_overwrite_count(input); + encoded.replace(encoded.length() - overwrite_count, encoded.length(), overwrite_count, '='); + return encoded; + } +} // namespace flight +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/flight/client_header_auth_middleware.h b/cpp/src/arrow/flight/client_header_auth_middleware.h new file mode 100644 index 0000000000000..7602c69f89d4f --- /dev/null +++ b/cpp/src/arrow/flight/client_header_auth_middleware.h @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Interfaces for defining middleware for Flight clients. Currently +// experimental. + +#pragma once + +#include "client_middleware.h" +#include "client_auth.h" +#include "client.h" + +#ifdef GRPCPP_PP_INCLUDE +#include +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) +#include +#endif +#else +#include +#endif + +#include +#include +#include +#include + +const std::string AUTH_HEADER = "authorization"; +const std::string BEARER_PREFIX = "Bearer "; +const std::string BASIC_PREFIX = "Basic "; + +namespace arrow { +namespace flight { + +void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password); + +class ARROW_FLIGHT_EXPORT ClientBearerTokenMiddleware : public ClientMiddleware { + public: + explicit ClientBearerTokenMiddleware(std::pair* bearer_token_); + + void SendingHeaders(AddCallHeaders* outgoing_headers); + void ReceivedHeaders(const CallHeaders& incoming_headers); + void CallCompleted(const Status& status); + + private: + std::pair* bearer_token; +}; + +class ARROW_FLIGHT_EXPORT ClientBearerTokenFactory : public ClientMiddlewareFactory { + public: + explicit ClientBearerTokenFactory(std::pair* bearer_token_) : bearer_token(bearer_token_) {} + + void StartCall(const CallInfo& info, std::unique_ptr* middleware); + void Reset(); + + private: + std::pair* bearer_token; +}; +} // namespace flight +} // namespace arrow From ab7d6a30a115ceb623ba1d4af8f0388d4ad66075 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Wed, 18 Nov 2020 17:09:27 -0800 Subject: [PATCH 02/31] [1] Added integration test for client header authentication in C++ and server header authentication in Java --- cpp/src/arrow/flight/CMakeLists.txt | 10 +- .../test_integration_client_header_auth.cc | 115 ++++++++++++++++++ .../integration/Auth2IntegrationServer.java | 65 ++++++++++ 3 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 cpp/src/arrow/flight/test_integration_client_header_auth.cc create mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index e7758d123c80c..4aeabeadc2399 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -214,10 +214,16 @@ if(ARROW_BUILD_INTEGRATION) target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) + add_executable(flight-test-integration-client-header-auth test_integration_client_header_auth.cc) + target_link_libraries(flight-test-integration-client-header-auth ${ARROW_FLIGHT_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES} GTest::gtest) + add_dependencies(arrow_flight flight-test-integration-client - flight-test-integration-server) + flight-test-integration-server + flight-test-integration-client-header-auth) add_dependencies(arrow-integration flight-test-integration-client - flight-test-integration-server) + flight-test-integration-server + flight-test-integration-client-header-auth) endif() if(ARROW_BUILD_BENCHMARKS) diff --git a/cpp/src/arrow/flight/test_integration_client_header_auth.cc b/cpp/src/arrow/flight/test_integration_client_header_auth.cc new file mode 100644 index 0000000000000..15b0158b65bd2 --- /dev/null +++ b/cpp/src/arrow/flight/test_integration_client_header_auth.cc @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Client implementation for Flight integration testing. Loads +// RecordBatches from the given JSON file and uploads them to the +// Flight server, which stores the data and schema in memory. The +// client then requests the data from the server and compares it to +// the data originally uploaded. + +#include +#include +#include + +#include + +#include "arrow/io/file.h" +#include "arrow/io/test_common.h" +#include "arrow/ipc/dictionary.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/testing/extension_type.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/json_integration.h" +#include "arrow/util/logging.h" + +#include "arrow/flight/api.h" +#include "arrow/flight/test_integration.h" +#include "arrow/flight/test_util.h" + +DEFINE_string(host, "localhost", "Server port to connect to"); +DEFINE_int32(port, 31337, "Server port to connect to"); +DEFINE_string(username, "flight1", "Username to use in basic auth"); +DEFINE_string(password, "woohoo1", "Password to use in basic auth"); +DEFINE_string(username_invalid, "foooo", "Username to use in basic auth"); +DEFINE_string(password_invalid, "barrr", "Password to use in basic auth"); + +void TestValidCredentials() { + std::cout << "Testing with valid auth credentials." << std::endl; + auto get_uri = []() { + return "grpc+tcp://" + FLAGS_host + ":" + std::to_string(FLAGS_port); + }; + + // Generate Location with URI. + arrow::flight::Location location; + ABORT_NOT_OK(arrow::flight::Location::Parse(get_uri(), &location)); + + // Create client and connect to Location. + std::unique_ptr client; + ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, &client)); + + // Authenticate credentials and retreive token. + std::pair bearer_token = std::make_pair("", ""); + ABORT_NOT_OK(client->AuthenticateBasicToken(FLAGS_username, FLAGS_password, &bearer_token)); + + // Validate token was received. + if (bearer_token == std::make_pair(std::string(""), std::string(""))) { + std::cout << "Testing valid credentials was unsuccessful: Failed to get token from basic authentication." << std::endl; + return; + } + + // Try to list flights, this will force the bearer token to be send and authenticated. + std::unique_ptr listing; + arrow::flight::FlightCallOptions options; + options.metadata.push_back(bearer_token); + ABORT_NOT_OK(client->ListFlights(options, {}, &listing)); + std::cout << "Test valid credentials was successful." << std::endl; +} + +void TestInvalidCredentials() { + auto get_uri = []() { + return "grpc+tcp://" + FLAGS_host + ":" + std::to_string(FLAGS_port); + }; + + // Generate Location with URI. + arrow::flight::Location location; + ABORT_NOT_OK(arrow::flight::Location::Parse(get_uri(), &location)); + + // Create client and connect to Location. + std::unique_ptr client; + ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, &client)); + + // Authenticate credentials and retreive token. + std::pair bearer_token = std::make_pair("", ""); + EXPECT_EQ(arrow::StatusCode::IOError, + client->AuthenticateBasicToken(FLAGS_username_invalid, FLAGS_password_invalid, &bearer_token).code()); + + // Validate token was received. + if (bearer_token != std::make_pair(std::string(""), std::string(""))) { + std::cout << "Testing invalid credentials was unsuccessful: Obtained token from basic authentication when using invalid credentials." << std::endl; + return; + } + + std::cout << "Testing invalid credentials was successful." << std::endl; +} + +int main(int argc, char** argv) { + std::cout << "Starting auth header based flight integration test." << std::endl; + TestValidCredentials(); + TestInvalidCredentials(); +} \ No newline at end of file diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java new file mode 100644 index 0000000000000..59aaa390e8e6e --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java @@ -0,0 +1,65 @@ +package org.apache.arrow.flight.example.integration; + +import com.google.common.base.Strings; +import org.apache.arrow.flight.*; +import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator; +import org.apache.arrow.flight.example.InMemoryStore; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; + +import java.io.IOException; + +public class Auth2IntegrationServer { + private static final int PORT = 31337; + private static final String USERNAME_1 = "flight1"; + private static final String PASSWORD_1 = "woohoo1"; + private static final String HOST = "localhost"; + private static final BufferAllocator ALLOCATOR = new RootAllocator(Long.MAX_VALUE); + private static FlightServer server; + + static void launchServer() throws IOException, InterruptedException { + final Location location = Location.forGrpcInsecure(HOST, PORT); + final InMemoryStore store = new InMemoryStore(ALLOCATOR, location); + server = FlightServer.builder(ALLOCATOR, location, store).headerAuthenticator( + new GeneratedBearerTokenAuthenticator( + new BasicCallHeaderAuthenticator(Auth2IntegrationServer::validate)) + ).build().start(); + store.setLocation(Location.forGrpcInsecure("localhost", server.getPort())); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + System.out.println("\nExiting..."); + AutoCloseables.close(server, ALLOCATOR); + } catch (Exception e) { + e.printStackTrace(); + } + })); + + System.out.println("Server running on " + server.getLocation()); + server.awaitTermination(); + } + + private static CallHeaderAuthenticator.AuthResult validate(String username, String password) { + if (Strings.isNullOrEmpty(username)) { + throw CallStatus.UNAUTHENTICATED.withDescription("Credentials not supplied.").toRuntimeException(); + } + final String identity; + if (USERNAME_1.equals(username) && PASSWORD_1.equals(password)) { + identity = USERNAME_1; + } else { + throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException(); + } + return () -> identity; + } + + public static void main(String[] args) { + try { + launchServer(); + } catch (Exception e) { + System.out.println("Launching server failed " + e); + } + } +} From 73be3e72b9e1d8d8bae46fcda87a3064994ace1a Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Wed, 18 Nov 2020 19:09:54 -0800 Subject: [PATCH 03/31] [1] Updates for linting --- cpp/src/arrow/flight/client.cc | 7 ++-- cpp/src/arrow/flight/client.h | 8 ++-- .../flight/client_header_auth_middleware.cc | 38 ++++++++++--------- .../flight/client_header_auth_middleware.h | 18 +++++---- 4 files changed, 41 insertions(+), 30 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index c8700e11d2bfc..70fb3c8c73fbb 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -105,7 +105,7 @@ struct ClientRpc { std::chrono::system_clock::now() + options.timeout); context.set_deadline(deadline); } - for (auto metadata: options.metadata) { + for (auto metadata : options.metadata) { context.AddMetadata(metadata.first, metadata.second); } } @@ -998,7 +998,8 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - Status AuthenticateBasicToken(std::string username, std::string password, std::pair* bearer_token) { + Status AuthenticateBasicToken(std::string username, std::string password, + std::pair* bearer_token) { // Add bearer token factory to middleware so it can intercept the bearer token. middleware.push_back(std::make_shared(bearer_token)); ClientRpc rpc({}); @@ -1227,7 +1228,7 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, } Status FlightClient::AuthenticateBasicToken( - std::string username, std::string password, + std::string username, std::string password, std::pair* bearer_token) { return impl_->AuthenticateBasicToken(username, password, bearer_token); } diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 0bcdbc4fcf98e..d19a86f7576b3 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -194,10 +194,12 @@ class ARROW_FLIGHT_EXPORT FlightClient { std::unique_ptr auth_handler); /// \brief Authenticate to the server using the given handler. - /// \param[in] options Per-RPC options - /// \param[in] auth_handler The authentication mechanism to use + /// \param[in] username Username to use + /// \param[in] password Password to use + /// \param[in] bearer_token Bearer token retreived if applicable /// \return Status OK if the client authenticated successfully - Status AuthenticateBasicToken(std::string username, std::string password, std::pair* bearer_token); + Status AuthenticateBasicToken(std::string username, std::string password, + std::pair* bearer_token); /// \brief Perform the indicated action, returning an iterator to the stream /// of results, if any diff --git a/cpp/src/arrow/flight/client_header_auth_middleware.cc b/cpp/src/arrow/flight/client_header_auth_middleware.cc index e51a684a284e0..26688a8c3c6a8 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware.cc +++ b/cpp/src/arrow/flight/client_header_auth_middleware.cc @@ -28,12 +28,14 @@ namespace flight { std::string base64_encode(const std::string& input); - ClientBearerTokenMiddleware::ClientBearerTokenMiddleware(std::pair* bearer_token_) + ClientBearerTokenMiddleware::ClientBearerTokenMiddleware( + std::pair* bearer_token_) : bearer_token(bearer_token_) { } void ClientBearerTokenMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { } - void ClientBearerTokenMiddleware::ReceivedHeaders(const CallHeaders& incoming_headers) { + void ClientBearerTokenMiddleware::ReceivedHeaders( + const CallHeaders& incoming_headers) { // Grab the auth token if one exists. auto bearer_iter = incoming_headers.find(AUTH_HEADER); if (bearer_iter == incoming_headers.end()) { @@ -43,9 +45,9 @@ namespace flight { // Check if the value of the auth token starts with the bearer prefix, latch the token. std::string bearer_val = bearer_iter->second.to_string(); if (bearer_val.size() > BEARER_PREFIX.size()) { - bool hasPrefix = std::equal(bearer_val.begin(), bearer_val.begin() + BEARER_PREFIX.size(), BEARER_PREFIX.begin(), - [] (const char& char1, const char& char2) { - return (std::toupper(char1) == std::toupper(char2)); + bool hasPrefix = std::equal(bearer_val.begin(), bearer_val.begin() + BEARER_PREFIX.size(), BEARER_PREFIX.begin(), + [] (const char& char1, const char& char2) { + return (std::toupper(char1) == std::toupper(char2)); } ); if (hasPrefix) { @@ -55,7 +57,7 @@ namespace flight { } void ClientBearerTokenMiddleware::CallCompleted(const Status& status) { } - + void ClientBearerTokenFactory::StartCall(const CallInfo& info, std::unique_ptr* middleware) { *middleware = std::unique_ptr(new ClientBearerTokenMiddleware(bearer_token)); } @@ -68,28 +70,28 @@ namespace flight { std::string string_format(const std::string& format, const Args... args) { // Check size requirement for new string and increment by 1 for null terminator. size_t size = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; - if(size <= 0){ - throw std::runtime_error("Error during string formatting. Format: '" + format + "'."); + if(size <= 0){ + throw std::runtime_error("Error during string formatting. Format: '" + format + "'."); } // Create buffer for new string and write string in. - std::unique_ptr buf(new char[size]); + std::unique_ptr buf(new char[size]); std::snprintf(buf.get(), size, format.c_str(), args...); - + // Convert to std::string, subtracting size by 1 to trim null terminator. return std::string(buf.get(), buf.get() + size - 1); } - + void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password) { const std::string formatted_credentials = string_format("%s:%s", username.c_str(), password.c_str()); context->AddMetadata(AUTH_HEADER, BASIC_PREFIX + base64_encode(formatted_credentials)); } - + std::string base64_encode(const std::string& input) { - static const std::string base64_chars = + static const std::string base64_chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - auto get_encoded_length = [] (const std::string& in) { - return 4 * ((in.size() + 2) / 3); + auto get_encoded_length = [] (const std::string& in) { + return 4 * ((in.size() + 2) / 3); }; auto get_overwrite_count = [] (const std::string& in) { const std::string::size_type remainder = in.length() % 3; @@ -110,10 +112,12 @@ namespace flight { encoded.push_back(base64_chars[(octriple >> j * 6) & 0x3F]); } } - + // Round up to nearest multiple of 3 and replace characters at end based on rounding. int overwrite_count = get_overwrite_count(input); - encoded.replace(encoded.length() - overwrite_count, encoded.length(), overwrite_count, '='); + encoded.replace(encoded.length() - overwrite_count, + encoded.length(), + overwrite_count, '='); return encoded; } } // namespace flight diff --git a/cpp/src/arrow/flight/client_header_auth_middleware.h b/cpp/src/arrow/flight/client_header_auth_middleware.h index 7602c69f89d4f..3e83d4c180d61 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware.h +++ b/cpp/src/arrow/flight/client_header_auth_middleware.h @@ -20,9 +20,9 @@ #pragma once -#include "client_middleware.h" -#include "client_auth.h" -#include "client.h" +#include "arrow/flight/client_middleware.h" +#include "arrow/flight/client_auth.h" +#include "arrow/flight/client.h" #ifdef GRPCPP_PP_INCLUDE #include @@ -45,11 +45,14 @@ const std::string BASIC_PREFIX = "Basic "; namespace arrow { namespace flight { -void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password); +void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, + const std::string& username, + const std::string& password); class ARROW_FLIGHT_EXPORT ClientBearerTokenMiddleware : public ClientMiddleware { public: - explicit ClientBearerTokenMiddleware(std::pair* bearer_token_); + explicit ClientBearerTokenMiddleware( + std::pair* bearer_token_); void SendingHeaders(AddCallHeaders* outgoing_headers); void ReceivedHeaders(const CallHeaders& incoming_headers); @@ -61,11 +64,12 @@ class ARROW_FLIGHT_EXPORT ClientBearerTokenMiddleware : public ClientMiddleware class ARROW_FLIGHT_EXPORT ClientBearerTokenFactory : public ClientMiddlewareFactory { public: - explicit ClientBearerTokenFactory(std::pair* bearer_token_) : bearer_token(bearer_token_) {} + explicit ClientBearerTokenFactory(std::pair* bearer_token_) + : bearer_token(bearer_token_) {} void StartCall(const CallInfo& info, std::unique_ptr* middleware); void Reset(); - + private: std::pair* bearer_token; }; From 74ef8ea8a0383640a3e808b6d98837792ed05266 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Wed, 18 Nov 2020 19:11:38 -0800 Subject: [PATCH 04/31] [1] Adding missed file. --- .../test_integration_client_header_auth.cc | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/flight/test_integration_client_header_auth.cc b/cpp/src/arrow/flight/test_integration_client_header_auth.cc index 15b0158b65bd2..c3086e9d9c12f 100644 --- a/cpp/src/arrow/flight/test_integration_client_header_auth.cc +++ b/cpp/src/arrow/flight/test_integration_client_header_auth.cc @@ -65,11 +65,13 @@ void TestValidCredentials() { // Authenticate credentials and retreive token. std::pair bearer_token = std::make_pair("", ""); - ABORT_NOT_OK(client->AuthenticateBasicToken(FLAGS_username, FLAGS_password, &bearer_token)); + ABORT_NOT_OK( + client->AuthenticateBasicToken(FLAGS_username, FLAGS_password, &bearer_token)); // Validate token was received. if (bearer_token == std::make_pair(std::string(""), std::string(""))) { - std::cout << "Testing valid credentials was unsuccessful: Failed to get token from basic authentication." << std::endl; + std::cout << "Testing valid credentials was unsuccessful: " + << "Failed to get token from basic authentication." << std::endl; return; } @@ -97,14 +99,16 @@ void TestInvalidCredentials() { // Authenticate credentials and retreive token. std::pair bearer_token = std::make_pair("", ""); EXPECT_EQ(arrow::StatusCode::IOError, - client->AuthenticateBasicToken(FLAGS_username_invalid, FLAGS_password_invalid, &bearer_token).code()); + client->AuthenticateBasicToken( + FLAGS_username_invalid, FLAGS_password_invalid, &bearer_token).code()); // Validate token was received. if (bearer_token != std::make_pair(std::string(""), std::string(""))) { - std::cout << "Testing invalid credentials was unsuccessful: Obtained token from basic authentication when using invalid credentials." << std::endl; - return; + std::cout << "Testing invalid credentials was unsuccessful: " + << "Obtained token from basic authentication when using " + << "invalid credentials." << std::endl; } - + std::cout << "Testing invalid credentials was successful." << std::endl; } @@ -112,4 +116,4 @@ int main(int argc, char** argv) { std::cout << "Starting auth header based flight integration test." << std::endl; TestValidCredentials(); TestInvalidCredentials(); -} \ No newline at end of file +} From 29a3192f06fea287f4726633483999f5fa5d5348 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Wed, 18 Nov 2020 19:17:14 -0800 Subject: [PATCH 05/31] [1] Adding fix for Java lint errors. --- .../integration/Auth2IntegrationServer.java | 81 +++++++++---------- 1 file changed, 40 insertions(+), 41 deletions(-) diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java index 59aaa390e8e6e..ba4bcb262cb2c 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java @@ -13,53 +13,52 @@ import java.io.IOException; public class Auth2IntegrationServer { - private static final int PORT = 31337; - private static final String USERNAME_1 = "flight1"; - private static final String PASSWORD_1 = "woohoo1"; - private static final String HOST = "localhost"; - private static final BufferAllocator ALLOCATOR = new RootAllocator(Long.MAX_VALUE); - private static FlightServer server; + private static final int PORT = 31337; + private static final String USERNAME_1 = "flight1"; + private static final String PASSWORD_1 = "woohoo1"; + private static final String HOST = "localhost"; + private static final BufferAllocator ALLOCATOR = new RootAllocator(Long.MAX_VALUE); + private static FlightServer server; - static void launchServer() throws IOException, InterruptedException { - final Location location = Location.forGrpcInsecure(HOST, PORT); - final InMemoryStore store = new InMemoryStore(ALLOCATOR, location); - server = FlightServer.builder(ALLOCATOR, location, store).headerAuthenticator( - new GeneratedBearerTokenAuthenticator( - new BasicCallHeaderAuthenticator(Auth2IntegrationServer::validate)) + static void launchServer() throws IOException, InterruptedException { + final Location location = Location.forGrpcInsecure(HOST, PORT); + final InMemoryStore store = new InMemoryStore(ALLOCATOR, location); + server = FlightServer.builder(ALLOCATOR, location, store).headerAuthenticator( + new GeneratedBearerTokenAuthenticator( + new BasicCallHeaderAuthenticator(Auth2IntegrationServer::validate)) ).build().start(); - store.setLocation(Location.forGrpcInsecure("localhost", server.getPort())); + store.setLocation(Location.forGrpcInsecure("localhost", server.getPort())); - Runtime.getRuntime().addShutdownHook(new Thread(() -> { - try { - System.out.println("\nExiting..."); - AutoCloseables.close(server, ALLOCATOR); - } catch (Exception e) { - e.printStackTrace(); - } - })); + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + System.out.println("\nExiting..."); + AutoCloseables.close(server, ALLOCATOR); + } catch (Exception e) { + e.printStackTrace(); + }})); - System.out.println("Server running on " + server.getLocation()); - server.awaitTermination(); - } + System.out.println("Server running on " + server.getLocation()); + server.awaitTermination(); + } - private static CallHeaderAuthenticator.AuthResult validate(String username, String password) { - if (Strings.isNullOrEmpty(username)) { - throw CallStatus.UNAUTHENTICATED.withDescription("Credentials not supplied.").toRuntimeException(); - } - final String identity; - if (USERNAME_1.equals(username) && PASSWORD_1.equals(password)) { - identity = USERNAME_1; - } else { - throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException(); - } - return () -> identity; + private static CallHeaderAuthenticator.AuthResult validate(String username, String password) { + if (Strings.isNullOrEmpty(username)) { + throw CallStatus.UNAUTHENTICATED.withDescription("Credentials not supplied.").toRuntimeException(); } + final String identity; + if (USERNAME_1.equals(username) && PASSWORD_1.equals(password)) { + identity = USERNAME_1; + } else { + throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException(); + } + return () -> identity; + } - public static void main(String[] args) { - try { - launchServer(); - } catch (Exception e) { - System.out.println("Launching server failed " + e); - } + public static void main(String[] args) { + try { + launchServer(); + } catch (Exception e) { + System.out.println("Launching server failed " + e); } + } } From fac0bd02b46562c6196e5b1f7612701ae2d78cb0 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Wed, 18 Nov 2020 19:19:51 -0800 Subject: [PATCH 06/31] [1] Added a couple comments --- cpp/src/arrow/flight/client_header_auth_middleware.h | 1 + .../flight/example/integration/Auth2IntegrationServer.java | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/cpp/src/arrow/flight/client_header_auth_middleware.h b/cpp/src/arrow/flight/client_header_auth_middleware.h index 3e83d4c180d61..481220f3efdd8 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware.h +++ b/cpp/src/arrow/flight/client_header_auth_middleware.h @@ -45,6 +45,7 @@ const std::string BASIC_PREFIX = "Basic "; namespace arrow { namespace flight { +// TODO: Need to add documentation in this file. void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password); diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java index ba4bcb262cb2c..76c9b60cbb9bd 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java @@ -12,6 +12,10 @@ import java.io.IOException; +/** + * Java server for running integration tests - this is currently setup to run + * against the cpp test 'flight-test-integration-client-header-auth'. + */ public class Auth2IntegrationServer { private static final int PORT = 31337; private static final String USERNAME_1 = "flight1"; From 7fd1279426efd13cd8482b003ec3849bfbaa0a35 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Wed, 18 Nov 2020 19:44:03 -0800 Subject: [PATCH 07/31] [1] Minor comment fixes --- cpp/src/arrow/flight/client.h | 1 + cpp/src/arrow/flight/client_header_auth_middleware.cc | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index d19a86f7576b3..9d883bf577ef7 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -66,6 +66,7 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions { /// \brief IPC writer options, if applicable for the call. ipc::IpcWriteOptions write_options; + /// \brief Metadata for client to add to context. std::vector> metadata; }; diff --git a/cpp/src/arrow/flight/client_header_auth_middleware.cc b/cpp/src/arrow/flight/client_header_auth_middleware.cc index 26688a8c3c6a8..c70c67922e859 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware.cc +++ b/cpp/src/arrow/flight/client_header_auth_middleware.cc @@ -121,4 +121,4 @@ namespace flight { return encoded; } } // namespace flight -} // namespace arrow \ No newline at end of file +} // namespace arrow From 0b5a08e72c23f3239819cd0ea5734b014c750b07 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Mon, 23 Nov 2020 10:58:32 -0800 Subject: [PATCH 08/31] [1] Addressed pull request comments, still need to address comments around unit tests but checking in WIP. --- cpp/src/arrow/flight/CMakeLists.txt | 2 +- cpp/src/arrow/flight/client.cc | 45 +++++-- cpp/src/arrow/flight/client.h | 4 +- .../flight/client_header_auth_middleware.cc | 124 ------------------ .../client_header_auth_middleware_internal.cc | 117 +++++++++++++++++ ... client_header_auth_middleware_internal.h} | 47 +++---- 6 files changed, 171 insertions(+), 168 deletions(-) delete mode 100644 cpp/src/arrow/flight/client_header_auth_middleware.cc create mode 100644 cpp/src/arrow/flight/client_header_auth_middleware_internal.cc rename cpp/src/arrow/flight/{client_header_auth_middleware.h => client_header_auth_middleware_internal.h} (64%) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 5691d83b691e2..c1fe302b20e35 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -123,7 +123,7 @@ set(ARROW_FLIGHT_SRCS serialization_internal.cc server.cc server_auth.cc - client_header_auth_middleware.cc + client_header_auth_middleware_internal.cc types.cc) add_arrow_lib(arrow_flight diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 70fb3c8c73fbb..69fcd58406829 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -50,8 +50,8 @@ #include "arrow/util/uri.h" #include "arrow/flight/client_auth.h" +#include "arrow/flight/client_header_auth_middleware_internal.h" #include "arrow/flight/client_middleware.h" -#include "arrow/flight/client_header_auth_middleware.h" #include "arrow/flight/internal.h" #include "arrow/flight/middleware.h" #include "arrow/flight/middleware_internal.h" @@ -332,7 +332,7 @@ class GrpcClientInterceptorAdapterFactory : public grpc::experimental::ClientInterceptorFactoryInterface { public: GrpcClientInterceptorAdapterFactory( - std::vector>& middleware) + const std::vector> middleware) : middleware_(middleware) {} grpc::experimental::Interceptor* CreateClientInterceptor( @@ -364,6 +364,7 @@ class GrpcClientInterceptorAdapterFactory } const CallInfo flight_info{flight_method}; + std::lock_guard lock(middleware_lock_); for (auto& factory : middleware_) { std::unique_ptr instance; factory->StartCall(flight_info, &instance); @@ -374,8 +375,18 @@ class GrpcClientInterceptorAdapterFactory return new GrpcClientInterceptorAdapter(std::move(middleware)); } + void AddMiddlewareFactory(std::shared_ptr middleware_factory) { + std::lock_guard lock(middleware_lock_); + middleware_.push_back(middleware_factory); + } + + void RemoveMiddlewareFactory() { + middleware_.pop_back(); + } + private: - std::vector>& middleware_; + std::mutex middleware_lock_; + std::vector> middleware_; }; class GrpcClientAuthSender : public ClientAuthSender { @@ -874,6 +885,8 @@ constexpr char BLANK_ROOT_PEM[] = } // namespace class FlightClient::FlightClientImpl { public: + FlightClientImpl() : interceptor_pointer(NULLPTR) {} + Status Connect(const Location& location, const FlightClientOptions& options) { const std::string& scheme = location.scheme(); @@ -967,9 +980,8 @@ class FlightClient::FlightClientImpl { std::vector> interceptors; - middleware = std::move(options.middleware); - interceptors.emplace_back( - new GrpcClientInterceptorAdapterFactory(middleware)); + interceptor_pointer = new GrpcClientInterceptorAdapterFactory(options.middleware); + interceptors.emplace_back(interceptor_pointer); stub_ = pb::FlightService::NewStub( grpc::experimental::CreateCustomChannelWithInterceptors( @@ -998,12 +1010,20 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - Status AuthenticateBasicToken(std::string username, std::string password, + Status AuthenticateBasicToken(const std::string& username, const std::string& password, std::pair* bearer_token) { // Add bearer token factory to middleware so it can intercept the bearer token. - middleware.push_back(std::make_shared(bearer_token)); + if (interceptor_pointer != NULLPTR) { + std::cout << "Adding middleware." << std::endl; + interceptor_pointer->AddMiddlewareFactory(std::make_shared(bearer_token)); + } else { + std::cout << "NULLPTR" << std::endl; + return MakeFlightError(FlightStatusCode::Internal, + "Connect must be called before AuthenticateBasicToken."); + } ClientRpc rpc({}); - AddBasicAuthHeaders(&rpc.context, username, password); + std::cout << "AddBasicAuthHeaders" << std::endl; + internal::AddBasicAuthHeaders(&rpc.context, username, password); std::shared_ptr> stream = stub_->Handshake(&rpc.context); @@ -1011,7 +1031,8 @@ class FlightClient::FlightClientImpl { GrpcClientAuthReader incoming{stream}; // Explicitly close our side of the connection bool finished_writes = stream->WritesDone(); - middleware.pop_back(); + std::cout << "Middleware popback" << std::endl; + interceptor_pointer->RemoveMiddlewareFactory(); RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); if (!finished_writes) { return MakeFlightError(FlightStatusCode::Internal, @@ -1203,8 +1224,8 @@ class FlightClient::FlightClientImpl { GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig> noop_auth_check_; #endif - std::vector> middleware; int64_t write_size_limit_bytes_; + GrpcClientInterceptorAdapterFactory* interceptor_pointer; }; FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); } @@ -1228,7 +1249,7 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, } Status FlightClient::AuthenticateBasicToken( - std::string username, std::string password, + const std::string& username, const std::string& password, std::pair* bearer_token) { return impl_->AuthenticateBasicToken(username, password, bearer_token); } diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 9d883bf577ef7..1810c8f837efa 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -194,12 +194,12 @@ class ARROW_FLIGHT_EXPORT FlightClient { Status Authenticate(const FlightCallOptions& options, std::unique_ptr auth_handler); - /// \brief Authenticate to the server using the given handler. + /// \brief Authenticate to the server using basic authentication with base 64 encoding. /// \param[in] username Username to use /// \param[in] password Password to use /// \param[in] bearer_token Bearer token retreived if applicable /// \return Status OK if the client authenticated successfully - Status AuthenticateBasicToken(std::string username, std::string password, + Status AuthenticateBasicToken(const std::string& username, const std::string& password, std::pair* bearer_token); /// \brief Perform the indicated action, returning an iterator to the stream diff --git a/cpp/src/arrow/flight/client_header_auth_middleware.cc b/cpp/src/arrow/flight/client_header_auth_middleware.cc deleted file mode 100644 index c70c67922e859..0000000000000 --- a/cpp/src/arrow/flight/client_header_auth_middleware.cc +++ /dev/null @@ -1,124 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Interfaces for defining middleware for Flight clients. Currently -// experimental. - -#include "client_header_auth_middleware.h" -#include "client_middleware.h" -#include "client_auth.h" -#include "client.h" - -namespace arrow { -namespace flight { - - std::string base64_encode(const std::string& input); - - ClientBearerTokenMiddleware::ClientBearerTokenMiddleware( - std::pair* bearer_token_) - : bearer_token(bearer_token_) { } - - void ClientBearerTokenMiddleware::SendingHeaders(AddCallHeaders* outgoing_headers) { } - - void ClientBearerTokenMiddleware::ReceivedHeaders( - const CallHeaders& incoming_headers) { - // Grab the auth token if one exists. - auto bearer_iter = incoming_headers.find(AUTH_HEADER); - if (bearer_iter == incoming_headers.end()) { - return; - } - - // Check if the value of the auth token starts with the bearer prefix, latch the token. - std::string bearer_val = bearer_iter->second.to_string(); - if (bearer_val.size() > BEARER_PREFIX.size()) { - bool hasPrefix = std::equal(bearer_val.begin(), bearer_val.begin() + BEARER_PREFIX.size(), BEARER_PREFIX.begin(), - [] (const char& char1, const char& char2) { - return (std::toupper(char1) == std::toupper(char2)); - } - ); - if (hasPrefix) { - *bearer_token = std::make_pair(AUTH_HEADER, bearer_val); - } - } - } - - void ClientBearerTokenMiddleware::CallCompleted(const Status& status) { } - - void ClientBearerTokenFactory::StartCall(const CallInfo& info, std::unique_ptr* middleware) { - *middleware = std::unique_ptr(new ClientBearerTokenMiddleware(bearer_token)); - } - - void ClientBearerTokenFactory::Reset() { - *bearer_token = std::make_pair("", ""); - } - - template - std::string string_format(const std::string& format, const Args... args) { - // Check size requirement for new string and increment by 1 for null terminator. - size_t size = std::snprintf(nullptr, 0, format.c_str(), args ...) + 1; - if(size <= 0){ - throw std::runtime_error("Error during string formatting. Format: '" + format + "'."); - } - - // Create buffer for new string and write string in. - std::unique_ptr buf(new char[size]); - std::snprintf(buf.get(), size, format.c_str(), args...); - - // Convert to std::string, subtracting size by 1 to trim null terminator. - return std::string(buf.get(), buf.get() + size - 1); - } - - void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password) { - const std::string formatted_credentials = string_format("%s:%s", username.c_str(), password.c_str()); - context->AddMetadata(AUTH_HEADER, BASIC_PREFIX + base64_encode(formatted_credentials)); - } - - std::string base64_encode(const std::string& input) { - static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; - auto get_encoded_length = [] (const std::string& in) { - return 4 * ((in.size() + 2) / 3); - }; - auto get_overwrite_count = [] (const std::string& in) { - const std::string::size_type remainder = in.length() % 3; - return (remainder > 0) ? (3 - (remainder % 3)) : 0; - }; - - // Generate string with required length for encoding. - std::string encoded; - encoded.reserve(get_encoded_length(input)); - - // Loop through input writing base64 characters to string. - for (int i = 0; i < input.length();) { - uint32_t octet_1 = i < input.length() ? (unsigned char)input[i++] : 0; - uint32_t octet_2 = i < input.length() ? (unsigned char)input[i++] : 0; - uint32_t octet_3 = i < input.length() ? (unsigned char)input[i++] : 0; - uint32_t octriple = (octet_1 << 0x10) + (octet_2 << 0x08) + octet_3; - for (int j = 3; j >= 0; j--) { - encoded.push_back(base64_chars[(octriple >> j * 6) & 0x3F]); - } - } - - // Round up to nearest multiple of 3 and replace characters at end based on rounding. - int overwrite_count = get_overwrite_count(input); - encoded.replace(encoded.length() - overwrite_count, - encoded.length(), - overwrite_count, '='); - return encoded; - } -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc b/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc new file mode 100644 index 0000000000000..03830eb566682 --- /dev/null +++ b/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Interfaces for defining middleware for Flight clients. Currently +// experimental. + +#include "client_header_auth_middleware_internal.h" +#include "arrow/flight/client_auth.h" +#include "arrow/flight/client.h" +#include "arrow/util/base64.h" +#include "arrow/util/make_unique.h" + +#include +#include +#include +#include +#include + +const std::string kAuthHeader = "authorization"; +const std::string kBearerPrefix = "Bearer "; +const std::string kBasicPrefix = "Basic "; + +namespace arrow { +namespace flight { +namespace internal { + +// Add base64 encoded credentials to the outbound headers. +// +// @param context Context object to add the headers to. +// @param username Username to format and encode. +// @param password Password to format and encode. +void AddBasicAuthHeaders(grpc::ClientContext* context, + const std::string& username, const std::string& password) { + const std::string credentials = username + ":" + password; + context->AddMetadata(kAuthHeader, kBasicPrefix + + arrow::util::base64_encode((const unsigned char*)credentials.c_str(), + credentials.size())); +} + +class ClientBearerTokenFactory::Impl { + public: + Impl(std::pair* bearer_token) + : bearer_token_(bearer_token) { } + + void StartCall(const CallInfo& info, std::unique_ptr* middleware) { + ARROW_UNUSED(info); + *middleware = arrow::internal::make_unique(bearer_token_); + } + + private: + class ClientBearerTokenMiddleware : public ClientMiddleware { + public: + explicit ClientBearerTokenMiddleware(std::pair* bearer_token) + : bearer_token_(bearer_token) { } + + void SendingHeaders(AddCallHeaders* outgoing_headers) override { } + + void ReceivedHeaders(const CallHeaders& incoming_headers) override { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [] (const char& char1, const char& char2) { + return (std::toupper(char1) == std::toupper(char2)); + }; + + // Grab the auth token if one exists. + const auto bearer_iter = incoming_headers.find(kAuthHeader); + if (bearer_iter == incoming_headers.end()) { + return; + } + + // Check if the value of the auth token starts with the bearer prefix, latch the token. + const std::string bearer_val = bearer_iter->second.to_string(); + if (bearer_val.size() > kBearerPrefix.size()) { + if (std::equal(bearer_val.begin(), bearer_val.begin() + kBearerPrefix.size(), + kBearerPrefix.begin(), char_compare)) { + *bearer_token_ = std::make_pair(kAuthHeader, bearer_val); + } + } + } + + void CallCompleted(const Status& status) override { } + + private: + std::pair* bearer_token_; + }; + + private: + std::pair* bearer_token_; +}; + +ClientBearerTokenFactory::ClientBearerTokenFactory( + std::pair* bearer_token) + : impl_(new ClientBearerTokenFactory::Impl(bearer_token)) { } + +ClientBearerTokenFactory::~ClientBearerTokenFactory() { } + +void ClientBearerTokenFactory::StartCall(const CallInfo& info, + std::unique_ptr* middleware) { + impl_->StartCall(info, middleware); +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/client_header_auth_middleware.h b/cpp/src/arrow/flight/client_header_auth_middleware_internal.h similarity index 64% rename from cpp/src/arrow/flight/client_header_auth_middleware.h rename to cpp/src/arrow/flight/client_header_auth_middleware_internal.h index 481220f3efdd8..20f311d08603f 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware.h +++ b/cpp/src/arrow/flight/client_header_auth_middleware_internal.h @@ -21,8 +21,6 @@ #pragma once #include "arrow/flight/client_middleware.h" -#include "arrow/flight/client_auth.h" -#include "arrow/flight/client.h" #ifdef GRPCPP_PP_INCLUDE #include @@ -33,46 +31,37 @@ #include #endif -#include -#include -#include -#include - -const std::string AUTH_HEADER = "authorization"; -const std::string BEARER_PREFIX = "Bearer "; -const std::string BASIC_PREFIX = "Basic "; - namespace arrow { namespace flight { +namespace internal { -// TODO: Need to add documentation in this file. +/// \brief Add basic authentication header key value pair to context. +/// +/// \param context grpc context variable to add header to. +/// \param username username to encode into header. +/// \param password password to to encode into header. void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password); -class ARROW_FLIGHT_EXPORT ClientBearerTokenMiddleware : public ClientMiddleware { - public: - explicit ClientBearerTokenMiddleware( - std::pair* bearer_token_); - - void SendingHeaders(AddCallHeaders* outgoing_headers); - void ReceivedHeaders(const CallHeaders& incoming_headers); - void CallCompleted(const Status& status); - - private: - std::pair* bearer_token; -}; - +/// \brief Client-side middleware for receiving and latching a bearer token. class ARROW_FLIGHT_EXPORT ClientBearerTokenFactory : public ClientMiddlewareFactory { public: - explicit ClientBearerTokenFactory(std::pair* bearer_token_) - : bearer_token(bearer_token_) {} + /// \brief Constructor for factory. + /// + /// \param[out] bearer_token_ pointer to a std::pair of std::strings that the factory + /// will populate with the bearer token that is received from the server. + ClientBearerTokenFactory(std::pair* bearer_token_); + + ~ClientBearerTokenFactory(); void StartCall(const CallInfo& info, std::unique_ptr* middleware); - void Reset(); private: - std::pair* bearer_token; + class Impl; + std::unique_ptr impl_; }; + +} // namespace internal } // namespace flight } // namespace arrow From 6861cf00966f804ef7d6a03eac77167b45357e83 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Mon, 23 Nov 2020 12:56:35 -0800 Subject: [PATCH 09/31] [1] Added unit test. [2] Still need to update integration test to run automatically. --- cpp/src/arrow/flight/client.cc | 4 - cpp/src/arrow/flight/flight_test.cc | 127 ++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 69fcd58406829..b6fe4d8d03d7f 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -1014,15 +1014,12 @@ class FlightClient::FlightClientImpl { std::pair* bearer_token) { // Add bearer token factory to middleware so it can intercept the bearer token. if (interceptor_pointer != NULLPTR) { - std::cout << "Adding middleware." << std::endl; interceptor_pointer->AddMiddlewareFactory(std::make_shared(bearer_token)); } else { - std::cout << "NULLPTR" << std::endl; return MakeFlightError(FlightStatusCode::Internal, "Connect must be called before AuthenticateBasicToken."); } ClientRpc rpc({}); - std::cout << "AddBasicAuthHeaders" << std::endl; internal::AddBasicAuthHeaders(&rpc.context, username, password); std::shared_ptr> stream = stub_->Handshake(&rpc.context); @@ -1031,7 +1028,6 @@ class FlightClient::FlightClientImpl { GrpcClientAuthReader incoming{stream}; // Explicitly close our side of the connection bool finished_writes = stream->WritesDone(); - std::cout << "Middleware popback" << std::endl; interceptor_pointer->RemoveMiddlewareFactory(); RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); if (!finished_writes) { diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index a5fdfe8184e15..49a5df8567e25 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -36,6 +36,7 @@ #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" +#include "arrow/util/base64.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" @@ -43,6 +44,7 @@ #error "gRPC headers should not be in public API" #endif +#include "arrow/flight/client_header_auth_middleware_internal.h" #include "arrow/flight/internal.h" #include "arrow/flight/middleware_internal.h" #include "arrow/flight/test_util.h" @@ -52,6 +54,15 @@ namespace pb = arrow::flight::protocol; namespace arrow { namespace flight { +const std::string kValidUsername = "flight_username"; +const std::string kValidPassword = "flight_password"; +const std::string kInvalidUsername = "invalid_flight_username"; +const std::string kInvalidPassword = "invalid_flight_password"; +const std::string kBearerToken = "bearertoken"; +const std::string kBasicPrefix = "Basic "; +const std::string kBearerPrefix = "Bearer "; +const std::string kAuthHeader = "authorization"; + void AssertEqual(const ActionType& expected, const ActionType& actual) { ASSERT_EQ(expected.type, actual.type); ASSERT_EQ(expected.description, actual.description); @@ -774,6 +785,69 @@ class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { } }; +// A server middleware for validating incoming base64 header authentication. +class HeaderAuthServerMiddleware : public ServerMiddleware { + public: + explicit HeaderAuthServerMiddleware(const CallHeaders& incoming_headers) { + incoming_headers_ = incoming_headers; + } + + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [] (const char& char1, const char& char2) { + return (std::toupper(char1) == std::toupper(char2)); + }; + + std::string username; + std::string password; + for (auto& iter : incoming_headers_) { + const std::string key = iter.first.to_string(); + const std::string val = iter.second.to_string(); + if (key == kAuthHeader) { + if (val.size() > kBasicPrefix.size()) { + if (std::equal(val.begin(), val.begin() + kBasicPrefix.size(), + kBasicPrefix.begin(), char_compare)) { + const std::string encoded_credentials = val.substr(kBasicPrefix.size()); + const std::string decoded_credentials = arrow::util::base64_decode( + encoded_credentials); + std::stringstream decoded_stream(decoded_credentials); + std::getline(decoded_stream, username, ':'); + std::getline(decoded_stream, password, ':'); + break; + } + } + } + } + + if (username == kValidUsername && + password == kValidPassword) { + outgoing_headers->AddHeader(kAuthHeader, kBearerPrefix + kBearerToken); + } + } + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "HeaderAuthServerMiddleware"; } + + CallHeaders incoming_headers_; +}; + +// Factory for base64 header authentication testing. +class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + HeaderAuthServerMiddlewareFactory() {} + + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override { + const std::pair& iter_pair = + incoming_headers.equal_range(kAuthHeader); + if (iter_pair.first != iter_pair.second) { + *middleware = std::make_shared(incoming_headers); + } + return Status::OK(); + } +}; + // A client middleware that adds a thread-local "request ID" to // outgoing calls as a header, and keeps track of the status of // completed calls. NOT thread-safe. @@ -993,6 +1067,51 @@ class TestErrorMiddleware : public ::testing::Test { std::unique_ptr server_; }; +class TestBasicHeaderAuthMiddleware : public ::testing::Test { + public: + void SetUp() { + server_middleware_ = std::make_shared(); + ASSERT_OK(MakeServer( + &server_, &client_, [&](FlightServerOptions* options) { + options->auth_handler = std::unique_ptr( + new TestServerAuthHandler("", "")); + options->middleware.push_back({"header-auth-server", server_middleware_}); + return Status::OK(); + }, + [&](FlightClientOptions* options) { return Status::OK(); })); + } + + void RunValidClientAuth() { + std::pair bearer_token; + // Note: Status intentionally ignored because it requires C++ server implementation of + // header auth. For now it returns an IOError. + arrow::Status status = + client_->AuthenticateBasicToken(kValidUsername, kValidPassword, &bearer_token); + ASSERT_EQ(bearer_token.first, kAuthHeader); + ASSERT_EQ(bearer_token.second, (kBearerPrefix + kBearerToken)); + } + + void RunInvalidClientAuth() { + std::pair bearer_token; + // Note: Status intentionally ignored because it requires C++ server implementation of + // header auth. For now it returns an IOError. + arrow::Status status = client_->AuthenticateBasicToken( + kInvalidUsername, kInvalidPassword, &bearer_token); + ASSERT_EQ(bearer_token.first, std::string("")); + ASSERT_EQ(bearer_token.second, std::string("")); + } + + void TearDown() { + ASSERT_OK(server_->Shutdown()); + } + + protected: + std::unique_ptr client_; + std::unique_ptr server_; + std::shared_ptr server_middleware_; + std::shared_ptr client_middleware_; +}; + TEST_F(TestErrorMiddleware, TestMetadata) { Action action; std::unique_ptr stream; @@ -2172,5 +2291,13 @@ TEST_F(TestPropagatingMiddleware, DoPut) { ValidateStatus(status, FlightMethod::DoPut); } +TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { + RunValidClientAuth(); +} + +TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { + RunInvalidClientAuth(); +} + } // namespace flight } // namespace arrow From 01a134b4e8e8e8532349d33bf453b9a6e7044bbe Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Mon, 23 Nov 2020 14:00:39 -0800 Subject: [PATCH 10/31] [1] Fixed linting issues --- cpp/src/arrow/flight/client.h | 2 +- .../client_header_auth_middleware_internal.cc | 32 +++++++++-------- .../client_header_auth_middleware_internal.h | 6 ++-- cpp/src/arrow/flight/flight_test.cc | 28 +++++++-------- .../integration/Auth2IntegrationServer.java | 34 +++++++++++++++---- 5 files changed, 62 insertions(+), 40 deletions(-) diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 1810c8f837efa..af68d9c9c759b 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -199,7 +199,7 @@ class ARROW_FLIGHT_EXPORT FlightClient { /// \param[in] password Password to use /// \param[in] bearer_token Bearer token retreived if applicable /// \return Status OK if the client authenticated successfully - Status AuthenticateBasicToken(const std::string& username, const std::string& password, + Status AuthenticateBasicToken(const std::string& username, const std::string& password, std::pair* bearer_token); /// \brief Perform the indicated action, returning an iterator to the stream diff --git a/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc b/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc index 03830eb566682..92af4393abb4e 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc +++ b/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc @@ -30,9 +30,9 @@ #include #include -const std::string kAuthHeader = "authorization"; -const std::string kBearerPrefix = "Bearer "; -const std::string kBasicPrefix = "Basic "; +const char kAuthHeader[] = "authorization"; +const char kBearerPrefix[] = "Bearer "; +const char kBasicPrefix[] = "Basic "; namespace arrow { namespace flight { @@ -43,28 +43,30 @@ namespace internal { // @param context Context object to add the headers to. // @param username Username to format and encode. // @param password Password to format and encode. -void AddBasicAuthHeaders(grpc::ClientContext* context, +void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password) { const std::string credentials = username + ":" + password; - context->AddMetadata(kAuthHeader, kBasicPrefix + - arrow::util::base64_encode((const unsigned char*)credentials.c_str(), + context->AddMetadata(kAuthHeader, kBasicPrefix + + arrow::util::base64_encode((const unsigned char*)credentials.c_str(), credentials.size())); } class ClientBearerTokenFactory::Impl { public: - Impl(std::pair* bearer_token) + explicit Impl(std::pair* bearer_token) : bearer_token_(bearer_token) { } void StartCall(const CallInfo& info, std::unique_ptr* middleware) { ARROW_UNUSED(info); - *middleware = arrow::internal::make_unique(bearer_token_); + *middleware = + arrow::internal::make_unique(bearer_token_); } private: class ClientBearerTokenMiddleware : public ClientMiddleware { public: - explicit ClientBearerTokenMiddleware(std::pair* bearer_token) + explicit ClientBearerTokenMiddleware( + std::pair* bearer_token) : bearer_token_(bearer_token) { } void SendingHeaders(AddCallHeaders* outgoing_headers) override { } @@ -81,11 +83,11 @@ class ClientBearerTokenFactory::Impl { return; } - // Check if the value of the auth token starts with the bearer prefix, latch the token. + // Check if the value of the auth token starts with the bearer prefix and latch it. const std::string bearer_val = bearer_iter->second.to_string(); - if (bearer_val.size() > kBearerPrefix.size()) { - if (std::equal(bearer_val.begin(), bearer_val.begin() + kBearerPrefix.size(), - kBearerPrefix.begin(), char_compare)) { + if (bearer_val.size() > strlen(kBearerPrefix)) { + if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix), + kBearerPrefix, char_compare)) { *bearer_token_ = std::make_pair(kAuthHeader, bearer_val); } } @@ -102,12 +104,12 @@ class ClientBearerTokenFactory::Impl { }; ClientBearerTokenFactory::ClientBearerTokenFactory( - std::pair* bearer_token) + std::pair* bearer_token) : impl_(new ClientBearerTokenFactory::Impl(bearer_token)) { } ClientBearerTokenFactory::~ClientBearerTokenFactory() { } -void ClientBearerTokenFactory::StartCall(const CallInfo& info, +void ClientBearerTokenFactory::StartCall(const CallInfo& info, std::unique_ptr* middleware) { impl_->StartCall(info, middleware); } diff --git a/cpp/src/arrow/flight/client_header_auth_middleware_internal.h b/cpp/src/arrow/flight/client_header_auth_middleware_internal.h index 20f311d08603f..4bf5259aa0012 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware_internal.h +++ b/cpp/src/arrow/flight/client_header_auth_middleware_internal.h @@ -40,8 +40,8 @@ namespace internal { /// \param context grpc context variable to add header to. /// \param username username to encode into header. /// \param password password to to encode into header. -void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, - const std::string& username, +void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, + const std::string& username, const std::string& password); /// \brief Client-side middleware for receiving and latching a bearer token. @@ -51,7 +51,7 @@ class ARROW_FLIGHT_EXPORT ClientBearerTokenFactory : public ClientMiddlewareFact /// /// \param[out] bearer_token_ pointer to a std::pair of std::strings that the factory /// will populate with the bearer token that is received from the server. - ClientBearerTokenFactory(std::pair* bearer_token_); + explicit ClientBearerTokenFactory(std::pair* bearer_token_); ~ClientBearerTokenFactory(); diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 49a5df8567e25..f63428547ac72 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -54,14 +54,14 @@ namespace pb = arrow::flight::protocol; namespace arrow { namespace flight { -const std::string kValidUsername = "flight_username"; -const std::string kValidPassword = "flight_password"; -const std::string kInvalidUsername = "invalid_flight_username"; -const std::string kInvalidPassword = "invalid_flight_password"; -const std::string kBearerToken = "bearertoken"; -const std::string kBasicPrefix = "Basic "; -const std::string kBearerPrefix = "Bearer "; -const std::string kAuthHeader = "authorization"; +const char kValidUsername[] = "flight_username"; +const char kValidPassword[] = "flight_password"; +const char kInvalidUsername[] = "invalid_flight_username"; +const char kInvalidPassword[] = "invalid_flight_password"; +const char kBearerToken[] = "bearertoken"; +const char kBasicPrefix[] = "Basic "; +const char kBearerPrefix[] = "Bearer "; +const char kAuthHeader[] = "authorization"; void AssertEqual(const ActionType& expected, const ActionType& actual) { ASSERT_EQ(expected.type, actual.type); @@ -804,10 +804,10 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { const std::string key = iter.first.to_string(); const std::string val = iter.second.to_string(); if (key == kAuthHeader) { - if (val.size() > kBasicPrefix.size()) { - if (std::equal(val.begin(), val.begin() + kBasicPrefix.size(), - kBasicPrefix.begin(), char_compare)) { - const std::string encoded_credentials = val.substr(kBasicPrefix.size()); + if (val.size() > strlen(kBasicPrefix)) { + if (std::equal(val.begin(), val.begin() + strlen(kBasicPrefix), + kBasicPrefix, char_compare)) { + const std::string encoded_credentials = val.substr(strlen(kBasicPrefix)); const std::string decoded_credentials = arrow::util::base64_decode( encoded_credentials); std::stringstream decoded_stream(decoded_credentials); @@ -821,7 +821,7 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { if (username == kValidUsername && password == kValidPassword) { - outgoing_headers->AddHeader(kAuthHeader, kBearerPrefix + kBearerToken); + outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken); } } @@ -1088,7 +1088,7 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { arrow::Status status = client_->AuthenticateBasicToken(kValidUsername, kValidPassword, &bearer_token); ASSERT_EQ(bearer_token.first, kAuthHeader); - ASSERT_EQ(bearer_token.second, (kBearerPrefix + kBearerToken)); + ASSERT_EQ(bearer_token.second, (std::string(kBearerPrefix) + kBearerToken)); } void RunInvalidClientAuth() { diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java index 76c9b60cbb9bd..6d839c986f074 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java @@ -1,7 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.arrow.flight.example.integration; import com.google.common.base.Strings; -import org.apache.arrow.flight.*; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Location; import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator; @@ -34,12 +53,13 @@ static void launchServer() throws IOException, InterruptedException { store.setLocation(Location.forGrpcInsecure("localhost", server.getPort())); Runtime.getRuntime().addShutdownHook(new Thread(() -> { - try { - System.out.println("\nExiting..."); - AutoCloseables.close(server, ALLOCATOR); - } catch (Exception e) { - e.printStackTrace(); - }})); + try { + System.out.println("\nExiting..."); + AutoCloseables.close(server, ALLOCATOR); + } catch (Exception e) { + e.printStackTrace(); + } + })); System.out.println("Server running on " + server.getLocation()); server.awaitTermination(); From de78c6b51d04c0a1e82aac7b9ce2a4874d6fd659 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Mon, 23 Nov 2020 15:22:33 -0800 Subject: [PATCH 11/31] [1] Fixing linting issues --- cpp/src/arrow/flight/client.cc | 3 ++- .../arrow/flight/client_header_auth_middleware_internal.cc | 4 ++-- .../flight/example/integration/Auth2IntegrationServer.java | 7 ++++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index b6fe4d8d03d7f..2b29d73f40437 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -1014,7 +1014,8 @@ class FlightClient::FlightClientImpl { std::pair* bearer_token) { // Add bearer token factory to middleware so it can intercept the bearer token. if (interceptor_pointer != NULLPTR) { - interceptor_pointer->AddMiddlewareFactory(std::make_shared(bearer_token)); + interceptor_pointer->AddMiddlewareFactory( + std::make_shared(bearer_token)); } else { return MakeFlightError(FlightStatusCode::Internal, "Connect must be called before AuthenticateBasicToken."); diff --git a/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc b/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc index 92af4393abb4e..54bf001e0a730 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc +++ b/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc @@ -18,8 +18,8 @@ // Interfaces for defining middleware for Flight clients. Currently // experimental. -#include "client_header_auth_middleware_internal.h" #include "arrow/flight/client_auth.h" +#include "arrow/flight/client_header_auth_middleware_internal.h" #include "arrow/flight/client.h" #include "arrow/util/base64.h" #include "arrow/util/make_unique.h" @@ -58,7 +58,7 @@ class ClientBearerTokenFactory::Impl { void StartCall(const CallInfo& info, std::unique_ptr* middleware) { ARROW_UNUSED(info); - *middleware = + *middleware = arrow::internal::make_unique(bearer_token_); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java index 6d839c986f074..0fcaf75952ef0 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java @@ -17,9 +17,10 @@ package org.apache.arrow.flight.example.integration; -import com.google.common.base.Strings; -import org.apache.arrow.flight.FlightServer; +import java.io.IOException; + import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.Location; import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; @@ -29,7 +30,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; -import java.io.IOException; +import com.google.common.base.Strings; /** * Java server for running integration tests - this is currently setup to run From c44698f9093a8403889693e635fc9ba806e01109 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Mon, 23 Nov 2020 15:32:48 -0800 Subject: [PATCH 12/31] [1] Removed some extra spaces at the end of some lines. --- cpp/src/arrow/flight/flight_test.cc | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index f63428547ac72..0ff5086e184a3 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -805,7 +805,7 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { const std::string val = iter.second.to_string(); if (key == kAuthHeader) { if (val.size() > strlen(kBasicPrefix)) { - if (std::equal(val.begin(), val.begin() + strlen(kBasicPrefix), + if (std::equal(val.begin(), val.begin() + strlen(kBasicPrefix), kBasicPrefix, char_compare)) { const std::string encoded_credentials = val.substr(strlen(kBasicPrefix)); const std::string decoded_credentials = arrow::util::base64_decode( @@ -819,8 +819,7 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { } } - if (username == kValidUsername && - password == kValidPassword) { + if ((username == kValidUsername) && (password == kValidPassword)) { outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken); } } @@ -841,7 +840,7 @@ class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { std::shared_ptr* middleware) override { const std::pair& iter_pair = incoming_headers.equal_range(kAuthHeader); - if (iter_pair.first != iter_pair.second) { + if (iter_pair.first != iter_pair.second) { *middleware = std::make_shared(incoming_headers); } return Status::OK(); @@ -1072,11 +1071,11 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { void SetUp() { server_middleware_ = std::make_shared(); ASSERT_OK(MakeServer( - &server_, &client_, [&](FlightServerOptions* options) { + &server_, &client_, [&](FlightServerOptions* options) { options->auth_handler = std::unique_ptr( new TestServerAuthHandler("", "")); options->middleware.push_back({"header-auth-server", server_middleware_}); - return Status::OK(); + return Status::OK(); }, [&](FlightClientOptions* options) { return Status::OK(); })); } @@ -1085,7 +1084,7 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { std::pair bearer_token; // Note: Status intentionally ignored because it requires C++ server implementation of // header auth. For now it returns an IOError. - arrow::Status status = + arrow::Status status = client_->AuthenticateBasicToken(kValidUsername, kValidPassword, &bearer_token); ASSERT_EQ(bearer_token.first, kAuthHeader); ASSERT_EQ(bearer_token.second, (std::string(kBearerPrefix) + kBearerToken)); From e975fd8d136583b44a00c3dfef176bbca261784c Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Mon, 23 Nov 2020 17:43:06 -0800 Subject: [PATCH 13/31] [1] Correcting linting issues. --- cpp/src/arrow/flight/client.cc | 6 +--- .../client_header_auth_middleware_internal.cc | 32 +++++++++---------- cpp/src/arrow/flight/flight_test.cc | 31 ++++++++---------- .../test_integration_client_header_auth.cc | 8 +++-- 4 files changed, 35 insertions(+), 42 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 2b29d73f40437..60dd9c810ae6a 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -380,9 +380,7 @@ class GrpcClientInterceptorAdapterFactory middleware_.push_back(middleware_factory); } - void RemoveMiddlewareFactory() { - middleware_.pop_back(); - } + void RemoveMiddlewareFactory() { middleware_.pop_back(); } private: std::mutex middleware_lock_; @@ -1038,8 +1036,6 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - - Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, std::unique_ptr* listing) { pb::Criteria pb_criteria; diff --git a/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc b/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc index 54bf001e0a730..3796ae1cdb7b7 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc +++ b/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc @@ -18,17 +18,16 @@ // Interfaces for defining middleware for Flight clients. Currently // experimental. -#include "arrow/flight/client_auth.h" #include "arrow/flight/client_header_auth_middleware_internal.h" #include "arrow/flight/client.h" +#include "arrow/flight/client_auth.h" #include "arrow/util/base64.h" #include "arrow/util/make_unique.h" #include -#include #include -#include #include +#include const char kAuthHeader[] = "authorization"; const char kBearerPrefix[] = "Bearer "; @@ -43,18 +42,19 @@ namespace internal { // @param context Context object to add the headers to. // @param username Username to format and encode. // @param password Password to format and encode. -void AddBasicAuthHeaders(grpc::ClientContext* context, - const std::string& username, const std::string& password) { +void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, + const std::string& password) { const std::string credentials = username + ":" + password; - context->AddMetadata(kAuthHeader, kBasicPrefix + - arrow::util::base64_encode((const unsigned char*)credentials.c_str(), - credentials.size())); + context->AddMetadata( + kAuthHeader, + kBasicPrefix + arrow::util::base64_encode((const unsigned char*)credentials.c_str(), + credentials.size())); } class ClientBearerTokenFactory::Impl { public: explicit Impl(std::pair* bearer_token) - : bearer_token_(bearer_token) { } + : bearer_token_(bearer_token) {} void StartCall(const CallInfo& info, std::unique_ptr* middleware) { ARROW_UNUSED(info); @@ -67,14 +67,14 @@ class ClientBearerTokenFactory::Impl { public: explicit ClientBearerTokenMiddleware( std::pair* bearer_token) - : bearer_token_(bearer_token) { } + : bearer_token_(bearer_token) {} - void SendingHeaders(AddCallHeaders* outgoing_headers) override { } + void SendingHeaders(AddCallHeaders* outgoing_headers) override {} void ReceivedHeaders(const CallHeaders& incoming_headers) override { // Lambda function to compare characters without case sensitivity. - auto char_compare = [] (const char& char1, const char& char2) { - return (std::toupper(char1) == std::toupper(char2)); + auto char_compare = [](const char& char1, const char& char2) { + return (std::toupper(char1) == std::toupper(char2)); }; // Grab the auth token if one exists. @@ -93,7 +93,7 @@ class ClientBearerTokenFactory::Impl { } } - void CallCompleted(const Status& status) override { } + void CallCompleted(const Status& status) override {} private: std::pair* bearer_token_; @@ -105,9 +105,9 @@ class ClientBearerTokenFactory::Impl { ClientBearerTokenFactory::ClientBearerTokenFactory( std::pair* bearer_token) - : impl_(new ClientBearerTokenFactory::Impl(bearer_token)) { } + : impl_(new ClientBearerTokenFactory::Impl(bearer_token)) {} -ClientBearerTokenFactory::~ClientBearerTokenFactory() { } +ClientBearerTokenFactory::~ClientBearerTokenFactory() {} void ClientBearerTokenFactory::StartCall(const CallInfo& info, std::unique_ptr* middleware) { diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 0ff5086e184a3..490a6db582f40 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -794,8 +794,8 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { void SendingHeaders(AddCallHeaders* outgoing_headers) override { // Lambda function to compare characters without case sensitivity. - auto char_compare = [] (const char& char1, const char& char2) { - return (std::toupper(char1) == std::toupper(char2)); + auto char_compare = [](const char& char1, const char& char2) { + return (std::toupper(char1) == std::toupper(char2)); }; std::string username; @@ -805,11 +805,11 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { const std::string val = iter.second.to_string(); if (key == kAuthHeader) { if (val.size() > strlen(kBasicPrefix)) { - if (std::equal(val.begin(), val.begin() + strlen(kBasicPrefix), - kBasicPrefix, char_compare)) { + if (std::equal(val.begin(), val.begin() + strlen(kBasicPrefix), kBasicPrefix, + char_compare)) { const std::string encoded_credentials = val.substr(strlen(kBasicPrefix)); - const std::string decoded_credentials = arrow::util::base64_decode( - encoded_credentials); + const std::string decoded_credentials = + arrow::util::base64_decode(encoded_credentials); std::stringstream decoded_stream(decoded_credentials); std::getline(decoded_stream, username, ':'); std::getline(decoded_stream, password, ':'); @@ -1071,9 +1071,10 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { void SetUp() { server_middleware_ = std::make_shared(); ASSERT_OK(MakeServer( - &server_, &client_, [&](FlightServerOptions* options) { - options->auth_handler = std::unique_ptr( - new TestServerAuthHandler("", "")); + &server_, &client_, + [&](FlightServerOptions* options) { + options->auth_handler = + std::unique_ptr(new TestServerAuthHandler("", "")); options->middleware.push_back({"header-auth-server", server_middleware_}); return Status::OK(); }, @@ -1100,9 +1101,7 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { ASSERT_EQ(bearer_token.second, std::string("")); } - void TearDown() { - ASSERT_OK(server_->Shutdown()); - } + void TearDown() { ASSERT_OK(server_->Shutdown()); } protected: std::unique_ptr client_; @@ -2290,13 +2289,9 @@ TEST_F(TestPropagatingMiddleware, DoPut) { ValidateStatus(status, FlightMethod::DoPut); } -TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { - RunValidClientAuth(); -} +TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth(); } -TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { - RunInvalidClientAuth(); -} +TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); } } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test_integration_client_header_auth.cc b/cpp/src/arrow/flight/test_integration_client_header_auth.cc index c3086e9d9c12f..c39fc1a449b62 100644 --- a/cpp/src/arrow/flight/test_integration_client_header_auth.cc +++ b/cpp/src/arrow/flight/test_integration_client_header_auth.cc @@ -66,7 +66,7 @@ void TestValidCredentials() { // Authenticate credentials and retreive token. std::pair bearer_token = std::make_pair("", ""); ABORT_NOT_OK( - client->AuthenticateBasicToken(FLAGS_username, FLAGS_password, &bearer_token)); + client->AuthenticateBasicToken(FLAGS_username, FLAGS_password, &bearer_token)); // Validate token was received. if (bearer_token == std::make_pair(std::string(""), std::string(""))) { @@ -99,8 +99,10 @@ void TestInvalidCredentials() { // Authenticate credentials and retreive token. std::pair bearer_token = std::make_pair("", ""); EXPECT_EQ(arrow::StatusCode::IOError, - client->AuthenticateBasicToken( - FLAGS_username_invalid, FLAGS_password_invalid, &bearer_token).code()); + client + ->AuthenticateBasicToken(FLAGS_username_invalid, FLAGS_password_invalid, + &bearer_token) + .code()); // Validate token was received. if (bearer_token != std::make_pair(std::string(""), std::string(""))) { From 37889fa3b9fed132211fc63c9fe4340a1faeff7a Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Mon, 23 Nov 2020 18:01:13 -0800 Subject: [PATCH 14/31] [1] Minor cmake fix --- cpp/src/arrow/flight/CMakeLists.txt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index c1fe302b20e35..eb9f4a27aaaf3 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -214,9 +214,10 @@ if(ARROW_BUILD_INTEGRATION) target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) - add_executable(flight-test-integration-client-header-auth test_integration_client_header_auth.cc) - target_link_libraries(flight-test-integration-client-header-auth ${ARROW_FLIGHT_TEST_LINK_LIBS} - ${GFLAGS_LIBRARIES} GTest::gtest) + add_executable(flight-test-integration-client-header-auth + test_integration_client_header_auth.cc) + target_link_libraries(flight-test-integration-client-header-auth + ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) add_dependencies(arrow_flight flight-test-integration-client flight-test-integration-server From e7ac27cbe0a14be78dab6f3256da64c2898e08ec Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Mon, 23 Nov 2020 19:00:49 -0800 Subject: [PATCH 15/31] [1] Trying different cmake spacing. --- cpp/src/arrow/flight/CMakeLists.txt | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index eb9f4a27aaaf3..fdf30c8463c03 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -118,12 +118,12 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") # protobuf-internal.cc set(ARROW_FLIGHT_SRCS client.cc + client_header_auth_middleware_internal.cc internal.cc protocol_internal.cc serialization_internal.cc server.cc server_auth.cc - client_header_auth_middleware_internal.cc types.cc) add_arrow_lib(arrow_flight @@ -214,17 +214,16 @@ if(ARROW_BUILD_INTEGRATION) target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) - add_executable(flight-test-integration-client-header-auth - test_integration_client_header_auth.cc) - target_link_libraries(flight-test-integration-client-header-auth - ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) + add_executable(flight-test-integration-client-header-auth test_integration_client_header_auth.cc) + target_link_libraries(flight-test-integration-client-header-auth ${ARROW_FLIGHT_TEST_LINK_LIBS} + ${GFLAGS_LIBRARIES} GTest::gtest) add_dependencies(arrow_flight flight-test-integration-client - flight-test-integration-server - flight-test-integration-client-header-auth) + flight-test-integration-client-header-auth + flight-test-integration-server) add_dependencies(arrow-integration flight-test-integration-client - flight-test-integration-server - flight-test-integration-client-header-auth) + flight-test-integration-client-header-auth + flight-test-integration-server) endif() if(ARROW_BUILD_BENCHMARKS) From ba7cb9f77abdf6e769851c2287a83311351a2b31 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Mon, 23 Nov 2020 19:16:39 -0800 Subject: [PATCH 16/31] [1] Trying different cmake --- cpp/src/arrow/flight/CMakeLists.txt | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index fdf30c8463c03..23015ab88d019 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -214,16 +214,16 @@ if(ARROW_BUILD_INTEGRATION) target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) - add_executable(flight-test-integration-client-header-auth test_integration_client_header_auth.cc) - target_link_libraries(flight-test-integration-client-header-auth ${ARROW_FLIGHT_TEST_LINK_LIBS} + add_executable(flight-test-integration-auth test_integration_client_header_auth.cc) + target_link_libraries(flight-test-integration-auth ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) - add_dependencies(arrow_flight flight-test-integration-client - flight-test-integration-client-header-auth - flight-test-integration-server) - add_dependencies(arrow-integration flight-test-integration-client - flight-test-integration-client-header-auth - flight-test-integration-server) + add_dependencies(arrow_flight flight-test-integration-auth + flight-test-integration-client + flight-test-integration-server) + add_dependencies(arrow-integration flight-test-integration-auth + flight-test-integration-client + flight-test-integration-server) endif() if(ARROW_BUILD_BENCHMARKS) From 1426252242e2578e9a20e1ebf8b2e96fcc534dce Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 11:14:22 -0800 Subject: [PATCH 17/31] [1] Addressed code review comments. --- cpp/src/arrow/flight/CMakeLists.txt | 2 +- cpp/src/arrow/flight/client.cc | 40 +++--- cpp/src/arrow/flight/client.h | 9 +- .../client_header_auth_middleware_internal.cc | 119 ------------------ .../arrow/flight/client_header_internal.cc | 89 +++++++++++++ ...re_internal.h => client_header_internal.h} | 24 ++-- cpp/src/arrow/flight/flight_test.cc | 11 +- .../test_integration_client_header_auth.cc | 8 +- 8 files changed, 129 insertions(+), 173 deletions(-) delete mode 100644 cpp/src/arrow/flight/client_header_auth_middleware_internal.cc create mode 100644 cpp/src/arrow/flight/client_header_internal.cc rename cpp/src/arrow/flight/{client_header_auth_middleware_internal.h => client_header_internal.h} (72%) diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 23015ab88d019..fb60b602e8f97 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -118,7 +118,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") # protobuf-internal.cc set(ARROW_FLIGHT_SRCS client.cc - client_header_auth_middleware_internal.cc + client_header_internal.cc internal.cc protocol_internal.cc serialization_internal.cc diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 60dd9c810ae6a..636159a688032 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -50,7 +50,7 @@ #include "arrow/util/uri.h" #include "arrow/flight/client_auth.h" -#include "arrow/flight/client_header_auth_middleware_internal.h" +#include "arrow/flight/client_header_internal.h" #include "arrow/flight/client_middleware.h" #include "arrow/flight/internal.h" #include "arrow/flight/middleware.h" @@ -105,8 +105,8 @@ struct ClientRpc { std::chrono::system_clock::now() + options.timeout); context.set_deadline(deadline); } - for (auto metadata : options.metadata) { - context.AddMetadata(metadata.first, metadata.second); + for (auto header : options.headers) { + context.AddMetadata(header.first, header.second); } } @@ -883,7 +883,7 @@ constexpr char BLANK_ROOT_PEM[] = } // namespace class FlightClient::FlightClientImpl { public: - FlightClientImpl() : interceptor_pointer(NULLPTR) {} + FlightClientImpl() {} Status Connect(const Location& location, const FlightClientOptions& options) { const std::string& scheme = location.scheme(); @@ -978,8 +978,8 @@ class FlightClient::FlightClientImpl { std::vector> interceptors; - interceptor_pointer = new GrpcClientInterceptorAdapterFactory(options.middleware); - interceptors.emplace_back(interceptor_pointer); + interceptors.emplace_back( + new GrpcClientInterceptorAdapterFactory(std::move(options.middleware))); stub_ = pb::FlightService::NewStub( grpc::experimental::CreateCustomChannelWithInterceptors( @@ -1008,31 +1008,28 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - Status AuthenticateBasicToken(const std::string& username, const std::string& password, - std::pair* bearer_token) { - // Add bearer token factory to middleware so it can intercept the bearer token. - if (interceptor_pointer != NULLPTR) { - interceptor_pointer->AddMiddlewareFactory( - std::make_shared(bearer_token)); - } else { - return MakeFlightError(FlightStatusCode::Internal, - "Connect must be called before AuthenticateBasicToken."); - } + Status AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password, std::pair* bearer_token) { + // Add basic auth headers to outgoing headers. ClientRpc rpc({}); internal::AddBasicAuthHeaders(&rpc.context, username, password); + std::shared_ptr> stream = stub_->Handshake(&rpc.context); - GrpcClientAuthSender outgoing{stream}; GrpcClientAuthReader incoming{stream}; - // Explicitly close our side of the connection + + // Explicitly close our side of the connection. bool finished_writes = stream->WritesDone(); - interceptor_pointer->RemoveMiddlewareFactory(); RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); if (!finished_writes) { return MakeFlightError(FlightStatusCode::Internal, "Could not finish writing before closing"); } + + // Grab bearer token from incoming headers. + internal::GetBearerTokenHeader(rpc.context, bearer_token); return Status::OK(); } @@ -1218,7 +1215,6 @@ class FlightClient::FlightClientImpl { noop_auth_check_; #endif int64_t write_size_limit_bytes_; - GrpcClientInterceptorAdapterFactory* interceptor_pointer; }; FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); } @@ -1241,10 +1237,10 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, return impl_->Authenticate(options, std::move(auth_handler)); } -Status FlightClient::AuthenticateBasicToken( +Status FlightClient::AuthenticateBasicToken(const FlightCallOptions& options, const std::string& username, const std::string& password, std::pair* bearer_token) { - return impl_->AuthenticateBasicToken(username, password, bearer_token); + return impl_->AuthenticateBasicToken(options, username, password, bearer_token); } Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action, diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index af68d9c9c759b..2ef007286859c 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -66,8 +66,8 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions { /// \brief IPC writer options, if applicable for the call. ipc::IpcWriteOptions write_options; - /// \brief Metadata for client to add to context. - std::vector> metadata; + /// \brief Headers for client to add to context. + std::vector> headers; }; /// \brief Indicate that the client attempted to write a message @@ -194,12 +194,13 @@ class ARROW_FLIGHT_EXPORT FlightClient { Status Authenticate(const FlightCallOptions& options, std::unique_ptr auth_handler); - /// \brief Authenticate to the server using basic authentication with base 64 encoding. + /// \brief Authenticate to the server using basic HTTP style authentication. /// \param[in] username Username to use /// \param[in] password Password to use /// \param[in] bearer_token Bearer token retreived if applicable /// \return Status OK if the client authenticated successfully - Status AuthenticateBasicToken(const std::string& username, const std::string& password, + Status AuthenticateBasicToken(const FlightCallOptions& options, + const std::string& username, const std::string& password, std::pair* bearer_token); /// \brief Perform the indicated action, returning an iterator to the stream diff --git a/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc b/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc deleted file mode 100644 index 3796ae1cdb7b7..0000000000000 --- a/cpp/src/arrow/flight/client_header_auth_middleware_internal.cc +++ /dev/null @@ -1,119 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Interfaces for defining middleware for Flight clients. Currently -// experimental. - -#include "arrow/flight/client_header_auth_middleware_internal.h" -#include "arrow/flight/client.h" -#include "arrow/flight/client_auth.h" -#include "arrow/util/base64.h" -#include "arrow/util/make_unique.h" - -#include -#include -#include -#include - -const char kAuthHeader[] = "authorization"; -const char kBearerPrefix[] = "Bearer "; -const char kBasicPrefix[] = "Basic "; - -namespace arrow { -namespace flight { -namespace internal { - -// Add base64 encoded credentials to the outbound headers. -// -// @param context Context object to add the headers to. -// @param username Username to format and encode. -// @param password Password to format and encode. -void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, - const std::string& password) { - const std::string credentials = username + ":" + password; - context->AddMetadata( - kAuthHeader, - kBasicPrefix + arrow::util::base64_encode((const unsigned char*)credentials.c_str(), - credentials.size())); -} - -class ClientBearerTokenFactory::Impl { - public: - explicit Impl(std::pair* bearer_token) - : bearer_token_(bearer_token) {} - - void StartCall(const CallInfo& info, std::unique_ptr* middleware) { - ARROW_UNUSED(info); - *middleware = - arrow::internal::make_unique(bearer_token_); - } - - private: - class ClientBearerTokenMiddleware : public ClientMiddleware { - public: - explicit ClientBearerTokenMiddleware( - std::pair* bearer_token) - : bearer_token_(bearer_token) {} - - void SendingHeaders(AddCallHeaders* outgoing_headers) override {} - - void ReceivedHeaders(const CallHeaders& incoming_headers) override { - // Lambda function to compare characters without case sensitivity. - auto char_compare = [](const char& char1, const char& char2) { - return (std::toupper(char1) == std::toupper(char2)); - }; - - // Grab the auth token if one exists. - const auto bearer_iter = incoming_headers.find(kAuthHeader); - if (bearer_iter == incoming_headers.end()) { - return; - } - - // Check if the value of the auth token starts with the bearer prefix and latch it. - const std::string bearer_val = bearer_iter->second.to_string(); - if (bearer_val.size() > strlen(kBearerPrefix)) { - if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix), - kBearerPrefix, char_compare)) { - *bearer_token_ = std::make_pair(kAuthHeader, bearer_val); - } - } - } - - void CallCompleted(const Status& status) override {} - - private: - std::pair* bearer_token_; - }; - - private: - std::pair* bearer_token_; -}; - -ClientBearerTokenFactory::ClientBearerTokenFactory( - std::pair* bearer_token) - : impl_(new ClientBearerTokenFactory::Impl(bearer_token)) {} - -ClientBearerTokenFactory::~ClientBearerTokenFactory() {} - -void ClientBearerTokenFactory::StartCall(const CallInfo& info, - std::unique_ptr* middleware) { - impl_->StartCall(info, middleware); -} - -} // namespace internal -} // namespace flight -} // namespace arrow diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc new file mode 100644 index 0000000000000..3c9d470d807d9 --- /dev/null +++ b/cpp/src/arrow/flight/client_header_internal.cc @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Interfaces for defining middleware for Flight clients. Currently +// experimental. + +#include "arrow/flight/client_header_internal.h" +#include "arrow/flight/client.h" +#include "arrow/flight/client_auth.h" +#include "arrow/util/base64.h" +#include "arrow/util/make_unique.h" + +#include +#include +#include +#include + +const char kAuthHeader[] = "authorization"; +const char kBearerPrefix[] = "Bearer "; +const char kBasicPrefix[] = "Basic "; + +namespace arrow { +namespace flight { +namespace internal { + +// Add base64 encoded credentials to the outbound headers. +// +// @param context Context object to add the headers to. +// @param username Username to format and encode. +// @param password Password to format and encode. +void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, + const std::string& password) { + const std::string credentials = username + ":" + password; + context->AddMetadata( + kAuthHeader, + kBasicPrefix + arrow::util::base64_encode((const unsigned char*)credentials.c_str(), + credentials.size())); +} + +// Get bearer token from inbound headers. +// +// @param headers Incoming headers. +// @param[out] Bearer token pointer to set. +void GetBearerTokenHeader(grpc::ClientContext& context, + std::pair* bearer_token) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (std::toupper(char1) == std::toupper(char2)); + }; + + // Grab the auth token if one exists. + auto trailing_headers = context.GetServerTrailingMetadata(); + auto initial_headers = context.GetServerInitialMetadata(); + auto bearer_iter = trailing_headers.find(kAuthHeader); + if (bearer_iter == trailing_headers.end()) { + bearer_iter = initial_headers.find(kAuthHeader); + if (bearer_iter == initial_headers.end()) { + return; + } + } + + // Check if the value of the auth token starts with the bearer prefix and latch it. + std::string bearer_val(bearer_iter->second.data(), bearer_iter->second.size()); + if (bearer_val.size() > strlen(kBearerPrefix)) { + if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix), + kBearerPrefix, char_compare)) { + *bearer_token = std::make_pair(kAuthHeader, bearer_val); + } + } +} + + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/client_header_auth_middleware_internal.h b/cpp/src/arrow/flight/client_header_internal.h similarity index 72% rename from cpp/src/arrow/flight/client_header_auth_middleware_internal.h rename to cpp/src/arrow/flight/client_header_internal.h index 4bf5259aa0012..1a0b0243cc642 100644 --- a/cpp/src/arrow/flight/client_header_auth_middleware_internal.h +++ b/cpp/src/arrow/flight/client_header_internal.h @@ -44,23 +44,13 @@ void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password); -/// \brief Client-side middleware for receiving and latching a bearer token. -class ARROW_FLIGHT_EXPORT ClientBearerTokenFactory : public ClientMiddlewareFactory { - public: - /// \brief Constructor for factory. - /// - /// \param[out] bearer_token_ pointer to a std::pair of std::strings that the factory - /// will populate with the bearer token that is received from the server. - explicit ClientBearerTokenFactory(std::pair* bearer_token_); - - ~ClientBearerTokenFactory(); - - void StartCall(const CallInfo& info, std::unique_ptr* middleware); - - private: - class Impl; - std::unique_ptr impl_; -}; +/// \brief Get bearer token from incoming headers. +/// +/// \param headers headers to check for bearer token. +/// \param[out] bearer_token_ pointer to a std::pair of std::strings that the factory +/// will populate with the bearer token that is received from the server. +void ARROW_FLIGHT_EXPORT GetBearerTokenHeader( + grpc::ClientContext& context, std::pair* bearer_token); } // namespace internal } // namespace flight diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 490a6db582f40..4c5914adfbaf3 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -44,7 +44,7 @@ #error "gRPC headers should not be in public API" #endif -#include "arrow/flight/client_header_auth_middleware_internal.h" +#include "arrow/flight/client_header_internal.h" #include "arrow/flight/internal.h" #include "arrow/flight/middleware_internal.h" #include "arrow/flight/test_util.h" @@ -1074,7 +1074,7 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { &server_, &client_, [&](FlightServerOptions* options) { options->auth_handler = - std::unique_ptr(new TestServerAuthHandler("", "")); + std::unique_ptr(new NoOpAuthHandler()); options->middleware.push_back({"header-auth-server", server_middleware_}); return Status::OK(); }, @@ -1085,8 +1085,8 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { std::pair bearer_token; // Note: Status intentionally ignored because it requires C++ server implementation of // header auth. For now it returns an IOError. - arrow::Status status = - client_->AuthenticateBasicToken(kValidUsername, kValidPassword, &bearer_token); + arrow::Status status = client_->AuthenticateBasicToken({}, kValidUsername, + kValidPassword, &bearer_token); ASSERT_EQ(bearer_token.first, kAuthHeader); ASSERT_EQ(bearer_token.second, (std::string(kBearerPrefix) + kBearerToken)); } @@ -1096,7 +1096,7 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { // Note: Status intentionally ignored because it requires C++ server implementation of // header auth. For now it returns an IOError. arrow::Status status = client_->AuthenticateBasicToken( - kInvalidUsername, kInvalidPassword, &bearer_token); + {}, kInvalidUsername, kInvalidPassword, &bearer_token); ASSERT_EQ(bearer_token.first, std::string("")); ASSERT_EQ(bearer_token.second, std::string("")); } @@ -1107,7 +1107,6 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { std::unique_ptr client_; std::unique_ptr server_; std::shared_ptr server_middleware_; - std::shared_ptr client_middleware_; }; TEST_F(TestErrorMiddleware, TestMetadata) { diff --git a/cpp/src/arrow/flight/test_integration_client_header_auth.cc b/cpp/src/arrow/flight/test_integration_client_header_auth.cc index c39fc1a449b62..0a137ef9429f6 100644 --- a/cpp/src/arrow/flight/test_integration_client_header_auth.cc +++ b/cpp/src/arrow/flight/test_integration_client_header_auth.cc @@ -66,7 +66,7 @@ void TestValidCredentials() { // Authenticate credentials and retreive token. std::pair bearer_token = std::make_pair("", ""); ABORT_NOT_OK( - client->AuthenticateBasicToken(FLAGS_username, FLAGS_password, &bearer_token)); + client->AuthenticateBasicToken({}, FLAGS_username, FLAGS_password, &bearer_token)); // Validate token was received. if (bearer_token == std::make_pair(std::string(""), std::string(""))) { @@ -78,7 +78,7 @@ void TestValidCredentials() { // Try to list flights, this will force the bearer token to be send and authenticated. std::unique_ptr listing; arrow::flight::FlightCallOptions options; - options.metadata.push_back(bearer_token); + options.headers.push_back(bearer_token); ABORT_NOT_OK(client->ListFlights(options, {}, &listing)); std::cout << "Test valid credentials was successful." << std::endl; } @@ -100,8 +100,8 @@ void TestInvalidCredentials() { std::pair bearer_token = std::make_pair("", ""); EXPECT_EQ(arrow::StatusCode::IOError, client - ->AuthenticateBasicToken(FLAGS_username_invalid, FLAGS_password_invalid, - &bearer_token) + ->AuthenticateBasicToken({}, FLAGS_username_invalid, + FLAGS_password_invalid, &bearer_token) .code()); // Validate token was received. From 516d993a9136ac3a0fe7189717cb08de5f89cf5f Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 11:57:44 -0800 Subject: [PATCH 18/31] [1] Removing integration test and reverting some cmake changes. --- cpp/src/arrow/flight/CMakeLists.txt | 14 +- .../test_integration_client_header_auth.cc | 121 ------------------ .../integration/Auth2IntegrationServer.java | 89 ------------- 3 files changed, 4 insertions(+), 220 deletions(-) delete mode 100644 cpp/src/arrow/flight/test_integration_client_header_auth.cc delete mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index fb60b602e8f97..86e3c510ebbf4 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -214,16 +214,10 @@ if(ARROW_BUILD_INTEGRATION) target_link_libraries(flight-test-integration-client ${ARROW_FLIGHT_TEST_LINK_LIBS} ${GFLAGS_LIBRARIES} GTest::gtest) - add_executable(flight-test-integration-auth test_integration_client_header_auth.cc) - target_link_libraries(flight-test-integration-auth ${ARROW_FLIGHT_TEST_LINK_LIBS} - ${GFLAGS_LIBRARIES} GTest::gtest) - - add_dependencies(arrow_flight flight-test-integration-auth - flight-test-integration-client - flight-test-integration-server) - add_dependencies(arrow-integration flight-test-integration-auth - flight-test-integration-client - flight-test-integration-server) + add_dependencies(arrow_flight flight-test-integration-client + flight-test-integration-server) + add_dependencies(arrow-integration flight-test-integration-client + flight-test-integration-server) endif() if(ARROW_BUILD_BENCHMARKS) diff --git a/cpp/src/arrow/flight/test_integration_client_header_auth.cc b/cpp/src/arrow/flight/test_integration_client_header_auth.cc deleted file mode 100644 index 0a137ef9429f6..0000000000000 --- a/cpp/src/arrow/flight/test_integration_client_header_auth.cc +++ /dev/null @@ -1,121 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -// Client implementation for Flight integration testing. Loads -// RecordBatches from the given JSON file and uploads them to the -// Flight server, which stores the data and schema in memory. The -// client then requests the data from the server and compares it to -// the data originally uploaded. - -#include -#include -#include - -#include - -#include "arrow/io/file.h" -#include "arrow/io/test_common.h" -#include "arrow/ipc/dictionary.h" -#include "arrow/ipc/writer.h" -#include "arrow/record_batch.h" -#include "arrow/table.h" -#include "arrow/testing/extension_type.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/testing/json_integration.h" -#include "arrow/util/logging.h" - -#include "arrow/flight/api.h" -#include "arrow/flight/test_integration.h" -#include "arrow/flight/test_util.h" - -DEFINE_string(host, "localhost", "Server port to connect to"); -DEFINE_int32(port, 31337, "Server port to connect to"); -DEFINE_string(username, "flight1", "Username to use in basic auth"); -DEFINE_string(password, "woohoo1", "Password to use in basic auth"); -DEFINE_string(username_invalid, "foooo", "Username to use in basic auth"); -DEFINE_string(password_invalid, "barrr", "Password to use in basic auth"); - -void TestValidCredentials() { - std::cout << "Testing with valid auth credentials." << std::endl; - auto get_uri = []() { - return "grpc+tcp://" + FLAGS_host + ":" + std::to_string(FLAGS_port); - }; - - // Generate Location with URI. - arrow::flight::Location location; - ABORT_NOT_OK(arrow::flight::Location::Parse(get_uri(), &location)); - - // Create client and connect to Location. - std::unique_ptr client; - ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, &client)); - - // Authenticate credentials and retreive token. - std::pair bearer_token = std::make_pair("", ""); - ABORT_NOT_OK( - client->AuthenticateBasicToken({}, FLAGS_username, FLAGS_password, &bearer_token)); - - // Validate token was received. - if (bearer_token == std::make_pair(std::string(""), std::string(""))) { - std::cout << "Testing valid credentials was unsuccessful: " - << "Failed to get token from basic authentication." << std::endl; - return; - } - - // Try to list flights, this will force the bearer token to be send and authenticated. - std::unique_ptr listing; - arrow::flight::FlightCallOptions options; - options.headers.push_back(bearer_token); - ABORT_NOT_OK(client->ListFlights(options, {}, &listing)); - std::cout << "Test valid credentials was successful." << std::endl; -} - -void TestInvalidCredentials() { - auto get_uri = []() { - return "grpc+tcp://" + FLAGS_host + ":" + std::to_string(FLAGS_port); - }; - - // Generate Location with URI. - arrow::flight::Location location; - ABORT_NOT_OK(arrow::flight::Location::Parse(get_uri(), &location)); - - // Create client and connect to Location. - std::unique_ptr client; - ABORT_NOT_OK(arrow::flight::FlightClient::Connect(location, &client)); - - // Authenticate credentials and retreive token. - std::pair bearer_token = std::make_pair("", ""); - EXPECT_EQ(arrow::StatusCode::IOError, - client - ->AuthenticateBasicToken({}, FLAGS_username_invalid, - FLAGS_password_invalid, &bearer_token) - .code()); - - // Validate token was received. - if (bearer_token != std::make_pair(std::string(""), std::string(""))) { - std::cout << "Testing invalid credentials was unsuccessful: " - << "Obtained token from basic authentication when using " - << "invalid credentials." << std::endl; - } - - std::cout << "Testing invalid credentials was successful." << std::endl; -} - -int main(int argc, char** argv) { - std::cout << "Starting auth header based flight integration test." << std::endl; - TestValidCredentials(); - TestInvalidCredentials(); -} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java deleted file mode 100644 index 0fcaf75952ef0..0000000000000 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Auth2IntegrationServer.java +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.arrow.flight.example.integration; - -import java.io.IOException; - -import org.apache.arrow.flight.CallStatus; -import org.apache.arrow.flight.FlightServer; -import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; -import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; -import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator; -import org.apache.arrow.flight.example.InMemoryStore; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.util.AutoCloseables; - -import com.google.common.base.Strings; - -/** - * Java server for running integration tests - this is currently setup to run - * against the cpp test 'flight-test-integration-client-header-auth'. - */ -public class Auth2IntegrationServer { - private static final int PORT = 31337; - private static final String USERNAME_1 = "flight1"; - private static final String PASSWORD_1 = "woohoo1"; - private static final String HOST = "localhost"; - private static final BufferAllocator ALLOCATOR = new RootAllocator(Long.MAX_VALUE); - private static FlightServer server; - - static void launchServer() throws IOException, InterruptedException { - final Location location = Location.forGrpcInsecure(HOST, PORT); - final InMemoryStore store = new InMemoryStore(ALLOCATOR, location); - server = FlightServer.builder(ALLOCATOR, location, store).headerAuthenticator( - new GeneratedBearerTokenAuthenticator( - new BasicCallHeaderAuthenticator(Auth2IntegrationServer::validate)) - ).build().start(); - store.setLocation(Location.forGrpcInsecure("localhost", server.getPort())); - - Runtime.getRuntime().addShutdownHook(new Thread(() -> { - try { - System.out.println("\nExiting..."); - AutoCloseables.close(server, ALLOCATOR); - } catch (Exception e) { - e.printStackTrace(); - } - })); - - System.out.println("Server running on " + server.getLocation()); - server.awaitTermination(); - } - - private static CallHeaderAuthenticator.AuthResult validate(String username, String password) { - if (Strings.isNullOrEmpty(username)) { - throw CallStatus.UNAUTHENTICATED.withDescription("Credentials not supplied.").toRuntimeException(); - } - final String identity; - if (USERNAME_1.equals(username) && PASSWORD_1.equals(password)) { - identity = USERNAME_1; - } else { - throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException(); - } - return () -> identity; - } - - public static void main(String[] args) { - try { - launchServer(); - } catch (Exception e) { - System.out.println("Launching server failed " + e); - } - } -} From 3000ecba7d02293969ec4602b8d071c4a39e8817 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 12:06:17 -0800 Subject: [PATCH 19/31] [1] Removed some no longer used functionality. [2] Fixed option passing. --- cpp/src/arrow/flight/client.cc | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 636159a688032..5e79b65a6cf98 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -332,7 +332,7 @@ class GrpcClientInterceptorAdapterFactory : public grpc::experimental::ClientInterceptorFactoryInterface { public: GrpcClientInterceptorAdapterFactory( - const std::vector> middleware) + std::vector> middleware) : middleware_(middleware) {} grpc::experimental::Interceptor* CreateClientInterceptor( @@ -364,7 +364,6 @@ class GrpcClientInterceptorAdapterFactory } const CallInfo flight_info{flight_method}; - std::lock_guard lock(middleware_lock_); for (auto& factory : middleware_) { std::unique_ptr instance; factory->StartCall(flight_info, &instance); @@ -374,17 +373,8 @@ class GrpcClientInterceptorAdapterFactory } return new GrpcClientInterceptorAdapter(std::move(middleware)); } - - void AddMiddlewareFactory(std::shared_ptr middleware_factory) { - std::lock_guard lock(middleware_lock_); - middleware_.push_back(middleware_factory); - } - - void RemoveMiddlewareFactory() { middleware_.pop_back(); } - private: - std::mutex middleware_lock_; - std::vector> middleware_; + std::vector> middleware_; }; class GrpcClientAuthSender : public ClientAuthSender { @@ -883,8 +873,6 @@ constexpr char BLANK_ROOT_PEM[] = } // namespace class FlightClient::FlightClientImpl { public: - FlightClientImpl() {} - Status Connect(const Location& location, const FlightClientOptions& options) { const std::string& scheme = location.scheme(); @@ -1012,7 +1000,7 @@ class FlightClient::FlightClientImpl { const FlightCallOptions& options, const std::string& username, const std::string& password, std::pair* bearer_token) { // Add basic auth headers to outgoing headers. - ClientRpc rpc({}); + ClientRpc rpc(options); internal::AddBasicAuthHeaders(&rpc.context, username, password); std::shared_ptr> From 1de10fa473622f45400662f1f10160feaab11cad Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 13:17:04 -0800 Subject: [PATCH 20/31] [1] Added improved testing and fixed linting. --- cpp/src/arrow/flight/client.cc | 15 +- .../arrow/flight/client_header_internal.cc | 1 - cpp/src/arrow/flight/flight_test.cc | 133 +++++++++++++----- 3 files changed, 107 insertions(+), 42 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 5e79b65a6cf98..95a0dc67539cb 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -373,8 +373,9 @@ class GrpcClientInterceptorAdapterFactory } return new GrpcClientInterceptorAdapter(std::move(middleware)); } + private: - std::vector> middleware_; + std::vector> middleware_; }; class GrpcClientAuthSender : public ClientAuthSender { @@ -996,9 +997,9 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - Status AuthenticateBasicToken( - const FlightCallOptions& options, const std::string& username, - const std::string& password, std::pair* bearer_token) { + Status AuthenticateBasicToken(const FlightCallOptions& options, + const std::string& username, const std::string& password, + std::pair* bearer_token) { // Add basic auth headers to outgoing headers. ClientRpc rpc(options); internal::AddBasicAuthHeaders(&rpc.context, username, password); @@ -1225,9 +1226,9 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, return impl_->Authenticate(options, std::move(auth_handler)); } -Status FlightClient::AuthenticateBasicToken(const FlightCallOptions& options, - const std::string& username, const std::string& password, - std::pair* bearer_token) { +Status FlightClient::AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password, std::pair* bearer_token) { return impl_->AuthenticateBasicToken(options, username, password, bearer_token); } diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc index 3c9d470d807d9..17a13b83c065d 100644 --- a/cpp/src/arrow/flight/client_header_internal.cc +++ b/cpp/src/arrow/flight/client_header_internal.cc @@ -83,7 +83,6 @@ void GetBearerTokenHeader(grpc::ClientContext& context, } } - } // namespace internal } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 4c5914adfbaf3..005bedeb0cf19 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -553,6 +553,14 @@ class OptionsTestServer : public FlightServerBase { } }; +class HeaderAuthTestServer : public FlightServerBase { + public: + Status ListFlights(const ServerCallContext& context, const Criteria* criteria, + std::unique_ptr* listings) override { + return Status::OK(); + } +}; + class TestMetadata : public ::testing::Test { public: void SetUp() { @@ -785,6 +793,30 @@ class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { } }; +// Function to look in CallHeaders for a key that has a value starting with prefix and +// return the rest of the value after the prefix. +std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers, + const std::string& key, + const std::string& prefix) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (std::toupper(char1) == std::toupper(char2)); + }; + + auto iter = incoming_headers.find(key); + if (iter == incoming_headers.end()) { + return ""; + } + const std::string val = iter->second.to_string(); + if (val.size() > prefix.length()) { + if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(), + char_compare)) { + return val.substr(prefix.length()); + } + } + return ""; +} + // A server middleware for validating incoming base64 header authentication. class HeaderAuthServerMiddleware : public ServerMiddleware { public: @@ -793,32 +825,12 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { } void SendingHeaders(AddCallHeaders* outgoing_headers) override { - // Lambda function to compare characters without case sensitivity. - auto char_compare = [](const char& char1, const char& char2) { - return (std::toupper(char1) == std::toupper(char2)); - }; - - std::string username; - std::string password; - for (auto& iter : incoming_headers_) { - const std::string key = iter.first.to_string(); - const std::string val = iter.second.to_string(); - if (key == kAuthHeader) { - if (val.size() > strlen(kBasicPrefix)) { - if (std::equal(val.begin(), val.begin() + strlen(kBasicPrefix), kBasicPrefix, - char_compare)) { - const std::string encoded_credentials = val.substr(strlen(kBasicPrefix)); - const std::string decoded_credentials = - arrow::util::base64_decode(encoded_credentials); - std::stringstream decoded_stream(decoded_credentials); - std::getline(decoded_stream, username, ':'); - std::getline(decoded_stream, password, ':'); - break; - } - } - } - } - + std::string encoded_credentials = + FindKeyValPrefixInCallHeaders(incoming_headers_, kAuthHeader, kBasicPrefix); + std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials)); + std::string username, password; + std::getline(decoded_stream, username, ':'); + std::getline(decoded_stream, password, ':'); if ((username == kValidUsername) && (password == kValidPassword)) { outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken); } @@ -831,6 +843,29 @@ class HeaderAuthServerMiddleware : public ServerMiddleware { CallHeaders incoming_headers_; }; +// A server middleware for validating incoming bearer header authentication. +class BearerAuthServerMiddleware : public ServerMiddleware { + public: + explicit BearerAuthServerMiddleware(const CallHeaders& incoming_headers, bool* isValid) + : isValid_(isValid) { + incoming_headers_ = incoming_headers; + } + + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + std::string bearer_token = + FindKeyValPrefixInCallHeaders(incoming_headers_, kAuthHeader, kBearerPrefix); + *isValid_ = (bearer_token == std::string(kBearerToken)); + } + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "BearerAuthServerMiddleware"; } + + private: + CallHeaders incoming_headers_; + bool* isValid_; +}; + // Factory for base64 header authentication testing. class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { public: @@ -847,6 +882,28 @@ class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { } }; +// Factory for base64 header authentication testing. +class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + BearerAuthServerMiddlewareFactory() : isValid_(false) {} + + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override { + const std::pair& iter_pair = + incoming_headers.equal_range(kAuthHeader); + if (iter_pair.first != iter_pair.second) { + *middleware = + std::make_shared(incoming_headers, &isValid_); + } + return Status::OK(); + } + + bool GetIsValid() { return isValid_; }; + + private: + bool isValid_; +}; + // A client middleware that adds a thread-local "request ID" to // outgoing calls as a header, and keeps track of the status of // completed calls. NOT thread-safe. @@ -1069,13 +1126,17 @@ class TestErrorMiddleware : public ::testing::Test { class TestBasicHeaderAuthMiddleware : public ::testing::Test { public: void SetUp() { - server_middleware_ = std::make_shared(); - ASSERT_OK(MakeServer( + header_middleware_ = std::make_shared(); + bearer_middleware_ = std::make_shared(); + std::pair bearer = make_pair( + kAuthHeader, std::string(kBearerPrefix) + " " + std::string(kBearerToken)); + ASSERT_OK(MakeServer( &server_, &client_, [&](FlightServerOptions* options) { options->auth_handler = std::unique_ptr(new NoOpAuthHandler()); - options->middleware.push_back({"header-auth-server", server_middleware_}); + options->middleware.push_back({"header-auth-server", header_middleware_}); + options->middleware.push_back({"bearer-auth-server", bearer_middleware_}); return Status::OK(); }, [&](FlightClientOptions* options) { return Status::OK(); })); @@ -1083,12 +1144,15 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { void RunValidClientAuth() { std::pair bearer_token; - // Note: Status intentionally ignored because it requires C++ server implementation of - // header auth. For now it returns an IOError. - arrow::Status status = client_->AuthenticateBasicToken({}, kValidUsername, - kValidPassword, &bearer_token); + ASSERT_OK(client_->AuthenticateBasicToken({}, kValidUsername, kValidPassword, + &bearer_token)); ASSERT_EQ(bearer_token.first, kAuthHeader); ASSERT_EQ(bearer_token.second, (std::string(kBearerPrefix) + kBearerToken)); + std::unique_ptr listing; + FlightCallOptions call_options; + call_options.headers.push_back(bearer_token); + ASSERT_OK(client_->ListFlights(call_options, {}, &listing)); + ASSERT_TRUE(bearer_middleware_->GetIsValid()); } void RunInvalidClientAuth() { @@ -1106,7 +1170,8 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { protected: std::unique_ptr client_; std::unique_ptr server_; - std::shared_ptr server_middleware_; + std::shared_ptr header_middleware_; + std::shared_ptr bearer_middleware_; }; TEST_F(TestErrorMiddleware, TestMetadata) { From 065af4a77e09e796c633a45d137da5c4bfcf7a7e Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 14:38:15 -0800 Subject: [PATCH 21/31] [1] Fixed lint issue --- cpp/src/arrow/flight/flight_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 005bedeb0cf19..8012b8484baca 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -898,7 +898,7 @@ class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory { return Status::OK(); } - bool GetIsValid() { return isValid_; }; + bool GetIsValid() { return isValid_; } private: bool isValid_; From 911fcc746d435f7dd061fe359651c15e4bfdead2 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 16:43:52 -0800 Subject: [PATCH 22/31] [1] Updating submodule --- testing | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing b/testing index 3ab0d53528a05..d6c4deb22c4b4 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit 3ab0d53528a050c370a31c3741574250a6e88a4d +Subproject commit d6c4deb22c4b4e9e3247a2f291046e3c671ad235 From f41edce63892d778b80017f3223ee7ce05e2660e Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 17:17:50 -0800 Subject: [PATCH 23/31] [1] Minor documentation fixes. --- cpp/src/arrow/flight/client_header_internal.cc | 6 +++--- cpp/src/arrow/flight/client_header_internal.h | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc index 17a13b83c065d..34f913f176516 100644 --- a/cpp/src/arrow/flight/client_header_internal.cc +++ b/cpp/src/arrow/flight/client_header_internal.cc @@ -53,8 +53,8 @@ void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& userna // Get bearer token from inbound headers. // -// @param headers Incoming headers. -// @param[out] Bearer token pointer to set. +// @param context Incoming ClientContext that contains headers. +// @param bearer_token[out] Bearer token pointer to set. void GetBearerTokenHeader(grpc::ClientContext& context, std::pair* bearer_token) { // Lambda function to compare characters without case sensitivity. @@ -62,7 +62,7 @@ void GetBearerTokenHeader(grpc::ClientContext& context, return (std::toupper(char1) == std::toupper(char2)); }; - // Grab the auth token if one exists. + // Get the auth token if it exists, this can be in the initial or the trailing metadata. auto trailing_headers = context.GetServerTrailingMetadata(); auto initial_headers = context.GetServerInitialMetadata(); auto bearer_iter = trailing_headers.find(kAuthHeader); diff --git a/cpp/src/arrow/flight/client_header_internal.h b/cpp/src/arrow/flight/client_header_internal.h index 1a0b0243cc642..3c7e993980dfb 100644 --- a/cpp/src/arrow/flight/client_header_internal.h +++ b/cpp/src/arrow/flight/client_header_internal.h @@ -46,8 +46,8 @@ void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, /// \brief Get bearer token from incoming headers. /// -/// \param headers headers to check for bearer token. -/// \param[out] bearer_token_ pointer to a std::pair of std::strings that the factory +/// \param context context that contains headers which hold the bearer token. +/// \param[out] bearer_token pointer to a std::pair of std::strings that the factory /// will populate with the bearer token that is received from the server. void ARROW_FLIGHT_EXPORT GetBearerTokenHeader( grpc::ClientContext& context, std::pair* bearer_token); From 6b6fbbe704279b8d8e90c847cdf61bb983bcc656 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 17:53:24 -0800 Subject: [PATCH 24/31] [1] Fixed casting issue on some builds --- cpp/src/arrow/flight/client_header_internal.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc index 34f913f176516..8b29d9f6ba7f8 100644 --- a/cpp/src/arrow/flight/client_header_internal.cc +++ b/cpp/src/arrow/flight/client_header_internal.cc @@ -47,8 +47,9 @@ void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& userna const std::string credentials = username + ":" + password; context->AddMetadata( kAuthHeader, - kBasicPrefix + arrow::util::base64_encode((const unsigned char*)credentials.c_str(), - credentials.size())); + kBasicPrefix + arrow::util::base64_encode( + static_cast(credentials.c_str()), + static_cast(credentials.size())); } // Get bearer token from inbound headers. From 199b655591ce8428a5f270dada745aca5251c080 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 18:07:47 -0800 Subject: [PATCH 25/31] [1] Added missing parameter for documentation --- cpp/src/arrow/flight/client.h | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 2ef007286859c..2f721b5fca0d3 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -195,6 +195,7 @@ class ARROW_FLIGHT_EXPORT FlightClient { std::unique_ptr auth_handler); /// \brief Authenticate to the server using basic HTTP style authentication. + /// \param[in] options Per-RPC options /// \param[in] username Username to use /// \param[in] password Password to use /// \param[in] bearer_token Bearer token retreived if applicable From 47aa581e24bae230a9dadce469491686c194e5f3 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 18:19:19 -0800 Subject: [PATCH 26/31] [1] Fixing cast. --- cpp/src/arrow/flight/client_header_internal.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc index 8b29d9f6ba7f8..9bd843706f87a 100644 --- a/cpp/src/arrow/flight/client_header_internal.cc +++ b/cpp/src/arrow/flight/client_header_internal.cc @@ -48,8 +48,8 @@ void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& userna context->AddMetadata( kAuthHeader, kBasicPrefix + arrow::util::base64_encode( - static_cast(credentials.c_str()), - static_cast(credentials.size())); + reinterpret_cast(credentials.c_str()), + static_cast(credentials.size()))); } // Get bearer token from inbound headers. From 477d865b8c48e735923bc313f81a37e1dbfb7134 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 18:48:56 -0800 Subject: [PATCH 27/31] [1] Moving std:: from toupper call because it causes break in some builds. --- cpp/src/arrow/flight/client_header_internal.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc index 9bd843706f87a..5dfaf9f679a46 100644 --- a/cpp/src/arrow/flight/client_header_internal.cc +++ b/cpp/src/arrow/flight/client_header_internal.cc @@ -60,7 +60,7 @@ void GetBearerTokenHeader(grpc::ClientContext& context, std::pair* bearer_token) { // Lambda function to compare characters without case sensitivity. auto char_compare = [](const char& char1, const char& char2) { - return (std::toupper(char1) == std::toupper(char2)); + return (::toupper(char1) == ::toupper(char2)); }; // Get the auth token if it exists, this can be in the initial or the trailing metadata. From 1cc3fdbde0c00380ee18405f2a5243d4e3c2ae9e Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 18:52:33 -0800 Subject: [PATCH 28/31] [1] Adding missed std remove --- cpp/src/arrow/flight/flight_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index d925bf3ebed57..f0ec94dbba8a1 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -817,7 +817,7 @@ std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers, const std::string& prefix) { // Lambda function to compare characters without case sensitivity. auto char_compare = [](const char& char1, const char& char2) { - return (std::toupper(char1) == std::toupper(char2)); + return (::toupper(char1) == ::toupper(char2)); }; auto iter = incoming_headers.find(key); From d27465d6908262df8df2bbea0881237df142cd47 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Tue, 24 Nov 2020 23:06:02 -0800 Subject: [PATCH 29/31] [1] Fixed linting issue. --- cpp/src/arrow/flight/client_header_internal.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc index 5dfaf9f679a46..4dd934943d24c 100644 --- a/cpp/src/arrow/flight/client_header_internal.cc +++ b/cpp/src/arrow/flight/client_header_internal.cc @@ -48,8 +48,8 @@ void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& userna context->AddMetadata( kAuthHeader, kBasicPrefix + arrow::util::base64_encode( - reinterpret_cast(credentials.c_str()), - static_cast(credentials.size()))); + reinterpret_cast(credentials.c_str()), + static_cast(credentials.size()))); } // Get bearer token from inbound headers. From d21006f747816c6416c86298ac704ac12ad4caa1 Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Wed, 25 Nov 2020 10:58:33 -0800 Subject: [PATCH 30/31] [1] Updated test return error properly and to check for error [2] Switched return type of AuthenticateBasicToken to use arrow::Result instead of Status and removed bearer token from parameter list --- cpp/src/arrow/flight/client.cc | 15 ++-- cpp/src/arrow/flight/client.h | 11 +-- .../arrow/flight/client_header_internal.cc | 29 ++++--- cpp/src/arrow/flight/client_header_internal.h | 8 +- cpp/src/arrow/flight/flight_test.cc | 83 +++++++++---------- 5 files changed, 73 insertions(+), 73 deletions(-) diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 9a10833b0caee..5c56e6409a75b 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -998,9 +998,9 @@ class FlightClient::FlightClientImpl { return Status::OK(); } - Status AuthenticateBasicToken(const FlightCallOptions& options, - const std::string& username, const std::string& password, - std::pair* bearer_token) { + arrow::Result> AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password) { // Add basic auth headers to outgoing headers. ClientRpc rpc(options); internal::AddBasicAuthHeaders(&rpc.context, username, password); @@ -1019,8 +1019,7 @@ class FlightClient::FlightClientImpl { } // Grab bearer token from incoming headers. - internal::GetBearerTokenHeader(rpc.context, bearer_token); - return Status::OK(); + return internal::GetBearerTokenHeader(rpc.context); } Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, @@ -1227,10 +1226,10 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, return impl_->Authenticate(options, std::move(auth_handler)); } -Status FlightClient::AuthenticateBasicToken( +arrow::Result> FlightClient::AuthenticateBasicToken( const FlightCallOptions& options, const std::string& username, - const std::string& password, std::pair* bearer_token) { - return impl_->AuthenticateBasicToken(options, username, password, bearer_token); + const std::string& password) { + return impl_->AuthenticateBasicToken(options, username, password); } Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action, diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 2f721b5fca0d3..441f11467669e 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -29,6 +29,7 @@ #include "arrow/ipc/options.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/variant.h" @@ -198,11 +199,11 @@ class ARROW_FLIGHT_EXPORT FlightClient { /// \param[in] options Per-RPC options /// \param[in] username Username to use /// \param[in] password Password to use - /// \param[in] bearer_token Bearer token retreived if applicable - /// \return Status OK if the client authenticated successfully - Status AuthenticateBasicToken(const FlightCallOptions& options, - const std::string& username, const std::string& password, - std::pair* bearer_token); + /// \return Arrow result with bearer token and status OK if client authenticated + /// sucessfully + arrow::Result> AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password); /// \brief Perform the indicated action, returning an iterator to the stream /// of results, if any diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc index 4dd934943d24c..2112b41f72f33 100644 --- a/cpp/src/arrow/flight/client_header_internal.cc +++ b/cpp/src/arrow/flight/client_header_internal.cc @@ -37,11 +37,11 @@ namespace arrow { namespace flight { namespace internal { -// Add base64 encoded credentials to the outbound headers. -// -// @param context Context object to add the headers to. -// @param username Username to format and encode. -// @param password Password to format and encode. +/// \brief Add base64 encoded credentials to the outbound headers. +/// +/// \param context Context object to add the headers to. +/// \param username Username to format and encode. +/// \param password Password to format and encode. void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, const std::string& password) { const std::string credentials = username + ":" + password; @@ -52,12 +52,12 @@ void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& userna static_cast(credentials.size()))); } -// Get bearer token from inbound headers. -// -// @param context Incoming ClientContext that contains headers. -// @param bearer_token[out] Bearer token pointer to set. -void GetBearerTokenHeader(grpc::ClientContext& context, - std::pair* bearer_token) { +/// \brief Get bearer token from inbound headers. +/// +/// \param context Incoming ClientContext that contains headers. +/// \return Arrow result with bearer token (empty if no bearer token found). +arrow::Result> GetBearerTokenHeader( + grpc::ClientContext& context) { // Lambda function to compare characters without case sensitivity. auto char_compare = [](const char& char1, const char& char2) { return (::toupper(char1) == ::toupper(char2)); @@ -70,7 +70,7 @@ void GetBearerTokenHeader(grpc::ClientContext& context, if (bearer_iter == trailing_headers.end()) { bearer_iter = initial_headers.find(kAuthHeader); if (bearer_iter == initial_headers.end()) { - return; + return std::make_pair("", ""); } } @@ -79,9 +79,12 @@ void GetBearerTokenHeader(grpc::ClientContext& context, if (bearer_val.size() > strlen(kBearerPrefix)) { if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix), kBearerPrefix, char_compare)) { - *bearer_token = std::make_pair(kAuthHeader, bearer_val); + return std::make_pair(kAuthHeader, bearer_val); } } + + // The server is not required to provide a bearer token. + return std::make_pair("", ""); } } // namespace internal diff --git a/cpp/src/arrow/flight/client_header_internal.h b/cpp/src/arrow/flight/client_header_internal.h index 3c7e993980dfb..718848a5ffd46 100644 --- a/cpp/src/arrow/flight/client_header_internal.h +++ b/cpp/src/arrow/flight/client_header_internal.h @@ -21,6 +21,7 @@ #pragma once #include "arrow/flight/client_middleware.h" +#include "arrow/result.h" #ifdef GRPCPP_PP_INCLUDE #include @@ -47,10 +48,9 @@ void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, /// \brief Get bearer token from incoming headers. /// /// \param context context that contains headers which hold the bearer token. -/// \param[out] bearer_token pointer to a std::pair of std::strings that the factory -/// will populate with the bearer token that is received from the server. -void ARROW_FLIGHT_EXPORT GetBearerTokenHeader( - grpc::ClientContext& context, std::pair* bearer_token); +/// \return Bearer token, parsed from headers, empty if one is not present. +arrow::Result> ARROW_FLIGHT_EXPORT +GetBearerTokenHeader(grpc::ClientContext& context); } // namespace internal } // namespace flight diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index f0ec94dbba8a1..f247059ef8ddd 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -834,30 +834,44 @@ std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers, return ""; } -// A server middleware for validating incoming base64 header authentication. class HeaderAuthServerMiddleware : public ServerMiddleware { public: - explicit HeaderAuthServerMiddleware(const CallHeaders& incoming_headers) { - incoming_headers_ = incoming_headers; - } + explicit HeaderAuthServerMiddleware() {} void SendingHeaders(AddCallHeaders* outgoing_headers) override { - std::string encoded_credentials = - FindKeyValPrefixInCallHeaders(incoming_headers_, kAuthHeader, kBasicPrefix); - std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials)); - std::string username, password; - std::getline(decoded_stream, username, ':'); - std::getline(decoded_stream, password, ':'); - if ((username == kValidUsername) && (password == kValidPassword)) { - outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken); - } + outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken); } void CallCompleted(const Status& status) override {} std::string name() const override { return "HeaderAuthServerMiddleware"; } +}; - CallHeaders incoming_headers_; +void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username, + std::string& password) { + std::string encoded_credentials = + FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix); + std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials)); + std::getline(decoded_stream, username, ':'); + std::getline(decoded_stream, password, ':'); +} + +// Factory for base64 header authentication testing. +class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + HeaderAuthServerMiddlewareFactory() {} + + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override { + std::string username, password; + ParseBasicHeader(incoming_headers, username, password); + if ((username == kValidUsername) && (password == kValidPassword)) { + *middleware = std::make_shared(); + } else if ((username == kInvalidUsername) && (password == kInvalidPassword)) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid credentials"); + } + return Status::OK(); + } }; // A server middleware for validating incoming bearer header authentication. @@ -883,22 +897,6 @@ class BearerAuthServerMiddleware : public ServerMiddleware { bool* isValid_; }; -// Factory for base64 header authentication testing. -class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { - public: - HeaderAuthServerMiddlewareFactory() {} - - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, - std::shared_ptr* middleware) override { - const std::pair& iter_pair = - incoming_headers.equal_range(kAuthHeader); - if (iter_pair.first != iter_pair.second) { - *middleware = std::make_shared(incoming_headers); - } - return Status::OK(); - } -}; - // Factory for base64 header authentication testing. class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory { public: @@ -1160,26 +1158,25 @@ class TestBasicHeaderAuthMiddleware : public ::testing::Test { } void RunValidClientAuth() { - std::pair bearer_token; - ASSERT_OK(client_->AuthenticateBasicToken({}, kValidUsername, kValidPassword, - &bearer_token)); - ASSERT_EQ(bearer_token.first, kAuthHeader); - ASSERT_EQ(bearer_token.second, (std::string(kBearerPrefix) + kBearerToken)); + arrow::Result> bearer_result = + client_->AuthenticateBasicToken({}, kValidUsername, kValidPassword); + ASSERT_OK(bearer_result.status()); + ASSERT_EQ(bearer_result.ValueOrDie().first, kAuthHeader); + ASSERT_EQ(bearer_result.ValueOrDie().second, + (std::string(kBearerPrefix) + kBearerToken)); std::unique_ptr listing; FlightCallOptions call_options; - call_options.headers.push_back(bearer_token); + call_options.headers.push_back(bearer_result.ValueOrDie()); ASSERT_OK(client_->ListFlights(call_options, {}, &listing)); ASSERT_TRUE(bearer_middleware_->GetIsValid()); } void RunInvalidClientAuth() { - std::pair bearer_token; - // Note: Status intentionally ignored because it requires C++ server implementation of - // header auth. For now it returns an IOError. - arrow::Status status = client_->AuthenticateBasicToken( - {}, kInvalidUsername, kInvalidPassword, &bearer_token); - ASSERT_EQ(bearer_token.first, std::string("")); - ASSERT_EQ(bearer_token.second, std::string("")); + arrow::Result> bearer_result = + client_->AuthenticateBasicToken({}, kInvalidUsername, kInvalidPassword); + ASSERT_RAISES(IOError, bearer_result.status()); + ASSERT_THAT(bearer_result.status().message(), + ::testing::HasSubstr("Invalid credentials")); } void TearDown() { ASSERT_OK(server_->Shutdown()); } From 6cd8a45cfd95771243a16be2cd53aaa5d69d94aa Mon Sep 17 00:00:00 2001 From: Lyndon Bauto Date: Wed, 25 Nov 2020 11:09:29 -0800 Subject: [PATCH 31/31] [1] Fixed linting issue. --- cpp/src/arrow/flight/flight_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index f247059ef8ddd..2868f84e7c91f 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -836,8 +836,6 @@ std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers, class HeaderAuthServerMiddleware : public ServerMiddleware { public: - explicit HeaderAuthServerMiddleware() {} - void SendingHeaders(AddCallHeaders* outgoing_headers) override { outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken); }