Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/flight/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ set(ARROW_FLIGHT_SRCS
serialization_internal.cc
server.cc
server_auth.cc
server_middleware.cc
server_tracing_middleware.cc
transport.cc
transport_server.cc
Expand Down
20 changes: 10 additions & 10 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ class TestTls : public ::testing::Test {

// A server middleware that rejects all calls.
class RejectServerMiddlewareFactory : public ServerMiddlewareFactory {
Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
return MakeFlightError(FlightStatusCode::Unauthenticated, "All calls are rejected");
}
Expand Down Expand Up @@ -484,7 +484,7 @@ class CountingServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
CountingServerMiddlewareFactory() : successful_(0), failed_(0) {}

Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
*middleware = std::make_shared<CountingServerMiddleware>(&successful_, &failed_);
return Status::OK();
Expand Down Expand Up @@ -517,10 +517,10 @@ class TracingTestServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
TracingTestServerMiddlewareFactory() {}

Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
incoming_headers.equal_range("x-tracing-span-id");
context.incoming_headers().equal_range("x-tracing-span-id");
if (iter_pair.first != iter_pair.second) {
const std::string_view& value = (*iter_pair.first).second;
*middleware = std::make_shared<TracingTestServerMiddleware>(std::string(value));
Expand Down Expand Up @@ -578,10 +578,10 @@ class HeaderAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
HeaderAuthServerMiddlewareFactory() {}

Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
std::string username, password;
ParseBasicHeader(incoming_headers, username, password);
ParseBasicHeader(context.incoming_headers(), username, password);
if ((username == kValidUsername) && (password == kValidPassword)) {
*middleware = std::make_shared<HeaderAuthServerMiddleware>();
} else if ((username == kInvalidUsername) && (password == kInvalidPassword)) {
Expand Down Expand Up @@ -619,13 +619,13 @@ class BearerAuthServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
BearerAuthServerMiddlewareFactory() : isValid_(false) {}

Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
incoming_headers.equal_range(kAuthHeader);
context.incoming_headers().equal_range(kAuthHeader);
if (iter_pair.first != iter_pair.second) {
*middleware =
std::make_shared<BearerAuthServerMiddleware>(incoming_headers, &isValid_);
*middleware = std::make_shared<BearerAuthServerMiddleware>(
context.incoming_headers(), &isValid_);
}
return Status::OK();
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/flight/integration_tests/test_integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ class TestServerMiddleware : public ServerMiddleware {

class TestServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
const std::pair<CallHeaders::const_iterator, CallHeaders::const_iterator>& iter_pair =
incoming_headers.equal_range("x-middleware");
context.incoming_headers().equal_range("x-middleware");
std::string received = "";
if (iter_pair.first != iter_pair.second) {
const std::string_view& value = (*iter_pair.first).second;
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/arrow/flight/server_auth.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@

#include <string>

#include "arrow/flight/type_fwd.h"
#include "arrow/flight/visibility.h"
#include "arrow/status.h"

namespace arrow {

namespace flight {

class ServerCallContext;

/// \brief A reader for messages from the client during an
/// authentication handshake.
class ARROW_FLIGHT_EXPORT ServerAuthReader {
Expand Down
35 changes: 35 additions & 0 deletions cpp/src/arrow/flight/server_middleware.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/flight/server_middleware.h"
#include "arrow/flight/server.h"

namespace arrow {
namespace flight {

Status ServerMiddlewareFactory::StartCall(const CallInfo& info,
const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) {
// TODO: We can make this pure virtual function when we remove
// the deprecated version.
ARROW_SUPPRESS_DEPRECATION_WARNING
return StartCall(info, context.incoming_headers(), middleware);
ARROW_UNSUPPRESS_DEPRECATION_WARNING
}

} // namespace flight
} // namespace arrow
24 changes: 23 additions & 1 deletion cpp/src/arrow/flight/server_middleware.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <string>

#include "arrow/flight/middleware.h"
#include "arrow/flight/type_fwd.h"
#include "arrow/flight/visibility.h" // IWYU pragma: keep
#include "arrow/status.h"

Expand Down Expand Up @@ -61,6 +62,22 @@ class ARROW_FLIGHT_EXPORT ServerMiddlewareFactory {
public:
virtual ~ServerMiddlewareFactory() = default;

/// \brief A callback for the start of a new call.
///
/// Return a non-OK status to reject the call with the given status.
///
/// \param[in] info Information about the call.
/// \param[in] context The call context.
/// \param[out] middleware The middleware instance for this call. If
/// null, no middleware will be added to this call instance from
/// this factory.
/// \return Status A non-OK status will reject the call with the
/// given status. Middleware previously in the chain will have
/// their CallCompleted callback called. Other middleware
/// factories will not be called.
virtual Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware);

/// \brief A callback for the start of a new call.
///
/// Return a non-OK status to reject the call with the given status.
Expand All @@ -75,8 +92,13 @@ class ARROW_FLIGHT_EXPORT ServerMiddlewareFactory {
/// given status. Middleware previously in the chain will have
/// their CallCompleted callback called. Other middleware
/// factories will not be called.
/// \deprecated Deprecated in 13.0.0. Implement the StartCall()
/// with ServerCallContext version instead.
ARROW_DEPRECATED("Deprecated in 13.0.0. Use ServerCallContext overload instead.")
virtual Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
std::shared_ptr<ServerMiddleware>* middleware) = 0;
std::shared_ptr<ServerMiddleware>* middleware) {
return Status::NotImplemented(typeid(this).name(), "::StartCall() isn't implemented");
}
};

} // namespace flight
Expand Down
13 changes: 7 additions & 6 deletions cpp/src/arrow/flight/server_tracing_middleware.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

#include "arrow/flight/server_tracing_middleware.h"
#include "arrow/flight/server.h"

#include <string>
#include <string_view>
Expand Down Expand Up @@ -122,19 +123,19 @@ class TracingServerMiddleware::Impl {
class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
virtual ~TracingServerMiddlewareFactory() = default;
Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
Status StartCall(const CallInfo& info, const ServerCallContext& context,
std::shared_ptr<ServerMiddleware>* middleware) override {
constexpr char kServiceName[] = "arrow.flight.protocol.FlightService";

FlightServerCarrier carrier(incoming_headers);
auto context = otel::context::RuntimeContext::GetCurrent();
FlightServerCarrier carrier(context.incoming_headers());
auto otel_context = otel::context::RuntimeContext::GetCurrent();
auto propagator =
otel::context::propagation::GlobalTextMapPropagator::GetGlobalPropagator();
auto new_context = propagator->Extract(carrier, context);
auto new_otel_context = propagator->Extract(carrier, otel_context);

otel::trace::StartSpanOptions options;
options.kind = otel::trace::SpanKind::kServer;
options.parent = otel::trace::GetSpan(new_context)->GetContext();
options.parent = otel::trace::GetSpan(new_otel_context)->GetContext();

auto* tracer = arrow::internal::tracing::GetTracer();
auto method_name = ToString(info.method);
Expand Down Expand Up @@ -167,7 +168,7 @@ class TracingServerMiddleware::Impl {
class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
virtual ~TracingServerMiddlewareFactory() = default;
Status StartCall(const CallInfo&, const CallHeaders&,
Status StartCall(const CallInfo&, const ServerCallContext&,
std::shared_ptr<ServerMiddleware>* middleware) override {
std::unique_ptr<TracingServerMiddleware::Impl> impl(
new TracingServerMiddleware::Impl());
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/arrow/flight/transport/grpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,7 @@ class GrpcServiceHandler final : public FlightService::Service {
GrpcAddServerHeaders outgoing_headers(context);
for (const auto& factory : middleware_) {
std::shared_ptr<ServerMiddleware> instance;
Status result =
factory.second->StartCall(info, flight_context.incoming_headers(), &instance);
Status result = factory.second->StartCall(info, flight_context, &instance);
if (!result.ok()) {
// Interceptor rejected call, end the request on all existing
// interceptors
Expand Down