Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@
"ossl",
"ccrng",
"KEYWRAP",
"HKDF",
"NVME",
// EC2
"IMDS",
Expand Down
13 changes: 13 additions & 0 deletions src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ namespace Aws

Aws::String GetChecksum() const { return m_checksum; };
void SetChecksum(const Aws::String& checksum) { m_checksum = checksum; }

std::shared_ptr<Aws::Utils::Crypto::Hash> GetChecksumHash() const { return m_checksumHash; }
void SetChecksumHash(std::shared_ptr<Aws::Utils::Crypto::Hash> hash) { m_checksumHash = hash; }
private:

int m_partId = 0;
Expand All @@ -93,6 +96,7 @@ namespace Aws
std::atomic<unsigned char*> m_downloadBuffer;
bool m_lastPart = false;
Aws::String m_checksum;
std::shared_ptr<Aws::Utils::Crypto::Hash> m_checksumHash;
};

using PartPointer = std::shared_ptr< PartState >;
Expand Down Expand Up @@ -389,6 +393,13 @@ namespace Aws
Aws::String GetChecksum() const { return m_checksum; }
void SetChecksum(const Aws::String& checksum) { this->m_checksum = checksum; }

void SetPartChecksum(int partId, std::shared_ptr<Aws::Utils::Crypto::Hash> hash) { m_partChecksums[partId] = hash; }
std::shared_ptr<Aws::Utils::Crypto::Hash> GetPartChecksum(int partId) const {
auto it = m_partChecksums.find(partId);
return it != m_partChecksums.end() ? it->second : nullptr;
}
const Aws::Map<int, std::shared_ptr<Aws::Utils::Crypto::Hash>>& GetPartChecksums() const { return m_partChecksums; }

private:
void CleanupDownloadStream();

Expand Down Expand Up @@ -430,6 +441,8 @@ namespace Aws
mutable std::condition_variable m_waitUntilFinishedSignal;
mutable std::mutex m_getterSetterLock;
Aws::String m_checksum;
// Map of part number to Hash instance for multipart download checksum validation
Aws::Map<int, std::shared_ptr<Aws::Utils::Crypto::Hash>> m_partChecksums;
};

AWS_TRANSFER_API Aws::OStream& operator << (Aws::OStream& s, TransferStatus status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ namespace Aws
* upload. Defaults to CRC64-NVME.
*/
Aws::S3::Model::ChecksumAlgorithm checksumAlgorithm = S3::Model::ChecksumAlgorithm::CRC64NVME;

/**
* Enable checksum validation for downloads. When enabled, checksums will be
* calculated during download and validated against S3 response headers.
* Defaults to true.
*/
bool validateChecksums = true;
};

