diff --git a/.github/actions/setup-builder/action.yml b/.github/actions/setup-builder/action.yml new file mode 100644 index 000000000..43de1cbaa --- /dev/null +++ b/.github/actions/setup-builder/action.yml @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This file is heavily inspired by +# [datafusion](https://github.com/apache/datafusion/blob/main/.github/actions/setup-builder/action.yaml). +name: Prepare Rust Builder +description: 'Prepare Rust Build Environment' +inputs: + rust-version: + description: 'version of rust to install (e.g. stable)' + required: true + default: 'stable' +runs: + using: "composite" + steps: + - name: Setup Rust toolchain + shell: bash + run: | + echo "Installing ${{ inputs.rust-version }}" + rustup toolchain install ${{ inputs.rust-version }} + rustup default ${{ inputs.rust-version }} + rustup component add rustfmt clippy + - name: Fixup git permissions + # https://github.com/actions/checkout/issues/766 + shell: bash + run: git config --global --add safe.directory "$GITHUB_WORKSPACE" \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f0f8f0f5b..d155b2949 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -29,15 +29,29 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event_name }} cancel-in-progress: true +env: + rust_msrv: "1.77.1" + jobs: check: - runs-on: ubuntu-latest + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest steps: - uses: actions/checkout@v4 - name: Check License Header uses: apache/skywalking-eyes/header@v0.6.0 + - name: Install cargo-sort + run: make install-cargo-sort + + - name: Install taplo-cli + run: make install-taplo-cli + - name: Cargo format run: make check-fmt @@ -50,8 +64,29 @@ jobs: - name: Cargo sort run: make cargo-sort + - name: Cargo Machete + run: make cargo-machete build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-latest + steps: + - uses: actions/checkout@v4 + + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ env.rust_msrv }} + + - name: Build + run: make build + + build_with_no_default_features: runs-on: ${{ matrix.os }} strategy: matrix: @@ -62,15 +97,23 @@ jobs: steps: - uses: actions/checkout@v4 - name: Build - run: cargo build + run: cargo build -p iceberg --no-default-features unit: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ env.rust_msrv }} + - name: Test run: cargo test --no-fail-fast --all-targets --all-features --workspace - + + - name: Async-std Test + run: cargo test --no-fail-fast --all-targets --no-default-features --features "async-std" --features "storage-fs" --workspace + - name: Doc Test run: cargo test --no-fail-fast --doc --all-features --workspace diff --git a/.github/workflows/ci_typos.yml b/.github/workflows/ci_typos.yml index 51a6a7b91..0030cc8d1 100644 --- a/.github/workflows/ci_typos.yml +++ b/.github/workflows/ci_typos.yml @@ -41,7 +41,5 @@ jobs: FORCE_COLOR: 1 steps: - uses: actions/checkout@v4 - - run: curl -LsSf https://github.com/crate-ci/typos/releases/download/v1.14.8/typos-v1.14.8-x86_64-unknown-linux-musl.tar.gz | tar zxf - -C ${CARGO_HOME:-~/.cargo}/bin - - - name: do typos check with typos-cli - run: typos + - name: Check typos + uses: crate-ci/typos@v1.23.2 diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 9289d3e10..a57aa612f 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -21,13 +21,11 @@ on: push: tags: - '*' - pull_request: - branches: - - main - paths: - - ".github/workflows/publish.yml" workflow_dispatch: +env: + rust_msrv: "1.77.1" + jobs: publish: runs-on: ubuntu-latest @@ -42,9 +40,11 @@ jobs: - "crates/catalog/rest" steps: - uses: actions/checkout@v4 - - name: Dryrun ${{ matrix.package }} - working-directory: ${{ matrix.package }} - run: cargo publish --all-features --dry-run + + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + with: + rust-version: ${{ env.rust_msrv }} - name: Publish ${{ matrix.package }} working-directory: ${{ matrix.package }} diff --git a/Cargo.toml b/Cargo.toml index 75ab0fe45..3f4631ba0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,19 +38,22 @@ rust-version = "1.77.1" anyhow = "1.0.72" apache-avro = "0.16" array-init = "2" -arrow-arith = { version = "51" } -arrow-array = { version = "51" } -arrow-ord = { version = "51" } -arrow-schema = { version = "51" } -arrow-select = { version = "51" } +arrow-arith = { version = "52" } +arrow-array = { version = "52" } +arrow-ord = { version = "52" } +arrow-schema = { version = "52" } +arrow-select = { version = "52" } +arrow-string = { version = "52" } async-stream = "0.3.5" async-trait = "0.1" +async-std = "1.12.0" aws-config = "1.1.8" aws-sdk-glue = "1.21.0" bimap = "0.6" bitvec = "1.0.1" bytes = "1.5" chrono = "0.4.34" +ctor = "0.2.8" derive_builder = "0.20.0" either = "1" env_logger = "0.11.0" @@ -60,18 +63,17 @@ iceberg = { version = "0.2.0", path = "./crates/iceberg" } iceberg-catalog-rest = { version = "0.2.0", path = "./crates/catalog/rest" } iceberg-catalog-hms = { version = "0.2.0", path = "./crates/catalog/hms" } itertools = "0.13" -lazy_static = "1" log = "^0.4" mockito = "^1" murmur3 = "0.5.2" once_cell = "1" -opendal = "0.46" +opendal = "0.47" ordered-float = "4.0.0" -parquet = "51" -pilota = "0.11.0" +parquet = "52" +pilota = "0.11.2" pretty_assertions = "1.4.0" port_scanner = "0.1.5" -reqwest = { version = "^0.12", features = ["json"] } +reqwest = { version = "^0.12", default-features = false, features = ["json"] } rust_decimal = "1.31.0" serde = { version = "^1.0", features = ["rc"] } serde_bytes = "0.11.8" @@ -80,11 +82,10 @@ serde_json = "^1.0" serde_repr = "0.1.16" serde_with = "3.4.0" tempfile = "3.8" -tokio = { version = "1", features = ["macros"] } +tokio = { version = "1", default-features = false } typed-builder = "^0.18" url = "2" -urlencoding = "2" uuid = { version = "1.6.1", features = ["v7"] } -volo-thrift = "0.9.2" -hive_metastore = "0.0.2" +volo-thrift = "0.10" +hive_metastore = "0.1.0" tera = "1" diff --git a/Makefile b/Makefile index ff01e1807..ed10a8acd 100644 --- a/Makefile +++ b/Makefile @@ -17,30 +17,37 @@ .EXPORT_ALL_VARIABLES: -RUST_LOG = debug - build: - cargo build + cargo build --all-targets --all-features --workspace check-fmt: - cargo fmt --all -- --check + cargo fmt --all -- --check check-clippy: - cargo clippy --all-targets --all-features --workspace -- -D warnings + cargo clippy --all-targets --all-features --workspace -- -D warnings + +install-cargo-sort: + cargo install cargo-sort@1.0.9 -cargo-sort: - cargo install cargo-sort +cargo-sort: install-cargo-sort cargo sort -c -w -fix-toml: - cargo install taplo-cli --locked +install-cargo-machete: + cargo install cargo-machete + +cargo-machete: install-cargo-machete + cargo machete + +install-taplo-cli: + cargo install taplo-cli@0.9.0 + +fix-toml: install-taplo-cli taplo fmt -check-toml: - cargo install taplo-cli --locked +check-toml: install-taplo-cli taplo check -check: check-fmt check-clippy cargo-sort check-toml +check: check-fmt check-clippy cargo-sort check-toml cargo-machete doc-test: cargo test --no-fail-fast --doc --all-features --workspace diff --git a/README.md b/README.md index 47da862f1..4f8265b79 100644 --- a/README.md +++ b/README.md @@ -17,89 +17,71 @@ ~ under the License. --> -# Apache Iceberg Rust - -Native Rust implementation of [Apache Iceberg](https://iceberg.apache.org/). - -## Roadmap - -### Catalog - -| Catalog Type | Status | -| ------------ | ----------- | -| Rest | Done | -| Hive | Done | -| Sql | In Progress | -| Glue | Done | -| DynamoDB | Not Started | - -### FileIO - -| FileIO Type | Status | -| ----------- | ----------- | -| S3 | Done | -| Local File | Done | -| GCS | Not Started | -| HDFS | Not Started | - -Our `FileIO` is powered by [Apache OpenDAL](https://github.com/apache/opendal), so it would be quite easy to -expand to other service. - -### Table API - -#### Reader - -| Feature | Status | -| ---------------------------------------------------------- | ----------- | -| File based task planning | Done | -| Size based task planning | Not started | -| Filter pushdown(manifest evaluation, partition prunning) | In Progress | -| Apply deletions, including equality and position deletions | Not started | -| Read into arrow record batch | In Progress | -| Parquet file support | Done | -| ORC file support | Not started | - -#### Writer - -| Feature | Status | -| ------------------------ | ----------- | -| Data writer | Not started | -| Equality deletion writer | Not started | -| Position deletion writer | Not started | -| Partitioned writer | Not started | -| Upsert writer | Not started | -| Parquet file support | Not started | -| ORC file support | Not started | - -#### Transaction - -| Feature | Status | -| --------------------- | ----------- | -| Schema evolution | Not started | -| Update partition spec | Not started | -| Update properties | Not started | -| Replace sort order | Not started | -| Update location | Not started | -| Append files | Not started | -| Rewrite files | Not started | -| Rewrite manifests | Not started | -| Overwrite files | Not started | -| Row level updates | Not started | -| Replace partitions | Not started | -| Snapshot management | Not started | - -### Integrations - -We will add integrations with other rust based data systems, such as polars, datafusion, etc. +# Apache Iceberg™ Rust + + + +Rust implementation of [Apache Iceberg™](https://iceberg.apache.org/). + +Working on [v0.3.0 Release Milestone](https://github.com/apache/iceberg-rust/milestone/2) + +## Components + +The Apache Iceberg Rust project is composed of the following components: + +| Name | Release | Docs | +|------------------------|------------------------------------------------------------|------------------------------------------------------| +| [iceberg] | [![iceberg image]][iceberg link] | [![docs release]][iceberg release docs] | +| [iceberg-datafusion] | - | - | +| [iceberg-catalog-glue] | - | - | +| [iceberg-catalog-hms] | [![iceberg-catalog-hms image]][iceberg-catalog-hms link] | [![docs release]][iceberg-catalog-hms release docs] | +| [iceberg-catalog-rest] | [![iceberg-catalog-rest image]][iceberg-catalog-rest link] | [![docs release]][iceberg-catalog-rest release docs] | + +[docs release]: https://img.shields.io/badge/docs-release-blue +[iceberg]: crates/iceberg/README.md +[iceberg image]: https://img.shields.io/crates/v/iceberg.svg +[iceberg link]: https://crates.io/crates/iceberg +[iceberg release docs]: https://docs.rs/iceberg + +[iceberg-datafusion]: crates/integrations/datafusion/README.md + +[iceberg-catalog-glue]: crates/catalog/glue/README.md + +[iceberg-catalog-hms]: crates/catalog/hms/README.md +[iceberg-catalog-hms image]: https://img.shields.io/crates/v/iceberg-catalog-hms.svg +[iceberg-catalog-hms link]: https://crates.io/crates/iceberg-catalog-hms +[iceberg-catalog-hms release docs]: https://docs.rs/iceberg-catalog-hms + +[iceberg-catalog-rest]: crates/catalog/rest/README.md +[iceberg-catalog-rest image]: https://img.shields.io/crates/v/iceberg-catalog-rest.svg +[iceberg-catalog-rest link]: https://crates.io/crates/iceberg-catalog-rest +[iceberg-catalog-rest release docs]: https://docs.rs/iceberg-catalog-rest + +## Supported Rust Version + +Iceberg Rust is built and tested with stable rust, and will keep a rolling MSRV(minimum supported rust version). The +current MSRV is 1.77.1. + +Also, we use unstable rust to run linters, such as `clippy` and `rustfmt`. But this will not affect downstream users, +and only MSRV is required. + ## Contribute -Iceberg is an active open-source project. We are always open to people who want to use it or contribute to it. Here are some ways to go. +Apache Iceberg is an active open-source project, governed under the Apache Software Foundation (ASF). We are always open to people who want to use or contribute to it. Here are some ways to get involved. - Start with [Contributing Guide](CONTRIBUTING.md). - Submit [Issues](https://github.com/apache/iceberg-rust/issues/new) for bug report or feature requests. -- Discuss at [dev mailing list](mailto:dev@iceberg.apache.org) ([subscribe]() / [unsubscribe]() / [archives](https://lists.apache.org/list.html?dev@iceberg.apache.org)) -- Talk to community directly at [Slack #rust channel](https://join.slack.com/t/apache-iceberg/shared_invite/zt-1zbov3k6e-KtJfoaxp97YfX6dPz1Bk7A). +- Discuss + at [dev mailing list](mailto:dev@iceberg.apache.org) ([subscribe]() / [unsubscribe]() / [archives](https://lists.apache.org/list.html?dev@iceberg.apache.org)) +- Talk to the community directly + at [Slack #rust channel](https://join.slack.com/t/apache-iceberg/shared_invite/zt-1zbov3k6e-KtJfoaxp97YfX6dPz1Bk7A). + +The Apache Iceberg community is built on the principles described in the [Apache Way](https://www.apache.org/theapacheway/index.html) and all who engage with the community are expected to be respectful, open, come with the best interests of the community in mind, and abide by the Apache Foundation [Code of Conduct](https://www.apache.org/foundation/policies/conduct.html). +## Users + +- [Databend](https://github.com/datafuselabs/databend/): An open-source cloud data warehouse that serves as a cost-effective alternative to Snowflake. +- [iceberg-catalog](https://github.com/hansetag/iceberg-catalog): A Rust implementation of the Iceberg REST Catalog specification. ## License diff --git a/crates/catalog/glue/Cargo.toml b/crates/catalog/glue/Cargo.toml index 8e1c077f1..0d2e1f983 100644 --- a/crates/catalog/glue/Cargo.toml +++ b/crates/catalog/glue/Cargo.toml @@ -41,6 +41,6 @@ typed-builder = { workspace = true } uuid = { workspace = true } [dev-dependencies] +ctor = { workspace = true } iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } -opendal = { workspace = true, features = ["services-s3"] } port_scanner = { workspace = true } diff --git a/crates/catalog/glue/testdata/glue_catalog/docker-compose.yaml b/crates/catalog/glue/testdata/glue_catalog/docker-compose.yaml index a0be22e30..0a2c938a7 100644 --- a/crates/catalog/glue/testdata/glue_catalog/docker-compose.yaml +++ b/crates/catalog/glue/testdata/glue_catalog/docker-compose.yaml @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -version: '3.8' - services: minio: image: minio/minio:RELEASE.2024-03-07T00-43-48Z diff --git a/crates/catalog/glue/tests/glue_catalog_test.rs b/crates/catalog/glue/tests/glue_catalog_test.rs index eb0cd96b9..3edd8cdaf 100644 --- a/crates/catalog/glue/tests/glue_catalog_test.rs +++ b/crates/catalog/glue/tests/glue_catalog_test.rs @@ -18,7 +18,9 @@ //! Integration tests for glue catalog. use std::collections::HashMap; +use std::sync::RwLock; +use ctor::{ctor, dtor}; use iceberg::io::{S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY}; use iceberg::spec::{NestedField, PrimitiveType, Schema, Type}; use iceberg::{Catalog, Namespace, NamespaceIdent, Result, TableCreation, TableIdent}; @@ -32,26 +34,36 @@ use tokio::time::sleep; const GLUE_CATALOG_PORT: u16 = 5000; const MINIO_PORT: u16 = 9000; +static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); -#[derive(Debug)] -struct TestFixture { - _docker_compose: DockerCompose, - glue_catalog: GlueCatalog, -} - -async fn set_test_fixture(func: &str) -> TestFixture { - set_up(); - +#[ctor] +fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); let docker_compose = DockerCompose::new( - normalize_test_name(format!("{}_{func}", module_path!())), + normalize_test_name(module_path!()), format!("{}/testdata/glue_catalog", env!("CARGO_MANIFEST_DIR")), ); - docker_compose.run(); + guard.replace(docker_compose); +} + +#[dtor] +fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); +} - let glue_catalog_ip = docker_compose.get_container_ip("moto"); - let minio_ip = docker_compose.get_container_ip("minio"); +async fn get_catalog() -> GlueCatalog { + set_up(); + let (glue_catalog_ip, minio_ip) = { + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + ( + docker_compose.get_container_ip("moto"), + docker_compose.get_container_ip("minio"), + ) + }; let read_port = format!("{}:{}", glue_catalog_ip, GLUE_CATALOG_PORT); loop { if !scan_port_addr(&read_port) { @@ -84,21 +96,12 @@ async fn set_test_fixture(func: &str) -> TestFixture { .props(props.clone()) .build(); - let glue_catalog = GlueCatalog::new(config).await.unwrap(); - - TestFixture { - _docker_compose: docker_compose, - glue_catalog, - } + GlueCatalog::new(config).await.unwrap() } -async fn set_test_namespace(fixture: &TestFixture, namespace: &NamespaceIdent) -> Result<()> { +async fn set_test_namespace(catalog: &GlueCatalog, namespace: &NamespaceIdent) -> Result<()> { let properties = HashMap::new(); - - fixture - .glue_catalog - .create_namespace(namespace, properties) - .await?; + catalog.create_namespace(namespace, properties).await?; Ok(()) } @@ -124,33 +127,26 @@ fn set_table_creation(location: impl ToString, name: impl ToString) -> Result Result<()> { - let fixture = set_test_fixture("test_rename_table").await; + let catalog = get_catalog().await; let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let namespace = Namespace::new(NamespaceIdent::new("my_database".into())); + let namespace = Namespace::new(NamespaceIdent::new("test_rename_table".into())); - fixture - .glue_catalog + catalog .create_namespace(namespace.name(), HashMap::new()) .await?; - let table = fixture - .glue_catalog - .create_table(namespace.name(), creation) - .await?; + let table = catalog.create_table(namespace.name(), creation).await?; let dest = TableIdent::new(namespace.name().clone(), "my_table_rename".to_string()); - fixture - .glue_catalog - .rename_table(table.identifier(), &dest) - .await?; + catalog.rename_table(table.identifier(), &dest).await?; - let table = fixture.glue_catalog.load_table(&dest).await?; + let table = catalog.load_table(&dest).await?; assert_eq!(table.identifier(), &dest); let src = TableIdent::new(namespace.name().clone(), "my_table".to_string()); - let src_table_exists = fixture.glue_catalog.table_exists(&src).await?; + let src_table_exists = catalog.table_exists(&src).await?; assert!(!src_table_exists); Ok(()) @@ -158,29 +154,22 @@ async fn test_rename_table() -> Result<()> { #[tokio::test] async fn test_table_exists() -> Result<()> { - let fixture = set_test_fixture("test_table_exists").await; + let catalog = get_catalog().await; let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let namespace = Namespace::new(NamespaceIdent::new("my_database".into())); + let namespace = Namespace::new(NamespaceIdent::new("test_table_exists".into())); - fixture - .glue_catalog + catalog .create_namespace(namespace.name(), HashMap::new()) .await?; let ident = TableIdent::new(namespace.name().clone(), "my_table".to_string()); - let exists = fixture.glue_catalog.table_exists(&ident).await?; + let exists = catalog.table_exists(&ident).await?; assert!(!exists); - let table = fixture - .glue_catalog - .create_table(namespace.name(), creation) - .await?; + let table = catalog.create_table(namespace.name(), creation).await?; - let exists = fixture - .glue_catalog - .table_exists(table.identifier()) - .await?; + let exists = catalog.table_exists(table.identifier()).await?; assert!(exists); @@ -189,26 +178,19 @@ async fn test_table_exists() -> Result<()> { #[tokio::test] async fn test_drop_table() -> Result<()> { - let fixture = set_test_fixture("test_drop_table").await; + let catalog = get_catalog().await; let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let namespace = Namespace::new(NamespaceIdent::new("my_database".into())); + let namespace = Namespace::new(NamespaceIdent::new("test_drop_table".into())); - fixture - .glue_catalog + catalog .create_namespace(namespace.name(), HashMap::new()) .await?; - let table = fixture - .glue_catalog - .create_table(namespace.name(), creation) - .await?; + let table = catalog.create_table(namespace.name(), creation).await?; - fixture.glue_catalog.drop_table(table.identifier()).await?; + catalog.drop_table(table.identifier()).await?; - let result = fixture - .glue_catalog - .table_exists(table.identifier()) - .await?; + let result = catalog.table_exists(table.identifier()).await?; assert!(!result); @@ -217,22 +199,17 @@ async fn test_drop_table() -> Result<()> { #[tokio::test] async fn test_load_table() -> Result<()> { - let fixture = set_test_fixture("test_load_table").await; + let catalog = get_catalog().await; let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let namespace = Namespace::new(NamespaceIdent::new("my_database".into())); + let namespace = Namespace::new(NamespaceIdent::new("test_load_table".into())); - fixture - .glue_catalog + catalog .create_namespace(namespace.name(), HashMap::new()) .await?; - let expected = fixture - .glue_catalog - .create_table(namespace.name(), creation) - .await?; + let expected = catalog.create_table(namespace.name(), creation).await?; - let result = fixture - .glue_catalog + let result = catalog .load_table(&TableIdent::new( namespace.name().clone(), "my_table".to_string(), @@ -248,23 +225,19 @@ async fn test_load_table() -> Result<()> { #[tokio::test] async fn test_create_table() -> Result<()> { - let fixture = set_test_fixture("test_create_table").await; - let namespace = NamespaceIdent::new("my_database".to_string()); - set_test_namespace(&fixture, &namespace).await?; + let catalog = get_catalog().await; + let namespace = NamespaceIdent::new("test_create_table".to_string()); + set_test_namespace(&catalog, &namespace).await?; let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let result = fixture - .glue_catalog - .create_table(&namespace, creation) - .await?; + let result = catalog.create_table(&namespace, creation).await?; assert_eq!(result.identifier().name(), "my_table"); assert!(result .metadata_location() .is_some_and(|location| location.starts_with("s3a://warehouse/hive/metadata/00000-"))); assert!( - fixture - .glue_catalog + catalog .file_io() .is_exist("s3a://warehouse/hive/metadata/") .await? @@ -275,12 +248,12 @@ async fn test_create_table() -> Result<()> { #[tokio::test] async fn test_list_tables() -> Result<()> { - let fixture = set_test_fixture("test_list_tables").await; - let namespace = NamespaceIdent::new("my_database".to_string()); - set_test_namespace(&fixture, &namespace).await?; + let catalog = get_catalog().await; + let namespace = NamespaceIdent::new("test_list_tables".to_string()); + set_test_namespace(&catalog, &namespace).await?; let expected = vec![]; - let result = fixture.glue_catalog.list_tables(&namespace).await?; + let result = catalog.list_tables(&namespace).await?; assert_eq!(result, expected); @@ -289,16 +262,16 @@ async fn test_list_tables() -> Result<()> { #[tokio::test] async fn test_drop_namespace() -> Result<()> { - let fixture = set_test_fixture("test_drop_namespace").await; - let namespace = NamespaceIdent::new("my_database".to_string()); - set_test_namespace(&fixture, &namespace).await?; + let catalog = get_catalog().await; + let namespace = NamespaceIdent::new("test_drop_namespace".to_string()); + set_test_namespace(&catalog, &namespace).await?; - let exists = fixture.glue_catalog.namespace_exists(&namespace).await?; + let exists = catalog.namespace_exists(&namespace).await?; assert!(exists); - fixture.glue_catalog.drop_namespace(&namespace).await?; + catalog.drop_namespace(&namespace).await?; - let exists = fixture.glue_catalog.namespace_exists(&namespace).await?; + let exists = catalog.namespace_exists(&namespace).await?; assert!(!exists); Ok(()) @@ -306,23 +279,20 @@ async fn test_drop_namespace() -> Result<()> { #[tokio::test] async fn test_update_namespace() -> Result<()> { - let fixture = set_test_fixture("test_update_namespace").await; - let namespace = NamespaceIdent::new("my_database".into()); - set_test_namespace(&fixture, &namespace).await?; + let catalog = get_catalog().await; + let namespace = NamespaceIdent::new("test_update_namespace".into()); + set_test_namespace(&catalog, &namespace).await?; - let before_update = fixture.glue_catalog.get_namespace(&namespace).await?; + let before_update = catalog.get_namespace(&namespace).await?; let before_update = before_update.properties().get("description"); assert_eq!(before_update, None); let properties = HashMap::from([("description".to_string(), "my_update".to_string())]); - fixture - .glue_catalog - .update_namespace(&namespace, properties) - .await?; + catalog.update_namespace(&namespace, properties).await?; - let after_update = fixture.glue_catalog.get_namespace(&namespace).await?; + let after_update = catalog.get_namespace(&namespace).await?; let after_update = after_update.properties().get("description"); assert_eq!(after_update, Some("my_update".to_string()).as_ref()); @@ -332,16 +302,16 @@ async fn test_update_namespace() -> Result<()> { #[tokio::test] async fn test_namespace_exists() -> Result<()> { - let fixture = set_test_fixture("test_namespace_exists").await; + let catalog = get_catalog().await; - let namespace = NamespaceIdent::new("my_database".into()); + let namespace = NamespaceIdent::new("test_namespace_exists".into()); - let exists = fixture.glue_catalog.namespace_exists(&namespace).await?; + let exists = catalog.namespace_exists(&namespace).await?; assert!(!exists); - set_test_namespace(&fixture, &namespace).await?; + set_test_namespace(&catalog, &namespace).await?; - let exists = fixture.glue_catalog.namespace_exists(&namespace).await?; + let exists = catalog.namespace_exists(&namespace).await?; assert!(exists); Ok(()) @@ -349,16 +319,16 @@ async fn test_namespace_exists() -> Result<()> { #[tokio::test] async fn test_get_namespace() -> Result<()> { - let fixture = set_test_fixture("test_get_namespace").await; + let catalog = get_catalog().await; - let namespace = NamespaceIdent::new("my_database".into()); + let namespace = NamespaceIdent::new("test_get_namespace".into()); - let does_not_exist = fixture.glue_catalog.get_namespace(&namespace).await; + let does_not_exist = catalog.get_namespace(&namespace).await; assert!(does_not_exist.is_err()); - set_test_namespace(&fixture, &namespace).await?; + set_test_namespace(&catalog, &namespace).await?; - let result = fixture.glue_catalog.get_namespace(&namespace).await?; + let result = catalog.get_namespace(&namespace).await?; let expected = Namespace::new(namespace); assert_eq!(result, expected); @@ -368,17 +338,14 @@ async fn test_get_namespace() -> Result<()> { #[tokio::test] async fn test_create_namespace() -> Result<()> { - let fixture = set_test_fixture("test_create_namespace").await; + let catalog = get_catalog().await; let properties = HashMap::new(); - let namespace = NamespaceIdent::new("my_database".into()); + let namespace = NamespaceIdent::new("test_create_namespace".into()); let expected = Namespace::new(namespace.clone()); - let result = fixture - .glue_catalog - .create_namespace(&namespace, properties) - .await?; + let result = catalog.create_namespace(&namespace, properties).await?; assert_eq!(result, expected); @@ -387,18 +354,16 @@ async fn test_create_namespace() -> Result<()> { #[tokio::test] async fn test_list_namespace() -> Result<()> { - let fixture = set_test_fixture("test_list_namespace").await; + let catalog = get_catalog().await; - let expected = vec![]; - let result = fixture.glue_catalog.list_namespaces(None).await?; - assert_eq!(result, expected); + let namespace = NamespaceIdent::new("test_list_namespace".to_string()); + set_test_namespace(&catalog, &namespace).await?; - let namespace = NamespaceIdent::new("my_database".to_string()); - set_test_namespace(&fixture, &namespace).await?; + let result = catalog.list_namespaces(None).await?; + assert!(result.contains(&namespace)); - let expected = vec![namespace]; - let result = fixture.glue_catalog.list_namespaces(None).await?; - assert_eq!(result, expected); + let empty_result = catalog.list_namespaces(Some(&namespace)).await?; + assert!(empty_result.is_empty()); Ok(()) } diff --git a/crates/catalog/hms/Cargo.toml b/crates/catalog/hms/Cargo.toml index b53901552..e7d4ec2f3 100644 --- a/crates/catalog/hms/Cargo.toml +++ b/crates/catalog/hms/Cargo.toml @@ -43,6 +43,6 @@ uuid = { workspace = true } volo-thrift = { workspace = true } [dev-dependencies] +ctor = { workspace = true } iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } -opendal = { workspace = true, features = ["services-s3"] } port_scanner = { workspace = true } diff --git a/crates/catalog/hms/src/catalog.rs b/crates/catalog/hms/src/catalog.rs index 18fcacdfc..7a292c51b 100644 --- a/crates/catalog/hms/src/catalog.rs +++ b/crates/catalog/hms/src/catalog.rs @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::error::from_io_error; use crate::error::from_thrift_error; +use crate::error::{from_io_error, from_thrift_exception}; use super::utils::*; +use anyhow::anyhow; use async_trait::async_trait; use hive_metastore::ThriftHiveMetastoreClient; use hive_metastore::ThriftHiveMetastoreClientBuilder; @@ -36,7 +37,7 @@ use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::net::ToSocketAddrs; use typed_builder::TypedBuilder; -use volo_thrift::ResponseError; +use volo_thrift::MaybeException; /// Which variant of the thrift transport to communicate with HMS /// See: @@ -137,7 +138,8 @@ impl Catalog for HmsCatalog { .0 .get_all_databases() .await - .map_err(from_thrift_error)? + .map(from_thrift_exception) + .map_err(from_thrift_error)?? }; Ok(dbs @@ -195,7 +197,8 @@ impl Catalog for HmsCatalog { .0 .get_database(name.into()) .await - .map_err(from_thrift_error)?; + .map(from_thrift_exception) + .map_err(from_thrift_error)??; let ns = convert_to_namespace(&db)?; @@ -220,17 +223,16 @@ impl Catalog for HmsCatalog { let resp = self.client.0.get_database(name.into()).await; match resp { - Ok(_) => Ok(true), - Err(err) => { - if let ResponseError::UserException(ThriftHiveMetastoreGetDatabaseException::O1( - _, - )) = &err - { - Ok(false) - } else { - Err(from_thrift_error(err)) - } + Ok(MaybeException::Ok(_)) => Ok(true), + Ok(MaybeException::Exception(ThriftHiveMetastoreGetDatabaseException::O1(_))) => { + Ok(false) } + Ok(MaybeException::Exception(exception)) => Err(Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting thrift error".to_string(), + ) + .with_source(anyhow!("thrift error: {:?}", exception))), + Err(err) => Err(from_thrift_error(err)), } } @@ -306,7 +308,8 @@ impl Catalog for HmsCatalog { .0 .get_all_tables(name.into()) .await - .map_err(from_thrift_error)?; + .map(from_thrift_exception) + .map_err(from_thrift_error)??; let tables = tables .iter() @@ -397,7 +400,8 @@ impl Catalog for HmsCatalog { .0 .get_table(db_name.clone().into(), table.name.clone().into()) .await - .map_err(from_thrift_error)?; + .map(from_thrift_exception) + .map_err(from_thrift_error)??; let metadata_location = get_metadata_location(&hive_table.parameters)?; @@ -457,16 +461,14 @@ impl Catalog for HmsCatalog { .await; match resp { - Ok(_) => Ok(true), - Err(err) => { - if let ResponseError::UserException(ThriftHiveMetastoreGetTableException::O2(_)) = - &err - { - Ok(false) - } else { - Err(from_thrift_error(err)) - } - } + Ok(MaybeException::Ok(_)) => Ok(true), + Ok(MaybeException::Exception(ThriftHiveMetastoreGetTableException::O2(_))) => Ok(false), + Ok(MaybeException::Exception(exception)) => Err(Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting thrift error".to_string(), + ) + .with_source(anyhow!("thrift error: {:?}", exception))), + Err(err) => Err(from_thrift_error(err)), } } @@ -488,7 +490,8 @@ impl Catalog for HmsCatalog { .0 .get_table(src_dbname.clone().into(), src_tbl_name.clone().into()) .await - .map_err(from_thrift_error)?; + .map(from_thrift_exception) + .map_err(from_thrift_error)??; tbl.db_name = Some(dest_dbname.into()); tbl.table_name = Some(dest_tbl_name.into()); diff --git a/crates/catalog/hms/src/error.rs b/crates/catalog/hms/src/error.rs index a0f393c62..cee5e462f 100644 --- a/crates/catalog/hms/src/error.rs +++ b/crates/catalog/hms/src/error.rs @@ -19,12 +19,12 @@ use anyhow::anyhow; use iceberg::{Error, ErrorKind}; use std::fmt::Debug; use std::io; +use volo_thrift::MaybeException; /// Format a thrift error into iceberg error. -pub fn from_thrift_error(error: volo_thrift::error::ResponseError) -> Error -where - T: Debug, -{ +/// +/// Please only throw this error when you are sure that the error is caused by thrift. +pub fn from_thrift_error(error: impl std::error::Error) -> Error { Error::new( ErrorKind::Unexpected, "Operation failed for hitting thrift error".to_string(), @@ -32,6 +32,18 @@ where .with_source(anyhow!("thrift error: {:?}", error)) } +/// Format a thrift exception into iceberg error. +pub fn from_thrift_exception(value: MaybeException) -> Result { + match value { + MaybeException::Ok(v) => Ok(v), + MaybeException::Exception(err) => Err(Error::new( + ErrorKind::Unexpected, + "Operation failed for hitting thrift error".to_string(), + ) + .with_source(anyhow!("thrift error: {:?}", err))), + } +} + /// Format an io error into iceberg error. pub fn from_io_error(error: io::Error) -> Error { Error::new( diff --git a/crates/catalog/hms/src/utils.rs b/crates/catalog/hms/src/utils.rs index 04ee5d4b3..baaa004ed 100644 --- a/crates/catalog/hms/src/utils.rs +++ b/crates/catalog/hms/src/utils.rs @@ -74,11 +74,15 @@ pub(crate) fn convert_to_namespace(database: &Database) -> Result { properties.insert(HMS_DB_OWNER.to_string(), owner.to_string()); }; - if let Some(owner_type) = &database.owner_type { - let value = match owner_type { - PrincipalType::User => "User", - PrincipalType::Group => "Group", - PrincipalType::Role => "Role", + if let Some(owner_type) = database.owner_type { + let value = if owner_type == PrincipalType::USER { + "User" + } else if owner_type == PrincipalType::GROUP { + "Group" + } else if owner_type == PrincipalType::ROLE { + "Role" + } else { + unreachable!("Invalid owner type") }; properties.insert(HMS_DB_OWNER_TYPE.to_string(), value.to_string()); @@ -117,9 +121,9 @@ pub(crate) fn convert_to_database( HMS_DB_OWNER => db.owner_name = Some(v.clone().into()), HMS_DB_OWNER_TYPE => { let owner_type = match v.to_lowercase().as_str() { - "user" => PrincipalType::User, - "group" => PrincipalType::Group, - "role" => PrincipalType::Role, + "user" => PrincipalType::USER, + "group" => PrincipalType::GROUP, + "role" => PrincipalType::ROLE, _ => { return Err(Error::new( ErrorKind::DataInvalid, @@ -144,7 +148,7 @@ pub(crate) fn convert_to_database( // https://github.com/apache/iceberg/blob/main/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveHadoopUtil.java#L44 if db.owner_name.is_none() { db.owner_name = Some(HMS_DEFAULT_DB_OWNER.into()); - db.owner_type = Some(PrincipalType::User); + db.owner_type = Some(PrincipalType::USER); } Ok(db) @@ -504,7 +508,7 @@ mod tests { assert_eq!(db.name, Some(FastStr::from("my_namespace"))); assert_eq!(db.description, Some(FastStr::from("my_description"))); assert_eq!(db.owner_name, Some(FastStr::from("apache"))); - assert_eq!(db.owner_type, Some(PrincipalType::User)); + assert_eq!(db.owner_type, Some(PrincipalType::USER)); if let Some(params) = db.parameters { assert_eq!(params.get("key1"), Some(&FastStr::from("value1"))); @@ -522,7 +526,7 @@ mod tests { assert_eq!(db.name, Some(FastStr::from("my_namespace"))); assert_eq!(db.owner_name, Some(FastStr::from(HMS_DEFAULT_DB_OWNER))); - assert_eq!(db.owner_type, Some(PrincipalType::User)); + assert_eq!(db.owner_type, Some(PrincipalType::USER)); Ok(()) } diff --git a/crates/catalog/hms/testdata/hms_catalog/Dockerfile b/crates/catalog/hms/testdata/hms_catalog/Dockerfile index ff8c9fae6..8392e174a 100644 --- a/crates/catalog/hms/testdata/hms_catalog/Dockerfile +++ b/crates/catalog/hms/testdata/hms_catalog/Dockerfile @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM openjdk:8-jre-slim AS build +FROM --platform=$BUILDPLATFORM openjdk:8-jre-slim AS build + +ARG BUILDPLATFORM RUN apt-get update -qq && apt-get -qq -y install curl diff --git a/crates/catalog/hms/testdata/hms_catalog/docker-compose.yaml b/crates/catalog/hms/testdata/hms_catalog/docker-compose.yaml index c9605868b..181fac149 100644 --- a/crates/catalog/hms/testdata/hms_catalog/docker-compose.yaml +++ b/crates/catalog/hms/testdata/hms_catalog/docker-compose.yaml @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -version: '3.8' - services: minio: image: minio/minio:RELEASE.2024-03-07T00-43-48Z @@ -43,6 +41,7 @@ services: hive-metastore: image: iceberg-hive-metastore build: ./ + platform: ${DOCKER_DEFAULT_PLATFORM} expose: - 9083 environment: diff --git a/crates/catalog/hms/tests/hms_catalog_test.rs b/crates/catalog/hms/tests/hms_catalog_test.rs index 3dd2c7d09..a109757fe 100644 --- a/crates/catalog/hms/tests/hms_catalog_test.rs +++ b/crates/catalog/hms/tests/hms_catalog_test.rs @@ -18,7 +18,9 @@ //! Integration tests for hms catalog. use std::collections::HashMap; +use std::sync::RwLock; +use ctor::{ctor, dtor}; use iceberg::io::{S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY}; use iceberg::spec::{NestedField, PrimitiveType, Schema, Type}; use iceberg::{Catalog, Namespace, NamespaceIdent, TableCreation, TableIdent}; @@ -30,29 +32,42 @@ use tokio::time::sleep; const HMS_CATALOG_PORT: u16 = 9083; const MINIO_PORT: u16 = 9000; +static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); type Result = std::result::Result; -struct TestFixture { - _docker_compose: DockerCompose, - hms_catalog: HmsCatalog, -} - -async fn set_test_fixture(func: &str) -> TestFixture { - set_up(); - +#[ctor] +fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); let docker_compose = DockerCompose::new( - normalize_test_name(format!("{}_{func}", module_path!())), + normalize_test_name(module_path!()), format!("{}/testdata/hms_catalog", env!("CARGO_MANIFEST_DIR")), ); - docker_compose.run(); + guard.replace(docker_compose); +} + +#[dtor] +fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); +} - let hms_catalog_ip = docker_compose.get_container_ip("hive-metastore"); - let minio_ip = docker_compose.get_container_ip("minio"); +async fn get_catalog() -> HmsCatalog { + set_up(); + + let (hms_catalog_ip, minio_ip) = { + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + ( + docker_compose.get_container_ip("hive-metastore"), + docker_compose.get_container_ip("minio"), + ) + }; let read_port = format!("{}:{}", hms_catalog_ip, HMS_CATALOG_PORT); loop { if !scan_port_addr(&read_port) { + log::info!("scan read_port {} check", read_port); log::info!("Waiting for 1s hms catalog to ready..."); sleep(std::time::Duration::from_millis(1000)).await; } else { @@ -77,12 +92,15 @@ async fn set_test_fixture(func: &str) -> TestFixture { .props(props) .build(); - let hms_catalog = HmsCatalog::new(config).unwrap(); + HmsCatalog::new(config).unwrap() +} + +async fn set_test_namespace(catalog: &HmsCatalog, namespace: &NamespaceIdent) -> Result<()> { + let properties = HashMap::new(); - TestFixture { - _docker_compose: docker_compose, - hms_catalog, - } + catalog.create_namespace(namespace, properties).await?; + + Ok(()) } fn set_table_creation(location: impl ToString, name: impl ToString) -> Result { @@ -106,23 +124,18 @@ fn set_table_creation(location: impl ToString, name: impl ToString) -> Result Result<()> { - let fixture = set_test_fixture("test_rename_table").await; - let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let namespace = Namespace::new(NamespaceIdent::new("default".into())); + let catalog = get_catalog().await; + let creation: TableCreation = set_table_creation("s3a://warehouse/hive", "my_table")?; + let namespace = Namespace::new(NamespaceIdent::new("test_rename_table".into())); + set_test_namespace(&catalog, namespace.name()).await?; - let table = fixture - .hms_catalog - .create_table(namespace.name(), creation) - .await?; + let table: iceberg::table::Table = catalog.create_table(namespace.name(), creation).await?; let dest = TableIdent::new(namespace.name().clone(), "my_table_rename".to_string()); - fixture - .hms_catalog - .rename_table(table.identifier(), &dest) - .await?; + catalog.rename_table(table.identifier(), &dest).await?; - let result = fixture.hms_catalog.table_exists(&dest).await?; + let result = catalog.table_exists(&dest).await?; assert!(result); @@ -131,16 +144,14 @@ async fn test_rename_table() -> Result<()> { #[tokio::test] async fn test_table_exists() -> Result<()> { - let fixture = set_test_fixture("test_table_exists").await; + let catalog = get_catalog().await; let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let namespace = Namespace::new(NamespaceIdent::new("default".into())); + let namespace = Namespace::new(NamespaceIdent::new("test_table_exists".into())); + set_test_namespace(&catalog, namespace.name()).await?; - let table = fixture - .hms_catalog - .create_table(namespace.name(), creation) - .await?; + let table = catalog.create_table(namespace.name(), creation).await?; - let result = fixture.hms_catalog.table_exists(table.identifier()).await?; + let result = catalog.table_exists(table.identifier()).await?; assert!(result); @@ -149,18 +160,16 @@ async fn test_table_exists() -> Result<()> { #[tokio::test] async fn test_drop_table() -> Result<()> { - let fixture = set_test_fixture("test_drop_table").await; + let catalog = get_catalog().await; let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let namespace = Namespace::new(NamespaceIdent::new("default".into())); + let namespace = Namespace::new(NamespaceIdent::new("test_drop_table".into())); + set_test_namespace(&catalog, namespace.name()).await?; - let table = fixture - .hms_catalog - .create_table(namespace.name(), creation) - .await?; + let table = catalog.create_table(namespace.name(), creation).await?; - fixture.hms_catalog.drop_table(table.identifier()).await?; + catalog.drop_table(table.identifier()).await?; - let result = fixture.hms_catalog.table_exists(table.identifier()).await?; + let result = catalog.table_exists(table.identifier()).await?; assert!(!result); @@ -169,17 +178,14 @@ async fn test_drop_table() -> Result<()> { #[tokio::test] async fn test_load_table() -> Result<()> { - let fixture = set_test_fixture("test_load_table").await; + let catalog = get_catalog().await; let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let namespace = Namespace::new(NamespaceIdent::new("default".into())); + let namespace = Namespace::new(NamespaceIdent::new("test_load_table".into())); + set_test_namespace(&catalog, namespace.name()).await?; - let expected = fixture - .hms_catalog - .create_table(namespace.name(), creation) - .await?; + let expected = catalog.create_table(namespace.name(), creation).await?; - let result = fixture - .hms_catalog + let result = catalog .load_table(&TableIdent::new( namespace.name().clone(), "my_table".to_string(), @@ -195,22 +201,19 @@ async fn test_load_table() -> Result<()> { #[tokio::test] async fn test_create_table() -> Result<()> { - let fixture = set_test_fixture("test_create_table").await; + let catalog = get_catalog().await; let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - let namespace = Namespace::new(NamespaceIdent::new("default".into())); + let namespace = Namespace::new(NamespaceIdent::new("test_create_table".into())); + set_test_namespace(&catalog, namespace.name()).await?; - let result = fixture - .hms_catalog - .create_table(namespace.name(), creation) - .await?; + let result = catalog.create_table(namespace.name(), creation).await?; assert_eq!(result.identifier().name(), "my_table"); assert!(result .metadata_location() .is_some_and(|location| location.starts_with("s3a://warehouse/hive/metadata/00000-"))); assert!( - fixture - .hms_catalog + catalog .file_io() .is_exist("s3a://warehouse/hive/metadata/") .await? @@ -221,18 +224,16 @@ async fn test_create_table() -> Result<()> { #[tokio::test] async fn test_list_tables() -> Result<()> { - let fixture = set_test_fixture("test_list_tables").await; - let ns = Namespace::new(NamespaceIdent::new("default".into())); - let result = fixture.hms_catalog.list_tables(ns.name()).await?; + let catalog = get_catalog().await; + let ns = Namespace::new(NamespaceIdent::new("test_list_tables".into())); + let result = catalog.list_tables(ns.name()).await?; + set_test_namespace(&catalog, ns.name()).await?; assert_eq!(result, vec![]); let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - fixture - .hms_catalog - .create_table(ns.name(), creation) - .await?; - let result = fixture.hms_catalog.list_tables(ns.name()).await?; + catalog.create_table(ns.name(), creation).await?; + let result = catalog.list_tables(ns.name()).await?; assert_eq!( result, @@ -244,17 +245,15 @@ async fn test_list_tables() -> Result<()> { #[tokio::test] async fn test_list_namespace() -> Result<()> { - let fixture = set_test_fixture("test_list_namespace").await; + let catalog = get_catalog().await; - let expected_no_parent = vec![NamespaceIdent::new("default".into())]; - let result_no_parent = fixture.hms_catalog.list_namespaces(None).await?; + let result_no_parent = catalog.list_namespaces(None).await?; - let result_with_parent = fixture - .hms_catalog + let result_with_parent = catalog .list_namespaces(Some(&NamespaceIdent::new("parent".into()))) .await?; - assert_eq!(expected_no_parent, result_no_parent); + assert!(result_no_parent.contains(&NamespaceIdent::new("default".into()))); assert!(result_with_parent.is_empty()); Ok(()) @@ -262,7 +261,7 @@ async fn test_list_namespace() -> Result<()> { #[tokio::test] async fn test_create_namespace() -> Result<()> { - let fixture = set_test_fixture("test_create_namespace").await; + let catalog = get_catalog().await; let properties = HashMap::from([ ("comment".to_string(), "my_description".to_string()), @@ -279,14 +278,11 @@ async fn test_create_namespace() -> Result<()> { ]); let ns = Namespace::with_properties( - NamespaceIdent::new("my_namespace".into()), + NamespaceIdent::new("test_create_namespace".into()), properties.clone(), ); - let result = fixture - .hms_catalog - .create_namespace(ns.name(), properties) - .await?; + let result = catalog.create_namespace(ns.name(), properties).await?; assert_eq!(result, ns); @@ -294,8 +290,8 @@ async fn test_create_namespace() -> Result<()> { } #[tokio::test] -async fn test_get_namespace() -> Result<()> { - let fixture = set_test_fixture("test_get_namespace").await; +async fn test_get_default_namespace() -> Result<()> { + let catalog = get_catalog().await; let ns = Namespace::new(NamespaceIdent::new("default".into())); let properties = HashMap::from([ @@ -313,7 +309,7 @@ async fn test_get_namespace() -> Result<()> { let expected = Namespace::with_properties(NamespaceIdent::new("default".into()), properties); - let result = fixture.hms_catalog.get_namespace(ns.name()).await?; + let result = catalog.get_namespace(ns.name()).await?; assert_eq!(expected, result); @@ -322,19 +318,13 @@ async fn test_get_namespace() -> Result<()> { #[tokio::test] async fn test_namespace_exists() -> Result<()> { - let fixture = set_test_fixture("test_namespace_exists").await; + let catalog = get_catalog().await; let ns_exists = Namespace::new(NamespaceIdent::new("default".into())); - let ns_not_exists = Namespace::new(NamespaceIdent::new("not_here".into())); + let ns_not_exists = Namespace::new(NamespaceIdent::new("test_namespace_exists".into())); - let result_exists = fixture - .hms_catalog - .namespace_exists(ns_exists.name()) - .await?; - let result_not_exists = fixture - .hms_catalog - .namespace_exists(ns_not_exists.name()) - .await?; + let result_exists = catalog.namespace_exists(ns_exists.name()).await?; + let result_not_exists = catalog.namespace_exists(ns_not_exists.name()).await?; assert!(result_exists); assert!(!result_not_exists); @@ -344,17 +334,15 @@ async fn test_namespace_exists() -> Result<()> { #[tokio::test] async fn test_update_namespace() -> Result<()> { - let fixture = set_test_fixture("test_update_namespace").await; + let catalog = get_catalog().await; - let ns = Namespace::new(NamespaceIdent::new("default".into())); + let ns = NamespaceIdent::new("test_update_namespace".into()); + set_test_namespace(&catalog, &ns).await?; let properties = HashMap::from([("comment".to_string(), "my_update".to_string())]); - fixture - .hms_catalog - .update_namespace(ns.name(), properties) - .await?; + catalog.update_namespace(&ns, properties).await?; - let db = fixture.hms_catalog.get_namespace(ns.name()).await?; + let db = catalog.get_namespace(&ns).await?; assert_eq!( db.properties().get("comment"), @@ -366,21 +354,18 @@ async fn test_update_namespace() -> Result<()> { #[tokio::test] async fn test_drop_namespace() -> Result<()> { - let fixture = set_test_fixture("test_drop_namespace").await; + let catalog = get_catalog().await; let ns = Namespace::new(NamespaceIdent::new("delete_me".into())); - fixture - .hms_catalog - .create_namespace(ns.name(), HashMap::new()) - .await?; + catalog.create_namespace(ns.name(), HashMap::new()).await?; - let result = fixture.hms_catalog.namespace_exists(ns.name()).await?; + let result = catalog.namespace_exists(ns.name()).await?; assert!(result); - fixture.hms_catalog.drop_namespace(ns.name()).await?; + catalog.drop_namespace(ns.name()).await?; - let result = fixture.hms_catalog.namespace_exists(ns.name()).await?; + let result = catalog.namespace_exists(ns.name()).await?; assert!(!result); Ok(()) diff --git a/crates/catalog/rest/Cargo.toml b/crates/catalog/rest/Cargo.toml index 43e589910..add57183b 100644 --- a/crates/catalog/rest/Cargo.toml +++ b/crates/catalog/rest/Cargo.toml @@ -32,6 +32,7 @@ keywords = ["iceberg", "rest", "catalog"] # async-trait = { workspace = true } async-trait = { workspace = true } chrono = { workspace = true } +http = "1.1.0" iceberg = { workspace = true } itertools = { workspace = true } log = "0.4.20" @@ -39,13 +40,13 @@ reqwest = { workspace = true } serde = { workspace = true } serde_derive = { workspace = true } serde_json = { workspace = true } +tokio = { workspace = true, features = ["sync"] } typed-builder = { workspace = true } -urlencoding = { workspace = true } uuid = { workspace = true, features = ["v4"] } [dev-dependencies] +ctor = { workspace = true } iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } mockito = { workspace = true } -opendal = { workspace = true, features = ["services-fs"] } port_scanner = { workspace = true } tokio = { workspace = true } diff --git a/crates/catalog/rest/src/catalog.rs b/crates/catalog/rest/src/catalog.rs index efea9cbf3..83cc0ec84 100644 --- a/crates/catalog/rest/src/catalog.rs +++ b/crates/catalog/rest/src/catalog.rs @@ -23,13 +23,15 @@ use std::str::FromStr; use async_trait::async_trait; use itertools::Itertools; use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue}; -use reqwest::{Client, Request, Response, StatusCode, Url}; -use serde::de::DeserializeOwned; +use reqwest::{Method, StatusCode, Url}; +use tokio::sync::OnceCell; use typed_builder::TypedBuilder; -use urlencoding::encode; -use crate::catalog::_serde::{ - CommitTableRequest, CommitTableResponse, CreateTableRequest, LoadTableResponse, +use crate::client::HttpClient; +use crate::types::{ + CatalogConfig, CommitTableRequest, CommitTableResponse, CreateTableRequest, ErrorResponse, + ListNamespaceResponse, ListTableResponse, LoadTableResponse, NamespaceSerde, + RenameTableRequest, NO_CONTENT, OK, }; use iceberg::io::FileIO; use iceberg::table::Table; @@ -38,17 +40,12 @@ use iceberg::{ Catalog, Error, ErrorKind, Namespace, NamespaceIdent, TableCommit, TableCreation, TableIdent, }; -use self::_serde::{ - CatalogConfig, ErrorResponse, ListNamespaceResponse, ListTableResponse, NamespaceSerde, - RenameTableRequest, TokenResponse, NO_CONTENT, OK, -}; - const ICEBERG_REST_SPEC_VERSION: &str = "0.14.1"; const CARGO_PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); const PATH_V1: &str = "v1"; /// Rest catalog configuration. -#[derive(Debug, TypedBuilder)] +#[derive(Clone, Debug, TypedBuilder)] pub struct RestCatalogConfig { uri: String, #[builder(default, setter(strip_option))] @@ -71,7 +68,7 @@ impl RestCatalogConfig { [&self.uri, PATH_V1, "config"].join("/") } - fn get_token_endpoint(&self) -> String { + pub(crate) fn get_token_endpoint(&self) -> String { if let Some(auth_url) = self.props.get("rest.authorization-url") { auth_url.to_string() } else { @@ -84,11 +81,11 @@ impl RestCatalogConfig { } fn namespace_endpoint(&self, ns: &NamespaceIdent) -> String { - self.url_prefixed(&["namespaces", &ns.encode_in_url()]) + self.url_prefixed(&["namespaces", &ns.to_url_string()]) } fn tables_endpoint(&self, ns: &NamespaceIdent) -> String { - self.url_prefixed(&["namespaces", &ns.encode_in_url(), "tables"]) + self.url_prefixed(&["namespaces", &ns.to_url_string(), "tables"]) } fn rename_table_endpoint(&self) -> String { @@ -98,13 +95,47 @@ impl RestCatalogConfig { fn table_endpoint(&self, table: &TableIdent) -> String { self.url_prefixed(&[ "namespaces", - &table.namespace.encode_in_url(), + &table.namespace.to_url_string(), "tables", - encode(&table.name).as_ref(), + &table.name, ]) } - fn http_headers(&self) -> Result { + /// Get the token from the config. + /// + /// Client will use `token` to send requests if exists. + pub(crate) fn token(&self) -> Option { + self.props.get("token").cloned() + } + + /// Get the credentials from the config. Client will use `credential` + /// to fetch a new token if exists. + /// + /// ## Output + /// + /// - `None`: No credential is set. + /// - `Some(None, client_secret)`: No client_id is set, use client_secret directly. + /// - `Some(Some(client_id), client_secret)`: Both client_id and client_secret are set. + pub(crate) fn credential(&self) -> Option<(Option, String)> { + let cred = self.props.get("credential")?; + + match cred.split_once(':') { + Some((client_id, client_secret)) => { + Some((Some(client_id.to_string()), client_secret.to_string())) + } + None => Some((None, cred.to_string())), + } + } + + /// Get the extra headers from config. + /// + /// We will include: + /// + /// - `content-type` + /// - `x-client-version` + /// - `user-agnet` + /// - all headers specified by `header.xxx` in props. + pub(crate) fn extra_headers(&self) -> Result { let mut headers = HeaderMap::from_iter([ ( header::CONTENT_TYPE, @@ -120,167 +151,160 @@ impl RestCatalogConfig { ), ]); - if let Some(token) = self.props.get("token") { + for (key, value) in self + .props + .iter() + .filter(|(k, _)| k.starts_with("header.")) + // The unwrap here is same since we are filtering the keys + .map(|(k, v)| (k.strip_prefix("header.").unwrap(), v)) + { headers.insert( - header::AUTHORIZATION, - HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| { + HeaderName::from_str(key).map_err(|e| { Error::new( ErrorKind::DataInvalid, - "Invalid token received from catalog server!", + format!("Invalid header name: {key}"), + ) + .with_source(e) + })?, + HeaderValue::from_str(value).map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + format!("Invalid header value: {value}"), ) .with_source(e) })?, ); } - for (key, value) in self.props.iter() { - if let Some(stripped_key) = key.strip_prefix("header.") { - // Avoid overwriting default headers - if !headers.contains_key(stripped_key) { - headers.insert( - HeaderName::from_str(stripped_key).map_err(|e| { - Error::new( - ErrorKind::DataInvalid, - format!("Invalid header name: {stripped_key}!"), - ) - .with_source(e) - })?, - HeaderValue::from_str(value).map_err(|e| { - Error::new( - ErrorKind::DataInvalid, - format!("Invalid header value: {value}!"), - ) - .with_source(e) - })?, - ); - } - } - } Ok(headers) } - fn try_create_rest_client(&self) -> Result { - // TODO: We will add ssl config, sigv4 later - let headers = self.http_headers()?; - - Ok(HttpClient( - Client::builder().default_headers(headers).build()?, - )) - } + /// Get the optional oauth headers from the config. + pub(crate) fn extra_oauth_params(&self) -> HashMap { + let mut params = HashMap::new(); - fn optional_oauth_params(&self) -> HashMap<&str, &str> { - let mut optional_oauth_param = HashMap::new(); if let Some(scope) = self.props.get("scope") { - optional_oauth_param.insert("scope", scope.as_str()); + params.insert("scope".to_string(), scope.to_string()); } else { - optional_oauth_param.insert("scope", "catalog"); + params.insert("scope".to_string(), "catalog".to_string()); } - let set_of_optional_params = ["audience", "resource"]; - for param_name in set_of_optional_params.iter() { - if let Some(value) = self.props.get(*param_name) { - optional_oauth_param.insert(param_name.to_owned(), value); + + let optional_params = ["audience", "resource"]; + for param_name in optional_params { + if let Some(value) = self.props.get(param_name) { + params.insert(param_name.to_string(), value.to_string()); } } - optional_oauth_param + params + } + + /// Merge the config with the given config fetched from rest server. + pub(crate) fn merge_with_config(mut self, mut config: CatalogConfig) -> Self { + if let Some(uri) = config.overrides.remove("uri") { + self.uri = uri; + } + + let mut props = config.defaults; + props.extend(self.props); + props.extend(config.overrides); + + self.props = props; + self } } #[derive(Debug)] -struct HttpClient(Client); - -impl HttpClient { - async fn query< - R: DeserializeOwned, - E: DeserializeOwned + Into, - const SUCCESS_CODE: u16, - >( - &self, - request: Request, - ) -> Result { - let resp = self.0.execute(request).await?; +struct RestContext { + client: HttpClient, - if resp.status().as_u16() == SUCCESS_CODE { - let text = resp.bytes().await?; - Ok(serde_json::from_slice::(&text).map_err(|e| { - Error::new( - ErrorKind::Unexpected, - "Failed to parse response from rest catalog server!", - ) - .with_context("json", String::from_utf8_lossy(&text)) - .with_source(e) - })?) - } else { - let code = resp.status(); - let text = resp.bytes().await?; - let e = serde_json::from_slice::(&text).map_err(|e| { - Error::new( - ErrorKind::Unexpected, - "Failed to parse response from rest catalog server!", - ) - .with_context("json", String::from_utf8_lossy(&text)) - .with_context("code", code.to_string()) - .with_source(e) - })?; - Err(e.into()) + /// Runtime config is fetched from rest server and stored here. + /// + /// It's could be different from the user config. + config: RestCatalogConfig, +} + +impl RestContext {} + +/// Rest catalog implementation. +#[derive(Debug)] +pub struct RestCatalog { + /// User config is stored as-is and never be changed. + /// + /// It's could be different from the config fetched from the server and used at runtime. + user_config: RestCatalogConfig, + ctx: OnceCell, +} + +impl RestCatalog { + /// Creates a rest catalog from config. + pub fn new(config: RestCatalogConfig) -> Self { + Self { + user_config: config, + ctx: OnceCell::new(), } } - async fn execute, const SUCCESS_CODE: u16>( - &self, - request: Request, - ) -> Result<()> { - let resp = self.0.execute(request).await?; + /// Get the context from the catalog. + async fn context(&self) -> Result<&RestContext> { + self.ctx + .get_or_try_init(|| async { + let catalog_config = RestCatalog::load_config(&self.user_config).await?; + let config = self.user_config.clone().merge_with_config(catalog_config); + let client = HttpClient::new(&config)?; - if resp.status().as_u16() == SUCCESS_CODE { - Ok(()) - } else { - let code = resp.status(); - let text = resp.bytes().await?; - let e = serde_json::from_slice::(&text).map_err(|e| { - Error::new( - ErrorKind::Unexpected, - "Failed to parse response from rest catalog server!", - ) - .with_context("json", String::from_utf8_lossy(&text)) - .with_context("code", code.to_string()) - .with_source(e) - })?; - Err(e.into()) + Ok(RestContext { config, client }) + }) + .await + } + + /// Load the runtime config from the server by user_config. + /// + /// It's required for a rest catalog to update it's config after creation. + async fn load_config(user_config: &RestCatalogConfig) -> Result { + let client = HttpClient::new(user_config)?; + + let mut request = client.request(Method::GET, user_config.config_endpoint()); + + if let Some(warehouse_location) = &user_config.warehouse { + request = request.query(&[("warehouse", warehouse_location)]); } + + let config = client + .query::(request.build()?) + .await?; + Ok(config) } - /// More generic logic handling for special cases like head. - async fn do_execute>( + async fn load_file_io( &self, - request: Request, - handler: impl FnOnce(&Response) -> Option, - ) -> Result { - let resp = self.0.execute(request).await?; + metadata_location: Option<&str>, + extra_config: Option>, + ) -> Result { + let mut props = self.context().await?.config.props.clone(); + if let Some(config) = extra_config { + props.extend(config); + } - if let Some(ret) = handler(&resp) { - Ok(ret) - } else { - let code = resp.status(); - let text = resp.bytes().await?; - let e = serde_json::from_slice::(&text).map_err(|e| { - Error::new( + // If the warehouse is a logical identifier instead of a URL we don't want + // to raise an exception + let warehouse_path = match self.context().await?.config.warehouse.as_deref() { + Some(url) if Url::parse(url).is_ok() => Some(url), + Some(_) => None, + None => None, + }; + + let file_io = match warehouse_path.or(metadata_location) { + Some(url) => FileIO::from_path(url)?.with_props(props).build()?, + None => { + return Err(Error::new( ErrorKind::Unexpected, - "Failed to parse response from rest catalog server!", - ) - .with_context("code", code.to_string()) - .with_context("json", String::from_utf8_lossy(&text)) - .with_source(e) - })?; - Err(e.into()) - } - } -} + "Unable to load file io, neither warehouse nor metadata location is set!", + ))? + } + }; -/// Rest catalog implementation. -#[derive(Debug)] -pub struct RestCatalog { - config: RestCatalogConfig, - client: HttpClient, + Ok(file_io) + } } #[async_trait] @@ -290,12 +314,17 @@ impl Catalog for RestCatalog { &self, parent: Option<&NamespaceIdent>, ) -> Result> { - let mut request = self.client.0.get(self.config.namespaces_endpoint()); + let mut request = self.context().await?.client.request( + Method::GET, + self.context().await?.config.namespaces_endpoint(), + ); if let Some(ns) = parent { - request = request.query(&[("parent", ns.encode_in_url())]); + request = request.query(&[("parent", ns.to_url_string())]); } let resp = self + .context() + .await? .client .query::(request.build()?) .await?; @@ -313,9 +342,13 @@ impl Catalog for RestCatalog { properties: HashMap, ) -> Result { let request = self + .context() + .await? .client - .0 - .post(self.config.namespaces_endpoint()) + .request( + Method::POST, + self.context().await?.config.namespaces_endpoint(), + ) .json(&NamespaceSerde { namespace: namespace.as_ref().clone(), properties: Some(properties), @@ -323,6 +356,8 @@ impl Catalog for RestCatalog { .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; @@ -333,12 +368,18 @@ impl Catalog for RestCatalog { /// Get a namespace information from the catalog. async fn get_namespace(&self, namespace: &NamespaceIdent) -> Result { let request = self + .context() + .await? .client - .0 - .get(self.config.namespace_endpoint(namespace)) + .request( + Method::GET, + self.context().await?.config.namespace_endpoint(namespace), + ) .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; @@ -363,12 +404,18 @@ impl Catalog for RestCatalog { async fn namespace_exists(&self, ns: &NamespaceIdent) -> Result { let request = self + .context() + .await? .client - .0 - .head(self.config.namespace_endpoint(ns)) + .request( + Method::HEAD, + self.context().await?.config.namespace_endpoint(ns), + ) .build()?; - self.client + self.context() + .await? + .client .do_execute::(request, |resp| match resp.status() { StatusCode::NO_CONTENT => Some(true), StatusCode::NOT_FOUND => Some(false), @@ -380,12 +427,18 @@ impl Catalog for RestCatalog { /// Drop a namespace from the catalog. async fn drop_namespace(&self, namespace: &NamespaceIdent) -> Result<()> { let request = self + .context() + .await? .client - .0 - .delete(self.config.namespace_endpoint(namespace)) + .request( + Method::DELETE, + self.context().await?.config.namespace_endpoint(namespace), + ) .build()?; - self.client + self.context() + .await? + .client .execute::(request) .await } @@ -393,12 +446,18 @@ impl Catalog for RestCatalog { /// List tables from namespace. async fn list_tables(&self, namespace: &NamespaceIdent) -> Result> { let request = self + .context() + .await? .client - .0 - .get(self.config.tables_endpoint(namespace)) + .request( + Method::GET, + self.context().await?.config.tables_endpoint(namespace), + ) .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; @@ -415,9 +474,13 @@ impl Catalog for RestCatalog { let table_ident = TableIdent::new(namespace.clone(), creation.name.clone()); let request = self + .context() + .await? .client - .0 - .post(self.config.tables_endpoint(namespace)) + .request( + Method::POST, + self.context().await?.config.tables_endpoint(namespace), + ) .json(&CreateTableRequest { name: creation.name, location: creation.location, @@ -435,11 +498,15 @@ impl Catalog for RestCatalog { .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; - let file_io = self.load_file_io(resp.metadata_location.as_deref(), resp.config)?; + let file_io = self + .load_file_io(resp.metadata_location.as_deref(), resp.config) + .await?; let table = Table::builder() .identifier(table_ident) @@ -459,17 +526,25 @@ impl Catalog for RestCatalog { /// Load table from the catalog. async fn load_table(&self, table: &TableIdent) -> Result { let request = self + .context() + .await? .client - .0 - .get(self.config.table_endpoint(table)) + .request( + Method::GET, + self.context().await?.config.table_endpoint(table), + ) .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; - let file_io = self.load_file_io(resp.metadata_location.as_deref(), resp.config)?; + let file_io = self + .load_file_io(resp.metadata_location.as_deref(), resp.config) + .await?; let table_builder = Table::builder() .identifier(table.clone()) @@ -486,12 +561,18 @@ impl Catalog for RestCatalog { /// Drop a table from the catalog. async fn drop_table(&self, table: &TableIdent) -> Result<()> { let request = self + .context() + .await? .client - .0 - .delete(self.config.table_endpoint(table)) + .request( + Method::DELETE, + self.context().await?.config.table_endpoint(table), + ) .build()?; - self.client + self.context() + .await? + .client .execute::(request) .await } @@ -499,12 +580,18 @@ impl Catalog for RestCatalog { /// Check if a table exists in the catalog. async fn table_exists(&self, table: &TableIdent) -> Result { let request = self + .context() + .await? .client - .0 - .head(self.config.table_endpoint(table)) + .request( + Method::HEAD, + self.context().await?.config.table_endpoint(table), + ) .build()?; - self.client + self.context() + .await? + .client .do_execute::(request, |resp| match resp.status() { StatusCode::NO_CONTENT => Some(true), StatusCode::NOT_FOUND => Some(false), @@ -516,16 +603,22 @@ impl Catalog for RestCatalog { /// Rename a table in the catalog. async fn rename_table(&self, src: &TableIdent, dest: &TableIdent) -> Result<()> { let request = self + .context() + .await? .client - .0 - .post(self.config.rename_table_endpoint()) + .request( + Method::POST, + self.context().await?.config.rename_table_endpoint(), + ) .json(&RenameTableRequest { source: src.clone(), destination: dest.clone(), }) .build()?; - self.client + self.context() + .await? + .client .execute::(request) .await } @@ -533,9 +626,16 @@ impl Catalog for RestCatalog { /// Update table. async fn update_table(&self, mut commit: TableCommit) -> Result
{ let request = self + .context() + .await? .client - .0 - .post(self.config.table_endpoint(commit.identifier())) + .request( + Method::POST, + self.context() + .await? + .config + .table_endpoint(commit.identifier()), + ) .json(&CommitTableRequest { identifier: commit.identifier().clone(), requirements: commit.take_requirements(), @@ -544,11 +644,15 @@ impl Catalog for RestCatalog { .build()?; let resp = self + .context() + .await? .client .query::(request) .await?; - let file_io = self.load_file_io(Some(&resp.metadata_location), None)?; + let file_io = self + .load_file_io(Some(&resp.metadata_location), None) + .await?; Ok(Table::builder() .identifier(commit.identifier().clone()) .file_io(file_io) @@ -558,294 +662,6 @@ impl Catalog for RestCatalog { } } -impl RestCatalog { - /// Creates a rest catalog from config. - pub async fn new(config: RestCatalogConfig) -> Result { - let mut catalog = Self { - client: config.try_create_rest_client()?, - config, - }; - catalog.fetch_access_token().await?; - catalog.client = catalog.config.try_create_rest_client()?; - catalog.update_config().await?; - catalog.client = catalog.config.try_create_rest_client()?; - - Ok(catalog) - } - - async fn fetch_access_token(&mut self) -> Result<()> { - if self.config.props.contains_key("token") { - return Ok(()); - } - if let Some(credential) = self.config.props.get("credential") { - let (client_id, client_secret) = if credential.contains(':') { - let (client_id, client_secret) = credential.split_once(':').unwrap(); - (Some(client_id), client_secret) - } else { - (None, credential.as_str()) - }; - let mut params = HashMap::with_capacity(4); - params.insert("grant_type", "client_credentials"); - if let Some(client_id) = client_id { - params.insert("client_id", client_id); - } - params.insert("client_secret", client_secret); - let optional_oauth_params = self.config.optional_oauth_params(); - params.extend(optional_oauth_params); - let req = self - .client - .0 - .post(self.config.get_token_endpoint()) - .form(¶ms) - .build()?; - let res = self - .client - .query::(req) - .await - .map_err(|e| { - Error::new( - ErrorKind::Unexpected, - "Failed to fetch access token from catalog server!", - ) - .with_source(e) - })?; - let token = res.access_token; - self.config.props.insert("token".to_string(), token); - } - - Ok(()) - } - - async fn update_config(&mut self) -> Result<()> { - let mut request = self.client.0.get(self.config.config_endpoint()); - - if let Some(warehouse_location) = &self.config.warehouse { - request = request.query(&[("warehouse", warehouse_location)]); - } - - let mut config = self - .client - .query::(request.build()?) - .await?; - - let mut props = config.defaults; - props.extend(self.config.props.clone()); - if let Some(uri) = config.overrides.remove("uri") { - self.config.uri = uri; - } - props.extend(config.overrides); - - self.config.props = props; - - Ok(()) - } - - fn load_file_io( - &self, - metadata_location: Option<&str>, - extra_config: Option>, - ) -> Result { - let mut props = self.config.props.clone(); - if let Some(config) = extra_config { - props.extend(config); - } - - // If the warehouse is a logical identifier instead of a URL we don't want - // to raise an exception - let warehouse_path = match self.config.warehouse.as_deref() { - Some(url) if Url::parse(url).is_ok() => Some(url), - Some(_) => None, - None => None, - }; - - let file_io = match warehouse_path.or(metadata_location) { - Some(url) => FileIO::from_path(url)?.with_props(props).build()?, - None => { - return Err(Error::new( - ErrorKind::Unexpected, - "Unable to load file io, neither warehouse nor metadata location is set!", - ))? - } - }; - - Ok(file_io) - } -} - -/// Requests and responses for rest api. -mod _serde { - use std::collections::HashMap; - - use serde_derive::{Deserialize, Serialize}; - - use iceberg::spec::{Schema, SortOrder, TableMetadata, UnboundPartitionSpec}; - use iceberg::{Error, ErrorKind, Namespace, TableIdent, TableRequirement, TableUpdate}; - - pub(super) const OK: u16 = 200u16; - pub(super) const NO_CONTENT: u16 = 204u16; - - #[derive(Clone, Debug, Serialize, Deserialize)] - pub(super) struct CatalogConfig { - pub(super) overrides: HashMap, - pub(super) defaults: HashMap, - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct ErrorResponse { - error: ErrorModel, - } - - impl From for Error { - fn from(resp: ErrorResponse) -> Error { - resp.error.into() - } - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct ErrorModel { - pub(super) message: String, - pub(super) r#type: String, - pub(super) code: u16, - pub(super) stack: Option>, - } - - impl From for Error { - fn from(value: ErrorModel) -> Self { - let mut error = Error::new(ErrorKind::DataInvalid, value.message) - .with_context("type", value.r#type) - .with_context("code", format!("{}", value.code)); - - if let Some(stack) = value.stack { - error = error.with_context("stack", stack.join("\n")); - } - - error - } - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct OAuthError { - pub(super) error: String, - pub(super) error_description: Option, - pub(super) error_uri: Option, - } - - impl From for Error { - fn from(value: OAuthError) -> Self { - let mut error = Error::new( - ErrorKind::DataInvalid, - format!("OAuthError: {}", value.error), - ); - - if let Some(desc) = value.error_description { - error = error.with_context("description", desc); - } - - if let Some(uri) = value.error_uri { - error = error.with_context("uri", uri); - } - - error - } - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct TokenResponse { - pub(super) access_token: String, - pub(super) token_type: String, - pub(super) expires_in: Option, - pub(super) issued_token_type: Option, - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct NamespaceSerde { - pub(super) namespace: Vec, - pub(super) properties: Option>, - } - - impl TryFrom for super::Namespace { - type Error = Error; - fn try_from(value: NamespaceSerde) -> std::result::Result { - Ok(super::Namespace::with_properties( - super::NamespaceIdent::from_vec(value.namespace)?, - value.properties.unwrap_or_default(), - )) - } - } - - impl From<&Namespace> for NamespaceSerde { - fn from(value: &Namespace) -> Self { - Self { - namespace: value.name().as_ref().clone(), - properties: Some(value.properties().clone()), - } - } - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct ListNamespaceResponse { - pub(super) namespaces: Vec>, - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct UpdateNamespacePropsRequest { - removals: Option>, - updates: Option>, - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct UpdateNamespacePropsResponse { - updated: Vec, - removed: Vec, - missing: Option>, - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct ListTableResponse { - pub(super) identifiers: Vec, - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct RenameTableRequest { - pub(super) source: TableIdent, - pub(super) destination: TableIdent, - } - - #[derive(Debug, Deserialize)] - #[serde(rename_all = "kebab-case")] - pub(super) struct LoadTableResponse { - pub(super) metadata_location: Option, - pub(super) metadata: TableMetadata, - pub(super) config: Option>, - } - - #[derive(Debug, Serialize, Deserialize)] - #[serde(rename_all = "kebab-case")] - pub(super) struct CreateTableRequest { - pub(super) name: String, - pub(super) location: Option, - pub(super) schema: Schema, - pub(super) partition_spec: Option, - pub(super) write_order: Option, - pub(super) stage_create: Option, - pub(super) properties: Option>, - } - - #[derive(Debug, Serialize, Deserialize)] - pub(super) struct CommitTableRequest { - pub(super) identifier: TableIdent, - pub(super) requirements: Vec, - pub(super) updates: Vec, - } - - #[derive(Debug, Serialize, Deserialize)] - #[serde(rename_all = "kebab-case")] - pub(super) struct CommitTableResponse { - pub(super) metadata_location: String, - pub(super) metadata: TableMetadata, - } -} - #[cfg(test)] mod tests { use chrono::{TimeZone, Utc}; @@ -882,12 +698,16 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); assert_eq!( - catalog.config.props.get("warehouse"), + catalog + .context() + .await + .unwrap() + .config + .props + .get("warehouse"), Some(&"s3://iceberg-catalog".to_string()) ); @@ -926,6 +746,7 @@ mod tests { "expires_in": 86400 }"#, ) + .expect(2) .create_async() .await } @@ -944,16 +765,12 @@ mod tests { .uri(server.url()) .props(props) .build(), - ) - .await - .unwrap(); + ); + let token = catalog.context().await.unwrap().client.token().await; oauth_mock.assert_async().await; config_mock.assert_async().await; - assert_eq!( - catalog.config.props.get("token"), - Some(&"ey000000000000".to_string()) - ); + assert_eq!(token, Some("ey000000000000".to_string())); } #[tokio::test] @@ -983,6 +800,7 @@ mod tests { "expires_in": 86400 }"#, ) + .expect(2) .create_async() .await; @@ -993,16 +811,13 @@ mod tests { .uri(server.url()) .props(props) .build(), - ) - .await - .unwrap(); + ); + + let token = catalog.context().await.unwrap().client.token().await; oauth_mock.assert_async().await; config_mock.assert_async().await; - assert_eq!( - catalog.config.props.get("token"), - Some(&"ey000000000000".to_string()) - ); + assert_eq!(token, Some("ey000000000000".to_string())); } #[tokio::test] @@ -1015,7 +830,7 @@ mod tests { .uri(server.url()) .props(props) .build(); - let headers: HeaderMap = config.http_headers().unwrap(); + let headers: HeaderMap = config.extra_headers().unwrap(); let expected_headers = HeaderMap::from_iter([ ( @@ -1052,12 +867,12 @@ mod tests { .uri(server.url()) .props(props) .build(); - let headers: HeaderMap = config.http_headers().unwrap(); + let headers: HeaderMap = config.extra_headers().unwrap(); let expected_headers = HeaderMap::from_iter([ ( header::CONTENT_TYPE, - HeaderValue::from_static("application/json"), + HeaderValue::from_static("application/yaml"), ), ( HeaderName::from_static("x-client-version"), @@ -1096,16 +911,13 @@ mod tests { .uri(server.url()) .props(props) .build(), - ) - .await - .unwrap(); + ); + + let token = catalog.context().await.unwrap().client.token().await; oauth_mock.assert_async().await; config_mock.assert_async().await; - assert_eq!( - catalog.config.props.get("token"), - Some(&"ey000000000000".to_string()) - ); + assert_eq!(token, Some("ey000000000000".to_string())); } #[tokio::test] @@ -1143,9 +955,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let _namespaces = catalog.list_namespaces(None).await.unwrap(); @@ -1172,9 +982,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let namespaces = catalog.list_namespaces(None).await.unwrap(); @@ -1208,9 +1016,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let namespaces = catalog .create_namespace( @@ -1250,9 +1056,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let namespaces = catalog .get_namespace(&NamespaceIdent::new("ns1".to_string())) @@ -1282,9 +1086,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); assert!(catalog .namespace_exists(&NamespaceIdent::new("ns1".to_string())) @@ -1307,9 +1109,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); catalog .drop_namespace(&NamespaceIdent::new("ns1".to_string())) @@ -1346,9 +1146,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let tables = catalog .list_tables(&NamespaceIdent::new("ns1".to_string())) @@ -1378,9 +1176,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); catalog .drop_table(&TableIdent::new( @@ -1406,9 +1202,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); assert!(catalog .table_exists(&TableIdent::new( @@ -1434,9 +1228,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); catalog .rename_table( @@ -1467,9 +1259,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table = catalog .load_table(&TableIdent::new( @@ -1580,9 +1370,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table = catalog .load_table(&TableIdent::new( @@ -1619,9 +1407,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table_creation = TableCreation::builder() .name("test1".to_string()) @@ -1761,9 +1547,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table_creation = TableCreation::builder() .name("test1".to_string()) @@ -1816,9 +1600,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table1 = { let file = File::open(format!( @@ -1938,9 +1720,7 @@ mod tests { .create_async() .await; - let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()) - .await - .unwrap(); + let catalog = RestCatalog::new(RestCatalogConfig::builder().uri(server.url()).build()); let table1 = { let file = File::open(format!( diff --git a/crates/catalog/rest/src/client.rs b/crates/catalog/rest/src/client.rs new file mode 100644 index 000000000..43e14c731 --- /dev/null +++ b/crates/catalog/rest/src/client.rs @@ -0,0 +1,276 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::types::{ErrorResponse, TokenResponse, OK}; +use crate::RestCatalogConfig; +use iceberg::Result; +use iceberg::{Error, ErrorKind}; +use reqwest::header::HeaderMap; +use reqwest::{Client, IntoUrl, Method, Request, RequestBuilder, Response}; +use serde::de::DeserializeOwned; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use std::sync::Mutex; + +pub(crate) struct HttpClient { + client: Client, + + /// The token to be used for authentication. + /// + /// It's possible to fetch the token from the server while needed. + token: Mutex>, + /// The token endpoint to be used for authentication. + token_endpoint: String, + /// The credential to be used for authentication. + credential: Option<(Option, String)>, + /// Extra headers to be added to each request. + extra_headers: HeaderMap, + /// Extra oauth parameters to be added to each authentication request. + extra_oauth_params: HashMap, +} + +impl Debug for HttpClient { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("HttpClient") + .field("client", &self.client) + .field("extra_headers", &self.extra_headers) + .finish_non_exhaustive() + } +} + +impl HttpClient { + pub fn new(cfg: &RestCatalogConfig) -> Result { + Ok(HttpClient { + client: Client::new(), + + token: Mutex::new(cfg.token()), + token_endpoint: cfg.get_token_endpoint(), + credential: cfg.credential(), + extra_headers: cfg.extra_headers()?, + extra_oauth_params: cfg.extra_oauth_params(), + }) + } + + /// This API is testing only to assert the token. + #[cfg(test)] + pub(crate) async fn token(&self) -> Option { + let mut req = self + .request(Method::GET, &self.token_endpoint) + .build() + .unwrap(); + self.authenticate(&mut req).await.ok(); + self.token.lock().unwrap().clone() + } + + /// Authenticate the request by filling token. + /// + /// - If neither token nor credential is provided, this method will do nothing. + /// - If only credential is provided, this method will try to fetch token from the server. + /// - If token is provided, this method will use the token directly. + /// + /// # TODO + /// + /// Support refreshing token while needed. + async fn authenticate(&self, req: &mut Request) -> Result<()> { + // Clone the token from lock without holding the lock for entire function. + let token = { self.token.lock().expect("lock poison").clone() }; + + if self.credential.is_none() && token.is_none() { + return Ok(()); + } + + // Use token if provided. + if let Some(token) = &token { + req.headers_mut().insert( + http::header::AUTHORIZATION, + format!("Bearer {token}").parse().map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Invalid token received from catalog server!", + ) + .with_source(e) + })?, + ); + return Ok(()); + } + + // Credential must exist here. + let (client_id, client_secret) = self.credential.as_ref().ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "Credential must be provided for authentication", + ) + })?; + + let mut params = HashMap::with_capacity(4); + params.insert("grant_type", "client_credentials"); + if let Some(client_id) = client_id { + params.insert("client_id", client_id); + } + params.insert("client_secret", client_secret); + params.extend( + self.extra_oauth_params + .iter() + .map(|(k, v)| (k.as_str(), v.as_str())), + ); + + let auth_req = self + .client + .request(Method::POST, &self.token_endpoint) + .form(¶ms) + .build()?; + let auth_resp = self.client.execute(auth_req).await?; + + let auth_res: TokenResponse = if auth_resp.status().as_u16() == OK { + let text = auth_resp.bytes().await?; + Ok(serde_json::from_slice(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_source(e) + })?) + } else { + let code = auth_resp.status(); + let text = auth_resp.bytes().await?; + let e: ErrorResponse = serde_json::from_slice(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_context("code", code.to_string()) + .with_source(e) + })?; + Err(Error::from(e)) + }?; + let token = auth_res.access_token; + // Update token. + *self.token.lock().expect("lock poison") = Some(token.clone()); + // Insert token in request. + req.headers_mut().insert( + http::header::AUTHORIZATION, + format!("Bearer {token}").parse().map_err(|e| { + Error::new( + ErrorKind::DataInvalid, + "Invalid token received from catalog server!", + ) + .with_source(e) + })?, + ); + + Ok(()) + } + + #[inline] + pub fn request(&self, method: Method, url: U) -> RequestBuilder { + self.client.request(method, url) + } + + pub async fn query< + R: DeserializeOwned, + E: DeserializeOwned + Into, + const SUCCESS_CODE: u16, + >( + &self, + mut request: Request, + ) -> Result { + self.authenticate(&mut request).await?; + + let resp = self.client.execute(request).await?; + + if resp.status().as_u16() == SUCCESS_CODE { + let text = resp.bytes().await?; + Ok(serde_json::from_slice::(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_source(e) + })?) + } else { + let code = resp.status(); + let text = resp.bytes().await?; + let e = serde_json::from_slice::(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_context("code", code.to_string()) + .with_source(e) + })?; + Err(e.into()) + } + } + + pub async fn execute, const SUCCESS_CODE: u16>( + &self, + mut request: Request, + ) -> Result<()> { + self.authenticate(&mut request).await?; + + let resp = self.client.execute(request).await?; + + if resp.status().as_u16() == SUCCESS_CODE { + Ok(()) + } else { + let code = resp.status(); + let text = resp.bytes().await?; + let e = serde_json::from_slice::(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("json", String::from_utf8_lossy(&text)) + .with_context("code", code.to_string()) + .with_source(e) + })?; + Err(e.into()) + } + } + + /// More generic logic handling for special cases like head. + pub async fn do_execute>( + &self, + mut request: Request, + handler: impl FnOnce(&Response) -> Option, + ) -> Result { + self.authenticate(&mut request).await?; + + let resp = self.client.execute(request).await?; + + if let Some(ret) = handler(&resp) { + Ok(ret) + } else { + let code = resp.status(); + let text = resp.bytes().await?; + let e = serde_json::from_slice::(&text).map_err(|e| { + Error::new( + ErrorKind::Unexpected, + "Failed to parse response from rest catalog server!", + ) + .with_context("code", code.to_string()) + .with_context("json", String::from_utf8_lossy(&text)) + .with_source(e) + })?; + Err(e.into()) + } + } +} diff --git a/crates/catalog/rest/src/lib.rs b/crates/catalog/rest/src/lib.rs index 023fe7ab2..f94ee8781 100644 --- a/crates/catalog/rest/src/lib.rs +++ b/crates/catalog/rest/src/lib.rs @@ -20,4 +20,7 @@ #![deny(missing_docs)] mod catalog; +mod client; +mod types; + pub use catalog::*; diff --git a/crates/catalog/rest/src/types.rs b/crates/catalog/rest/src/types.rs new file mode 100644 index 000000000..c8d704b26 --- /dev/null +++ b/crates/catalog/rest/src/types.rs @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; + +use serde_derive::{Deserialize, Serialize}; + +use iceberg::spec::{Schema, SortOrder, TableMetadata, UnboundPartitionSpec}; +use iceberg::{ + Error, ErrorKind, Namespace, NamespaceIdent, TableIdent, TableRequirement, TableUpdate, +}; + +pub(super) const OK: u16 = 200u16; +pub(super) const NO_CONTENT: u16 = 204u16; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(super) struct CatalogConfig { + pub(super) overrides: HashMap, + pub(super) defaults: HashMap, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct ErrorResponse { + error: ErrorModel, +} + +impl From for Error { + fn from(resp: ErrorResponse) -> Error { + resp.error.into() + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct ErrorModel { + pub(super) message: String, + pub(super) r#type: String, + pub(super) code: u16, + pub(super) stack: Option>, +} + +impl From for Error { + fn from(value: ErrorModel) -> Self { + let mut error = Error::new(ErrorKind::DataInvalid, value.message) + .with_context("type", value.r#type) + .with_context("code", format!("{}", value.code)); + + if let Some(stack) = value.stack { + error = error.with_context("stack", stack.join("\n")); + } + + error + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct OAuthError { + pub(super) error: String, + pub(super) error_description: Option, + pub(super) error_uri: Option, +} + +impl From for Error { + fn from(value: OAuthError) -> Self { + let mut error = Error::new( + ErrorKind::DataInvalid, + format!("OAuthError: {}", value.error), + ); + + if let Some(desc) = value.error_description { + error = error.with_context("description", desc); + } + + if let Some(uri) = value.error_uri { + error = error.with_context("uri", uri); + } + + error + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct TokenResponse { + pub(super) access_token: String, + pub(super) token_type: String, + pub(super) expires_in: Option, + pub(super) issued_token_type: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct NamespaceSerde { + pub(super) namespace: Vec, + pub(super) properties: Option>, +} + +impl TryFrom for Namespace { + type Error = Error; + fn try_from(value: NamespaceSerde) -> std::result::Result { + Ok(Namespace::with_properties( + NamespaceIdent::from_vec(value.namespace)?, + value.properties.unwrap_or_default(), + )) + } +} + +impl From<&Namespace> for NamespaceSerde { + fn from(value: &Namespace) -> Self { + Self { + namespace: value.name().as_ref().clone(), + properties: Some(value.properties().clone()), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct ListNamespaceResponse { + pub(super) namespaces: Vec>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct UpdateNamespacePropsRequest { + removals: Option>, + updates: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct UpdateNamespacePropsResponse { + updated: Vec, + removed: Vec, + missing: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct ListTableResponse { + pub(super) identifiers: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct RenameTableRequest { + pub(super) source: TableIdent, + pub(super) destination: TableIdent, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(super) struct LoadTableResponse { + pub(super) metadata_location: Option, + pub(super) metadata: TableMetadata, + pub(super) config: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(super) struct CreateTableRequest { + pub(super) name: String, + pub(super) location: Option, + pub(super) schema: Schema, + pub(super) partition_spec: Option, + pub(super) write_order: Option, + pub(super) stage_create: Option, + pub(super) properties: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct CommitTableRequest { + pub(super) identifier: TableIdent, + pub(super) requirements: Vec, + pub(super) updates: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub(super) struct CommitTableResponse { + pub(super) metadata_location: String, + pub(super) metadata: TableMetadata, +} diff --git a/crates/catalog/rest/testdata/rest_catalog/docker-compose.yaml b/crates/catalog/rest/testdata/rest_catalog/docker-compose.yaml index 0152a22ca..b49b6c6c1 100644 --- a/crates/catalog/rest/testdata/rest_catalog/docker-compose.yaml +++ b/crates/catalog/rest/testdata/rest_catalog/docker-compose.yaml @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -version: '3.8' - services: rest: image: tabulario/iceberg-rest:0.10.0 diff --git a/crates/catalog/rest/tests/rest_catalog_test.rs b/crates/catalog/rest/tests/rest_catalog_test.rs index 205428d61..621536a73 100644 --- a/crates/catalog/rest/tests/rest_catalog_test.rs +++ b/crates/catalog/rest/tests/rest_catalog_test.rs @@ -17,6 +17,7 @@ //! Integration tests for rest catalog. +use ctor::{ctor, dtor}; use iceberg::spec::{FormatVersion, NestedField, PrimitiveType, Schema, Type}; use iceberg::transaction::Transaction; use iceberg::{Catalog, Namespace, NamespaceIdent, TableCreation, TableIdent}; @@ -25,26 +26,37 @@ use iceberg_test_utils::docker::DockerCompose; use iceberg_test_utils::{normalize_test_name, set_up}; use port_scanner::scan_port_addr; use std::collections::HashMap; +use std::sync::RwLock; use tokio::time::sleep; const REST_CATALOG_PORT: u16 = 8181; +static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); -struct TestFixture { - _docker_compose: DockerCompose, - rest_catalog: RestCatalog, -} - -async fn set_test_fixture(func: &str) -> TestFixture { - set_up(); +#[ctor] +fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); let docker_compose = DockerCompose::new( - normalize_test_name(format!("{}_{func}", module_path!())), + normalize_test_name(module_path!()), format!("{}/testdata/rest_catalog", env!("CARGO_MANIFEST_DIR")), ); - - // Start docker compose docker_compose.run(); + guard.replace(docker_compose); +} - let rest_catalog_ip = docker_compose.get_container_ip("rest"); +#[dtor] +fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); +} + +async fn get_catalog() -> RestCatalog { + set_up(); + + let rest_catalog_ip = { + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + docker_compose.get_container_ip("rest") + }; let read_port = format!("{}:{}", rest_catalog_ip, REST_CATALOG_PORT); loop { @@ -59,21 +71,15 @@ async fn set_test_fixture(func: &str) -> TestFixture { let config = RestCatalogConfig::builder() .uri(format!("http://{}:{}", rest_catalog_ip, REST_CATALOG_PORT)) .build(); - let rest_catalog = RestCatalog::new(config).await.unwrap(); - - TestFixture { - _docker_compose: docker_compose, - rest_catalog, - } + RestCatalog::new(config) } #[tokio::test] async fn test_get_non_exist_namespace() { - let fixture = set_test_fixture("test_get_non_exist_namespace").await; + let catalog = get_catalog().await; - let result = fixture - .rest_catalog - .get_namespace(&NamespaceIdent::from_strs(["demo"]).unwrap()) + let result = catalog + .get_namespace(&NamespaceIdent::from_strs(["test_get_non_exist_namespace"]).unwrap()) .await; assert!(result.is_err()); @@ -85,7 +91,7 @@ async fn test_get_non_exist_namespace() { #[tokio::test] async fn test_get_namespace() { - let fixture = set_test_fixture("test_get_namespace").await; + let catalog = get_catalog().await; let ns = Namespace::with_properties( NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), @@ -96,11 +102,10 @@ async fn test_get_namespace() { ); // Verify that namespace doesn't exist - assert!(fixture.rest_catalog.get_namespace(ns.name()).await.is_err()); + assert!(catalog.get_namespace(ns.name()).await.is_err()); // Create this namespace - let created_ns = fixture - .rest_catalog + let created_ns = catalog .create_namespace(ns.name(), ns.properties().clone()) .await .unwrap(); @@ -109,17 +114,17 @@ async fn test_get_namespace() { assert_map_contains(ns.properties(), created_ns.properties()); // Check that this namespace already exists - let get_ns = fixture.rest_catalog.get_namespace(ns.name()).await.unwrap(); + let get_ns = catalog.get_namespace(ns.name()).await.unwrap(); assert_eq!(ns.name(), get_ns.name()); assert_map_contains(ns.properties(), created_ns.properties()); } #[tokio::test] async fn test_list_namespace() { - let fixture = set_test_fixture("test_list_namespace").await; + let catalog = get_catalog().await; let ns1 = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), + NamespaceIdent::from_strs(["test_list_namespace", "ios"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -127,7 +132,7 @@ async fn test_list_namespace() { ); let ns2 = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "macos"]).unwrap(), + NamespaceIdent::from_strs(["test_list_namespace", "macos"]).unwrap(), HashMap::from([ ("owner".to_string(), "xuanwo".to_string()), ("community".to_string(), "apache".to_string()), @@ -135,42 +140,41 @@ async fn test_list_namespace() { ); // Currently this namespace doesn't exist, so it should return error. - assert!(fixture - .rest_catalog - .list_namespaces(Some(&NamespaceIdent::from_strs(["apple"]).unwrap())) + assert!(catalog + .list_namespaces(Some( + &NamespaceIdent::from_strs(["test_list_namespace"]).unwrap() + )) .await .is_err()); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns1.name(), ns1.properties().clone()) .await .unwrap(); - fixture - .rest_catalog + catalog .create_namespace(ns2.name(), ns1.properties().clone()) .await .unwrap(); // List namespace - let mut nss = fixture - .rest_catalog - .list_namespaces(Some(&NamespaceIdent::from_strs(["apple"]).unwrap())) + let nss = catalog + .list_namespaces(Some( + &NamespaceIdent::from_strs(["test_list_namespace"]).unwrap(), + )) .await .unwrap(); - nss.sort(); - assert_eq!(&nss[0], ns1.name()); - assert_eq!(&nss[1], ns2.name()); + assert!(nss.contains(ns1.name())); + assert!(nss.contains(ns2.name())); } #[tokio::test] async fn test_list_empty_namespace() { - let fixture = set_test_fixture("test_list_empty_namespace").await; + let catalog = get_catalog().await; let ns_apple = Namespace::with_properties( - NamespaceIdent::from_strs(["apple"]).unwrap(), + NamespaceIdent::from_strs(["test_list_empty_namespace", "apple"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -178,23 +182,20 @@ async fn test_list_empty_namespace() { ); // Currently this namespace doesn't exist, so it should return error. - assert!(fixture - .rest_catalog + assert!(catalog .list_namespaces(Some(ns_apple.name())) .await .is_err()); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns_apple.name(), ns_apple.properties().clone()) .await .unwrap(); // List namespace - let nss = fixture - .rest_catalog - .list_namespaces(Some(&NamespaceIdent::from_strs(["apple"]).unwrap())) + let nss = catalog + .list_namespaces(Some(ns_apple.name())) .await .unwrap(); assert!(nss.is_empty()); @@ -202,10 +203,10 @@ async fn test_list_empty_namespace() { #[tokio::test] async fn test_list_root_namespace() { - let fixture = set_test_fixture("test_list_root_namespace").await; + let catalog = get_catalog().await; let ns1 = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), + NamespaceIdent::from_strs(["test_list_root_namespace", "apple", "ios"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -213,7 +214,7 @@ async fn test_list_root_namespace() { ); let ns2 = Namespace::with_properties( - NamespaceIdent::from_strs(["google", "android"]).unwrap(), + NamespaceIdent::from_strs(["test_list_root_namespace", "google", "android"]).unwrap(), HashMap::from([ ("owner".to_string(), "xuanwo".to_string()), ("community".to_string(), "apache".to_string()), @@ -221,38 +222,34 @@ async fn test_list_root_namespace() { ); // Currently this namespace doesn't exist, so it should return error. - assert!(fixture - .rest_catalog - .list_namespaces(Some(&NamespaceIdent::from_strs(["apple"]).unwrap())) + assert!(catalog + .list_namespaces(Some( + &NamespaceIdent::from_strs(["test_list_root_namespace"]).unwrap() + )) .await .is_err()); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns1.name(), ns1.properties().clone()) .await .unwrap(); - fixture - .rest_catalog + catalog .create_namespace(ns2.name(), ns1.properties().clone()) .await .unwrap(); // List namespace - let mut nss = fixture.rest_catalog.list_namespaces(None).await.unwrap(); - nss.sort(); - - assert_eq!(&nss[0], &NamespaceIdent::from_strs(["apple"]).unwrap()); - assert_eq!(&nss[1], &NamespaceIdent::from_strs(["google"]).unwrap()); + let nss = catalog.list_namespaces(None).await.unwrap(); + assert!(nss.contains(&NamespaceIdent::from_strs(["test_list_root_namespace"]).unwrap())); } #[tokio::test] async fn test_create_table() { - let fixture = set_test_fixture("test_create_table").await; + let catalog = get_catalog().await; let ns = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), + NamespaceIdent::from_strs(["test_create_table", "apple", "ios"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -260,8 +257,7 @@ async fn test_create_table() { ); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns.name(), ns.properties().clone()) .await .unwrap(); @@ -282,8 +278,7 @@ async fn test_create_table() { .schema(schema.clone()) .build(); - let table = fixture - .rest_catalog + let table = catalog .create_table(ns.name(), table_creation) .await .unwrap(); @@ -310,10 +305,10 @@ async fn test_create_table() { #[tokio::test] async fn test_update_table() { - let fixture = set_test_fixture("test_update_table").await; + let catalog = get_catalog().await; let ns = Namespace::with_properties( - NamespaceIdent::from_strs(["apple", "ios"]).unwrap(), + NamespaceIdent::from_strs(["test_update_table", "apple", "ios"]).unwrap(), HashMap::from([ ("owner".to_string(), "ray".to_string()), ("community".to_string(), "apache".to_string()), @@ -321,8 +316,7 @@ async fn test_update_table() { ); // Create namespaces - fixture - .rest_catalog + catalog .create_namespace(ns.name(), ns.properties().clone()) .await .unwrap(); @@ -344,8 +338,7 @@ async fn test_update_table() { .schema(schema.clone()) .build(); - let table = fixture - .rest_catalog + let table = catalog .create_table(ns.name(), table_creation) .await .unwrap(); @@ -359,7 +352,7 @@ async fn test_update_table() { let table2 = Transaction::new(&table) .set_properties(HashMap::from([("prop1".to_string(), "v1".to_string())])) .unwrap() - .commit(&fixture.rest_catalog) + .commit(&catalog) .await .unwrap(); @@ -375,3 +368,39 @@ fn assert_map_contains(map1: &HashMap, map2: &HashMap, - field_ids: Vec, file_io: FileIO, - schema: SchemaRef, - predicates: Option, } impl ArrowReaderBuilder { /// Create a new ArrowReaderBuilder - pub fn new(file_io: FileIO, schema: SchemaRef) -> Self { + pub(crate) fn new(file_io: FileIO) -> Self { ArrowReaderBuilder { batch_size: None, - field_ids: vec![], file_io, - schema, - predicates: None, } } @@ -74,38 +69,20 @@ impl ArrowReaderBuilder { self } - /// Sets the desired column projection with a list of field ids. - pub fn with_field_ids(mut self, field_ids: impl IntoIterator) -> Self { - self.field_ids = field_ids.into_iter().collect(); - self - } - - /// Sets the predicates to apply to the scan. - pub fn with_predicates(mut self, predicates: BoundPredicate) -> Self { - self.predicates = Some(predicates); - self - } - /// Build the ArrowReader. pub fn build(self) -> ArrowReader { ArrowReader { batch_size: self.batch_size, - field_ids: self.field_ids, - schema: self.schema, file_io: self.file_io, - predicates: self.predicates, } } } /// Reads data from Parquet files +#[derive(Clone)] pub struct ArrowReader { batch_size: Option, - field_ids: Vec, - #[allow(dead_code)] - schema: SchemaRef, file_io: FileIO, - predicates: Option, } impl ArrowReader { @@ -114,16 +91,16 @@ impl ArrowReader { pub fn read(self, mut tasks: FileScanTaskStream) -> crate::Result { let file_io = self.file_io.clone(); - // Collect Parquet column indices from field ids - let mut collector = CollectFieldIdVisitor { - field_ids: HashSet::default(), - }; - if let Some(predicates) = &self.predicates { - visit(&mut collector, predicates)?; - } - Ok(try_stream! { while let Some(Ok(task)) = tasks.next().await { + // Collect Parquet column indices from field ids + let mut collector = CollectFieldIdVisitor { + field_ids: HashSet::default(), + }; + if let Some(predicates) = task.predicate() { + visit(&mut collector, predicates)?; + } + let parquet_file = file_io .new_input(task.data_file_path())?; let (parquet_metadata, parquet_reader) = try_join!(parquet_file.metadata(), parquet_file.reader())?; @@ -134,11 +111,11 @@ impl ArrowReader { let parquet_schema = batch_stream_builder.parquet_schema(); let arrow_schema = batch_stream_builder.schema(); - let projection_mask = self.get_arrow_projection_mask(parquet_schema, arrow_schema)?; + let projection_mask = self.get_arrow_projection_mask(task.project_field_ids(),task.schema(),parquet_schema, arrow_schema)?; batch_stream_builder = batch_stream_builder.with_projection(projection_mask); let parquet_schema = batch_stream_builder.parquet_schema(); - let row_filter = self.get_row_filter(parquet_schema, &collector)?; + let row_filter = self.get_row_filter(task.predicate(),parquet_schema, &collector)?; if let Some(row_filter) = row_filter { batch_stream_builder = batch_stream_builder.with_row_filter(row_filter); @@ -160,10 +137,12 @@ impl ArrowReader { fn get_arrow_projection_mask( &self, + field_ids: &[i32], + iceberg_schema_of_task: &Schema, parquet_schema: &SchemaDescriptor, arrow_schema: &ArrowSchemaRef, ) -> crate::Result { - if self.field_ids.is_empty() { + if field_ids.is_empty() { Ok(ProjectionMask::all()) } else { // Build the map between field id and column index in Parquet schema. @@ -183,11 +162,11 @@ impl ArrowReader { } let field_id = field_id.unwrap(); - if !self.field_ids.contains(&(field_id as usize)) { + if !field_ids.contains(&field_id) { return false; } - let iceberg_field = self.schema.field_by_id(field_id); + let iceberg_field = iceberg_schema_of_task.field_by_id(field_id); let parquet_iceberg_field = iceberg_schema.field_by_id(field_id); if iceberg_field.is_none() || parquet_iceberg_field.is_none() { @@ -202,19 +181,19 @@ impl ArrowReader { true }); - if column_map.len() != self.field_ids.len() { + if column_map.len() != field_ids.len() { return Err(Error::new( ErrorKind::DataInvalid, format!( "Parquet schema {} and Iceberg schema {} do not match.", - iceberg_schema, self.schema + iceberg_schema, iceberg_schema_of_task ), )); } let mut indices = vec![]; - for field_id in &self.field_ids { - if let Some(col_idx) = column_map.get(&(*field_id as i32)) { + for field_id in field_ids { + if let Some(col_idx) = column_map.get(field_id) { indices.push(*col_idx); } else { return Err(Error::new( @@ -229,10 +208,11 @@ impl ArrowReader { fn get_row_filter( &self, + predicates: Option<&BoundPredicate>, parquet_schema: &SchemaDescriptor, collector: &CollectFieldIdVisitor, ) -> Result> { - if let Some(predicates) = &self.predicates { + if let Some(predicates) = predicates { let field_id_map = build_field_id_map(parquet_schema)?; // Collect Parquet column indices from field ids. @@ -741,42 +721,98 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { fn starts_with( &mut self, - _reference: &BoundReference, - _literal: &Datum, + reference: &BoundReference, + literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - // TODO: Implement starts_with - self.build_always_true() + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + starts_with(&left, literal.as_ref()) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } } fn not_starts_with( &mut self, - _reference: &BoundReference, - _literal: &Datum, + reference: &BoundReference, + literal: &Datum, _predicate: &BoundPredicate, ) -> Result> { - // TODO: Implement not_starts_with - self.build_always_true() + if let Some(idx) = self.bound_reference(reference)? { + let literal = get_arrow_datum(literal)?; + + Ok(Box::new(move |batch| { + let left = project_column(&batch, idx)?; + + // update here if arrow ever adds a native not_starts_with + not(&starts_with(&left, literal.as_ref())?) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } } fn r#in( &mut self, - _reference: &BoundReference, - _literals: &FnvHashSet, + reference: &BoundReference, + literals: &FnvHashSet, _predicate: &BoundPredicate, ) -> Result> { - // TODO: Implement in - self.build_always_true() + if let Some(idx) = self.bound_reference(reference)? { + let literals: Vec<_> = literals + .iter() + .map(|lit| get_arrow_datum(lit).unwrap()) + .collect(); + + Ok(Box::new(move |batch| { + // update this if arrow ever adds a native is_in kernel + let left = project_column(&batch, idx)?; + let mut acc = BooleanArray::from(vec![false; batch.num_rows()]); + for literal in &literals { + acc = or(&acc, &eq(&left, literal.as_ref())?)? + } + + Ok(acc) + })) + } else { + // A missing column, treating it as null. + self.build_always_false() + } } fn not_in( &mut self, - _reference: &BoundReference, - _literals: &FnvHashSet, + reference: &BoundReference, + literals: &FnvHashSet, _predicate: &BoundPredicate, ) -> Result> { - // TODO: Implement not_in - self.build_always_true() + if let Some(idx) = self.bound_reference(reference)? { + let literals: Vec<_> = literals + .iter() + .map(|lit| get_arrow_datum(lit).unwrap()) + .collect(); + + Ok(Box::new(move |batch| { + // update this if arrow ever adds a native not_in kernel + let left = project_column(&batch, idx)?; + let mut acc = BooleanArray::from(vec![true; batch.num_rows()]); + for literal in &literals { + acc = and(&acc, &neq(&left, literal.as_ref())?)? + } + + Ok(acc) + })) + } else { + // A missing column, treating it as null. + self.build_always_true() + } } } @@ -784,7 +820,8 @@ impl<'a> BoundPredicateVisitor for PredicateConverter<'a> { /// /// # TODO /// -/// [ParquetObjectReader](https://docs.rs/parquet/latest/src/parquet/arrow/async_reader/store.rs.html#64) contains the following hints to speed up metadata loading, we can consider adding them to this struct: +/// [ParquetObjectReader](https://docs.rs/parquet/latest/src/parquet/arrow/async_reader/store.rs.html#64) +/// contains the following hints to speed up metadata loading, we can consider adding them to this struct: /// /// - `metadata_size_hint`: Provide a hint as to the size of the parquet file's footer. /// - `preload_column_index`: Load the Column Index as part of [`Self::get_metadata`]. diff --git a/crates/iceberg/src/arrow/schema.rs b/crates/iceberg/src/arrow/schema.rs index 172d4bb79..3102f6d33 100644 --- a/crates/iceberg/src/arrow/schema.rs +++ b/crates/iceberg/src/arrow/schema.rs @@ -26,6 +26,7 @@ use crate::{Error, ErrorKind}; use arrow_array::types::{validate_decimal_precision_and_scale, Decimal128Type}; use arrow_array::{ BooleanArray, Datum as ArrowDatum, Float32Array, Float64Array, Int32Array, Int64Array, + StringArray, }; use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; use bitvec::macros::internal::funty::Fundamental; @@ -34,6 +35,9 @@ use rust_decimal::prelude::ToPrimitive; use std::collections::HashMap; use std::sync::Arc; +/// When iceberg map type convert to Arrow map type, the default map field name is "key_value". +pub(crate) const DEFAULT_MAP_FIELD_NAME: &str = "key_value"; + /// A post order arrow schema visitor. /// /// For order of methods called, please refer to [`visit_schema`]. @@ -499,9 +503,10 @@ impl SchemaVisitor for ToArrowSchemaConverter { _ => unreachable!(), }; let field = Field::new( - "entries", + DEFAULT_MAP_FIELD_NAME, DataType::Struct(vec![key_field, value_field].into()), - map.value_field.required, + // Map field is always not nullable + false, ); Ok(ArrowSchemaOrFieldOrType::Type(DataType::Map( @@ -561,7 +566,7 @@ impl SchemaVisitor for ToArrowSchemaConverter { Ok(ArrowSchemaOrFieldOrType::Type(DataType::Date32)) } crate::spec::PrimitiveType::Time => Ok(ArrowSchemaOrFieldOrType::Type( - DataType::Time32(TimeUnit::Microsecond), + DataType::Time64(TimeUnit::Microsecond), )), crate::spec::PrimitiveType::Timestamp => Ok(ArrowSchemaOrFieldOrType::Type( DataType::Timestamp(TimeUnit::Microsecond, None), @@ -605,6 +610,7 @@ pub(crate) fn get_arrow_datum(datum: &Datum) -> Result Ok(Box::new(Int64Array::new_scalar(*value))), PrimitiveLiteral::Float(value) => Ok(Box::new(Float32Array::new_scalar(value.as_f32()))), PrimitiveLiteral::Double(value) => Ok(Box::new(Float64Array::new_scalar(value.as_f64()))), + PrimitiveLiteral::String(value) => Ok(Box::new(StringArray::new_scalar(value.as_str()))), l => Err(Error::new( ErrorKind::FeatureUnsupported, format!( @@ -657,10 +663,9 @@ mod tests { let r#struct = DataType::Struct(fields); let map = DataType::Map( Arc::new( - Field::new("entries", r#struct, false).with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "19".to_string(), - )])), + Field::new(DEFAULT_MAP_FIELD_NAME, r#struct, false).with_metadata(HashMap::from([ + (PARQUET_FIELD_ID_META_KEY.to_string(), "19".to_string()), + ])), ), false, ); @@ -1022,7 +1027,10 @@ mod tests { ]); let r#struct = DataType::Struct(fields); - let map = DataType::Map(Arc::new(Field::new("entries", r#struct, false)), false); + let map = DataType::Map( + Arc::new(Field::new(DEFAULT_MAP_FIELD_NAME, r#struct, false)), + false, + ); let fields = Fields::from(vec![ Field::new("aa", DataType::Int32, false).with_metadata(HashMap::from([( @@ -1086,7 +1094,7 @@ mod tests { PARQUET_FIELD_ID_META_KEY.to_string(), "8".to_string(), )])), - Field::new("i", DataType::Time32(TimeUnit::Microsecond), false).with_metadata( + Field::new("i", DataType::Time64(TimeUnit::Microsecond), false).with_metadata( HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "9".to_string())]), ), Field::new( diff --git a/crates/iceberg/src/avro/mod.rs b/crates/iceberg/src/avro/mod.rs index bdccb2ff4..f2a9310e7 100644 --- a/crates/iceberg/src/avro/mod.rs +++ b/crates/iceberg/src/avro/mod.rs @@ -16,6 +16,5 @@ // under the License. //! Avro related codes. -#[allow(dead_code)] mod schema; pub(crate) use schema::*; diff --git a/crates/iceberg/src/avro/schema.rs b/crates/iceberg/src/avro/schema.rs index 636f1283c..11d000cc5 100644 --- a/crates/iceberg/src/avro/schema.rs +++ b/crates/iceberg/src/avro/schema.rs @@ -19,10 +19,10 @@ use std::collections::BTreeMap; use crate::spec::{ - visit_schema, ListType, MapType, NestedField, NestedFieldRef, PrimitiveType, Schema, - SchemaVisitor, StructType, Type, + visit_schema, ListType, MapType, NestedFieldRef, PrimitiveType, Schema, SchemaVisitor, + StructType, }; -use crate::{ensure_data_valid, Error, ErrorKind, Result}; +use crate::{Error, ErrorKind, Result}; use apache_avro::schema::{ DecimalSchema, FixedSchema, Name, RecordField as AvroRecordField, RecordFieldOrder, RecordSchema, UnionSchema, @@ -272,261 +272,264 @@ fn avro_optional(avro_schema: AvroSchema) -> Result { ])?)) } -fn is_avro_optional(avro_schema: &AvroSchema) -> bool { - match avro_schema { - AvroSchema::Union(union) => union.is_nullable(), - _ => false, - } -} +#[cfg(test)] +mod tests { + use super::*; + use crate::ensure_data_valid; + use crate::spec::{ListType, MapType, NestedField, PrimitiveType, Schema, StructType, Type}; + use apache_avro::schema::{Namespace, UnionSchema}; + use apache_avro::Schema as AvroSchema; + use std::fs::read_to_string; -/// Post order avro schema visitor. -pub(crate) trait AvroSchemaVisitor { - type T; + fn is_avro_optional(avro_schema: &AvroSchema) -> bool { + match avro_schema { + AvroSchema::Union(union) => union.is_nullable(), + _ => false, + } + } - fn record(&mut self, record: &RecordSchema, fields: Vec) -> Result; + /// Post order avro schema visitor. + pub(crate) trait AvroSchemaVisitor { + type T; - fn union(&mut self, union: &UnionSchema, options: Vec) -> Result; + fn record(&mut self, record: &RecordSchema, fields: Vec) -> Result; - fn array(&mut self, array: &AvroSchema, item: Self::T) -> Result; - fn map(&mut self, map: &AvroSchema, value: Self::T) -> Result; + fn union(&mut self, union: &UnionSchema, options: Vec) -> Result; - fn primitive(&mut self, schema: &AvroSchema) -> Result; -} + fn array(&mut self, array: &AvroSchema, item: Self::T) -> Result; + fn map(&mut self, map: &AvroSchema, value: Self::T) -> Result; -/// Visit avro schema in post order visitor. -pub(crate) fn visit(schema: &AvroSchema, visitor: &mut V) -> Result { - match schema { - AvroSchema::Record(record) => { - let field_results = record - .fields - .iter() - .map(|f| visit(&f.schema, visitor)) - .collect::>>()?; - - visitor.record(record, field_results) - } - AvroSchema::Union(union) => { - let option_results = union - .variants() - .iter() - .map(|f| visit(f, visitor)) - .collect::>>()?; - - visitor.union(union, option_results) - } - AvroSchema::Array(item) => { - let item_result = visit(item, visitor)?; - visitor.array(schema, item_result) - } - AvroSchema::Map(inner) => { - let item_result = visit(inner, visitor)?; - visitor.map(schema, item_result) - } - schema => visitor.primitive(schema), + fn primitive(&mut self, schema: &AvroSchema) -> Result; } -} -struct AvroSchemaToSchema { - next_id: i32, -} + struct AvroSchemaToSchema { + next_id: i32, + } -impl AvroSchemaToSchema { - fn next_field_id(&mut self) -> i32 { - self.next_id += 1; - self.next_id + impl AvroSchemaToSchema { + fn next_field_id(&mut self) -> i32 { + self.next_id += 1; + self.next_id + } } -} -impl AvroSchemaVisitor for AvroSchemaToSchema { - // Only `AvroSchema::Null` will return `None` - type T = Option; + impl AvroSchemaVisitor for AvroSchemaToSchema { + // Only `AvroSchema::Null` will return `None` + type T = Option; + + fn record( + &mut self, + record: &RecordSchema, + field_types: Vec>, + ) -> Result> { + let mut fields = Vec::with_capacity(field_types.len()); + for (avro_field, typ) in record.fields.iter().zip_eq(field_types) { + let field_id = avro_field + .custom_attributes + .get(FILED_ID_PROP) + .and_then(Value::as_i64) + .ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Can't convert field, missing field id: {avro_field:?}"), + ) + })?; - fn record( - &mut self, - record: &RecordSchema, - field_types: Vec>, - ) -> Result> { - let mut fields = Vec::with_capacity(field_types.len()); - for (avro_field, typ) in record.fields.iter().zip_eq(field_types) { - let field_id = avro_field - .custom_attributes - .get(FILED_ID_PROP) - .and_then(Value::as_i64) - .ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - format!("Can't convert field, missing field id: {avro_field:?}"), - ) - })?; + let optional = is_avro_optional(&avro_field.schema); - let optional = is_avro_optional(&avro_field.schema); + let mut field = if optional { + NestedField::optional(field_id as i32, &avro_field.name, typ.unwrap()) + } else { + NestedField::required(field_id as i32, &avro_field.name, typ.unwrap()) + }; - let mut field = if optional { - NestedField::optional(field_id as i32, &avro_field.name, typ.unwrap()) - } else { - NestedField::required(field_id as i32, &avro_field.name, typ.unwrap()) - }; + if let Some(doc) = &avro_field.doc { + field = field.with_doc(doc); + } - if let Some(doc) = &avro_field.doc { - field = field.with_doc(doc); + fields.push(field.into()); } - fields.push(field.into()); + Ok(Some(Type::Struct(StructType::new(fields)))) } - Ok(Some(Type::Struct(StructType::new(fields)))) - } - - fn union( - &mut self, - union: &UnionSchema, - mut options: Vec>, - ) -> Result> { - ensure_data_valid!( - options.len() <= 2 && !options.is_empty(), - "Can't convert avro union type {:?} to iceberg.", - union - ); - - if options.len() > 1 { + fn union( + &mut self, + union: &UnionSchema, + mut options: Vec>, + ) -> Result> { ensure_data_valid!( - options[0].is_none(), + options.len() <= 2 && !options.is_empty(), "Can't convert avro union type {:?} to iceberg.", union ); - } - if options.len() == 1 { - Ok(Some(options.remove(0).unwrap())) - } else { - Ok(Some(options.remove(1).unwrap())) - } - } + if options.len() > 1 { + ensure_data_valid!( + options[0].is_none(), + "Can't convert avro union type {:?} to iceberg.", + union + ); + } - fn array(&mut self, array: &AvroSchema, item: Option) -> Result { - if let AvroSchema::Array(item_schema) = array { - let element_field = NestedField::list_element( - self.next_field_id(), - item.unwrap(), - !is_avro_optional(item_schema), - ) - .into(); - Ok(Some(Type::List(ListType { element_field }))) - } else { - Err(Error::new( - ErrorKind::Unexpected, - "Expected avro array schema, but {array}", - )) + if options.len() == 1 { + Ok(Some(options.remove(0).unwrap())) + } else { + Ok(Some(options.remove(1).unwrap())) + } } - } - fn map(&mut self, map: &AvroSchema, value: Option) -> Result> { - if let AvroSchema::Map(value_schema) = map { - // Due to avro rust implementation's limitation, we can't store attributes in map schema, - // we will fix it later when it has been resolved. - let key_field = NestedField::map_key_element( - self.next_field_id(), - Type::Primitive(PrimitiveType::String), - ); - let value_field = NestedField::map_value_element( - self.next_field_id(), - value.unwrap(), - !is_avro_optional(value_schema), - ); - Ok(Some(Type::Map(MapType { - key_field: key_field.into(), - value_field: value_field.into(), - }))) - } else { - Err(Error::new( - ErrorKind::Unexpected, - "Expected avro map schema, but {map}", - )) + fn array(&mut self, array: &AvroSchema, item: Option) -> Result { + if let AvroSchema::Array(item_schema) = array { + let element_field = NestedField::list_element( + self.next_field_id(), + item.unwrap(), + !is_avro_optional(item_schema), + ) + .into(); + Ok(Some(Type::List(ListType { element_field }))) + } else { + Err(Error::new( + ErrorKind::Unexpected, + "Expected avro array schema, but {array}", + )) + } } - } - fn primitive(&mut self, schema: &AvroSchema) -> Result> { - let typ = match schema { - AvroSchema::Decimal(decimal) => { - Type::decimal(decimal.precision as u32, decimal.scale as u32)? + fn map(&mut self, map: &AvroSchema, value: Option) -> Result> { + if let AvroSchema::Map(value_schema) = map { + // Due to avro rust implementation's limitation, we can't store attributes in map schema, + // we will fix it later when it has been resolved. + let key_field = NestedField::map_key_element( + self.next_field_id(), + Type::Primitive(PrimitiveType::String), + ); + let value_field = NestedField::map_value_element( + self.next_field_id(), + value.unwrap(), + !is_avro_optional(value_schema), + ); + Ok(Some(Type::Map(MapType { + key_field: key_field.into(), + value_field: value_field.into(), + }))) + } else { + Err(Error::new( + ErrorKind::Unexpected, + "Expected avro map schema, but {map}", + )) } - AvroSchema::Date => Type::Primitive(PrimitiveType::Date), - AvroSchema::TimeMicros => Type::Primitive(PrimitiveType::Time), - AvroSchema::TimestampMicros => Type::Primitive(PrimitiveType::Timestamp), - AvroSchema::Boolean => Type::Primitive(PrimitiveType::Boolean), - AvroSchema::Int => Type::Primitive(PrimitiveType::Int), - AvroSchema::Long => Type::Primitive(PrimitiveType::Long), - AvroSchema::Float => Type::Primitive(PrimitiveType::Float), - AvroSchema::Double => Type::Primitive(PrimitiveType::Double), - AvroSchema::String | AvroSchema::Enum(_) => Type::Primitive(PrimitiveType::String), - AvroSchema::Fixed(fixed) => { - if let Some(logical_type) = fixed.attributes.get(LOGICAL_TYPE) { - let logical_type = logical_type.as_str().ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - "logicalType in attributes of avro schema is not a string type", - ) - })?; - match logical_type { - UUID_LOGICAL_TYPE => Type::Primitive(PrimitiveType::Uuid), - ty => { - return Err(Error::new( - ErrorKind::FeatureUnsupported, - format!( + } + + fn primitive(&mut self, schema: &AvroSchema) -> Result> { + let typ = match schema { + AvroSchema::Decimal(decimal) => { + Type::decimal(decimal.precision as u32, decimal.scale as u32)? + } + AvroSchema::Date => Type::Primitive(PrimitiveType::Date), + AvroSchema::TimeMicros => Type::Primitive(PrimitiveType::Time), + AvroSchema::TimestampMicros => Type::Primitive(PrimitiveType::Timestamp), + AvroSchema::Boolean => Type::Primitive(PrimitiveType::Boolean), + AvroSchema::Int => Type::Primitive(PrimitiveType::Int), + AvroSchema::Long => Type::Primitive(PrimitiveType::Long), + AvroSchema::Float => Type::Primitive(PrimitiveType::Float), + AvroSchema::Double => Type::Primitive(PrimitiveType::Double), + AvroSchema::String | AvroSchema::Enum(_) => Type::Primitive(PrimitiveType::String), + AvroSchema::Fixed(fixed) => { + if let Some(logical_type) = fixed.attributes.get(LOGICAL_TYPE) { + let logical_type = logical_type.as_str().ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + "logicalType in attributes of avro schema is not a string type", + ) + })?; + match logical_type { + UUID_LOGICAL_TYPE => Type::Primitive(PrimitiveType::Uuid), + ty => { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + format!( "Logical type {ty} is not support in iceberg primitive type.", ), - )) + )) + } } + } else { + Type::Primitive(PrimitiveType::Fixed(fixed.size as u64)) } - } else { - Type::Primitive(PrimitiveType::Fixed(fixed.size as u64)) } + AvroSchema::Bytes => Type::Primitive(PrimitiveType::Binary), + AvroSchema::Null => return Ok(None), + _ => { + return Err(Error::new( + ErrorKind::Unexpected, + "Unable to convert avro {schema} to iceberg primitive type.", + )) + } + }; + + Ok(Some(typ)) + } + } + + /// Visit avro schema in post order visitor. + pub(crate) fn visit( + schema: &AvroSchema, + visitor: &mut V, + ) -> Result { + match schema { + AvroSchema::Record(record) => { + let field_results = record + .fields + .iter() + .map(|f| visit(&f.schema, visitor)) + .collect::>>()?; + + visitor.record(record, field_results) + } + AvroSchema::Union(union) => { + let option_results = union + .variants() + .iter() + .map(|f| visit(f, visitor)) + .collect::>>()?; + + visitor.union(union, option_results) + } + AvroSchema::Array(item) => { + let item_result = visit(item, visitor)?; + visitor.array(schema, item_result) + } + AvroSchema::Map(inner) => { + let item_result = visit(inner, visitor)?; + visitor.map(schema, item_result) } - AvroSchema::Bytes => Type::Primitive(PrimitiveType::Binary), - AvroSchema::Null => return Ok(None), - _ => { - return Err(Error::new( + schema => visitor.primitive(schema), + } + } + /// Converts avro schema to iceberg schema. + pub(crate) fn avro_schema_to_schema(avro_schema: &AvroSchema) -> Result { + if let AvroSchema::Record(_) = avro_schema { + let mut converter = AvroSchemaToSchema { next_id: 0 }; + let typ = + visit(avro_schema, &mut converter)?.expect("Iceberg schema should not be none."); + if let Type::Struct(s) = typ { + Schema::builder() + .with_fields(s.fields().iter().cloned()) + .build() + } else { + Err(Error::new( ErrorKind::Unexpected, - "Unable to convert avro {schema} to iceberg primitive type.", + format!("Expected to convert avro record schema to struct type, but {typ}"), )) } - }; - - Ok(Some(typ)) - } -} - -/// Converts avro schema to iceberg schema. -pub(crate) fn avro_schema_to_schema(avro_schema: &AvroSchema) -> Result { - if let AvroSchema::Record(_) = avro_schema { - let mut converter = AvroSchemaToSchema { next_id: 0 }; - let typ = visit(avro_schema, &mut converter)?.expect("Iceberg schema should not be none."); - if let Type::Struct(s) = typ { - Schema::builder() - .with_fields(s.fields().iter().cloned()) - .build() } else { Err(Error::new( - ErrorKind::Unexpected, - format!("Expected to convert avro record schema to struct type, but {typ}"), + ErrorKind::DataInvalid, + "Can't convert non record avro schema to iceberg schema: {avro_schema}", )) } - } else { - Err(Error::new( - ErrorKind::DataInvalid, - "Can't convert non record avro schema to iceberg schema: {avro_schema}", - )) } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::avro::schema::AvroSchemaToSchema; - use crate::spec::{ListType, MapType, NestedField, PrimitiveType, Schema, StructType, Type}; - use apache_avro::schema::{Namespace, UnionSchema}; - use apache_avro::Schema as AvroSchema; - use std::fs::read_to_string; fn read_test_data_file_to_avro_schema(filename: &str) -> AvroSchema { let input = read_to_string(format!( diff --git a/crates/iceberg/src/catalog/mod.rs b/crates/iceberg/src/catalog/mod.rs index e2ea150ce..4d89aa9b4 100644 --- a/crates/iceberg/src/catalog/mod.rs +++ b/crates/iceberg/src/catalog/mod.rs @@ -30,7 +30,6 @@ use std::fmt::Debug; use std::mem::take; use std::ops::Deref; use typed_builder::TypedBuilder; -use urlencoding::encode; use uuid::Uuid; /// The catalog API for Iceberg Rust. @@ -123,9 +122,9 @@ impl NamespaceIdent { Self::from_vec(iter.into_iter().map(|s| s.to_string()).collect()) } - /// Returns url encoded format. - pub fn encode_in_url(&self) -> String { - encode(&self.as_ref().join("\u{1F}")).to_string() + /// Returns a string for used in url. + pub fn to_url_string(&self) -> String { + self.as_ref().join("\u{001f}") } /// Returns inner strings. diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs index 3d77c4df8..16f75b090 100644 --- a/crates/iceberg/src/expr/mod.rs +++ b/crates/iceberg/src/expr/mod.rs @@ -18,6 +18,7 @@ //! This module contains expressions. mod term; +use serde::{Deserialize, Serialize}; pub use term::*; pub(crate) mod accessor; mod predicate; @@ -32,7 +33,7 @@ use std::fmt::{Display, Formatter}; /// The discriminant of this enum is used for determining the type of the operator, see /// [`PredicateOperator::is_unary`], [`PredicateOperator::is_binary`], [`PredicateOperator::is_set`] #[allow(missing_docs)] -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] #[non_exhaustive] #[repr(u16)] pub enum PredicateOperator { diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 158ab135b..3a91d6bbc 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -25,6 +25,7 @@ use std::ops::Not; use array_init::array_init; use fnv::FnvHashSet; use itertools::Itertools; +use serde::{Deserialize, Serialize}; use crate::error::Result; use crate::expr::{Bind, BoundReference, PredicateOperator, Reference}; @@ -37,6 +38,29 @@ pub struct LogicalExpression { inputs: [Box; N], } +impl Serialize for LogicalExpression { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + self.inputs.serialize(serializer) + } +} + +impl<'de, T: Deserialize<'de>, const N: usize> Deserialize<'de> for LogicalExpression { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + let inputs = Vec::>::deserialize(deserializer)?; + Ok(LogicalExpression::new( + array_init::from_iter(inputs.into_iter()).ok_or_else(|| { + serde::de::Error::custom(format!("Failed to deserialize LogicalExpression: the len of inputs is not match with the len of LogicalExpression {}",N)) + })?, + )) + } +} + impl Debug for LogicalExpression { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("LogicalExpression") @@ -79,11 +103,12 @@ where } /// Unary predicate, for example, `a IS NULL`. -#[derive(PartialEq, Clone)] +#[derive(PartialEq, Clone, Serialize, Deserialize)] pub struct UnaryExpression { /// Operator of this predicate, must be single operand operator. op: PredicateOperator, /// Term of this predicate, for example, `a` in `a IS NULL`. + #[serde(bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>"))] term: T, } @@ -129,11 +154,12 @@ impl UnaryExpression { } /// Binary predicate, for example, `a > 10`. -#[derive(PartialEq, Clone)] +#[derive(PartialEq, Clone, Serialize, Deserialize)] pub struct BinaryExpression { /// Operator of this predicate, must be binary operator, such as `=`, `>`, `<`, etc. op: PredicateOperator, /// Term of this predicate, for example, `a` in `a > 10`. + #[serde(bound(serialize = "T: Serialize", deserialize = "T: Deserialize<'de>"))] term: T, /// Literal of this predicate, for example, `10` in `a > 10`. literal: Datum, @@ -190,7 +216,7 @@ impl Bind for BinaryExpression { } /// Set predicates, for example, `a in (1, 2, 3)`. -#[derive(PartialEq, Clone)] +#[derive(PartialEq, Clone, Serialize, Deserialize)] pub struct SetExpression { /// Operator of this predicate, must be set operator, such as `IN`, `NOT IN`, etc. op: PredicateOperator, @@ -253,7 +279,7 @@ impl Display for SetExpression { } /// Unbound predicate expression before binding to a schema. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] pub enum Predicate { /// AlwaysTrue predicate, for example, `TRUE`. AlwaysTrue, @@ -622,7 +648,7 @@ impl Not for Predicate { } /// Bound predicate expression after binding to a schema. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub enum BoundPredicate { /// An expression always evaluates to true. AlwaysTrue, @@ -678,9 +704,9 @@ mod tests { use std::ops::Not; use std::sync::Arc; - use crate::expr::Bind; use crate::expr::Predicate::{AlwaysFalse, AlwaysTrue}; use crate::expr::Reference; + use crate::expr::{Bind, BoundPredicate}; use crate::spec::Datum; use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; @@ -879,12 +905,19 @@ mod tests { ) } + fn test_bound_predicate_serialize_diserialize(bound_predicate: BoundPredicate) { + let serialized = serde_json::to_string(&bound_predicate).unwrap(); + let deserialized: BoundPredicate = serde_json::from_str(&serialized).unwrap(); + assert_eq!(bound_predicate, deserialized); + } + #[test] fn test_bind_is_null() { let schema = table_schema_simple(); let expr = Reference::new("foo").is_null(); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "foo IS NULL"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -893,6 +926,7 @@ mod tests { let expr = Reference::new("bar").is_null(); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "False"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -901,6 +935,7 @@ mod tests { let expr = Reference::new("foo").is_not_null(); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "foo IS NOT NULL"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -909,6 +944,7 @@ mod tests { let expr = Reference::new("bar").is_not_null(); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "True"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -922,6 +958,7 @@ mod tests { let expr_string = Reference::new("foo").is_nan(); let bound_expr_string = expr_string.bind(schema_string, true); assert!(bound_expr_string.is_err()); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -938,6 +975,7 @@ mod tests { let expr = Reference::new("qux").is_not_nan(); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "qux IS NOT NAN"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -954,6 +992,7 @@ mod tests { let expr = Reference::new("bar").less_than(Datum::int(10)); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar < 10"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -970,6 +1009,7 @@ mod tests { let expr = Reference::new("bar").less_than_or_equal_to(Datum::int(10)); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar <= 10"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -986,6 +1026,7 @@ mod tests { let expr = Reference::new("bar").greater_than(Datum::int(10)); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar > 10"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1002,6 +1043,7 @@ mod tests { let expr = Reference::new("bar").greater_than_or_equal_to(Datum::int(10)); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar >= 10"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1018,6 +1060,7 @@ mod tests { let expr = Reference::new("bar").equal_to(Datum::int(10)); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar = 10"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1034,6 +1077,7 @@ mod tests { let expr = Reference::new("bar").not_equal_to(Datum::int(10)); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar != 10"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1050,6 +1094,7 @@ mod tests { let expr = Reference::new("foo").starts_with(Datum::string("abcd")); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), r#"foo STARTS WITH "abcd""#); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1066,6 +1111,7 @@ mod tests { let expr = Reference::new("foo").not_starts_with(Datum::string("abcd")); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), r#"foo NOT STARTS WITH "abcd""#); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1082,6 +1128,7 @@ mod tests { let expr = Reference::new("bar").is_in([Datum::int(10), Datum::int(20)]); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar IN (20, 10)"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1090,6 +1137,7 @@ mod tests { let expr = Reference::new("bar").is_in(vec![]); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "False"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1098,6 +1146,7 @@ mod tests { let expr = Reference::new("bar").is_in(vec![Datum::int(10)]); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar = 10"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1114,6 +1163,7 @@ mod tests { let expr = Reference::new("bar").is_not_in([Datum::int(10), Datum::int(20)]); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar NOT IN (20, 10)"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1122,6 +1172,7 @@ mod tests { let expr = Reference::new("bar").is_not_in(vec![]); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "True"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1130,6 +1181,7 @@ mod tests { let expr = Reference::new("bar").is_not_in(vec![Datum::int(10)]); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "bar != 10"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1148,6 +1200,7 @@ mod tests { .and(Reference::new("foo").is_null()); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "(bar < 10) AND (foo IS NULL)"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1158,6 +1211,7 @@ mod tests { .and(Reference::new("bar").is_null()); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "False"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1168,6 +1222,7 @@ mod tests { .and(Reference::new("bar").is_not_null()); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1178,6 +1233,7 @@ mod tests { .or(Reference::new("foo").is_null()); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "(bar < 10) OR (foo IS NULL)"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1188,6 +1244,7 @@ mod tests { .or(Reference::new("bar").is_not_null()); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "True"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1198,6 +1255,7 @@ mod tests { .or(Reference::new("bar").is_null()); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1206,6 +1264,7 @@ mod tests { let expr = !Reference::new("bar").less_than(Datum::int(10)); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "NOT (bar < 10)"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1214,6 +1273,7 @@ mod tests { let expr = !Reference::new("bar").is_not_null(); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), "False"); + test_bound_predicate_serialize_diserialize(bound_expr); } #[test] @@ -1222,5 +1282,6 @@ mod tests { let expr = !Reference::new("bar").is_null(); let bound_expr = expr.bind(schema, true).unwrap(); assert_eq!(&format!("{bound_expr}"), r#"True"#); + test_bound_predicate_serialize_diserialize(bound_expr); } } diff --git a/crates/iceberg/src/expr/term.rs b/crates/iceberg/src/expr/term.rs index 1fbf86c50..909aa62bc 100644 --- a/crates/iceberg/src/expr/term.rs +++ b/crates/iceberg/src/expr/term.rs @@ -20,6 +20,7 @@ use std::fmt::{Display, Formatter}; use fnv::FnvHashSet; +use serde::{Deserialize, Serialize}; use crate::expr::accessor::{StructAccessor, StructAccessorRef}; use crate::expr::Bind; @@ -32,7 +33,7 @@ pub type Term = Reference; /// A named reference in an unbound expression. /// For example, `a` in `a > 10`. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct Reference { name: String, } @@ -351,7 +352,7 @@ impl Bind for Reference { } /// A named reference in a bound expression after binding to a schema. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct BoundReference { // This maybe different from [`name`] filed in [`NestedField`] since this contains full path. // For example, if the field is `a.b.c`, then `field.name` is `c`, but `original_name` is `a.b.c`. diff --git a/crates/iceberg/src/expr/visitors/expression_evaluator.rs b/crates/iceberg/src/expr/visitors/expression_evaluator.rs index b69d093bd..81e91f3ee 100644 --- a/crates/iceberg/src/expr/visitors/expression_evaluator.rs +++ b/crates/iceberg/src/expr/visitors/expression_evaluator.rs @@ -47,7 +47,7 @@ impl ExpressionEvaluator { /// to see if this [`DataFile`] could possibly contain data that matches /// the scan's filter. pub(crate) fn eval(&self, data_file: &DataFile) -> Result { - let mut visitor = ExpressionEvaluatorVisitor::new(self, data_file.partition()); + let mut visitor = ExpressionEvaluatorVisitor::new(data_file.partition()); visit(&mut visitor, &self.partition_filter) } @@ -58,19 +58,14 @@ impl ExpressionEvaluator { /// specifically for data file partitions. #[derive(Debug)] struct ExpressionEvaluatorVisitor<'a> { - /// Reference to an [`ExpressionEvaluator`]. - expression_evaluator: &'a ExpressionEvaluator, /// Reference to a [`DataFile`]'s partition [`Struct`]. partition: &'a Struct, } impl<'a> ExpressionEvaluatorVisitor<'a> { /// Creates a new [`ExpressionEvaluatorVisitor`]. - fn new(expression_evaluator: &'a ExpressionEvaluator, partition: &'a Struct) -> Self { - Self { - expression_evaluator, - partition, - } + fn new(partition: &'a Struct) -> Self { + Self { partition } } } diff --git a/crates/iceberg/src/expr/visitors/inclusive_metrics_evaluator.rs b/crates/iceberg/src/expr/visitors/inclusive_metrics_evaluator.rs index 5f73f2f84..8d45fa29d 100644 --- a/crates/iceberg/src/expr/visitors/inclusive_metrics_evaluator.rs +++ b/crates/iceberg/src/expr/visitors/inclusive_metrics_evaluator.rs @@ -1206,64 +1206,6 @@ mod test { assert!(result, "Should read: id above upper bound"); } - fn test_case_insensitive_integer_not_eq_rewritten() { - let result = InclusiveMetricsEvaluator::eval( - &equal_int_not_case_insensitive("ID", INT_MIN_VALUE - 25), - &get_test_file_1(), - true, - ) - .unwrap(); - assert!(result, "Should read: id below lower bound"); - - let result = InclusiveMetricsEvaluator::eval( - &equal_int_not_case_insensitive("ID", INT_MIN_VALUE - 1), - &get_test_file_1(), - true, - ) - .unwrap(); - assert!(result, "Should read: id below lower bound"); - - let result = InclusiveMetricsEvaluator::eval( - &equal_int_not_case_insensitive("ID", INT_MIN_VALUE), - &get_test_file_1(), - true, - ) - .unwrap(); - assert!(result, "Should read: id equal to lower bound"); - - let result = InclusiveMetricsEvaluator::eval( - &equal_int_not_case_insensitive("ID", INT_MAX_VALUE - 4), - &get_test_file_1(), - true, - ) - .unwrap(); - assert!(result, "Should read: id between lower and upper bound"); - - let result = InclusiveMetricsEvaluator::eval( - &equal_int_not_case_insensitive("ID", INT_MAX_VALUE), - &get_test_file_1(), - true, - ) - .unwrap(); - assert!(result, "Should read: id equal to upper bound"); - - let result = InclusiveMetricsEvaluator::eval( - &equal_int_not_case_insensitive("ID", INT_MAX_VALUE + 1), - &get_test_file_1(), - true, - ) - .unwrap(); - assert!(result, "Should read: id above upper bound"); - - let result = InclusiveMetricsEvaluator::eval( - &equal_int_not_case_insensitive("ID", INT_MAX_VALUE + 6), - &get_test_file_1(), - true, - ) - .unwrap(); - assert!(result, "Should read: id above upper bound"); - } - #[test] #[should_panic] fn test_case_sensitive_integer_not_eq_rewritten() { @@ -1882,17 +1824,6 @@ mod test { filter.bind(schema.clone(), true).unwrap() } - fn equal_int_not_case_insensitive(reference: &str, int_literal: i32) -> BoundPredicate { - let schema = create_test_schema(); - let filter = Predicate::Binary(BinaryExpression::new( - Eq, - Reference::new(reference), - Datum::int(int_literal), - )) - .not(); - filter.bind(schema.clone(), false).unwrap() - } - fn not_equal_int(reference: &str, int_literal: i32) -> BoundPredicate { let schema = create_test_schema(); let filter = Predicate::Binary(BinaryExpression::new( diff --git a/crates/iceberg/src/expr/visitors/manifest_evaluator.rs b/crates/iceberg/src/expr/visitors/manifest_evaluator.rs index c1b6dbed9..eb770ea2c 100644 --- a/crates/iceberg/src/expr/visitors/manifest_evaluator.rs +++ b/crates/iceberg/src/expr/visitors/manifest_evaluator.rs @@ -17,8 +17,9 @@ use crate::expr::visitors::bound_predicate_visitor::{visit, BoundPredicateVisitor}; use crate::expr::{BoundPredicate, BoundReference}; -use crate::spec::{Datum, FieldSummary, ManifestFile}; +use crate::spec::{Datum, FieldSummary, ManifestFile, PrimitiveLiteral, Type}; use crate::Result; +use crate::{Error, ErrorKind}; use fnv::FnvHashSet; /// Evaluates a [`ManifestFile`] to see if the partition summaries @@ -45,37 +46,35 @@ impl ManifestEvaluator { return Ok(true); } - let mut evaluator = ManifestFilterVisitor::new(self, &manifest_file.partitions); + let mut evaluator = ManifestFilterVisitor::new(&manifest_file.partitions); visit(&mut evaluator, &self.partition_filter) } } struct ManifestFilterVisitor<'a> { - manifest_evaluator: &'a ManifestEvaluator, partitions: &'a Vec, } impl<'a> ManifestFilterVisitor<'a> { - fn new(manifest_evaluator: &'a ManifestEvaluator, partitions: &'a Vec) -> Self { - ManifestFilterVisitor { - manifest_evaluator, - partitions, - } + fn new(partitions: &'a Vec) -> Self { + ManifestFilterVisitor { partitions } } } -// Remove this annotation once all todos have been removed -#[allow(unused_variables)] +const ROWS_MIGHT_MATCH: Result = Ok(true); +const ROWS_CANNOT_MATCH: Result = Ok(false); +const IN_PREDICATE_LIMIT: usize = 200; + impl BoundPredicateVisitor for ManifestFilterVisitor<'_> { type T = bool; fn always_true(&mut self) -> crate::Result { - Ok(true) + ROWS_MIGHT_MATCH } fn always_false(&mut self) -> crate::Result { - Ok(false) + ROWS_CANNOT_MATCH } fn and(&mut self, lhs: bool, rhs: bool) -> crate::Result { @@ -103,7 +102,15 @@ impl BoundPredicateVisitor for ManifestFilterVisitor<'_> { reference: &BoundReference, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + + // contains_null encodes whether at least one partition value is null, + // lowerBound is null if all partition values are null + if ManifestFilterVisitor::are_all_null(field, &reference.field().field_type) { + ROWS_CANNOT_MATCH + } else { + ROWS_MIGHT_MATCH + } } fn is_nan( @@ -111,10 +118,18 @@ impl BoundPredicateVisitor for ManifestFilterVisitor<'_> { reference: &BoundReference, _predicate: &BoundPredicate, ) -> crate::Result { - Ok(self - .field_summary_for_reference(reference) - .contains_nan - .is_some()) + let field = self.field_summary_for_reference(reference); + if let Some(contains_nan) = field.contains_nan { + if !contains_nan { + return ROWS_CANNOT_MATCH; + } + } + + if ManifestFilterVisitor::are_all_null(field, &reference.field().field_type) { + return ROWS_CANNOT_MATCH; + } + + ROWS_MIGHT_MATCH } fn not_nan( @@ -122,79 +137,210 @@ impl BoundPredicateVisitor for ManifestFilterVisitor<'_> { reference: &BoundReference, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + if let Some(contains_nan) = field.contains_nan { + // check if all values are nan + if contains_nan && !field.contains_null && field.lower_bound.is_none() { + return ROWS_CANNOT_MATCH; + } + } + ROWS_MIGHT_MATCH } fn less_than( &mut self, reference: &BoundReference, - literal: &Datum, + datum: &Datum, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + match &field.lower_bound { + Some(bound) if datum <= bound => ROWS_CANNOT_MATCH, + Some(_) => ROWS_MIGHT_MATCH, + None => ROWS_CANNOT_MATCH, + } } fn less_than_or_eq( &mut self, reference: &BoundReference, - literal: &Datum, + datum: &Datum, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + match &field.lower_bound { + Some(bound) if datum < bound => ROWS_CANNOT_MATCH, + Some(_) => ROWS_MIGHT_MATCH, + None => ROWS_CANNOT_MATCH, + } } fn greater_than( &mut self, reference: &BoundReference, - literal: &Datum, + datum: &Datum, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + match &field.upper_bound { + Some(bound) if datum >= bound => ROWS_CANNOT_MATCH, + Some(_) => ROWS_MIGHT_MATCH, + None => ROWS_CANNOT_MATCH, + } } fn greater_than_or_eq( &mut self, reference: &BoundReference, - literal: &Datum, + datum: &Datum, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + match &field.upper_bound { + Some(bound) if datum > bound => ROWS_CANNOT_MATCH, + Some(_) => ROWS_MIGHT_MATCH, + None => ROWS_CANNOT_MATCH, + } } fn eq( &mut self, reference: &BoundReference, - literal: &Datum, + datum: &Datum, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + + if field.lower_bound.is_none() || field.upper_bound.is_none() { + return ROWS_CANNOT_MATCH; + } + + if let Some(lower_bound) = &field.lower_bound { + if lower_bound > datum { + return ROWS_CANNOT_MATCH; + } + } + + if let Some(upper_bound) = &field.upper_bound { + if upper_bound < datum { + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH } fn not_eq( &mut self, - reference: &BoundReference, - literal: &Datum, + _reference: &BoundReference, + _datum: &Datum, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + // because the bounds are not necessarily a min or max value, this cannot be answered using + // them. notEq(col, X) with (X, Y) doesn't guarantee that X is a value in col. + ROWS_MIGHT_MATCH } fn starts_with( &mut self, reference: &BoundReference, - literal: &Datum, + datum: &Datum, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + + if field.lower_bound.is_none() || field.upper_bound.is_none() { + return ROWS_CANNOT_MATCH; + } + + let prefix = ManifestFilterVisitor::datum_as_str( + datum, + "Cannot perform starts_with on non-string value", + )?; + let prefix_len = prefix.len(); + + if let Some(lower_bound) = &field.lower_bound { + let lower_bound_str = ManifestFilterVisitor::datum_as_str( + lower_bound, + "Cannot perform starts_with on non-string lower bound", + )?; + let min_len = lower_bound_str.len().min(prefix_len); + if prefix.as_bytes().lt(&lower_bound_str.as_bytes()[..min_len]) { + return ROWS_CANNOT_MATCH; + } + } + + if let Some(upper_bound) = &field.upper_bound { + let upper_bound_str = ManifestFilterVisitor::datum_as_str( + upper_bound, + "Cannot perform starts_with on non-string upper bound", + )?; + let min_len = upper_bound_str.len().min(prefix_len); + if prefix.as_bytes().gt(&upper_bound_str.as_bytes()[..min_len]) { + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH } fn not_starts_with( &mut self, reference: &BoundReference, - literal: &Datum, + datum: &Datum, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + + if field.contains_null || field.lower_bound.is_none() || field.upper_bound.is_none() { + return ROWS_MIGHT_MATCH; + } + + let prefix = ManifestFilterVisitor::datum_as_str( + datum, + "Cannot perform not_starts_with on non-string value", + )?; + let prefix_len = prefix.len(); + + // not_starts_with will match unless all values must start with the prefix. This happens when + // the lower and upper bounds both start with the prefix. + if let Some(lower_bound) = &field.lower_bound { + let lower_bound_str = ManifestFilterVisitor::datum_as_str( + lower_bound, + "Cannot perform not_starts_with on non-string lower bound", + )?; + + // if lower is shorter than the prefix then lower doesn't start with the prefix + if prefix_len > lower_bound_str.len() { + return ROWS_MIGHT_MATCH; + } + + if prefix + .as_bytes() + .eq(&lower_bound_str.as_bytes()[..prefix_len]) + { + if let Some(upper_bound) = &field.upper_bound { + let upper_bound_str = ManifestFilterVisitor::datum_as_str( + upper_bound, + "Cannot perform not_starts_with on non-string upper bound", + )?; + + // if upper is shorter than the prefix then upper can't start with the prefix + if prefix_len > upper_bound_str.len() { + return ROWS_MIGHT_MATCH; + } + + if prefix + .as_bytes() + .eq(&upper_bound_str.as_bytes()[..prefix_len]) + { + return ROWS_CANNOT_MATCH; + } + } + } + } + + ROWS_MIGHT_MATCH } fn r#in( @@ -203,16 +349,39 @@ impl BoundPredicateVisitor for ManifestFilterVisitor<'_> { literals: &FnvHashSet, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + let field = self.field_summary_for_reference(reference); + if field.lower_bound.is_none() { + return ROWS_CANNOT_MATCH; + } + + if literals.len() > IN_PREDICATE_LIMIT { + return ROWS_MIGHT_MATCH; + } + + if let Some(lower_bound) = &field.lower_bound { + if literals.iter().all(|datum| lower_bound > datum) { + return ROWS_CANNOT_MATCH; + } + } + + if let Some(upper_bound) = &field.upper_bound { + if literals.iter().all(|datum| upper_bound < datum) { + return ROWS_CANNOT_MATCH; + } + } + + ROWS_MIGHT_MATCH } fn not_in( &mut self, - reference: &BoundReference, - literals: &FnvHashSet, + _reference: &BoundReference, + _literals: &FnvHashSet, _predicate: &BoundPredicate, ) -> crate::Result { - todo!() + // because the bounds are not necessarily a min or max value, this cannot be answered using + // them. notIn(col, {X, ...}) with (X, Y) doesn't guarantee that X is a value in col. + ROWS_MIGHT_MATCH } } @@ -221,43 +390,222 @@ impl ManifestFilterVisitor<'_> { let pos = reference.accessor().position(); &self.partitions[pos] } + + fn are_all_null(field: &FieldSummary, r#type: &Type) -> bool { + // contains_null encodes whether at least one partition value is null, + // lowerBound is null if all partition values are null + let mut all_null: bool = field.contains_null && field.lower_bound.is_none(); + + if all_null && r#type.is_floating_type() { + // floating point types may include NaN values, which we check separately. + // In case bounds don't include NaN value, contains_nan needs to be checked against. + all_null = match field.contains_nan { + Some(val) => !val, + None => false, + } + } + + all_null + } + + fn datum_as_str<'a>(bound: &'a Datum, err_msg: &str) -> crate::Result<&'a String> { + let PrimitiveLiteral::String(bound) = bound.literal() else { + return Err(Error::new(ErrorKind::Unexpected, err_msg)); + }; + Ok(bound) + } } #[cfg(test)] mod test { - use crate::expr::visitors::inclusive_projection::InclusiveProjection; use crate::expr::visitors::manifest_evaluator::ManifestEvaluator; use crate::expr::{ - Bind, BoundPredicate, Predicate, PredicateOperator, Reference, UnaryExpression, + BinaryExpression, Bind, Predicate, PredicateOperator, Reference, SetExpression, + UnaryExpression, }; use crate::spec::{ - FieldSummary, ManifestContentType, ManifestFile, NestedField, PartitionField, - PartitionSpec, PartitionSpecRef, PrimitiveType, Schema, SchemaRef, Transform, Type, + Datum, FieldSummary, ManifestContentType, ManifestFile, NestedField, PrimitiveType, Schema, + SchemaRef, Type, }; use crate::Result; + use fnv::FnvHashSet; + use std::ops::Not; use std::sync::Arc; - fn create_schema_and_partition_spec() -> Result<(SchemaRef, PartitionSpecRef)> { + const INT_MIN_VALUE: i32 = 30; + const INT_MAX_VALUE: i32 = 79; + + const STRING_MIN_VALUE: &str = "a"; + const STRING_MAX_VALUE: &str = "z"; + + fn create_schema() -> Result { let schema = Schema::builder() - .with_fields(vec![Arc::new(NestedField::optional( - 1, - "a", - Type::Primitive(PrimitiveType::Float), - ))]) + .with_fields(vec![ + Arc::new(NestedField::required( + 1, + "id", + Type::Primitive(PrimitiveType::Int), + )), + Arc::new(NestedField::optional( + 2, + "all_nulls_missing_nan", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 3, + "some_nulls", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 4, + "no_nulls", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 5, + "float", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 6, + "all_nulls_double", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 7, + "all_nulls_no_nans", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 8, + "all_nans", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 9, + "both_nan_and_null", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 10, + "no_nan_or_null", + Type::Primitive(PrimitiveType::Double), + )), + Arc::new(NestedField::optional( + 11, + "all_nulls_missing_nan_float", + Type::Primitive(PrimitiveType::Float), + )), + Arc::new(NestedField::optional( + 12, + "all_same_value_or_null", + Type::Primitive(PrimitiveType::String), + )), + Arc::new(NestedField::optional( + 13, + "no_nulls_same_value_a", + Type::Primitive(PrimitiveType::String), + )), + ]) .build()?; - let spec = PartitionSpec::builder() - .with_spec_id(1) - .with_fields(vec![PartitionField::builder() - .source_id(1) - .name("a".to_string()) - .field_id(1) - .transform(Transform::Identity) - .build()]) - .build() - .unwrap(); + Ok(Arc::new(schema)) + } - Ok((Arc::new(schema), Arc::new(spec))) + fn create_partitions() -> Vec { + vec![ + // id + FieldSummary { + contains_null: false, + contains_nan: None, + lower_bound: Some(Datum::int(INT_MIN_VALUE)), + upper_bound: Some(Datum::int(INT_MAX_VALUE)), + }, + // all_nulls_missing_nan + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: None, + upper_bound: None, + }, + // some_nulls + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: Some(Datum::string(STRING_MIN_VALUE)), + upper_bound: Some(Datum::string(STRING_MAX_VALUE)), + }, + // no_nulls + FieldSummary { + contains_null: false, + contains_nan: None, + lower_bound: Some(Datum::string(STRING_MIN_VALUE)), + upper_bound: Some(Datum::string(STRING_MAX_VALUE)), + }, + // float + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: Some(Datum::float(0.0)), + upper_bound: Some(Datum::float(20.0)), + }, + // all_nulls_double + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: None, + upper_bound: None, + }, + // all_nulls_no_nans + FieldSummary { + contains_null: true, + contains_nan: Some(false), + lower_bound: None, + upper_bound: None, + }, + // all_nans + FieldSummary { + contains_null: false, + contains_nan: Some(true), + lower_bound: None, + upper_bound: None, + }, + // both_nan_and_null + FieldSummary { + contains_null: true, + contains_nan: Some(true), + lower_bound: None, + upper_bound: None, + }, + // no_nan_or_null + FieldSummary { + contains_null: false, + contains_nan: Some(false), + lower_bound: Some(Datum::float(0.0)), + upper_bound: Some(Datum::float(20.0)), + }, + // all_nulls_missing_nan_float + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: None, + upper_bound: None, + }, + // all_same_value_or_null + FieldSummary { + contains_null: true, + contains_nan: None, + lower_bound: Some(Datum::string(STRING_MIN_VALUE)), + upper_bound: Some(Datum::string(STRING_MIN_VALUE)), + }, + // no_nulls_same_value_a + FieldSummary { + contains_null: false, + contains_nan: None, + lower_bound: Some(Datum::string(STRING_MIN_VALUE)), + upper_bound: Some(Datum::string(STRING_MIN_VALUE)), + }, + ] } fn create_manifest_file(partitions: Vec) -> ManifestFile { @@ -280,131 +628,663 @@ mod test { } } - fn create_partition_schema( - partition_spec: &PartitionSpecRef, - schema: &Schema, - ) -> Result { - let partition_type = partition_spec.partition_type(schema)?; + #[test] + fn test_always_true() -> Result<()> { + let case_sensitive = false; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::AlwaysTrue.bind(schema.clone(), case_sensitive)?; + + assert!(ManifestEvaluator::new(filter).eval(&manifest_file)?); + + Ok(()) + } + + #[test] + fn test_always_false() -> Result<()> { + let case_sensitive = false; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::AlwaysFalse.bind(schema.clone(), case_sensitive)?; + + assert!(!ManifestEvaluator::new(filter).eval(&manifest_file)?); + + Ok(()) + } + + #[test] + fn test_all_nulls() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + // all_nulls_missing_nan + let all_nulls_missing_nan_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("all_nulls_missing_nan"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(all_nulls_missing_nan_filter).eval(&manifest_file)?, + "Should skip: all nulls column with non-floating type contains all null" + ); + + // all_nulls_missing_nan_float + let all_nulls_missing_nan_float_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("all_nulls_missing_nan_float"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_missing_nan_float_filter).eval(&manifest_file)?, + "Should read: no NaN information may indicate presence of NaN value" + ); - let partition_fields: Vec<_> = partition_type.fields().iter().map(Arc::clone).collect(); + // some_nulls + let some_nulls_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("some_nulls"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(some_nulls_filter).eval(&manifest_file)?, + "Should read: column with some nulls contains a non-null value" + ); - let partition_schema = Arc::new( - Schema::builder() - .with_schema_id(partition_spec.spec_id) - .with_fields(partition_fields) - .build()?, + // no_nulls + let no_nulls_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNull, + Reference::new("no_nulls"), + )) + .bind(schema.clone(), case_sensitive)?; + + assert!( + ManifestEvaluator::new(no_nulls_filter).eval(&manifest_file)?, + "Should read: non-null column contains a non-null value" ); - Ok(partition_schema) + Ok(()) } - fn create_partition_filter( - partition_spec: PartitionSpecRef, - partition_schema: SchemaRef, - filter: &BoundPredicate, - case_sensitive: bool, - ) -> Result { - let mut inclusive_projection = InclusiveProjection::new(partition_spec); + #[test] + fn test_no_nulls() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + // all_nulls_missing_nan + let all_nulls_missing_nan_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("all_nulls_missing_nan"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_missing_nan_filter).eval(&manifest_file)?, + "Should read: at least one null value in all null column" + ); + + // some_nulls + let some_nulls_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("some_nulls"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(some_nulls_filter).eval(&manifest_file)?, + "Should read: column with some nulls contains a null value" + ); + + // no_nulls + let no_nulls_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("no_nulls"), + )) + .bind(schema.clone(), case_sensitive)?; - let partition_filter = inclusive_projection - .project(filter)? - .rewrite_not() - .bind(partition_schema, case_sensitive)?; + assert!( + !ManifestEvaluator::new(no_nulls_filter).eval(&manifest_file)?, + "Should skip: non-null column contains no null values" + ); - Ok(partition_filter) + // both_nan_and_null + let both_nan_and_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + Reference::new("both_nan_and_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(both_nan_and_null_filter).eval(&manifest_file)?, + "Should read: both_nan_and_null column contains no null values" + ); + + Ok(()) } - fn create_manifest_evaluator( - schema: SchemaRef, - partition_spec: PartitionSpecRef, - filter: &BoundPredicate, - case_sensitive: bool, - ) -> Result { - let partition_schema = create_partition_schema(&partition_spec, &schema)?; - let partition_filter = create_partition_filter( - partition_spec, - partition_schema.clone(), - filter, - case_sensitive, - )?; + #[test] + fn test_is_nan() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + // float + let float_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("float"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(float_filter).eval(&manifest_file)?, + "Should read: no information on if there are nan value in float column" + ); + + // all_nulls_double + let all_nulls_double_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("all_nulls_double"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_double_filter).eval(&manifest_file)?, + "Should read: no NaN information may indicate presence of NaN value" + ); + + // all_nulls_missing_nan_float + let all_nulls_missing_nan_float_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("all_nulls_missing_nan_float"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_missing_nan_float_filter).eval(&manifest_file)?, + "Should read: no NaN information may indicate presence of NaN value" + ); + + // all_nulls_no_nans + let all_nulls_no_nans_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("all_nulls_no_nans"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(all_nulls_no_nans_filter).eval(&manifest_file)?, + "Should skip: no nan column doesn't contain nan value" + ); + + // all_nans + let all_nans_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("all_nans"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nans_filter).eval(&manifest_file)?, + "Should read: all_nans column contains nan value" + ); + + // both_nan_and_null + let both_nan_and_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("both_nan_and_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(both_nan_and_null_filter).eval(&manifest_file)?, + "Should read: both_nan_and_null column contains nan value" + ); - Ok(ManifestEvaluator::new(partition_filter)) + // no_nan_or_null + let no_nan_or_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::IsNan, + Reference::new("no_nan_or_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(no_nan_or_null_filter).eval(&manifest_file)?, + "Should skip: no_nan_or_null column doesn't contain nan value" + ); + + Ok(()) } #[test] - fn test_manifest_file_empty_partitions() -> Result<()> { - let case_sensitive = false; + fn test_not_nan() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + // float + let float_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("float"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(float_filter).eval(&manifest_file)?, + "Should read: no information on if there are nan value in float column" + ); - let (schema, partition_spec) = create_schema_and_partition_spec()?; + // all_nulls_double + let all_nulls_double_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("all_nulls_double"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_double_filter).eval(&manifest_file)?, + "Should read: all null column contains non nan value" + ); - let filter = Predicate::AlwaysTrue.bind(schema.clone(), case_sensitive)?; + // all_nulls_no_nans + let all_nulls_no_nans_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("all_nulls_no_nans"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(all_nulls_no_nans_filter).eval(&manifest_file)?, + "Should read: no_nans column contains non nan value" + ); - let manifest_file = create_manifest_file(vec![]); + // all_nans + let all_nans_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("all_nans"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(all_nans_filter).eval(&manifest_file)?, + "Should skip: all nans column doesn't contain non nan value" + ); - let manifest_evaluator = - create_manifest_evaluator(schema, partition_spec, &filter, case_sensitive)?; + // both_nan_and_null + let both_nan_and_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("both_nan_and_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(both_nan_and_null_filter).eval(&manifest_file)?, + "Should read: both_nan_and_null nans column contains non nan value" + ); + + // no_nan_or_null + let no_nan_or_null_filter = Predicate::Unary(UnaryExpression::new( + PredicateOperator::NotNan, + Reference::new("no_nan_or_null"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(no_nan_or_null_filter).eval(&manifest_file)?, + "Should read: no_nan_or_null column contains non nan value" + ); - let result = manifest_evaluator.eval(&manifest_file)?; + Ok(()) + } - assert!(result); + #[test] + fn test_and() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .and(Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 30), + ))) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: no information on if there are nan value in float column" + ); Ok(()) } #[test] - fn test_manifest_file_trivial_partition_passing_filter() -> Result<()> { + fn test_or() -> Result<()> { let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .or(Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MAX_VALUE + 1), + ))) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should skip: or(false, false)" + ); - let (schema, partition_spec) = create_schema_and_partition_spec()?; + Ok(()) + } - let filter = Predicate::Unary(UnaryExpression::new( - PredicateOperator::IsNull, - Reference::new("a"), + #[test] + fn test_not() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), )) + .not() .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: not(false)" + ); - let manifest_file = create_manifest_file(vec![FieldSummary { - contains_null: true, - contains_nan: None, - lower_bound: None, - upper_bound: None, - }]); + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .not() + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should skip: not(true)" + ); - let manifest_evaluator = - create_manifest_evaluator(schema, partition_spec, &filter, case_sensitive)?; + Ok(()) + } - let result = manifest_evaluator.eval(&manifest_file)?; + #[test] + fn test_less_than() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThan, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id range below lower bound (5 < 30)" + ); - assert!(result); + Ok(()) + } + + #[test] + fn test_less_than_or_eq() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::LessThanOrEq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id range below lower bound (5 < 30)" + ); Ok(()) } #[test] - fn test_manifest_file_trivial_partition_rejected_filter() -> Result<()> { + fn test_greater_than() -> Result<()> { let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThan, + Reference::new("id"), + Datum::int(INT_MAX_VALUE + 6), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id range above upper bound (85 < 79)" + ); - let (schema, partition_spec) = create_schema_and_partition_spec()?; + Ok(()) + } - let filter = Predicate::Unary(UnaryExpression::new( - PredicateOperator::IsNan, - Reference::new("a"), + #[test] + fn test_greater_than_or_eq() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MAX_VALUE + 6), )) .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id range above upper bound (85 < 79)" + ); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::GreaterThanOrEq, + Reference::new("id"), + Datum::int(INT_MAX_VALUE), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: one possible id" + ); + + Ok(()) + } + + #[test] + fn test_eq() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id below lower bound" + ); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: id equal to lower bound" + ); - let manifest_file = create_manifest_file(vec![FieldSummary { - contains_null: false, - contains_nan: None, - lower_bound: None, - upper_bound: None, - }]); + Ok(()) + } - let manifest_evaluator = - create_manifest_evaluator(schema, partition_spec, &filter, case_sensitive)?; + #[test] + fn test_not_eq() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotEq, + Reference::new("id"), + Datum::int(INT_MIN_VALUE - 25), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: id below lower bound" + ); + + Ok(()) + } - let result = manifest_evaluator.eval(&manifest_file).unwrap(); + #[test] + fn test_in() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Set(SetExpression::new( + PredicateOperator::In, + Reference::new("id"), + FnvHashSet::from_iter(vec![ + Datum::int(INT_MIN_VALUE - 25), + Datum::int(INT_MIN_VALUE - 24), + ]), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: id below lower bound (5 < 30, 6 < 30)" + ); - assert!(!result); + let filter = Predicate::Set(SetExpression::new( + PredicateOperator::In, + Reference::new("id"), + FnvHashSet::from_iter(vec![ + Datum::int(INT_MIN_VALUE - 1), + Datum::int(INT_MIN_VALUE), + ]), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: id equal to lower bound (30 == 30)" + ); + + Ok(()) + } + + #[test] + fn test_not_in() -> Result<()> { + let case_sensitive = true; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Set(SetExpression::new( + PredicateOperator::NotIn, + Reference::new("id"), + FnvHashSet::from_iter(vec![ + Datum::int(INT_MIN_VALUE - 25), + Datum::int(INT_MIN_VALUE - 24), + ]), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: id below lower bound (5 < 30, 6 < 30)" + ); + + Ok(()) + } + + #[test] + fn test_starts_with() -> Result<()> { + let case_sensitive = false; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::StartsWith, + Reference::new("some_nulls"), + Datum::string("a"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: range matches" + ); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::StartsWith, + Reference::new("some_nulls"), + Datum::string("zzzz"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should skip: range doesn't match" + ); + + Ok(()) + } + + #[test] + fn test_not_starts_with() -> Result<()> { + let case_sensitive = false; + let schema = create_schema()?; + let partitions = create_partitions(); + let manifest_file = create_manifest_file(partitions); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotStartsWith, + Reference::new("some_nulls"), + Datum::string("a"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should read: range matches" + ); + + let filter = Predicate::Binary(BinaryExpression::new( + PredicateOperator::NotStartsWith, + Reference::new("no_nulls_same_value_a"), + Datum::string("a"), + )) + .bind(schema.clone(), case_sensitive)?; + assert!( + !ManifestEvaluator::new(filter).eval(&manifest_file)?, + "Should not read: all values start with the prefix" + ); Ok(()) } diff --git a/crates/iceberg/src/io.rs b/crates/iceberg/src/io/file_io.rs similarity index 66% rename from crates/iceberg/src/io.rs rename to crates/iceberg/src/io/file_io.rs index c045b22f1..54b2cd487 100644 --- a/crates/iceberg/src/io.rs +++ b/crates/iceberg/src/io/file_io.rs @@ -15,71 +15,15 @@ // specific language governing permissions and limitations // under the License. -//! File io implementation. -//! -//! # How to build `FileIO` -//! -//! We provided a `FileIOBuilder` to build `FileIO` from scratch. For example: -//! ```rust -//! use iceberg::io::{FileIOBuilder, S3_REGION}; -//! -//! let file_io = FileIOBuilder::new("s3") -//! .with_prop(S3_REGION, "us-east-1") -//! .build() -//! .unwrap(); -//! ``` -//! -//! Or you can pass a path to ask `FileIO` to infer schema for you: -//! ```rust -//! use iceberg::io::{FileIO, S3_REGION}; -//! let file_io = FileIO::from_path("s3://bucket/a") -//! .unwrap() -//! .with_prop(S3_REGION, "us-east-1") -//! .build() -//! .unwrap(); -//! ``` -//! -//! # How to use `FileIO` -//! -//! Currently `FileIO` provides simple methods for file operations: -//! -//! - `delete`: Delete file. -//! - `is_exist`: Check if file exists. -//! - `new_input`: Create input file for reading. -//! - `new_output`: Create output file for writing. - +use super::storage::Storage; +use crate::{Error, ErrorKind, Result}; use bytes::Bytes; +use opendal::Operator; +use std::collections::HashMap; use std::ops::Range; -use std::{collections::HashMap, sync::Arc}; - -use crate::{error::Result, Error, ErrorKind}; -use once_cell::sync::Lazy; -use opendal::{Operator, Scheme}; +use std::sync::Arc; use url::Url; -/// Following are arguments for [s3 file io](https://py.iceberg.apache.org/configuration/#s3). -/// S3 endopint. -pub const S3_ENDPOINT: &str = "s3.endpoint"; -/// S3 access key id. -pub const S3_ACCESS_KEY_ID: &str = "s3.access-key-id"; -/// S3 secret access key. -pub const S3_SECRET_ACCESS_KEY: &str = "s3.secret-access-key"; -/// S3 region. -pub const S3_REGION: &str = "s3.region"; - -/// A mapping from iceberg s3 configuration key to [`opendal::Operator`] configuration key. -static S3_CONFIG_MAPPING: Lazy> = Lazy::new(|| { - let mut m = HashMap::with_capacity(4); - m.insert(S3_ENDPOINT, "endpoint"); - m.insert(S3_ACCESS_KEY_ID, "access_key_id"); - m.insert(S3_SECRET_ACCESS_KEY, "secret_access_key"); - m.insert(S3_REGION, "region"); - - m -}); - -const DEFAULT_ROOT_PATH: &str = "/"; - /// FileIO implementation, used to manipulate files in underlying storage. /// /// # Note @@ -91,59 +35,6 @@ pub struct FileIO { inner: Arc, } -/// Builder for [`FileIO`]. -#[derive(Debug)] -pub struct FileIOBuilder { - /// This is used to infer scheme of operator. - /// - /// If this is `None`, then [`FileIOBuilder::build`](FileIOBuilder::build) will build a local file io. - scheme_str: Option, - /// Arguments for operator. - props: HashMap, -} - -impl FileIOBuilder { - /// Creates a new builder with scheme. - pub fn new(scheme_str: impl ToString) -> Self { - Self { - scheme_str: Some(scheme_str.to_string()), - props: HashMap::default(), - } - } - - /// Creates a new builder for local file io. - pub fn new_fs_io() -> Self { - Self { - scheme_str: None, - props: HashMap::default(), - } - } - - /// Add argument for operator. - pub fn with_prop(mut self, key: impl ToString, value: impl ToString) -> Self { - self.props.insert(key.to_string(), value.to_string()); - self - } - - /// Add argument for operator. - pub fn with_props( - mut self, - args: impl IntoIterator, - ) -> Self { - self.props - .extend(args.into_iter().map(|e| (e.0.to_string(), e.1.to_string()))); - self - } - - /// Builds [`FileIO`]. - pub fn build(self) -> Result { - let storage = Storage::build(self)?; - Ok(FileIO { - inner: Arc::new(storage), - }) - } -} - impl FileIO { /// Try to infer file io scheme from path. /// @@ -151,7 +42,7 @@ impl FileIO { /// If it's not a valid url, will try to detect if it's a file path. /// /// Otherwise will return parsing error. - pub fn from_path(path: impl AsRef) -> Result { + pub fn from_path(path: impl AsRef) -> crate::Result { let url = Url::parse(path.as_ref()) .map_err(Error::from) .or_else(|e| { @@ -205,6 +96,66 @@ impl FileIO { } } +/// Builder for [`FileIO`]. +#[derive(Debug)] +pub struct FileIOBuilder { + /// This is used to infer scheme of operator. + /// + /// If this is `None`, then [`FileIOBuilder::build`](FileIOBuilder::build) will build a local file io. + scheme_str: Option, + /// Arguments for operator. + props: HashMap, +} + +impl FileIOBuilder { + /// Creates a new builder with scheme. + pub fn new(scheme_str: impl ToString) -> Self { + Self { + scheme_str: Some(scheme_str.to_string()), + props: HashMap::default(), + } + } + + /// Creates a new builder for local file io. + pub fn new_fs_io() -> Self { + Self { + scheme_str: None, + props: HashMap::default(), + } + } + + /// Fetch the scheme string. + /// + /// The scheme_str will be empty if it's None. + pub(crate) fn into_parts(self) -> (String, HashMap) { + (self.scheme_str.unwrap_or_default(), self.props) + } + + /// Add argument for operator. + pub fn with_prop(mut self, key: impl ToString, value: impl ToString) -> Self { + self.props.insert(key.to_string(), value.to_string()); + self + } + + /// Add argument for operator. + pub fn with_props( + mut self, + args: impl IntoIterator, + ) -> Self { + self.props + .extend(args.into_iter().map(|e| (e.0.to_string(), e.1.to_string()))); + self + } + + /// Builds [`FileIO`]. + pub fn build(self) -> crate::Result { + let storage = Storage::build(self)?; + Ok(FileIO { + inner: Arc::new(storage), + }) + } +} + /// The struct the represents the metadata of a file. /// /// TODO: we can add last modified time, content type, etc. in the future. @@ -224,12 +175,12 @@ pub trait FileRead: Send + Unpin + 'static { /// Read file content with given range. /// /// TODO: we can support reading non-contiguous bytes in the future. - async fn read(&self, range: Range) -> Result; + async fn read(&self, range: Range) -> crate::Result; } #[async_trait::async_trait] impl FileRead for opendal::Reader { - async fn read(&self, range: Range) -> Result { + async fn read(&self, range: Range) -> crate::Result { Ok(opendal::Reader::read(self, range).await?.to_bytes()) } } @@ -251,7 +202,7 @@ impl InputFile { } /// Check if file exists. - pub async fn exists(&self) -> Result { + pub async fn exists(&self) -> crate::Result { Ok(self .op .is_exist(&self.path[self.relative_path_pos..]) @@ -259,7 +210,7 @@ impl InputFile { } /// Fetch and returns metadata of file. - pub async fn metadata(&self) -> Result { + pub async fn metadata(&self) -> crate::Result { let meta = self.op.stat(&self.path[self.relative_path_pos..]).await?; Ok(FileMetadata { @@ -270,7 +221,7 @@ impl InputFile { /// Read and returns whole content of file. /// /// For continues reading, use [`Self::reader`] instead. - pub async fn read(&self) -> Result { + pub async fn read(&self) -> crate::Result { Ok(self .op .read(&self.path[self.relative_path_pos..]) @@ -281,7 +232,7 @@ impl InputFile { /// Creates [`FileRead`] for continues reading. /// /// For one-time reading, use [`Self::read`] instead. - pub async fn reader(&self) -> Result { + pub async fn reader(&self) -> crate::Result { Ok(self.op.reader(&self.path[self.relative_path_pos..]).await?) } } @@ -297,21 +248,21 @@ pub trait FileWrite: Send + Unpin + 'static { /// Write bytes to file. /// /// TODO: we can support writing non-contiguous bytes in the future. - async fn write(&mut self, bs: Bytes) -> Result<()>; + async fn write(&mut self, bs: Bytes) -> crate::Result<()>; /// Close file. /// /// Calling close on closed file will generate an error. - async fn close(&mut self) -> Result<()>; + async fn close(&mut self) -> crate::Result<()>; } #[async_trait::async_trait] impl FileWrite for opendal::Writer { - async fn write(&mut self, bs: Bytes) -> Result<()> { + async fn write(&mut self, bs: Bytes) -> crate::Result<()> { Ok(opendal::Writer::write(self, bs).await?) } - async fn close(&mut self) -> Result<()> { + async fn close(&mut self) -> crate::Result<()> { Ok(opendal::Writer::close(self).await?) } } @@ -333,7 +284,7 @@ impl OutputFile { } /// Checks if file exists. - pub async fn exists(&self) -> Result { + pub async fn exists(&self) -> crate::Result { Ok(self .op .is_exist(&self.path[self.relative_path_pos..]) @@ -355,7 +306,7 @@ impl OutputFile { /// /// Calling `write` will overwrite the file if it exists. /// For continues writing, use [`Self::writer`]. - pub async fn write(&self, bs: Bytes) -> Result<()> { + pub async fn write(&self, bs: Bytes) -> crate::Result<()> { let mut writer = self.writer().await?; writer.write(bs).await?; writer.close().await @@ -366,114 +317,13 @@ impl OutputFile { /// # Notes /// /// For one-time writing, use [`Self::write`] instead. - pub async fn writer(&self) -> Result> { + pub async fn writer(&self) -> crate::Result> { Ok(Box::new( self.op.writer(&self.path[self.relative_path_pos..]).await?, )) } } -// We introduce this because I don't want to handle unsupported `Scheme` in every method. -#[derive(Debug)] -enum Storage { - LocalFs { - op: Operator, - }, - S3 { - scheme_str: String, - props: HashMap, - }, -} - -impl Storage { - /// Creates operator from path. - /// - /// # Arguments - /// - /// * path: It should be *absolute* path starting with scheme string used to construct [`FileIO`]. - /// - /// # Returns - /// - /// The return value consists of two parts: - /// - /// * An [`opendal::Operator`] instance used to operate on file. - /// * Relative path to the root uri of [`opendal::Operator`]. - /// - fn create_operator<'a>(&self, path: &'a impl AsRef) -> Result<(Operator, &'a str)> { - let path = path.as_ref(); - match self { - Storage::LocalFs { op } => { - if let Some(stripped) = path.strip_prefix("file:/") { - Ok((op.clone(), stripped)) - } else { - Ok((op.clone(), &path[1..])) - } - } - Storage::S3 { scheme_str, props } => { - let mut props = props.clone(); - let url = Url::parse(path)?; - let bucket = url.host_str().ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - format!("Invalid s3 url: {}, missing bucket", path), - ) - })?; - - props.insert("bucket".to_string(), bucket.to_string()); - - let prefix = format!("{}://{}/", scheme_str, bucket); - if path.starts_with(&prefix) { - Ok((Operator::via_map(Scheme::S3, props)?, &path[prefix.len()..])) - } else { - Err(Error::new( - ErrorKind::DataInvalid, - format!("Invalid s3 url: {}, should start with {}", path, prefix), - )) - } - } - } - } - - /// Parse scheme. - fn parse_scheme(scheme: &str) -> Result { - match scheme { - "file" | "" => Ok(Scheme::Fs), - "s3" | "s3a" => Ok(Scheme::S3), - s => Ok(s.parse::()?), - } - } - - /// Convert iceberg config to opendal config. - fn build(file_io_builder: FileIOBuilder) -> Result { - let scheme_str = file_io_builder.scheme_str.unwrap_or("".to_string()); - let scheme = Self::parse_scheme(&scheme_str)?; - let mut new_props = HashMap::default(); - new_props.insert("root".to_string(), DEFAULT_ROOT_PATH.to_string()); - - match scheme { - Scheme::Fs => Ok(Self::LocalFs { - op: Operator::via_map(Scheme::Fs, new_props)?, - }), - Scheme::S3 => { - for prop in file_io_builder.props { - if let Some(op_key) = S3_CONFIG_MAPPING.get(prop.0.as_str()) { - new_props.insert(op_key.to_string(), prop.1); - } - } - - Ok(Self::S3 { - scheme_str, - props: new_props, - }) - } - _ => Err(Error::new( - ErrorKind::FeatureUnsupported, - format!("Constructing file io from scheme: {scheme} not supported now",), - )), - } - } -} - #[cfg(test)] mod tests { use std::io::Write; diff --git a/crates/iceberg/src/io/mod.rs b/crates/iceberg/src/io/mod.rs new file mode 100644 index 000000000..914293da3 --- /dev/null +++ b/crates/iceberg/src/io/mod.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! File io implementation. +//! +//! # How to build `FileIO` +//! +//! We provided a `FileIOBuilder` to build `FileIO` from scratch. For example: +//! ```rust +//! use iceberg::io::{FileIOBuilder, S3_REGION}; +//! +//! let file_io = FileIOBuilder::new("s3") +//! .with_prop(S3_REGION, "us-east-1") +//! .build() +//! .unwrap(); +//! ``` +//! +//! Or you can pass a path to ask `FileIO` to infer schema for you: +//! ```rust +//! use iceberg::io::{FileIO, S3_REGION}; +//! let file_io = FileIO::from_path("s3://bucket/a") +//! .unwrap() +//! .with_prop(S3_REGION, "us-east-1") +//! .build() +//! .unwrap(); +//! ``` +//! +//! # How to use `FileIO` +//! +//! Currently `FileIO` provides simple methods for file operations: +//! +//! - `delete`: Delete file. +//! - `is_exist`: Check if file exists. +//! - `new_input`: Create input file for reading. +//! - `new_output`: Create output file for writing. + +mod file_io; +pub use file_io::*; + +mod storage; +#[cfg(feature = "storage-s3")] +mod storage_s3; +#[cfg(feature = "storage-s3")] +pub use storage_s3::*; +#[cfg(feature = "storage-fs")] +mod storage_fs; +#[cfg(feature = "storage-fs")] +use storage_fs::*; diff --git a/crates/iceberg/src/io/storage.rs b/crates/iceberg/src/io/storage.rs new file mode 100644 index 000000000..8d7df45b8 --- /dev/null +++ b/crates/iceberg/src/io/storage.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::FileIOBuilder; +#[cfg(feature = "storage-fs")] +use super::FsConfig; +#[cfg(feature = "storage-s3")] +use super::S3Config; +use crate::{Error, ErrorKind}; +use opendal::{Operator, Scheme}; + +/// The storage carries all supported storage services in iceberg +#[derive(Debug)] +pub(crate) enum Storage { + #[cfg(feature = "storage-fs")] + LocalFs { config: FsConfig }, + #[cfg(feature = "storage-s3")] + S3 { + /// s3 storage could have `s3://` and `s3a://`. + /// Storing the scheme string here to return the correct path. + scheme_str: String, + config: S3Config, + }, +} + +impl Storage { + /// Convert iceberg config to opendal config. + pub(crate) fn build(file_io_builder: FileIOBuilder) -> crate::Result { + let (scheme_str, props) = file_io_builder.into_parts(); + let scheme = Self::parse_scheme(&scheme_str)?; + + match scheme { + #[cfg(feature = "storage-fs")] + Scheme::Fs => Ok(Self::LocalFs { + config: FsConfig::new(props), + }), + #[cfg(feature = "storage-s3")] + Scheme::S3 => Ok(Self::S3 { + scheme_str, + config: S3Config::new(props), + }), + _ => Err(Error::new( + ErrorKind::FeatureUnsupported, + format!("Constructing file io from scheme: {scheme} not supported now",), + )), + } + } + + /// Creates operator from path. + /// + /// # Arguments + /// + /// * path: It should be *absolute* path starting with scheme string used to construct [`FileIO`]. + /// + /// # Returns + /// + /// The return value consists of two parts: + /// + /// * An [`opendal::Operator`] instance used to operate on file. + /// * Relative path to the root uri of [`opendal::Operator`]. + /// + pub(crate) fn create_operator<'a>( + &self, + path: &'a impl AsRef, + ) -> crate::Result<(Operator, &'a str)> { + let path = path.as_ref(); + match self { + #[cfg(feature = "storage-fs")] + Storage::LocalFs { config } => { + let op = config.build(path)?; + + if let Some(stripped) = path.strip_prefix("file:/") { + Ok((op, stripped)) + } else { + Ok((op, &path[1..])) + } + } + #[cfg(feature = "storage-s3")] + Storage::S3 { scheme_str, config } => { + let op = config.build(path)?; + let op_info = op.info(); + + // Check prefix of s3 path. + let prefix = format!("{}://{}/", scheme_str, op_info.name()); + if path.starts_with(&prefix) { + Ok((op, &path[prefix.len()..])) + } else { + Err(Error::new( + ErrorKind::DataInvalid, + format!("Invalid s3 url: {}, should start with {}", path, prefix), + )) + } + } + #[cfg(all(not(feature = "storage-s3"), not(feature = "storage-fs")))] + _ => Err(Error::new( + ErrorKind::FeatureUnsupported, + "No storage service has been enabled", + )), + } + } + + /// Parse scheme. + fn parse_scheme(scheme: &str) -> crate::Result { + match scheme { + "file" | "" => Ok(Scheme::Fs), + "s3" | "s3a" => Ok(Scheme::S3), + s => Ok(s.parse::()?), + } + } +} + +/// redact_secret will redact the secret part of the string. +#[inline] +pub(crate) fn redact_secret(s: &str) -> String { + let len = s.len(); + if len <= 6 { + return "***".to_string(); + } + + format!("{}***{}", &s[0..3], &s[len - 3..len]) +} diff --git a/crates/iceberg/src/io/storage_fs.rs b/crates/iceberg/src/io/storage_fs.rs new file mode 100644 index 000000000..38c3fa129 --- /dev/null +++ b/crates/iceberg/src/io/storage_fs.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::Result; +use opendal::{Operator, Scheme}; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; + +/// # TODO +/// +/// opendal has a plan to introduce native config support. +/// We manually parse the config here and those code will be finally removed. +#[derive(Default, Clone)] +pub(crate) struct FsConfig {} + +impl Debug for FsConfig { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FsConfig").finish() + } +} + +impl FsConfig { + /// Decode from iceberg props. + pub fn new(_: HashMap) -> Self { + Self::default() + } + + /// Build new opendal operator from give path. + /// + /// fs always build from `/` + pub fn build(&self, _: &str) -> Result { + let m = HashMap::from_iter([("root".to_string(), "/".to_string())]); + Ok(Operator::via_map(Scheme::Fs, m)?) + } +} diff --git a/crates/iceberg/src/io/storage_s3.rs b/crates/iceberg/src/io/storage_s3.rs new file mode 100644 index 000000000..d001e06cd --- /dev/null +++ b/crates/iceberg/src/io/storage_s3.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::io::storage::redact_secret; +use crate::{Error, ErrorKind, Result}; +use opendal::{Operator, Scheme}; +use std::collections::HashMap; +use std::fmt::{Debug, Formatter}; +use url::Url; + +/// Following are arguments for [s3 file io](https://py.iceberg.apache.org/configuration/#s3). +/// S3 endpoint. +pub const S3_ENDPOINT: &str = "s3.endpoint"; +/// S3 access key id. +pub const S3_ACCESS_KEY_ID: &str = "s3.access-key-id"; +/// S3 secret access key. +pub const S3_SECRET_ACCESS_KEY: &str = "s3.secret-access-key"; +/// S3 region. +pub const S3_REGION: &str = "s3.region"; + +/// # TODO +/// +/// opendal has a plan to introduce native config support. +/// We manually parse the config here and those code will be finally removed. +#[derive(Default, Clone)] +pub(crate) struct S3Config { + pub endpoint: String, + pub access_key_id: String, + pub secret_access_key: String, + pub region: String, +} + +impl Debug for S3Config { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("S3Config") + .field("endpoint", &self.endpoint) + .field("region", &self.region) + .field("access_key_id", &redact_secret(&self.access_key_id)) + .field("secret_access_key", &redact_secret(&self.secret_access_key)) + .finish() + } +} + +impl S3Config { + /// Decode from iceberg props. + pub fn new(m: HashMap) -> Self { + let mut cfg = Self::default(); + if let Some(endpoint) = m.get(S3_ENDPOINT) { + cfg.endpoint = endpoint.clone(); + }; + if let Some(access_key_id) = m.get(S3_ACCESS_KEY_ID) { + cfg.access_key_id = access_key_id.clone(); + }; + if let Some(secret_access_key) = m.get(S3_SECRET_ACCESS_KEY) { + cfg.secret_access_key = secret_access_key.clone(); + }; + if let Some(region) = m.get(S3_REGION) { + cfg.region = region.clone(); + }; + + cfg + } + + /// Build new opendal operator from give path. + pub fn build(&self, path: &str) -> Result { + let url = Url::parse(path)?; + let bucket = url.host_str().ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Invalid s3 url: {}, missing bucket", path), + ) + })?; + + let mut m = HashMap::with_capacity(5); + m.insert("bucket".to_string(), bucket.to_string()); + m.insert("endpoint".to_string(), self.endpoint.clone()); + m.insert("access_key_id".to_string(), self.access_key_id.clone()); + m.insert( + "secret_access_key".to_string(), + self.secret_access_key.clone(), + ); + m.insert("region".to_string(), self.region.clone()); + + Ok(Operator::via_map(Scheme::S3, m)?) + } +} diff --git a/crates/iceberg/src/lib.rs b/crates/iceberg/src/lib.rs index 407009861..8d22b5d4a 100644 --- a/crates/iceberg/src/lib.rs +++ b/crates/iceberg/src/lib.rs @@ -40,7 +40,6 @@ pub use catalog::TableUpdate; pub use catalog::ViewCreation; pub use catalog::ViewUpdate; -#[allow(dead_code)] pub mod table; mod avro; @@ -49,10 +48,11 @@ pub mod spec; pub mod scan; -#[allow(dead_code)] pub mod expr; pub mod transaction; pub mod transform; +mod runtime; + pub mod arrow; pub mod writer; diff --git a/crates/iceberg/src/runtime/mod.rs b/crates/iceberg/src/runtime/mod.rs new file mode 100644 index 000000000..65c30e82c --- /dev/null +++ b/crates/iceberg/src/runtime/mod.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This module contains the async runtime abstraction for iceberg. + +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub enum JoinHandle { + #[cfg(feature = "tokio")] + Tokio(tokio::task::JoinHandle), + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + AsyncStd(async_std::task::JoinHandle), + #[cfg(all(not(feature = "async-std"), not(feature = "tokio")))] + Unimplemented(Box), +} + +impl Future for JoinHandle { + type Output = T; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.get_mut() { + #[cfg(feature = "tokio")] + JoinHandle::Tokio(handle) => Pin::new(handle) + .poll(cx) + .map(|h| h.expect("tokio spawned task failed")), + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + JoinHandle::AsyncStd(handle) => Pin::new(handle).poll(cx), + #[cfg(all(not(feature = "async-std"), not(feature = "tokio")))] + JoinHandle::Unimplemented(_) => unimplemented!("no runtime has been enabled"), + } + } +} + +#[allow(dead_code)] +pub fn spawn(f: F) -> JoinHandle +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + #[cfg(feature = "tokio")] + return JoinHandle::Tokio(tokio::task::spawn(f)); + + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + return JoinHandle::AsyncStd(async_std::task::spawn(f)); + + #[cfg(all(not(feature = "async-std"), not(feature = "tokio")))] + unimplemented!("no runtime has been enabled") +} + +#[allow(dead_code)] +pub fn spawn_blocking(f: F) -> JoinHandle +where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, +{ + #[cfg(feature = "tokio")] + return JoinHandle::Tokio(tokio::task::spawn_blocking(f)); + + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + return JoinHandle::AsyncStd(async_std::task::spawn_blocking(f)); + + #[cfg(all(not(feature = "async-std"), not(feature = "tokio")))] + unimplemented!("no runtime has been enabled") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "tokio")] + #[tokio::test] + async fn test_tokio_spawn() { + let handle = spawn(async { 1 + 1 }); + assert_eq!(handle.await, 2); + } + + #[cfg(feature = "tokio")] + #[tokio::test] + async fn test_tokio_spawn_blocking() { + let handle = spawn_blocking(|| 1 + 1); + assert_eq!(handle.await, 2); + } + + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + #[async_std::test] + async fn test_async_std_spawn() { + let handle = spawn(async { 1 + 1 }); + assert_eq!(handle.await, 2); + } + + #[cfg(all(feature = "async-std", not(feature = "tokio")))] + #[async_std::test] + async fn test_async_std_spawn_blocking() { + let handle = spawn_blocking(|| 1 + 1); + assert_eq!(handle.await, 2); + } +} diff --git a/crates/iceberg/src/scan.rs b/crates/iceberg/src/scan.rs index 5f0922e93..730e3dadf 100644 --- a/crates/iceberg/src/scan.rs +++ b/crates/iceberg/src/scan.rs @@ -49,7 +49,6 @@ pub struct TableScanBuilder<'a> { table: &'a Table, // Empty column names means to select all columns column_names: Vec, - predicates: Option, snapshot_id: Option, batch_size: Option, case_sensitive: bool, @@ -61,7 +60,6 @@ impl<'a> TableScanBuilder<'a> { Self { table, column_names: vec![], - predicates: None, snapshot_id: None, batch_size: None, case_sensitive: true, @@ -96,12 +94,6 @@ impl<'a> TableScanBuilder<'a> { self } - /// Add a predicate to the scan. The scan will only return rows that match the predicate. - pub fn filter(mut self, predicate: Predicate) -> Self { - self.predicates = Some(predicate); - self - } - /// Select some columns of the table. pub fn select(mut self, column_names: impl IntoIterator) -> Self { self.column_names = column_names @@ -161,17 +153,56 @@ impl<'a> TableScanBuilder<'a> { } } - let bound_predicates = if let Some(ref predicates) = self.predicates { + let bound_predicates = if let Some(ref predicates) = self.filter { Some(predicates.bind(schema.clone(), true)?) } else { None }; + let mut field_ids = vec![]; + for column_name in &self.column_names { + let field_id = schema.field_id_by_name(column_name).ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!( + "Column {} not found in table. Schema: {}", + column_name, schema + ), + ) + })?; + + let field = schema + .as_struct() + .field_by_id(field_id) + .ok_or_else(|| { + Error::new( + ErrorKind::FeatureUnsupported, + format!( + "Column {} is not a direct child of schema but a nested field, which is not supported now. Schema: {}", + column_name, schema + ), + ) + })?; + + if !field.field_type.is_primitive() { + return Err(Error::new( + ErrorKind::FeatureUnsupported, + format!( + "Column {} is not a primitive type. Schema: {}", + column_name, schema + ), + )); + } + + field_ids.push(field_id); + } + Ok(TableScan { snapshot, file_io: self.table.file_io().clone(), table_metadata: self.table.metadata_ref(), column_names: self.column_names, + field_ids, bound_predicates, schema, batch_size: self.batch_size, @@ -183,12 +214,12 @@ impl<'a> TableScanBuilder<'a> { /// Table scan. #[derive(Debug)] -#[allow(dead_code)] pub struct TableScan { snapshot: SnapshotRef, table_metadata: TableMetadataRef, file_io: FileIO, column_names: Vec, + field_ids: Vec, bound_predicates: Option, schema: SchemaRef, batch_size: Option, @@ -212,6 +243,9 @@ impl TableScan { let mut manifest_evaluator_cache = ManifestEvaluatorCache::new(); let mut expression_evaluator_cache = ExpressionEvaluatorCache::new(); + let field_ids = self.field_ids.clone(); + let bound_predicates = self.bound_predicates.clone(); + Ok(try_stream! { let manifest_list = context .snapshot @@ -280,6 +314,9 @@ impl TableScan { data_file_path: manifest_entry.data_file().file_path().to_string(), start: 0, length: manifest_entry.file_size_in_bytes(), + project_field_ids: field_ids.clone(), + predicate: bound_predicates.clone(), + schema: context.schema.clone(), }); yield scan_task?; } @@ -292,57 +329,12 @@ impl TableScan { /// Returns an [`ArrowRecordBatchStream`]. pub async fn to_arrow(&self) -> Result { - let mut arrow_reader_builder = - ArrowReaderBuilder::new(self.file_io.clone(), self.schema.clone()); - - let mut field_ids = vec![]; - for column_name in &self.column_names { - let field_id = self.schema.field_id_by_name(column_name).ok_or_else(|| { - Error::new( - ErrorKind::DataInvalid, - format!( - "Column {} not found in table. Schema: {}", - column_name, self.schema - ), - ) - })?; - - let field = self.schema - .as_struct() - .field_by_id(field_id) - .ok_or_else(|| { - Error::new( - ErrorKind::FeatureUnsupported, - format!( - "Column {} is not a direct child of schema but a nested field, which is not supported now. Schema: {}", - column_name, self.schema - ), - ) - })?; - - if !field.field_type.is_primitive() { - return Err(Error::new( - ErrorKind::FeatureUnsupported, - format!( - "Column {} is not a primitive type. Schema: {}", - column_name, self.schema - ), - )); - } - - field_ids.push(field_id as usize); - } - - arrow_reader_builder = arrow_reader_builder.with_field_ids(field_ids); + let mut arrow_reader_builder = ArrowReaderBuilder::new(self.file_io.clone()); if let Some(batch_size) = self.batch_size { arrow_reader_builder = arrow_reader_builder.with_batch_size(batch_size); } - if let Some(ref bound_predicates) = self.bound_predicates { - arrow_reader_builder = arrow_reader_builder.with_predicates(bound_predicates.clone()); - } - arrow_reader_builder.build().read(self.plan_files().await?) } @@ -353,6 +345,11 @@ impl TableScan { } false } + + /// Returns a reference to the column names of the table scan. + pub fn column_names(&self) -> &[String] { + &self.column_names + } } /// Holds the context necessary for file scanning operations @@ -499,10 +496,12 @@ impl ExpressionEvaluatorCache { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct FileScanTask { data_file_path: String, - #[allow(dead_code)] start: u64, - #[allow(dead_code)] length: u64, + project_field_ids: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + predicate: Option, + schema: SchemaRef, } impl FileScanTask { @@ -510,21 +509,39 @@ impl FileScanTask { pub fn data_file_path(&self) -> &str { &self.data_file_path } + + /// Returns the project field id of this file scan task. + pub fn project_field_ids(&self) -> &[i32] { + &self.project_field_ids + } + + /// Returns the predicate of this file scan task. + pub fn predicate(&self) -> Option<&BoundPredicate> { + self.predicate.as_ref() + } + + /// Returns the schema id of this file scan task. + pub fn schema(&self) -> &Schema { + &self.schema + } } #[cfg(test)] mod tests { - use crate::expr::Reference; + use crate::arrow::ArrowReaderBuilder; + use crate::expr::{BoundPredicate, Reference}; use crate::io::{FileIO, OutputFile}; + use crate::scan::FileScanTask; use crate::spec::{ DataContentType, DataFileBuilder, DataFileFormat, Datum, FormatVersion, Literal, Manifest, ManifestContentType, ManifestEntry, ManifestListWriter, ManifestMetadata, ManifestStatus, - ManifestWriter, Struct, TableMetadata, EMPTY_SNAPSHOT_ID, + ManifestWriter, NestedField, PrimitiveType, Schema, Struct, TableMetadata, Type, + EMPTY_SNAPSHOT_ID, }; use crate::table::Table; use crate::TableIdent; - use arrow_array::{ArrayRef, Int64Array, RecordBatch}; - use futures::TryStreamExt; + use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray}; + use futures::{stream, TryStreamExt}; use parquet::arrow::{ArrowWriter, PARQUET_FIELD_ID_META_KEY}; use parquet::basic::Compression; use parquet::file::properties::WriterProperties; @@ -705,10 +722,15 @@ mod tests { PARQUET_FIELD_ID_META_KEY.to_string(), "3".to_string(), )])), + arrow_schema::Field::new("a", arrow_schema::DataType::Utf8, false) + .with_metadata(HashMap::from([( + PARQUET_FIELD_ID_META_KEY.to_string(), + "4".to_string(), + )])), ]; Arc::new(arrow_schema::Schema::new(fields)) }; - // 3 columns: + // 4 columns: // x: [1, 1, 1, 1, ...] let col1 = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; @@ -725,7 +747,14 @@ mod tests { // z: [3, 3, 3, 3, ..., 4, 4, 4, 4] let col3 = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; - let to_write = RecordBatch::try_new(schema.clone(), vec![col1, col2, col3]).unwrap(); + + // a: ["Apache", "Apache", "Apache", ..., "Iceberg", "Iceberg", "Iceberg"] + let mut values = vec!["Apache"; 512]; + values.append(vec!["Iceberg"; 512].as_mut()); + let col4 = Arc::new(StringArray::from_iter_values(values)) as ArrayRef; + + let to_write = + RecordBatch::try_new(schema.clone(), vec![col1, col2, col3, col4]).unwrap(); // Write the Parquet files let props = WriterProperties::builder() @@ -773,7 +802,7 @@ mod tests { fn test_select_no_exist_column() { let table = TableTestFixture::new().table; - let table_scan = table.scan().select(["x", "y", "z", "a"]).build(); + let table_scan = table.scan().select(["x", "y", "z", "a", "b"]).build(); assert!(table_scan.is_err()); } @@ -861,6 +890,39 @@ mod tests { assert_eq!(int64_arr.value(0), 1); } + #[tokio::test] + async fn test_open_parquet_no_deletions_by_separate_reader() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Create table scan for current snapshot and plan files + let table_scan = fixture.table.scan().build().unwrap(); + + let mut plan_task: Vec<_> = table_scan + .plan_files() + .await + .unwrap() + .try_collect() + .await + .unwrap(); + assert_eq!(plan_task.len(), 2); + + let reader = ArrowReaderBuilder::new(fixture.table.file_io().clone()).build(); + let batch_stream = reader + .clone() + .read(Box::pin(stream::iter(vec![Ok(plan_task.remove(0))]))) + .unwrap(); + let batche1: Vec<_> = batch_stream.try_collect().await.unwrap(); + + let reader = ArrowReaderBuilder::new(fixture.table.file_io().clone()).build(); + let batch_stream = reader + .read(Box::pin(stream::iter(vec![Ok(plan_task.remove(0))]))) + .unwrap(); + let batche2: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batche1, batche2); + } + #[tokio::test] async fn test_open_parquet_with_projection() { let mut fixture = TableTestFixture::new(); @@ -892,7 +954,7 @@ mod tests { // Filter: y < 3 let mut builder = fixture.table.scan(); let predicate = Reference::new("y").less_than(Datum::long(3)); - builder = builder.filter(predicate); + builder = builder.with_filter(predicate); let table_scan = builder.build().unwrap(); let batch_stream = table_scan.to_arrow().await.unwrap(); @@ -918,7 +980,7 @@ mod tests { // Filter: y >= 5 let mut builder = fixture.table.scan(); let predicate = Reference::new("y").greater_than_or_equal_to(Datum::long(5)); - builder = builder.filter(predicate); + builder = builder.with_filter(predicate); let table_scan = builder.build().unwrap(); let batch_stream = table_scan.to_arrow().await.unwrap(); @@ -944,7 +1006,7 @@ mod tests { // Filter: y is null let mut builder = fixture.table.scan(); let predicate = Reference::new("y").is_null(); - builder = builder.filter(predicate); + builder = builder.with_filter(predicate); let table_scan = builder.build().unwrap(); let batch_stream = table_scan.to_arrow().await.unwrap(); @@ -961,7 +1023,7 @@ mod tests { // Filter: y is not null let mut builder = fixture.table.scan(); let predicate = Reference::new("y").is_not_null(); - builder = builder.filter(predicate); + builder = builder.with_filter(predicate); let table_scan = builder.build().unwrap(); let batch_stream = table_scan.to_arrow().await.unwrap(); @@ -980,7 +1042,7 @@ mod tests { let predicate = Reference::new("y") .less_than(Datum::long(5)) .and(Reference::new("z").greater_than_or_equal_to(Datum::long(4))); - builder = builder.filter(predicate); + builder = builder.with_filter(predicate); let table_scan = builder.build().unwrap(); let batch_stream = table_scan.to_arrow().await.unwrap(); @@ -1014,7 +1076,7 @@ mod tests { let predicate = Reference::new("y") .less_than(Datum::long(5)) .or(Reference::new("z").greater_than_or_equal_to(Datum::long(4))); - builder = builder.filter(predicate); + builder = builder.with_filter(predicate); let table_scan = builder.build().unwrap(); let batch_stream = table_scan.to_arrow().await.unwrap(); @@ -1040,4 +1102,141 @@ mod tests { let expected_z = Arc::new(Int64Array::from_iter_values(values)) as ArrayRef; assert_eq!(col, &expected_z); } + + #[tokio::test] + async fn test_filter_on_arrow_startswith() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: a STARTSWITH "Ice" + let mut builder = fixture.table.scan(); + let predicate = Reference::new("a").starts_with(Datum::string("Ice")); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("a").unwrap(); + let string_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.value(0), "Iceberg"); + } + + #[tokio::test] + async fn test_filter_on_arrow_not_startswith() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: a NOT STARTSWITH "Ice" + let mut builder = fixture.table.scan(); + let predicate = Reference::new("a").not_starts_with(Datum::string("Ice")); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("a").unwrap(); + let string_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.value(0), "Apache"); + } + + #[tokio::test] + async fn test_filter_on_arrow_in() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: a IN ("Sioux", "Iceberg") + let mut builder = fixture.table.scan(); + let predicate = + Reference::new("a").is_in([Datum::string("Sioux"), Datum::string("Iceberg")]); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("a").unwrap(); + let string_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.value(0), "Iceberg"); + } + + #[tokio::test] + async fn test_filter_on_arrow_not_in() { + let mut fixture = TableTestFixture::new(); + fixture.setup_manifest_files().await; + + // Filter: a NOT IN ("Sioux", "Iceberg") + let mut builder = fixture.table.scan(); + let predicate = + Reference::new("a").is_not_in([Datum::string("Sioux"), Datum::string("Iceberg")]); + builder = builder.with_filter(predicate); + let table_scan = builder.build().unwrap(); + + let batch_stream = table_scan.to_arrow().await.unwrap(); + + let batches: Vec<_> = batch_stream.try_collect().await.unwrap(); + + assert_eq!(batches[0].num_rows(), 512); + + let col = batches[0].column_by_name("a").unwrap(); + let string_arr = col.as_any().downcast_ref::().unwrap(); + assert_eq!(string_arr.value(0), "Apache"); + } + + #[test] + fn test_file_scan_task_serialize_deserialize() { + let test_fn = |task: FileScanTask| { + let serialized = serde_json::to_string(&task).unwrap(); + let deserialized: FileScanTask = serde_json::from_str(&serialized).unwrap(); + + assert_eq!(task.data_file_path, deserialized.data_file_path); + assert_eq!(task.start, deserialized.start); + assert_eq!(task.length, deserialized.length); + assert_eq!(task.project_field_ids, deserialized.project_field_ids); + assert_eq!(task.predicate, deserialized.predicate); + assert_eq!(task.schema, deserialized.schema); + }; + + // without predicate + let schema = Arc::new( + Schema::builder() + .with_fields(vec![Arc::new(NestedField::required( + 1, + "x", + Type::Primitive(PrimitiveType::Binary), + ))]) + .build() + .unwrap(), + ); + let task = FileScanTask { + data_file_path: "data_file_path".to_string(), + start: 0, + length: 100, + project_field_ids: vec![1, 2, 3], + predicate: None, + schema: schema.clone(), + }; + test_fn(task); + + // with predicate + let task = FileScanTask { + data_file_path: "data_file_path".to_string(), + start: 0, + length: 100, + project_field_ids: vec![1, 2, 3], + predicate: Some(BoundPredicate::AlwaysTrue), + schema, + }; + test_fn(task); + } } diff --git a/crates/iceberg/src/spec/datatypes.rs b/crates/iceberg/src/spec/datatypes.rs index cc911fa1a..80ac4a0a8 100644 --- a/crates/iceberg/src/spec/datatypes.rs +++ b/crates/iceberg/src/spec/datatypes.rs @@ -41,38 +41,37 @@ pub(crate) const MAX_DECIMAL_BYTES: u32 = 24; pub(crate) const MAX_DECIMAL_PRECISION: u32 = 38; mod _decimal { - use lazy_static::lazy_static; + use once_cell::sync::Lazy; use crate::spec::{MAX_DECIMAL_BYTES, MAX_DECIMAL_PRECISION}; - lazy_static! { - // Max precision of bytes, starts from 1 - pub(super) static ref MAX_PRECISION: [u32; MAX_DECIMAL_BYTES as usize] = { - let mut ret: [u32; 24] = [0; 24]; - for (i, prec) in ret.iter_mut().enumerate() { - *prec = 2f64.powi((8 * (i + 1) - 1) as i32).log10().floor() as u32; - } + // Max precision of bytes, starts from 1 + pub(super) static MAX_PRECISION: Lazy<[u32; MAX_DECIMAL_BYTES as usize]> = Lazy::new(|| { + let mut ret: [u32; 24] = [0; 24]; + for (i, prec) in ret.iter_mut().enumerate() { + *prec = 2f64.powi((8 * (i + 1) - 1) as i32).log10().floor() as u32; + } - ret - }; + ret + }); - // Required bytes of precision, starts from 1 - pub(super) static ref REQUIRED_LENGTH: [u32; MAX_DECIMAL_PRECISION as usize] = { - let mut ret: [u32; MAX_DECIMAL_PRECISION as usize] = [0; MAX_DECIMAL_PRECISION as usize]; + // Required bytes of precision, starts from 1 + pub(super) static REQUIRED_LENGTH: Lazy<[u32; MAX_DECIMAL_PRECISION as usize]> = + Lazy::new(|| { + let mut ret: [u32; MAX_DECIMAL_PRECISION as usize] = + [0; MAX_DECIMAL_PRECISION as usize]; for (i, required_len) in ret.iter_mut().enumerate() { for j in 0..MAX_PRECISION.len() { - if MAX_PRECISION[j] >= ((i+1) as u32) { - *required_len = (j+1) as u32; + if MAX_PRECISION[j] >= ((i + 1) as u32) { + *required_len = (j + 1) as u32; break; } } } ret - }; - - } + }); } #[derive(Debug, PartialEq, Eq, Clone)] @@ -667,6 +666,13 @@ pub struct ListType { pub element_field: NestedFieldRef, } +impl ListType { + /// Construct a list type with the given element field. + pub fn new(element_field: NestedFieldRef) -> Self { + Self { element_field } + } +} + /// Module for type serialization/deserialization. pub(super) mod _serde { use crate::spec::datatypes::Type::Map; @@ -782,6 +788,16 @@ pub struct MapType { pub value_field: NestedFieldRef, } +impl MapType { + /// Construct a map type with the given key and value fields. + pub fn new(key_field: NestedFieldRef, value_field: NestedFieldRef) -> Self { + Self { + key_field, + value_field, + } + } +} + #[cfg(test)] mod tests { use pretty_assertions::assert_eq; diff --git a/crates/iceberg/src/spec/manifest.rs b/crates/iceberg/src/spec/manifest.rs index f5a598472..f4b933175 100644 --- a/crates/iceberg/src/spec/manifest.rs +++ b/crates/iceberg/src/spec/manifest.rs @@ -289,8 +289,8 @@ impl ManifestWriter { avro_writer.append(value)?; } - let length = avro_writer.flush()?; let content = avro_writer.into_inner()?; + let length = content.len(); self.output.write(Bytes::from(content)).await?; let partition_summary = @@ -1203,13 +1203,11 @@ impl std::fmt::Display for DataFileFormat { mod _serde { use std::collections::HashMap; - use serde_bytes::ByteBuf; use serde_derive::{Deserialize, Serialize}; use serde_with::serde_as; use crate::spec::Datum; use crate::spec::Literal; - use crate::spec::PrimitiveLiteral; use crate::spec::RawLiteral; use crate::spec::Schema; use crate::spec::Struct; @@ -1333,12 +1331,8 @@ mod _serde { value_counts: Some(to_i64_entry(value.value_counts)?), null_value_counts: Some(to_i64_entry(value.null_value_counts)?), nan_value_counts: Some(to_i64_entry(value.nan_value_counts)?), - lower_bounds: Some(to_bytes_entry( - value.lower_bounds.into_iter().map(|(k, v)| (k, v.into())), - )), - upper_bounds: Some(to_bytes_entry( - value.upper_bounds.into_iter().map(|(k, v)| (k, v.into())), - )), + lower_bounds: Some(to_bytes_entry(value.lower_bounds)), + upper_bounds: Some(to_bytes_entry(value.upper_bounds)), key_metadata: Some(serde_bytes::ByteBuf::from(value.key_metadata)), split_offsets: Some(value.split_offsets), equality_ids: Some(value.equality_ids), @@ -1442,11 +1436,11 @@ mod _serde { Ok(m) } - fn to_bytes_entry(v: impl IntoIterator) -> Vec { + fn to_bytes_entry(v: impl IntoIterator) -> Vec { v.into_iter() .map(|e| BytesEntry { key: e.0, - value: Into::::into(e.1), + value: e.1.to_bytes(), }) .collect() } @@ -1906,8 +1900,7 @@ mod tests { partition: Struct::from_iter( vec![ Some( - Literal::try_from_bytes(&[120], &Type::Primitive(PrimitiveType::String)) - .unwrap() + Literal::string("x"), ), ] .into_iter() diff --git a/crates/iceberg/src/spec/manifest_list.rs b/crates/iceberg/src/spec/manifest_list.rs index ec7e2d8e6..688bdef7d 100644 --- a/crates/iceberg/src/spec/manifest_list.rs +++ b/crates/iceberg/src/spec/manifest_list.rs @@ -675,7 +675,7 @@ pub struct FieldSummary { /// [ManifestFileV1] and [ManifestFileV2] are internal struct that are only used for serialization and deserialization. pub(super) mod _serde { use crate::{ - spec::{Datum, PrimitiveLiteral, PrimitiveType, StructType}, + spec::{Datum, PrimitiveType, StructType}, Error, }; pub use serde_bytes::ByteBuf; @@ -965,8 +965,8 @@ pub(super) mod _serde { .map(|v| FieldSummary { contains_null: v.contains_null, contains_nan: v.contains_nan, - lower_bound: v.lower_bound.map(|v| PrimitiveLiteral::from(v).into()), - upper_bound: v.upper_bound.map(|v| PrimitiveLiteral::from(v).into()), + lower_bound: v.lower_bound.map(|v| v.to_bytes()), + upper_bound: v.upper_bound.map(|v| v.to_bytes()), }) .collect(), ) diff --git a/crates/iceberg/src/spec/schema.rs b/crates/iceberg/src/spec/schema.rs index 3e188e157..c76701b0c 100644 --- a/crates/iceberg/src/spec/schema.rs +++ b/crates/iceberg/src/spec/schema.rs @@ -317,7 +317,7 @@ impl Schema { /// Returns [`schema_id`]. #[inline] - pub fn schema_id(&self) -> i32 { + pub fn schema_id(&self) -> SchemaId { self.schema_id } @@ -329,8 +329,8 @@ impl Schema { /// Returns [`identifier_field_ids`]. #[inline] - pub fn identifier_field_ids(&self) -> &HashSet { - &self.identifier_field_ids + pub fn identifier_field_ids(&self) -> impl Iterator + '_ { + self.identifier_field_ids.iter().copied() } /// Get field id by full name. @@ -1196,7 +1196,7 @@ mod tests { (schema, record) } - fn table_schema_nested() -> Schema { + pub fn table_schema_nested() -> Schema { Schema::builder() .with_schema_id(1) .with_identifier_field_ids(vec![2]) diff --git a/crates/iceberg/src/spec/values.rs b/crates/iceberg/src/spec/values.rs index 567a847c0..310bec1d1 100644 --- a/crates/iceberg/src/spec/values.rs +++ b/crates/iceberg/src/spec/values.rs @@ -30,6 +30,9 @@ use bitvec::vec::BitVec; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; use ordered_float::OrderedFloat; use rust_decimal::Decimal; +use serde::de::{self, MapAccess}; +use serde::ser::SerializeStruct; +use serde::{Deserialize, Serialize}; use serde_bytes::ByteBuf; use serde_json::{Map as JsonMap, Number, Value as JsonValue}; use uuid::Uuid; @@ -105,6 +108,115 @@ pub struct Datum { literal: PrimitiveLiteral, } +impl Serialize for Datum { + fn serialize( + &self, + serializer: S, + ) -> std::result::Result { + let mut struct_ser = serializer + .serialize_struct("Datum", 2) + .map_err(serde::ser::Error::custom)?; + struct_ser + .serialize_field("type", &self.r#type) + .map_err(serde::ser::Error::custom)?; + struct_ser + .serialize_field( + "literal", + &RawLiteral::try_from( + Literal::Primitive(self.literal.clone()), + &Type::Primitive(self.r#type.clone()), + ) + .map_err(serde::ser::Error::custom)?, + ) + .map_err(serde::ser::Error::custom)?; + struct_ser.end() + } +} + +impl<'de> Deserialize<'de> for Datum { + fn deserialize>( + deserializer: D, + ) -> std::result::Result { + #[derive(Deserialize)] + #[serde(field_identifier, rename_all = "lowercase")] + enum Field { + Type, + Literal, + } + + struct DatumVisitor; + + impl<'de> serde::de::Visitor<'de> for DatumVisitor { + type Value = Datum; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("struct Datum") + } + + fn visit_seq(self, mut seq: A) -> std::result::Result + where + A: serde::de::SeqAccess<'de>, + { + let r#type = seq + .next_element::()? + .ok_or_else(|| serde::de::Error::invalid_length(0, &self))?; + let value = seq + .next_element::()? + .ok_or_else(|| serde::de::Error::invalid_length(1, &self))?; + let Literal::Primitive(primitive) = value + .try_into(&Type::Primitive(r#type.clone())) + .map_err(serde::de::Error::custom)? + .ok_or_else(|| serde::de::Error::custom("None value"))? + else { + return Err(serde::de::Error::custom("Invalid value")); + }; + + Ok(Datum::new(r#type, primitive)) + } + + fn visit_map(self, mut map: V) -> std::result::Result + where + V: MapAccess<'de>, + { + let mut raw_primitive: Option = None; + let mut r#type: Option = None; + while let Some(key) = map.next_key()? { + match key { + Field::Type => { + if r#type.is_some() { + return Err(de::Error::duplicate_field("type")); + } + r#type = Some(map.next_value()?); + } + Field::Literal => { + if raw_primitive.is_some() { + return Err(de::Error::duplicate_field("literal")); + } + raw_primitive = Some(map.next_value()?); + } + } + } + let Some(r#type) = r#type else { + return Err(serde::de::Error::missing_field("type")); + }; + let Some(raw_primitive) = raw_primitive else { + return Err(serde::de::Error::missing_field("literal")); + }; + let Literal::Primitive(primitive) = raw_primitive + .try_into(&Type::Primitive(r#type.clone())) + .map_err(serde::de::Error::custom)? + .ok_or_else(|| serde::de::Error::custom("None value"))? + else { + return Err(serde::de::Error::custom("Invalid value")); + }; + Ok(Datum::new(r#type, primitive)) + } + } + const FIELDS: &[&str] = &["type", "literal"]; + deserializer.deserialize_struct("Datum", FIELDS, DatumVisitor) + } +} + impl PartialOrd for Datum { fn partial_cmp(&self, other: &Self) -> Option { match (&self.literal, &other.literal, &self.r#type, &other.r#type) { @@ -270,7 +382,9 @@ impl Datum { Datum { r#type, literal } } - /// Create iceberg value from bytes + /// Create iceberg value from bytes. + /// + /// See [this spec](https://iceberg.apache.org/spec/#binary-single-value-serialization) for reference. pub fn try_from_bytes(bytes: &[u8], data_type: PrimitiveType) -> Result { let literal = match data_type { PrimitiveType::Boolean => { @@ -312,6 +426,34 @@ impl Datum { Ok(Datum::new(data_type, literal)) } + /// Convert the value to bytes + /// + /// See [this spec](https://iceberg.apache.org/spec/#binary-single-value-serialization) for reference. + pub fn to_bytes(&self) -> ByteBuf { + match &self.literal { + PrimitiveLiteral::Boolean(val) => { + if *val { + ByteBuf::from([1u8]) + } else { + ByteBuf::from([0u8]) + } + } + PrimitiveLiteral::Int(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Long(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Float(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Double(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Date(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Time(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Timestamp(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::Timestamptz(val) => ByteBuf::from(val.to_le_bytes()), + PrimitiveLiteral::String(val) => ByteBuf::from(val.as_bytes()), + PrimitiveLiteral::UUID(val) => ByteBuf::from(val.as_u128().to_be_bytes()), + PrimitiveLiteral::Fixed(val) => ByteBuf::from(val.as_slice()), + PrimitiveLiteral::Binary(val) => ByteBuf::from(val.as_slice()), + PrimitiveLiteral::Decimal(_) => todo!(), + } + } + /// Creates a boolean value. /// /// Example: @@ -1336,78 +1478,6 @@ impl Literal { } } -impl From for ByteBuf { - fn from(value: PrimitiveLiteral) -> Self { - match value { - PrimitiveLiteral::Boolean(val) => { - if val { - ByteBuf::from([1u8]) - } else { - ByteBuf::from([0u8]) - } - } - PrimitiveLiteral::Int(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Long(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Float(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Double(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Date(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Time(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Timestamp(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::Timestamptz(val) => ByteBuf::from(val.to_le_bytes()), - PrimitiveLiteral::String(val) => ByteBuf::from(val.as_bytes()), - PrimitiveLiteral::UUID(val) => ByteBuf::from(val.as_u128().to_be_bytes()), - PrimitiveLiteral::Fixed(val) => ByteBuf::from(val), - PrimitiveLiteral::Binary(val) => ByteBuf::from(val), - PrimitiveLiteral::Decimal(_) => todo!(), - } - } -} - -impl From for ByteBuf { - fn from(value: Literal) -> Self { - match value { - Literal::Primitive(val) => val.into(), - _ => unimplemented!(), - } - } -} - -impl From for Vec { - fn from(value: PrimitiveLiteral) -> Self { - match value { - PrimitiveLiteral::Boolean(val) => { - if val { - Vec::from([1u8]) - } else { - Vec::from([0u8]) - } - } - PrimitiveLiteral::Int(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Long(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Float(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Double(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Date(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Time(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Timestamp(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::Timestamptz(val) => Vec::from(val.to_le_bytes()), - PrimitiveLiteral::String(val) => Vec::from(val.as_bytes()), - PrimitiveLiteral::UUID(val) => Vec::from(val.as_u128().to_be_bytes()), - PrimitiveLiteral::Fixed(val) => val, - PrimitiveLiteral::Binary(val) => val, - PrimitiveLiteral::Decimal(_) => todo!(), - } - } -} - -impl From for Vec { - fn from(value: Literal) -> Self { - match value { - Literal::Primitive(val) => val.into(), - _ => unimplemented!(), - } - } -} - /// The partition struct stores the tuple of partition values for each file. /// Its type is derived from the partition fields of the partition spec used to write the manifest file. /// In v2, the partition struct’s field ids must match the ids from the partition spec. @@ -1510,21 +1580,9 @@ impl FromIterator> for Struct { } impl Literal { - /// Create iceberg value from bytes - pub fn try_from_bytes(bytes: &[u8], data_type: &Type) -> Result { - match data_type { - Type::Primitive(primitive_type) => { - let datum = Datum::try_from_bytes(bytes, primitive_type.clone())?; - Ok(Literal::Primitive(datum.literal)) - } - _ => Err(Error::new( - crate::ErrorKind::DataInvalid, - "Converting bytes to non-primitive types is not supported.", - )), - } - } - /// Create iceberg value from a json value + /// + /// See [this spec](https://iceberg.apache.org/spec/#json-single-value-serialization) for reference. pub fn try_from_json(value: JsonValue, data_type: &Type) -> Result> { match data_type { Type::Primitive(primitive) => match (primitive, value) { @@ -2320,10 +2378,17 @@ mod _serde { RawLiteralEnum::Boolean(v) => Ok(Some(Literal::bool(v))), RawLiteralEnum::Int(v) => match ty { Type::Primitive(PrimitiveType::Int) => Ok(Some(Literal::int(v))), + Type::Primitive(PrimitiveType::Long) => Ok(Some(Literal::long(i64::from(v)))), Type::Primitive(PrimitiveType::Date) => Ok(Some(Literal::date(v))), _ => Err(invalid_err("int")), }, RawLiteralEnum::Long(v) => match ty { + Type::Primitive(PrimitiveType::Int) => Ok(Some(Literal::int( + i32::try_from(v).map_err(|_| invalid_err("long"))?, + ))), + Type::Primitive(PrimitiveType::Date) => Ok(Some(Literal::date( + i32::try_from(v).map_err(|_| invalid_err("long"))?, + ))), Type::Primitive(PrimitiveType::Long) => Ok(Some(Literal::long(v))), Type::Primitive(PrimitiveType::Time) => Ok(Some(Literal::time(v))), Type::Primitive(PrimitiveType::Timestamp) => Ok(Some(Literal::timestamp(v))), @@ -2334,9 +2399,23 @@ mod _serde { }, RawLiteralEnum::Float(v) => match ty { Type::Primitive(PrimitiveType::Float) => Ok(Some(Literal::float(v))), + Type::Primitive(PrimitiveType::Double) => { + Ok(Some(Literal::double(f64::from(v)))) + } _ => Err(invalid_err("float")), }, RawLiteralEnum::Double(v) => match ty { + Type::Primitive(PrimitiveType::Float) => { + let v_32 = v as f32; + if v_32.is_finite() { + let v_64 = f64::from(v_32); + if (v_64 - v).abs() > f32::EPSILON as f64 { + // there is a precision loss + return Err(invalid_err("double")); + } + } + Ok(Some(Literal::float(v_32))) + } Type::Primitive(PrimitiveType::Double) => Ok(Some(Literal::double(v))), _ => Err(invalid_err("double")), }, @@ -2418,6 +2497,89 @@ mod _serde { } Ok(Some(Literal::Map(map))) } + Type::Primitive(PrimitiveType::Uuid) => { + if v.list.len() != 16 { + return Err(invalid_err_with_reason( + "list", + "The length of list should be 16", + )); + } + let mut bytes = [0u8; 16]; + for (i, v) in v.list.iter().enumerate() { + if let Some(RawLiteralEnum::Long(v)) = v { + bytes[i] = *v as u8; + } else { + return Err(invalid_err_with_reason( + "list", + "The element of list should be int", + )); + } + } + Ok(Some(Literal::uuid(uuid::Uuid::from_bytes(bytes)))) + } + Type::Primitive(PrimitiveType::Decimal { + precision: _, + scale: _, + }) => { + if v.list.len() != 16 { + return Err(invalid_err_with_reason( + "list", + "The length of list should be 16", + )); + } + let mut bytes = [0u8; 16]; + for (i, v) in v.list.iter().enumerate() { + if let Some(RawLiteralEnum::Long(v)) = v { + bytes[i] = *v as u8; + } else { + return Err(invalid_err_with_reason( + "list", + "The element of list should be int", + )); + } + } + Ok(Some(Literal::decimal(i128::from_be_bytes(bytes)))) + } + Type::Primitive(PrimitiveType::Binary) => { + let bytes = v + .list + .into_iter() + .map(|v| { + if let Some(RawLiteralEnum::Long(v)) = v { + Ok(v as u8) + } else { + Err(invalid_err_with_reason( + "list", + "The element of list should be int", + )) + } + }) + .collect::, Error>>()?; + Ok(Some(Literal::binary(bytes))) + } + Type::Primitive(PrimitiveType::Fixed(size)) => { + if v.list.len() != *size as usize { + return Err(invalid_err_with_reason( + "list", + "The length of list should be equal to size", + )); + } + let bytes = v + .list + .into_iter() + .map(|v| { + if let Some(RawLiteralEnum::Long(v)) = v { + Ok(v as u8) + } else { + Err(invalid_err_with_reason( + "list", + "The element of list should be int", + )) + } + }) + .collect::, Error>>()?; + Ok(Some(Literal::fixed(bytes))) + } _ => Err(invalid_err("list")), } } @@ -2498,23 +2660,27 @@ mod tests { assert_eq!(parsed_json_value, raw_json_value); } - fn check_avro_bytes_serde(input: Vec, expected_literal: Literal, expected_type: &Type) { + fn check_avro_bytes_serde( + input: Vec, + expected_datum: Datum, + expected_type: &PrimitiveType, + ) { let raw_schema = r#""bytes""#; let schema = apache_avro::Schema::parse_str(raw_schema).unwrap(); let bytes = ByteBuf::from(input); - let literal = Literal::try_from_bytes(&bytes, expected_type).unwrap(); - assert_eq!(literal, expected_literal); + let datum = Datum::try_from_bytes(&bytes, expected_type.clone()).unwrap(); + assert_eq!(datum, expected_datum); let mut writer = apache_avro::Writer::new(&schema, Vec::new()); - writer.append_ser(ByteBuf::from(literal)).unwrap(); + writer.append_ser(datum.to_bytes()).unwrap(); let encoded = writer.into_inner().unwrap(); let reader = apache_avro::Reader::with_schema(&schema, &*encoded).unwrap(); for record in reader { let result = apache_avro::from_value::(&record.unwrap()).unwrap(); - let desered_literal = Literal::try_from_bytes(&result, expected_type).unwrap(); - assert_eq!(desered_literal, expected_literal); + let desered_datum = Datum::try_from_bytes(&result, expected_type.clone()).unwrap(); + assert_eq!(desered_datum, expected_datum); } } @@ -2783,66 +2949,42 @@ mod tests { fn avro_bytes_boolean() { let bytes = vec![1u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Boolean(true)), - &Type::Primitive(PrimitiveType::Boolean), - ); + check_avro_bytes_serde(bytes, Datum::bool(true), &PrimitiveType::Boolean); } #[test] fn avro_bytes_int() { let bytes = vec![32u8, 0u8, 0u8, 0u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Int(32)), - &Type::Primitive(PrimitiveType::Int), - ); + check_avro_bytes_serde(bytes, Datum::int(32), &PrimitiveType::Int); } #[test] fn avro_bytes_long() { let bytes = vec![32u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Long(32)), - &Type::Primitive(PrimitiveType::Long), - ); + check_avro_bytes_serde(bytes, Datum::long(32), &PrimitiveType::Long); } #[test] fn avro_bytes_float() { let bytes = vec![0u8, 0u8, 128u8, 63u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Float(OrderedFloat(1.0))), - &Type::Primitive(PrimitiveType::Float), - ); + check_avro_bytes_serde(bytes, Datum::float(1.0), &PrimitiveType::Float); } #[test] fn avro_bytes_double() { let bytes = vec![0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 240u8, 63u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::Double(OrderedFloat(1.0))), - &Type::Primitive(PrimitiveType::Double), - ); + check_avro_bytes_serde(bytes, Datum::double(1.0), &PrimitiveType::Double); } #[test] fn avro_bytes_string() { let bytes = vec![105u8, 99u8, 101u8, 98u8, 101u8, 114u8, 103u8]; - check_avro_bytes_serde( - bytes, - Literal::Primitive(PrimitiveLiteral::String("iceberg".to_string())), - &Type::Primitive(PrimitiveType::String), - ); + check_avro_bytes_serde(bytes, Datum::string("iceberg"), &PrimitiveType::String); } #[test] @@ -3180,4 +3322,78 @@ mod tests { "Parse timestamptz with invalid input should fail!" ); } + + #[test] + fn test_datum_ser_deser() { + let test_fn = |datum: Datum| { + let json = serde_json::to_value(&datum).unwrap(); + let desered_datum: Datum = serde_json::from_value(json).unwrap(); + assert_eq!(datum, desered_datum); + }; + let datum = Datum::int(1); + test_fn(datum); + let datum = Datum::long(1); + test_fn(datum); + + let datum = Datum::float(1.0); + test_fn(datum); + let datum = Datum::float(0_f32); + test_fn(datum); + let datum = Datum::float(-0_f32); + test_fn(datum); + let datum = Datum::float(f32::MAX); + test_fn(datum); + let datum = Datum::float(f32::MIN); + test_fn(datum); + + // serde_json can't serialize f32::INFINITY, f32::NEG_INFINITY, f32::NAN + let datum = Datum::float(f32::INFINITY); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + let datum = Datum::float(f32::NEG_INFINITY); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + let datum = Datum::float(f32::NAN); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + + let datum = Datum::double(1.0); + test_fn(datum); + let datum = Datum::double(f64::MAX); + test_fn(datum); + let datum = Datum::double(f64::MIN); + test_fn(datum); + + // serde_json can't serialize f32::INFINITY, f32::NEG_INFINITY, f32::NAN + let datum = Datum::double(f64::INFINITY); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + let datum = Datum::double(f64::NEG_INFINITY); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + let datum = Datum::double(f64::NAN); + let json = serde_json::to_string(&datum).unwrap(); + assert!(serde_json::from_str::(&json).is_err()); + + let datum = Datum::string("iceberg"); + test_fn(datum); + let datum = Datum::bool(true); + test_fn(datum); + let datum = Datum::date(17486); + test_fn(datum); + let datum = Datum::time_from_hms_micro(22, 15, 33, 111).unwrap(); + test_fn(datum); + let datum = Datum::timestamp_micros(1510871468123456); + test_fn(datum); + let datum = Datum::timestamptz_micros(1510871468123456); + test_fn(datum); + let datum = Datum::uuid(Uuid::parse_str("f79c3e09-677c-4bbd-a479-3f349cb785e7").unwrap()); + test_fn(datum); + let datum = Datum::decimal(1420).unwrap(); + test_fn(datum); + let datum = Datum::binary(vec![1, 2, 3, 4, 5]); + test_fn(datum); + let datum = Datum::fixed(vec![1, 2, 3, 4, 5]); + test_fn(datum); + } } diff --git a/crates/iceberg/src/spec/view_metadata.rs b/crates/iceberg/src/spec/view_metadata.rs index 1b52ed45b..35d96d4b0 100644 --- a/crates/iceberg/src/spec/view_metadata.rs +++ b/crates/iceberg/src/spec/view_metadata.rs @@ -304,7 +304,7 @@ fn is_same_version(a: &ViewVersion, b: &ViewVersion) -> bool { } fn is_same_schema(a: &Schema, b: &Schema) -> bool { - a.as_struct() == b.as_struct() && a.identifier_field_ids() == b.identifier_field_ids() + a.as_struct() == b.as_struct() && a.identifier_field_ids().collect::>() == b.identifier_field_ids().collect::>() } /// Manipulating view metadata. diff --git a/crates/iceberg/src/table.rs b/crates/iceberg/src/table.rs index fd8bd28f2..c76b28612 100644 --- a/crates/iceberg/src/table.rs +++ b/crates/iceberg/src/table.rs @@ -16,6 +16,7 @@ // under the License. //! Table API for Apache Iceberg +use crate::arrow::ArrowReaderBuilder; use crate::io::FileIO; use crate::scan::TableScanBuilder; use crate::spec::{TableMetadata, TableMetadataRef}; @@ -70,6 +71,11 @@ impl Table { pub fn readonly(&self) -> bool { self.readonly } + + /// Create a reader for the table. + pub fn reader_builder(&self) -> ArrowReaderBuilder { + ArrowReaderBuilder::new(self.file_io.clone()) + } } /// `StaticTable` is a read-only table struct that can be created from a metadata file or from `TableMetaData` without a catalog. @@ -138,6 +144,11 @@ impl StaticTable { pub fn into_table(self) -> Table { self.0 } + + /// Create a reader for the table. + pub fn reader_builder(&self) -> ArrowReaderBuilder { + ArrowReaderBuilder::new(self.0.file_io.clone()) + } } #[cfg(test)] diff --git a/crates/iceberg/src/writer/base_writer/data_file_writer.rs b/crates/iceberg/src/writer/base_writer/data_file_writer.rs index 442c9f164..638a90584 100644 --- a/crates/iceberg/src/writer/base_writer/data_file_writer.rs +++ b/crates/iceberg/src/writer/base_writer/data_file_writer.rs @@ -108,10 +108,13 @@ impl CurrentFileStatus for DataFileWriter { #[cfg(test)] mod test { - use std::{collections::HashMap, sync::Arc}; + use std::sync::Arc; - use arrow_array::{types::Int64Type, ArrayRef, Int64Array, RecordBatch, StructArray}; - use parquet::{arrow::PARQUET_FIELD_ID_META_KEY, file::properties::WriterProperties}; + use crate::{ + spec::{DataContentType, Schema, Struct}, + Result, + }; + use parquet::file::properties::WriterProperties; use tempfile::TempDir; use crate::{ @@ -123,13 +126,12 @@ mod test { location_generator::{test::MockLocationGenerator, DefaultFileNameGenerator}, ParquetWriterBuilder, }, - tests::check_parquet_data_file, IcebergWriter, IcebergWriterBuilder, }, }; #[tokio::test] - async fn test_data_file_writer() -> Result<(), anyhow::Error> { + async fn test_parquet_writer() -> Result<()> { let temp_dir = TempDir::new().unwrap(); let file_io = FileIOBuilder::new_fs_io().build().unwrap(); let location_gen = @@ -137,181 +139,22 @@ mod test { let file_name_gen = DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); - // prepare data - // Int, Struct(Int), String, List(Int), Struct(Struct(Int)) - let schema = { - let fields = vec![ - arrow_schema::Field::new("col0", arrow_schema::DataType::Int64, true) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "0".to_string(), - )])), - arrow_schema::Field::new( - "col1", - arrow_schema::DataType::Struct( - vec![arrow_schema::Field::new( - "sub_col", - arrow_schema::DataType::Int64, - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "5".to_string(), - )]))] - .into(), - ), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "1".to_string(), - )])), - arrow_schema::Field::new("col2", arrow_schema::DataType::Utf8, true).with_metadata( - HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "2".to_string())]), - ), - arrow_schema::Field::new( - "col3", - arrow_schema::DataType::List(Arc::new( - arrow_schema::Field::new("item", arrow_schema::DataType::Int64, true) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "6".to_string(), - )])), - )), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "3".to_string(), - )])), - arrow_schema::Field::new( - "col4", - arrow_schema::DataType::Struct( - vec![arrow_schema::Field::new( - "sub_col", - arrow_schema::DataType::Struct( - vec![arrow_schema::Field::new( - "sub_sub_col", - arrow_schema::DataType::Int64, - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "7".to_string(), - )]))] - .into(), - ), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "8".to_string(), - )]))] - .into(), - ), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "4".to_string(), - )])), - ]; - Arc::new(arrow_schema::Schema::new(fields)) - }; - let col0 = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; - let col1 = Arc::new(StructArray::new( - vec![ - arrow_schema::Field::new("sub_col", arrow_schema::DataType::Int64, true) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "5".to_string(), - )])), - ] - .into(), - vec![Arc::new(Int64Array::from_iter_values(vec![1; 1024]))], - None, - )); - let col2 = Arc::new(arrow_array::StringArray::from_iter_values(vec![ - "test"; - 1024 - ])) as ArrayRef; - let col3 = Arc::new({ - let list_parts = arrow_array::ListArray::from_iter_primitive::(vec![ - Some( - vec![Some(1),] - ); - 1024 - ]) - .into_parts(); - arrow_array::ListArray::new( - Arc::new(list_parts.0.as_ref().clone().with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "6".to_string(), - )]))), - list_parts.1, - list_parts.2, - list_parts.3, - ) - }) as ArrayRef; - let col4 = Arc::new(StructArray::new( - vec![arrow_schema::Field::new( - "sub_col", - arrow_schema::DataType::Struct( - vec![arrow_schema::Field::new( - "sub_sub_col", - arrow_schema::DataType::Int64, - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "7".to_string(), - )]))] - .into(), - ), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "8".to_string(), - )]))] - .into(), - vec![Arc::new(StructArray::new( - vec![ - arrow_schema::Field::new("sub_sub_col", arrow_schema::DataType::Int64, true) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "7".to_string(), - )])), - ] - .into(), - vec![Arc::new(Int64Array::from_iter_values(vec![1; 1024]))], - None, - ))], - None, - )); - let to_write = - RecordBatch::try_new(schema.clone(), vec![col0, col1, col2, col3, col4]).unwrap(); - - // prepare writer - let pb = ParquetWriterBuilder::new( + let pw = ParquetWriterBuilder::new( WriterProperties::builder().build(), - to_write.schema(), + Arc::new(Schema::builder().build().unwrap()), file_io.clone(), location_gen, file_name_gen, ); - let mut data_file_writer = DataFileWriterBuilder::new(pb) + let mut data_file_writer = DataFileWriterBuilder::new(pw) .build(DataFileWriterConfig::new(None)) .await?; - // write - data_file_writer.write(to_write.clone()).await?; - let res = data_file_writer.close().await?; - assert_eq!(res.len(), 1); - let data_file = res.into_iter().next().unwrap(); - - // check - check_parquet_data_file(&file_io, &data_file, &to_write).await; + let data_file = data_file_writer.close().await.unwrap(); + assert_eq!(data_file.len(), 1); + assert_eq!(data_file[0].file_format, DataFileFormat::Parquet); + assert_eq!(data_file[0].content, DataContentType::Data); + assert_eq!(data_file[0].partition, Struct::empty()); Ok(()) } diff --git a/crates/iceberg/src/writer/file_writer/parquet_writer.rs b/crates/iceberg/src/writer/file_writer/parquet_writer.rs index a67d308af..5f50e417d 100644 --- a/crates/iceberg/src/writer/file_writer/parquet_writer.rs +++ b/crates/iceberg/src/writer/file_writer/parquet_writer.rs @@ -17,13 +17,17 @@ //! The module contains the file writer for parquet file format. -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{ - collections::HashMap, - sync::{atomic::AtomicI64, Arc}, +use super::{ + location_generator::{FileNameGenerator, LocationGenerator}, + track_writer::TrackWriter, + FileWriter, FileWriterBuilder, }; - +use crate::arrow::DEFAULT_MAP_FIELD_NAME; +use crate::spec::{ + visit_schema, Datum, ListType, MapType, NestedFieldRef, PrimitiveLiteral, PrimitiveType, + Schema, SchemaRef, SchemaVisitor, StructType, Type, +}; +use crate::ErrorKind; use crate::{io::FileIO, io::FileWrite, Result}; use crate::{ io::OutputFile, @@ -34,20 +38,31 @@ use crate::{ use arrow_schema::SchemaRef as ArrowSchemaRef; use bytes::Bytes; use futures::future::BoxFuture; -use parquet::{arrow::AsyncArrowWriter, format::FileMetaData}; -use parquet::{arrow::PARQUET_FIELD_ID_META_KEY, file::properties::WriterProperties}; - -use super::{ - location_generator::{FileNameGenerator, LocationGenerator}, - track_writer::TrackWriter, - FileWriter, FileWriterBuilder, +use itertools::Itertools; +use parquet::data_type::{ + BoolType, ByteArrayType, DataType as ParquetDataType, DoubleType, FixedLenByteArrayType, + FloatType, Int32Type, Int64Type, +}; +use parquet::file::properties::WriterProperties; +use parquet::file::statistics::TypedStatistics; +use parquet::{ + arrow::async_writer::AsyncFileWriter as ArrowAsyncFileWriter, arrow::AsyncArrowWriter, + format::FileMetaData, +}; +use parquet::{ + data_type::{ByteArray, FixedLenByteArray}, + file::statistics::{from_thrift, Statistics}, }; +use std::collections::HashMap; +use std::sync::atomic::AtomicI64; +use std::sync::Arc; +use uuid::Uuid; /// ParquetWriterBuilder is used to builder a [`ParquetWriter`] #[derive(Clone)] pub struct ParquetWriterBuilder { props: WriterProperties, - schema: ArrowSchemaRef, + schema: SchemaRef, file_io: FileIO, location_generator: T, @@ -59,7 +74,7 @@ impl ParquetWriterBuilder { /// To construct the write result, the schema should contain the `PARQUET_FIELD_ID_META_KEY` metadata for each field. pub fn new( props: WriterProperties, - schema: ArrowSchemaRef, + schema: SchemaRef, file_io: FileIO, location_generator: T, file_name_generator: F, @@ -78,29 +93,7 @@ impl FileWriterBuilder for ParquetWr type R = ParquetWriter; async fn build(self) -> crate::Result { - // Fetch field id from schema - let field_ids = self - .schema - .fields() - .iter() - .map(|field| { - field - .metadata() - .get(PARQUET_FIELD_ID_META_KEY) - .ok_or_else(|| { - Error::new( - crate::ErrorKind::Unexpected, - "Field id not found in arrow schema metadata.", - ) - })? - .parse::() - .map_err(|err| { - Error::new(crate::ErrorKind::Unexpected, "Failed to parse field id.") - .with_source(err) - }) - }) - .collect::>>()?; - + let arrow_schema: ArrowSchemaRef = Arc::new(self.schema.as_ref().try_into()?); let written_size = Arc::new(AtomicI64::new(0)); let out_file = self.file_io.new_output( self.location_generator @@ -108,76 +101,400 @@ impl FileWriterBuilder for ParquetWr )?; let inner_writer = TrackWriter::new(out_file.writer().await?, written_size.clone()); let async_writer = AsyncFileWriter::new(inner_writer); - let writer = AsyncArrowWriter::try_new(async_writer, self.schema.clone(), Some(self.props)) - .map_err(|err| { - Error::new( - crate::ErrorKind::Unexpected, - "Failed to build parquet writer.", - ) - .with_source(err) - })?; + let writer = + AsyncArrowWriter::try_new(async_writer, arrow_schema.clone(), Some(self.props)) + .map_err(|err| { + Error::new(ErrorKind::Unexpected, "Failed to build parquet writer.") + .with_source(err) + })?; Ok(ParquetWriter { + schema: self.schema.clone(), writer, written_size, current_row_num: 0, out_file, - field_ids, }) } } +struct IndexByParquetPathName { + name_to_id: HashMap, + + field_names: Vec, + + field_id: i32, +} + +impl IndexByParquetPathName { + pub fn new() -> Self { + Self { + name_to_id: HashMap::new(), + field_names: Vec::new(), + field_id: 0, + } + } + + pub fn get(&self, name: &str) -> Option<&i32> { + self.name_to_id.get(name) + } +} + +impl SchemaVisitor for IndexByParquetPathName { + type T = (); + + fn before_struct_field(&mut self, field: &NestedFieldRef) -> Result<()> { + self.field_names.push(field.name.to_string()); + self.field_id = field.id; + Ok(()) + } + + fn after_struct_field(&mut self, _field: &NestedFieldRef) -> Result<()> { + self.field_names.pop(); + Ok(()) + } + + fn before_list_element(&mut self, field: &NestedFieldRef) -> Result<()> { + self.field_names.push(format!("list.{}", field.name)); + self.field_id = field.id; + Ok(()) + } + + fn after_list_element(&mut self, _field: &NestedFieldRef) -> Result<()> { + self.field_names.pop(); + Ok(()) + } + + fn before_map_key(&mut self, field: &NestedFieldRef) -> Result<()> { + self.field_names + .push(format!("{DEFAULT_MAP_FIELD_NAME}.key")); + self.field_id = field.id; + Ok(()) + } + + fn after_map_key(&mut self, _field: &NestedFieldRef) -> Result<()> { + self.field_names.pop(); + Ok(()) + } + + fn before_map_value(&mut self, field: &NestedFieldRef) -> Result<()> { + self.field_names + .push(format!("{DEFAULT_MAP_FIELD_NAME}.value")); + self.field_id = field.id; + Ok(()) + } + + fn after_map_value(&mut self, _field: &NestedFieldRef) -> Result<()> { + self.field_names.pop(); + Ok(()) + } + + fn schema(&mut self, _schema: &Schema, _value: Self::T) -> Result { + Ok(()) + } + + fn field(&mut self, _field: &NestedFieldRef, _value: Self::T) -> Result { + Ok(()) + } + + fn r#struct(&mut self, _struct: &StructType, _results: Vec) -> Result { + Ok(()) + } + + fn list(&mut self, _list: &ListType, _value: Self::T) -> Result { + Ok(()) + } + + fn map(&mut self, _map: &MapType, _key_value: Self::T, _value: Self::T) -> Result { + Ok(()) + } + + fn primitive(&mut self, _p: &PrimitiveType) -> Result { + let full_name = self.field_names.iter().map(String::as_str).join("."); + let field_id = self.field_id; + if let Some(existing_field_id) = self.name_to_id.get(full_name.as_str()) { + return Err(Error::new(ErrorKind::DataInvalid, format!("Invalid schema: multiple fields for name {full_name}: {field_id} and {existing_field_id}"))); + } else { + self.name_to_id.insert(full_name, field_id); + } + + Ok(()) + } +} + /// `ParquetWriter`` is used to write arrow data into parquet file on storage. pub struct ParquetWriter { + schema: SchemaRef, out_file: OutputFile, writer: AsyncArrowWriter>, written_size: Arc, current_row_num: usize, - field_ids: Vec, +} + +/// Used to aggregate min and max value of each column. +struct MinMaxColAggregator { + lower_bounds: HashMap, + upper_bounds: HashMap, + schema: SchemaRef, +} + +impl MinMaxColAggregator { + fn new(schema: SchemaRef) -> Self { + Self { + lower_bounds: HashMap::new(), + upper_bounds: HashMap::new(), + schema, + } + } + + fn update_state( + &mut self, + field_id: i32, + state: &TypedStatistics, + convert_func: impl Fn(::T) -> Result, + ) { + if state.min_is_exact() { + let val = convert_func(state.min().clone()).unwrap(); + self.lower_bounds + .entry(field_id) + .and_modify(|e| { + if *e > val { + *e = val.clone() + } + }) + .or_insert(val); + } + if state.max_is_exact() { + let val = convert_func(state.max().clone()).unwrap(); + self.upper_bounds + .entry(field_id) + .and_modify(|e| { + if *e < val { + *e = val.clone() + } + }) + .or_insert(val); + } + } + + fn update(&mut self, field_id: i32, value: Statistics) -> Result<()> { + let Some(ty) = self + .schema + .field_by_id(field_id) + .map(|f| f.field_type.as_ref()) + else { + // Following java implementation: https://github.com/apache/iceberg/blob/29a2c456353a6120b8c882ed2ab544975b168d7b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java#L163 + // Ignore the field if it is not in schema. + return Ok(()); + }; + let Type::Primitive(ty) = ty.clone() else { + return Err(Error::new( + ErrorKind::Unexpected, + format!( + "Composed type {} is not supported for min max aggregation.", + ty + ), + )); + }; + + match (&ty, value) { + (PrimitiveType::Boolean, Statistics::Boolean(stat)) => { + let convert_func = |v: bool| Result::::Ok(Datum::bool(v)); + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Int, Statistics::Int32(stat)) => { + let convert_func = |v: i32| Result::::Ok(Datum::int(v)); + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Long, Statistics::Int64(stat)) => { + let convert_func = |v: i64| Result::::Ok(Datum::long(v)); + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Float, Statistics::Float(stat)) => { + let convert_func = |v: f32| Result::::Ok(Datum::float(v)); + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Double, Statistics::Double(stat)) => { + let convert_func = |v: f64| Result::::Ok(Datum::double(v)); + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::String, Statistics::ByteArray(stat)) => { + let convert_func = |v: ByteArray| { + Result::::Ok(Datum::string( + String::from_utf8(v.data().to_vec()).unwrap(), + )) + }; + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Binary, Statistics::ByteArray(stat)) => { + let convert_func = + |v: ByteArray| Result::::Ok(Datum::binary(v.data().to_vec())); + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Date, Statistics::Int32(stat)) => { + let convert_func = |v: i32| Result::::Ok(Datum::date(v)); + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Time, Statistics::Int64(stat)) => { + let convert_func = |v: i64| Datum::time_micros(v); + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Timestamp, Statistics::Int64(stat)) => { + let convert_func = |v: i64| Result::::Ok(Datum::timestamp_micros(v)); + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Timestamptz, Statistics::Int64(stat)) => { + let convert_func = |v: i64| Result::::Ok(Datum::timestamptz_micros(v)); + self.update_state::(field_id, &stat, convert_func) + } + ( + PrimitiveType::Decimal { + precision: _, + scale: _, + }, + Statistics::ByteArray(stat), + ) => { + let convert_func = |v: ByteArray| -> Result { + Result::::Ok(Datum::new( + ty.clone(), + PrimitiveLiteral::Decimal(i128::from_le_bytes( + v.data().try_into().unwrap(), + )), + )) + }; + self.update_state::(field_id, &stat, convert_func) + } + ( + PrimitiveType::Decimal { + precision: _, + scale: _, + }, + Statistics::Int32(stat), + ) => { + let convert_func = |v: i32| { + Result::::Ok(Datum::new( + ty.clone(), + PrimitiveLiteral::Decimal(i128::from(v)), + )) + }; + self.update_state::(field_id, &stat, convert_func) + } + ( + PrimitiveType::Decimal { + precision: _, + scale: _, + }, + Statistics::Int64(stat), + ) => { + let convert_func = |v: i64| { + Result::::Ok(Datum::new( + ty.clone(), + PrimitiveLiteral::Decimal(i128::from(v)), + )) + }; + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Uuid, Statistics::FixedLenByteArray(stat)) => { + let convert_func = |v: FixedLenByteArray| { + if v.len() != 16 { + return Err(Error::new( + ErrorKind::Unexpected, + "Invalid length of uuid bytes.", + )); + } + Ok(Datum::uuid(Uuid::from_bytes( + v.data()[..16].try_into().unwrap(), + ))) + }; + self.update_state::(field_id, &stat, convert_func) + } + (PrimitiveType::Fixed(len), Statistics::FixedLenByteArray(stat)) => { + let convert_func = |v: FixedLenByteArray| { + if v.len() != *len as usize { + return Err(Error::new( + ErrorKind::Unexpected, + "Invalid length of fixed bytes.", + )); + } + Ok(Datum::fixed(v.data().to_vec())) + }; + self.update_state::(field_id, &stat, convert_func) + } + (ty, value) => { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Statistics {} is not match with field type {}.", value, ty), + )) + } + } + Ok(()) + } + + fn produce(self) -> (HashMap, HashMap) { + (self.lower_bounds, self.upper_bounds) + } } impl ParquetWriter { fn to_data_file_builder( - field_ids: &[i32], + schema: SchemaRef, metadata: FileMetaData, written_size: usize, file_path: String, ) -> Result { - // Only enter here when the file is not empty. - assert!(!metadata.row_groups.is_empty()); - if field_ids.len() != metadata.row_groups[0].columns.len() { - return Err(Error::new( - crate::ErrorKind::Unexpected, - "Len of field id is not match with len of columns in parquet metadata.", - )); - } + let index_by_parquet_path = { + let mut visitor = IndexByParquetPathName::new(); + visit_schema(&schema, &mut visitor)?; + visitor + }; - let (column_sizes, value_counts, null_value_counts) = - { - let mut per_col_size: HashMap = HashMap::new(); - let mut per_col_val_num: HashMap = HashMap::new(); - let mut per_col_null_val_num: HashMap = HashMap::new(); - metadata.row_groups.iter().for_each(|group| { - group.columns.iter().zip(field_ids.iter()).for_each( - |(column_chunk, &field_id)| { - if let Some(column_chunk_metadata) = &column_chunk.meta_data { - *per_col_size.entry(field_id).or_insert(0) += - column_chunk_metadata.total_compressed_size as u64; - *per_col_val_num.entry(field_id).or_insert(0) += - column_chunk_metadata.num_values as u64; - *per_col_null_val_num.entry(field_id).or_insert(0_u64) += - column_chunk_metadata - .statistics - .as_ref() - .map(|s| s.null_count) - .unwrap_or(None) - .unwrap_or(0) as u64; - } - }, - ) - }); - (per_col_size, per_col_val_num, per_col_null_val_num) - }; + let (column_sizes, value_counts, null_value_counts, (lower_bounds, upper_bounds)) = { + let mut per_col_size: HashMap = HashMap::new(); + let mut per_col_val_num: HashMap = HashMap::new(); + let mut per_col_null_val_num: HashMap = HashMap::new(); + let mut min_max_agg = MinMaxColAggregator::new(schema); + + for row_group in &metadata.row_groups { + for column_chunk in row_group.columns.iter() { + let Some(column_chunk_metadata) = &column_chunk.meta_data else { + continue; + }; + let physical_type = column_chunk_metadata.type_; + let Some(&field_id) = + index_by_parquet_path.get(&column_chunk_metadata.path_in_schema.join(".")) + else { + // Following java implementation: https://github.com/apache/iceberg/blob/29a2c456353a6120b8c882ed2ab544975b168d7b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java#L163 + // Ignore the field if it is not in schema. + continue; + }; + *per_col_size.entry(field_id).or_insert(0) += + column_chunk_metadata.total_compressed_size as u64; + *per_col_val_num.entry(field_id).or_insert(0) += + column_chunk_metadata.num_values as u64; + if let Some(null_count) = column_chunk_metadata + .statistics + .as_ref() + .and_then(|s| s.null_count) + { + *per_col_null_val_num.entry(field_id).or_insert(0_u64) += null_count as u64; + } + if let Some(statistics) = &column_chunk_metadata.statistics { + min_max_agg.update( + field_id, + from_thrift(physical_type.try_into()?, Some(statistics.clone()))? + .unwrap(), + )?; + } + } + } + + ( + per_col_size, + per_col_val_num, + per_col_null_val_num, + min_max_agg.produce(), + ) + }; let mut builder = DataFileBuilder::default(); builder @@ -188,10 +505,11 @@ impl ParquetWriter { .column_sizes(column_sizes) .value_counts(value_counts) .null_value_counts(null_value_counts) - // # TODO + .lower_bounds(lower_bounds) + .upper_bounds(upper_bounds) + // # TODO(#417) // - nan_value_counts - // - lower_bounds - // - upper_bounds + // - distinct_counts .key_metadata(metadata.footer_signing_key_metadata.unwrap_or_default()) .split_offsets( metadata @@ -209,7 +527,7 @@ impl FileWriter for ParquetWriter { self.current_row_num += batch.num_rows(); self.writer.write(batch).await.map_err(|err| { Error::new( - crate::ErrorKind::Unexpected, + ErrorKind::Unexpected, "Failed to write using parquet writer.", ) .with_source(err) @@ -219,17 +537,13 @@ impl FileWriter for ParquetWriter { async fn close(self) -> crate::Result> { let metadata = self.writer.close().await.map_err(|err| { - Error::new( - crate::ErrorKind::Unexpected, - "Failed to close parquet writer.", - ) - .with_source(err) + Error::new(ErrorKind::Unexpected, "Failed to close parquet writer.").with_source(err) })?; let written_size = self.written_size.load(std::sync::atomic::Ordering::Relaxed); Ok(vec![Self::to_data_file_builder( - &self.field_ids, + self.schema, metadata, written_size as usize, self.out_file.location().to_string(), @@ -256,123 +570,187 @@ impl CurrentFileStatus for ParquetWriter { /// # NOTES /// /// We keep this wrapper been used inside only. -/// -/// # TODO -/// -/// Maybe we can use the buffer from ArrowWriter directly. -struct AsyncFileWriter(State); - -enum State { - Idle(Option), - Write(BoxFuture<'static, (W, Result<()>)>), - Close(BoxFuture<'static, (W, Result<()>)>), -} +struct AsyncFileWriter(W); impl AsyncFileWriter { /// Create a new `AsyncFileWriter` with the given writer. pub fn new(writer: W) -> Self { - Self(State::Idle(Some(writer))) + Self(writer) } } -impl tokio::io::AsyncWrite for AsyncFileWriter { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let this = self.get_mut(); - loop { - match &mut this.0 { - State::Idle(w) => { - let mut writer = w.take().unwrap(); - let bs = Bytes::copy_from_slice(buf); - let fut = async move { - let res = writer.write(bs).await; - (writer, res) - }; - this.0 = State::Write(Box::pin(fut)); - } - State::Write(fut) => { - let (writer, res) = futures::ready!(fut.as_mut().poll(cx)); - this.0 = State::Idle(Some(writer)); - return Poll::Ready(res.map(|_| buf.len()).map_err(|err| { - std::io::Error::new(std::io::ErrorKind::Other, Box::new(err)) - })); - } - State::Close(_) => { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "file is closed", - ))); - } - } - } +impl ArrowAsyncFileWriter for AsyncFileWriter { + fn write(&mut self, bs: Bytes) -> BoxFuture<'_, parquet::errors::Result<()>> { + Box::pin(async { + self.0 + .write(bs) + .await + .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))) + }) } - fn poll_flush( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) - } - - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let this = self.get_mut(); - loop { - match &mut this.0 { - State::Idle(w) => { - let mut writer = w.take().unwrap(); - let fut = async move { - let res = writer.close().await; - (writer, res) - }; - this.0 = State::Close(Box::pin(fut)); - } - State::Write(_) => { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "file is writing", - ))); - } - State::Close(fut) => { - let (writer, res) = futures::ready!(fut.as_mut().poll(cx)); - this.0 = State::Idle(Some(writer)); - return Poll::Ready(res.map_err(|err| { - std::io::Error::new(std::io::ErrorKind::Other, Box::new(err)) - })); - } - } - } + fn complete(&mut self) -> BoxFuture<'_, parquet::errors::Result<()>> { + Box::pin(async { + self.0 + .close() + .await + .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))) + }) } } #[cfg(test)] mod tests { + use std::collections::HashMap; use std::sync::Arc; use anyhow::Result; use arrow_array::types::Int64Type; use arrow_array::ArrayRef; + use arrow_array::BooleanArray; + use arrow_array::Int32Array; use arrow_array::Int64Array; + use arrow_array::ListArray; use arrow_array::RecordBatch; use arrow_array::StructArray; + use arrow_schema::DataType; + use arrow_schema::SchemaRef as ArrowSchemaRef; use arrow_select::concat::concat_batches; use parquet::arrow::PARQUET_FIELD_ID_META_KEY; use tempfile::TempDir; use super::*; use crate::io::FileIOBuilder; - use crate::spec::Struct; + use crate::spec::*; + use crate::spec::{PrimitiveLiteral, Struct}; use crate::writer::file_writer::location_generator::test::MockLocationGenerator; use crate::writer::file_writer::location_generator::DefaultFileNameGenerator; use crate::writer::tests::check_parquet_data_file; - #[derive(Clone)] - struct TestLocationGen; + fn schema_for_all_type() -> Schema { + Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::optional(0, "boolean", Type::Primitive(PrimitiveType::Boolean)).into(), + NestedField::optional(1, "int", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(2, "long", Type::Primitive(PrimitiveType::Long)).into(), + NestedField::optional(3, "float", Type::Primitive(PrimitiveType::Float)).into(), + NestedField::optional(4, "double", Type::Primitive(PrimitiveType::Double)).into(), + NestedField::optional(5, "string", Type::Primitive(PrimitiveType::String)).into(), + NestedField::optional(6, "binary", Type::Primitive(PrimitiveType::Binary)).into(), + NestedField::optional(7, "date", Type::Primitive(PrimitiveType::Date)).into(), + NestedField::optional(8, "time", Type::Primitive(PrimitiveType::Time)).into(), + NestedField::optional(9, "timestamp", Type::Primitive(PrimitiveType::Timestamp)) + .into(), + NestedField::optional( + 10, + "timestamptz", + Type::Primitive(PrimitiveType::Timestamptz), + ) + .into(), + NestedField::optional( + 11, + "decimal", + Type::Primitive(PrimitiveType::Decimal { + precision: 10, + scale: 5, + }), + ) + .into(), + NestedField::optional(12, "uuid", Type::Primitive(PrimitiveType::Uuid)).into(), + NestedField::optional(13, "fixed", Type::Primitive(PrimitiveType::Fixed(10))) + .into(), + ]) + .build() + .unwrap() + } + + fn nested_schema_for_test() -> Schema { + // Int, Struct(Int,Int), String, List(Int), Struct(Struct(Int)), Map(String, List(Int)) + Schema::builder() + .with_schema_id(1) + .with_fields(vec![ + NestedField::required(0, "col0", Type::Primitive(PrimitiveType::Long)).into(), + NestedField::required( + 1, + "col1", + Type::Struct(StructType::new(vec![ + NestedField::required(5, "col_1_5", Type::Primitive(PrimitiveType::Long)) + .into(), + NestedField::required(6, "col_1_6", Type::Primitive(PrimitiveType::Long)) + .into(), + ])), + ) + .into(), + NestedField::required(2, "col2", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required( + 3, + "col3", + Type::List(ListType::new( + NestedField::required(7, "element", Type::Primitive(PrimitiveType::Long)) + .into(), + )), + ) + .into(), + NestedField::required( + 4, + "col4", + Type::Struct(StructType::new(vec![NestedField::required( + 8, + "col_4_8", + Type::Struct(StructType::new(vec![NestedField::required( + 9, + "col_4_8_9", + Type::Primitive(PrimitiveType::Long), + ) + .into()])), + ) + .into()])), + ) + .into(), + NestedField::required( + 10, + "col5", + Type::Map(MapType::new( + NestedField::required(11, "key", Type::Primitive(PrimitiveType::String)) + .into(), + NestedField::required( + 12, + "value", + Type::List(ListType::new( + NestedField::required( + 13, + "item", + Type::Primitive(PrimitiveType::Long), + ) + .into(), + )), + ) + .into(), + )), + ) + .into(), + ]) + .build() + .unwrap() + } + + #[tokio::test] + async fn test_index_by_parquet_path() { + let expect = HashMap::from([ + ("col0".to_string(), 0), + ("col1.col_1_5".to_string(), 5), + ("col1.col_1_6".to_string(), 6), + ("col2".to_string(), 2), + ("col3.list.element".to_string(), 7), + ("col4.col_4_8.col_4_8_9".to_string(), 9), + ("col5.key_value.key".to_string(), 11), + ("col5.key_value.value.list.item".to_string(), 13), + ]); + let mut visitor = IndexByParquetPathName::new(); + visit_schema(&nested_schema_for_test(), &mut visitor).unwrap(); + assert_eq!(visitor.name_to_id, expect); + } #[tokio::test] async fn test_parquet_writer() -> Result<()> { @@ -392,7 +770,7 @@ mod tests { ]; Arc::new(arrow_schema::Schema::new(fields)) }; - let col = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; + let col = Arc::new(Int64Array::from_iter_values(0..1024)) as ArrayRef; let null_col = Arc::new(Int64Array::new_null(1024)) as ArrayRef; let to_write = RecordBatch::try_new(schema.clone(), vec![col]).unwrap(); let to_write_null = RecordBatch::try_new(schema.clone(), vec![null_col]).unwrap(); @@ -400,7 +778,7 @@ mod tests { // write data let mut pw = ParquetWriterBuilder::new( WriterProperties::builder().build(), - to_write.schema(), + Arc::new(to_write.schema().as_ref().try_into().unwrap()), file_io.clone(), loccation_gen, file_name_gen, @@ -421,6 +799,19 @@ mod tests { .build() .unwrap(); + // check data file + assert_eq!(data_file.record_count(), 2048); + assert_eq!(*data_file.value_counts(), HashMap::from([(0, 2048)])); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([(0, Datum::long(0))]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([(0, Datum::long(1023))]) + ); + assert_eq!(*data_file.null_value_counts(), HashMap::from([(0, 1024)])); + // check the written file let expect_batch = concat_batches(&schema, vec![&to_write, &to_write_null]).unwrap(); check_parquet_data_file(&file_io, &data_file, &expect_batch).await; @@ -438,164 +829,153 @@ mod tests { DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); // prepare data - // Int, Struct(Int), String, List(Int), Struct(Struct(Int)) - let schema = { - let fields = vec![ - arrow_schema::Field::new("col0", arrow_schema::DataType::Int64, true) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "0".to_string(), - )])), - arrow_schema::Field::new( - "col1", - arrow_schema::DataType::Struct( - vec![arrow_schema::Field::new( - "sub_col", - arrow_schema::DataType::Int64, - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "5".to_string(), - )]))] - .into(), - ), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "1".to_string(), - )])), - arrow_schema::Field::new("col2", arrow_schema::DataType::Utf8, true).with_metadata( - HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "2".to_string())]), - ), - arrow_schema::Field::new( - "col3", - arrow_schema::DataType::List(Arc::new( - arrow_schema::Field::new("item", arrow_schema::DataType::Int64, true) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "6".to_string(), - )])), - )), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "3".to_string(), - )])), - arrow_schema::Field::new( - "col4", - arrow_schema::DataType::Struct( - vec![arrow_schema::Field::new( - "sub_col", - arrow_schema::DataType::Struct( - vec![arrow_schema::Field::new( - "sub_sub_col", - arrow_schema::DataType::Int64, - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "7".to_string(), - )]))] - .into(), - ), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "8".to_string(), - )]))] - .into(), - ), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "4".to_string(), - )])), - ]; - Arc::new(arrow_schema::Schema::new(fields)) - }; - let col0 = Arc::new(Int64Array::from_iter_values(vec![1; 1024])) as ArrayRef; + let schema = nested_schema_for_test(); + let arrow_schema: ArrowSchemaRef = Arc::new((&schema).try_into().unwrap()); + let col0 = Arc::new(Int64Array::from_iter_values(0..1024)) as ArrayRef; let col1 = Arc::new(StructArray::new( + { + if let DataType::Struct(fields) = arrow_schema.field(1).data_type() { + fields.clone() + } else { + unreachable!() + } + }, vec![ - arrow_schema::Field::new("sub_col", arrow_schema::DataType::Int64, true) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "5".to_string(), - )])), - ] - .into(), - vec![Arc::new(Int64Array::from_iter_values(vec![1; 1024]))], + Arc::new(Int64Array::from_iter_values(0..1024)), + Arc::new(Int64Array::from_iter_values(0..1024)), + ], None, )); - let col2 = Arc::new(arrow_array::StringArray::from_iter_values(vec![ - "test"; - 1024 - ])) as ArrayRef; + let col2 = Arc::new(arrow_array::StringArray::from_iter_values( + (0..1024).map(|n| n.to_string()), + )) as ArrayRef; let col3 = Arc::new({ - let list_parts = arrow_array::ListArray::from_iter_primitive::(vec![ - Some( - vec![Some(1),] - ); - 1024 - ]) + let list_parts = arrow_array::ListArray::from_iter_primitive::( + (0..1024).map(|n| Some(vec![Some(n)])), + ) .into_parts(); arrow_array::ListArray::new( - Arc::new(list_parts.0.as_ref().clone().with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "6".to_string(), - )]))), + { + if let DataType::List(field) = arrow_schema.field(3).data_type() { + field.clone() + } else { + unreachable!() + } + }, list_parts.1, list_parts.2, list_parts.3, ) }) as ArrayRef; let col4 = Arc::new(StructArray::new( - vec![arrow_schema::Field::new( - "sub_col", - arrow_schema::DataType::Struct( - vec![arrow_schema::Field::new( - "sub_sub_col", - arrow_schema::DataType::Int64, - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "7".to_string(), - )]))] - .into(), - ), - true, - ) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "8".to_string(), - )]))] - .into(), + { + if let DataType::Struct(fields) = arrow_schema.field(4).data_type() { + fields.clone() + } else { + unreachable!() + } + }, vec![Arc::new(StructArray::new( - vec![ - arrow_schema::Field::new("sub_sub_col", arrow_schema::DataType::Int64, true) - .with_metadata(HashMap::from([( - PARQUET_FIELD_ID_META_KEY.to_string(), - "7".to_string(), - )])), - ] - .into(), - vec![Arc::new(Int64Array::from_iter_values(vec![1; 1024]))], + { + if let DataType::Struct(fields) = arrow_schema.field(4).data_type() { + if let DataType::Struct(fields) = fields[0].data_type() { + fields.clone() + } else { + unreachable!() + } + } else { + unreachable!() + } + }, + vec![Arc::new(Int64Array::from_iter_values(0..1024))], None, ))], None, )); - let to_write = - RecordBatch::try_new(schema.clone(), vec![col0, col1, col2, col3, col4]).unwrap(); + let col5 = Arc::new({ + let mut map_array_builder = arrow_array::builder::MapBuilder::new( + None, + arrow_array::builder::StringBuilder::new(), + arrow_array::builder::ListBuilder::new(arrow_array::builder::PrimitiveBuilder::< + Int64Type, + >::new()), + ); + for i in 0..1024 { + map_array_builder.keys().append_value(i.to_string()); + map_array_builder + .values() + .append_value(vec![Some(i as i64); i + 1]); + map_array_builder.append(true)?; + } + let (_, offset_buffer, struct_array, null_buffer, ordered) = + map_array_builder.finish().into_parts(); + let struct_array = { + let (_, mut arrays, nulls) = struct_array.into_parts(); + let list_array = { + let list_array = arrays[1] + .as_any() + .downcast_ref::() + .unwrap() + .clone(); + let (_, offsets, array, nulls) = list_array.into_parts(); + let list_field = { + if let DataType::Map(map_field, _) = arrow_schema.field(5).data_type() { + if let DataType::Struct(fields) = map_field.data_type() { + if let DataType::List(list_field) = fields[1].data_type() { + list_field.clone() + } else { + unreachable!() + } + } else { + unreachable!() + } + } else { + unreachable!() + } + }; + ListArray::new(list_field, offsets, array, nulls) + }; + arrays[1] = Arc::new(list_array) as ArrayRef; + StructArray::new( + { + if let DataType::Map(map_field, _) = arrow_schema.field(5).data_type() { + if let DataType::Struct(fields) = map_field.data_type() { + fields.clone() + } else { + unreachable!() + } + } else { + unreachable!() + } + }, + arrays, + nulls, + ) + }; + arrow_array::MapArray::new( + { + if let DataType::Map(map_field, _) = arrow_schema.field(5).data_type() { + map_field.clone() + } else { + unreachable!() + } + }, + offset_buffer, + struct_array, + null_buffer, + ordered, + ) + }) as ArrayRef; + let to_write = RecordBatch::try_new( + arrow_schema.clone(), + vec![col0, col1, col2, col3, col4, col5], + ) + .unwrap(); // write data let mut pw = ParquetWriterBuilder::new( WriterProperties::builder().build(), - to_write.schema(), + Arc::new(schema), file_io.clone(), location_gen, file_name_gen, @@ -615,6 +995,253 @@ mod tests { .build() .unwrap(); + // check data file + assert_eq!(data_file.record_count(), 1024); + assert_eq!( + *data_file.value_counts(), + HashMap::from([ + (0, 1024), + (5, 1024), + (6, 1024), + (2, 1024), + (7, 1024), + (9, 1024), + (11, 1024), + (13, (1..1025).sum()), + ]) + ); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([ + (0, Datum::long(0)), + (5, Datum::long(0)), + (6, Datum::long(0)), + (2, Datum::string("0")), + (7, Datum::long(0)), + (9, Datum::long(0)), + (11, Datum::string("0")), + (13, Datum::long(0)) + ]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([ + (0, Datum::long(1023)), + (5, Datum::long(1023)), + (6, Datum::long(1023)), + (2, Datum::string("999")), + (7, Datum::long(1023)), + (9, Datum::long(1023)), + (11, Datum::string("999")), + (13, Datum::long(1023)) + ]) + ); + + // check the written file + check_parquet_data_file(&file_io, &data_file, &to_write).await; + + Ok(()) + } + + #[tokio::test] + async fn test_all_type_for_write() -> Result<()> { + let temp_dir = TempDir::new().unwrap(); + let file_io = FileIOBuilder::new_fs_io().build().unwrap(); + let loccation_gen = + MockLocationGenerator::new(temp_dir.path().to_str().unwrap().to_string()); + let file_name_gen = + DefaultFileNameGenerator::new("test".to_string(), None, DataFileFormat::Parquet); + + // prepare data + // generate iceberg schema for all type + let schema = schema_for_all_type(); + let arrow_schema: ArrowSchemaRef = Arc::new((&schema).try_into().unwrap()); + let col0 = Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + None, + Some(true), + ])) as ArrayRef; + let col1 = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])) as ArrayRef; + let col2 = Arc::new(Int64Array::from(vec![Some(1), Some(2), None, Some(4)])) as ArrayRef; + let col3 = Arc::new(arrow_array::Float32Array::from(vec![ + Some(0.5), + Some(2.0), + None, + Some(3.5), + ])) as ArrayRef; + let col4 = Arc::new(arrow_array::Float64Array::from(vec![ + Some(0.5), + Some(2.0), + None, + Some(3.5), + ])) as ArrayRef; + let col5 = Arc::new(arrow_array::StringArray::from(vec![ + Some("a"), + Some("b"), + None, + Some("d"), + ])) as ArrayRef; + let col6 = Arc::new(arrow_array::LargeBinaryArray::from_opt_vec(vec![ + Some(b"one"), + None, + Some(b""), + Some(b"zzzz"), + ])) as ArrayRef; + let col7 = Arc::new(arrow_array::Date32Array::from(vec![ + Some(0), + Some(1), + None, + Some(3), + ])) as ArrayRef; + let col8 = Arc::new(arrow_array::Time64MicrosecondArray::from(vec![ + Some(0), + Some(1), + None, + Some(3), + ])) as ArrayRef; + let col9 = Arc::new(arrow_array::TimestampMicrosecondArray::from(vec![ + Some(0), + Some(1), + None, + Some(3), + ])) as ArrayRef; + let col10 = Arc::new( + arrow_array::TimestampMicrosecondArray::from(vec![Some(0), Some(1), None, Some(3)]) + .with_timezone_utc(), + ) as ArrayRef; + let col11 = Arc::new( + arrow_array::Decimal128Array::from(vec![Some(1), Some(2), None, Some(100)]) + .with_precision_and_scale(10, 5) + .unwrap(), + ) as ArrayRef; + let col12 = Arc::new( + arrow_array::FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![ + Some(Uuid::from_u128(0).as_bytes().to_vec()), + Some(Uuid::from_u128(1).as_bytes().to_vec()), + None, + Some(Uuid::from_u128(3).as_bytes().to_vec()), + ] + .into_iter(), + 16, + ) + .unwrap(), + ) as ArrayRef; + let col13 = Arc::new( + arrow_array::FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![ + Some(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]), + Some(vec![11, 12, 13, 14, 15, 16, 17, 18, 19, 20]), + None, + Some(vec![21, 22, 23, 24, 25, 26, 27, 28, 29, 30]), + ] + .into_iter(), + 10, + ) + .unwrap(), + ) as ArrayRef; + let to_write = RecordBatch::try_new( + arrow_schema.clone(), + vec![ + col0, col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11, col12, + col13, + ], + ) + .unwrap(); + + // write data + let mut pw = ParquetWriterBuilder::new( + WriterProperties::builder().build(), + Arc::new(schema), + file_io.clone(), + loccation_gen, + file_name_gen, + ) + .build() + .await?; + pw.write(&to_write).await?; + let res = pw.close().await?; + assert_eq!(res.len(), 1); + let data_file = res + .into_iter() + .next() + .unwrap() + // Put dummy field for build successfully. + .content(crate::spec::DataContentType::Data) + .partition(Struct::empty()) + .build() + .unwrap(); + + // check data file + assert_eq!(data_file.record_count(), 4); + assert!(data_file.value_counts().iter().all(|(_, &v)| { v == 4 })); + assert!(data_file + .null_value_counts() + .iter() + .all(|(_, &v)| { v == 1 })); + assert_eq!( + *data_file.lower_bounds(), + HashMap::from([ + (0, Datum::bool(false)), + (1, Datum::int(1)), + (2, Datum::long(1)), + (3, Datum::float(0.5)), + (4, Datum::double(0.5)), + (5, Datum::string("a")), + (6, Datum::binary(vec![])), + (7, Datum::date(0)), + (8, Datum::time_micros(0).unwrap()), + (9, Datum::timestamp_micros(0)), + (10, Datum::timestamptz_micros(0)), + ( + 11, + Datum::new( + PrimitiveType::Decimal { + precision: 10, + scale: 5 + }, + PrimitiveLiteral::Decimal(1) + ) + ), + (12, Datum::uuid(Uuid::from_u128(0))), + (13, Datum::fixed(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])), + (12, Datum::uuid(Uuid::from_u128(0))), + (13, Datum::fixed(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10])), + ]) + ); + assert_eq!( + *data_file.upper_bounds(), + HashMap::from([ + (0, Datum::bool(true)), + (1, Datum::int(4)), + (2, Datum::long(4)), + (3, Datum::float(3.5)), + (4, Datum::double(3.5)), + (5, Datum::string("d")), + (6, Datum::binary(vec![122, 122, 122, 122])), + (7, Datum::date(3)), + (8, Datum::time_micros(3).unwrap()), + (9, Datum::timestamp_micros(3)), + (10, Datum::timestamptz_micros(3)), + ( + 11, + Datum::new( + PrimitiveType::Decimal { + precision: 10, + scale: 5 + }, + PrimitiveLiteral::Decimal(100) + ) + ), + (12, Datum::uuid(Uuid::from_u128(3))), + ( + 13, + Datum::fixed(vec![21, 22, 23, 24, 25, 26, 27, 28, 29, 30]) + ), + ]) + ); + // check the written file check_parquet_data_file(&file_io, &data_file, &to_write).await; diff --git a/crates/iceberg/src/writer/mod.rs b/crates/iceberg/src/writer/mod.rs index 5f3ae5581..06b763d6e 100644 --- a/crates/iceberg/src/writer/mod.rs +++ b/crates/iceberg/src/writer/mod.rs @@ -127,62 +127,11 @@ mod tests { let input_content = input_file.read().await.unwrap(); let reader_builder = ParquetRecordBatchReaderBuilder::try_new(input_content.clone()).unwrap(); - let metadata = reader_builder.metadata().clone(); // check data let reader = reader_builder.build().unwrap(); let batches = reader.map(|batch| batch.unwrap()).collect::>(); let res = concat_batches(&batch.schema(), &batches).unwrap(); assert_eq!(*batch, res); - - // check metadata - let expect_column_num = batch.num_columns(); - - assert_eq!( - data_file.record_count, - metadata - .row_groups() - .iter() - .map(|group| group.num_rows()) - .sum::() as u64 - ); - - assert_eq!(data_file.file_size_in_bytes, input_content.len() as u64); - - assert_eq!(data_file.column_sizes.len(), expect_column_num); - data_file.column_sizes.iter().for_each(|(&k, &v)| { - let expect = metadata - .row_groups() - .iter() - .map(|group| group.column(k as usize).compressed_size()) - .sum::() as u64; - assert_eq!(v, expect); - }); - - assert_eq!(data_file.value_counts.len(), expect_column_num); - data_file.value_counts.iter().for_each(|(_, &v)| { - let expect = metadata - .row_groups() - .iter() - .map(|group| group.num_rows()) - .sum::() as u64; - assert_eq!(v, expect); - }); - - assert_eq!(data_file.null_value_counts.len(), expect_column_num); - data_file.null_value_counts.iter().for_each(|(&k, &v)| { - let expect = batch.column(k as usize).null_count() as u64; - assert_eq!(v, expect); - }); - - assert_eq!(data_file.split_offsets.len(), metadata.num_row_groups()); - data_file - .split_offsets - .iter() - .enumerate() - .for_each(|(i, &v)| { - let expect = metadata.row_groups()[i].file_offset().unwrap(); - assert_eq!(v, expect); - }); } } diff --git a/crates/iceberg/testdata/example_table_metadata_v2.json b/crates/iceberg/testdata/example_table_metadata_v2.json index 809c35587..cf9fef96d 100644 --- a/crates/iceberg/testdata/example_table_metadata_v2.json +++ b/crates/iceberg/testdata/example_table_metadata_v2.json @@ -15,7 +15,8 @@ "fields": [ {"id": 1, "name": "x", "required": true, "type": "long"}, {"id": 2, "name": "y", "required": true, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": true, "type": "long"} + {"id": 3, "name": "z", "required": true, "type": "long"}, + {"id": 4, "name": "a", "required": true, "type": "string"} ] } ], diff --git a/crates/iceberg/testdata/file_io_s3/docker-compose.yaml b/crates/iceberg/testdata/file_io_s3/docker-compose.yaml index 3c1cfcc78..0793d225b 100644 --- a/crates/iceberg/testdata/file_io_s3/docker-compose.yaml +++ b/crates/iceberg/testdata/file_io_s3/docker-compose.yaml @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -version: '3' services: minio: image: minio/minio:RELEASE.2024-02-26T09-33-48Z diff --git a/crates/iceberg/tests/file_io_s3_test.rs b/crates/iceberg/tests/file_io_s3_test.rs index 36e24f153..6d62a0416 100644 --- a/crates/iceberg/tests/file_io_s3_test.rs +++ b/crates/iceberg/tests/file_io_s3_test.rs @@ -17,86 +17,79 @@ //! Integration tests for FileIO S3. +use ctor::{ctor, dtor}; use iceberg::io::{ FileIO, FileIOBuilder, S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY, }; use iceberg_test_utils::docker::DockerCompose; +use iceberg_test_utils::{normalize_test_name, set_up}; +use std::sync::RwLock; -struct MinIOFixture { - _docker_compose: DockerCompose, - file_io: FileIO, +static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); + +#[ctor] +fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + let docker_compose = DockerCompose::new( + normalize_test_name(module_path!()), + format!("{}/testdata/file_io_s3", env!("CARGO_MANIFEST_DIR")), + ); + docker_compose.run(); + guard.replace(docker_compose); } -impl MinIOFixture { - async fn new(project_name: impl ToString) -> Self { - // Start the Docker container for the test fixture - let docker = DockerCompose::new( - project_name.to_string(), - format!("{}/testdata/file_io_s3", env!("CARGO_MANIFEST_DIR")), - ); - docker.run(); - let container_ip = docker.get_container_ip("minio"); - let read_port = format!("{}:{}", container_ip, 9000); - MinIOFixture { - _docker_compose: docker, - file_io: FileIOBuilder::new("s3") - .with_props(vec![ - (S3_ENDPOINT, format!("http://{}", read_port)), - (S3_ACCESS_KEY_ID, "admin".to_string()), - (S3_SECRET_ACCESS_KEY, "password".to_string()), - (S3_REGION, "us-east-1".to_string()), - ]) - .build() - .unwrap(), - } - } +#[dtor] +fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); +} + +async fn get_file_io() -> FileIO { + set_up(); + + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + let container_ip = docker_compose.get_container_ip("minio"); + let read_port = format!("{}:{}", container_ip, 9000); + + FileIOBuilder::new("s3") + .with_props(vec![ + (S3_ENDPOINT, format!("http://{}", read_port)), + (S3_ACCESS_KEY_ID, "admin".to_string()), + (S3_SECRET_ACCESS_KEY, "password".to_string()), + (S3_REGION, "us-east-1".to_string()), + ]) + .build() + .unwrap() } #[tokio::test] async fn test_file_io_s3_is_exist() { - let fixture = MinIOFixture::new("test_file_io_s3_is_exist").await; - assert!(!fixture.file_io.is_exist("s3://bucket2/any").await.unwrap()); - assert!(fixture.file_io.is_exist("s3://bucket1/").await.unwrap()); + let file_io = get_file_io().await; + assert!(!file_io.is_exist("s3://bucket2/any").await.unwrap()); + assert!(file_io.is_exist("s3://bucket1/").await.unwrap()); } #[tokio::test] async fn test_file_io_s3_output() { - // Start the Docker container for the test fixture - let fixture = MinIOFixture::new("test_file_io_s3_output").await; - assert!(!fixture - .file_io - .is_exist("s3://bucket1/test_output") - .await - .unwrap()); - let output_file = fixture - .file_io - .new_output("s3://bucket1/test_output") - .unwrap(); + let file_io = get_file_io().await; + assert!(!file_io.is_exist("s3://bucket1/test_output").await.unwrap()); + let output_file = file_io.new_output("s3://bucket1/test_output").unwrap(); { output_file.write("123".into()).await.unwrap(); } - assert!(fixture - .file_io - .is_exist("s3://bucket1/test_output") - .await - .unwrap()); + assert!(file_io.is_exist("s3://bucket1/test_output").await.unwrap()); } #[tokio::test] async fn test_file_io_s3_input() { - let fixture = MinIOFixture::new("test_file_io_s3_input").await; - let output_file = fixture - .file_io - .new_output("s3://bucket1/test_input") - .unwrap(); + let file_io = get_file_io().await; + let output_file = file_io.new_output("s3://bucket1/test_input").unwrap(); { output_file.write("test_input".into()).await.unwrap(); } - let input_file = fixture - .file_io - .new_input("s3://bucket1/test_input") - .unwrap(); + let input_file = file_io.new_input("s3://bucket1/test_input").unwrap(); { let buffer = input_file.read().await.unwrap(); diff --git a/crates/integrations/datafusion/Cargo.toml b/crates/integrations/datafusion/Cargo.toml index 56036dcc7..7e12d73e8 100644 --- a/crates/integrations/datafusion/Cargo.toml +++ b/crates/integrations/datafusion/Cargo.toml @@ -31,14 +31,14 @@ keywords = ["iceberg", "integrations", "datafusion"] [dependencies] anyhow = { workspace = true } async-trait = { workspace = true } -datafusion = { version = "38.0.0" } +datafusion = { version = "39.0.0" } futures = { workspace = true } iceberg = { workspace = true } log = { workspace = true } tokio = { workspace = true } [dev-dependencies] +ctor = { workspace = true } iceberg-catalog-hms = { workspace = true } iceberg_test_utils = { path = "../../test_utils", features = ["tests"] } -opendal = { workspace = true, features = ["services-s3"] } port_scanner = { workspace = true } diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs b/crates/integrations/datafusion/src/physical_plan/scan.rs index cc01148f1..8a74caa6a 100644 --- a/crates/integrations/datafusion/src/physical_plan/scan.rs +++ b/crates/integrations/datafusion/src/physical_plan/scan.rs @@ -76,7 +76,7 @@ impl ExecutionPlan for IcebergTableScan { self } - fn children(&self) -> Vec> { + fn children(&self) -> Vec<&Arc<(dyn ExecutionPlan + 'static)>> { vec![] } diff --git a/crates/integrations/datafusion/src/schema.rs b/crates/integrations/datafusion/src/schema.rs index 2ba69621a..f7b1a21d2 100644 --- a/crates/integrations/datafusion/src/schema.rs +++ b/crates/integrations/datafusion/src/schema.rs @@ -89,7 +89,7 @@ impl SchemaProvider for IcebergSchemaProvider { } fn table_exist(&self, name: &str) -> bool { - self.tables.get(name).is_some() + self.tables.contains_key(name) } async fn table(&self, name: &str) -> DFResult>> { diff --git a/crates/integrations/datafusion/testdata/docker-compose.yaml b/crates/integrations/datafusion/testdata/docker-compose.yaml index 282dc66ca..be915ab20 100644 --- a/crates/integrations/datafusion/testdata/docker-compose.yaml +++ b/crates/integrations/datafusion/testdata/docker-compose.yaml @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -version: '3.8' - services: minio: image: minio/minio:RELEASE.2024-03-07T00-43-48Z @@ -43,6 +41,7 @@ services: hive-metastore: image: iceberg-hive-metastore build: ./hms_catalog/ + platform: ${DOCKER_DEFAULT_PLATFORM} expose: - 9083 environment: diff --git a/crates/integrations/datafusion/testdata/hms_catalog/Dockerfile b/crates/integrations/datafusion/testdata/hms_catalog/Dockerfile index ff8c9fae6..abece560e 100644 --- a/crates/integrations/datafusion/testdata/hms_catalog/Dockerfile +++ b/crates/integrations/datafusion/testdata/hms_catalog/Dockerfile @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -FROM openjdk:8-jre-slim AS build +FROM --platform=$BUILDPLATFORM openjdk:8-jre-slim AS build RUN apt-get update -qq && apt-get -qq -y install curl diff --git a/crates/integrations/datafusion/tests/integration_datafusion_hms_test.rs b/crates/integrations/datafusion/tests/integration_datafusion_hms_test.rs index 20c5cc872..9ad1d401f 100644 --- a/crates/integrations/datafusion/tests/integration_datafusion_hms_test.rs +++ b/crates/integrations/datafusion/tests/integration_datafusion_hms_test.rs @@ -18,8 +18,9 @@ //! Integration tests for Iceberg Datafusion with Hive Metastore. use std::collections::HashMap; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; +use ctor::{ctor, dtor}; use datafusion::arrow::datatypes::DataType; use datafusion::execution::context::SessionContext; use iceberg::io::{S3_ACCESS_KEY_ID, S3_ENDPOINT, S3_REGION, S3_SECRET_ACCESS_KEY}; @@ -34,24 +35,55 @@ use tokio::time::sleep; const HMS_CATALOG_PORT: u16 = 9083; const MINIO_PORT: u16 = 9000; +static DOCKER_COMPOSE_ENV: RwLock> = RwLock::new(None); struct TestFixture { - _docker_compose: DockerCompose, hms_catalog: HmsCatalog, + props: HashMap, + hms_catalog_ip: String, } -async fn set_test_fixture(func: &str) -> TestFixture { - set_up(); - +#[ctor] +fn before_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); let docker_compose = DockerCompose::new( - normalize_test_name(format!("{}_{func}", module_path!())), + normalize_test_name(module_path!()), format!("{}/testdata", env!("CARGO_MANIFEST_DIR")), ); - docker_compose.run(); + guard.replace(docker_compose); +} + +#[dtor] +fn after_all() { + let mut guard = DOCKER_COMPOSE_ENV.write().unwrap(); + guard.take(); +} + +impl TestFixture { + fn get_catalog(&self) -> HmsCatalog { + let config = HmsCatalogConfig::builder() + .address(format!("{}:{}", self.hms_catalog_ip, HMS_CATALOG_PORT)) + .thrift_transport(HmsThriftTransport::Buffered) + .warehouse("s3a://warehouse/hive".to_string()) + .props(self.props.clone()) + .build(); + + HmsCatalog::new(config).unwrap() + } +} + +async fn get_test_fixture() -> TestFixture { + set_up(); - let hms_catalog_ip = docker_compose.get_container_ip("hive-metastore"); - let minio_ip = docker_compose.get_container_ip("minio"); + let (hms_catalog_ip, minio_ip) = { + let guard = DOCKER_COMPOSE_ENV.read().unwrap(); + let docker_compose = guard.as_ref().unwrap(); + ( + docker_compose.get_container_ip("hive-metastore"), + docker_compose.get_container_ip("minio"), + ) + }; let read_port = format!("{}:{}", hms_catalog_ip, HMS_CATALOG_PORT); loop { @@ -77,17 +109,26 @@ async fn set_test_fixture(func: &str) -> TestFixture { .address(format!("{}:{}", hms_catalog_ip, HMS_CATALOG_PORT)) .thrift_transport(HmsThriftTransport::Buffered) .warehouse("s3a://warehouse/hive".to_string()) - .props(props) + .props(props.clone()) .build(); let hms_catalog = HmsCatalog::new(config).unwrap(); TestFixture { - _docker_compose: docker_compose, hms_catalog, + props, + hms_catalog_ip, } } +async fn set_test_namespace(catalog: &HmsCatalog, namespace: &NamespaceIdent) -> Result<()> { + let properties = HashMap::new(); + + catalog.create_namespace(namespace, properties).await?; + + Ok(()) +} + fn set_table_creation(location: impl ToString, name: impl ToString) -> Result { let schema = Schema::builder() .with_schema_id(0) @@ -109,24 +150,24 @@ fn set_table_creation(location: impl ToString, name: impl ToString) -> Result Result<()> { - let fixture = set_test_fixture("test_provider_get_table_schema").await; + let fixture = get_test_fixture().await; + let namespace = NamespaceIdent::new("test_provider_get_table_schema".to_string()); + set_test_namespace(&fixture.hms_catalog, &namespace).await?; - let namespace = NamespaceIdent::new("default".to_string()); let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - fixture .hms_catalog .create_table(&namespace, creation) .await?; - let client = Arc::new(fixture.hms_catalog); + let client = Arc::new(fixture.get_catalog()); let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); let ctx = SessionContext::new(); ctx.register_catalog("hive", catalog); let provider = ctx.catalog("hive").unwrap(); - let schema = provider.schema("default").unwrap(); + let schema = provider.schema("test_provider_get_table_schema").unwrap(); let table = schema.table("my_table").await.unwrap().unwrap(); let table_schema = table.schema(); @@ -144,24 +185,24 @@ async fn test_provider_get_table_schema() -> Result<()> { #[tokio::test] async fn test_provider_list_table_names() -> Result<()> { - let fixture = set_test_fixture("test_provider_list_table_names").await; + let fixture = get_test_fixture().await; + let namespace = NamespaceIdent::new("test_provider_list_table_names".to_string()); + set_test_namespace(&fixture.hms_catalog, &namespace).await?; - let namespace = NamespaceIdent::new("default".to_string()); let creation = set_table_creation("s3a://warehouse/hive", "my_table")?; - fixture .hms_catalog .create_table(&namespace, creation) .await?; - let client = Arc::new(fixture.hms_catalog); + let client = Arc::new(fixture.get_catalog()); let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); let ctx = SessionContext::new(); ctx.register_catalog("hive", catalog); let provider = ctx.catalog("hive").unwrap(); - let schema = provider.schema("default").unwrap(); + let schema = provider.schema("test_provider_list_table_names").unwrap(); let expected = vec!["my_table"]; let result = schema.table_names(); @@ -173,10 +214,12 @@ async fn test_provider_list_table_names() -> Result<()> { #[tokio::test] async fn test_provider_list_schema_names() -> Result<()> { - let fixture = set_test_fixture("test_provider_list_schema_names").await; - set_table_creation("default", "my_table")?; + let fixture = get_test_fixture().await; + let namespace = NamespaceIdent::new("test_provider_list_schema_names".to_string()); + set_test_namespace(&fixture.hms_catalog, &namespace).await?; - let client = Arc::new(fixture.hms_catalog); + set_table_creation("test_provider_list_schema_names", "my_table")?; + let client = Arc::new(fixture.get_catalog()); let catalog = Arc::new(IcebergCatalogProvider::try_new(client).await?); let ctx = SessionContext::new(); @@ -184,10 +227,11 @@ async fn test_provider_list_schema_names() -> Result<()> { let provider = ctx.catalog("hive").unwrap(); - let expected = vec!["default"]; + let expected = ["default", "test_provider_list_schema_names"]; let result = provider.schema_names(); - assert_eq!(result, expected); - + assert!(expected + .iter() + .all(|item| result.contains(&item.to_string()))); Ok(()) } diff --git a/crates/test_utils/src/docker.rs b/crates/test_utils/src/docker.rs index df3c75439..3247fc96e 100644 --- a/crates/test_utils/src/docker.rs +++ b/crates/test_utils/src/docker.rs @@ -39,10 +39,23 @@ impl DockerCompose { self.project_name.as_str() } + fn get_os_arch() -> String { + let mut cmd = Command::new("docker"); + cmd.arg("info") + .arg("--format") + .arg("{{.OSType}}/{{.Architecture}}"); + + get_cmd_output(cmd, "Get os arch".to_string()) + .trim() + .to_string() + } + pub fn run(&self) { let mut cmd = Command::new("docker"); cmd.current_dir(&self.docker_compose_dir); + cmd.env("DOCKER_DEFAULT_PLATFORM", Self::get_os_arch()); + cmd.args(vec![ "compose", "-p", diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 6685489a6..7b10a8692 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -16,5 +16,5 @@ # under the License. [toolchain] -channel = "1.77.1" +channel = "nightly-2024-06-10" components = ["rustfmt", "clippy"]