Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from kubernetes import client as k8s_client_lib

from cloud_pipelines.orchestration.storage_providers import google_cloud_storage
from cloud_pipelines_backend.storage_providers import patched_google_cloud_storage

from . import kubernetes_launchers

Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
pod_labels=pod_labels,
pod_annotations={"gke-gcsfuse/volumes": "true"} | (pod_annotations or {}),
pod_postprocessor=final_pod_postporocessor,
_storage_provider=google_cloud_storage.GoogleCloudStorageProvider(
_storage_provider=patched_google_cloud_storage.PatchedGoogleCloudStorageProvider(
gcs_client
),
_create_volume_and_volume_mount=kubernetes_launchers._create_volume_and_volume_mount_google_cloud_storage,
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
pod_labels=pod_labels,
pod_annotations={"gke-gcsfuse/volumes": "true"} | (pod_annotations or {}),
pod_postprocessor=final_pod_postporocessor,
_storage_provider=google_cloud_storage.GoogleCloudStorageProvider(
_storage_provider=patched_google_cloud_storage.PatchedGoogleCloudStorageProvider(
gcs_client
),
_create_volume_and_volume_mount=kubernetes_launchers._create_volume_and_volume_mount_google_cloud_storage,
Expand Down
6 changes: 4 additions & 2 deletions cloud_pipelines_backend/launchers/kubernetes_launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,15 +717,17 @@ def __init__(
pod_postprocessors.append(pod_postprocessor)
final_pod_postporocessor = _create_pod_postprocessor_stack(pod_postprocessors)

from cloud_pipelines.orchestration.storage_providers import google_cloud_storage
from cloud_pipelines_backend.storage_providers import (
patched_google_cloud_storage,
)

super().__init__(
namespace=namespace,
service_account_name=service_account_name,
api_client=api_client,
request_timeout=request_timeout,
pod_name_prefix=pod_name_prefix,
_storage_provider=google_cloud_storage.GoogleCloudStorageProvider(
_storage_provider=patched_google_cloud_storage.PatchedGoogleCloudStorageProvider(
gcs_client
),
pod_labels=pod_labels,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Patched GCS storage provider that handles composite objects missing MD5 hashes.

GCS composite objects (created via parallel uploads or compose operations) do not
have an MD5 hash. The upstream SDK crashes with `AttributeError: 'NoneType' object
has no attribute 'encode'` when it tries to read `blob.md5_hash` for these objects.

This subclass overrides `_get_info_from_uri` to fall back to a unique timestamp
when MD5 is missing, which forces a cache miss for any downstream task that consumes
these artifacts. See: https://cloud.google.com/storage/docs/composite-objects
"""

import base64
import datetime
import logging

from cloud_pipelines.orchestration.storage_providers import google_cloud_storage
from cloud_pipelines.orchestration.storage_providers import interfaces

_LOGGER = logging.getLogger(name=__name__)


def _blob_hash(
*,
blob,
) -> dict[str, str]:
"""Return a hash dict for a GCS blob, handling composite objects gracefully.

For normal objects: returns {"md5": "<hex digest>"}.
For composite objects (md5_hash is None): returns {"md5": "no_md5_<ISO timestamp>"}
which produces a unique, non-repeating value that forces a cache miss.

The key must always be "md5" because the upstream SDK's _make_data_info_for_dir
assumes all files in a directory share the same hash key names and uses hashlib.new()
with that key name.
"""
if blob.md5_hash is not None:
return {
"md5": base64.decodebytes(blob.md5_hash.encode("ascii")).hex(),
}

_LOGGER.warning(
f"Blob {blob.name} is a composite object (component_count={blob.component_count}) "
f"with no MD5 hash. Using timestamp fallback — downstream tasks will not use caching for this artifact."
)
timestamp = datetime.datetime.now(tz=datetime.timezone.utc).isoformat()
return {
"md5": f"no_md5_{timestamp}",
}


class PatchedGoogleCloudStorageProvider(
google_cloud_storage.GoogleCloudStorageProvider,
):
"""GCS provider that gracefully handles composite objects with no MD5 hash.

NOTE: _get_info_from_uri is copied verbatim from the upstream SDK
(cloud-pipelines==0.26.3.12, google_cloud_storage.py lines 142-179)
with the only change being: blob.md5_hash.encode("ascii") replaced
by _blob_hash(blob=blob) to handle None md5_hash on composite objects.
If the upstream SDK is upgraded, this override must be kept in sync.
"""

def _get_info_from_uri(self, uri: str) -> interfaces.DataInfo:
from google.cloud import storage

blob_or_dir = storage.Blob.from_string(uri=uri, client=self._client)
if blob_or_dir.exists():
blob = blob_or_dir
blob.reload()
return interfaces.DataInfo(
total_size=blob.size,
is_dir=False,
hashes=_blob_hash(blob=blob),
)

dir_prefix = blob_or_dir.name.rstrip("/") + "/"
file_info_list = []
for blob in self._client.list_blobs(
bucket_or_name=blob_or_dir.bucket,
prefix=dir_prefix,
):
blob.reload()
assert blob.name.startswith(dir_prefix)
relative_source_blob_name = blob.name[len(dir_prefix) :]
file_info_list.append(
interfaces._FileInfo(
path=relative_source_blob_name,
size=blob.size,
hashes=_blob_hash(blob=blob),
)
)
data_info = interfaces._make_data_info_for_dir(file_info_list)
data_info._file_info_list = file_info_list
return data_info
128 changes: 128 additions & 0 deletions tests/test_patched_google_cloud_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import datetime
import sys
from unittest import mock

import pytest

# google-cloud-storage is not installed in the dev venv — mock the import chain
# so the module can be imported without the actual GCS SDK.
_mock_gcs_module = mock.MagicMock()
sys.modules.setdefault("google.cloud.storage", _mock_gcs_module)
sys.modules.setdefault("google.cloud", mock.MagicMock(storage=_mock_gcs_module))

from cloud_pipelines_backend.storage_providers import patched_google_cloud_storage


def _make_blob(
*,
name: str = "artifacts/test/output.txt",
size: int = 1024,
md5_hash: str | None = "1B2M2Y8AsgTpgAmY7PhCfg==",
component_count: int | None = None,
) -> mock.MagicMock:
blob = mock.MagicMock()
blob.name = name
blob.size = size
blob.md5_hash = md5_hash
blob.component_count = component_count
return blob


class TestBlobHash:
def test_normal_blob_returns_md5(self) -> None:
blob = _make_blob(md5_hash="1B2M2Y8AsgTpgAmY7PhCfg==")
result = patched_google_cloud_storage._blob_hash(blob=blob)

assert "md5" in result
assert result["md5"] == "d41d8cd98f00b204e9800998ecf8427e"

def test_composite_blob_returns_timestamp_under_md5_key(self) -> None:
blob = _make_blob(md5_hash=None, component_count=2)
result = patched_google_cloud_storage._blob_hash(blob=blob)

assert "md5" in result
assert result["md5"].startswith("no_md5_")
datetime.datetime.fromisoformat(result["md5"].removeprefix("no_md5_"))

def test_composite_blob_timestamps_are_unique(self) -> None:
blob = _make_blob(md5_hash=None, component_count=3)
result_1 = patched_google_cloud_storage._blob_hash(blob=blob)
result_2 = patched_google_cloud_storage._blob_hash(blob=blob)

assert result_1["md5"] != result_2["md5"]

def test_composite_blob_logs_warning(self) -> None:
blob = _make_blob(
name="artifacts/test/model.bin",
md5_hash=None,
component_count=5,
)
with mock.patch.object(
patched_google_cloud_storage._LOGGER, "warning"
) as mock_warn:
patched_google_cloud_storage._blob_hash(blob=blob)

mock_warn.assert_called_once()
call_args = mock_warn.call_args[0][0]
assert "artifacts/test/model.bin" in call_args
assert "component_count=5" in call_args


class TestPatchedProvider:
def _make_provider(
self,
) -> patched_google_cloud_storage.PatchedGoogleCloudStorageProvider:
provider = patched_google_cloud_storage.PatchedGoogleCloudStorageProvider(
client=mock.MagicMock(),
)
return provider

def test_single_file_with_md5(self) -> None:
provider = self._make_provider()
blob = _make_blob(md5_hash="1B2M2Y8AsgTpgAmY7PhCfg==")
blob.exists.return_value = True

_mock_gcs_module.Blob.from_string.return_value = blob
result = provider._get_info_from_uri("gs://bucket/file.txt")

assert result.is_dir is False
assert result.total_size == 1024
assert result.hashes["md5"] == "d41d8cd98f00b204e9800998ecf8427e"

def test_single_file_composite_no_md5(self) -> None:
provider = self._make_provider()
blob = _make_blob(md5_hash=None, component_count=2)
blob.exists.return_value = True

_mock_gcs_module.Blob.from_string.return_value = blob
result = provider._get_info_from_uri("gs://bucket/file.txt")

assert result.is_dir is False
assert result.hashes["md5"].startswith("no_md5_")

def test_directory_with_mixed_blobs(self) -> None:
provider = self._make_provider()

dir_blob = mock.MagicMock()
dir_blob.exists.return_value = False
dir_blob.name = "artifacts/output"
dir_blob.bucket = "bucket"

normal_blob = _make_blob(
name="artifacts/output/data.csv",
size=500,
md5_hash="1B2M2Y8AsgTpgAmY7PhCfg==",
)
composite_blob = _make_blob(
name="artifacts/output/model.bin",
size=2000,
md5_hash=None,
component_count=2,
)

_mock_gcs_module.Blob.from_string.return_value = dir_blob
provider._client.list_blobs.return_value = [normal_blob, composite_blob]
result = provider._get_info_from_uri("gs://bucket/artifacts/output")

assert result.is_dir is True
assert result.total_size == 2500
Loading