From 8b43aa3a90c318064796ab7ef4fad5fc036a075f Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 9 Aug 2024 22:50:02 +0100 Subject: [PATCH] rearrange can_cast_types and cast_with_options to get tests passing --- arrow-cast/src/cast/mod.rs | 304 +++++++++++++++++++------------------ 1 file changed, 160 insertions(+), 144 deletions(-) diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index 1ec5d4120805..3c735a9bfe34 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -150,7 +150,7 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (List(list_from) | LargeList(list_from), FixedSizeList(list_to, _)) => { can_cast_types(list_from.data_type(), list_to.data_type()) } - (List(_), _) => false, + (List(_) | LargeList(_), _) => false, (FixedSizeList(list_from,_), List(list_to)) | (FixedSizeList(list_from,_), LargeList(list_to)) => { can_cast_types(list_from.data_type(), list_to.data_type()) @@ -158,6 +158,10 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (FixedSizeList(inner, size), FixedSizeList(inner_to, size_to)) if size == size_to => { can_cast_types(inner.data_type(), inner_to.data_type()) } + + // TODO casting from unions to lists causes errors related to nullability, hence this block + (Union(_, _), List(_) | LargeList(_)) => false, + (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()), (_, LargeList(list_to)) => can_cast_types(from_type, list_to.data_type()), (_, FixedSizeList(list_to,size)) if *size == 1 => { @@ -170,6 +174,26 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { can_cast_types(from_key.data_type(), to_key.data_type()) && can_cast_types(from_value.data_type(), to_value.data_type()), _ => false }, + + (Struct(from_fields), Struct(to_fields)) => { + from_fields.len() == to_fields.len() && + from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { + // Assume that nullability between two structs are compatible, if not, + // cast kernel will return error. + can_cast_types(f1.data_type(), f2.data_type()) + }) + } + (Struct(_), _) => false, + (_, Struct(_)) => false, + + (Union(_, _), Union(_, _)) => false, + (Union(from_fields, _), _) => { + from_fields.iter().any(|(_, f)| can_cast_types(f.data_type(), to_type)) + }, + (_, Union(to_fields, _)) => { + to_fields.iter().any(|(_, t)| can_cast_types(from_type, t.data_type())) + }, + // cast one decimal type to another decimal type (Decimal128(_, _), Decimal128(_, _)) => true, (Decimal256(_, _), Decimal256(_, _)) => true, @@ -189,24 +213,6 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Decimal128(_, _) | Decimal256(_, _), Utf8 | LargeUtf8) => true, // Utf8 to decimal (Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, - (Struct(from_fields), Struct(to_fields)) => { - from_fields.len() == to_fields.len() && - from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { - // Assume that nullability between two structs are compatible, if not, - // cast kernel will return error. - can_cast_types(f1.data_type(), f2.data_type()) - }) - } - (Struct(_), _) => false, - (_, Struct(_)) => false, - - (Union(_, _), Union(_, _)) => false, - (Union(from_fields, _), _) => { - from_fields.iter().any(|(_, f)| can_cast_types(f.data_type(), to_type)) - }, - (_, Union(to_fields, _)) => { - to_fields.iter().any(|(_, t)| can_cast_types(from_type, t.data_type())) - }, (_, Boolean) => { DataType::is_integer(from_type) || @@ -813,6 +819,12 @@ pub fn cast_with_options( array.nulls().cloned(), )?)) } + + // TODO casting from unions to lists causes errors related to nullability, hence this block + (Union(_, _), List(_) | LargeList(_)) => Err(ArrowError::CastError( + "cannot cast union to list or large list".into(), + )), + (_, List(ref to)) => cast_values_to_list::(array, to, cast_options), (_, LargeList(ref to)) => cast_values_to_list::(array, to, cast_options), (_, FixedSizeList(ref to, size)) if *size == 1 => { @@ -824,6 +836,115 @@ pub fn cast_with_options( (Map(_, ordered1), Map(_, ordered2)) if ordered1 == ordered2 => { cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned()) } + + // structs + (Struct(_), Struct(to_fields)) => { + let array = array.as_struct(); + let fields = array + .columns() + .iter() + .zip(to_fields.iter()) + .map(|(l, field)| cast_with_options(l, field.data_type(), cast_options)) + .collect::, ArrowError>>()?; + let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?; + Ok(Arc::new(array) as ArrayRef) + } + (Struct(_), _) => Err(ArrowError::CastError( + "Cannot cast from struct to other types except struct".to_string(), + )), + (_, Struct(_)) => Err(ArrowError::CastError( + "Cannot cast to struct from other types except struct".to_string(), + )), + + // unions + // we might be able to support this, but it's complex + (Union(_, _), Union(_, _)) => Err(ArrowError::CastError( + "Cannot cast from union to union".to_string(), + )), + (Union(from_fields, _), _) => { + let Some((type_id, _)) = from_fields + .iter() + // try to find an exact match first + .find(|(_, f)| (f.data_type() == to_type)) + .or_else(|| { + // if no exact match, try to find a type that can be cast to + from_fields + .iter() + .find(|(_, f)| can_cast_types(f.data_type(), to_type)) + }) + else { + return Err(ArrowError::CastError(format!( + "Casting from union type to {to_type:?} not supported", + ))); + }; + + let union_array = array.as_any().downcast_ref::().unwrap(); + let child = union_field_array(union_array, type_id)?; + cast_with_options(child.as_ref(), to_type, cast_options) + } + (_, Union(to_fields, mode)) => { + let from_type = array.data_type(); + let Some((type_id, _)) = to_fields + .iter() + // try to find an exact match first + .find(|(_, f)| f.data_type() == from_type) + .or_else(|| { + // if no exact match, try to find a type that can be cast to + to_fields + .iter() + .find(|(_, f)| can_cast_types(from_type, f.data_type())) + }) + else { + return Err(ArrowError::CastError(format!( + "Casting from {from_type:?} to union type not supported", + ))); + }; + // type_ids is just type_id replicated for array.len() + let type_ids = std::iter::repeat(type_id) + .take(array.len()) + .collect::>(); + + let (offsets, children) = if mode == &UnionMode::Dense { + // offset ids are just `0..array.len()` + let offsets = (0i32..(array.len() as i32)).collect::>(); + + let children = to_fields + .iter() + .map(|(t, f)| { + if t == type_id { + // for the field with matching type_id, we cast the input data + cast_with_options(array, f.data_type(), cast_options) + } else { + // create empty ArrayRef's for other fields + Ok(new_empty_array(f.data_type())) + } + }) + .collect::, ArrowError>>()?; + (Some(offsets), children) + } else { + let children = to_fields + .iter() + .map(|(t, f)| { + if t == type_id { + // for the field with matching type_id, we cast the input data + cast_with_options(array, f.data_type(), cast_options) + } else { + // create empty ArrayRef's for other fields + Ok(new_null_array(f.data_type(), array.len())) + } + }) + .collect::, ArrowError>>()?; + (None, children) + }; + + Ok(Arc::new(UnionArray::try_new( + to_fields.clone(), + type_ids, + offsets, + children, + )?)) + } + (Decimal128(_, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal_same_type::( array.as_primitive(), @@ -1168,111 +1289,6 @@ pub fn cast_with_options( ))), } } - (Struct(_), Struct(to_fields)) => { - let array = array.as_struct(); - let fields = array - .columns() - .iter() - .zip(to_fields.iter()) - .map(|(l, field)| cast_with_options(l, field.data_type(), cast_options)) - .collect::, ArrowError>>()?; - let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?; - Ok(Arc::new(array) as ArrayRef) - } - (Struct(_), _) => Err(ArrowError::CastError( - "Cannot cast from struct to other types except struct".to_string(), - )), - (_, Struct(_)) => Err(ArrowError::CastError( - "Cannot cast to struct from other types except struct".to_string(), - )), - - // we might be able to support this, but it's complex - (Union(_, _), Union(_, _)) => Err(ArrowError::CastError( - "Cannot cast from union to union".to_string(), - )), - (Union(from_fields, _), _) => { - let Some((type_id, _)) = from_fields - .iter() - // try to find an exact match first - .find(|(_, f)| f.data_type() == to_type) - .or_else(|| { - // if no exact match, try to find a type that can be cast to - from_fields - .iter() - .find(|(_, f)| can_cast_types(f.data_type(), to_type)) - }) - else { - return Err(ArrowError::CastError(format!( - "Casting from union type to {to_type:?} not supported", - ))); - }; - - let union_array = array.as_any().downcast_ref::().unwrap(); - let child = union_field_array(union_array, type_id)?; - cast_with_options(child.as_ref(), to_type, cast_options) - } - (_, Union(to_fields, mode)) => { - let from_type = array.data_type(); - let Some((type_id, _)) = to_fields - .iter() - // try to find an exact match first - .find(|(_, f)| f.data_type() == from_type) - .or_else(|| { - // if no exact match, try to find a type that can be cast to - to_fields - .iter() - .find(|(_, f)| can_cast_types(from_type, f.data_type())) - }) - else { - return Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to union type not supported", - ))); - }; - // type_ids is just type_id replicated for array.len() - let type_ids = std::iter::repeat(type_id) - .take(array.len()) - .collect::>(); - - let (offsets, children) = if mode == &UnionMode::Dense { - // offset ids are just `0..array.len()` - let offsets = (0i32..(array.len() as i32)).collect::>(); - - let children = to_fields - .iter() - .map(|(t, f)| { - if t == type_id { - // for the field with matching type_id, we cast the input data - cast_with_options(array, f.data_type(), cast_options) - } else { - // create empty ArrayRef's for other fields - Ok(new_empty_array(f.data_type())) - } - }) - .collect::, ArrowError>>()?; - (Some(offsets), children) - } else { - let children = to_fields - .iter() - .map(|(t, f)| { - if t == type_id { - // for the field with matching type_id, we cast the input data - cast_with_options(array, f.data_type(), cast_options) - } else { - // create empty ArrayRef's for other fields - Ok(new_null_array(f.data_type(), array.len())) - } - }) - .collect::, ArrowError>>()?; - (None, children) - }; - - Ok(Arc::new(UnionArray::try_new( - to_fields.clone(), - type_ids, - offsets, - children, - )?)) - } (_, Boolean) => match from_type { UInt8 => cast_numeric_to_bool::(array), @@ -9698,23 +9714,23 @@ mod tests { let union_array = UnionArray::try_new(union_fields(), type_ids, None, children).unwrap(); - assert!(can_cast_types(&union_array.data_type(), &DataType::Int64)); - assert!(can_cast_types(&DataType::Int64, &union_array.data_type())); + assert!(can_cast_types(union_array.data_type(), &DataType::Int64)); + assert!(can_cast_types(&DataType::Int64, union_array.data_type())); - assert!(can_cast_types(&union_array.data_type(), &DataType::Float64)); - assert!(can_cast_types(&DataType::Float64, &union_array.data_type())); + assert!(can_cast_types(union_array.data_type(), &DataType::Float64)); + assert!(can_cast_types(&DataType::Float64, union_array.data_type())); - assert!(can_cast_types(&union_array.data_type(), &DataType::Utf8)); - assert!(can_cast_types(&DataType::Utf8, &union_array.data_type())); + assert!(can_cast_types(union_array.data_type(), &DataType::Utf8)); + assert!(can_cast_types(&DataType::Utf8, union_array.data_type())); assert!(!can_cast_types( - &union_array.data_type(), + union_array.data_type(), &DataType::Duration(TimeUnit::Second) )); // Duration to Utf8 is allowed assert!(can_cast_types( &DataType::Duration(TimeUnit::Second), - &union_array.data_type() + union_array.data_type() )); let cast_array = cast(&union_array, &DataType::Int64).unwrap(); @@ -9728,7 +9744,7 @@ mod tests { #[test] fn int_to_sparse_union_cast() { - let ints = Int64Array::from_iter_values(vec![1, 2, 3].into_iter()); + let ints = Int64Array::from_iter_values(vec![1, 2, 3]); let dt = DataType::Union(union_fields(), UnionMode::Sparse); let as_union = cast(&ints, &dt).unwrap(); @@ -9748,7 +9764,7 @@ mod tests { assert_eq!(as_int_vec::(&as_union.value(1)), vec![Some(2)]); assert_eq!(as_int_vec::(&as_union.value(2)), vec![Some(3)]); - let strings = StringArray::from_iter_values(vec!["a", "b", "c"].into_iter()); + let strings = StringArray::from_iter_values(vec!["a", "b", "c"]); let cast_array = cast(&strings, &dt).unwrap(); let as_union = cast_array.as_any().downcast_ref::().unwrap(); @@ -9779,23 +9795,23 @@ mod tests { let union_array = UnionArray::try_new(union_fields(), type_ids, Some(offsets), children).unwrap(); - assert!(can_cast_types(&union_array.data_type(), &DataType::Int64)); - assert!(can_cast_types(&DataType::Int64, &union_array.data_type())); + assert!(can_cast_types(union_array.data_type(), &DataType::Int64)); + assert!(can_cast_types(&DataType::Int64, union_array.data_type())); - assert!(can_cast_types(&union_array.data_type(), &DataType::Float64)); - assert!(can_cast_types(&DataType::Float64, &union_array.data_type())); + assert!(can_cast_types(union_array.data_type(), &DataType::Float64)); + assert!(can_cast_types(&DataType::Float64, union_array.data_type())); - assert!(can_cast_types(&union_array.data_type(), &DataType::Utf8)); - assert!(can_cast_types(&DataType::Utf8, &union_array.data_type())); + assert!(can_cast_types(union_array.data_type(), &DataType::Utf8)); + assert!(can_cast_types(&DataType::Utf8, union_array.data_type())); assert!(!can_cast_types( - &union_array.data_type(), + union_array.data_type(), &DataType::Duration(TimeUnit::Second) )); // Duration to Utf8 is allowed assert!(can_cast_types( &DataType::Duration(TimeUnit::Second), - &union_array.data_type() + union_array.data_type() )); let cast_array = cast(&union_array, &DataType::Int64).unwrap(); @@ -9809,7 +9825,7 @@ mod tests { #[test] fn int_to_dense_union_cast() { - let ints = Int64Array::from_iter_values(vec![1, 2, 3].into_iter()); + let ints = Int64Array::from_iter_values(vec![1, 2, 3]); let dt = DataType::Union(union_fields(), UnionMode::Dense); let cast_array = cast(&ints, &dt).unwrap(); @@ -9829,7 +9845,7 @@ mod tests { assert_eq!(as_int_vec::(&as_union.value(1)), vec![Some(2)]); assert_eq!(as_int_vec::(&as_union.value(2)), vec![Some(3)]); - let strings = StringArray::from_iter_values(vec!["a", "b", "c"].into_iter()); + let strings = StringArray::from_iter_values(vec!["a", "b", "c"]); let cast_array = cast(&strings, &dt).unwrap(); let as_union = cast_array.as_any().downcast_ref::().unwrap();