Skip to content

Commit

Permalink
Towards #75: Further refactor for using GATs + clippy fixes for rust …
Browse files Browse the repository at this point in the history
…1.65-beta
  • Loading branch information
ncpenke committed Oct 10, 2022
1 parent 37f5500 commit 8d5f8dd
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 153 deletions.
137 changes: 66 additions & 71 deletions arrow2_convert/src/deserialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,36 @@ use crate::field::*;

#[doc(hidden)]
/// Type whose reference can be used to create an iterator.
pub trait IterRef {
pub trait RefIntoIterator: Sized {
/// Iterator type.
type Iter<'a>: Iterator
type Iterator<'a>: Iterator
where
Self: 'a;

/// Converts `&self` into an iterator.
fn iter_ref(&self) -> Self::Iter<'_>;
fn ref_into_iter(&self) -> Self::Iterator<'_>;
}

impl<T> IterRef for T
impl<T> RefIntoIterator for T
where
for<'a> &'a T: IntoIterator,
{
type Iter<'a> = <&'a T as IntoIterator>::IntoIter where Self: 'a;
type Iterator<'a> = <&'a T as IntoIterator>::IntoIter<> where Self: 'a;

#[inline]
fn iter_ref(&self) -> Self::Iter<'_> {
fn ref_into_iter(&self) -> Self::Iterator<'_> {
self.into_iter()
}
}

/// Implemented by [`ArrowField`] that can be deserialized from arrow
pub trait ArrowDeserialize: ArrowField + Sized {
/// The `arrow2::Array` type corresponding to this field
type ArrayType: ArrowArray;
type ArrayType: RefIntoIterator;

/// Deserialize this field from arrow
fn arrow_deserialize(
v: <<Self::ArrayType as IterRef>::Iter<'_> as Iterator>::Item,
v: <<Self::ArrayType as RefIntoIterator>::Iterator<'_> as Iterator>::Item,
) -> Option<<Self as ArrowField>::Type>;

#[inline]
Expand All @@ -48,23 +48,36 @@ pub trait ArrowDeserialize: ArrowField + Sized {
/// something like for<'a> &'a T::ArrayType: IntoIterator<Item=Option<E>>,
/// However, the E parameter seems to confuse the borrow checker if it's a reference.
fn arrow_deserialize_internal(
v: <<Self::ArrayType as IterRef>::Iter<'_> as Iterator>::Item,
v: <<Self::ArrayType as RefIntoIterator>::Iterator<'_> as Iterator>::Item,
) -> <Self as ArrowField>::Type {
Self::arrow_deserialize(v).unwrap()
}
}

/// Internal trait used to support deserialization and iteration of structs, and nested struct lists
///
/// Trivial pass-thru implementations are provided for arrow2 arrays that auto-implement IterRef.
///
/// The derive macro generates implementations for typed struct arrays.
#[doc(hidden)]
pub trait ArrowArray: IterRef {
type BaseArrayType: Array;

// Returns a typed iterator to the underlying elements of the array from an untyped Array reference.
fn iter_from_array_ref(b: &dyn Array) -> <Self as IterRef>::Iter<'_>;
#[inline]
#[doc(hidden)]
/// For internal use only
///
/// TODO: this can be removed up by using arrow2::array::StructArray and
/// arrow2::array::UnionArray to perform the iteration for unions and structs
/// which should be possible if structs and unions are deserialized via scalars.
///
/// Helper to return an iterator for elements from a [`arrow2::array::Array`].
///
/// Overridden by struct and enum arrays generated by the derive macro, to
/// downcast to the arrow2 array type.
fn arrow_array_ref_into_iter(
array: &dyn Array,
) -> Option<<Self::ArrayType as RefIntoIterator>::Iterator<'_>>
where
Self::ArrayType: 'static,
{
Some(
array
.as_any()
.downcast_ref::<Self::ArrayType>()?
.ref_into_iter(),
)
}
}

// Macro to facilitate implementation for numeric types and numeric arrays.
Expand All @@ -78,23 +91,6 @@ macro_rules! impl_arrow_deserialize_primitive {
v.map(|t| *t)
}
}

impl_arrow_array!(PrimitiveArray<$physical_type>);
};
}

macro_rules! impl_arrow_array {
($array:ty) => {
impl ArrowArray for $array {
type BaseArrayType = Self;

fn iter_from_array_ref(b: &dyn Array) -> <Self as IterRef>::Iter<'_> {
b.as_any()
.downcast_ref::<Self::BaseArrayType>()
.unwrap()
.iter_ref()
}
}
};
}

Expand All @@ -107,17 +103,26 @@ where

