From 93ac104dfc176976b3a7e392ec4a5fed0dc9a81f Mon Sep 17 00:00:00 2001
From: David Li
Date: Thu, 9 Dec 2021 13:28:04 -0500
Subject: [PATCH] ARROW-14958: [C++][Python][FlightRPC] Add OpenTelemetry
middleware
---
cpp/src/arrow/flight/CMakeLists.txt | 3 +
cpp/src/arrow/flight/api.h | 2 +
.../arrow/flight/client_tracing_middleware.cc | 102 ++++++++++
.../arrow/flight/client_tracing_middleware.h | 34 ++++
cpp/src/arrow/flight/flight_test.cc | 181 +++++++++++++++--
cpp/src/arrow/flight/middleware.cc | 53 +++++
cpp/src/arrow/flight/middleware.h | 6 +-
.../arrow/flight/server_tracing_middleware.cc | 183 ++++++++++++++++++
.../arrow/flight/server_tracing_middleware.h | 68 +++++++
python/pyarrow/_flight.pyx | 64 +++++-
python/pyarrow/flight.py | 1 +
python/pyarrow/includes/libarrow_flight.pxd | 21 +-
python/pyarrow/tests/test_flight.py | 30 +++
13 files changed, 726 insertions(+), 22 deletions(-)
create mode 100644 cpp/src/arrow/flight/client_tracing_middleware.cc
create mode 100644 cpp/src/arrow/flight/client_tracing_middleware.h
create mode 100644 cpp/src/arrow/flight/middleware.cc
create mode 100644 cpp/src/arrow/flight/server_tracing_middleware.cc
create mode 100644 cpp/src/arrow/flight/server_tracing_middleware.h
diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt
index 212db853d66..4cdb5fe127d 100644
--- a/cpp/src/arrow/flight/CMakeLists.txt
+++ b/cpp/src/arrow/flight/CMakeLists.txt
@@ -184,10 +184,13 @@ set(ARROW_FLIGHT_SRCS
"${CMAKE_CURRENT_BINARY_DIR}/Flight.pb.cc"
client.cc
client_cookie_middleware.cc
+ client_tracing_middleware.cc
cookie_internal.cc
+ middleware.cc
serialization_internal.cc
server.cc
server_auth.cc
+ server_tracing_middleware.cc
transport.cc
transport_server.cc
# Bundle the gRPC impl with libarrow_flight
diff --git a/cpp/src/arrow/flight/api.h b/cpp/src/arrow/flight/api.h
index c58a9d48afa..61c475dc204 100644
--- a/cpp/src/arrow/flight/api.h
+++ b/cpp/src/arrow/flight/api.h
@@ -20,8 +20,10 @@
#include "arrow/flight/client.h"
#include "arrow/flight/client_auth.h"
#include "arrow/flight/client_middleware.h"
+#include "arrow/flight/client_tracing_middleware.h"
#include "arrow/flight/middleware.h"
#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/server_tracing_middleware.h"
#include "arrow/flight/types.h"
diff --git a/cpp/src/arrow/flight/client_tracing_middleware.cc b/cpp/src/arrow/flight/client_tracing_middleware.cc
new file mode 100644
index 00000000000..a45784bd31e
--- /dev/null
+++ b/cpp/src/arrow/flight/client_tracing_middleware.cc
@@ -0,0 +1,102 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/client_tracing_middleware.h"
+
+#include
+#include
+#include
+#include
+
+#include "arrow/util/tracing_internal.h"
+
+#ifdef ARROW_WITH_OPENTELEMETRY
+#include
+#include
+#endif
+
+namespace arrow {
+namespace flight {
+
+namespace {
+#ifdef ARROW_WITH_OPENTELEMETRY
+namespace otel = opentelemetry;
+class FlightClientCarrier : public otel::context::propagation::TextMapCarrier {
+ public:
+ FlightClientCarrier() = default;
+
+ otel::nostd::string_view Get(otel::nostd::string_view key) const noexcept override {
+ return "";
+ }
+
+ void Set(otel::nostd::string_view key,
+ otel::nostd::string_view value) noexcept override {
+ context_.emplace_back(key, value);
+ }
+
+ std::vector> context_;
+};
+
+class TracingClientMiddleware : public ClientMiddleware {
+ public:
+ explicit TracingClientMiddleware(FlightClientCarrier carrier)
+ : carrier_(std::move(carrier)) {}
+ virtual ~TracingClientMiddleware() = default;
+
+ void SendingHeaders(AddCallHeaders* outgoing_headers) override {
+ // The exact headers added are not arbitrary and are defined in
+ // the OpenTelemetry specification (see
+ // open-telemetry/opentelemetry-specification api-propagators.md)
+ for (const auto& pair : carrier_.context_) {
+ outgoing_headers->AddHeader(pair.first, pair.second);
+ }
+ }
+ void ReceivedHeaders(const CallHeaders&) override {}
+ void CallCompleted(const Status&) override {}
+
+ private:
+ FlightClientCarrier carrier_;
+};
+
+class TracingClientMiddlewareFactory : public ClientMiddlewareFactory {
+ public:
+ virtual ~TracingClientMiddlewareFactory() = default;
+ void StartCall(const CallInfo& info,
+ std::unique_ptr* middleware) override {
+ FlightClientCarrier carrier;
+ auto context = otel::context::RuntimeContext::GetCurrent();
+ auto propagator =
+ otel::context::propagation::GlobalTextMapPropagator::GetGlobalPropagator();
+ propagator->Inject(carrier, context);
+ *middleware = std::make_unique(std::move(carrier));
+ }
+};
+#else
+class TracingClientMiddlewareFactory : public ClientMiddlewareFactory {
+ public:
+ virtual ~TracingClientMiddlewareFactory() = default;
+ void StartCall(const CallInfo&, std::unique_ptr*) override {}
+};
+#endif
+} // namespace
+
+std::shared_ptr MakeTracingClientMiddlewareFactory() {
+ return std::make_shared();
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/client_tracing_middleware.h b/cpp/src/arrow/flight/client_tracing_middleware.h
new file mode 100644
index 00000000000..3a8b665ed6c
--- /dev/null
+++ b/cpp/src/arrow/flight/client_tracing_middleware.h
@@ -0,0 +1,34 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Middleware implementation for propagating OpenTelemetry spans.
+
+#pragma once
+
+#include
+
+#include "arrow/flight/client_middleware.h"
+
+namespace arrow {
+namespace flight {
+
+/// \brief Returns a ClientMiddlewareFactory that handles sending OpenTelemetry spans.
+ARROW_FLIGHT_EXPORT std::shared_ptr
+MakeTracingClientMiddlewareFactory();
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc
index 217f910d640..db187013ec9 100644
--- a/cpp/src/arrow/flight/flight_test.cc
+++ b/cpp/src/arrow/flight/flight_test.cc
@@ -27,10 +27,13 @@
#include
#include
#include
+#include
#include
#include
#include "arrow/flight/api.h"
+#include "arrow/flight/client_tracing_middleware.h"
+#include "arrow/flight/server_tracing_middleware.h"
#include "arrow/ipc/test_common.h"
#include "arrow/status.h"
#include "arrow/testing/generator.h"
@@ -54,6 +57,26 @@
#include "arrow/flight/serialization_internal.h"
#include "arrow/flight/test_definitions.h"
#include "arrow/flight/test_util.h"
+// OTel includes must come after any gRPC includes, and
+// client_header_internal.h includes gRPC. See:
+// https://github.com/open-telemetry/opentelemetry-cpp/blob/main/examples/otlp/README.md
+//
+// > gRPC internally uses a different version of Abseil than
+// > OpenTelemetry C++ SDK.
+// > ...
+// > ...in case if you run into conflict between Abseil library and
+// > OpenTelemetry C++ absl::variant implementation, please include
+// > either grpcpp/grpcpp.h or
+// > opentelemetry/exporters/otlp/otlp_grpc_exporter.h BEFORE any
+// > other API headers. This approach efficiently avoids the conflict
+// > between the two different versions of Abseil.
+#include "arrow/util/tracing_internal.h"
+#ifdef ARROW_WITH_OPENTELEMETRY
+#include
+#include
+#include
+#include
+#endif
namespace arrow {
namespace flight {
@@ -441,21 +464,21 @@ static thread_local std::string current_span_id = "";
// A server middleware that stores the current span ID, in an
// emulation of OpenTracing style distributed tracing.
-class TracingServerMiddleware : public ServerMiddleware {
+class TracingTestServerMiddleware : public ServerMiddleware {
public:
- explicit TracingServerMiddleware(const std::string& current_span_id)
+ explicit TracingTestServerMiddleware(const std::string& current_span_id)
: span_id(current_span_id) {}
void SendingHeaders(AddCallHeaders* outgoing_headers) override {}
void CallCompleted(const Status& status) override {}
- std::string name() const override { return "TracingServerMiddleware"; }
+ std::string name() const override { return "TracingTestServerMiddleware"; }
std::string span_id;
};
-class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
+class TracingTestServerMiddlewareFactory : public ServerMiddlewareFactory {
public:
- TracingServerMiddlewareFactory() {}
+ TracingTestServerMiddlewareFactory() {}
Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
std::shared_ptr* middleware) override {
@@ -463,7 +486,7 @@ class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
incoming_headers.equal_range("x-tracing-span-id");
if (iter_pair.first != iter_pair.second) {
const std::string_view& value = (*iter_pair.first).second;
- *middleware = std::make_shared(std::string(value));
+ *middleware = std::make_shared(std::string(value));
}
return Status::OK();
}
@@ -627,10 +650,10 @@ class ReportContextTestServer : public FlightServerBase {
std::unique_ptr* result) override {
std::shared_ptr buf;
const ServerMiddleware* middleware = context.GetMiddleware("tracing");
- if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") {
+ if (middleware == nullptr || middleware->name() != "TracingTestServerMiddleware") {
buf = Buffer::FromString("");
} else {
- buf = Buffer::FromString(((const TracingServerMiddleware*)middleware)->span_id);
+ buf = Buffer::FromString(((const TracingTestServerMiddleware*)middleware)->span_id);
}
*result = std::make_unique(std::vector{Result{buf}});
return Status::OK();
@@ -658,10 +681,10 @@ class PropagatingTestServer : public FlightServerBase {
Status DoAction(const ServerCallContext& context, const Action& action,
std::unique_ptr* result) override {
const ServerMiddleware* middleware = context.GetMiddleware("tracing");
- if (middleware == nullptr || middleware->name() != "TracingServerMiddleware") {
+ if (middleware == nullptr || middleware->name() != "TracingTestServerMiddleware") {
current_span_id = "";
} else {
- current_span_id = ((const TracingServerMiddleware*)middleware)->span_id;
+ current_span_id = ((const TracingTestServerMiddleware*)middleware)->span_id;
}
return client_->DoAction(action).Value(result);
@@ -728,7 +751,7 @@ class TestCountingServerMiddleware : public ::testing::Test {
class TestPropagatingMiddleware : public ::testing::Test {
public:
void SetUp() {
- server_middleware_ = std::make_shared();
+ server_middleware_ = std::make_shared();
second_client_middleware_ = std::make_shared();
client_middleware_ = std::make_shared();
@@ -782,7 +805,7 @@ class TestPropagatingMiddleware : public ::testing::Test {
std::unique_ptr client_;
std::unique_ptr first_server_;
std::unique_ptr second_server_;
- std::shared_ptr server_middleware_;
+ std::shared_ptr server_middleware_;
std::shared_ptr second_client_middleware_;
std::shared_ptr client_middleware_;
};
@@ -1528,5 +1551,139 @@ TEST_F(TestCancel, DoExchange) {
ARROW_UNUSED(do_exchange_result.writer->Close());
}
+class TracingTestServer : public FlightServerBase {
+ public:
+ Status DoAction(const ServerCallContext& call_context, const Action&,
+ std::unique_ptr* result) override {
+ std::vector results;
+ auto* middleware =
+ reinterpret_cast(call_context.GetMiddleware("tracing"));
+ if (!middleware) return Status::Invalid("Could not find middleware");
+#ifdef ARROW_WITH_OPENTELEMETRY
+ // Ensure the trace context is present (but the value is random so
+ // we cannot assert any particular value)
+ EXPECT_FALSE(middleware->GetTraceContext().empty());
+ auto span = arrow::internal::tracing::GetTracer()->GetCurrentSpan();
+ const auto context = span->GetContext();
+ {
+ const auto& span_id = context.span_id();
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(span_id.Id().size()));
+ std::memcpy(buffer->mutable_data(), span_id.Id().data(), span_id.Id().size());
+ results.push_back({std::move(buffer)});
+ }
+ {
+ const auto& trace_id = context.trace_id();
+ ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(trace_id.Id().size()));
+ std::memcpy(buffer->mutable_data(), trace_id.Id().data(), trace_id.Id().size());
+ results.push_back({std::move(buffer)});
+ }
+#else
+ // Ensure the trace context is not present (as OpenTelemetry is not enabled)
+ EXPECT_TRUE(middleware->GetTraceContext().empty());
+#endif
+ *result = std::make_unique(std::move(results));
+ return Status::OK();
+ }
+};
+
+class TestTracing : public ::testing::Test {
+ public:
+ void SetUp() {
+#ifdef ARROW_WITH_OPENTELEMETRY
+ // The default tracer always generates no-op spans which have no
+ // span/trace ID. Set up a different tracer. Note, this needs to
+ // be run before Arrow uses OTel as GetTracer() gets a tracer once
+ // and keeps it in a static.
+ std::vector> processors;
+ auto provider =
+ opentelemetry::nostd::shared_ptr(
+ new opentelemetry::sdk::trace::TracerProvider(std::move(processors)));
+ opentelemetry::trace::Provider::SetTracerProvider(std::move(provider));
+
+ opentelemetry::context::propagation::GlobalTextMapPropagator::SetGlobalPropagator(
+ opentelemetry::nostd::shared_ptr<
+ opentelemetry::context::propagation::TextMapPropagator>(
+ new opentelemetry::trace::propagation::HttpTraceContext()));
+#endif
+
+ ASSERT_OK(MakeServer(
+ &server_, &client_,
+ [](FlightServerOptions* options) {
+ options->middleware.emplace_back("tracing",
+ MakeTracingServerMiddlewareFactory());
+ return Status::OK();
+ },
+ [](FlightClientOptions* options) {
+ options->middleware.push_back(MakeTracingClientMiddlewareFactory());
+ return Status::OK();
+ }));
+ }
+ void TearDown() { ASSERT_OK(server_->Shutdown()); }
+
+ protected:
+ std::unique_ptr client_;
+ std::unique_ptr server_;
+};
+
+#ifdef ARROW_WITH_OPENTELEMETRY
+// Must define it ourselves to avoid a linker error
+constexpr size_t kSpanIdSize = opentelemetry::trace::SpanId::kSize;
+constexpr size_t kTraceIdSize = opentelemetry::trace::TraceId::kSize;
+
+TEST_F(TestTracing, NoParentTrace) {
+ ASSERT_OK_AND_ASSIGN(auto results, client_->DoAction(Action{}));
+
+ ASSERT_OK_AND_ASSIGN(auto result, results->Next());
+ ASSERT_NE(result, nullptr);
+ ASSERT_NE(result->body, nullptr);
+ // Span ID should be a valid span ID, i.e. the server must have started a span
+ ASSERT_EQ(result->body->size(), kSpanIdSize);
+ opentelemetry::trace::SpanId span_id({result->body->data(), kSpanIdSize});
+ ASSERT_TRUE(span_id.IsValid());
+
+ ASSERT_OK_AND_ASSIGN(result, results->Next());
+ ASSERT_NE(result, nullptr);
+ ASSERT_NE(result->body, nullptr);
+ ASSERT_EQ(result->body->size(), kTraceIdSize);
+ opentelemetry::trace::TraceId trace_id({result->body->data(), kTraceIdSize});
+ ASSERT_TRUE(trace_id.IsValid());
+}
+TEST_F(TestTracing, WithParentTrace) {
+ auto* tracer = arrow::internal::tracing::GetTracer();
+ auto span = tracer->StartSpan("test");
+ auto scope = tracer->WithActiveSpan(span);
+
+ auto span_context = span->GetContext();
+ auto current_trace_id = span_context.trace_id().Id();
+
+ ASSERT_OK_AND_ASSIGN(auto results, client_->DoAction(Action{}));
+
+ ASSERT_OK_AND_ASSIGN(auto result, results->Next());
+ ASSERT_NE(result, nullptr);
+ ASSERT_NE(result->body, nullptr);
+ ASSERT_EQ(result->body->size(), kSpanIdSize);
+ opentelemetry::trace::SpanId span_id({result->body->data(), kSpanIdSize});
+ ASSERT_TRUE(span_id.IsValid());
+
+ ASSERT_OK_AND_ASSIGN(result, results->Next());
+ ASSERT_NE(result, nullptr);
+ ASSERT_NE(result->body, nullptr);
+ ASSERT_EQ(result->body->size(), kTraceIdSize);
+ opentelemetry::trace::TraceId trace_id({result->body->data(), kTraceIdSize});
+ // The server span should have the same trace ID as the client span.
+ ASSERT_EQ(std::string_view(reinterpret_cast(trace_id.Id().data()),
+ trace_id.Id().size()),
+ std::string_view(reinterpret_cast(current_trace_id.data()),
+ current_trace_id.size()));
+}
+#else
+TEST_F(TestTracing, NoOp) {
+ // The middleware should not cause any trouble when OTel is not enabled.
+ ASSERT_OK_AND_ASSIGN(auto results, client_->DoAction(Action{}));
+ ASSERT_OK_AND_ASSIGN(auto result, results->Next());
+ ASSERT_EQ(result, nullptr);
+}
+#endif
+
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/middleware.cc b/cpp/src/arrow/flight/middleware.cc
new file mode 100644
index 00000000000..ffbcb6aad20
--- /dev/null
+++ b/cpp/src/arrow/flight/middleware.cc
@@ -0,0 +1,53 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/middleware.h"
+
+namespace arrow {
+namespace flight {
+
+std::string ToString(FlightMethod method) {
+ // Technically, we can get this via Protobuf reflection, but in
+ // practice we'd have to hardcode the method names to look up the
+ // method descriptor...
+ switch (method) {
+ case FlightMethod::Handshake:
+ return "Handshake";
+ case FlightMethod::ListFlights:
+ return "ListFlights";
+ case FlightMethod::GetFlightInfo:
+ return "GetFlightInfo";
+ case FlightMethod::GetSchema:
+ return "GetSchema";
+ case FlightMethod::DoGet:
+ return "DoGet";
+ case FlightMethod::DoPut:
+ return "DoPut";
+ case FlightMethod::DoAction:
+ return "DoAction";
+ case FlightMethod::ListActions:
+ return "ListActions";
+ case FlightMethod::DoExchange:
+ return "DoExchange";
+ case FlightMethod::Invalid:
+ default:
+ return "(unknown Flight method)";
+ }
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/middleware.h b/cpp/src/arrow/flight/middleware.h
index b050e9cc6ed..dc1ad24bc5c 100644
--- a/cpp/src/arrow/flight/middleware.h
+++ b/cpp/src/arrow/flight/middleware.h
@@ -30,7 +30,6 @@
#include "arrow/status.h"
namespace arrow {
-
namespace flight {
/// \brief Headers sent from the client or server.
@@ -66,6 +65,10 @@ enum class FlightMethod : char {
DoExchange = 9,
};
+/// \brief Get a human-readable name for a Flight method.
+ARROW_FLIGHT_EXPORT
+std::string ToString(FlightMethod method);
+
/// \brief Information about an instance of a Flight RPC.
struct ARROW_FLIGHT_EXPORT CallInfo {
public:
@@ -74,5 +77,4 @@ struct ARROW_FLIGHT_EXPORT CallInfo {
};
} // namespace flight
-
} // namespace arrow
diff --git a/cpp/src/arrow/flight/server_tracing_middleware.cc b/cpp/src/arrow/flight/server_tracing_middleware.cc
new file mode 100644
index 00000000000..eac530efb8a
--- /dev/null
+++ b/cpp/src/arrow/flight/server_tracing_middleware.cc
@@ -0,0 +1,183 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/server_tracing_middleware.h"
+
+#include
+#include
+#include
+#include
+
+#include "arrow/flight/transport/grpc/util_internal.h"
+#include "arrow/util/tracing_internal.h"
+
+#ifdef ARROW_WITH_OPENTELEMETRY
+#include
+#include
+#include
+#include
+#include
+#endif
+
+namespace arrow {
+namespace flight {
+
+#ifdef ARROW_WITH_OPENTELEMETRY
+namespace otel = opentelemetry;
+namespace {
+class FlightServerCarrier : public otel::context::propagation::TextMapCarrier {
+ public:
+ explicit FlightServerCarrier(const CallHeaders& incoming_headers)
+ : incoming_headers_(incoming_headers) {}
+
+ otel::nostd::string_view Get(otel::nostd::string_view key) const noexcept override {
+ std::string_view arrow_key(key.data(), key.size());
+ auto it = incoming_headers_.find(arrow_key);
+ if (it == incoming_headers_.end()) return "";
+ std::string_view result = it->second;
+ return {result.data(), result.size()};
+ }
+
+ void Set(otel::nostd::string_view, otel::nostd::string_view) noexcept override {}
+
+ const CallHeaders& incoming_headers_;
+};
+class KeyValueCarrier : public otel::context::propagation::TextMapCarrier {
+ public:
+ explicit KeyValueCarrier(std::vector* items)
+ : items_(items) {}
+ otel::nostd::string_view Get(otel::nostd::string_view key) const noexcept override {
+ return {};
+ }
+ void Set(otel::nostd::string_view key,
+ otel::nostd::string_view value) noexcept override {
+ items_->push_back({std::string(key), std::string(value)});
+ }
+
+ private:
+ std::vector* items_;
+};
+} // namespace
+
+class TracingServerMiddleware::Impl {
+ public:
+ Impl(otel::trace::Scope scope, otel::nostd::shared_ptr span)
+ : scope_(std::move(scope)), span_(std::move(span)) {}
+ void CallCompleted(const Status& status) {
+ if (!status.ok()) {
+ auto grpc_status = transport::grpc::ToGrpcStatus(status, /*ctx=*/nullptr);
+ span_->SetStatus(otel::trace::StatusCode::kError, status.ToString());
+ span_->SetAttribute(OTEL_GET_TRACE_ATTR(AttrRpcGrpcStatusCode),
+ static_cast(grpc_status.error_code()));
+ } else {
+ span_->SetStatus(otel::trace::StatusCode::kOk, "");
+ span_->SetAttribute(OTEL_GET_TRACE_ATTR(AttrRpcGrpcStatusCode), int32_t(0));
+ }
+ span_->End();
+ }
+ std::vector GetTraceContext() const {
+ std::vector result;
+ KeyValueCarrier carrier(&result);
+ auto context = otel::context::RuntimeContext::GetCurrent();
+ otel::trace::propagation::HttpTraceContext propagator;
+ propagator.Inject(carrier, context);
+ return result;
+ }
+
+ private:
+ otel::trace::Scope scope_;
+ otel::nostd::shared_ptr span_;
+};
+
+class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ virtual ~TracingServerMiddlewareFactory() = default;
+ Status StartCall(const CallInfo& info, const CallHeaders& incoming_headers,
+ std::shared_ptr* middleware) override {
+ constexpr char kRpcSystem[] = "grpc";
+ constexpr char kServiceName[] = "arrow.flight.protocol.FlightService";
+
+ FlightServerCarrier carrier(incoming_headers);
+ auto context = otel::context::RuntimeContext::GetCurrent();
+ auto propagator =
+ otel::context::propagation::GlobalTextMapPropagator::GetGlobalPropagator();
+ auto new_context = propagator->Extract(carrier, context);
+
+ otel::trace::StartSpanOptions options;
+ options.kind = otel::trace::SpanKind::kServer;
+ options.parent = otel::trace::GetSpan(new_context)->GetContext();
+
+ auto* tracer = arrow::internal::tracing::GetTracer();
+ auto method_name = ToString(info.method);
+ auto span = tracer->StartSpan(
+ method_name,
+ {
+ // Attributes from experimental trace semantic conventions spec
+ // https://github.com/open-telemetry/opentelemetry-specification/blob/main/semantic_conventions/trace/rpc.yaml
+ {OTEL_GET_TRACE_ATTR(AttrRpcSystem), kRpcSystem},
+ {OTEL_GET_TRACE_ATTR(AttrRpcService), kServiceName},
+ {OTEL_GET_TRACE_ATTR(AttrRpcMethod), method_name},
+ },
+ options);
+ auto scope = tracer->WithActiveSpan(span);
+
+ std::unique_ptr impl(
+ new TracingServerMiddleware::Impl(std::move(scope), std::move(span)));
+ *middleware = std::shared_ptr(
+ new TracingServerMiddleware(std::move(impl)));
+ return Status::OK();
+ }
+};
+#else
+class TracingServerMiddleware::Impl {
+ public:
+ void CallCompleted(const Status&) {}
+ std::vector GetTraceContext() const { return {}; }
+};
+class TracingServerMiddlewareFactory : public ServerMiddlewareFactory {
+ public:
+ virtual ~TracingServerMiddlewareFactory() = default;
+ Status StartCall(const CallInfo&, const CallHeaders&,
+ std::shared_ptr* middleware) override {
+ std::unique_ptr impl(
+ new TracingServerMiddleware::Impl());
+ *middleware = std::shared_ptr(
+ new TracingServerMiddleware(std::move(impl)));
+ return Status::OK();
+ }
+};
+#endif
+
+TracingServerMiddleware::TracingServerMiddleware(std::unique_ptr impl)
+ : impl_(std::move(impl)) {}
+TracingServerMiddleware::~TracingServerMiddleware() = default;
+void TracingServerMiddleware::SendingHeaders(AddCallHeaders*) {}
+void TracingServerMiddleware::CallCompleted(const Status& status) {
+ impl_->CallCompleted(status);
+}
+std::vector TracingServerMiddleware::GetTraceContext()
+ const {
+ return impl_->GetTraceContext();
+}
+constexpr char const TracingServerMiddleware::kMiddlewareName[];
+
+std::shared_ptr MakeTracingServerMiddlewareFactory() {
+ return std::make_shared();
+}
+
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/server_tracing_middleware.h b/cpp/src/arrow/flight/server_tracing_middleware.h
new file mode 100644
index 00000000000..581c8354368
--- /dev/null
+++ b/cpp/src/arrow/flight/server_tracing_middleware.h
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Middleware implementation for propagating OpenTelemetry spans.
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "arrow/flight/server_middleware.h"
+#include "arrow/flight/visibility.h"
+#include "arrow/status.h"
+
+namespace arrow {
+namespace flight {
+
+/// \brief Returns a ServerMiddlewareFactory that handles receiving OpenTelemetry spans.
+ARROW_FLIGHT_EXPORT std::shared_ptr
+MakeTracingServerMiddlewareFactory();
+
+/// \brief A server middleware that provides access to the
+/// OpenTelemetry context, if present.
+///
+/// Used to make the OpenTelemetry span available in Python.
+class ARROW_FLIGHT_EXPORT TracingServerMiddleware : public ServerMiddleware {
+ public:
+ ~TracingServerMiddleware();
+
+ static constexpr char const kMiddlewareName[] =
+ "arrow::flight::TracingServerMiddleware";
+
+ std::string name() const override { return kMiddlewareName; }
+ void SendingHeaders(AddCallHeaders*) override;
+ void CallCompleted(const Status&) override;
+
+ struct TraceKey {
+ std::string key;
+ std::string value;
+ };
+ /// \brief Get the trace context.
+ std::vector GetTraceContext() const;
+
+ private:
+ class Impl;
+ friend class TracingServerMiddlewareFactory;
+
+ explicit TracingServerMiddleware(std::unique_ptr impl);
+ std::unique_ptr impl_;
+};
+
+} // namespace flight
+} // namespace arrow
diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx
index 16e4aad5a00..b6c9177195a 100644
--- a/python/pyarrow/_flight.pyx
+++ b/python/pyarrow/_flight.pyx
@@ -1741,13 +1741,22 @@ cdef class ServerCallContext(_Weakrefable):
CServerMiddleware* c_middleware = \
self.context.GetMiddleware(CPyServerMiddlewareName)
CPyServerMiddleware* middleware
+ vector[CTracingServerMiddlewareTraceKey] c_trace_context
+ if c_middleware == NULL:
+ c_middleware = self.context.GetMiddleware(tobytes(key))
+
if c_middleware == NULL:
return None
- if c_middleware.name() != CPyServerMiddlewareName:
- return None
- middleware = c_middleware
- py_middleware = <_ServerMiddlewareWrapper> middleware.py_object()
- return py_middleware.middleware.get(key)
+ elif c_middleware.name() == CPyServerMiddlewareName:
+ middleware = c_middleware
+ py_middleware = <_ServerMiddlewareWrapper> middleware.py_object()
+ return py_middleware.middleware.get(key)
+ elif c_middleware.name() == CTracingServerMiddlewareName:
+ c_trace_context = ( c_middleware
+ ).GetTraceContext()
+ trace_context = {pair.key: pair.value for pair in c_trace_context}
+ return TracingServerMiddleware(trace_context)
+ return None
@staticmethod
cdef ServerCallContext wrap(const CServerCallContext& context):
@@ -2528,6 +2537,22 @@ cdef class ServerMiddlewareFactory(_Weakrefable):
"""
+cdef class TracingServerMiddlewareFactory(ServerMiddlewareFactory):
+ """A factory for tracing middleware instances.
+
+ This enables OpenTelemetry support in Arrow (if Arrow was compiled
+ with OpenTelemetry support enabled). A new span will be started on
+ each RPC call. The TracingServerMiddleware instance can then be
+ retrieved within an RPC handler to get the propagated context,
+ which can be used to start a new span on the Python side.
+
+ Because the Python/C++ OpenTelemetry libraries do not
+ interoperate, spans on the C++ side are not directly visible to
+ the Python side and vice versa.
+
+ """
+
+
cdef class ServerMiddleware(_Weakrefable):
"""Server-side middleware for a call, instantiated per RPC.
@@ -2574,6 +2599,13 @@ cdef class ServerMiddleware(_Weakrefable):
c_instance[0].reset(new CPyServerMiddleware(py_middleware, vtable))
+class TracingServerMiddleware(ServerMiddleware):
+ __slots__ = ["trace_context"]
+
+ def __init__(self, trace_context):
+ self.trace_context = trace_context
+
+
cdef class _ServerMiddlewareFactoryWrapper(ServerMiddlewareFactory):
"""Wrapper to bundle server middleware into a single C++ one."""
@@ -2739,7 +2771,27 @@ cdef class FlightServerBase(_Weakrefable):
c_options.get().tls_certificates.push_back(c_cert)
if middleware:
- py_middleware = _ServerMiddlewareFactoryWrapper(middleware)
+ non_tracing_middleware = {}
+ enable_tracing = None
+ for key, factory in middleware.items():
+ if isinstance(factory, TracingServerMiddlewareFactory):
+ if enable_tracing is not None:
+ raise ValueError(
+ "Can only provide "
+ "TracingServerMiddlewareFactory once")
+ if tobytes(key) == CPyServerMiddlewareName:
+ raise ValueError(f"Middleware key cannot be {key}")
+ enable_tracing = key
+ else:
+ non_tracing_middleware[key] = factory
+
+ if enable_tracing:
+ c_middleware.first = tobytes(enable_tracing)
+ c_middleware.second = MakeTracingServerMiddlewareFactory()
+ c_options.get().middleware.push_back(c_middleware)
+
+ py_middleware = _ServerMiddlewareFactoryWrapper(
+ non_tracing_middleware)
c_middleware.first = CPyServerMiddlewareName
c_middleware.second.reset(new CPyServerMiddlewareFactory(
py_middleware,
diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py
index 0664ff2c992..8f9fa6fa7c9 100644
--- a/python/pyarrow/flight.py
+++ b/python/pyarrow/flight.py
@@ -60,4 +60,5 @@
ServerMiddleware,
ServerMiddlewareFactory,
Ticket,
+ TracingServerMiddlewareFactory,
)
diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd
index 6377459404c..3301c1b6360 100644
--- a/python/pyarrow/includes/libarrow_flight.pxd
+++ b/python/pyarrow/includes/libarrow_flight.pxd
@@ -22,8 +22,8 @@ from pyarrow.includes.libarrow cimport *
cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
- cdef char* CPyServerMiddlewareName\
- " arrow::py::flight::kPyServerMiddlewareName"
+ cdef char* CTracingServerMiddlewareName\
+ " arrow::flight::TracingServerMiddleware::kMiddlewareName"
cdef cppclass CActionType" arrow::flight::ActionType":
c_string type
@@ -322,6 +322,20 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
" arrow::flight::ClientMiddlewareFactory":
pass
+ cpdef cppclass CTracingServerMiddlewareTraceKey\
+ " arrow::flight::TracingServerMiddleware::TraceKey":
+ CTracingServerMiddlewareTraceKey()
+ c_string key
+ c_string value
+
+ cdef cppclass CTracingServerMiddleware\
+ " arrow::flight::TracingServerMiddleware"(CServerMiddleware):
+ vector[CTracingServerMiddlewareTraceKey] GetTraceContext()
+
+ cdef shared_ptr[CServerMiddlewareFactory] \
+ MakeTracingServerMiddlewareFactory\
+ " arrow::flight::MakeTracingServerMiddlewareFactory"()
+
cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions":
CFlightServerOptions(const CLocation& location)
CLocation location
@@ -472,6 +486,9 @@ ctypedef CStatus cb_client_middleware_start_call(
unique_ptr[CClientMiddleware]*)
cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil:
+ cdef char* CPyServerMiddlewareName\
+ " arrow::py::flight::kPyServerMiddlewareName"
+
cdef cppclass PyFlightServerVtable:
PyFlightServerVtable()
function[cb_list_flights] list_flights
diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py
index 72d1fa5ec33..69318a5535b 100644
--- a/python/pyarrow/tests/test_flight.py
+++ b/python/pyarrow/tests/test_flight.py
@@ -2196,3 +2196,33 @@ def test_interpreter_shutdown():
See https://issues.apache.org/jira/browse/ARROW-16597.
"""
util.invoke_script("arrow_16597.py")
+
+
+class TracingFlightServer(FlightServerBase):
+ """A server that echoes back trace context values."""
+
+ def do_action(self, context, action):
+ trace_context = context.get_middleware("tracing").trace_context
+ # Don't turn this method into a generator since then
+ # trace_context will be evaluated after we've exited the scope
+ # of the OTel span (and so the value we want won't be present)
+ return ((f"{key}: {value}").encode("utf-8")
+ for (key, value) in trace_context.items())
+
+
+def test_tracing():
+ with TracingFlightServer(middleware={
+ "tracing": flight.TracingServerMiddlewareFactory(),
+ }) as server, \
+ FlightClient(('localhost', server.port)) as client:
+ # We can't tell if Arrow was built with OpenTelemetry support,
+ # so we can't count on any particular values being there; we
+ # can only ensure things don't blow up either way.
+ options = flight.FlightCallOptions(headers=[
+ # Pretend we have an OTel implementation
+ (b"traceparent", b"00-000ff00f00f0ff000f0f00ff0f00fff0-"
+ b"000f0000f0f00000-00"),
+ (b"tracestate", b""),
+ ])
+ for value in client.do_action((b"", b""), options=options):
+ pass