diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 64ad10050d..4621b5f3bf 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -877,12 +877,17 @@ def upsert( # get list of rows that exist so we don't have to load the entire target table matched_predicate = upsert_util.create_match_filter(df, join_cols) + # When ``when_matched_update_all=False`` the consumer loop below + # only ever reads ``join_cols`` off each destination batch. + selected_fields: tuple[str, ...] = ("*",) if when_matched_update_all else tuple(join_cols) + # We must use Transaction.table_metadata for the scan. This includes all uncommitted - but relevant - changes. matched_iceberg_record_batches_scan = DataScan( table_metadata=self.table_metadata, io=self._table.io, row_filter=matched_predicate, + selected_fields=selected_fields, case_sensitive=case_sensitive, ) diff --git a/tests/table/test_upsert.py b/tests/table/test_upsert.py index 08f90c6600..83444c0c0e 100644 --- a/tests/table/test_upsert.py +++ b/tests/table/test_upsert.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from pathlib import PosixPath +from typing import Any import pyarrow as pa import pytest @@ -888,3 +889,63 @@ def test_upsert_snapshot_properties(catalog: Catalog) -> None: for snapshot in snapshots[initial_snapshot_count:]: assert snapshot.summary is not None assert snapshot.summary.additional_properties.get("test_prop") == "test_value" + + +def test_upsert_narrows_destination_scan_projection_to_join_cols( + catalog: Catalog, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """``Transaction.upsert`` narrows the destination scan's + ``selected_fields`` to ``join_cols`` when + ``when_matched_update_all=False``. + + The insert-on-no-match branch only reads ``join_cols`` from each + destination batch (to feed ``create_match_filter``), so projection + at the scan boundary lets the parquet reader skip wide non-key + columns. The ``("*",)`` fallback on the ``=True`` branch is + exercised by the rest of this module — ``get_rows_to_update``'s + value-drift detection would silently break if it ever regressed. + """ + import functools + + from pyiceberg.table import DataScan + + identifier = "default.test_upsert_narrows_projection" + _drop_table(catalog, identifier) + table = catalog.create_table( + identifier, + schema=Schema( + NestedField(1, "id", IntegerType(), required=True), + NestedField(2, "payload", StringType(), required=True), + ), + ) + arrow_schema = pa.schema([pa.field("id", pa.int32(), nullable=False), pa.field("payload", pa.string(), nullable=False)]) + table.append(pa.Table.from_pylist([{"id": 1, "payload": "a"}], schema=arrow_schema)) + + # Spy on ``DataScan.__init__`` to capture each constructed scan's + # ``selected_fields``. ``functools.wraps`` preserves the original + # signature so ``DataScan.update()``'s reflective parameter lookup + # (used inside ``use_ref``) still resolves correctly. + captured: list[tuple[str, ...] | None] = [] + original_init = DataScan.__init__ + + @functools.wraps(original_init) + def _spy(self: DataScan, *args: Any, **kwargs: Any) -> None: + original_init(self, *args, **kwargs) + captured.append(kwargs.get("selected_fields")) + + monkeypatch.setattr(DataScan, "__init__", _spy) + + table.upsert( + df=pa.Table.from_pylist( + [{"id": 1, "payload": "a-new"}, {"id": 2, "payload": "b"}], + schema=arrow_schema, + ), + join_cols=["id"], + when_matched_update_all=False, + ) + + assert captured, "upsert path constructed no DataScan — projection contract regression" + assert all(sf == ("id",) for sf in captured), ( + f"expected every DataScan during upsert to use selected_fields=('id',); got {captured}" + )