diff --git a/cpp/src/arrow/filesystem/s3_internal.h b/cpp/src/arrow/filesystem/s3_internal.h index 54da3d5987e..772387e5fb6 100644 --- a/cpp/src/arrow/filesystem/s3_internal.h +++ b/cpp/src/arrow/filesystem/s3_internal.h @@ -29,15 +29,38 @@ #include #include #include +#include #include #include "arrow/filesystem/filesystem.h" #include "arrow/filesystem/s3fs.h" #include "arrow/status.h" +#include "arrow/util/base64.h" #include "arrow/util/logging.h" #include "arrow/util/print.h" #include "arrow/util/string.h" +#ifndef ARROW_AWS_SDK_VERSION_CHECK +// AWS_SDK_VERSION_{MAJOR,MINOR,PATCH} are available since 1.9.7. +# if defined(AWS_SDK_VERSION_MAJOR) && defined(AWS_SDK_VERSION_MINOR) && \ + defined(AWS_SDK_VERSION_PATCH) +// Redundant "(...)" are for suppressing "Weird number of spaces at +// line-start. Are you using a 2-space indent? [whitespace/indent] +// [3]" errors... +# define ARROW_AWS_SDK_VERSION_CHECK(major, minor, patch) \ + ((AWS_SDK_VERSION_MAJOR > (major) || \ + (AWS_SDK_VERSION_MAJOR == (major) && AWS_SDK_VERSION_MINOR > (minor)) || \ + ((AWS_SDK_VERSION_MAJOR == (major) && AWS_SDK_VERSION_MINOR == (minor) && \ + AWS_SDK_VERSION_PATCH >= (patch))))) +# else +# define ARROW_AWS_SDK_VERSION_CHECK(major, minor, patch) 0 +# endif +#endif // !ARROW_AWS_SDK_VERSION_CHECK + +#if ARROW_AWS_SDK_VERSION_CHECK(1, 9, 201) +# define ARROW_S3_HAS_SSE_CUSTOMER_KEY +#endif + namespace arrow { namespace fs { namespace internal { @@ -291,6 +314,70 @@ class ConnectRetryStrategy : public Aws::Client::RetryStrategy { int32_t max_retry_duration_; }; +/// \brief calculate the MD5 of the input SSE-C key (raw key, not base64 encoded) +/// \param sse_customer_key is the input SSE-C key +/// \return the base64 encoded MD5 for the input key +inline Result CalculateSSECustomerKeyMD5( + const std::string& sse_customer_key) { + // The key needs to be 256 bits (32 bytes) according to + // https://docs.aws.amazon.com/AmazonS3/latest/userguide/ServerSideEncryptionCustomerKeys.html#specifying-s3-c-encryption + if (sse_customer_key.length() != 32) { + return Status::Invalid("32 bytes SSE-C key is expected"); + } + + // Convert the raw binary key to an Aws::String + Aws::String sse_customer_key_aws_string(sse_customer_key.data(), + sse_customer_key.length()); + + // Compute the MD5 hash of the raw binary key + Aws::Utils::ByteBuffer sse_customer_key_md5 = + Aws::Utils::HashingUtils::CalculateMD5(sse_customer_key_aws_string); + + // Base64-encode the MD5 hash + return arrow::util::base64_encode(std::string_view( + reinterpret_cast(sse_customer_key_md5.GetUnderlyingData()), + sse_customer_key_md5.GetLength())); +} + +struct SSECustomerKeyHeaders { + std::string sse_customer_key; + std::string sse_customer_key_md5; + std::string sse_customer_algorithm; +}; + +inline Result> GetSSECustomerKeyHeaders( + const std::string& sse_customer_key) { + if (sse_customer_key.empty()) { + return std::nullopt; + } +#ifdef ARROW_S3_HAS_SSE_CUSTOMER_KEY + ARROW_ASSIGN_OR_RAISE(auto md5, internal::CalculateSSECustomerKeyMD5(sse_customer_key)); + return SSECustomerKeyHeaders{arrow::util::base64_encode(sse_customer_key), md5, + "AES256"}; +#else + return Status::NotImplemented( + "SSE customer key not supported by this version of the AWS SDK"); +#endif +} + +template +Status SetSSECustomerKey(S3RequestType* request, const std::string& sse_customer_key) { + ARROW_ASSIGN_OR_RAISE(auto maybe_headers, GetSSECustomerKeyHeaders(sse_customer_key)); + if (!maybe_headers.has_value()) { + return Status::OK(); + } +#ifdef ARROW_S3_HAS_SSE_CUSTOMER_KEY + auto headers = std::move(maybe_headers).value(); + request->SetSSECustomerKey(headers.sse_customer_key); + request->SetSSECustomerKeyMD5(headers.sse_customer_key_md5); + request->SetSSECustomerAlgorithm(headers.sse_customer_algorithm); + return Status::OK(); +#else + return Status::NotImplemented( + "SSE customer key not supported by this version of the AWS SDK"); +#endif +} + } // namespace internal } // namespace fs } // namespace arrow diff --git a/cpp/src/arrow/filesystem/s3_test_cert_internal.h b/cpp/src/arrow/filesystem/s3_test_cert_internal.h new file mode 100644 index 00000000000..0a69ade7d0e --- /dev/null +++ b/cpp/src/arrow/filesystem/s3_test_cert_internal.h @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +namespace arrow::fs { +// The below two static strings are generated according to +// https://github.com/minio/minio/tree/RELEASE.2024-09-22T00-33-43Z/docs/tls#323-generate-a-self-signed-certificate +// `openssl req -new -x509 -nodes -days 36500 -keyout private.key -out public.crt -config +// openssl.conf` +static constexpr const char* kMinioPrivateKey = R"(-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCqwKYHsTSciGqP +uU3qkTWpnXIi3iC0eeW7JSzJHGFs880WdR5JdK4WufPK+1xzgiYjMEPfAcuSWz3b +qYyCI61q+a9Iu2nj7cFTW9bfZrmWlnI0YOLJc+q0AAdAjF1lvRKenH8tbjz/2jyl +i/cYQ+I5Tg4nngrX8OmOfluNzwD/nwGLq6/DVbzDUdPI9q1XtVT/0Vf7qwbDG1HD +NkIzKT5B+YdSLaOCRYNK3x7RPsfazKIBrTmRy1v454wKe8TjTmTB7+m5wKqfCJcq +lI253WHcK0lsw6zCNtX/kahPAvm/8mniPolW4qxoD6xwebgMVkrNTs3ztcPIG9O4 +pmCbATijAgMBAAECggEACL5swiAU7Z8etdVrZAOjl9f0LEzrp9JGLVst++50Hrwt +WGUO8/wBnjBPh6lvhoq3oT2rfBP/dLMva7w28cMZ8kxu6W6PcZiPOdGOI0qDXm69 +0mjTtDU3Y5hMxsVpUvhnp6+j45Otk/x89o1ATgHL59tTZjv1mjFABIf78DsVdgF9 +CMi2q6Lv7NLftieyWmz1K3p109z9+xkDNSOkVrv1JFChviKqWgIS0rdFjySvTgoy +rHYT+TweDliKJrZCeoUJmNB0uVW/dM9lXhcvkvkJZKPPurylx1oH5a7K/sWFPf7A +Ed1vjvZQFlaXu/bOUUSOZtkErAir/oCxrUDsHxGsAQKBgQDZghyy7jNGNdjZe1Xs +On1ZVgIS3Nt+OLGCVH7tTsfZsCOb+SkrhB1RQva3YzPMfgoZScI9+bN/pRVf49Pj +qGEHkW/wozutUve7UMzeTOm1aWxUuaKSrmYST7muvAnlYEtO7agd0wrcusYXlMoG +KQwghkufO9I7wXcrudMKXZalIwKBgQDI+FaUwhgfThkgq6bRbdMEeosgohrCM9Wm +E5JMePQq4VaGcgGveWUoNOgT8kvJa0qQwQOqLZj7kUIdj+SCRt0u+Wu3p5IMqdOq +6tMnLNQ3wzUC2KGFLSfISR3L/bo5Bo6Jqz4hVtjMk3PV9bu50MNTNaofYb2xlf/f +/WgiEG0WgQKBgAr8RVLMMQ7EvXUOg6Jwuc//Rg+J1BQl7OE2P0rhBbr66HGCPhAS +liB6j1dnzT/wxbXNQeA7clNqFRBIw3TmFjB5qfuvYt44KIbvZ8l6fPtKncwRrCJY +aJNYL3qhyKYrHOKZojoPZKcNT9/1BdcVz6T842jhbpbSCKDOu9f0Lh2dAoGATZeM +Hh0eISAPFY0QeDV1znnds3jC6g4HQ/q0dnAQnWmo9XmY6v3sr2xV2jWnSxnwjRjo +aFD4itBXfYBr0ly30wYbr6mz+s2q2oeVhL+LJAhrNDEdk4SOooaQSY0p1BCTAdYq +w8Z7J+kaRRZ+J0zRzROgHkOncKQgSYPWK6i55YECgYAC+ECrHhUlPsfusjKpFsEe +stW1HCt3wXtKQn6SJ6IAesbxwALZS6Da/ZC2x1mdBHS3GwWvtGLc0BPnPVfJjr9V +m82qkgJ+p5d7qp7pRA7SFD+5809yVqRnEF3rSLafgGet9ah0ZjZvQ3fwnYZNnNH9 +t9pJcv2E5xY7/nFNIorpKg== +-----END PRIVATE KEY----- +)"; + +static constexpr const char* kMinioCert = R"(-----BEGIN CERTIFICATE----- +MIIDiTCCAnGgAwIBAgIUXbHZ6FAhKSXg4WSGUQySlSyE4U0wDQYJKoZIhvcNAQEL +BQAwXzELMAkGA1UEBhMCVVMxCzAJBgNVBAgMAlZBMQ4wDAYDVQQHDAVBcnJvdzEO +MAwGA1UECgwFQXJyb3cxDjAMBgNVBAsMBUFycm93MRMwEQYDVQQDDApBcnJyb3dU +ZXN0MB4XDTI0MDkyNDA5MzUxNloXDTM0MDkyMjA5MzUxNlowXzELMAkGA1UEBhMC +VVMxCzAJBgNVBAgMAlZBMQ4wDAYDVQQHDAVBcnJvdzEOMAwGA1UECgwFQXJyb3cx +DjAMBgNVBAsMBUFycm93MRMwEQYDVQQDDApBcnJyb3dUZXN0MIIBIjANBgkqhkiG +9w0BAQEFAAOCAQ8AMIIBCgKCAQEAqsCmB7E0nIhqj7lN6pE1qZ1yIt4gtHnluyUs +yRxhbPPNFnUeSXSuFrnzyvtcc4ImIzBD3wHLkls926mMgiOtavmvSLtp4+3BU1vW +32a5lpZyNGDiyXPqtAAHQIxdZb0Snpx/LW48/9o8pYv3GEPiOU4OJ54K1/Dpjn5b +jc8A/58Bi6uvw1W8w1HTyPatV7VU/9FX+6sGwxtRwzZCMyk+QfmHUi2jgkWDSt8e +0T7H2syiAa05kctb+OeMCnvE405kwe/pucCqnwiXKpSNud1h3CtJbMOswjbV/5Go +TwL5v/Jp4j6JVuKsaA+scHm4DFZKzU7N87XDyBvTuKZgmwE4owIDAQABoz0wOzAa +BgNVHREEEzARhwR/AAABgglsb2NhbGhvc3QwHQYDVR0OBBYEFOUNqUSfROf1dz3o +hAVBhgd3UIvKMA0GCSqGSIb3DQEBCwUAA4IBAQBSwWJ2dSw3jlHU0l2V3ozqthTt +XFo07AyWGw8AWNCM6mQ+GKBf0JJ1d7e4lyTf2lCobknS94EgGPORWeiucKYAoCjS +dh1eKGsSevz1rNbp7wsO7DoiRPciK+S95DbsPowloGI6fvOeE12Cf1udeNIpEYWs +OBFwN0HxfYqdPALCtw7l0icpTrJ2Us06UfL9kbkdZwQhXvOscG7JDRtNjBxl9XNm +TFeMNKROmrEPCWaYr6MJ+ItHtb5Cawapea4THz9GCjR9eLq2CbMqLezZ8xBHPzc4 +ixI2l0uCfg7ZUSA+90yaScc7bhEQ8CMiPtJgNKaKIqB58DpY7028xJpW7Ma2 +-----END CERTIFICATE----- +)"; +} // namespace arrow::fs diff --git a/cpp/src/arrow/filesystem/s3_test_util.cc b/cpp/src/arrow/filesystem/s3_test_util.cc index db0c60f2e80..0cfe038599c 100644 --- a/cpp/src/arrow/filesystem/s3_test_util.cc +++ b/cpp/src/arrow/filesystem/s3_test_util.cc @@ -19,6 +19,7 @@ # include #endif +#include "arrow/filesystem/s3_test_cert_internal.h" #include "arrow/filesystem/s3_test_util.h" #include "arrow/filesystem/s3fs.h" #include "arrow/testing/process.h" @@ -31,6 +32,11 @@ namespace arrow { namespace fs { +using ::arrow::internal::FileClose; +using ::arrow::internal::FileDescriptor; +using ::arrow::internal::FileOpenWritable; +using ::arrow::internal::FileWrite; +using ::arrow::internal::PlatformFilename; using ::arrow::internal::TemporaryDir; namespace { @@ -44,16 +50,16 @@ const char* kEnvConnectString = "ARROW_TEST_S3_CONNECT_STRING"; const char* kEnvAccessKey = "ARROW_TEST_S3_ACCESS_KEY"; const char* kEnvSecretKey = "ARROW_TEST_S3_SECRET_KEY"; -std::string GenerateConnectString() { return GetListenAddress(); } - } // namespace struct MinioTestServer::Impl { std::unique_ptr temp_dir_; + std::unique_ptr temp_dir_ca_; std::string connect_string_; std::string access_key_ = kMinioAccessKey; std::string secret_key_ = kMinioSecretKey; std::unique_ptr server_process_; + std::string scheme_ = "http"; }; MinioTestServer::MinioTestServer() : impl_(new Impl) {} @@ -69,7 +75,41 @@ std::string MinioTestServer::access_key() const { return impl_->access_key_; } std::string MinioTestServer::secret_key() const { return impl_->secret_key_; } -Status MinioTestServer::Start() { +std::string MinioTestServer::ca_dir_path() const { + return impl_->temp_dir_ca_->path().ToString(); +} + +std::string MinioTestServer::ca_file_path() const { + return impl_->temp_dir_ca_->path().ToString() + "/public.crt"; +} + +std::string MinioTestServer::scheme() const { return impl_->scheme_; } + +Status MinioTestServer::GenerateCertificateFile() { + // create the dedicated folder for certificate file, rather than reuse the data + // folder, since there is test case to check whether the folder is empty. + ARROW_ASSIGN_OR_RAISE(impl_->temp_dir_ca_, TemporaryDir::Make("s3fs-test-ca-")); + + ARROW_ASSIGN_OR_RAISE(auto public_crt_file, + PlatformFilename::FromString(ca_dir_path() + "/public.crt")); + ARROW_ASSIGN_OR_RAISE(auto public_cert_fd, FileOpenWritable(public_crt_file)); + ARROW_RETURN_NOT_OK(FileWrite(public_cert_fd.fd(), + reinterpret_cast(kMinioCert), + strlen(kMinioCert))); + ARROW_RETURN_NOT_OK(public_cert_fd.Close()); + + ARROW_ASSIGN_OR_RAISE(auto private_key_file, + PlatformFilename::FromString(ca_dir_path() + "/private.key")); + ARROW_ASSIGN_OR_RAISE(auto private_key_fd, FileOpenWritable(private_key_file)); + ARROW_RETURN_NOT_OK(FileWrite(private_key_fd.fd(), + reinterpret_cast(kMinioPrivateKey), + strlen(kMinioPrivateKey))); + ARROW_RETURN_NOT_OK(private_key_fd.Close()); + + return Status::OK(); +} + +Status MinioTestServer::Start(bool enable_tls) { const char* connect_str = std::getenv(kEnvConnectString); const char* access_key = std::getenv(kEnvAccessKey); const char* secret_key = std::getenv(kEnvSecretKey); @@ -88,12 +128,27 @@ Status MinioTestServer::Start() { impl_->server_process_->SetEnv("MINIO_SECRET_KEY", kMinioSecretKey); // Disable the embedded console (one less listening address to care about) impl_->server_process_->SetEnv("MINIO_BROWSER", "off"); - impl_->connect_string_ = GenerateConnectString(); - ARROW_RETURN_NOT_OK(impl_->server_process_->SetExecutable(kMinioExecutableName)); // NOTE: --quiet makes startup faster by suppressing remote version check - impl_->server_process_->SetArgs({"server", "--quiet", "--compat", "--address", - impl_->connect_string_, - impl_->temp_dir_->path().ToString()}); + std::vector minio_args({"server", "--quiet", "--compat"}); + if (enable_tls) { + ARROW_RETURN_NOT_OK(GenerateCertificateFile()); + minio_args.emplace_back("--certs-dir"); + minio_args.emplace_back(ca_dir_path()); + impl_->scheme_ = "https"; + // With TLS enabled, we need the connection hostname to match the certificate's + // subject name. This also constrains the actual listening IP address. + impl_->connect_string_ = GetListenAddress("localhost"); + } else { + // Without TLS enabled, we want to minimize the likelihood of address collisions + // by varying the listening IP address (note that most tests don't enable TLS). + impl_->connect_string_ = GetListenAddress(); + } + minio_args.emplace_back("--address"); + minio_args.emplace_back(impl_->connect_string_); + minio_args.emplace_back(impl_->temp_dir_->path().ToString()); + + ARROW_RETURN_NOT_OK(impl_->server_process_->SetExecutable(kMinioExecutableName)); + impl_->server_process_->SetArgs(minio_args); ARROW_RETURN_NOT_OK(impl_->server_process_->Execute()); return Status::OK(); } @@ -105,24 +160,29 @@ Status MinioTestServer::Stop() { struct MinioTestEnvironment::Impl { std::function>()> server_generator_; + bool enable_tls_; + + explicit Impl(bool enable_tls) : enable_tls_(enable_tls) {} Result> LaunchOneServer() { auto server = std::make_shared(); - RETURN_NOT_OK(server->Start()); + RETURN_NOT_OK(server->Start(enable_tls_)); return server; } }; -MinioTestEnvironment::MinioTestEnvironment() : impl_(new Impl) {} +MinioTestEnvironment::MinioTestEnvironment(bool enable_tls) + : impl_(new Impl(enable_tls)) {} MinioTestEnvironment::~MinioTestEnvironment() = default; void MinioTestEnvironment::SetUp() { auto pool = ::arrow::internal::GetCpuThreadPool(); - auto launch_one_server = []() -> Result> { + auto launch_one_server = + [enable_tls = impl_->enable_tls_]() -> Result> { auto server = std::make_shared(); - RETURN_NOT_OK(server->Start()); + RETURN_NOT_OK(server->Start(enable_tls)); return server; }; impl_->server_generator_ = [pool, launch_one_server]() { diff --git a/cpp/src/arrow/filesystem/s3_test_util.h b/cpp/src/arrow/filesystem/s3_test_util.h index e270a6e1c46..0a89a7a9d5a 100644 --- a/cpp/src/arrow/filesystem/s3_test_util.h +++ b/cpp/src/arrow/filesystem/s3_test_util.h @@ -40,7 +40,7 @@ class MinioTestServer { MinioTestServer(); ~MinioTestServer(); - Status Start(); + Status Start(bool enable_tls = false); Status Stop(); @@ -50,7 +50,14 @@ class MinioTestServer { std::string secret_key() const; + std::string ca_dir_path() const; + + std::string ca_file_path() const; + + std::string scheme() const; + private: + Status GenerateCertificateFile(); struct Impl; std::unique_ptr impl_; }; @@ -60,7 +67,7 @@ class MinioTestServer { class MinioTestEnvironment : public ::testing::Environment { public: - MinioTestEnvironment(); + explicit MinioTestEnvironment(bool enable_tls = false); ~MinioTestEnvironment(); void SetUp() override; diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index 13d6ead6ef6..ee47e1c7020 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -160,6 +160,7 @@ using internal::IsNotFound; using internal::OutcomeToResult; using internal::OutcomeToStatus; using internal::S3Backend; +using internal::SetSSECustomerKey; using internal::ToAwsString; using internal::ToURLEncodedAwsString; @@ -403,6 +404,13 @@ Result S3Options::FromUri(const Uri& uri, std::string* out_path) { } else if (kv.first == "allow_bucket_deletion") { ARROW_ASSIGN_OR_RAISE(options.allow_bucket_deletion, ::arrow::internal::ParseBoolean(kv.second)); + } else if (kv.first == "tls_ca_file_path") { + options.tls_ca_file_path = kv.second; + } else if (kv.first == "tls_ca_dir_path") { + options.tls_ca_dir_path = kv.second; + } else if (kv.first == "tls_verify_certificates") { + ARROW_ASSIGN_OR_RAISE(options.tls_verify_certificates, + ::arrow::internal::ParseBoolean(kv.second)); } else { return Status::Invalid("Unexpected query parameter in S3 URI: '", kv.first, "'"); } @@ -439,7 +447,11 @@ bool S3Options::Equals(const S3Options& other) const { background_writes == other.background_writes && allow_bucket_creation == other.allow_bucket_creation && allow_bucket_deletion == other.allow_bucket_deletion && - default_metadata_equals && GetAccessKey() == other.GetAccessKey() && + tls_ca_file_path == other.tls_ca_file_path && + tls_ca_dir_path == other.tls_ca_dir_path && + tls_verify_certificates == other.tls_verify_certificates && + sse_customer_key == other.sse_customer_key && default_metadata_equals && + GetAccessKey() == other.GetAccessKey() && GetSecretKey() == other.GetSecretKey() && GetSessionToken() == other.GetSessionToken()); } @@ -1125,12 +1137,17 @@ class ClientBuilder { } else { client_config_.retryStrategy = std::make_shared(); } - if (!internal::global_options.tls_ca_file_path.empty()) { + if (!options_.tls_ca_file_path.empty()) { + client_config_.caFile = ToAwsString(options_.tls_ca_file_path); + } else if (!internal::global_options.tls_ca_file_path.empty()) { client_config_.caFile = ToAwsString(internal::global_options.tls_ca_file_path); } - if (!internal::global_options.tls_ca_dir_path.empty()) { + if (!options_.tls_ca_dir_path.empty()) { + client_config_.caPath = ToAwsString(options_.tls_ca_dir_path); + } else if (!internal::global_options.tls_ca_dir_path.empty()) { client_config_.caPath = ToAwsString(internal::global_options.tls_ca_dir_path); } + client_config_.verifySSL = options_.tls_verify_certificates; // Set proxy options if provided if (!options_.proxy_options.scheme.empty()) { @@ -1292,11 +1309,14 @@ Aws::IOStreamFactory AwsWriteableStreamFactory(void* data, int64_t nbytes) { } Result GetObjectRange(Aws::S3::S3Client* client, - const S3Path& path, int64_t start, - int64_t length, void* out) { + const S3Path& path, + const std::string& sse_customer_key, + int64_t start, int64_t length, + void* out) { S3Model::GetObjectRequest req; req.SetBucket(ToAwsString(path.bucket)); req.SetKey(ToAwsString(path.key)); + RETURN_NOT_OK(SetSSECustomerKey(&req, sse_customer_key)); req.SetRange(ToAwsString(FormatRange(start, length))); req.SetResponseStreamFactory(AwsWriteableStreamFactory(out, length)); return OutcomeToResult("GetObject", client->GetObject(req)); @@ -1433,11 +1453,13 @@ bool IsDirectory(std::string_view key, const S3Model::HeadObjectResult& result) class ObjectInputFile final : public io::RandomAccessFile { public: ObjectInputFile(std::shared_ptr holder, const io::IOContext& io_context, - const S3Path& path, int64_t size = kNoSize) + const S3Path& path, int64_t size = kNoSize, + const std::string& sse_customer_key = "") : holder_(std::move(holder)), io_context_(io_context), path_(path), - content_length_(size) {} + content_length_(size), + sse_customer_key_(sse_customer_key) {} Status Init() { // Issue a HEAD Object to get the content-length and ensure any @@ -1450,6 +1472,7 @@ class ObjectInputFile final : public io::RandomAccessFile { S3Model::HeadObjectRequest req; req.SetBucket(ToAwsString(path_.bucket)); req.SetKey(ToAwsString(path_.key)); + RETURN_NOT_OK(SetSSECustomerKey(&req, sse_customer_key_)); ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); auto outcome = client_lock.Move()->HeadObject(req); @@ -1534,9 +1557,9 @@ class ObjectInputFile final : public io::RandomAccessFile { // Read the desired range of bytes ARROW_ASSIGN_OR_RAISE(auto client_lock, holder_->Lock()); - ARROW_ASSIGN_OR_RAISE( - S3Model::GetObjectResult result, - GetObjectRange(client_lock.get(), path_, position, nbytes, out)); + ARROW_ASSIGN_OR_RAISE(S3Model::GetObjectResult result, + GetObjectRange(client_lock.get(), path_, sse_customer_key_, + position, nbytes, out)); auto& stream = result.GetBody(); stream.ignore(nbytes); @@ -1584,6 +1607,7 @@ class ObjectInputFile final : public io::RandomAccessFile { int64_t pos_ = 0; int64_t content_length_ = kNoSize; std::shared_ptr metadata_; + std::string sse_customer_key_; }; // Upload size per part. While AWS and Minio support different sizes for each @@ -1620,7 +1644,8 @@ class ObjectOutputStream final : public io::OutputStream { metadata_(metadata), default_metadata_(options.default_metadata), background_writes_(options.background_writes), - allow_delayed_open_(options.allow_delayed_open) {} + allow_delayed_open_(options.allow_delayed_open), + sse_customer_key_(options.sse_customer_key) {} ~ObjectOutputStream() override { // For compliance with the rest of the IO stack, Close rather than Abort, @@ -1668,6 +1693,7 @@ class ObjectOutputStream final : public io::OutputStream { S3Model::CreateMultipartUploadRequest req; req.SetBucket(ToAwsString(path_.bucket)); req.SetKey(ToAwsString(path_.key)); + RETURN_NOT_OK(SetSSECustomerKey(&req, sse_customer_key_)); RETURN_NOT_OK(SetMetadataInRequest(&req)); auto outcome = client_lock.Move()->CreateMultipartUpload(req); @@ -1771,6 +1797,7 @@ class ObjectOutputStream final : public io::OutputStream { req.SetKey(ToAwsString(path_.key)); req.SetUploadId(multipart_upload_id_); req.SetMultipartUpload(std::move(completed_upload)); + RETURN_NOT_OK(SetSSECustomerKey(&req, sse_customer_key_)); auto outcome = client_lock.Move()->CompleteMultipartUploadWithErrorFixup(std::move(req)); @@ -1958,6 +1985,7 @@ class ObjectOutputStream final : public io::OutputStream { req.SetKey(ToAwsString(path_.key)); req.SetBody(std::make_shared(data, nbytes)); req.SetContentLength(nbytes); + RETURN_NOT_OK(SetSSECustomerKey(&req, sse_customer_key_)); if (!background_writes_) { req.SetBody(std::make_shared(data, nbytes)); @@ -2171,6 +2199,7 @@ class ObjectOutputStream final : public io::OutputStream { Future<> pending_uploads_completed = Future<>::MakeFinished(Status::OK()); }; std::shared_ptr upload_state_; + std::string sse_customer_key_; }; // This function assumes info->path() is already set @@ -2339,6 +2368,17 @@ class S3FileSystem::Impl : public std::enable_shared_from_this(holder_, fs->io_context(), path); + auto ptr = std::make_shared(holder_, fs->io_context(), path, kNoSize, + fs->options().sse_customer_key); RETURN_NOT_OK(ptr->Init()); return ptr; } @@ -3002,8 +3043,8 @@ class S3FileSystem::Impl : public std::enable_shared_from_this(holder_, fs->io_context(), path, info.size()); + auto ptr = std::make_shared( + holder_, fs->io_context(), path, info.size(), fs->options().sse_customer_key); RETURN_NOT_OK(ptr->Init()); return ptr; } diff --git a/cpp/src/arrow/filesystem/s3fs.h b/cpp/src/arrow/filesystem/s3fs.h index 85d5ff8fed5..ac6342f26a3 100644 --- a/cpp/src/arrow/filesystem/s3fs.h +++ b/cpp/src/arrow/filesystem/s3fs.h @@ -196,6 +196,37 @@ struct ARROW_EXPORT S3Options { /// delay between retries. std::shared_ptr retry_strategy; + /// Optional customer-provided key for server-side encryption (SSE-C). + /// + /// This should be the 32-byte AES-256 key, unencoded. + std::string sse_customer_key; + + /// Optional path to a single PEM file holding all TLS CA certificates + /// + /// If empty, global filesystem options will be used (see FileSystemGlobalOptions); + /// if the corresponding global filesystem option is also empty, the underlying + /// TLS library's defaults will be used. + /// + /// Note this option may be ignored on some systems (Windows, macOS). + std::string tls_ca_file_path; + + /// Optional path to a directory holding TLS CA + /// + /// The given directory should contain CA certificates as individual PEM files + /// named along the OpenSSL "hashed" format. + /// + /// If empty, global filesystem options will be used (see FileSystemGlobalOptions); + /// if the corresponding global filesystem option is also empty, the underlying + /// TLS library's defaults will be used. + /// + /// Note this option may be ignored on some systems (Windows, macOS). + std::string tls_ca_dir_path; + + /// Whether to verify the S3 endpoint's TLS certificate + /// + /// This option applies if the scheme is "https". + bool tls_verify_certificates = true; + S3Options(); /// Configure with the default AWS credentials provider chain. diff --git a/cpp/src/arrow/filesystem/s3fs_benchmark.cc b/cpp/src/arrow/filesystem/s3fs_benchmark.cc index 21216429639..b7b6dda6419 100644 --- a/cpp/src/arrow/filesystem/s3fs_benchmark.cc +++ b/cpp/src/arrow/filesystem/s3fs_benchmark.cc @@ -61,7 +61,7 @@ class MinioFixture : public benchmark::Fixture { public: void SetUp(const ::benchmark::State& state) override { minio_.reset(new MinioTestServer()); - ASSERT_OK(minio_->Start()); + ASSERT_OK(minio_->Start(/*enable_tls=*/false)); const char* region_str = std::getenv(kEnvAwsRegion); if (region_str) { @@ -110,7 +110,7 @@ class MinioFixture : public benchmark::Fixture { void MakeFileSystem() { options_.ConfigureAccessKey(minio_->access_key(), minio_->secret_key()); - options_.scheme = "http"; + options_.scheme = minio_->scheme(); if (!region_.empty()) { options_.region = region_; } diff --git a/cpp/src/arrow/filesystem/s3fs_test.cc b/cpp/src/arrow/filesystem/s3fs_test.cc index 43091aaa986..3082ecb7843 100644 --- a/cpp/src/arrow/filesystem/s3fs_test.cc +++ b/cpp/src/arrow/filesystem/s3fs_test.cc @@ -71,6 +71,12 @@ #include "arrow/util/range.h" #include "arrow/util/string.h" +// TLS tests require the ability to set a custom CA file when initiating S3 client +// connections, which the AWS SDK currently only supports on Linux. +#if defined(__linux__) +# define ENABLE_TLS_TESTS +#endif // Linux + namespace arrow { namespace fs { @@ -80,6 +86,7 @@ using ::arrow::internal::ToChars; using ::arrow::internal::Zip; using ::arrow::util::UriEscape; +using ::arrow::fs::internal::CalculateSSECustomerKeyMD5; using ::arrow::fs::internal::ConnectRetryStrategy; using ::arrow::fs::internal::ErrorToStatus; using ::arrow::fs::internal::OutcomeToStatus; @@ -94,8 +101,15 @@ ::testing::Environment* s3_env = ::testing::AddGlobalTestEnvironment(new S3Envir ::testing::Environment* minio_env = ::testing::AddGlobalTestEnvironment(new MinioTestEnvironment); -MinioTestEnvironment* GetMinioEnv() { - return ::arrow::internal::checked_cast(minio_env); +::testing::Environment* minio_env_https = + ::testing::AddGlobalTestEnvironment(new MinioTestEnvironment(/*enable_tls=*/true)); + +MinioTestEnvironment* GetMinioEnv(bool enable_tls) { + if (enable_tls) { + return ::arrow::internal::checked_cast(minio_env_https); + } else { + return ::arrow::internal::checked_cast(minio_env); + } } class ShortRetryStrategy : public S3RetryStrategy { @@ -202,10 +216,15 @@ class S3TestMixin : public AwsTestMixin { protected: Status InitServerAndClient() { - ARROW_ASSIGN_OR_RAISE(minio_, GetMinioEnv()->GetOneServer()); + ARROW_ASSIGN_OR_RAISE(minio_, GetMinioEnv(enable_tls_)->GetOneServer()); client_config_.reset(new Aws::Client::ClientConfiguration()); client_config_->endpointOverride = ToAwsString(minio_->connect_string()); - client_config_->scheme = Aws::Http::Scheme::HTTP; + if (minio_->scheme() == "https") { + client_config_->scheme = Aws::Http::Scheme::HTTPS; + client_config_->caFile = ToAwsString(minio_->ca_file_path()); + } else { + client_config_->scheme = Aws::Http::Scheme::HTTP; + } client_config_->retryStrategy = std::make_shared(kRetryInterval, kMaxRetryDuration); credentials_ = {ToAwsString(minio_->access_key()), ToAwsString(minio_->secret_key())}; @@ -224,6 +243,11 @@ class S3TestMixin : public AwsTestMixin { std::unique_ptr client_config_; Aws::Auth::AWSCredentials credentials_; std::unique_ptr client_; + // Use plain HTTP by default, as this allows us to listen on different loopback + // addresses and thus minimize the risk of address reuse (HTTPS requires the + // hostname to match the certificate's subject name, constraining us to a + // single loopback address). + bool enable_tls_ = false; }; void AssertGetObject(Aws::S3::Model::GetObjectResult& result, @@ -249,6 +273,27 @@ void AssertObjectContents(Aws::S3::S3Client* client, const std::string& bucket, AssertGetObject(result, expected); } +//////////////////////////////////////////////////////////////////////////// +// Misc tests + +class InternalsTest : public AwsTestMixin {}; + +TEST_F(InternalsTest, CalculateSSECustomerKeyMD5) { + ASSERT_RAISES(Invalid, CalculateSSECustomerKeyMD5("")); // invalid length + ASSERT_RAISES(Invalid, + CalculateSSECustomerKeyMD5( + "1234567890123456789012345678901234567890")); // invalid length + // valid case, with some non-ASCII character and a null byte in the sse_customer_key + char sse_customer_key[32] = {}; + sse_customer_key[0] = '\x40'; // '@' character + sse_customer_key[1] = '\0'; // null byte + sse_customer_key[2] = '\xFF'; // non-ASCII + sse_customer_key[31] = '\xFA'; // non-ASCII + std::string sse_customer_key_string(sse_customer_key, sizeof(sse_customer_key)); + ASSERT_OK_AND_ASSIGN(auto md5, CalculateSSECustomerKeyMD5(sse_customer_key_string)) + ASSERT_EQ(md5, "97FTa6lj0hE7lshKdBy61g=="); // valid case +} + //////////////////////////////////////////////////////////////////////////// // S3Options tests @@ -300,6 +345,17 @@ TEST_F(S3OptionsTest, FromUri) { ASSERT_EQ(options.scheme, "http"); ASSERT_EQ(options.endpoint_override, "localhost"); ASSERT_EQ(path, "mybucket/foo/bar"); + ASSERT_EQ(options.tls_verify_certificates, true); + + // Explicit tls related configuration + ASSERT_OK_AND_ASSIGN( + options, + S3Options::FromUri("s3://mybucket/foo/bar/?tls_ca_dir_path=/test&tls_ca_file_path=/" + "test/test.pem&tls_verify_certificates=false", + &path)); + ASSERT_EQ(options.tls_ca_dir_path, "/test"); + ASSERT_EQ(options.tls_ca_file_path, "/test/test.pem"); + ASSERT_EQ(options.tls_verify_certificates, false); // Missing bucket name ASSERT_RAISES(Invalid, S3Options::FromUri("s3:///foo/bar/", &path)); @@ -443,6 +499,9 @@ class TestS3FS : public S3TestMixin { // Most tests will create buckets options_.allow_bucket_creation = true; options_.allow_bucket_deletion = true; + if (enable_tls_) { + options_.tls_ca_file_path = minio_->ca_file_path(); + } MakeFileSystem(); // Set up test bucket { @@ -532,7 +591,7 @@ class TestS3FS : public S3TestMixin { Result> MakeNewFileSystem( io::IOContext io_context = io::default_io_context()) { options_.ConfigureAccessKey(minio_->access_key(), minio_->secret_key()); - options_.scheme = "http"; + options_.scheme = minio_->scheme(); options_.endpoint_override = minio_->connect_string(); if (!options_.retry_strategy) { options_.retry_strategy = std::make_shared(); @@ -1298,6 +1357,82 @@ TEST_F(TestS3FS, OpenInputFile) { ASSERT_RAISES(IOError, file->Seek(10)); } +// Minio only allows Server Side Encryption on HTTPS client connections. +#ifdef ENABLE_TLS_TESTS +class TestS3FSHTTPS : public TestS3FS { + public: + void SetUp() override { + enable_tls_ = true; + TestS3FS::SetUp(); + } +}; + +TEST_F(TestS3FSHTTPS, SSECustomerKeyMatch) { + // normal write/read with correct SSE-C key + std::shared_ptr stream; + options_.sse_customer_key = "12345678123456781234567812345678"; + for (const auto& allow_delayed_open : {false, true}) { + ARROW_SCOPED_TRACE("allow_delayed_open = ", allow_delayed_open); + options_.allow_delayed_open = allow_delayed_open; + MakeFileSystem(); + ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile_with_sse_c")); + ASSERT_OK(stream->Write("some")); + ASSERT_OK(stream->Close()); + ASSERT_OK_AND_ASSIGN(auto file, fs_->OpenInputFile("bucket/newfile_with_sse_c")); + ASSERT_OK_AND_ASSIGN(auto buf, file->Read(5)); + AssertBufferEqual(*buf, "some"); + ASSERT_OK(RestoreTestBucket()); + } +} + +TEST_F(TestS3FSHTTPS, SSECustomerKeyMismatch) { + std::shared_ptr stream; + for (const auto& allow_delayed_open : {false, true}) { + ARROW_SCOPED_TRACE("allow_delayed_open = ", allow_delayed_open); + options_.allow_delayed_open = allow_delayed_open; + options_.sse_customer_key = "12345678123456781234567812345678"; + MakeFileSystem(); + ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile_with_sse_c")); + ASSERT_OK(stream->Write("some")); + ASSERT_OK(stream->Close()); + options_.sse_customer_key = "87654321876543218765432187654321"; + MakeFileSystem(); + ASSERT_RAISES(IOError, fs_->OpenInputFile("bucket/newfile_with_sse_c")); + ASSERT_OK(RestoreTestBucket()); + } +} + +TEST_F(TestS3FSHTTPS, SSECustomerKeyMissing) { + std::shared_ptr stream; + for (const auto& allow_delayed_open : {false, true}) { + ARROW_SCOPED_TRACE("allow_delayed_open = ", allow_delayed_open); + options_.allow_delayed_open = allow_delayed_open; + options_.sse_customer_key = "12345678123456781234567812345678"; + MakeFileSystem(); + ASSERT_OK_AND_ASSIGN(stream, fs_->OpenOutputStream("bucket/newfile_with_sse_c")); + ASSERT_OK(stream->Write("some")); + ASSERT_OK(stream->Close()); + + options_.sse_customer_key = {}; + MakeFileSystem(); + ASSERT_RAISES(IOError, fs_->OpenInputFile("bucket/newfile_with_sse_c")); + ASSERT_OK(RestoreTestBucket()); + } +} + +TEST_F(TestS3FSHTTPS, SSECustomerKeyCopyFile) { + ASSERT_OK_AND_ASSIGN(auto stream, fs_->OpenOutputStream("bucket/newfile_with_sse_c")); + ASSERT_OK(stream->Write("some")); + ASSERT_OK(stream->Close()); + ASSERT_OK(fs_->CopyFile("bucket/newfile_with_sse_c", "bucket/copied_with_sse_c")); + + ASSERT_OK_AND_ASSIGN(auto file, fs_->OpenInputFile("bucket/copied_with_sse_c")); + ASSERT_OK_AND_ASSIGN(auto buf, file->Read(5)); + AssertBufferEqual(*buf, "some"); + ASSERT_OK(RestoreTestBucket()); +} +#endif // ENABLE_TLS_TESTS + struct S3OptionsTestParameters { bool background_writes{false}; bool allow_delayed_open{false}; @@ -1420,7 +1555,8 @@ TEST_F(TestS3FS, FileSystemFromUri) { std::stringstream ss; ss << "s3://" << minio_->access_key() << ":" << minio_->secret_key() << "@bucket/somedir/subdir/subfile" - << "?scheme=http&endpoint_override=" << UriEscape(minio_->connect_string()); + << "?scheme=" << minio_->scheme() + << "&endpoint_override=" << UriEscape(minio_->connect_string()); std::string path; ASSERT_OK_AND_ASSIGN(auto fs, FileSystemFromUri(ss.str(), &path)); @@ -1522,7 +1658,7 @@ class TestS3FSGeneric : public S3TestMixin, public GenericFileSystemTest { } options_.ConfigureAccessKey(minio_->access_key(), minio_->secret_key()); - options_.scheme = "http"; + options_.scheme = minio_->scheme(); options_.endpoint_override = minio_->connect_string(); options_.retry_strategy = std::make_shared(); ASSERT_OK_AND_ASSIGN(s3fs_, S3FileSystem::Make(options_)); diff --git a/cpp/src/arrow/testing/util.cc b/cpp/src/arrow/testing/util.cc index 7bef9f7d475..e5e53801df9 100644 --- a/cpp/src/arrow/testing/util.cc +++ b/cpp/src/arrow/testing/util.cc @@ -206,6 +206,12 @@ std::string GetListenAddress() { return ss.str(); } +std::string GetListenAddress(const std::string& host) { + std::stringstream ss; + ss << host << ":" << GetListenPort(); + return ss.str(); +} + const std::vector>& all_dictionary_index_types() { static std::vector> types = { int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()}; diff --git a/cpp/src/arrow/testing/util.h b/cpp/src/arrow/testing/util.h index b4b2785a362..8cc28a8b073 100644 --- a/cpp/src/arrow/testing/util.h +++ b/cpp/src/arrow/testing/util.h @@ -128,6 +128,10 @@ ARROW_TESTING_EXPORT int GetListenPort(); // port conflicts. ARROW_TESTING_EXPORT std::string GetListenAddress(); +// Get a "host:port" to listen on. Compared to GetListenAddress(), this function would use +// the host passed in. +ARROW_TESTING_EXPORT std::string GetListenAddress(const std::string& host); + ARROW_TESTING_EXPORT const std::vector>& all_dictionary_index_types();