Skip to content

Commit

Permalink
feat: add Struct Accessors to BoundReferences (#317)
Browse files Browse the repository at this point in the history
  • Loading branch information
sdd authored Apr 6, 2024
1 parent 4e89ac7 commit ca9de89
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 7 deletions.
120 changes: 120 additions & 0 deletions crates/iceberg/src/expr/accessor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::spec::{Datum, Literal, PrimitiveType, Struct};
use crate::{Error, ErrorKind};
use serde_derive::{Deserialize, Serialize};
use std::sync::Arc;

#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct StructAccessor {
position: usize,
r#type: PrimitiveType,
inner: Option<Box<StructAccessor>>,
}

pub(crate) type StructAccessorRef = Arc<StructAccessor>;

impl StructAccessor {
pub(crate) fn new(position: usize, r#type: PrimitiveType) -> Self {
StructAccessor {
position,
r#type,
inner: None,
}
}

pub(crate) fn wrap(position: usize, inner: Box<StructAccessor>) -> Self {
StructAccessor {
position,
r#type: inner.r#type().clone(),
inner: Some(inner),
}
}

pub(crate) fn position(&self) -> usize {
self.position
}

pub(crate) fn r#type(&self) -> &PrimitiveType {
&self.r#type
}

pub(crate) fn get<'a>(&'a self, container: &'a Struct) -> crate::Result<Datum> {
match &self.inner {
None => {
if let Literal::Primitive(literal) = &container[self.position] {
Ok(Datum::new(self.r#type().clone(), literal.clone()))
} else {
Err(Error::new(
ErrorKind::Unexpected,
"Expected Literal to be Primitive",
))
}
}
Some(inner) => {
if let Literal::Struct(wrapped) = &container[self.position] {
inner.get(wrapped)
} else {
Err(Error::new(
ErrorKind::Unexpected,
"Nested accessor should only be wrapping a Struct",
))
}
}
}
}
}

#[cfg(test)]
mod tests {
use crate::expr::accessor::StructAccessor;
use crate::spec::{Datum, Literal, PrimitiveType, Struct};

#[test]
fn test_single_level_accessor() {
let accessor = StructAccessor::new(1, PrimitiveType::Boolean);

assert_eq!(accessor.r#type(), &PrimitiveType::Boolean);
assert_eq!(accessor.position(), 1);

let test_struct =
Struct::from_iter(vec![Some(Literal::bool(false)), Some(Literal::bool(true))]);

assert_eq!(accessor.get(&test_struct).unwrap(), Datum::bool(true));
}

#[test]
fn test_nested_accessor() {
let nested_accessor = StructAccessor::new(1, PrimitiveType::Boolean);
let accessor = StructAccessor::wrap(2, Box::new(nested_accessor));

assert_eq!(accessor.r#type(), &PrimitiveType::Boolean);
//assert_eq!(accessor.position(), 1);

let nested_test_struct =
Struct::from_iter(vec![Some(Literal::bool(false)), Some(Literal::bool(true))]);

let test_struct = Struct::from_iter(vec![
Some(Literal::bool(false)),
Some(Literal::bool(false)),
Some(Literal::Struct(nested_test_struct)),
]);

assert_eq!(accessor.get(&test_struct).unwrap(), Datum::bool(true));
}
}
1 change: 1 addition & 0 deletions crates/iceberg/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ mod term;
use std::fmt::{Display, Formatter};

pub use term::*;
pub(crate) mod accessor;
mod predicate;

use crate::spec::SchemaRef;
Expand Down
33 changes: 31 additions & 2 deletions crates/iceberg/src/expr/term.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::fmt::{Display, Formatter};

use fnv::FnvHashSet;

use crate::expr::accessor::{StructAccessor, StructAccessorRef};
use crate::expr::Bind;
use crate::expr::{BinaryExpression, Predicate, PredicateOperator, SetExpression, UnaryExpression};
use crate::spec::{Datum, NestedField, NestedFieldRef, SchemaRef};
Expand Down Expand Up @@ -333,7 +334,19 @@ impl Bind for Reference {
format!("Field {} not found in schema", self.name),
)
})?;
Ok(BoundReference::new(self.name.clone(), field.clone()))

let accessor = schema.accessor_by_field_id(field.id).ok_or_else(|| {
Error::new(
ErrorKind::DataInvalid,
format!("Accessor for Field {} not found", self.name),
)
})?;

Ok(BoundReference::new(
self.name.clone(),
field.clone(),
accessor.clone(),
))
}
}

Expand All @@ -344,21 +357,32 @@ pub struct BoundReference {
// For example, if the field is `a.b.c`, then `field.name` is `c`, but `original_name` is `a.b.c`.
column_name: String,
field: NestedFieldRef,
accessor: StructAccessorRef,
}

impl BoundReference {
/// Creates a new bound reference.
pub fn new(name: impl Into<String>, field: NestedFieldRef) -> Self {
pub fn new(
name: impl Into<String>,
field: NestedFieldRef,
accessor: StructAccessorRef,
) -> Self {
Self {
column_name: name.into(),
field,
accessor,
}
}

/// Return the field of this reference.
pub fn field(&self) -> &NestedField {
&self.field
}

/// Get this BoundReference's Accessor
pub fn accessor(&self) -> &StructAccessor {
&self.accessor
}
}

impl Display for BoundReference {
Expand All @@ -374,6 +398,7 @@ pub type BoundTerm = BoundReference;
mod tests {
use std::sync::Arc;

use crate::expr::accessor::StructAccessor;
use crate::expr::{Bind, BoundReference, Reference};
use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};

Expand All @@ -397,9 +422,11 @@ mod tests {
let schema = table_schema_simple();
let reference = Reference::new("bar").bind(schema, true).unwrap();

let accessor_ref = Arc::new(StructAccessor::new(1, PrimitiveType::Int));
let expected_ref = BoundReference::new(
"bar",
NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(),
accessor_ref.clone(),
);

assert_eq!(expected_ref, reference);
Expand All @@ -410,9 +437,11 @@ mod tests {
let schema = table_schema_simple();
let reference = Reference::new("BAR").bind(schema, false).unwrap();

let accessor_ref = Arc::new(StructAccessor::new(1, PrimitiveType::Int));
let expected_ref = BoundReference::new(
"BAR",
NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(),
accessor_ref.clone(),
);

assert_eq!(expected_ref, reference);
Expand Down
Loading

0 comments on commit ca9de89

Please sign in to comment.