diff --git a/crates/iceberg/src/spec/schema.rs b/crates/iceberg/src/spec/schema.rs index 26bd0bec3..4843689c1 100644 --- a/crates/iceberg/src/spec/schema.rs +++ b/crates/iceberg/src/spec/schema.rs @@ -955,9 +955,9 @@ mod tests { }; use crate::spec::schema::Schema; use crate::spec::schema::_serde::{SchemaEnum, SchemaV1, SchemaV2}; - use std::collections::HashMap; + use std::collections::{HashMap, HashSet}; - use super::DEFAULT_SCHEMA_ID; + use super::{visit_schema, PruneColumn, DEFAULT_SCHEMA_ID}; fn check_schema_serde(json: &str, expected_type: Schema, _expected_enum: SchemaEnum) { let desered_type: Schema = serde_json::from_str(json).unwrap(); @@ -1533,4 +1533,430 @@ table { ); } } + #[test] + fn test_schema_prune_columns_string() { + let expected_schema = Type::Struct(StructType::new(vec![NestedField::optional( + 1, + "foo", + Type::Primitive(PrimitiveType::String), + ) + .into()])); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([1]); + let mut visitor = PruneColumn::new(selected, false); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_schema_prune_columns_string_full() { + let expected_schema = Type::Struct(StructType::new(vec![NestedField::optional( + 1, + "foo", + Type::Primitive(PrimitiveType::String), + ) + .into()])); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([1]); + let mut visitor = PruneColumn::new(selected, true); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_schema_prune_columns_list() { + let expected_schema = Type::Struct(StructType::new(vec![NestedField::required( + 4, + "qux", + Type::List(ListType { + element_field: NestedField::list_element( + 5, + Type::Primitive(PrimitiveType::String), + true, + ) + .into(), + }), + ) + .into()])); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([5]); + let mut visitor = PruneColumn::new(selected, false); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_list_itself() { + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([4]); + let mut visitor = PruneColumn::new(selected, false); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_err()); + } + + #[test] + fn test_schema_prune_columns_list_full() { + let expected_schema = Type::Struct(StructType::new(vec![NestedField::required( + 4, + "qux", + Type::List(ListType { + element_field: NestedField::list_element( + 5, + Type::Primitive(PrimitiveType::String), + true, + ) + .into(), + }), + ) + .into()])); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([5]); + let mut visitor = PruneColumn::new(selected, true); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_map() { + let expected_schema = Type::Struct(StructType::new(vec![NestedField::required( + 6, + "quux", + Type::Map(MapType { + key_field: NestedField::map_key_element(7, Type::Primitive(PrimitiveType::String)) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Map(MapType { + key_field: NestedField::map_key_element( + 9, + Type::Primitive(PrimitiveType::String), + ) + .into(), + value_field: NestedField::map_value_element( + 10, + Type::Primitive(PrimitiveType::Int), + true, + ) + .into(), + }), + true, + ) + .into(), + }), + ) + .into()])); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([9]); + let mut visitor = PruneColumn::new(selected, false); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_map_itself() { + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([6]); + let mut visitor = PruneColumn::new(selected, false); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_err()); + } + + #[test] + fn test_prune_columns_map_full() { + let expected_schema = Type::Struct(StructType::new(vec![NestedField::required( + 6, + "quux", + Type::Map(MapType { + key_field: NestedField::map_key_element(7, Type::Primitive(PrimitiveType::String)) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Map(MapType { + key_field: NestedField::map_key_element( + 9, + Type::Primitive(PrimitiveType::String), + ) + .into(), + value_field: NestedField::map_value_element( + 10, + Type::Primitive(PrimitiveType::Int), + true, + ) + .into(), + }), + true, + ) + .into(), + }), + ) + .into()])); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([9]); + let mut visitor = PruneColumn::new(selected, true); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_map_key() { + let expected_schema = Type::Struct(StructType::new(vec![NestedField::required( + 6, + "quux", + Type::Map(MapType { + key_field: NestedField::map_key_element(7, Type::Primitive(PrimitiveType::String)) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Map(MapType { + key_field: NestedField::map_key_element( + 9, + Type::Primitive(PrimitiveType::String), + ) + .into(), + value_field: NestedField::map_value_element( + 10, + Type::Primitive(PrimitiveType::Int), + true, + ) + .into(), + }), + true, + ) + .into(), + }), + ) + .into()])); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([10]); + let mut visitor = PruneColumn::new(selected, false); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_struct() { + let expected_schema = Type::Struct(StructType::new(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![NestedField::optional( + 16, + "name", + Type::Primitive(PrimitiveType::String), + ) + .into()])), + ) + .into()])); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([16]); + let mut visitor = PruneColumn::new(selected, false); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_struct_full() { + let expected_schema = Type::Struct(StructType::new(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![NestedField::optional( + 16, + "name", + Type::Primitive(PrimitiveType::String), + ) + .into()])), + ) + .into()])); + let schema = table_schema_nested(); + let selected: HashSet = HashSet::from([16]); + let mut visitor = PruneColumn::new(selected, true); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_empty_struct() { + let empty_schema = Schema::builder() + .with_fields(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![])), + ) + .into()]) + .build() + .unwrap(); + let expected_schema = Type::Struct(StructType::new(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![])), + ) + .into()])); + let selected: HashSet = HashSet::from([15]); + let mut visitor = PruneColumn::new(selected, false); + let result = visit_schema(&empty_schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_empty_struct_full() { + let empty_schema = Schema::builder() + .with_fields(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![])), + ) + .into()]) + .build() + .unwrap(); + let expected_schema = Type::Struct(StructType::new(vec![NestedField::optional( + 15, + "person", + Type::Struct(StructType::new(vec![])), + ) + .into()])); + let selected: HashSet = HashSet::from([15]); + let mut visitor = PruneColumn::new(selected, true); + let result = visit_schema(&empty_schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_struct_in_map() { + let empty_schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![NestedField::required( + 6, + "id_to_person", + Type::Map(MapType { + key_field: NestedField::map_key_element(7, Type::Primitive(PrimitiveType::Int)) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Struct( + StructType::new(vec![ + NestedField::optional(10, "name", Primitive(PrimitiveType::String)) + .into(), + NestedField::required(11, "age", Primitive(PrimitiveType::Int)) + .into(), + ]) + .into(), + ), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap(); + let expected_schema = Type::Struct(StructType::new(vec![NestedField::required( + 6, + "id_to_person", + Type::Map(MapType { + key_field: NestedField::map_key_element(7, Type::Primitive(PrimitiveType::Int)) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Struct( + StructType::new(vec![NestedField::required( + 11, + "age", + Primitive(PrimitiveType::Int), + ) + .into()]) + .into(), + ), + true, + ) + .into(), + }), + ) + .into()])); + let selected: HashSet = HashSet::from([11]); + let mut visitor = PruneColumn::new(selected, false); + let result = visit_schema(&empty_schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + #[test] + fn test_prune_columns_struct_in_map_full() { + let empty_schema = Schema::builder() + .with_schema_id(1) + .with_fields(vec![NestedField::required( + 6, + "id_to_person", + Type::Map(MapType { + key_field: NestedField::map_key_element(7, Type::Primitive(PrimitiveType::Int)) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Struct( + StructType::new(vec![ + NestedField::optional(10, "name", Primitive(PrimitiveType::String)) + .into(), + NestedField::required(11, "age", Primitive(PrimitiveType::Int)) + .into(), + ]) + .into(), + ), + true, + ) + .into(), + }), + ) + .into()]) + .build() + .unwrap(); + let expected_schema = Type::Struct(StructType::new(vec![NestedField::required( + 6, + "id_to_person", + Type::Map(MapType { + key_field: NestedField::map_key_element(7, Type::Primitive(PrimitiveType::Int)) + .into(), + value_field: NestedField::map_value_element( + 8, + Type::Struct( + StructType::new(vec![NestedField::required( + 11, + "age", + Primitive(PrimitiveType::Int), + ) + .into()]) + .into(), + ), + true, + ) + .into(), + }), + ) + .into()])); + let selected: HashSet = HashSet::from([11]); + let mut visitor = PruneColumn::new(selected, true); + let result = visit_schema(&empty_schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!(result.unwrap().unwrap(), expected_schema); + } + + #[test] + fn test_prune_columns_select_original_schema() { + let schema = table_schema_nested(); + let selected: HashSet = (0..schema.highest_field_id() + 1).collect(); + let mut visitor = PruneColumn::new(selected, true); + let result = visit_schema(&schema, &mut visitor); + assert!(result.is_ok()); + assert_eq!( + result.unwrap().unwrap(), + Type::Struct(schema.as_struct().clone()) + ); + } }