#[inline]
fn arrow_deserialize(
v: <<Self::ArrayType as IterRef>::Iter<'_> as Iterator>::Item,
v: <<Self::ArrayType as RefIntoIterator>::Iterator<'_> as Iterator>::Item,
) -> Option<<Self as ArrowField>::Type> {
Self::arrow_deserialize_internal(v).map(Some)
}

#[inline]
fn arrow_deserialize_internal(
v: <<Self::ArrayType as IterRef>::Iter<'_> as Iterator>::Item,
v: <<Self::ArrayType as RefIntoIterator>::Iterator<'_> as Iterator>::Item,
) -> <Self as ArrowField>::Type {
<T as ArrowDeserialize>::arrow_deserialize(v)
}

fn arrow_array_ref_into_iter(
array: &dyn Array,
) -> Option<<Self::ArrayType as RefIntoIterator>::Iterator<'_>>
where
Self::ArrayType: 'static,
{
<T as ArrowDeserialize>::arrow_array_ref_into_iter(array)
}
}

impl_arrow_deserialize_primitive!(u8);
Expand All @@ -140,8 +145,6 @@ impl<const PRECISION: usize, const SCALE: usize> ArrowDeserialize for I128<PRECI
}
}

impl_arrow_array!(PrimitiveArray<i128>);

impl ArrowDeserialize for String {
type ArrayType = Utf8Array<i32>;

Expand Down Expand Up @@ -221,10 +224,10 @@ where
T: ArrowDeserialize + ArrowEnableVecForType + 'static,
{
use std::ops::Deref;
v.map(|t| {
arrow_array_deserialize_iterator_internal::<<T as ArrowField>::Type, T>(t.deref())
.collect::<Vec<<T as ArrowField>::Type>>()
})
Some(
arrow_array_deserialize_iterator_internal::<<T as ArrowField>::Type, T>(v?.deref())?
.collect::<Vec<<T as ArrowField>::Type>>(),
)
}

// Blanket implementation for Vec
Expand Down Expand Up @@ -261,16 +264,6 @@ where
}
}

impl_arrow_array!(BooleanArray);
impl_arrow_array!(Utf8Array<i32>);
impl_arrow_array!(Utf8Array<i64>);
impl_arrow_array!(BinaryArray<i32>);
impl_arrow_array!(BinaryArray<i64>);
impl_arrow_array!(FixedSizeBinaryArray);
impl_arrow_array!(ListArray<i32>);
impl_arrow_array!(ListArray<i64>);
impl_arrow_array!(FixedSizeListArray);

/// Top-level API to deserialize from Arrow
pub trait TryIntoCollection<Collection, Element>
where
Expand All @@ -288,40 +281,42 @@ where
}

/// Helper to return an iterator for elements from a [`arrow2::array::Array`].
fn arrow_array_deserialize_iterator_internal<'a, Element, Field>(
b: &'a dyn Array,
) -> impl Iterator<Item = Element> + 'a
fn arrow_array_deserialize_iterator_internal<Element, Field>(
b: &dyn Array,
) -> Option<impl Iterator<Item = Element> + '_>
where
Field: ArrowDeserialize + ArrowField<Type = Element> + 'static,
{
<<Field as ArrowDeserialize>::ArrayType as ArrowArray>::iter_from_array_ref(b)
.map(<Field as ArrowDeserialize>::arrow_deserialize_internal)
Some(
<Field as ArrowDeserialize>::arrow_array_ref_into_iter(b)?
.map(<Field as ArrowDeserialize>::arrow_deserialize_internal),
)
}

/// Returns a typed iterator to a target type from an `arrow2::Array`
pub fn arrow_array_deserialize_iterator_as_type<'a, Element, ArrowType>(
arr: &'a dyn Array,
) -> arrow2::error::Result<impl Iterator<Item = Element> + 'a>
pub fn arrow_array_deserialize_iterator_as_type<Element, ArrowType>(
arr: &dyn Array,
) -> arrow2::error::Result<impl Iterator<Item = Element> + '_>
where
Element: 'static,
ArrowType: ArrowDeserialize + ArrowField<Type = Element> + 'static,
{
if &<ArrowType as ArrowField>::data_type() != arr.data_type() {
// TODO: use arrow2_convert error type here and include more detail
Err(arrow2::error::Error::InvalidArgumentError(
"Data type mismatch".to_string(),
))
} else {
Ok(arrow_array_deserialize_iterator_internal::<
Element,
ArrowType,
>(arr))
arrow_array_deserialize_iterator_internal::<Element, ArrowType>(arr).ok_or_else(||
// TODO: use arrow2_convert error type here and include more detail
arrow2::error::Error::InvalidArgumentError("Schema mismatch".to_string()))
}
}

