diff --git a/s3fs/core.py b/s3fs/core.py index 508bdcbc..c60b66e7 100644 --- a/s3fs/core.py +++ b/s3fs/core.py @@ -36,6 +36,7 @@ def setup_logging(level=None): if "S3FS_LOGGING_LEVEL" in os.environ: setup_logging() +MANAGED_COPY_THRESHOLD = 5 * 2 ** 30 S3_RETRYABLE_ERRORS = (socket.timeout,) _VALID_FILE_MODES = {"r", "w", "a", "rb", "wb", "ab"} @@ -1311,6 +1312,54 @@ async def _copy_basic(self, path1, path2, **kwargs): raise ValueError("Copy failed (%r -> %r): %s" % (path1, path2, e)) from e self.invalidate_cache(path2) + async def _copy_etag_preserved(self, path1, path2, size, total_parts, **kwargs): + """Copy file between locations on S3 as multi-part while preserving + the etag (using the same part sizes for each part""" + + bucket1, key1, version1 = self.split_path(path1) + bucket2, key2, version2 = self.split_path(path2) + + mpu = await self._call_s3( + self.s3.create_multipart_upload, Bucket=bucket2, Key=key2, **kwargs + ) + part_infos = await asyncio.gather( + *[ + self._call_s3( + self.s3.head_object, Bucket=bucket1, Key=key1, PartNumber=i + ) + for i in range(1, total_parts + 1) + ] + ) + + parts = [] + brange_first = 0 + for i, part_info in enumerate(part_infos, 1): + part_size = part_info["ContentLength"] + brange_last = brange_first + part_size - 1 + if brange_last > size: + brange_last = size - 1 + + part = await self._call_s3( + self.s3.upload_part_copy, + Bucket=bucket2, + Key=key2, + PartNumber=i, + UploadId=mpu["UploadId"], + CopySource=path1, + CopySourceRange="bytes=%i-%i" % (brange_first, brange_last), + ) + parts.append({"PartNumber": i, "ETag": part["CopyPartResult"]["ETag"]}) + brange_first += part_size + + await self._call_s3( + self.s3.complete_multipart_upload, + Bucket=bucket2, + Key=key2, + UploadId=mpu["UploadId"], + MultipartUpload={"Parts": parts}, + ) + self.invalidate_cache(path2) + async def _copy_managed(self, path1, path2, size, block=5 * 2 ** 30, **kwargs): """Copy file between locations on S3 as multi-part @@ -1351,15 +1400,36 @@ async def _copy_managed(self, path1, path2, size, block=5 * 2 ** 30, **kwargs): ) self.invalidate_cache(path2) - async def _cp_file(self, path1, path2, **kwargs): - gb5 = 5 * 2 ** 30 + async def _cp_file(self, path1, path2, preserve_etag=None, **kwargs): + """Copy file between locations on S3. + + preserve_etag: bool + Whether to preserve etag while copying. If the file is uploaded + as a single part, then it will be always equalivent to the md5 + hash of the file hence etag will always be preserved. But if the + file is uploaded in multi parts, then this option will try to + reproduce the same multipart upload while copying and preserve + the generated etag. + """ path1 = self._strip_protocol(path1) bucket, key, vers = self.split_path(path1) - size = (await self._info(path1, bucket, key, version_id=vers))["size"] - if size <= gb5: + + info = await self._info(path1, bucket, key, version_id=vers) + size = info["size"] + + _, _, parts_suffix = info["ETag"].strip('"').partition("-") + if preserve_etag and parts_suffix: + await self._copy_etag_preserved( + path1, path2, size, total_parts=int(parts_suffix) + ) + elif size <= MANAGED_COPY_THRESHOLD: # simple copy allowed for <5GB await self._copy_basic(path1, path2, **kwargs) else: + # if the preserve_etag is true, either the file is uploaded + # on multiple parts or the size is lower than 5GB + assert not preserve_etag + # serial multipart copy await self._copy_managed(path1, path2, size, **kwargs) diff --git a/s3fs/tests/test_s3fs.py b/s3fs/tests/test_s3fs.py index dbe14f1a..e1d253d5 100644 --- a/s3fs/tests/test_s3fs.py +++ b/s3fs/tests/test_s3fs.py @@ -7,6 +7,7 @@ from concurrent.futures import ProcessPoolExecutor import io import os +import random import requests import time import sys @@ -14,6 +15,7 @@ import moto from itertools import chain import fsspec.core +import s3fs.core from s3fs.core import S3FileSystem from s3fs.utils import ignoring, SSEParams from botocore.exceptions import NoCredentialsError @@ -1921,3 +1923,35 @@ def create_file(content: bytes): create_file(content2) with expect_errno(errno.EBUSY): f.read() + + +def test_s3fs_etag_preserving_multipart_copy(monkeypatch, s3): + # Set this to a lower value so that we can actually + # test this without creating giant objects in memory + monkeypatch.setattr(s3fs.core, "MANAGED_COPY_THRESHOLD", 5 * 2 ** 20) + + test_file1 = test_bucket_name + "/test/multipart-upload.txt" + test_file2 = test_bucket_name + "/test/multipart-upload-copy.txt" + + with s3.open(test_file1, "wb", block_size=5 * 2 ** 21) as stream: + for _ in range(5): + stream.write(b"b" * (stream.blocksize + random.randrange(200))) + + file_1 = s3.info(test_file1) + + s3.copy(test_file1, test_file2) + file_2 = s3.info(test_file2) + s3.rm(test_file2) + + # normal copy() uses a block size of 5GB + assert file_1["ETag"] != file_2["ETag"] + + s3.copy(test_file1, test_file2, preserve_etag=True) + file_2 = s3.info(test_file2) + s3.rm(test_file2) + + # etag preserving copy() determines each part size for the destination + # by checking out the matching part's size on the source + assert file_1["ETag"] == file_2["ETag"] + + s3.rm(test_file1)