Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions native/core/src/execution/columnar_to_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,17 @@ impl ColumnarToRowContext {
}

match (actual_type, schema_type) {
// Spark's StringType / BinaryType map to Utf8 / Binary at the JVM-declared
// schema layer, but the upstream native plan may legitimately emit the Large*
// variants (e.g. CometHashAggregate group-key promotion under
// `spark.comet.exec.useLargeDataTypes`). The downstream typed dispatch handles
// both i32 and i64 offsets natively (`TypedArray::String`/`LargeString` and
// `Binary`/`LargeBinary`), so pass the array through unchanged. The cast
// kernel would refuse a Large->small downcast whose absolute offsets exceed
// i32::MAX even when the logical slice would fit, so attempting the cast here
// is both unnecessary and incorrect.
(DataType::LargeUtf8, DataType::Utf8)
| (DataType::LargeBinary, DataType::Binary) => Ok(Arc::clone(array)),
(DataType::Dictionary(_, _), schema)
if !matches!(schema, DataType::Dictionary(_, _)) =>
{
Expand Down
63 changes: 45 additions & 18 deletions native/core/src/execution/operators/shuffle_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,36 +177,61 @@ impl ShuffleScanExec {

let num_rows = batch.num_rows();

// Extract column arrays, unpacking any dictionary-encoded columns.
// Native shuffle may dictionary-encode string/binary columns for efficiency,
// but downstream DataFusion operators expect the value types declared in the
// schema (e.g. Utf8, not Dictionary<Int32, Utf8>).
let columns: Vec<ArrayRef> = batch
.columns()
.iter()
.map(|col| unpack_dictionary(col))
.collect();

debug_assert_eq!(
columns.len(),
batch.num_columns(),
data_types.len(),
"Shuffle block column count mismatch: got {} but expected {}",
columns.len(),
batch.num_columns(),
data_types.len()
);

// Coerce each decoded column to the catalyst-declared type:
// * unpack any dictionary-encoded columns to their value type (native shuffle
// may dictionary-encode string/binary columns for efficiency);
// * downcast LargeUtf8/LargeBinary back to Utf8/Binary when an upstream
// aggregate opted into `spark.comet.exec.useLargeDataTypes` and wrote
// Large* into the shuffle block while catalyst still declares the small
// variant. The mirror of this coercion on the write side lives in
// `SchemaAlignExec` (native/shuffle/src/schema_align.rs).
let columns: Vec<ArrayRef> = batch
.columns()
.iter()
.zip(data_types.iter())
.map(|(col, expected)| coerce_to_declared(col, expected))
.collect::<Result<Vec<_>, CometError>>()?;

Ok(InputBatch::new(columns, Some(num_rows)))
})
}
}

