From 9a4a79b24c70de1d546064fa1e3da6897a66f442 Mon Sep 17 00:00:00 2001 From: Mitchell Date: Fri, 3 Jul 2026 20:31:53 -0500 Subject: [PATCH] fix: handle null sub-arrays in flatten --- native/spark-expr/src/array_funcs/flatten.rs | 403 ++++++++++++++++++ native/spark-expr/src/array_funcs/mod.rs | 2 + native/spark-expr/src/comet_scalar_funcs.rs | 6 +- .../sql-tests/expressions/array/flatten.sql | 13 +- 4 files changed, 419 insertions(+), 5 deletions(-) create mode 100644 native/spark-expr/src/array_funcs/flatten.rs diff --git a/native/spark-expr/src/array_funcs/flatten.rs b/native/spark-expr/src/array_funcs/flatten.rs new file mode 100644 index 0000000000..77a5438fb9 --- /dev/null +++ b/native/spark-expr/src/array_funcs/flatten.rs @@ -0,0 +1,403 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Spark-compatible flatten(array>). +// +// DataFusion's flatten preserves the outer array null bitmap only. Spark returns NULL for a row +// when any sub-array inside that row is NULL, so Comet needs a Spark-specific null bitmap. + +use arrow::array::{Array, ArrayRef, GenericListArray, NullBufferBuilder, OffsetSizeTrait}; +use arrow::buffer::{NullBuffer, OffsetBuffer}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, + Field, FieldRef, +}; +use datafusion::common::cast::{as_large_list_array, as_list_array}; +use datafusion::common::{exec_err, utils::take_function_args, Result}; +use datafusion::logical_expr::{ + ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::sync::Arc; + +use super::arrays_zip::make_scalar_function; + +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkFlatten { + signature: Signature, +} + +impl Default for SparkFlatten { + fn default() -> Self { + Self::new() + } +} + +impl SparkFlatten { + pub fn new() -> Self { + Self { + signature: Signature::array(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkFlatten { + fn name(&self) -> &str { + "flatten" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + spark_flatten_return_type(&arg_types[0]) + } + + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let [arg_field] = take_function_args(self.name(), args.arg_fields)?; + let data_type = spark_flatten_return_type(arg_field.data_type())?; + let nullable = match arg_field.data_type() { + List(field) | LargeList(field) => arg_field.is_nullable() || field.is_nullable(), + Null => true, + _ => { + return exec_err!( + "Not reachable, data_type should be List, LargeList or FixedSizeList" + ) + } + }; + + Ok(Arc::new(Field::new(self.name(), data_type, nullable))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(spark_flatten_inner)(&args.args) + } +} + +fn spark_flatten_return_type(arg_type: &DataType) -> Result { + let data_type = match arg_type { + List(field) => match field.data_type() { + List(field) | FixedSizeList(field, _) => List(Arc::clone(field)), + LargeList(field) => LargeList(Arc::clone(field)), + _ => arg_type.clone(), + }, + LargeList(field) => match field.data_type() { + List(field) | LargeList(field) | FixedSizeList(field, _) => { + LargeList(Arc::clone(field)) + } + _ => arg_type.clone(), + }, + Null => Null, + _ => exec_err!("Not reachable, data_type should be List, LargeList or FixedSizeList")?, + }; + + Ok(data_type) +} + +fn spark_flatten_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("flatten", args)?; + + match array.data_type() { + List(_) => { + let outer = as_list_array(array)?; + let (_field, offsets, values, _outer_nulls) = outer.clone().into_parts(); + let values = cast_fsl_to_list(values)?; + + match values.data_type() { + List(_) => { + let inner = as_list_array(&values)?; + let nulls = spark_flatten_nulls(outer, inner); + let (inner_field, inner_offsets, inner_values, _) = inner.clone().into_parts(); + let offsets = get_offsets_for_flatten::(inner_offsets, &offsets); + let flattened_array = + GenericListArray::::new(inner_field, offsets, inner_values, nulls); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + LargeList(_) => { + let inner = as_large_list_array(&values)?; + let nulls = spark_flatten_nulls(outer, inner); + let (inner_field, inner_offsets, inner_values, _) = inner.clone().into_parts(); + let offsets = get_offsets_for_flatten::(inner_offsets, &offsets); + let flattened_array = + GenericListArray::::new(inner_field, offsets, inner_values, nulls); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + _ => Ok(Arc::clone(array) as ArrayRef), + } + } + LargeList(_) => { + let outer = as_large_list_array(array)?; + let (_field, offsets, values, _outer_nulls) = outer.clone().into_parts(); + let values = cast_fsl_to_list(values)?; + + match values.data_type() { + List(_) => { + let inner = as_list_array(&values)?; + let nulls = spark_flatten_nulls(outer, inner); + let (inner_field, inner_offsets, inner_values, _) = inner.clone().into_parts(); + let offsets = get_large_offsets_for_flatten(inner_offsets, &offsets); + let flattened_array = + GenericListArray::::new(inner_field, offsets, inner_values, nulls); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + LargeList(_) => { + let inner = as_large_list_array(&values)?; + let nulls = spark_flatten_nulls(outer, inner); + let (inner_field, inner_offsets, inner_values, _) = inner.clone().into_parts(); + let offsets = get_offsets_for_flatten::(inner_offsets, &offsets); + let flattened_array = + GenericListArray::::new(inner_field, offsets, inner_values, nulls); + + Ok(Arc::new(flattened_array) as ArrayRef) + } + _ => Ok(Arc::clone(array) as ArrayRef), + } + } + Null => Ok(Arc::clone(array)), + _ => { + exec_err!("flatten does not support type '{}'", array.data_type()) + } + } +} + +fn spark_flatten_nulls( + outer: &GenericListArray

