diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 61fa6e9d0c4..10858552500 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -223,7 +223,7 @@ class ARROW_FLIGHT_EXPORT FlightClient { /// \param[in] username Username to use /// \param[in] password Password to use /// \return Arrow result with bearer token and status OK if client authenticated - /// sucessfully + /// successfully arrow::Result> AuthenticateBasicToken( const FlightCallOptions& options, const std::string& username, const std::string& password); diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 5520dfc48f7..a2b69494b8e 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -214,6 +214,30 @@ TEST(TestFlight, DISABLED_IpV6Port) { ASSERT_OK(client->ListFlights()); } +TEST(TestFlight, ServerCallContextIncomingHeaders) { + auto server = ExampleTestServer(); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); + FlightServerOptions options(location); + ASSERT_OK(server->Init(options)); + + ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(server->location())); + Action action; + action.type = "list-incoming-headers"; + action.body = Buffer::FromString("test-header"); + FlightCallOptions call_options; + call_options.headers.emplace_back("test-header1", "value1"); + call_options.headers.emplace_back("test-header2", "value2"); + ASSERT_OK_AND_ASSIGN(auto stream, client->DoAction(call_options, action)); + ASSERT_OK_AND_ASSIGN(auto result, stream->Next()); + ASSERT_NE(result.get(), nullptr); + ASSERT_EQ(result->body->ToString(), "test-header1: value1"); + ASSERT_OK_AND_ASSIGN(result, stream->Next()); + ASSERT_NE(result.get(), nullptr); + ASSERT_EQ(result->body->ToString(), "test-header2: value2"); + ASSERT_OK_AND_ASSIGN(result, stream->Next()); + ASSERT_EQ(result.get(), nullptr); +} + // ---------------------------------------------------------------------- // Client tests diff --git a/cpp/src/arrow/flight/middleware.h b/cpp/src/arrow/flight/middleware.h index dc1ad24bc5c..e936b9f0202 100644 --- a/cpp/src/arrow/flight/middleware.h +++ b/cpp/src/arrow/flight/middleware.h @@ -20,23 +20,17 @@ #pragma once -#include #include #include #include #include -#include "arrow/flight/visibility.h" // IWYU pragma: keep +#include "arrow/flight/types.h" #include "arrow/status.h" namespace arrow { namespace flight { -/// \brief Headers sent from the client or server. -/// -/// Header values are ordered. -using CallHeaders = std::multimap; - /// \brief A write-only wrapper around headers for an RPC call. class ARROW_FLIGHT_EXPORT AddCallHeaders { public: diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index 1d1b1a50f37..6fb8ab12131 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -137,6 +137,8 @@ class ARROW_FLIGHT_EXPORT ServerCallContext { /// \brief Check if the current RPC has been cancelled (by the client, by /// a network error, etc.). virtual bool is_cancelled() const = 0; + /// \brief The headers sent by the client for this call. + virtual const CallHeaders& incoming_headers() const = 0; }; class ARROW_FLIGHT_EXPORT FlightServerOptions { diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index 0d6c28b2968..7430a9b7dea 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -474,12 +474,31 @@ class FlightTestServer : public FlightServerBase { return Status::OK(); } + Status ListIncomingHeaders(const ServerCallContext& context, const Action& action, + std::unique_ptr* out) { + std::vector results; + std::string_view prefix(*action.body); + for (const auto& header : context.incoming_headers()) { + if (header.first.substr(0, prefix.size()) != prefix) { + continue; + } + Result result; + result.body = Buffer::FromString(std::string(header.first) + ": " + + std::string(header.second)); + results.push_back(result); + } + *out = std::make_unique(std::move(results)); + return Status::OK(); + } + Status DoAction(const ServerCallContext& context, const Action& action, std::unique_ptr* out) override { if (action.type == "action1") { return RunAction1(action, out); } else if (action.type == "action2") { return RunAction2(out); + } else if (action.type == "list-incoming-headers") { + return ListIncomingHeaders(context, action, out); } else { return Status::NotImplemented(action.type); } diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc index a643111e3b2..acf80462f1a 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc @@ -117,11 +117,18 @@ class GrpcServerAuthSender : public ServerAuthSender { class GrpcServerCallContext : public ServerCallContext { explicit GrpcServerCallContext(::grpc::ServerContext* context) - : context_(context), peer_(context_->peer()) {} + : context_(context), peer_(context_->peer()) { + for (const auto& entry : context->client_metadata()) { + incoming_headers_.insert( + {std::string_view(entry.first.data(), entry.first.length()), + std::string_view(entry.second.data(), entry.second.length())}); + } + } const std::string& peer_identity() const override { return peer_identity_; } const std::string& peer() const override { return peer_; } bool is_cancelled() const override { return context_->IsCancelled(); } + const CallHeaders& incoming_headers() const override { return incoming_headers_; } // Helper method that runs interceptors given the result of an RPC, // then returns the final gRPC status to send to the client @@ -156,6 +163,7 @@ class GrpcServerCallContext : public ServerCallContext { std::string peer_identity_; std::vector> middleware_; std::unordered_map> middleware_map_; + CallHeaders incoming_headers_; }; class GrpcAddServerHeaders : public AddCallHeaders { @@ -310,17 +318,12 @@ class GrpcServiceHandler final : public FlightService::Service { GrpcServerCallContext& flight_context) { // Run server middleware const CallInfo info{method}; - CallHeaders incoming_headers; - for (const auto& entry : context->client_metadata()) { - incoming_headers.insert( - {std::string_view(entry.first.data(), entry.first.length()), - std::string_view(entry.second.data(), entry.second.length())}); - } GrpcAddServerHeaders outgoing_headers(context); for (const auto& factory : middleware_) { std::shared_ptr instance; - Status result = factory.second->StartCall(info, incoming_headers, &instance); + Status result = + factory.second->StartCall(info, flight_context.incoming_headers(), &instance); if (!result.ok()) { // Interceptor rejected call, end the request on all existing // interceptors diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index 946b29383bf..4a573d74292 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -76,9 +76,11 @@ class UcxServerCallContext : public flight::ServerCallContext { return nullptr; } bool is_cancelled() const override { return false; } + const CallHeaders& incoming_headers() const override { return incoming_headers_; } private: std::string peer_; + CallHeaders incoming_headers_; }; class UcxServerStream : public internal::ServerDataStream { diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 39353bcb997..9d92f0be955 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -123,6 +124,11 @@ ARROW_FLIGHT_EXPORT Status MakeFlightError(FlightStatusCode code, std::string message, std::string extra_info = {}); +/// \brief Headers sent from the client or server. +/// +/// Header values are ordered. +using CallHeaders = std::multimap; + /// \brief A TLS certificate plus key. struct ARROW_FLIGHT_EXPORT CertKeyPair { /// \brief The certificate in PEM format.