Skip to content

Commit

Permalink
Calculate checksum from local file if upload optimization succeeds (#…
Browse files Browse the repository at this point in the history
…3968)

Co-authored-by: Alexei Mochalov <[email protected]>
  • Loading branch information
sir-sigurd and nl0 authored Jun 18, 2024
1 parent 988bc30 commit 358feaf
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 62 deletions.
144 changes: 87 additions & 57 deletions api/python/quilt3/data_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def get_checksum_chunksize(file_size: int) -> int:
return chunksize


def is_mpu(file_size: int) -> bool:
return file_size >= CHECKSUM_MULTIPART_THRESHOLD


_EMPTY_STRING_SHA256 = hashlib.sha256(b'').digest()


Expand Down Expand Up @@ -303,7 +307,7 @@ def _copy_local_file(ctx: WorkerContext, size: int, src_path: str, dest_path: st
def _upload_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_key: str):
s3_client = ctx.s3_client_provider.standard_client

if size < CHECKSUM_MULTIPART_THRESHOLD:
if not is_mpu(size):
with ReadFileChunk.from_filename(src_path, 0, size, [ctx.progress]) as fd:
resp = s3_client.put_object(
Body=fd,
Expand Down Expand Up @@ -460,7 +464,7 @@ def _copy_remote_file(ctx: WorkerContext, size: int, src_bucket: str, src_key: s

s3_client = ctx.s3_client_provider.standard_client

if size < CHECKSUM_MULTIPART_THRESHOLD:
if not is_mpu(size):
params: Dict[str, Any] = dict(
CopySource=src_params,
Bucket=dest_bucket,
Expand Down Expand Up @@ -530,43 +534,62 @@ def upload_part(i, start, end):
ctx.run(upload_part, i, start, end)


def _upload_or_copy_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_path: str):
def _calculate_local_checksum(path: str, size: int):
chunksize = get_checksum_chunksize(size)

part_hashes = []
for start in range(0, size, chunksize):
end = min(start + chunksize, size)
part_hashes.append(_calculate_local_part_checksum(path, start, end - start))

return _make_checksum_from_parts(part_hashes)


def _reuse_remote_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_path: str):
# Optimization: check if the remote file already exists and has the right ETag,
# and skip the upload.
if size >= UPLOAD_ETAG_OPTIMIZATION_THRESHOLD:
try:
params = dict(Bucket=dest_bucket, Key=dest_path)
s3_client = ctx.s3_client_provider.find_correct_client(S3Api.HEAD_OBJECT, dest_bucket, params)
resp = s3_client.head_object(**params, ChecksumMode='ENABLED')
except ClientError:
# Destination doesn't exist, so fall through to the normal upload.
pass
except S3NoValidClientError:
# S3ClientProvider can't currently distinguish between a user that has PUT but not LIST permissions and a
# user that has no permissions. If we can't find a valid client, proceed to the upload stage anyway.
pass
else:
# Check the ETag.
dest_size = resp['ContentLength']
dest_etag = resp['ETag']
dest_version_id = resp.get('VersionId')
if size == dest_size and resp.get('ServerSideEncryption') != 'aws:kms':
src_etag = _calculate_etag(src_path)
if src_etag == dest_etag:
# Nothing more to do. We should not attempt to copy the object because
# that would cause the "copy object to itself" error.
# TODO: Check SHA256 before checking ETag?
s3_checksum = resp.get('ChecksumSHA256')
if s3_checksum is None:
checksum = None
elif '-' in s3_checksum:
checksum, _ = s3_checksum.split('-', 1)
else:
checksum = _simple_s3_to_quilt_checksum(s3_checksum)
ctx.progress(size)
ctx.done(PhysicalKey(dest_bucket, dest_path, dest_version_id), checksum)
return # Optimization succeeded.
if size < UPLOAD_ETAG_OPTIMIZATION_THRESHOLD:
return None
try:
params = dict(Bucket=dest_bucket, Key=dest_path)
s3_client = ctx.s3_client_provider.find_correct_client(S3Api.HEAD_OBJECT, dest_bucket, params)
resp = s3_client.head_object(**params, ChecksumMode="ENABLED")
except ClientError:
# Destination doesn't exist, so fall through to the normal upload.
pass
except S3NoValidClientError:
# S3ClientProvider can't currently distinguish between a user that has PUT but not LIST permissions and a
# user that has no permissions. If we can't find a valid client, proceed to the upload stage anyway.
pass
else:
dest_size = resp["ContentLength"]
if dest_size != size:
return None
# TODO: we could check hashes of parts, to finish faster
s3_checksum = resp.get("ChecksumSHA256")
if s3_checksum is not None:
if "-" in s3_checksum:
checksum, num_parts_str = s3_checksum.split("-", 1)
num_parts = int(num_parts_str)
else:
checksum = _simple_s3_to_quilt_checksum(s3_checksum)
num_parts = None
expected_num_parts = math.ceil(size / get_checksum_chunksize(size)) if is_mpu(size) else None
if num_parts == expected_num_parts and checksum == _calculate_local_checksum(src_path, size):
return resp.get("VersionId"), checksum
elif resp.get("ServerSideEncryption") != "aws:kms" and resp["ETag"] == _calculate_etag(src_path):
return resp.get("VersionId"), _calculate_local_checksum(src_path, size)

return None


def _upload_or_reuse_file(ctx: WorkerContext, size: int, src_path: str, dest_bucket: str, dest_path: str):
result = _reuse_remote_file(ctx, size, src_path, dest_bucket, dest_path)
if result is not None:
dest_version_id, checksum = result
ctx.progress(size)
ctx.done(PhysicalKey(dest_bucket, dest_path, dest_version_id), checksum)
return # Optimization succeeded.
# If the optimization didn't happen, do the normal upload.
_upload_file(ctx, size, src_path, dest_bucket, dest_path)

Expand Down Expand Up @@ -648,7 +671,7 @@ def done_callback(value, checksum):
else:
if dest.version_id:
raise ValueError("Cannot set VersionId on destination")
_upload_or_copy_file(ctx, size, src.path, dest.bucket, dest.path)
_upload_or_reuse_file(ctx, size, src.path, dest.bucket, dest.path)
else:
if dest.is_local():
_download_file(ctx, size, src.bucket, src.path, src.version_id, dest.path)
Expand Down Expand Up @@ -701,7 +724,7 @@ def _calculate_etag(file_path):
"""
size = pathlib.Path(file_path).stat().st_size
with open(file_path, 'rb') as fd:
if size < CHECKSUM_MULTIPART_THRESHOLD:
if not is_mpu(size):
contents = fd.read()
etag = hashlib.md5(contents).hexdigest()
else:
Expand Down Expand Up @@ -970,6 +993,28 @@ def wrapper(*args, **kwargs):
return wrapper


def _calculate_local_part_checksum(src: str, offset: int, length: int, callback=None) -> bytes:
hash_obj = hashlib.sha256()
bytes_remaining = length
with open(src, "rb") as fd:
fd.seek(offset)
while bytes_remaining > 0:
chunk = fd.read(min(s3_transfer_config.io_chunksize, bytes_remaining))
if not chunk:
# Should not happen, but let's not get stuck in an infinite loop.
raise QuiltException("Unexpected end of file")
hash_obj.update(chunk)
if callback is not None:
callback(len(chunk))
bytes_remaining -= len(chunk)

return hash_obj.digest()


def _make_checksum_from_parts(parts: List[bytes]) -> str:
return binascii.b2a_base64(hashlib.sha256(b"".join(parts)).digest(), newline=False).decode()


@retry(stop=stop_after_attempt(MAX_FIX_HASH_RETRIES),
wait=wait_exponential(multiplier=1, min=1, max=10),
retry=retry_if_result(lambda results: any(r is None or isinstance(r, Exception) for r in results)),
Expand All @@ -990,21 +1035,10 @@ def _calculate_checksum_internal(src_list, sizes, results) -> List[bytes]:
progress_update = with_lock(progress.update)

def _process_url_part(src: PhysicalKey, offset: int, length: int):
hash_obj = hashlib.sha256()

if src.is_local():
bytes_remaining = length
with open(src.path, 'rb') as fd:
fd.seek(offset)
while bytes_remaining > 0:
chunk = fd.read(min(s3_transfer_config.io_chunksize, bytes_remaining))
if not chunk:
# Should not happen, but let's not get stuck in an infinite loop.
raise QuiltException("Unexpected end of file")
hash_obj.update(chunk)
progress_update(len(chunk))
bytes_remaining -= len(chunk)
return _calculate_local_part_checksum(src.path, offset, length, progress_update)
else:
hash_obj = hashlib.sha256()
end = offset + length - 1
params = dict(
Bucket=src.bucket,
Expand All @@ -1026,7 +1060,7 @@ def _process_url_part(src: PhysicalKey, offset: int, length: int):
except (ConnectionError, HTTPClientError, ReadTimeoutError) as ex:
return ex

return hash_obj.digest()
return hash_obj.digest()

futures: List[Tuple[int, List[Future]]] = []

Expand All @@ -1046,11 +1080,7 @@ def _process_url_part(src: PhysicalKey, offset: int, length: int):
for idx, future_list in futures:
future_results = [future.result() for future in future_list]
exceptions = [ex for ex in future_results if isinstance(ex, Exception)]
if exceptions:
results[idx] = exceptions[0]
else:
hashes_hash = hashlib.sha256(b''.join(future_results)).digest()
results[idx] = binascii.b2a_base64(hashes_hash, newline=False).decode()
results[idx] = exceptions[0] if exceptions else _make_checksum_from_parts(future_results)
finally:
stopped = True
for _, future_list in futures:
Expand Down
4 changes: 2 additions & 2 deletions api/python/quilt3/packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,9 +1540,9 @@ def check_hash_conficts(latest_hash):
new_entry.hash = dict(type=SHA256_CHUNKED_HASH_NAME, value=checksum)
pkg._set(logical_key, new_entry)

# Needed if the files already exist in S3, but were uploaded without ChecksumAlgorithm='SHA256'.
# Some entries may miss hash values (e.g because of selector_fn), so we need
# to fix them before calculating the top hash.
pkg._fix_sha256()

top_hash = pkg._calculate_top_hash(pkg._meta, pkg.walk())

if dedupe and top_hash == latest_hash:
Expand Down
6 changes: 4 additions & 2 deletions api/python/tests/integration/test_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,10 +1914,11 @@ def test_push_selector_fn_false(self):
selector_fn = mock.MagicMock(return_value=False)
push_manifest_mock = self.patch_s3_registry('push_manifest')
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4')
with patch('quilt3.packages.calculate_checksum', return_value=[('SHA256', "a" * 64)]):
with patch('quilt3.packages.calculate_checksum', return_value=["a" * 64]) as calculate_checksum_mock:
pkg.push(pkg_name, registry=f's3://{dst_bucket}', selector_fn=selector_fn, force=True)

selector_fn.assert_called_once_with(lk, pkg[lk])
calculate_checksum_mock.assert_called_once_with([PhysicalKey(src_bucket, src_key, src_version)], [0])
push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY)
assert Package.load(
BytesIO(push_manifest_mock.call_args[0][2])
Expand Down Expand Up @@ -1960,10 +1961,11 @@ def test_push_selector_fn_true(self):
)
push_manifest_mock = self.patch_s3_registry('push_manifest')
self.patch_s3_registry('shorten_top_hash', return_value='7a67ff4')
with patch('quilt3.packages.calculate_checksum', return_value=["a" * 64]):
with patch('quilt3.packages.calculate_checksum', return_value=[]) as calculate_checksum_mock:
pkg.push(pkg_name, registry=f's3://{dst_bucket}', selector_fn=selector_fn, force=True)

selector_fn.assert_called_once_with(lk, pkg[lk])
calculate_checksum_mock.assert_called_once_with([], [])
push_manifest_mock.assert_called_once_with(pkg_name, mock.sentinel.top_hash, ANY)
assert Package.load(
BytesIO(push_manifest_mock.call_args[0][2])
Expand Down
Loading

0 comments on commit 358feaf

Please sign in to comment.