diff --git a/native/core/src/execution/columnar_to_row.rs b/native/core/src/execution/columnar_to_row.rs index 14e115cba0..65925a0f0c 100644 --- a/native/core/src/execution/columnar_to_row.rs +++ b/native/core/src/execution/columnar_to_row.rs @@ -1028,6 +1028,17 @@ impl ColumnarToRowContext { } match (actual_type, schema_type) { + // Spark's StringType / BinaryType map to Utf8 / Binary at the JVM-declared + // schema layer, but the upstream native plan may legitimately emit the Large* + // variants (e.g. CometHashAggregate group-key promotion under + // `spark.comet.exec.useLargeDataTypes`). The downstream typed dispatch handles + // both i32 and i64 offsets natively (`TypedArray::String`/`LargeString` and + // `Binary`/`LargeBinary`), so pass the array through unchanged. The cast + // kernel would refuse a Large->small downcast whose absolute offsets exceed + // i32::MAX even when the logical slice would fit, so attempting the cast here + // is both unnecessary and incorrect. + (DataType::LargeUtf8, DataType::Utf8) + | (DataType::LargeBinary, DataType::Binary) => Ok(Arc::clone(array)), (DataType::Dictionary(_, _), schema) if !matches!(schema, DataType::Dictionary(_, _)) => { diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index f3209b0c1b..cf47f62141 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -177,36 +177,61 @@ impl ShuffleScanExec { let num_rows = batch.num_rows(); - // Extract column arrays, unpacking any dictionary-encoded columns. - // Native shuffle may dictionary-encode string/binary columns for efficiency, - // but downstream DataFusion operators expect the value types declared in the - // schema (e.g. Utf8, not Dictionary). - let columns: Vec = batch - .columns() - .iter() - .map(|col| unpack_dictionary(col)) - .collect(); - debug_assert_eq!( - columns.len(), + batch.num_columns(), data_types.len(), "Shuffle block column count mismatch: got {} but expected {}", - columns.len(), + batch.num_columns(), data_types.len() ); + // Coerce each decoded column to the catalyst-declared type: + // * unpack any dictionary-encoded columns to their value type (native shuffle + // may dictionary-encode string/binary columns for efficiency); + // * downcast LargeUtf8/LargeBinary back to Utf8/Binary when an upstream + // aggregate opted into `spark.comet.exec.useLargeDataTypes` and wrote + // Large* into the shuffle block while catalyst still declares the small + // variant. The mirror of this coercion on the write side lives in + // `SchemaAlignExec` (native/shuffle/src/schema_align.rs). + let columns: Vec = batch + .columns() + .iter() + .zip(data_types.iter()) + .map(|(col, expected)| coerce_to_declared(col, expected)) + .collect::, CometError>>()?; + Ok(InputBatch::new(columns, Some(num_rows))) }) } } -/// If `array` is dictionary-encoded, cast it to the value type. Otherwise return as-is. -fn unpack_dictionary(array: &ArrayRef) -> ArrayRef { - if let DataType::Dictionary(_, value_type) = array.data_type() { - arrow::compute::cast(array, value_type.as_ref()).expect("failed to unpack dictionary array") +/// Coerce `array` to `expected`: unpack dictionary encoding when present, then downcast +/// any remaining type drift (e.g. `LargeUtf8`/`LargeBinary` → `Utf8`/`Binary`) so the +/// column matches what catalyst declared. Returns the input unchanged when no work is +/// needed. Propagates errors from the arrow cast kernel; the caller is `get_next` which +/// already returns `Result`. +fn coerce_to_declared(array: &ArrayRef, expected: &DataType) -> Result { + // Step 1: unpack any dictionary encoding, then fall through to the type-mismatch check + // so a `Dictionary<_, LargeUtf8>` column with an expected `Utf8` type composes both + // steps rather than short-circuiting after the unpack. + let unpacked: ArrayRef = if let DataType::Dictionary(_, value_type) = array.data_type() { + arrow::compute::cast(array, value_type.as_ref()).map_err(|e| { + CometError::from(ExecutionError::DataFusionError(format!( + "failed to unpack dictionary array: {e}" + ))) + })? } else { Arc::clone(array) + }; + if unpacked.data_type() == expected { + return Ok(unpacked); } + arrow::compute::cast(&unpacked, expected).map_err(|e| { + CometError::from(ExecutionError::DataFusionError(format!( + "failed to cast shuffle-scan column from {:?} to {expected:?}: {e}", + unpacked.data_type() + ))) + }) } fn schema_from_data_types(data_types: &[DataType]) -> SchemaRef { @@ -465,11 +490,13 @@ mod tests { ) .unwrap(); - // Feed the decoded batch through unpack_dictionary (simulating get_next) + // Feed the decoded batch through coerce_to_declared (simulating get_next) + let expected_types = [DataType::Int32, DataType::Utf8]; let columns: Vec = decoded .columns() .iter() - .map(|col| super::unpack_dictionary(col)) + .zip(expected_types.iter()) + .map(|(col, expected)| super::coerce_to_declared(col, expected).unwrap()) .collect(); let input = InputBatch::new(columns, Some(decoded.num_rows())); scan.set_input_batch(input); diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 553fe5215c..017aad3a54 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -142,7 +142,6 @@ use url::Url; // For clippy error on type_complexity. type PhyAggResult = Result, ExecutionError>; -type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; pub type PlanCreationResult = Result<(Vec, Vec, Arc), ExecutionError>; @@ -174,6 +173,32 @@ fn strip_timestamp_tz( } } +/// Promote a Utf8/Binary group-by expression to its Large* variant. +/// +/// Gated by `spark.comet.exec.useLargeDataTypes`. The promotion is an +/// offset-width-only cast (i32 → i64) that makes DataFusion's HashAggregate +/// dispatch to `ByteGroupValueBuilder::`, removing the per-task `i32::MAX` +/// (2 GiB) cap on the group-key byte buffer. Returns the wrapped expression +/// together with the original DataType, so the caller can build a matching +/// cast-back projection above the aggregate; returns `None` when the input +/// is not Utf8/Binary (no promotion needed). +fn promote_byte_group_key( + expr: Arc, + schema: &Schema, +) -> Result<(Arc, Option), ExecutionError> { + match expr.data_type(schema)? { + DataType::Utf8 => Ok(( + Arc::new(CastExpr::new(expr, DataType::LargeUtf8, None)), + Some(DataType::Utf8), + )), + DataType::Binary => Ok(( + Arc::new(CastExpr::new(expr, DataType::LargeBinary, None)), + Some(DataType::Binary), + )), + _ => Ok((expr, None)), + } +} + #[derive(Default)] pub struct BinaryExprOptions { pub is_integral_div: bool, @@ -1092,17 +1117,39 @@ impl PhysicalPlanner { let (scans, shuffle_scans, child) = self.create_plan(&children[0], inputs, partition_count)?; - let group_exprs: PhyExprResult = agg + // When `spark.comet.exec.useLargeDataTypes` is on, wrap Utf8/Binary + // group keys in a Cast to LargeUtf8/LargeBinary so DataFusion dispatches + // to ByteGroupValueBuilder:: and the per-task group-key byte buffer + // is no longer capped at i32::MAX (2 GiB). The promotion is reverted at + // the aggregate's output by a Projection below, so LargeUtf8/LargeBinary + // never leaves this operator -- keeps the FFI, JVM shuffle, and Spark + // consumer paths untouched. + let use_large = agg.use_large_data_types; + let child_schema_ref = child.schema(); + let child_schema = child_schema_ref.as_ref(); + // Per group column: `Some(original_dt)` when the column was promoted to a + // Large* variant, `None` when it was passed through as-is. Populated in + // lockstep with `group_exprs` and consumed below to build the revert + // projection. + let mut group_reverts: Vec> = + Vec::with_capacity(agg.grouping_exprs.len()); + let group_exprs: Vec<(Arc, String)> = agg .grouping_exprs .iter() .enumerate() .map(|(idx, expr)| { - self.create_expr(expr, child.schema()) - .map(|r| (r, format!("col_{idx}"))) + let raw = self.create_expr(expr, Arc::clone(&child_schema_ref))?; + let (wrapped, revert) = if use_large { + promote_byte_group_key(raw, child_schema)? + } else { + (raw, None) + }; + group_reverts.push(revert); + Ok((wrapped, format!("col_{idx}"))) }) - .collect(); - let group_by = PhysicalGroupBy::new_single(group_exprs?); - let schema = child.schema(); + .collect::, ExecutionError>>()?; + let group_by = PhysicalGroupBy::new_single(group_exprs); + let schema = Arc::clone(&child_schema_ref); let proto_mode = ProtoAggregateMode::try_from(agg.mode).map_err(|_| { ExecutionError::GeneralError(format!( @@ -1218,6 +1265,36 @@ impl PhysicalPlanner { )?, ); + // Cast promoted group columns back to their original Utf8/Binary type + // so LargeUtf8/LargeBinary never crosses the FFI boundary, JVM columnar + // shuffle, or the Spark consumer path (all of which reject Large*). + // Uses `SchemaAlignExec` (not a plain `ProjectionExec` + `CastExpr`) + // because arrow's cast kernel rejects any single Large* array whose + // value bytes exceed `i32::MAX` -- which is precisely the regime this + // flag is used in. `SchemaAlignExec` splits each batch by row ranges + // so every emitted small-offset chunk fits under 2 GiB. + let aggregate = if use_large && group_reverts.iter().any(Option::is_some) { + let agg_schema = aggregate.schema(); + let target_fields: Vec = agg_schema + .fields() + .iter() + .enumerate() + .map(|(idx, f)| { + let target_dt = group_reverts + .get(idx) + .and_then(|r| r.clone()) + .unwrap_or_else(|| f.data_type().clone()); + Field::new(f.name(), target_dt, f.is_nullable()) + .with_metadata(f.metadata().clone()) + }) + .collect(); + let target_schema: SchemaRef = Arc::new(Schema::new(target_fields)); + SchemaAlignExec::try_new_or_passthrough(aggregate, &target_schema) + .map_err(|e| ExecutionError::DataFusionError(e.to_string()))? + } else { + aggregate + }; + Ok(( scans, shuffle_scans, @@ -1599,6 +1676,7 @@ impl PhysicalPlanner { let writer_input = align_shuffle_writer_input( Arc::clone(&child.native_plan), &writer.expected_output_schema, + writer.use_large_data_types, )?; let partitioning = self.create_partitioning( @@ -3463,9 +3541,15 @@ fn convert_spark_types_to_arrow_schema( /// Wrap `child` in a `SchemaAlignExec` when its output drifts from what Spark catalyst /// declared. See . +/// +/// `use_large_data_types` is accepted for backward compatibility with older proto payloads +/// but is now a no-op: the HashAggregate reverts LargeUtf8/LargeBinary group columns to +/// their original Utf8/Binary type in its own output projection, so the shuffle writer's +/// child never emits Large* variants and the writer's expected schema stays Utf8/Binary. fn align_shuffle_writer_input( child: Arc, expected_proto: &[spark_operator::SparkStructField], + _use_large_data_types: bool, ) -> Result, ExecutionError> { if expected_proto.is_empty() { return Ok(child); diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 9d81d2853b..bdb93f07d4 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -311,6 +311,11 @@ message HashAggregate { // Offset in the child's output where aggregate buffer attributes start. // Used by PartialMerge to locate state fields in the input. int32 initial_input_buffer_offset = 7; + // When true, the native planner wraps Utf8/Binary group-by expressions in a Cast to + // LargeUtf8/LargeBinary (and casts the aggregate's output back to Utf8/Binary), which + // promotes the group-key byte buffer from i32 to i64 offsets and removes the 2 GiB cap. + // Driven by spark.comet.exec.useLargeDataTypes. + bool use_large_data_types = 9; } message Limit { @@ -341,6 +346,12 @@ message ShuffleWriter { // to absorb DataFusion-vs-Spark type drift. Empty when the child is a placeholder Scan; // that path already has a cast point upstream. repeated SparkStructField expected_output_schema = 9; + // When true, the native planner promotes any Utf8/Binary field in + // `expected_output_schema` to LargeUtf8/LargeBinary before invoking SchemaAlignExec, + // so an aggregate that opted into `spark.comet.exec.useLargeDataTypes` and emits + // LargeUtf8 group keys can shuffle without a wasted Large->small->Large round-trip. + // The JVM shuffle read side maps LargeUtf8 back to Spark StringType transparently. + bool use_large_data_types = 10; } message ParquetWriter { diff --git a/native/shuffle/src/schema_align.rs b/native/shuffle/src/schema_align.rs index b5566becef..d1dca0b032 100644 --- a/native/shuffle/src/schema_align.rs +++ b/native/shuffle/src/schema_align.rs @@ -35,9 +35,12 @@ //! for the running list of mismatched //! functions. -use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow::array::{ + Array, ArrayRef, BinaryBuilder, LargeBinaryArray, LargeStringArray, RecordBatch, + RecordBatchOptions, StringBuilder, +}; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::{Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::common::DataFusionError; use datafusion::physical_expr::EquivalenceProperties; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; @@ -50,7 +53,7 @@ use datafusion::{ }; use futures::{Stream, StreamExt}; use std::{ - collections::HashSet, + collections::{HashSet, VecDeque}, pin::Pin, sync::{Arc, Mutex, OnceLock}, task::{Context, Poll}, @@ -83,6 +86,14 @@ enum ColumnAction { Passthrough, /// Cast the input column to the target data_type. Cast, + /// Cast a LargeUtf8 input down to Utf8 (i64 -> i32 offsets). Listed separately because + /// the cast kernel rejects any column whose values buffer exceeds `i32::MAX`, so the + /// stream must pre-split input batches into row ranges whose per-batch byte total + /// stays under that cap. + CastLargeStringToString, + /// Cast a LargeBinary input down to Binary (i64 -> i32 offsets). Same shrink-split + /// constraint as `CastLargeStringToString`. + CastLargeBinaryToBinary, } impl SchemaAlignExec { @@ -130,7 +141,15 @@ impl SchemaAlignExec { expected_field.data_type() ); } - ColumnAction::Cast + match (actual_field.data_type(), expected_field.data_type()) { + (DataType::LargeUtf8, DataType::Utf8) => { + ColumnAction::CastLargeStringToString + } + (DataType::LargeBinary, DataType::Binary) => { + ColumnAction::CastLargeBinaryToBinary + } + _ => ColumnAction::Cast, + } }; let target_nullable = actual_field.is_nullable() || expected_field.is_nullable(); let field_changed = !matches!(action, ColumnAction::Passthrough) @@ -219,6 +238,7 @@ impl ExecutionPlan for SchemaAlignExec { child_stream, target_schema: Arc::clone(&self.target_schema), column_actions: Arc::clone(&self.column_actions), + pending: VecDeque::new(), })) } @@ -235,10 +255,40 @@ struct SchemaAlignStream { child_stream: SendableRecordBatchStream, target_schema: SchemaRef, column_actions: Arc>, + /// Sub-batches produced by the last input batch and not yet yielded. Used when a + /// `CastLargeStringToString` / `CastLargeBinaryToBinary` column would overflow the + /// destination `i32` offsets, so the input is split into multiple Utf8/Binary outputs. + pending: VecDeque, } +/// `i32::MAX` bytes — the cap on a Utf8/Binary values buffer (its offsets are `i32`). +const I32_BYTE_CAP: i64 = i32::MAX as i64; + impl SchemaAlignStream { - fn align(&self, batch: RecordBatch) -> Result { + /// Apply the per-column actions to `batch` and push the resulting (possibly multiple) + /// aligned batches into `out`. Splits the input by row ranges when any + /// `CastLargeStringToString` / `CastLargeBinaryToBinary` column would otherwise emit a + /// values buffer larger than `i32::MAX`. + fn align_into( + &self, + batch: RecordBatch, + out: &mut VecDeque, + ) -> Result<(), DataFusionError> { + let ranges = self.compute_row_ranges(&batch)?; + for (start, length) in ranges { + let slice = if start == 0 && length == batch.num_rows() { + batch.clone() + } else { + batch.slice(start, length) + }; + out.push_back(self.align_slice(slice)?); + } + Ok(()) + } + + /// Apply `column_actions` to a single row range that is already known to fit each + /// shrinking-cast column's destination offset width. + fn align_slice(&self, batch: RecordBatch) -> Result { let mut columns: Vec = Vec::with_capacity(batch.num_columns()); for (idx, action) in self.column_actions.iter().enumerate() { let column = batch.column(idx); @@ -249,6 +299,55 @@ impl SchemaAlignStream { self.target_schema.field(idx).data_type(), &CastOptions::default(), )?, + // Build a fresh Utf8/Binary array from the slice rather than calling + // arrow's cast kernel. `cast_byte_container` reads the underlying + // offsets buffer in full and verifies every absolute offset fits the + // destination offset type — slicing the source array does not rebase + // the offsets, so a slice that is logically small can still trip the + // i32::MAX check if its offsets sit far into the values buffer. We + // copy values explicitly so the new offsets start at 0. + ColumnAction::CastLargeStringToString => { + let arr = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "SchemaAlignExec: column[{idx}] expected LargeStringArray, \ + got {:?}", + column.data_type() + )) + })?; + let mut builder = StringBuilder::with_capacity(arr.len(), 0); + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + builder.append_value(arr.value(i)); + } + } + Arc::new(builder.finish()) as ArrayRef + } + ColumnAction::CastLargeBinaryToBinary => { + let arr = column + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "SchemaAlignExec: column[{idx}] expected LargeBinaryArray, \ + got {:?}", + column.data_type() + )) + })?; + let mut builder = BinaryBuilder::with_capacity(arr.len(), 0); + for i in 0..arr.len() { + if arr.is_null(i) { + builder.append_null(); + } else { + builder.append_value(arr.value(i)); + } + } + Arc::new(builder.finish()) as ArrayRef + } }; columns.push(aligned); } @@ -256,15 +355,134 @@ impl SchemaAlignStream { RecordBatch::try_new_with_options(Arc::clone(&self.target_schema), columns, &options) .map_err(DataFusionError::from) } + + /// Compute `(start_row, length)` ranges that split `batch` so each shrinking-cast + /// column's per-slice values buffer stays under `i32::MAX`. Returns a single full-batch + /// range when no split is needed (the common case). + fn compute_row_ranges( + &self, + batch: &RecordBatch, + ) -> Result, DataFusionError> { + let num_rows = batch.num_rows(); + if num_rows == 0 { + return Ok(vec![(0, 0)]); + } + + // Per-column running-byte cursors used to decide each split point. Empty when no + // column requires shrink-splitting; in that case we always return a single range. + let mut shrinking_cols: Vec<&[i64]> = Vec::new(); + for (idx, action) in self.column_actions.iter().enumerate() { + match action { + ColumnAction::CastLargeStringToString => { + let arr = batch + .column(idx) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "SchemaAlignExec: column[{idx}] expected LargeStringArray for \ + LargeUtf8 -> Utf8 cast, got {:?}", + batch.column(idx).data_type() + )) + })?; + shrinking_cols.push(arr.value_offsets()); + } + ColumnAction::CastLargeBinaryToBinary => { + let arr = batch + .column(idx) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "SchemaAlignExec: column[{idx}] expected LargeBinaryArray for \ + LargeBinary -> Binary cast, got {:?}", + batch.column(idx).data_type() + )) + })?; + shrinking_cols.push(arr.value_offsets()); + } + _ => {} + } + } + + if shrinking_cols.is_empty() { + return Ok(vec![(0, num_rows)]); + } + + let mut ranges: Vec<(usize, usize)> = Vec::new(); + let mut chunk_start = 0usize; + // The base offsets at the start of the current chunk, one per shrinking col, used to + // measure each row's contribution to the chunk so far. + let mut chunk_base: Vec = shrinking_cols + .iter() + .map(|offs| offs[chunk_start]) + .collect(); + + for row in 0..num_rows { + // First pass: reject the whole batch if any single row already exceeds the + // destination cap on ANY shrinking column. Must scan every column even after + // a split fires later in this iteration -- otherwise a fat single value in a + // column past the one that triggered the split would slip through and later + // panic inside StringBuilder/BinaryBuilder when the row lands in a chunk. + for (col_idx, offsets) in shrinking_cols.iter().enumerate() { + let row_bytes = offsets[row + 1] - offsets[row]; + if row_bytes > I32_BYTE_CAP { + return Err(DataFusionError::Execution(format!( + "SchemaAlignExec: cannot cast Large variant down to small offsets — \ + row {row} of column[{col_idx}] is {row_bytes} bytes which exceeds the \ + i32 offset cap ({I32_BYTE_CAP} bytes)" + ))); + } + } + + // Second pass: decide whether to close the current chunk before this row. + // The oversized-row guard above already ensures the new chunk's first row fits. + for col_idx in 0..shrinking_cols.len() { + let projected = shrinking_cols[col_idx][row + 1] - chunk_base[col_idx]; + if projected > I32_BYTE_CAP { + let length = row - chunk_start; + debug_assert!(length > 0, "split would emit an empty chunk"); + ranges.push((chunk_start, length)); + chunk_start = row; + for (i, offs) in shrinking_cols.iter().enumerate() { + chunk_base[i] = offs[chunk_start]; + } + break; + } + } + } + + // Flush the final chunk (always non-empty: starts at chunk_start <= num_rows-1). + ranges.push((chunk_start, num_rows - chunk_start)); + Ok(ranges) + } } impl Stream for SchemaAlignStream { type Item = datafusion::common::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.child_stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(batch))) => Poll::Ready(Some(self.align(batch))), - other => other, + loop { + // Drain any sub-batches buffered from a prior split before pulling more input. + if let Some(ready) = self.pending.pop_front() { + return Poll::Ready(Some(Ok(ready))); + } + match self.child_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + let mut buf = std::mem::take(&mut self.pending); + if let Err(e) = self.align_into(batch, &mut buf) { + self.pending = buf; + return Poll::Ready(Some(Err(e))); + } + self.pending = buf; + // Loop back to pop_front and yield the first sub-batch (or pull again on + // an input batch that produced zero outputs, e.g. zero-row inputs). + continue; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } } } } diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 351ccfa777..985d39a6b9 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -29,7 +29,8 @@ use crate::conversion_funcs::numeric::{ use crate::conversion_funcs::string::{ cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int, cast_string_to_timestamp, cast_string_to_timestamp_ntz, - is_df_cast_from_string_spark_compatible, spark_cast_utf8_to_boolean, + is_df_cast_from_large_string_spark_compatible, is_df_cast_from_string_spark_compatible, + spark_cast_utf8_to_boolean, }; use crate::conversion_funcs::temporal::{ cast_date_to_timestamp, is_df_cast_from_date_spark_compatible, @@ -497,12 +498,19 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b is_df_cast_from_decimal_spark_compatible(to_type) } DataType::Utf8 => is_df_cast_from_string_spark_compatible(to_type), + DataType::LargeUtf8 => is_df_cast_from_large_string_spark_compatible(to_type), DataType::Date32 => is_df_cast_from_date_spark_compatible(to_type), DataType::Timestamp(_, _) => is_df_cast_from_timestamp_spark_compatible(to_type), DataType::Binary => { // note that this is not completely Spark compatible because // DataFusion only supports binary data containing valid UTF-8 strings - matches!(to_type, DataType::Utf8) + matches!(to_type, DataType::Utf8 | DataType::LargeBinary) + } + DataType::LargeBinary => { + // Symmetric with Binary: allow the offset-width narrowing back to + // Binary. Utf8 conversion is not offered here because arrow does not + // provide a direct LargeBinary -> Utf8 cast that validates the bytes. + matches!(to_type, DataType::Binary) } _ => false, } diff --git a/native/spark-expr/src/conversion_funcs/string.rs b/native/spark-expr/src/conversion_funcs/string.rs index adfd6e2390..74f7850e46 100644 --- a/native/spark-expr/src/conversion_funcs/string.rs +++ b/native/spark-expr/src/conversion_funcs/string.rs @@ -171,7 +171,20 @@ impl TimeStampInfo { } pub(crate) fn is_df_cast_from_string_spark_compatible(to_type: &DataType) -> bool { - matches!(to_type, DataType::Binary) + // Utf8 -> Binary is a zero-copy reinterpret; Utf8 -> LargeUtf8 is a pure + // offset-width widening handled by arrow's cast_byte_container. Both are + // Spark-equivalent (Spark's StringType and BinaryType are indifferent to + // Arrow's offset width). + matches!(to_type, DataType::Binary | DataType::LargeUtf8) +} + +pub(crate) fn is_df_cast_from_large_string_spark_compatible(to_type: &DataType) -> bool { + // Symmetric mirror of `is_df_cast_from_string_spark_compatible` for the + // LargeUtf8 source type. LargeUtf8 -> Utf8 is an i64 -> i32 offset narrowing + // that arrow's cast_byte_container performs with an overflow check per row; + // arrays whose values buffer exceeds `i32::MAX` bytes are pre-split by + // CometSchemaAlignExec before the cast reaches this code path. + matches!(to_type, DataType::LargeBinary | DataType::Utf8) } pub(crate) fn cast_string_to_float( diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 8e47151358..22058fd6c3 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -368,6 +368,26 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_AGG_USE_LARGE_DATATYPES: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.useLargeDataTypes") + .category(CATEGORY_EXEC) + .doc( + "When true, Comet wraps Utf8/Binary group-by expressions inside a Cast to " + + "LargeUtf8/LargeBinary immediately before each native HashAggregate, and wraps the " + + "aggregate's output in a Projection that casts those columns back to Utf8/Binary. " + + "This promotes DataFusion's per-task group-key byte buffer from `i32` offsets " + + "(2 GiB hard cap) to `i64` offsets, removing the `offset overflow, buffer size > " + + "2147483647` failure that can hit CUBE / GROUPING SETS / COUNT(DISTINCT) workloads " + + "where a single partition accumulates more than 2 GiB of distinct string keys. " + + "Scan, shuffle, and the JVM FFI boundary remain Utf8/Binary, so neither downstream " + + "Spark nor Comet's shuffle (which does not support LargeUtf8) is affected. The cast " + + "is an offset-width promotion and only rebuilds the offset buffer (the value bytes " + + "are shared), so per-batch overhead is O(rows) not O(bytes). Defaults to false " + + "because the cap is only reachable for very large per-partition group cardinalities; " + + "enable it when you see the offset-overflow error.") + .booleanConf + .createWithDefault(false) + val COMET_SCALA_UDF_CODEGEN_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.exec.scalaUDF.codegen.enabled") .category(CATEGORY_EXEC) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index a486e2e861..1490ccc716 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -348,6 +348,10 @@ class CometNativeShuffleWriter[K, V]( .map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)) .toArray schema2Proto(expectedFields).foreach(shuffleWriterBuilder.addExpectedOutputSchema) + // Mirror of `spark.comet.exec.useLargeDataTypes` so the native planner can decide + // whether SchemaAlignExec should downcast LargeUtf8/LargeBinary back to Utf8/Binary + // (default behaviour) or preserve the Large* variants end-to-end through shuffle. + shuffleWriterBuilder.setUseLargeDataTypes(CometConf.COMET_AGG_USE_LARGE_DATATYPES.get()) OperatorOuterClass.Operator .newBuilder() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index af9e1df8a3..86072ffe56 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1613,6 +1613,8 @@ trait CometBaseAggregate { if (aggregateExpressions.isEmpty) { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) + hashAggBuilder.setUseLargeDataTypes( + CometConf.COMET_AGG_USE_LARGE_DATATYPES.get(aggregate.conf)) buildAggOp( builder, hashAggBuilder, @@ -1690,6 +1692,8 @@ trait CometBaseAggregate { hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) hashAggBuilder.setModeValue(mode.getNumber) + hashAggBuilder.setUseLargeDataTypes( + CometConf.COMET_AGG_USE_LARGE_DATATYPES.get(aggregate.conf)) // Send per-expression modes and buffer offset for PartialMerge handling if (hasPartialMerge) { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 15e1e2c410..c6264c94c2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -97,7 +97,9 @@ object Utils extends CometTypeShim with Logging { case float: ArrowType.FloatingPoint if float.getPrecision == FloatingPointPrecision.DOUBLE => DoubleType case ArrowType.Utf8.INSTANCE => StringType + case ArrowType.LargeUtf8.INSTANCE => StringType case ArrowType.Binary.INSTANCE => BinaryType + case ArrowType.LargeBinary.INSTANCE => BinaryType case _: ArrowType.FixedSizeBinary => BinaryType case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType @@ -418,7 +420,8 @@ object Utils extends CometTypeShim with Logging { valueVector match { case v @ (_: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector | _: BigIntVector | _: Float4Vector | _: Float8Vector | _: VarCharVector | - _: DecimalVector | _: DateDayVector | _: TimeStampMicroTZVector | _: VarBinaryVector | + _: LargeVarCharVector | _: DecimalVector | _: DateDayVector | + _: TimeStampMicroTZVector | _: VarBinaryVector | _: LargeVarBinaryVector | _: FixedSizeBinaryVector | _: TimeStampMicroVector | _: StructVector | _: ListVector | _: MapVector | _: NullVector | _: TimeNanoVector) => v.asInstanceOf[FieldVector] diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index ae14c68207..63cb936bd3 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.functions.{avg, col, count_distinct, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataTypes, StructField, StructType} -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometNativeException} import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT import org.apache.comet.CometSparkSessionExtensions.isSpark41Plus import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGenerator, SchemaGenOptions} @@ -45,8 +45,19 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Several aggregate tests exercise overflow behavior expected to wrap around silently; // ANSI-mode variants opt in to ANSI explicitly via withSQLConf. + // + // For the CUBE(9) offset-overflow repro at the bottom of the suite we MUST disable + // off-heap (CometTestBase defaults to off-heap + 2 GiB cap, which spills long before + // i32::MAX) and force the on-heap memory pool to `unbounded`. Both knobs are read + // from SparkConf at SparkContext init (CometExecIterator.scala:295/310), NOT from + // per-query SQLConf, so they must live here. override protected def sparkConf: SparkConf = - super.sparkConf.set(SQLConf.ANSI_ENABLED.key, "false") + super.sparkConf + .set(SQLConf.ANSI_ENABLED.key, "false") + .set("spark.memory.offHeap.enabled", "false") + .set(CometConf.COMET_ONHEAP_MEMORY_POOL_TYPE.key, "unbounded") + .set("spark.ui.enabled", "true") + .set("spark.ui.port", "4040") test("min/max floating point with negative zero") { val r = new Random(42) @@ -2150,4 +2161,112 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("CUBE(9) + COUNT(DISTINCT) wide Utf8 keys: useLargeDataTypes preserves correctness") { + import org.apache.spark.sql.functions.{count_distinct, grouping} + + def writeData(path: String, numRows: Long, wideBytes: Int): Unit = { + val userPadLen = wideBytes - "user-".length + val measurePadLen = wideBytes - "meas-".length + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + "parquet.enable.dictionary" -> "false", + CometConf.COMET_ENABLED.key -> "false") { + spark + .range(0, numRows, 1, 1) + .selectExpr( + s"concat('user-', lpad(cast(id as string), $userPadLen, '0')) as userId", + s"concat('meas-', lpad(cast(id as string), $measurePadLen, '0')) as distinctMeasure", + "cast(id % 4 as string) as dim1", + "cast(id % 3 as string) as dim2", + "cast(id % 6 as string) as dim3", + "cast(id % 5 as string) as dim4", + "cast(id % 7 as string) as dim5", + "cast(id % 2 as string) as dim6", + "cast(id % 8 as string) as dim7", + "cast(id % 9 as string) as dim8") + .write + .mode("overwrite") + .parquet(path) + } + } + + def buildAggDf(path: String): DataFrame = { + val events = spark.read.parquet(path).coalesce(1) + events.createOrReplaceTempView("events") + val dims = Seq("dim1", "dim2", "dim3", "dim4", "dim5", "dim6", "dim7", "dim8") + val cubeCols = dims :+ "userId" + val groupingFlags = cubeCols.map(c => grouping(col(c)).as(s"__g_$c")) + spark + .table("events") + .cube(cubeCols.map(col): _*) + .agg(count_distinct(col("distinctMeasure")).as("uniq_measure"), groupingFlags: _*) + .filter("__g_userId = 0") + } + + val baseConf: Seq[(String, String)] = Seq( + SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "false", + // Spark's whole-stage codegen generates a single doConsume method per Expand + // with two parameters per group column; with 10 group cols this exceeds + // Janino's per-method limits and the Comet-disabled baseline that + // checkSparkAnswerAndOperator runs to produce the reference answer fails to + // compile. Disable codegen so the baseline falls back to the interpreted path. + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_ONHEAP_MEMORY_POOL_TYPE.key -> "unbounded") + + withTempDir { dir => + // ---- Branch 1: useLargeDataTypes=false ---- + // __g_userId=0 filter pushes the per-task ByteGroupValueBuilder byte + // buffer over i32::MAX (~2.95 GiB cumulative) and the native partial + // aggregate trips bytes.rs:202-206 + // "offset overflow, buffer size > 2147483647" + val path = new Path(dir.toURI.toString, "events_heavy").toString + writeData(path, numRows = 30000L, wideBytes = 384) + withSQLConf((baseConf :+ (CometConf.COMET_AGG_USE_LARGE_DATATYPES.key -> "false")): _*) { + withTempView("events") { + val df = buildAggDf(path) + val plan = df.queryExecution.executedPlan + val cometAggs = collectWithSubqueries(plan) { case a: CometHashAggregateExec => a } + assert( + cometAggs.nonEmpty, + s"With useLargeDataTypes=false, expected at least one " + + s"CometHashAggregateExec in plan, got:\n$plan") + val e = intercept[Throwable] { + df.collect() + } + val chain = + Iterator.iterate(e: Throwable)(_.getCause).takeWhile(_ != null).toList ++ + Iterator + .iterate(e: Throwable)(_.getCause) + .takeWhile(_ != null) + .flatMap(c => Option(c.getSuppressed).toList.flatten) + assert( + chain.exists(t => Option(t.getMessage).exists(_.contains("offset overflow"))), + s"With useLargeDataTypes=false, expected an 'offset overflow' error, got:\n" + + chain.map(t => s" ${t.getClass.getName}: ${t.getMessage}").mkString("\n")) + } + } + + withSQLConf((baseConf :+ (CometConf.COMET_AGG_USE_LARGE_DATATYPES.key -> "true")): _*) { + withTempView("events") { + val df = buildAggDf(path) + // The full result set is ~15M rows × ~400 bytes ≈ 6 GB (CUBE(9) fan-out + // over 30 k rows), which drives `collect()` -> driver OOM at the default + // 4 GB heap. Compare via a fixed-size summary that still round-trips + // every group through the LargeUtf8 aggregate: row count, sum of the + // scalar measure, and a checksum over the 9 grouping flags. + val summary = df + .agg( + org.apache.spark.sql.functions.count("*").as("n"), + org.apache.spark.sql.functions.sum("uniq_measure").as("sum_uniq")) + checkSparkAnswerAndOperator(summary) + } + } + } + } }