Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[in] username Username to use
/// \param[in] password Password to use
/// \return Arrow result with bearer token and status OK if client authenticated
/// sucessfully
/// successfully
arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
const FlightCallOptions& options, const std::string& username,
const std::string& password);
Expand Down
24 changes: 24 additions & 0 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,30 @@ TEST(TestFlight, DISABLED_IpV6Port) {
ASSERT_OK(client->ListFlights());
}

TEST(TestFlight, ServerCallContextIncomingHeaders) {
auto server = ExampleTestServer();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));

ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(server->location()));
Action action;
action.type = "list-incoming-headers";
action.body = Buffer::FromString("test-header");
FlightCallOptions call_options;
call_options.headers.emplace_back("test-header1", "value1");
call_options.headers.emplace_back("test-header2", "value2");
ASSERT_OK_AND_ASSIGN(auto stream, client->DoAction(call_options, action));
ASSERT_OK_AND_ASSIGN(auto result, stream->Next());
ASSERT_NE(result.get(), nullptr);
ASSERT_EQ(result->body->ToString(), "test-header1: value1");
ASSERT_OK_AND_ASSIGN(result, stream->Next());
ASSERT_NE(result.get(), nullptr);
ASSERT_EQ(result->body->ToString(), "test-header2: value2");
ASSERT_OK_AND_ASSIGN(result, stream->Next());
ASSERT_EQ(result.get(), nullptr);
}

// ----------------------------------------------------------------------
// Client tests

Expand Down
8 changes: 1 addition & 7 deletions cpp/src/arrow/flight/middleware.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,17 @@

#pragma once

#include <map>
#include <memory>
#include <string>
#include <string_view>
#include <utility>

#include "arrow/flight/visibility.h" // IWYU pragma: keep
#include "arrow/flight/types.h"
#include "arrow/status.h"

namespace arrow {
namespace flight {

/// \brief Headers sent from the client or server.
///
/// Header values are ordered.
using CallHeaders = std::multimap<std::string_view, std::string_view>;

/// \brief A write-only wrapper around headers for an RPC call.
class ARROW_FLIGHT_EXPORT AddCallHeaders {
public:
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/flight/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class ARROW_FLIGHT_EXPORT ServerCallContext {
/// \brief Check if the current RPC has been cancelled (by the client, by
/// a network error, etc.).
virtual bool is_cancelled() const = 0;
/// \brief The headers sent by the client for this call.
virtual const CallHeaders& incoming_headers() const = 0;
};

class ARROW_FLIGHT_EXPORT FlightServerOptions {
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/arrow/flight/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,31 @@ class FlightTestServer : public FlightServerBase {
return Status::OK();
}

Status ListIncomingHeaders(const ServerCallContext& context, const Action& action,
std::unique_ptr<ResultStream>* out) {
std::vector<Result> results;
std::string_view prefix(*action.body);
for (const auto& header : context.incoming_headers()) {
if (header.first.substr(0, prefix.size()) != prefix) {
continue;
}
Result result;
result.body = Buffer::FromString(std::string(header.first) + ": " +
std::string(header.second));
results.push_back(result);
}
*out = std::make_unique<SimpleResultStream>(std::move(results));
return Status::OK();
}

Status DoAction(const ServerCallContext& context, const Action& action,
std::unique_ptr<ResultStream>* out) override {
if (action.type == "action1") {
return RunAction1(action, out);
} else if (action.type == "action2") {
return RunAction2(out);
} else if (action.type == "list-incoming-headers") {
return ListIncomingHeaders(context, action, out);
} else {
return Status::NotImplemented(action.type);
}
Expand Down
19 changes: 11 additions & 8 deletions cpp/src/arrow/flight/transport/grpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,18 @@ class GrpcServerAuthSender : public ServerAuthSender {

class GrpcServerCallContext : public ServerCallContext {
explicit GrpcServerCallContext(::grpc::ServerContext* context)
: context_(context), peer_(context_->peer()) {}
: context_(context), peer_(context_->peer()) {
for (const auto& entry : context->client_metadata()) {
incoming_headers_.insert(
{std::string_view(entry.first.data(), entry.first.length()),
std::string_view(entry.second.data(), entry.second.length())});
}
}

const std::string& peer_identity() const override { return peer_identity_; }
const std::string& peer() const override { return peer_; }
bool is_cancelled() const override { return context_->IsCancelled(); }
const CallHeaders& incoming_headers() const override { return incoming_headers_; }

// Helper method that runs interceptors given the result of an RPC,
// then returns the final gRPC status to send to the client
Expand Down Expand Up @@ -156,6 +163,7 @@ class GrpcServerCallContext : public ServerCallContext {
std::string peer_identity_;
std::vector<std::shared_ptr<ServerMiddleware>> middleware_;
std::unordered_map<std::string, std::shared_ptr<ServerMiddleware>> middleware_map_;
CallHeaders incoming_headers_;
};

class GrpcAddServerHeaders : public AddCallHeaders {
Expand Down Expand Up @@ -310,17 +318,12 @@ class GrpcServiceHandler final : public FlightService::Service {
GrpcServerCallContext& flight_context) {
// Run server middleware
const CallInfo info{method};
CallHeaders incoming_headers;
for (const auto& entry : context->client_metadata()) {
incoming_headers.insert(
{std::string_view(entry.first.data(), entry.first.length()),
std::string_view(entry.second.data(), entry.second.length())});
}

GrpcAddServerHeaders outgoing_headers(context);
for (const auto& factory : middleware_) {
std::shared_ptr<ServerMiddleware> instance;
Status result = factory.second->StartCall(info, incoming_headers, &instance);
Status result =
factory.second->StartCall(info, flight_context.incoming_headers(), &instance);
if (!result.ok()) {
// Interceptor rejected call, end the request on all existing
// interceptors
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/flight/transport/ucx/ucx_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ class UcxServerCallContext : public flight::ServerCallContext {
return nullptr;
}
bool is_cancelled() const override { return false; }
const CallHeaders& incoming_headers() const override { return incoming_headers_; }

private:
std::string peer_;
CallHeaders incoming_headers_;
};

class UcxServerStream : public internal::ServerDataStream {
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/flight/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <cstddef>
#include <cstdint>
#include <map>
#include <memory>
#include <string>
#include <string_view>
Expand Down Expand Up @@ -123,6 +124,11 @@ ARROW_FLIGHT_EXPORT
Status MakeFlightError(FlightStatusCode code, std::string message,
std::string extra_info = {});

/// \brief Headers sent from the client or server.
///
/// Header values are ordered.
using CallHeaders = std::multimap<std::string_view, std::string_view>;

/// \brief A TLS certificate plus key.
struct ARROW_FLIGHT_EXPORT CertKeyPair {
/// \brief The certificate in PEM format.
Expand Down