diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index f0c23efb69f8e..86e3c510ebbf4 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -118,6 +118,7 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") # protobuf-internal.cc set(ARROW_FLIGHT_SRCS client.cc + client_header_internal.cc internal.cc protocol_internal.cc serialization_internal.cc diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index cdffa7f06567a..5c56e6409a75b 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -50,6 +50,7 @@ #include "arrow/util/uri.h" #include "arrow/flight/client_auth.h" +#include "arrow/flight/client_header_internal.h" #include "arrow/flight/client_middleware.h" #include "arrow/flight/internal.h" #include "arrow/flight/middleware.h" @@ -104,6 +105,9 @@ struct ClientRpc { std::chrono::system_clock::now() + options.timeout); context.set_deadline(deadline); } + for (auto header : options.headers) { + context.AddMetadata(header.first, header.second); + } } /// \brief Add an auth token via an auth handler @@ -994,6 +998,30 @@ class FlightClient::FlightClientImpl { return Status::OK(); } + arrow::Result> AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password) { + // Add basic auth headers to outgoing headers. + ClientRpc rpc(options); + internal::AddBasicAuthHeaders(&rpc.context, username, password); + + std::shared_ptr> + stream = stub_->Handshake(&rpc.context); + GrpcClientAuthSender outgoing{stream}; + GrpcClientAuthReader incoming{stream}; + + // Explicitly close our side of the connection. + bool finished_writes = stream->WritesDone(); + RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); + if (!finished_writes) { + return MakeFlightError(FlightStatusCode::Internal, + "Could not finish writing before closing"); + } + + // Grab bearer token from incoming headers. + return internal::GetBearerTokenHeader(rpc.context); + } + Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, std::unique_ptr* listing) { pb::Criteria pb_criteria; @@ -1198,6 +1226,12 @@ Status FlightClient::Authenticate(const FlightCallOptions& options, return impl_->Authenticate(options, std::move(auth_handler)); } +arrow::Result> FlightClient::AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password) { + return impl_->AuthenticateBasicToken(options, username, password); +} + Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action, std::unique_ptr* results) { return impl_->DoAction(options, action, results); diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 935e8fb92ba43..441f11467669e 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -29,6 +29,7 @@ #include "arrow/ipc/options.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/variant.h" @@ -65,6 +66,9 @@ class ARROW_FLIGHT_EXPORT FlightCallOptions { /// \brief IPC writer options, if applicable for the call. ipc::IpcWriteOptions write_options; + + /// \brief Headers for client to add to context. + std::vector> headers; }; /// \brief Indicate that the client attempted to write a message @@ -191,6 +195,16 @@ class ARROW_FLIGHT_EXPORT FlightClient { Status Authenticate(const FlightCallOptions& options, std::unique_ptr auth_handler); + /// \brief Authenticate to the server using basic HTTP style authentication. + /// \param[in] options Per-RPC options + /// \param[in] username Username to use + /// \param[in] password Password to use + /// \return Arrow result with bearer token and status OK if client authenticated + /// sucessfully + arrow::Result> AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password); + /// \brief Perform the indicated action, returning an iterator to the stream /// of results, if any /// \param[in] options Per-RPC options diff --git a/cpp/src/arrow/flight/client_header_internal.cc b/cpp/src/arrow/flight/client_header_internal.cc new file mode 100644 index 0000000000000..2112b41f72f33 --- /dev/null +++ b/cpp/src/arrow/flight/client_header_internal.cc @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Interfaces for defining middleware for Flight clients. Currently +// experimental. + +#include "arrow/flight/client_header_internal.h" +#include "arrow/flight/client.h" +#include "arrow/flight/client_auth.h" +#include "arrow/util/base64.h" +#include "arrow/util/make_unique.h" + +#include +#include +#include +#include + +const char kAuthHeader[] = "authorization"; +const char kBearerPrefix[] = "Bearer "; +const char kBasicPrefix[] = "Basic "; + +namespace arrow { +namespace flight { +namespace internal { + +/// \brief Add base64 encoded credentials to the outbound headers. +/// +/// \param context Context object to add the headers to. +/// \param username Username to format and encode. +/// \param password Password to format and encode. +void AddBasicAuthHeaders(grpc::ClientContext* context, const std::string& username, + const std::string& password) { + const std::string credentials = username + ":" + password; + context->AddMetadata( + kAuthHeader, + kBasicPrefix + arrow::util::base64_encode( + reinterpret_cast(credentials.c_str()), + static_cast(credentials.size()))); +} + +/// \brief Get bearer token from inbound headers. +/// +/// \param context Incoming ClientContext that contains headers. +/// \return Arrow result with bearer token (empty if no bearer token found). +arrow::Result> GetBearerTokenHeader( + grpc::ClientContext& context) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (::toupper(char1) == ::toupper(char2)); + }; + + // Get the auth token if it exists, this can be in the initial or the trailing metadata. + auto trailing_headers = context.GetServerTrailingMetadata(); + auto initial_headers = context.GetServerInitialMetadata(); + auto bearer_iter = trailing_headers.find(kAuthHeader); + if (bearer_iter == trailing_headers.end()) { + bearer_iter = initial_headers.find(kAuthHeader); + if (bearer_iter == initial_headers.end()) { + return std::make_pair("", ""); + } + } + + // Check if the value of the auth token starts with the bearer prefix and latch it. + std::string bearer_val(bearer_iter->second.data(), bearer_iter->second.size()); + if (bearer_val.size() > strlen(kBearerPrefix)) { + if (std::equal(bearer_val.begin(), bearer_val.begin() + strlen(kBearerPrefix), + kBearerPrefix, char_compare)) { + return std::make_pair(kAuthHeader, bearer_val); + } + } + + // The server is not required to provide a bearer token. + return std::make_pair("", ""); +} + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/client_header_internal.h b/cpp/src/arrow/flight/client_header_internal.h new file mode 100644 index 0000000000000..718848a5ffd46 --- /dev/null +++ b/cpp/src/arrow/flight/client_header_internal.h @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Interfaces for defining middleware for Flight clients. Currently +// experimental. + +#pragma once + +#include "arrow/flight/client_middleware.h" +#include "arrow/result.h" + +#ifdef GRPCPP_PP_INCLUDE +#include +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) +#include +#endif +#else +#include +#endif + +namespace arrow { +namespace flight { +namespace internal { + +/// \brief Add basic authentication header key value pair to context. +/// +/// \param context grpc context variable to add header to. +/// \param username username to encode into header. +/// \param password password to to encode into header. +void ARROW_FLIGHT_EXPORT AddBasicAuthHeaders(grpc::ClientContext* context, + const std::string& username, + const std::string& password); + +/// \brief Get bearer token from incoming headers. +/// +/// \param context context that contains headers which hold the bearer token. +/// \return Bearer token, parsed from headers, empty if one is not present. +arrow::Result> ARROW_FLIGHT_EXPORT +GetBearerTokenHeader(grpc::ClientContext& context); + +} // namespace internal +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 95048f6684cee..2868f84e7c91f 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -36,6 +36,7 @@ #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" +#include "arrow/util/base64.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" @@ -43,6 +44,7 @@ #error "gRPC headers should not be in public API" #endif +#include "arrow/flight/client_header_internal.h" #include "arrow/flight/internal.h" #include "arrow/flight/middleware_internal.h" #include "arrow/flight/test_util.h" @@ -52,6 +54,15 @@ namespace pb = arrow::flight::protocol; namespace arrow { namespace flight { +const char kValidUsername[] = "flight_username"; +const char kValidPassword[] = "flight_password"; +const char kInvalidUsername[] = "invalid_flight_username"; +const char kInvalidPassword[] = "invalid_flight_password"; +const char kBearerToken[] = "bearertoken"; +const char kBasicPrefix[] = "Basic "; +const char kBearerPrefix[] = "Bearer "; +const char kAuthHeader[] = "authorization"; + void AssertEqual(const ActionType& expected, const ActionType& actual) { ASSERT_EQ(expected.type, actual.type); ASSERT_EQ(expected.description, actual.description); @@ -559,6 +570,14 @@ class OptionsTestServer : public FlightServerBase { } }; +class HeaderAuthTestServer : public FlightServerBase { + public: + Status ListFlights(const ServerCallContext& context, const Criteria* criteria, + std::unique_ptr* listings) override { + return Status::OK(); + } +}; + class TestMetadata : public ::testing::Test { public: void SetUp() { @@ -791,6 +810,113 @@ class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { } }; +// Function to look in CallHeaders for a key that has a value starting with prefix and +// return the rest of the value after the prefix. +std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers, + const std::string& key, + const std::string& prefix) { + // Lambda function to compare characters without case sensitivity. + auto char_compare = [](const char& char1, const char& char2) { + return (::toupper(char1) == ::toupper(char2)); + }; + + auto iter = incoming_headers.find(key); + if (iter == incoming_headers.end()) { + return ""; + } + const std::string val = iter->second.to_string(); + if (val.size() > prefix.length()) { + if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(), + char_compare)) { + return val.substr(prefix.length()); + } + } + return ""; +} + +class HeaderAuthServerMiddleware : public ServerMiddleware { + public: + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + outgoing_headers->AddHeader(kAuthHeader, std::string(kBearerPrefix) + kBearerToken); + } + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "HeaderAuthServerMiddleware"; } +}; + +void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username, + std::string& password) { + std::string encoded_credentials = + FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix); + std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials)); + std::getline(decoded_stream, username, ':'); + std::getline(decoded_stream, password, ':'); +} + +// Factory for base64 header authentication testing. +class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + HeaderAuthServerMiddlewareFactory() {} + + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override { + std::string username, password; + ParseBasicHeader(incoming_headers, username, password); + if ((username == kValidUsername) && (password == kValidPassword)) { + *middleware = std::make_shared(); + } else if ((username == kInvalidUsername) && (password == kInvalidPassword)) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid credentials"); + } + return Status::OK(); + } +}; + +// A server middleware for validating incoming bearer header authentication. +class BearerAuthServerMiddleware : public ServerMiddleware { + public: + explicit BearerAuthServerMiddleware(const CallHeaders& incoming_headers, bool* isValid) + : isValid_(isValid) { + incoming_headers_ = incoming_headers; + } + + void SendingHeaders(AddCallHeaders* outgoing_headers) override { + std::string bearer_token = + FindKeyValPrefixInCallHeaders(incoming_headers_, kAuthHeader, kBearerPrefix); + *isValid_ = (bearer_token == std::string(kBearerToken)); + } + + void CallCompleted(const Status& status) override {} + + std::string name() const override { return "BearerAuthServerMiddleware"; } + + private: + CallHeaders incoming_headers_; + bool* isValid_; +}; + +// Factory for base64 header authentication testing. +class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory { + public: + BearerAuthServerMiddlewareFactory() : isValid_(false) {} + + Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + std::shared_ptr* middleware) override { + const std::pair& iter_pair = + incoming_headers.equal_range(kAuthHeader); + if (iter_pair.first != iter_pair.second) { + *middleware = + std::make_shared(incoming_headers, &isValid_); + } + return Status::OK(); + } + + bool GetIsValid() { return isValid_; } + + private: + bool isValid_; +}; + // A client middleware that adds a thread-local "request ID" to // outgoing calls as a header, and keeps track of the status of // completed calls. NOT thread-safe. @@ -1010,6 +1136,56 @@ class TestErrorMiddleware : public ::testing::Test { std::unique_ptr server_; }; +class TestBasicHeaderAuthMiddleware : public ::testing::Test { + public: + void SetUp() { + header_middleware_ = std::make_shared(); + bearer_middleware_ = std::make_shared(); + std::pair bearer = make_pair( + kAuthHeader, std::string(kBearerPrefix) + " " + std::string(kBearerToken)); + ASSERT_OK(MakeServer( + &server_, &client_, + [&](FlightServerOptions* options) { + options->auth_handler = + std::unique_ptr(new NoOpAuthHandler()); + options->middleware.push_back({"header-auth-server", header_middleware_}); + options->middleware.push_back({"bearer-auth-server", bearer_middleware_}); + return Status::OK(); + }, + [&](FlightClientOptions* options) { return Status::OK(); })); + } + + void RunValidClientAuth() { + arrow::Result> bearer_result = + client_->AuthenticateBasicToken({}, kValidUsername, kValidPassword); + ASSERT_OK(bearer_result.status()); + ASSERT_EQ(bearer_result.ValueOrDie().first, kAuthHeader); + ASSERT_EQ(bearer_result.ValueOrDie().second, + (std::string(kBearerPrefix) + kBearerToken)); + std::unique_ptr listing; + FlightCallOptions call_options; + call_options.headers.push_back(bearer_result.ValueOrDie()); + ASSERT_OK(client_->ListFlights(call_options, {}, &listing)); + ASSERT_TRUE(bearer_middleware_->GetIsValid()); + } + + void RunInvalidClientAuth() { + arrow::Result> bearer_result = + client_->AuthenticateBasicToken({}, kInvalidUsername, kInvalidPassword); + ASSERT_RAISES(IOError, bearer_result.status()); + ASSERT_THAT(bearer_result.status().message(), + ::testing::HasSubstr("Invalid credentials")); + } + + void TearDown() { ASSERT_OK(server_->Shutdown()); } + + protected: + std::unique_ptr client_; + std::unique_ptr server_; + std::shared_ptr header_middleware_; + std::shared_ptr bearer_middleware_; +}; + TEST_F(TestErrorMiddleware, TestMetadata) { Action action; std::unique_ptr stream; @@ -2193,5 +2369,9 @@ TEST_F(TestPropagatingMiddleware, DoPut) { ValidateStatus(status, FlightMethod::DoPut); } +TEST_F(TestBasicHeaderAuthMiddleware, ValidCredentials) { RunValidClientAuth(); } + +TEST_F(TestBasicHeaderAuthMiddleware, InvalidCredentials) { RunInvalidClientAuth(); } + } // namespace flight } // namespace arrow