diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 4ec7a73afe..c4c4003db0 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -149,7 +149,7 @@ from pyiceberg.table.name_mapping import NameMapping, apply_name_mapping from pyiceberg.table.puffin import PuffinFile from pyiceberg.transforms import IdentityTransform, TruncateTransform -from pyiceberg.typedef import EMPTY_DICT, Properties, Record, TableVersion +from pyiceberg.typedef import EMPTY_DICT, ArrowStreamExportable, Properties, Record, TableVersion from pyiceberg.types import ( BinaryType, BooleanType, @@ -2680,30 +2680,45 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[list[ """Bin-pack ``tbl`` into groups of RecordBatches, each ~``target_file_size``. Note: - ``target_file_size`` is measured in **uncompressed in-memory** Arrow bytes - (``Table.nbytes`` / ``RecordBatch.nbytes``), not compressed on-disk Parquet - bytes. The resulting Parquet file after compression (zstd by default, - plus dictionary/RLE encoding) is typically 3-10× smaller than - ``target_file_size``. This is a coarse proxy for the spec-defined + ``target_file_size`` is measured in **uncompressed in-memory** Arrow + bytes, not compressed on-disk Parquet bytes. The size estimate uses + ``nbytes`` when available and falls back to referenced buffer size for + Arrow view types that do not support ``nbytes``. The resulting Parquet + file after compression (zstd by default, plus dictionary/RLE encoding) + is typically 3-10× smaller than ``target_file_size``. This is a coarse + proxy for the spec-defined ``write.target-file-size-bytes`` and will be tightened to true on-disk bytes once the writer is switched to a rolling-``ParquetWriter`` with ``OutputStream.tell()`` (#2998). """ from pyiceberg.utils.bin_packing import PackingIterator - avg_row_size_bytes = tbl.nbytes / tbl.num_rows + avg_row_size_bytes = _arrow_data_size(tbl) / tbl.num_rows target_rows_per_file = max(1, int(target_file_size / avg_row_size_bytes)) batches = tbl.to_batches(max_chunksize=target_rows_per_file) bin_packed_record_batches = PackingIterator( items=batches, target_weight=target_file_size, lookback=len(batches), # ignore lookback - weight_func=lambda x: x.nbytes, + weight_func=_arrow_data_size, largest_bin_first=False, ) return bin_packed_record_batches +def _arrow_data_size(data: pa.Table | pa.RecordBatch) -> int: + """Estimate Arrow data size for writer bin-packing. + + ``nbytes`` is the better logical-size estimate, but PyArrow can raise for + view types such as ``string_view`` exported by libraries like Polars. Fall + back to total referenced buffer size so those streams can still be written. + """ + try: + return data.nbytes + except pyarrow.lib.ArrowTypeError: + return data.get_total_buffer_size() + + def bin_pack_record_batches(batches: Iterable[pa.RecordBatch], target_file_size: int) -> Iterator[list[pa.RecordBatch]]: """Microbatch a single-pass stream of RecordBatches into target-sized groups. @@ -2719,9 +2734,11 @@ def bin_pack_record_batches(batches: Iterable[pa.RecordBatch], target_file_size: Note: ``target_file_size`` is measured in **uncompressed in-memory** Arrow - bytes (``RecordBatch.nbytes``), not compressed on-disk Parquet bytes. - The resulting Parquet file after compression is typically 3-10× - smaller than ``target_file_size``. Matches the existing + bytes, not compressed on-disk Parquet bytes. The size estimate uses + ``nbytes`` when available and falls back to referenced buffer size for + Arrow view types that do not support ``nbytes``. The resulting Parquet + file after compression is typically 3-10× smaller than + ``target_file_size``. Matches the existing :func:`bin_pack_arrow_table` semantics; both will be tightened to true on-disk bytes once the writer is switched to a rolling- ``ParquetWriter`` with ``OutputStream.tell()`` (#2998). @@ -2730,7 +2747,7 @@ def bin_pack_record_batches(batches: Iterable[pa.RecordBatch], target_file_size: buffer_bytes = 0 for batch in batches: buffer.append(batch) - buffer_bytes += batch.nbytes + buffer_bytes += _arrow_data_size(batch) if buffer_bytes >= target_file_size: yield buffer buffer = [] @@ -3033,3 +3050,23 @@ def _get_field_from_arrow_table(arrow_table: pa.Table, field_path: str) -> pa.Ar field_array = arrow_table[path_parts[0]] # Navigate into the struct using the remaining path parts return pc.struct_field(field_array, path_parts[1:]) + + +def _coerce_arrow_input(df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable) -> pa.Table | pa.RecordBatchReader: + """Normalize Arrow write input to a pa.Table or pa.RecordBatchReader. + + Native pyarrow inputs pass through unchanged; any object implementing the + Arrow PyCapsule stream interface (``__arrow_c_stream__``) is imported as a + streaming RecordBatchReader. + """ + if isinstance(df, (pa.Table, pa.RecordBatchReader)): + return df + + # Any object implementing the Arrow PyCapsule stream interface. + if hasattr(df, "__arrow_c_stream__"): + return pa.RecordBatchReader.from_stream(df) + + raise ValueError( + f"Expected pa.Table, pa.RecordBatchReader, or an object implementing the " + f"Arrow PyCapsule interface (__arrow_c_stream__), got: {df!r}" + ) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 64ad10050d..9846677f4b 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -82,6 +82,7 @@ from pyiceberg.transforms import IdentityTransform from pyiceberg.typedef import ( EMPTY_DICT, + ArrowStreamExportable, IcebergBaseModel, IcebergRootModel, Identifier, @@ -452,7 +453,7 @@ def update_statistics(self) -> UpdateStatistics: def append( self, - df: pa.Table | pa.RecordBatchReader, + df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH, ) -> None: @@ -505,10 +506,9 @@ def append( except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _coerce_arrow_input, _dataframe_to_data_files - if not isinstance(df, (pa.Table, pa.RecordBatchReader)): - raise ValueError(f"Expected pa.Table or pa.RecordBatchReader, got: {df}") + df = _coerce_arrow_input(df) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( @@ -598,7 +598,7 @@ def dynamic_partition_overwrite( def overwrite( self, - df: pa.Table | pa.RecordBatchReader, + df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable, overwrite_filter: BooleanExpression | str = ALWAYS_TRUE, snapshot_properties: dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, @@ -662,10 +662,9 @@ def overwrite( except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _dataframe_to_data_files + from pyiceberg.io.pyarrow import _check_pyarrow_schema_compatible, _coerce_arrow_input, _dataframe_to_data_files - if not isinstance(df, (pa.Table, pa.RecordBatchReader)): - raise ValueError(f"Expected pa.Table or pa.RecordBatchReader, got: {df}") + df = _coerce_arrow_input(df) downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False _check_pyarrow_schema_compatible( @@ -1472,7 +1471,7 @@ def upsert( def append( self, - df: pa.Table | pa.RecordBatchReader, + df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable, snapshot_properties: dict[str, str] = EMPTY_DICT, branch: str | None = MAIN_BRANCH, ) -> None: @@ -1507,7 +1506,7 @@ def dynamic_partition_overwrite( def overwrite( self, - df: pa.Table | pa.RecordBatchReader, + df: pa.Table | pa.RecordBatchReader | ArrowStreamExportable, overwrite_filter: BooleanExpression | str = ALWAYS_TRUE, snapshot_properties: dict[str, str] = EMPTY_DICT, case_sensitive: bool = True, @@ -1716,6 +1715,10 @@ def __datafusion_table_provider__(self, session: Any | None = None) -> IcebergDa ).__datafusion_table_provider__ return provider(session) + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: + """Export this Table as an Arrow C stream (PyCapsule interface).""" + return self.scan().to_arrow_batch_reader().__arrow_c_stream__(requested_schema) + class StaticTable(Table): """Load a table directly from a metadata file (i.e., without using a catalog).""" @@ -2252,6 +2255,10 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader: batches, ).cast(target_schema) + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: + """Export this scan's result as an Arrow C stream (PyCapsule interface).""" + return self.to_arrow_batch_reader().__arrow_c_stream__(requested_schema) + def to_pandas(self, **kwargs: Any) -> pd.DataFrame: """Read a Pandas DataFrame eagerly from this Iceberg table. diff --git a/pyiceberg/typedef.py b/pyiceberg/typedef.py index 6989144ef9..e965aebfe4 100644 --- a/pyiceberg/typedef.py +++ b/pyiceberg/typedef.py @@ -112,6 +112,19 @@ def __setitem__(self, pos: int, value: Any) -> None: """Assign a value to a StructProtocol.""" +@runtime_checkable +class ArrowStreamExportable(Protocol): # pragma: no cover + """Any object implementing the Arrow PyCapsule stream interface. + + Covers pa.Table, pa.RecordBatchReader, and third-party producers + (polars, arro3, nanoarrow, ...) without depending on any of them. + """ + + @abstractmethod + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: + """Export the object as an Arrow C stream PyCapsule.""" + + class IcebergBaseModel(BaseModel): """ This class extends the Pydantic BaseModel to set default values by overriding them. diff --git a/tests/catalog/test_catalog_behaviors.py b/tests/catalog/test_catalog_behaviors.py index b859e2d541..4c94c7d3c0 100644 --- a/tests/catalog/test_catalog_behaviors.py +++ b/tests/catalog/test_catalog_behaviors.py @@ -1318,7 +1318,7 @@ def test_append_invalid_input_type_raises(catalog: Catalog) -> None: identifier = f"default.append_invalid_input_{catalog.name}" pa_table = _simple_arrow_table() tbl = catalog.create_table(identifier=identifier, schema=pa_table.schema) - with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader"): + with pytest.raises(ValueError, match="Expected pa.Table, pa.RecordBatchReader, or an object implementing"): tbl.append("not an arrow object") diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 1d1488255f..eb391144fd 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -768,7 +768,7 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non properties={"format-version": "1"}, ) - with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"): + with pytest.raises(ValueError, match="Expected pa.Table, pa.RecordBatchReader, or an object implementing"): tbl.append("not a df") diff --git a/tests/integration/test_writes/test_writes.py b/tests/integration/test_writes/test_writes.py index 609c1863bc..d32385c3fb 100644 --- a/tests/integration/test_writes/test_writes.py +++ b/tests/integration/test_writes/test_writes.py @@ -791,10 +791,10 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_ identifier = "default.arrow_data_files" tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) - with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"): + with pytest.raises(ValueError, match="Expected pa.Table, pa.RecordBatchReader, or an object implementing"): tbl.overwrite("not a df") - with pytest.raises(ValueError, match="Expected pa.Table or pa.RecordBatchReader, got: not a df"): + with pytest.raises(ValueError, match="Expected pa.Table, pa.RecordBatchReader, or an object implementing"): tbl.append("not a df") diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 2f36661a1f..7bf0ca1f36 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -2435,6 +2435,17 @@ def test_bin_pack_arrow_table_target_size_smaller_than_row(arrow_table_with_null assert sum(batch.num_rows for bin_ in bin_packed for batch in bin_) == arrow_table_with_null.num_rows +def test_bin_pack_arrow_table_with_string_view() -> None: + if not hasattr(pa, "string_view"): + pytest.skip("pyarrow does not support string_view") + + table = pa.table({"region": pa.array(["ca", "mx"], type=pa.string_view())}) + + bins = list(bin_pack_arrow_table(table, target_file_size=1)) + + assert sum(batch.num_rows for bin_ in bins for batch in bin_) == table.num_rows + + def test_bin_pack_record_batches_single_bin(arrow_table_with_null: pa.Table) -> None: batches = arrow_table_with_null.to_batches() bins = list(bin_pack_record_batches(iter(batches), target_file_size=arrow_table_with_null.nbytes * 10)) diff --git a/tests/table/test_arrow_capsule.py b/tests/table/test_arrow_capsule.py new file mode 100644 index 0000000000..d62f3207f8 --- /dev/null +++ b/tests/table/test_arrow_capsule.py @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for Arrow PyCapsule interface support on the read and write paths. + +Covers the input/consumer side (write methods accept any object implementing +``__arrow_c_stream__``) and the output/producer side (``Table`` and ``DataScan`` +expose ``__arrow_c_stream__`` so any Arrow consumer can ingest them). +""" + +from collections.abc import Callable +from pathlib import PosixPath +from typing import Any + +import pyarrow as pa +import pytest + +from pyiceberg.catalog.memory import InMemoryCatalog +from pyiceberg.io.pyarrow import _coerce_arrow_input +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.transforms import IdentityTransform +from pyiceberg.types import IntegerType, NestedField, StringType + +SCHEMA = Schema( + NestedField(1, "id", IntegerType(), required=False), + NestedField(2, "region", StringType(), required=False), +) +ARROW_SCHEMA = pa.schema( + [ + pa.field("id", pa.int32(), nullable=True), + pa.field("region", pa.string(), nullable=True), + ] +) +PARTITION_SPEC = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="region")) + + +class _ArrowStreamWrapper: + """A minimal third-party-style Arrow producer. + + Exposes only ``__arrow_c_stream__`` -- it is deliberately *not* a + ``pa.Table``/``pa.RecordBatchReader`` -- to stand in for libraries such as + polars or arro3 without taking a dependency on them. + """ + + def __init__(self, data: pa.Table): + self._data = data + + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: + return self._data.__arrow_c_stream__(requested_schema) + + +@pytest.fixture +def catalog(tmp_path: PosixPath) -> InMemoryCatalog: + catalog = InMemoryCatalog("test.in_memory.catalog", warehouse=tmp_path.absolute().as_posix()) + catalog.create_namespace("default") + return catalog + + +def _data(ids: list[int], regions: list[str]) -> pa.Table: + return pa.table({"id": pa.array(ids, type=pa.int32()), "region": regions}, schema=ARROW_SCHEMA) + + +def _string_view_data(ids: list[int], regions: list[str]) -> pa.Table: + if not hasattr(pa, "string_view"): + pytest.skip("pyarrow does not support string_view") + return pa.table( + {"id": pa.array(ids, type=pa.int32()), "region": pa.array(regions, type=pa.string_view())}, + schema=pa.schema( + [ + pa.field("id", pa.int32(), nullable=True), + pa.field("region", pa.string_view(), nullable=True), + ] + ), + ) + + +def _rows(table: pa.Table) -> list[dict[str, Any]]: + return sorted(table.to_pylist(), key=lambda r: r["id"]) + + +def test_coerce_arrow_input() -> None: + """Unit coverage of every branch of the coercion helper.""" + table = _data([1, 2, 3], ["us", "eu", "us"]) + + # native types pass through unchanged (identity) + assert _coerce_arrow_input(table) is table + reader = table.to_reader() + assert _coerce_arrow_input(reader) is reader + + # a foreign capsule producer is imported as a RecordBatchReader (streaming preserved) + coerced = _coerce_arrow_input(_ArrowStreamWrapper(table)) + assert isinstance(coerced, pa.RecordBatchReader) + assert coerced.read_all().num_rows == 3 + + # anything else is rejected + with pytest.raises(ValueError, match="Expected pa.Table, pa.RecordBatchReader"): + _coerce_arrow_input(object()) + + +# --------------------------------------------------------------------------- +# Input / consumer side (issue #2680) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "make_input", + [ + pytest.param(lambda d: d, id="table"), + pytest.param(lambda d: d.to_reader(), id="reader"), + pytest.param(lambda d: _ArrowStreamWrapper(d), id="capsule"), + # A capsule whose stream yields multiple batches must be fully drained. + pytest.param( + lambda d: _ArrowStreamWrapper(pa.Table.from_batches(d.to_batches(max_chunksize=1))), + id="capsule_multi_batch", + ), + ], +) +def test_append_accepts_arrow_inputs(catalog: InMemoryCatalog, make_input: Callable[[pa.Table], object]) -> None: + tbl = catalog.create_table("default.append", schema=SCHEMA) + + tbl.append(make_input(_data([1, 2, 3], ["us", "eu", "us"]))) + + assert _rows(tbl.scan().to_arrow()) == _rows(_data([1, 2, 3], ["us", "eu", "us"])) + + +def test_overwrite_accepts_arrow_capsule(catalog: InMemoryCatalog) -> None: + tbl = catalog.create_table("default.overwrite_capsule", schema=SCHEMA) + tbl.append(_data([1, 2], ["us", "eu"])) + + tbl.overwrite(_ArrowStreamWrapper(_data([9], ["jp"]))) + + assert _rows(tbl.scan().to_arrow()) == _rows(_data([9], ["jp"])) + + +def test_append_accepts_arrow_capsule_with_string_view(catalog: InMemoryCatalog) -> None: + """Regression: Polars exports string columns as string_view over PyCapsule.""" + tbl = catalog.create_table("default.append_string_view", schema=SCHEMA) + + tbl.append(_ArrowStreamWrapper(_string_view_data([10, 11], ["ca", "mx"]))) + + assert _rows(tbl.scan().to_arrow()) == _rows(_data([10, 11], ["ca", "mx"])) + + +def test_append_pa_table_to_partitioned_table(catalog: InMemoryCatalog) -> None: + """Regression: a native pa.Table on a partitioned table must take the table + (partition-splitting) path, not be coerced into a RecordBatchReader (which + only supports unpartitioned writes).""" + tbl = catalog.create_table("default.append_partitioned", schema=SCHEMA, partition_spec=PARTITION_SPEC) + + tbl.append(_data([1, 2], ["us", "eu"])) + + assert tbl.scan().to_arrow().num_rows == 2 + + +# --------------------------------------------------------------------------- +# Output / producer side (issue #1655) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "produce", + [ + pytest.param(lambda tbl: tbl, id="table"), + pytest.param(lambda tbl: tbl.scan(), id="scan"), + ], +) +def test_supports_arrow_c_stream(catalog: InMemoryCatalog, produce: Callable[[Table], object]) -> None: + tbl = catalog.create_table("default.stream", schema=SCHEMA) + tbl.append(_data([1, 2, 3], ["us", "eu", "us"])) + + # A consumer ingests the Table/DataScan directly via the PyCapsule interface. + consumed = pa.table(produce(tbl)) + + assert _rows(consumed) == _rows(tbl.scan().to_arrow()) + + +def test_scan_arrow_c_stream_respects_filter_and_projection(catalog: InMemoryCatalog) -> None: + tbl = catalog.create_table("default.scan_stream_filtered", schema=SCHEMA) + tbl.append(_data([1, 2, 3], ["us", "eu", "us"])) + + scan = tbl.scan(row_filter="region == 'us'", selected_fields=("id",)) + consumed = pa.table(scan) + + assert consumed.column_names == ["id"] + assert sorted(consumed.column("id").to_pylist()) == [1, 3] + + +def test_capsule_roundtrip_scan_into_append(catalog: InMemoryCatalog) -> None: + """The two halves compose: a scan (producer) can be appended into another + table (consumer) with no explicit pyarrow conversion in between.""" + src = catalog.create_table("default.roundtrip_src", schema=SCHEMA) + src.append(_data([1, 2, 3], ["us", "eu", "us"])) + dst = catalog.create_table("default.roundtrip_dst", schema=SCHEMA) + + dst.append(src.scan()) + + assert _rows(dst.scan().to_arrow()) == _rows(src.scan().to_arrow())