diff --git a/packages/google-cloud-storage/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py b/packages/google-cloud-storage/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py index 845770c3a215..3a782c8135eb 100644 --- a/packages/google-cloud-storage/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py +++ b/packages/google-cloud-storage/google/cloud/storage/asyncio/retry/reads_resumption_strategy.py @@ -36,7 +36,11 @@ class _DownloadState: """A helper class to track the state of a single range download.""" def __init__( - self, initial_offset: int, initial_length: int, user_buffer: IO[bytes] + self, + initial_offset: int, + initial_length: int, + user_buffer: IO[bytes], + is_full_object_read: bool = False, ): self.initial_offset = initial_offset self.initial_length = initial_length @@ -44,6 +48,10 @@ def __init__( self.bytes_written = 0 self.next_expected_offset = initial_offset self.is_complete = False + self.is_full_object_read = is_full_object_read + self.rolling_checksum = ( + google_crc32c.Checksum() if is_full_object_read else None + ) class _ReadResumptionStrategy(_BaseResumptionStrategy): @@ -90,6 +98,7 @@ def update_state_from_response( ) download_states = state["download_states"] + checksum_enabled = state.get("enable_checksum", True) for object_data_range in proto.object_data_ranges: # Ignore empty ranges or ranges for IDs not in our state @@ -125,7 +134,7 @@ def update_state_from_response( checksummed_data = object_data_range.checksummed_data data = checksummed_data.content - if checksummed_data.HasField("crc32c"): + if checksum_enabled and checksummed_data.HasField("crc32c"): server_checksum = checksummed_data.crc32c client_checksum = google_crc32c.value(data) if server_checksum != client_checksum: @@ -138,10 +147,14 @@ def update_state_from_response( # Update State & Write Data chunk_size = len(data) read_state.user_buffer.write(data) + + # Commit updates only after the write succeeds + if checksum_enabled and read_state.rolling_checksum is not None: + read_state.rolling_checksum.update(data) read_state.bytes_written += chunk_size read_state.next_expected_offset += chunk_size - # Final Byte Count Verification + # Final Byte Count & Full Object Checksum Verification if object_data_range.range_end: read_state.is_complete = True if ( @@ -154,6 +167,22 @@ def update_state_from_response( f"Expected {read_state.initial_length}, got {read_state.bytes_written}", ) + # Perform full-object checksum verification once the stream finishes. + if read_state.is_full_object_read and checksum_enabled: + full_obj_server_crc32c = state.get("full_obj_server_crc32c") + if full_obj_server_crc32c is not None: + # Use standard big-endian byte conversion to retrieve the rolling checksum value. + client_checksum = int.from_bytes( + read_state.rolling_checksum.digest(), + byteorder="big", + ) + if client_checksum != full_obj_server_crc32c: + raise DataCorruption( + response, + f"Full object checksum mismatch for read_id {read_id}. " + f"Server authoritative crc32c: {full_obj_server_crc32c}, client calculated rolling: {client_checksum}.", + ) + async def recover_state_on_failure(self, error: Exception, state: Any) -> None: """Handles BidiReadObjectRedirectedError for reads.""" routing_token, read_handle = _handle_redirect(error) diff --git a/packages/google-cloud-storage/tests/unit/asyncio/retry/test_reads_resumption_strategy.py b/packages/google-cloud-storage/tests/unit/asyncio/retry/test_reads_resumption_strategy.py index dc27cb701974..4f7849801acd 100644 --- a/packages/google-cloud-storage/tests/unit/asyncio/retry/test_reads_resumption_strategy.py +++ b/packages/google-cloud-storage/tests/unit/asyncio/retry/test_reads_resumption_strategy.py @@ -45,6 +45,26 @@ def test_initialization(self): self.assertEqual(state.bytes_written, 0) self.assertEqual(state.next_expected_offset, initial_offset) self.assertFalse(state.is_complete) + self.assertFalse(state.is_full_object_read) + self.assertIsNone(state.rolling_checksum) + + def test_initialization_with_full_object_read(self): + """Test that _DownloadState initializes correctly when is_full_object_read is True.""" + initial_offset = 10 + initial_length = 100 + user_buffer = io.BytesIO() + state_full = _DownloadState( + initial_offset, initial_length, user_buffer, is_full_object_read=True + ) + + self.assertEqual(state_full.initial_offset, initial_offset) + self.assertEqual(state_full.initial_length, initial_length) + self.assertEqual(state_full.user_buffer, user_buffer) + self.assertEqual(state_full.bytes_written, 0) + self.assertEqual(state_full.next_expected_offset, initial_offset) + self.assertFalse(state_full.is_complete) + self.assertTrue(state_full.is_full_object_read) + self.assertIsNotNone(state_full.rolling_checksum) class TestReadResumptionStrategy(unittest.TestCase): @@ -53,12 +73,17 @@ def setUp(self): self.state = {"download_states": {}, "read_handle": None, "routing_token": None} - def _add_download(self, read_id, offset=0, length=100, buffer=None): + def _add_download( + self, read_id, offset=0, length=100, buffer=None, is_full_object_read=False + ): """Helper to inject a download state into the correct nested location.""" if buffer is None: buffer = io.BytesIO() state = _DownloadState( - initial_offset=offset, initial_length=length, user_buffer=buffer + initial_offset=offset, + initial_length=length, + user_buffer=buffer, + is_full_object_read=is_full_object_read, ) self.state["download_states"][read_id] = state return state @@ -358,3 +383,55 @@ async def run(): # Token should remain unchanged self.assertEqual(self.state["routing_token"], "existing-token") + + def test_update_state_full_object_checksum_success(self): + """Test that full object checksum verification succeeds on range_end.""" + read_state = self._add_download( + _READ_ID, offset=0, length=9, is_full_object_read=True + ) + self.state["enable_checksum"] = True + self.state["full_obj_server_crc32c"] = google_crc32c.value(b"testdata1") + + resp1 = self._create_response(b"test", _READ_ID, offset=0) + self.strategy.update_state_from_response(resp1, self.state) + + resp2 = self._create_response(b"data1", _READ_ID, offset=4, range_end=True) + self.strategy.update_state_from_response(resp2, self.state) + + self.assertTrue(read_state.is_complete) + self.assertEqual(read_state.bytes_written, 9) + + def test_update_state_full_object_checksum_failure(self): + """Test that full object checksum verification raises DataCorruption on mismatch at range_end.""" + self._add_download(_READ_ID, offset=0, length=9, is_full_object_read=True) + self.state["enable_checksum"] = True + self.state["full_obj_server_crc32c"] = 111111 # Wrong server checksum! + + resp1 = self._create_response(b"test", _READ_ID, offset=0) + self.strategy.update_state_from_response(resp1, self.state) + + resp2 = self._create_response(b"data1", _READ_ID, offset=4, range_end=True) + with self.assertRaisesRegex(DataCorruption, "Full object checksum mismatch"): + self.strategy.update_state_from_response(resp2, self.state) + + def test_update_state_checksum_mismatch_ignored_when_disabled(self): + """Test that a CRC32C mismatch is ignored when enable_checksum is False.""" + self._add_download(_READ_ID) + self.state["enable_checksum"] = False + response = self._create_response(b"data", _READ_ID, offset=0, crc=999999) + + # Should NOT raise DataCorruption! + self.strategy.update_state_from_response(response, self.state) + + def test_update_state_full_object_checksum_mismatch_ignored_when_disabled(self): + """Test that a full-object CRC32C mismatch is ignored when enable_checksum is False.""" + self._add_download(_READ_ID, offset=0, length=9, is_full_object_read=True) + self.state["enable_checksum"] = False + self.state["full_obj_server_crc32c"] = 111111 # Wrong server checksum! + + resp1 = self._create_response(b"test", _READ_ID, offset=0) + self.strategy.update_state_from_response(resp1, self.state) + + resp2 = self._create_response(b"data1", _READ_ID, offset=4, range_end=True) + # Should NOT raise DataCorruption! + self.strategy.update_state_from_response(resp2, self.state)