diff --git a/.cspell.json b/.cspell.json index 0697ccc4edd..b72ea55deb2 100644 --- a/.cspell.json +++ b/.cspell.json @@ -261,6 +261,7 @@ "ossl", "ccrng", "KEYWRAP", + "HKDF", "NVME", // EC2 "IMDS", diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h index 0d062be1e00..a5e1505b014 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h @@ -79,6 +79,9 @@ namespace Aws Aws::String GetChecksum() const { return m_checksum; }; void SetChecksum(const Aws::String& checksum) { m_checksum = checksum; } + + std::shared_ptr GetChecksumHash() const { return m_checksumHash; } + void SetChecksumHash(std::shared_ptr hash) { m_checksumHash = hash; } private: int m_partId = 0; @@ -93,6 +96,7 @@ namespace Aws std::atomic m_downloadBuffer; bool m_lastPart = false; Aws::String m_checksum; + std::shared_ptr m_checksumHash; }; using PartPointer = std::shared_ptr< PartState >; @@ -389,6 +393,15 @@ namespace Aws Aws::String GetChecksum() const { return m_checksum; } void SetChecksum(const Aws::String& checksum) { this->m_checksum = checksum; } + void SetPartChecksum(int partId, const Aws::String& checksum, uint64_t size) { + m_partChecksums[partId] = std::make_pair(checksum, size); + } + std::pair GetPartChecksum(int partId) const { + auto it = m_partChecksums.find(partId); + return it != m_partChecksums.end() ? it->second : std::make_pair("", 0); + } + const Aws::Map>& GetPartChecksums() const { return m_partChecksums; } + private: void CleanupDownloadStream(); @@ -430,6 +443,8 @@ namespace Aws mutable std::condition_variable m_waitUntilFinishedSignal; mutable std::mutex m_getterSetterLock; Aws::String m_checksum; + // Map of part number to Hash instance for multipart download checksum validation + Aws::Map> m_partChecksums; }; AWS_TRANSFER_API Aws::OStream& operator << (Aws::OStream& s, TransferStatus status); diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h index a4b5580fd6e..725f14c1219 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h @@ -144,6 +144,13 @@ namespace Aws * upload. Defaults to CRC64-NVME. */ Aws::S3::Model::ChecksumAlgorithm checksumAlgorithm = S3::Model::ChecksumAlgorithm::CRC64NVME; + + /** + * Enable checksum validation for downloads. When enabled, checksums will be + * calculated during download and validated against S3 response headers. + * Defaults to true. + */ + bool validateChecksums = true; }; /** diff --git a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp index 996e427e114..c5874501bcc 100644 --- a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp +++ b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -51,6 +52,28 @@ namespace Aws } } + + + template + static Aws::String GetChecksumFromResult(const ResultT& result, S3::Model::ChecksumAlgorithm algorithm) { + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) { + return result.GetChecksumCRC32(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) { + return result.GetChecksumCRC32C(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::CRC64NVME) { + return result.GetChecksumCRC64NVME(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) { + return result.GetChecksumSHA1(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA256) { + return result.GetChecksumSHA256(); + } + return ""; + } + struct TransferHandleAsyncContext : public Aws::Client::AsyncCallerContext { std::shared_ptr handle; @@ -664,26 +687,7 @@ namespace Aws { if (handle->ShouldContinue()) { - partState->SetChecksum([&]() -> Aws::String { - if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32) - { - return outcome.GetResult().GetChecksumCRC32(); - } - else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C) - { - return outcome.GetResult().GetChecksumCRC32C(); - } - else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1) - { - return outcome.GetResult().GetChecksumSHA1(); - } - else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA256) - { - return outcome.GetResult().GetChecksumSHA256(); - } - //Return empty checksum for not set. - return ""; - }()); + partState->SetChecksum(GetChecksumFromResult(outcome.GetResult(), m_transferConfig.checksumAlgorithm)); handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag()); AWS_LOGSTREAM_DEBUG(CLASS_TAG, "Transfer handle [" << handle->GetId() << " successfully uploaded Part: [" << partState->GetPartId() << "] to Bucket: [" @@ -938,6 +942,7 @@ namespace Aws handle->SetContentType(getObjectOutcome.GetResult().GetContentType()); handle->ChangePartToCompleted(partState, getObjectOutcome.GetResult().GetETag()); getObjectOutcome.GetResult().GetBody().flush(); + handle->UpdateStatus(TransferStatus::COMPLETED); } else @@ -1074,6 +1079,12 @@ namespace Aws { partState->SetDownloadBuffer(buffer); + // Initialize checksum Hash for this part if validation is enabled + if (m_transferConfig.validateChecksums) + { + handle->SetPartChecksum(partState->GetPartId(), partState->GetChecksum(), partState->GetSizeInBytes()); + } + auto getObjectRangeRequest = m_transferConfig.getObjectTemplate; getObjectRangeRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag); getObjectRangeRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); }); @@ -1239,6 +1250,66 @@ namespace Aws { if (failedParts.size() == 0 && handle->GetBytesTransferred() == handle->GetBytesTotalSize()) { + if (m_transferConfig.validateChecksums) { + auto checksumType = outcome.GetResult().GetChecksumType(); + if (checksumType == S3::Model::ChecksumType::FULL_OBJECT) { + Aws::String expectedChecksum = GetChecksumFromResult(outcome.GetResult(),m_transferConfig.checksumAlgorithm); + if (!expectedChecksum.empty()) { + auto combinedChecksum = 0ULL; + bool isCRC64 = (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC64NVME); + + for (auto& partChecksum : handle->GetPartChecksums()) { + Aws::String checksumStr = partChecksum.second.first; + uint64_t partSize = partChecksum.second.second; + + auto decoded = Aws::Utils::HashingUtils::Base64Decode(checksumStr); + + if (combinedChecksum == 0) { + if (isCRC64) { + combinedChecksum = *reinterpret_cast(decoded.GetUnderlyingData()); + } + else { + combinedChecksum = *reinterpret_cast(decoded.GetUnderlyingData()); + } + } + else { + if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32) { + auto partCrc = *reinterpret_cast(decoded.GetUnderlyingData()); + combinedChecksum = Aws::Crt::Checksum::CombineCRC32(static_cast(combinedChecksum), partCrc, partSize); + } else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C){ + auto partCrc = *reinterpret_cast(decoded.GetUnderlyingData()); + combinedChecksum = Aws::Crt::Checksum::CombineCRC32C(static_cast(combinedChecksum), partCrc, partSize); + } else if (isCRC64) { + auto partCrc = *reinterpret_cast(decoded.GetUnderlyingData()); + combinedChecksum = Aws::Crt::Checksum::CombineCRC64NVME(combinedChecksum, partCrc, partSize); + } + } + } + + Aws::Utils::ByteBuffer checksumBuffer(isCRC64 ? 8 : 4); + if (isCRC64) { + *reinterpret_cast(checksumBuffer.GetUnderlyingData()) = combinedChecksum; + } else { + *reinterpret_cast(checksumBuffer.GetUnderlyingData()) = static_cast(combinedChecksum); + } + Aws::String calculatedChecksum = Utils::HashingUtils::Base64Encode(checksumBuffer); + + if (calculatedChecksum != expectedChecksum) { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId() + << "] Full-object checksum mismatch. Expected: " << expectedChecksum + << ", Calculated: " << calculatedChecksum); + Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, + "ChecksumMismatch", + "Full-object checksum validation failed", + false); + handle->SetError(error); + handle->UpdateStatus(TransferStatus::FAILED); + TriggerErrorCallback(handle, error); + return; + } + } + } + } outcome.GetResult().GetBody().flush(); handle->UpdateStatus(TransferStatus::COMPLETED); }