diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 2a88e5f8ec7..917c0c33211 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -199,6 +199,7 @@ set(ARROW_FLIGHT_SRCS serialization_internal.cc server.cc server_auth.cc + server_middleware.cc server_tracing_middleware.cc transport.cc transport_server.cc diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 09e9d8c1561..d56dc81e356 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -452,7 +452,7 @@ class TestTls : public ::testing::Test { // A server middleware that rejects all calls. class RejectServerMiddlewareFactory : public ServerMiddlewareFactory { - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + Status StartCall(const CallInfo& info, const ServerCallContext& context, std::shared_ptr* middleware) override { return MakeFlightError(FlightStatusCode::Unauthenticated, "All calls are rejected"); } @@ -484,7 +484,7 @@ class CountingServerMiddlewareFactory : public ServerMiddlewareFactory { public: CountingServerMiddlewareFactory() : successful_(0), failed_(0) {} - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + Status StartCall(const CallInfo& info, const ServerCallContext& context, std::shared_ptr* middleware) override { *middleware = std::make_shared(&successful_, &failed_); return Status::OK(); @@ -517,10 +517,10 @@ class TracingTestServerMiddlewareFactory : public ServerMiddlewareFactory { public: TracingTestServerMiddlewareFactory() {} - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + Status StartCall(const CallInfo& info, const ServerCallContext& context, std::shared_ptr* middleware) override { const std::pair& iter_pair = - incoming_headers.equal_range("x-tracing-span-id"); + context.incoming_headers().equal_range("x-tracing-span-id"); if (iter_pair.first != iter_pair.second) { const std::string_view& value = (*iter_pair.first).second; *middleware = std::make_shared(std::string(value)); @@ -578,10 +578,10 @@ class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory { public: HeaderAuthServerMiddlewareFactory() {} - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + Status StartCall(const CallInfo& info, const ServerCallContext& context, std::shared_ptr* middleware) override { std::string username, password; - ParseBasicHeader(incoming_headers, username, password); + ParseBasicHeader(context.incoming_headers(), username, password); if ((username == kValidUsername) && (password == kValidPassword)) { *middleware = std::make_shared(); } else if ((username == kInvalidUsername) && (password == kInvalidPassword)) { @@ -619,13 +619,13 @@ class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory { public: BearerAuthServerMiddlewareFactory() : isValid_(false) {} - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + Status StartCall(const CallInfo& info, const ServerCallContext& context, std::shared_ptr* middleware) override { const std::pair& iter_pair = - incoming_headers.equal_range(kAuthHeader); + context.incoming_headers().equal_range(kAuthHeader); if (iter_pair.first != iter_pair.second) { - *middleware = - std::make_shared(incoming_headers, &isValid_); + *middleware = std::make_shared( + context.incoming_headers(), &isValid_); } return Status::OK(); } diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index f6af1429785..9a300d1bd25 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -143,10 +143,10 @@ class TestServerMiddleware : public ServerMiddleware { class TestServerMiddlewareFactory : public ServerMiddlewareFactory { public: - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + Status StartCall(const CallInfo& info, const ServerCallContext& context, std::shared_ptr* middleware) override { const std::pair& iter_pair = - incoming_headers.equal_range("x-middleware"); + context.incoming_headers().equal_range("x-middleware"); std::string received = ""; if (iter_pair.first != iter_pair.second) { const std::string_view& value = (*iter_pair.first).second; diff --git a/cpp/src/arrow/flight/server_auth.h b/cpp/src/arrow/flight/server_auth.h index 3d4787c0c71..93d3352ba20 100644 --- a/cpp/src/arrow/flight/server_auth.h +++ b/cpp/src/arrow/flight/server_auth.h @@ -21,6 +21,7 @@ #include +#include "arrow/flight/type_fwd.h" #include "arrow/flight/visibility.h" #include "arrow/status.h" @@ -28,8 +29,6 @@ namespace arrow { namespace flight { -class ServerCallContext; - /// \brief A reader for messages from the client during an /// authentication handshake. class ARROW_FLIGHT_EXPORT ServerAuthReader { diff --git a/cpp/src/arrow/flight/server_middleware.cc b/cpp/src/arrow/flight/server_middleware.cc new file mode 100644 index 00000000000..d7ace580dc6 --- /dev/null +++ b/cpp/src/arrow/flight/server_middleware.cc @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/server_middleware.h" +#include "arrow/flight/server.h" + +namespace arrow { +namespace flight { + +Status ServerMiddlewareFactory::StartCall(const CallInfo& info, + const ServerCallContext& context, + std::shared_ptr* middleware) { + // TODO: We can make this pure virtual function when we remove + // the deprecated version. + ARROW_SUPPRESS_DEPRECATION_WARNING + return StartCall(info, context.incoming_headers(), middleware); + ARROW_UNSUPPRESS_DEPRECATION_WARNING +} + +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/server_middleware.h b/cpp/src/arrow/flight/server_middleware.h index 26431aff01d..030f1a17c21 100644 --- a/cpp/src/arrow/flight/server_middleware.h +++ b/cpp/src/arrow/flight/server_middleware.h @@ -24,6 +24,7 @@ #include #include "arrow/flight/middleware.h" +#include "arrow/flight/type_fwd.h" #include "arrow/flight/visibility.h" // IWYU pragma: keep #include "arrow/status.h" @@ -61,6 +62,22 @@ class ARROW_FLIGHT_EXPORT ServerMiddlewareFactory { public: virtual ~ServerMiddlewareFactory() = default; + /// \brief A callback for the start of a new call. + /// + /// Return a non-OK status to reject the call with the given status. + /// + /// \param[in] info Information about the call. + /// \param[in] context The call context. + /// \param[out] middleware The middleware instance for this call. If + /// null, no middleware will be added to this call instance from + /// this factory. + /// \return Status A non-OK status will reject the call with the + /// given status. Middleware previously in the chain will have + /// their CallCompleted callback called. Other middleware + /// factories will not be called. + virtual Status StartCall(const CallInfo& info, const ServerCallContext& context, + std::shared_ptr* middleware); + /// \brief A callback for the start of a new call. /// /// Return a non-OK status to reject the call with the given status. @@ -75,8 +92,13 @@ class ARROW_FLIGHT_EXPORT ServerMiddlewareFactory { /// given status. Middleware previously in the chain will have /// their CallCompleted callback called. Other middleware /// factories will not be called. + /// \deprecated Deprecated in 13.0.0. Implement the StartCall() + /// with ServerCallContext version instead. + ARROW_DEPRECATED("Deprecated in 13.0.0. Use ServerCallContext overload instead.") virtual Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, - std::shared_ptr* middleware) = 0; + std::shared_ptr* middleware) { + return Status::NotImplemented(typeid(this).name(), "::StartCall() isn't implemented"); + } }; } // namespace flight diff --git a/cpp/src/arrow/flight/server_tracing_middleware.cc b/cpp/src/arrow/flight/server_tracing_middleware.cc index 6587db1d1dc..b5326d88a43 100644 --- a/cpp/src/arrow/flight/server_tracing_middleware.cc +++ b/cpp/src/arrow/flight/server_tracing_middleware.cc @@ -16,6 +16,7 @@ // under the License. #include "arrow/flight/server_tracing_middleware.h" +#include "arrow/flight/server.h" #include #include @@ -122,19 +123,19 @@ class TracingServerMiddleware::Impl { class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { public: virtual ~TracingServerMiddlewareFactory() = default; - Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers, + Status StartCall(const CallInfo& info, const ServerCallContext& context, std::shared_ptr* middleware) override { constexpr char kServiceName[] = "arrow.flight.protocol.FlightService"; - FlightServerCarrier carrier(incoming_headers); - auto context = otel::context::RuntimeContext::GetCurrent(); + FlightServerCarrier carrier(context.incoming_headers()); + auto otel_context = otel::context::RuntimeContext::GetCurrent(); auto propagator = otel::context::propagation::GlobalTextMapPropagator::GetGlobalPropagator(); - auto new_context = propagator->Extract(carrier, context); + auto new_otel_context = propagator->Extract(carrier, otel_context); otel::trace::StartSpanOptions options; options.kind = otel::trace::SpanKind::kServer; - options.parent = otel::trace::GetSpan(new_context)->GetContext(); + options.parent = otel::trace::GetSpan(new_otel_context)->GetContext(); auto* tracer = arrow::internal::tracing::GetTracer(); auto method_name = ToString(info.method); @@ -167,7 +168,7 @@ class TracingServerMiddleware::Impl { class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { public: virtual ~TracingServerMiddlewareFactory() = default; - Status StartCall(const CallInfo&, const CallHeaders&, + Status StartCall(const CallInfo&, const ServerCallContext&, std::shared_ptr* middleware) override { std::unique_ptr impl( new TracingServerMiddleware::Impl()); diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc index 09d702cd841..dcf9c3f8c9f 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc @@ -323,8 +323,7 @@ class GrpcServiceHandler final : public FlightService::Service { GrpcAddServerHeaders outgoing_headers(context); for (const auto& factory : middleware_) { std::shared_ptr instance; - Status result = - factory.second->StartCall(info, flight_context.incoming_headers(), &instance); + Status result = factory.second->StartCall(info, flight_context, &instance); if (!result.ok()) { // Interceptor rejected call, end the request on all existing // interceptors