/// If `array` is dictionary-encoded, cast it to the value type. Otherwise return as-is.
fn unpack_dictionary(array: &ArrayRef) -> ArrayRef {
if let DataType::Dictionary(_, value_type) = array.data_type() {
arrow::compute::cast(array, value_type.as_ref()).expect("failed to unpack dictionary array")
/// Coerce `array` to `expected`: unpack dictionary encoding when present, then downcast
/// any remaining type drift (e.g. `LargeUtf8`/`LargeBinary` → `Utf8`/`Binary`) so the
/// column matches what catalyst declared. Returns the input unchanged when no work is
/// needed. Propagates errors from the arrow cast kernel; the caller is `get_next` which
/// already returns `Result<InputBatch, CometError>`.
fn coerce_to_declared(array: &ArrayRef, expected: &DataType) -> Result<ArrayRef, CometError> {
// Step 1: unpack any dictionary encoding, then fall through to the type-mismatch check
// so a `Dictionary<_, LargeUtf8>` column with an expected `Utf8` type composes both
// steps rather than short-circuiting after the unpack.
let unpacked: ArrayRef = if let DataType::Dictionary(_, value_type) = array.data_type() {
arrow::compute::cast(array, value_type.as_ref()).map_err(|e| {
CometError::from(ExecutionError::DataFusionError(format!(
"failed to unpack dictionary array: {e}"
)))
})?
} else {
Arc::clone(array)
};
if unpacked.data_type() == expected {
return Ok(unpacked);
}
arrow::compute::cast(&unpacked, expected).map_err(|e| {
CometError::from(ExecutionError::DataFusionError(format!(
"failed to cast shuffle-scan column from {:?} to {expected:?}: {e}",
unpacked.data_type()
)))
})
}

fn schema_from_data_types(data_types: &[DataType]) -> SchemaRef {
Expand Down Expand Up @@ -465,11 +490,13 @@ mod tests {
)
.unwrap();

// Feed the decoded batch through unpack_dictionary (simulating get_next)
// Feed the decoded batch through coerce_to_declared (simulating get_next)
let expected_types = [DataType::Int32, DataType::Utf8];
let columns: Vec<ArrayRef> = decoded
.columns()
.iter()
.map(|col| super::unpack_dictionary(col))
.zip(expected_types.iter())
.map(|(col, expected)| super::coerce_to_declared(col, expected).unwrap())
.collect();
let input = InputBatch::new(columns, Some(decoded.num_rows()));
scan.set_input_batch(input);
Expand Down
98 changes: 91 additions & 7 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ use url::Url;

// For clippy error on type_complexity.
type PhyAggResult = Result<Vec<AggregateFunctionExpr>, ExecutionError>;
type PhyExprResult = Result<Vec<(Arc<dyn PhysicalExpr>, String)>, ExecutionError>;
type PartitionPhyExprResult = Result<Vec<Arc<dyn PhysicalExpr>>, ExecutionError>;
pub type PlanCreationResult =
Result<(Vec<ScanExec>, Vec<ShuffleScanExec>, Arc<SparkPlan>), ExecutionError>;
Expand Down Expand Up @@ -174,6 +173,32 @@ fn strip_timestamp_tz(
}
}

/// Promote a Utf8/Binary group-by expression to its Large* variant.
///
/// Gated by `spark.comet.exec.useLargeDataTypes`. The promotion is an
/// offset-width-only cast (i32 → i64) that makes DataFusion's HashAggregate
/// dispatch to `ByteGroupValueBuilder::<i64>`, removing the per-task `i32::MAX`
/// (2 GiB) cap on the group-key byte buffer. Returns the wrapped expression
/// together with the original DataType, so the caller can build a matching
/// cast-back projection above the aggregate; returns `None` when the input
/// is not Utf8/Binary (no promotion needed).
fn promote_byte_group_key(
expr: Arc<dyn PhysicalExpr>,
schema: &Schema,
) -> Result<(Arc<dyn PhysicalExpr>, Option<DataType>), ExecutionError> {
match expr.data_type(schema)? {
DataType::Utf8 => Ok((
Arc::new(CastExpr::new(expr, DataType::LargeUtf8, None)),
Some(DataType::Utf8),
)),
DataType::Binary => Ok((
Arc::new(CastExpr::new(expr, DataType::LargeBinary, None)),
Some(DataType::Binary),
)),
_ => Ok((expr, None)),
}
}

#[derive(Default)]
pub struct BinaryExprOptions {
pub is_integral_div: bool,
Expand Down Expand Up @@ -1092,17 +1117,39 @@ impl PhysicalPlanner {
let (scans, shuffle_scans, child) =
self.create_plan(&children[0], inputs, partition_count)?;

let group_exprs: PhyExprResult = agg
// When `spark.comet.exec.useLargeDataTypes` is on, wrap Utf8/Binary
// group keys in a Cast to LargeUtf8/LargeBinary so DataFusion dispatches
// to ByteGroupValueBuilder::<i64> and the per-task group-key byte buffer
// is no longer capped at i32::MAX (2 GiB). The promotion is reverted at
// the aggregate's output by a Projection below, so LargeUtf8/LargeBinary
// never leaves this operator -- keeps the FFI, JVM shuffle, and Spark
// consumer paths untouched.
let use_large = agg.use_large_data_types;
let child_schema_ref = child.schema();
let child_schema = child_schema_ref.as_ref();
// Per group column: `Some(original_dt)` when the column was promoted to a
// Large* variant, `None` when it was passed through as-is. Populated in
// lockstep with `group_exprs` and consumed below to build the revert
// projection.
let mut group_reverts: Vec<Option<DataType>> =
Vec::with_capacity(agg.grouping_exprs.len());
let group_exprs: Vec<(Arc<dyn PhysicalExpr>, String)> = agg
.grouping_exprs
.iter()
.enumerate()
.map(|(idx, expr)| {
self.create_expr(expr, child.schema())
.map(|r| (r, format!("col_{idx}")))
let raw = self.create_expr(expr, Arc::clone(&child_schema_ref))?;
let (wrapped, revert) = if use_large {
promote_byte_group_key(raw, child_schema)?
} else {
(raw, None)
};
group_reverts.push(revert);
Ok((wrapped, format!("col_{idx}")))
})
.collect();
let group_by = PhysicalGroupBy::new_single(group_exprs?);
let schema = child.schema();
.collect::<Result<Vec<_>, ExecutionError>>()?;
let group_by = PhysicalGroupBy::new_single(group_exprs);
let schema = Arc::clone(&child_schema_ref);

let proto_mode = ProtoAggregateMode::try_from(agg.mode).map_err(|_| {
ExecutionError::GeneralError(format!(
Expand Down Expand Up @@ -1218,6 +1265,36 @@ impl PhysicalPlanner {
)?,
);

// Cast promoted group columns back to their original Utf8/Binary type
// so LargeUtf8/LargeBinary never crosses the FFI boundary, JVM columnar
// shuffle, or the Spark consumer path (all of which reject Large*).
// Uses `SchemaAlignExec` (not a plain `ProjectionExec` + `CastExpr`)
// because arrow's cast kernel rejects any single Large* array whose
// value bytes exceed `i32::MAX` -- which is precisely the regime this
// flag is used in. `SchemaAlignExec` splits each batch by row ranges
// so every emitted small-offset chunk fits under 2 GiB.
let aggregate = if use_large && group_reverts.iter().any(Option::is_some) {
let agg_schema = aggregate.schema();
let target_fields: Vec<Field> = agg_schema
.fields()
.iter()
.enumerate()
.map(|(idx, f)| {
let target_dt = group_reverts
.get(idx)
.and_then(|r| r.clone())
.unwrap_or_else(|| f.data_type().clone());
Field::new(f.name(), target_dt, f.is_nullable())
.with_metadata(f.metadata().clone())
})
.collect();
let target_schema: SchemaRef = Arc::new(Schema::new(target_fields));
SchemaAlignExec::try_new_or_passthrough(aggregate, &target_schema)
.map_err(|e| ExecutionError::DataFusionError(e.to_string()))?
} else {
aggregate
};

Ok((
scans,
shuffle_scans,
Expand Down Expand Up @@ -1599,6 +1676,7 @@ impl PhysicalPlanner {
let writer_input = align_shuffle_writer_input(
Arc::clone(&child.native_plan),
&writer.expected_output_schema,
writer.use_large_data_types,
)?;

let partitioning = self.create_partitioning(
Expand Down Expand Up @@ -3463,9 +3541,15 @@ fn convert_spark_types_to_arrow_schema(

/// Wrap `child` in a `SchemaAlignExec` when its output drifts from what Spark catalyst
/// declared. See <https://github.com/apache/datafusion-comet/issues/4515>.
///
/// `use_large_data_types` is accepted for backward compatibility with older proto payloads
/// but is now a no-op: the HashAggregate reverts LargeUtf8/LargeBinary group columns to
/// their original Utf8/Binary type in its own output projection, so the shuffle writer's
/// child never emits Large* variants and the writer's expected schema stays Utf8/Binary.
fn align_shuffle_writer_input(
child: Arc<dyn ExecutionPlan>,
expected_proto: &[spark_operator::SparkStructField],
_use_large_data_types: bool,
) -> Result<Arc<dyn ExecutionPlan>, ExecutionError> {
if expected_proto.is_empty() {
return Ok(child);
Expand Down
11 changes: 11 additions & 0 deletions native/proto/src/proto/operator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,11 @@ message HashAggregate {
// Offset in the child's output where aggregate buffer attributes start.
// Used by PartialMerge to locate state fields in the input.
int32 initial_input_buffer_offset = 7;
// When true, the native planner wraps Utf8/Binary group-by expressions in a Cast to
// LargeUtf8/LargeBinary (and casts the aggregate's output back to Utf8/Binary), which
// promotes the group-key byte buffer from i32 to i64 offsets and removes the 2 GiB cap.
// Driven by spark.comet.exec.useLargeDataTypes.
bool use_large_data_types = 9;
}

message Limit {
Expand Down Expand Up @@ -341,6 +346,12 @@ message ShuffleWriter {
// to absorb DataFusion-vs-Spark type drift. Empty when the child is a placeholder Scan;
// that path already has a cast point upstream.
repeated SparkStructField expected_output_schema = 9;
// When true, the native planner promotes any Utf8/Binary field in
// `expected_output_schema` to LargeUtf8/LargeBinary before invoking SchemaAlignExec,
// so an aggregate that opted into `spark.comet.exec.useLargeDataTypes` and emits
// LargeUtf8 group keys can shuffle without a wasted Large->small->Large round-trip.
// The JVM shuffle read side maps LargeUtf8 back to Spark StringType transparently.
bool use_large_data_types = 10;
}

message ParquetWriter {
Expand Down
Loading
Loading