/**
Expand Down
183 changes: 163 additions & 20 deletions src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>
#include <aws/core/utils/memory/stl/AWSStringStream.h>
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
#include <aws/crt/checksum/CRC.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/model/AbortMultipartUploadRequest.h>
#include <aws/s3/model/CompleteMultipartUploadRequest.h>
Expand Down Expand Up @@ -51,6 +52,42 @@ namespace Aws
}
}

static std::shared_ptr<Utils::Crypto::Hash> CreateHashForAlgorithm(S3::Model::ChecksumAlgorithm algorithm) {
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) {
return Aws::MakeShared<Utils::Crypto::CRC32>(CLASS_TAG);
}
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) {
return Aws::MakeShared<Utils::Crypto::CRC32C>(CLASS_TAG);
}
if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) {
return Aws::MakeShared<Utils::Crypto::Sha1>(CLASS_TAG);
}
if (algorithm == S3::Model::ChecksumAlgorithm::SHA256) {
return Aws::MakeShared<Utils::Crypto::Sha256>(CLASS_TAG);
}
return Aws::MakeShared<Utils::Crypto::CRC64>(CLASS_TAG);
}

template <typename ResultT>
static Aws::String GetChecksumFromResult(const ResultT& result, S3::Model::ChecksumAlgorithm algorithm) {
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) {
return result.GetChecksumCRC32();
}
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) {
return result.GetChecksumCRC32C();
}
if (algorithm == S3::Model::ChecksumAlgorithm::CRC64NVME) {
return result.GetChecksumCRC64NVME();
}
if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) {
return result.GetChecksumSHA1();
}
if (algorithm == S3::Model::ChecksumAlgorithm::SHA256) {
return result.GetChecksumSHA256();
}
return "";
}

struct TransferHandleAsyncContext : public Aws::Client::AsyncCallerContext
{
std::shared_ptr<TransferHandle> handle;
Expand Down Expand Up @@ -664,26 +701,7 @@ namespace Aws
{
if (handle->ShouldContinue())
{
partState->SetChecksum([&]() -> Aws::String {
if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32)
{
return outcome.GetResult().GetChecksumCRC32();
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C)
{
return outcome.GetResult().GetChecksumCRC32C();
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1)
{
return outcome.GetResult().GetChecksumSHA1();
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA256)
{
return outcome.GetResult().GetChecksumSHA256();
}
//Return empty checksum for not set.
return "";
}());
partState->SetChecksum(GetChecksumFromResult(outcome.GetResult(), m_transferConfig.checksumAlgorithm));
handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag());
AWS_LOGSTREAM_DEBUG(CLASS_TAG, "Transfer handle [" << handle->GetId()
<< " successfully uploaded Part: [" << partState->GetPartId() << "] to Bucket: ["
Expand Down Expand Up @@ -938,6 +956,61 @@ namespace Aws
handle->SetContentType(getObjectOutcome.GetResult().GetContentType());
handle->ChangePartToCompleted(partState, getObjectOutcome.GetResult().GetETag());
getObjectOutcome.GetResult().GetBody().flush();

// Validate checksum for single-part download by reading file
if (m_transferConfig.validateChecksums)
{
Aws::String expectedChecksum = GetChecksumFromResult(getObjectOutcome.GetResult(), m_transferConfig.checksumAlgorithm);

if (!expectedChecksum.empty() && !handle->GetTargetFilePath().empty())
{
auto hash = CreateHashForAlgorithm(m_transferConfig.checksumAlgorithm);
Aws::IFStream fileStream(handle->GetTargetFilePath().c_str(), std::ios::binary);

if (fileStream.good())
{
const size_t bufferSize = 8192;
char buffer[bufferSize];
while (fileStream.good())
{
fileStream.read(buffer, bufferSize);
std::streamsize bytesRead = fileStream.gcount();
if (bytesRead > 0)
{
hash->Update(reinterpret_cast<unsigned char*>(buffer), static_cast<size_t>(bytesRead));
}
}
fileStream.close();

auto calculatedResult = hash->GetHash();
if (calculatedResult.IsSuccess())
{
Aws::String calculatedChecksum = Utils::HashingUtils::Base64Encode(calculatedResult.GetResult());
if (calculatedChecksum != expectedChecksum)
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId()
<< "] Checksum mismatch for single-part download. Expected: "
<< expectedChecksum << ", Calculated: " << calculatedChecksum);

// Delete the corrupted file
Aws::FileSystem::RemoveFileIfExists(handle->GetTargetFilePath().c_str());

handle->ChangePartToFailed(partState);
handle->UpdateStatus(TransferStatus::FAILED);
Aws::Client::AWSError<Aws::S3::S3Errors> error(Aws::S3::S3Errors::INTERNAL_FAILURE,
"ChecksumMismatch",
"Single-part download checksum validation failed",
false);
handle->SetError(error);
TriggerErrorCallback(handle, error);
TriggerTransferStatusUpdatedCallback(handle);
return;
}
}
}
}
}

handle->UpdateStatus(TransferStatus::COMPLETED);
}
else
Expand Down Expand Up @@ -1074,6 +1147,12 @@ namespace Aws
{
partState->SetDownloadBuffer(buffer);

// Initialize checksum Hash for this part if validation is enabled
if (m_transferConfig.validateChecksums)
{
handle->SetPartChecksum(partState->GetPartId(), CreateHashForAlgorithm(m_transferConfig.checksumAlgorithm));
}

auto getObjectRangeRequest = m_transferConfig.getObjectTemplate;
getObjectRangeRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag);
getObjectRangeRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); });
Expand Down Expand Up @@ -1239,6 +1318,70 @@ namespace Aws
{
if (failedParts.size() == 0 && handle->GetBytesTransferred() == handle->GetBytesTotalSize())
{
// Combine part checksums and validate full-object checksum
if (m_transferConfig.validateChecksums)
{
Aws::String expectedChecksum = GetChecksumFromResult(outcome.GetResult(), m_transferConfig.checksumAlgorithm);
if (!expectedChecksum.empty())
{
auto combinedChecksum = 0ULL;
bool isCRC64 = (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC64NVME);

for (auto& partChecksum : handle->GetPartChecksums())
{
int partNumber = partChecksum.first;
auto hash = partChecksum.second;

// Get part size from completed parts
auto partSize = handle->GetCompletedParts()[partNumber]->GetSizeInBytes();

auto partResult = hash->GetHash();
auto partData = partResult.GetResult();

if (combinedChecksum == 0) {
if (isCRC64) {
combinedChecksum = *reinterpret_cast<const unsigned long long*>(partData.GetUnderlyingData());
} else {
combinedChecksum = *reinterpret_cast<const unsigned int*>(partData.GetUnderlyingData());
}
} else {
if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32) {
auto partCrc = *reinterpret_cast<const unsigned int*>(partData.GetUnderlyingData());
combinedChecksum = Aws::Crt::Checksum::CombineCRC32(static_cast<uint32_t>(combinedChecksum), partCrc, partSize);
} else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C) {
auto partCrc = *reinterpret_cast<const unsigned int*>(partData.GetUnderlyingData());
combinedChecksum = Aws::Crt::Checksum::CombineCRC32C(static_cast<uint32_t>(combinedChecksum), partCrc, partSize);
} else if (isCRC64) {
auto partCrc = *reinterpret_cast<const unsigned long long*>(partData.GetUnderlyingData());
combinedChecksum = Aws::Crt::Checksum::CombineCRC64NVME(combinedChecksum, partCrc, partSize);
}
}
}

// Compare with expected checksum
Aws::Utils::ByteBuffer checksumBuffer(isCRC64 ? 8 : 4);
if (isCRC64) {
*reinterpret_cast<unsigned long long*>(checksumBuffer.GetUnderlyingData()) = combinedChecksum;
} else {
*reinterpret_cast<unsigned int*>(checksumBuffer.GetUnderlyingData()) = static_cast<unsigned int>(combinedChecksum);
}
Aws::String calculatedChecksum = Utils::HashingUtils::Base64Encode(checksumBuffer);

if (calculatedChecksum != expectedChecksum) {
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId()
<< "] Full-object checksum mismatch. Expected: " << expectedChecksum
<< ", Calculated: " << calculatedChecksum);
Aws::Client::AWSError<Aws::S3::S3Errors> error(Aws::S3::S3Errors::INTERNAL_FAILURE,
"ChecksumMismatch",
"Full-object checksum validation failed",
false);
handle->SetError(error);
handle->UpdateStatus(TransferStatus::FAILED);
TriggerErrorCallback(handle, error);
return;
}
}
}
outcome.GetResult().GetBody().flush();
handle->UpdateStatus(TransferStatus::COMPLETED);
}
Expand Down
Loading