diff --git a/cloud_pipelines_backend/launchers/google_kubernetes_launchers.py b/cloud_pipelines_backend/launchers/google_kubernetes_launchers.py index 6bbb5510..1e4ee460 100644 --- a/cloud_pipelines_backend/launchers/google_kubernetes_launchers.py +++ b/cloud_pipelines_backend/launchers/google_kubernetes_launchers.py @@ -2,7 +2,7 @@ from kubernetes import client as k8s_client_lib -from cloud_pipelines.orchestration.storage_providers import google_cloud_storage +from cloud_pipelines_backend.storage_providers import patched_google_cloud_storage from . import kubernetes_launchers @@ -44,7 +44,7 @@ def __init__( pod_labels=pod_labels, pod_annotations={"gke-gcsfuse/volumes": "true"} | (pod_annotations or {}), pod_postprocessor=final_pod_postporocessor, - _storage_provider=google_cloud_storage.GoogleCloudStorageProvider( + _storage_provider=patched_google_cloud_storage.PatchedGoogleCloudStorageProvider( gcs_client ), _create_volume_and_volume_mount=kubernetes_launchers._create_volume_and_volume_mount_google_cloud_storage, @@ -85,7 +85,7 @@ def __init__( pod_labels=pod_labels, pod_annotations={"gke-gcsfuse/volumes": "true"} | (pod_annotations or {}), pod_postprocessor=final_pod_postporocessor, - _storage_provider=google_cloud_storage.GoogleCloudStorageProvider( + _storage_provider=patched_google_cloud_storage.PatchedGoogleCloudStorageProvider( gcs_client ), _create_volume_and_volume_mount=kubernetes_launchers._create_volume_and_volume_mount_google_cloud_storage, diff --git a/cloud_pipelines_backend/launchers/kubernetes_launchers.py b/cloud_pipelines_backend/launchers/kubernetes_launchers.py index 6c481baa..c433a47a 100644 --- a/cloud_pipelines_backend/launchers/kubernetes_launchers.py +++ b/cloud_pipelines_backend/launchers/kubernetes_launchers.py @@ -717,7 +717,9 @@ def __init__( pod_postprocessors.append(pod_postprocessor) final_pod_postporocessor = _create_pod_postprocessor_stack(pod_postprocessors) - from cloud_pipelines.orchestration.storage_providers import google_cloud_storage + from cloud_pipelines_backend.storage_providers import ( + patched_google_cloud_storage, + ) super().__init__( namespace=namespace, @@ -725,7 +727,7 @@ def __init__( api_client=api_client, request_timeout=request_timeout, pod_name_prefix=pod_name_prefix, - _storage_provider=google_cloud_storage.GoogleCloudStorageProvider( + _storage_provider=patched_google_cloud_storage.PatchedGoogleCloudStorageProvider( gcs_client ), pod_labels=pod_labels, diff --git a/cloud_pipelines_backend/storage_providers/patched_google_cloud_storage.py b/cloud_pipelines_backend/storage_providers/patched_google_cloud_storage.py new file mode 100644 index 00000000..09476624 --- /dev/null +++ b/cloud_pipelines_backend/storage_providers/patched_google_cloud_storage.py @@ -0,0 +1,94 @@ +"""Patched GCS storage provider that handles composite objects missing MD5 hashes. + +GCS composite objects (created via parallel uploads or compose operations) do not +have an MD5 hash. The upstream SDK crashes with `AttributeError: 'NoneType' object +has no attribute 'encode'` when it tries to read `blob.md5_hash` for these objects. + +This subclass overrides `_get_info_from_uri` to fall back to a unique timestamp +when MD5 is missing, which forces a cache miss for any downstream task that consumes +these artifacts. See: https://cloud.google.com/storage/docs/composite-objects +""" + +import base64 +import datetime +import logging + +from cloud_pipelines.orchestration.storage_providers import google_cloud_storage +from cloud_pipelines.orchestration.storage_providers import interfaces + +_LOGGER = logging.getLogger(name=__name__) + + +def _blob_hash( + *, + blob, +) -> dict[str, str]: + """Return a hash dict for a GCS blob, handling composite objects gracefully. + + For normal objects: returns {"md5": ""}. + For composite objects (md5_hash is None): returns {"md5": "no_md5_"} + which produces a unique, non-repeating value that forces a cache miss. + + The key must always be "md5" because the upstream SDK's _make_data_info_for_dir + assumes all files in a directory share the same hash key names and uses hashlib.new() + with that key name. + """ + if blob.md5_hash is not None: + return { + "md5": base64.decodebytes(blob.md5_hash.encode("ascii")).hex(), + } + + _LOGGER.warning( + f"Blob {blob.name} is a composite object (component_count={blob.component_count}) " + f"with no MD5 hash. Using timestamp fallback — downstream tasks will not use caching for this artifact." + ) + timestamp = datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + return { + "md5": f"no_md5_{timestamp}", + } + + +class PatchedGoogleCloudStorageProvider( + google_cloud_storage.GoogleCloudStorageProvider, +): + """GCS provider that gracefully handles composite objects with no MD5 hash. + + NOTE: _get_info_from_uri is copied verbatim from the upstream SDK + (cloud-pipelines==0.26.3.12, google_cloud_storage.py lines 142-179) + with the only change being: blob.md5_hash.encode("ascii") replaced + by _blob_hash(blob=blob) to handle None md5_hash on composite objects. + If the upstream SDK is upgraded, this override must be kept in sync. + """ + + def _get_info_from_uri(self, uri: str) -> interfaces.DataInfo: + from google.cloud import storage + + blob_or_dir = storage.Blob.from_string(uri=uri, client=self._client) + if blob_or_dir.exists(): + blob = blob_or_dir + blob.reload() + return interfaces.DataInfo( + total_size=blob.size, + is_dir=False, + hashes=_blob_hash(blob=blob), + ) + + dir_prefix = blob_or_dir.name.rstrip("/") + "/" + file_info_list = [] + for blob in self._client.list_blobs( + bucket_or_name=blob_or_dir.bucket, + prefix=dir_prefix, + ): + blob.reload() + assert blob.name.startswith(dir_prefix) + relative_source_blob_name = blob.name[len(dir_prefix) :] + file_info_list.append( + interfaces._FileInfo( + path=relative_source_blob_name, + size=blob.size, + hashes=_blob_hash(blob=blob), + ) + ) + data_info = interfaces._make_data_info_for_dir(file_info_list) + data_info._file_info_list = file_info_list + return data_info diff --git a/tests/test_patched_google_cloud_storage.py b/tests/test_patched_google_cloud_storage.py new file mode 100644 index 00000000..8446f820 --- /dev/null +++ b/tests/test_patched_google_cloud_storage.py @@ -0,0 +1,128 @@ +import datetime +import sys +from unittest import mock + +import pytest + +# google-cloud-storage is not installed in the dev venv — mock the import chain +# so the module can be imported without the actual GCS SDK. +_mock_gcs_module = mock.MagicMock() +sys.modules.setdefault("google.cloud.storage", _mock_gcs_module) +sys.modules.setdefault("google.cloud", mock.MagicMock(storage=_mock_gcs_module)) + +from cloud_pipelines_backend.storage_providers import patched_google_cloud_storage + + +def _make_blob( + *, + name: str = "artifacts/test/output.txt", + size: int = 1024, + md5_hash: str | None = "1B2M2Y8AsgTpgAmY7PhCfg==", + component_count: int | None = None, +) -> mock.MagicMock: + blob = mock.MagicMock() + blob.name = name + blob.size = size + blob.md5_hash = md5_hash + blob.component_count = component_count + return blob + + +class TestBlobHash: + def test_normal_blob_returns_md5(self) -> None: + blob = _make_blob(md5_hash="1B2M2Y8AsgTpgAmY7PhCfg==") + result = patched_google_cloud_storage._blob_hash(blob=blob) + + assert "md5" in result + assert result["md5"] == "d41d8cd98f00b204e9800998ecf8427e" + + def test_composite_blob_returns_timestamp_under_md5_key(self) -> None: + blob = _make_blob(md5_hash=None, component_count=2) + result = patched_google_cloud_storage._blob_hash(blob=blob) + + assert "md5" in result + assert result["md5"].startswith("no_md5_") + datetime.datetime.fromisoformat(result["md5"].removeprefix("no_md5_")) + + def test_composite_blob_timestamps_are_unique(self) -> None: + blob = _make_blob(md5_hash=None, component_count=3) + result_1 = patched_google_cloud_storage._blob_hash(blob=blob) + result_2 = patched_google_cloud_storage._blob_hash(blob=blob) + + assert result_1["md5"] != result_2["md5"] + + def test_composite_blob_logs_warning(self) -> None: + blob = _make_blob( + name="artifacts/test/model.bin", + md5_hash=None, + component_count=5, + ) + with mock.patch.object( + patched_google_cloud_storage._LOGGER, "warning" + ) as mock_warn: + patched_google_cloud_storage._blob_hash(blob=blob) + + mock_warn.assert_called_once() + call_args = mock_warn.call_args[0][0] + assert "artifacts/test/model.bin" in call_args + assert "component_count=5" in call_args + + +class TestPatchedProvider: + def _make_provider( + self, + ) -> patched_google_cloud_storage.PatchedGoogleCloudStorageProvider: + provider = patched_google_cloud_storage.PatchedGoogleCloudStorageProvider( + client=mock.MagicMock(), + ) + return provider + + def test_single_file_with_md5(self) -> None: + provider = self._make_provider() + blob = _make_blob(md5_hash="1B2M2Y8AsgTpgAmY7PhCfg==") + blob.exists.return_value = True + + _mock_gcs_module.Blob.from_string.return_value = blob + result = provider._get_info_from_uri("gs://bucket/file.txt") + + assert result.is_dir is False + assert result.total_size == 1024 + assert result.hashes["md5"] == "d41d8cd98f00b204e9800998ecf8427e" + + def test_single_file_composite_no_md5(self) -> None: + provider = self._make_provider() + blob = _make_blob(md5_hash=None, component_count=2) + blob.exists.return_value = True + + _mock_gcs_module.Blob.from_string.return_value = blob + result = provider._get_info_from_uri("gs://bucket/file.txt") + + assert result.is_dir is False + assert result.hashes["md5"].startswith("no_md5_") + + def test_directory_with_mixed_blobs(self) -> None: + provider = self._make_provider() + + dir_blob = mock.MagicMock() + dir_blob.exists.return_value = False + dir_blob.name = "artifacts/output" + dir_blob.bucket = "bucket" + + normal_blob = _make_blob( + name="artifacts/output/data.csv", + size=500, + md5_hash="1B2M2Y8AsgTpgAmY7PhCfg==", + ) + composite_blob = _make_blob( + name="artifacts/output/model.bin", + size=2000, + md5_hash=None, + component_count=2, + ) + + _mock_gcs_module.Blob.from_string.return_value = dir_blob + provider._client.list_blobs.return_value = [normal_blob, composite_blob] + result = provider._get_info_from_uri("gs://bucket/artifacts/output") + + assert result.is_dir is True + assert result.total_size == 2500