, + inner: &GenericListArray, +) -> Option { + let mut nulls = NullBufferBuilder::new(outer.len()); + let inner_nulls = inner.nulls(); + + for (row, offset_window) in outer.offsets().windows(2).enumerate() { + if outer.is_null(row) { + nulls.append_null(); + continue; + } + + let start = offset_window[0].to_usize().unwrap(); + let end = offset_window[1].to_usize().unwrap(); + let has_null_subarray = + inner_nulls.is_some_and(|n| (start..end).any(|inner_row| n.is_null(inner_row))); + + if has_null_subarray { + nulls.append_null(); + } else { + nulls.append_non_null(); + } + } + + nulls.finish() +} + +fn get_offsets_for_flatten( + inner_offsets: OffsetBuffer, + outer_offsets: &OffsetBuffer

, +) -> OffsetBuffer { + let buffer = inner_offsets.into_inner(); + let offsets: Vec = outer_offsets + .iter() + .map(|i| buffer[i.to_usize().unwrap()]) + .collect(); + OffsetBuffer::new(offsets.into()) +} + +fn get_large_offsets_for_flatten( + inner_offsets: OffsetBuffer, + outer_offsets: &OffsetBuffer

, +) -> OffsetBuffer { + let buffer = inner_offsets.into_inner(); + let offsets: Vec = outer_offsets + .iter() + .map(|i| buffer[i.to_usize().unwrap()].to_i64().unwrap()) + .collect(); + OffsetBuffer::new(offsets.into()) +} + +fn cast_fsl_to_list(array: ArrayRef) -> Result { + match array.data_type() { + FixedSizeList(field, _) => Ok(arrow::compute::cast(&array, &List(Arc::clone(field)))?), + _ => Ok(array), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, Int32Builder, ListArray, ListBuilder}; + use datafusion::common::ScalarValue; + + type IntElement = Option; + type InnerArray = Option>; + type OuterArray = Option>; + + #[test] + fn test_flatten_null_subarray_returns_null_row() { + let input = nested_int_list_array(vec![ + Some(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + ]), + Some(vec![Some(vec![Some(1)]), None]), + Some(vec![None, None]), + ]); + + let result = spark_flatten_inner(&[Arc::new(input)]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 3); + assert_list_value(result, 0, &[Some(1), Some(2), Some(3), Some(4), Some(5)]); + assert!(result.is_null(1)); + assert!(result.is_null(2)); + } + + #[test] + fn test_flatten_preserves_null_elements() { + let input = nested_int_list_array(vec![Some(vec![ + Some(vec![Some(1), None]), + Some(vec![None, Some(2)]), + ])]); + + let result = spark_flatten_inner(&[Arc::new(input)]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.len(), 1); + assert_list_value(result, 0, &[Some(1), None, None, Some(2)]); + } + + #[test] + fn test_flatten_allows_scalar_input() { + let input = nested_int_list_array(vec![Some(vec![Some(vec![Some(1), Some(2)])])]); + let scalar = ColumnarValue::Scalar(ScalarValue::try_from_array(&input, 0).unwrap()); + + let result = SparkFlatten::new() + .invoke_with_args(ScalarFunctionArgs { + args: vec![scalar], + arg_fields: vec![], + number_rows: 1, + return_field: Arc::new(arrow::datatypes::Field::new("flatten", Null, true)), + config_options: Arc::new(datafusion::config::ConfigOptions::default()), + }) + .unwrap(); + + assert!(matches!(result, ColumnarValue::Scalar(_))); + } + + fn nested_int_list_array(rows: Vec) -> ListArray { + let inner_builder = ListBuilder::new(Int32Builder::new()); + let mut outer_builder = ListBuilder::new(inner_builder); + + for row in rows { + match row { + Some(subarrays) => { + let inner_builder = outer_builder.values(); + for subarray in subarrays { + match subarray { + Some(values) => { + for value in values { + match value { + Some(value) => inner_builder.values().append_value(value), + None => inner_builder.values().append_null(), + } + } + inner_builder.append(true); + } + None => inner_builder.append(false), + } + } + outer_builder.append(true); + } + None => outer_builder.append(false), + } + } + + outer_builder.finish() + } + + fn assert_list_value(list_array: &ListArray, row: usize, expected: &[Option]) { + let values = list_array.value(row); + let values = values.as_any().downcast_ref::().unwrap(); + let actual = values.iter().collect::>(); + assert_eq!(actual, expected); + } + + #[test] + fn test_return_type() { + let input_type = List(Arc::new(arrow::datatypes::Field::new_list_field( + List(Arc::new(arrow::datatypes::Field::new_list_field( + DataType::Int32, + true, + ))), + true, + ))); + assert_eq!( + SparkFlatten::new().return_type(&[input_type]).unwrap(), + List(Arc::new(arrow::datatypes::Field::new_list_field( + DataType::Int32, + true, + ))) + ); + } + + #[test] + fn test_return_field_nullability_matches_spark() { + let array_of_arrays = |outer_nullable, contains_null_subarray| { + let element_field = Arc::new(arrow::datatypes::Field::new_list_field( + DataType::Int32, + true, + )); + Arc::new(arrow::datatypes::Field::new( + "arg", + List(Arc::new(arrow::datatypes::Field::new_list_field( + List(element_field), + contains_null_subarray, + ))), + outer_nullable, + )) + }; + + let scalar_args: [Option<&ScalarValue>; 1] = [None]; + let arg_fields = [array_of_arrays(false, false)]; + let field = SparkFlatten::new() + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + assert!(!field.is_nullable()); + + let arg_fields = [array_of_arrays(false, true)]; + let field = SparkFlatten::new() + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + assert!(field.is_nullable()); + + let arg_fields = [array_of_arrays(true, false)]; + let field = SparkFlatten::new() + .return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_args, + }) + .unwrap(); + assert!(field.is_nullable()); + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 057b0462ee..b1361e3b72 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -21,6 +21,7 @@ mod array_position; mod array_slice; mod arrays_overlap; mod arrays_zip; +mod flatten; mod get_array_struct_fields; mod list_extract; mod size; @@ -31,6 +32,7 @@ pub use array_position::SparkArrayPositionFunc; pub use array_slice::SparkArraySlice; pub use arrays_overlap::SparkArraysOverlap; pub use arrays_zip::SparkArraysZipFunc; +pub use flatten::SparkFlatten; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; pub use size::{spark_size, SparkSizeFunc}; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 42ee72c82a..c102b01849 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -27,8 +27,8 @@ use crate::{ spark_isnan, spark_lpad, spark_make_decimal, spark_month_name, spark_read_side_padding, spark_round, spark_rpad, spark_to_time, spark_unhex, spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArrayPositionFunc, SparkArraySlice, SparkArraysOverlap, SparkContains, - SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate, SparkMakeTime, - SparkNextDay, SparkSecondsToTimestamp, SparkSizeFunc, + SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkFlatten, SparkMakeDate, + SparkMakeTime, SparkNextDay, SparkSecondsToTimestamp, SparkSizeFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -224,6 +224,7 @@ pub fn create_comet_physical_fun_with_eval_mode( "to_time" => { make_comet_scalar_udf!("to_time", spark_to_time, without data_type, fail_on_error) } + "flatten" => Ok(Arc::new(ScalarUDF::new_from_impl(SparkFlatten::new()))), // make_date and next_day must be constructed with the ANSI flag (fail_on_error) so they // throw on invalid input under ANSI rather than returning NULL. "make_date" => Ok(Arc::new(ScalarUDF::new_from_impl(SparkMakeDate::new( @@ -250,6 +251,7 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkDateDiff::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateFromUnixDate::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), + Arc::new(ScalarUDF::new_from_impl(SparkFlatten::default())), Arc::new(ScalarUDF::new_from_impl(SparkMakeDate::default())), Arc::new(ScalarUDF::new_from_impl(SparkMakeTime::default())), Arc::new(ScalarUDF::new_from_impl(SparkNextDay::default())), diff --git a/spark/src/test/resources/sql-tests/expressions/array/flatten.sql b/spark/src/test/resources/sql-tests/expressions/array/flatten.sql index 8a9c60df24..b45e705a56 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/flatten.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/flatten.sql @@ -19,11 +19,18 @@ statement CREATE TABLE test_flatten(arr array>) USING parquet statement -INSERT INTO test_flatten VALUES (array(array(1, 2), array(3, 4))), (array(array(), array(1))), (array()), (NULL), (array(array(1, NULL), array(NULL))) +INSERT INTO test_flatten VALUES + (array(array(1, 2), array(3, 4))), + (array(array(), array(1))), + (array()), + (NULL), + (array(array(1, NULL), array(NULL))), + (array(array(1), CAST(NULL AS ARRAY))), + (array(CAST(NULL AS ARRAY), CAST(NULL AS ARRAY))) -query spark_answer_only +query SELECT flatten(arr) FROM test_flatten -- literal arguments -query spark_answer_only +query SELECT flatten(array(array(1, 2), array(3, 4)))