/// Return an iterator that deserializes an [`Array`] to an element of type T
pub fn arrow_array_deserialize_iterator<'a, T>(
arr: &'a dyn Array,
) -> arrow2::error::Result<impl Iterator<Item = T> + 'a>
pub fn arrow_array_deserialize_iterator<T>(
arr: &dyn Array,
) -> arrow2::error::Result<impl Iterator<Item = T> + '_>
where
T: ArrowDeserialize + ArrowField<Type = T> + 'static,
{
Expand Down
33 changes: 0 additions & 33 deletions arrow2_convert/tests/test_round_trip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,6 @@ use arrow2_convert::{
use std::borrow::Borrow;
use std::sync::Arc;

#[test]
fn test_nested_optional_struct_array() {
#[derive(Debug, Clone, ArrowField, PartialEq)]
struct Top {
child_array: Vec<Option<Child>>,
}
#[derive(Debug, Clone, ArrowField, PartialEq)]
struct Child {
a1: i64,
}

let original_array = vec![
Top {
child_array: vec![
Some(Child { a1: 10 }),
None,
Some(Child { a1: 12 }),
Some(Child { a1: 14 }),
],
},
Top {
child_array: vec![None, None, None, None],
},
Top {
child_array: vec![None, None, Some(Child { a1: 12 }), None],
},
];

let b: Box<dyn Array> = original_array.try_into_arrow().unwrap();
let round_trip: Vec<Top> = b.try_into_collection().unwrap();
assert_eq!(original_array, round_trip);
}

#[test]
fn test_large_string() {
let strs = vec!["1".to_string(), "2".to_string()];
Expand Down
14 changes: 14 additions & 0 deletions arrow2_convert/tests/test_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@ use arrow2_convert::deserialize::*;
use arrow2_convert::serialize::*;
use arrow2_convert::ArrowField;

#[test]
fn test_optional_struct_array() {
#[derive(Debug, Clone, ArrowField, PartialEq)]
struct Foo {
field: i32,
}

let original_array = vec![Some(Foo { field: 0 }), None, Some(Foo { field: 10 })];
let b: Box<dyn Array> = original_array.try_into_arrow().unwrap();
println!("{:?}", b.data_type());
let round_trip: Vec<Option<Foo>> = b.try_into_collection().unwrap();
assert_eq!(original_array, round_trip);
}

#[test]
fn test_nested_optional_struct_array() {
#[derive(Debug, Clone, ArrowField, PartialEq)]
Expand Down
40 changes: 17 additions & 23 deletions arrow2_convert_derive/src/derive_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -446,28 +446,6 @@ pub fn expand(input: DeriveEnum) -> TokenStream {
{}
};

let array_impl = quote! {
impl arrow2_convert::deserialize::ArrowArray for #array_name
{
type BaseArrayType = arrow2::array::UnionArray;

#[inline]
fn iter_from_array_ref<'a>(b: &'a dyn arrow2::array::Array) -> <&'a Self as IntoIterator>::IntoIter
{
use core::ops::Deref;
let arr = b.as_any().downcast_ref::<arrow2::array::UnionArray>().unwrap();
let fields = arr.fields();

#iterator_name {
#(
#variant_names: <<#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::ArrayType as arrow2_convert::deserialize::ArrowArray>::iter_from_array_ref(fields[#variant_indices].deref()),
)*
types_iter: arr.types().iter(),
}
}
}
};

let array_into_iterator_impl = quote! {
impl<'a> IntoIterator for &'a #array_name
{
Expand Down Expand Up @@ -517,12 +495,28 @@ pub fn expand(input: DeriveEnum) -> TokenStream {
fn arrow_deserialize<'a>(v: Option<Self>) -> Option<Self> {
v
}

#[inline]
fn arrow_array_ref_into_iter(
array: &dyn arrow2::array::Array
) -> Option<<#array_name as arrow2_convert::deserialize::RefIntoIterator>::Iterator<'_>>
{
use core::ops::Deref;
let arr = array.as_any().downcast_ref::<arrow2::array::UnionArray>()?;
let fields = arr.fields();

Some(#iterator_name {
#(
#variant_names: <#variant_types as arrow2_convert::deserialize::ArrowDeserialize>::arrow_array_ref_into_iter(fields[#variant_indices].deref())?,
)*
types_iter: arr.types().iter(),
})
}
}
};

generated.extend([
array_decl,
array_impl,
array_into_iterator_impl,
array_iterator_decl,
array_iterator_iterator_impl,
Expand Down
Loading

0 comments on commit 8d5f8dd

Please sign in to comment.