diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index 5404ac90b..6cad4adbe 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -71,7 +71,6 @@ serde_derive = { workspace = true } serde_json = { workspace = true } serde_repr = { workspace = true } serde_with = { workspace = true } -tokio = { workspace = true } typed-builder = { workspace = true } url = { workspace = true } urlencoding = { workspace = true } @@ -82,3 +81,4 @@ iceberg_test_utils = { path = "../test_utils", features = ["tests"] } pretty_assertions = { workspace = true } tempfile = { workspace = true } tera = { workspace = true } +tokio = { workspace = true } diff --git a/crates/iceberg/src/writer/file_writer/parquet_writer.rs b/crates/iceberg/src/writer/file_writer/parquet_writer.rs index 50a507e2e..43e49fc8b 100644 --- a/crates/iceberg/src/writer/file_writer/parquet_writer.rs +++ b/crates/iceberg/src/writer/file_writer/parquet_writer.rs @@ -17,6 +17,11 @@ //! The module contains the file writer for parquet file format. +use super::{ + location_generator::{FileNameGenerator, LocationGenerator}, + track_writer::TrackWriter, + FileWriter, FileWriterBuilder, +}; use crate::arrow::DEFAULT_MAP_FIELD_NAME; use crate::spec::{ visit_schema, Datum, ListType, MapType, NestedFieldRef, PrimitiveLiteral, PrimitiveType, @@ -40,25 +45,19 @@ use parquet::data_type::{ }; use parquet::file::properties::WriterProperties; use parquet::file::statistics::TypedStatistics; -use parquet::{arrow::AsyncArrowWriter, format::FileMetaData}; +use parquet::{ + arrow::async_writer::AsyncFileWriter as ArrowAsyncFileWriter, arrow::AsyncArrowWriter, + format::FileMetaData, +}; use parquet::{ data_type::{ByteArray, FixedLenByteArray}, file::statistics::{from_thrift, Statistics}, }; -use std::pin::Pin; -use std::task::{Context, Poll}; -use std::{ - collections::HashMap, - sync::{atomic::AtomicI64, Arc}, -}; +use std::collections::HashMap; +use std::sync::atomic::AtomicI64; +use std::sync::Arc; use uuid::Uuid; -use super::{ - location_generator::{FileNameGenerator, LocationGenerator}, - track_writer::TrackWriter, - FileWriter, FileWriterBuilder, -}; - /// ParquetWriterBuilder is used to builder a [`ParquetWriter`] #[derive(Clone)] pub struct ParquetWriterBuilder { @@ -571,102 +570,38 @@ impl CurrentFileStatus for ParquetWriter { /// # NOTES /// /// We keep this wrapper been used inside only. -/// -/// # TODO -/// -/// Maybe we can use the buffer from ArrowWriter directly. -struct AsyncFileWriter(State); - -enum State { - Idle(Option), - Write(BoxFuture<'static, (W, Result<()>)>), - Close(BoxFuture<'static, (W, Result<()>)>), -} +struct AsyncFileWriter(W); impl AsyncFileWriter { /// Create a new `AsyncFileWriter` with the given writer. pub fn new(writer: W) -> Self { - Self(State::Idle(Some(writer))) + Self(writer) } } -impl tokio::io::AsyncWrite for AsyncFileWriter { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let this = self.get_mut(); - loop { - match &mut this.0 { - State::Idle(w) => { - let mut writer = w.take().unwrap(); - let bs = Bytes::copy_from_slice(buf); - let fut = async move { - let res = writer.write(bs).await; - (writer, res) - }; - this.0 = State::Write(Box::pin(fut)); - } - State::Write(fut) => { - let (writer, res) = futures::ready!(fut.as_mut().poll(cx)); - this.0 = State::Idle(Some(writer)); - return Poll::Ready(res.map(|_| buf.len()).map_err(|err| { - std::io::Error::new(std::io::ErrorKind::Other, Box::new(err)) - })); - } - State::Close(_) => { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "file is closed", - ))); - } - } - } - } - - fn poll_flush( - self: Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { - Poll::Ready(Ok(())) +impl ArrowAsyncFileWriter for AsyncFileWriter { + fn write(&mut self, bs: Bytes) -> BoxFuture<'_, parquet::errors::Result<()>> { + Box::pin(async { + self.0 + .write(bs) + .await + .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))) + }) } - fn poll_shutdown( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let this = self.get_mut(); - loop { - match &mut this.0 { - State::Idle(w) => { - let mut writer = w.take().unwrap(); - let fut = async move { - let res = writer.close().await; - (writer, res) - }; - this.0 = State::Close(Box::pin(fut)); - } - State::Write(_) => { - return Poll::Ready(Err(std::io::Error::new( - std::io::ErrorKind::Other, - "file is writing", - ))); - } - State::Close(fut) => { - let (writer, res) = futures::ready!(fut.as_mut().poll(cx)); - this.0 = State::Idle(Some(writer)); - return Poll::Ready(res.map_err(|err| { - std::io::Error::new(std::io::ErrorKind::Other, Box::new(err)) - })); - } - } - } + fn complete(&mut self) -> BoxFuture<'_, parquet::errors::Result<()>> { + Box::pin(async { + self.0 + .close() + .await + .map_err(|err| parquet::errors::ParquetError::External(Box::new(err))) + }) } } #[cfg(test)] mod tests { + use std::collections::HashMap; use std::sync::Arc; use anyhow::Result; @@ -686,8 +621,8 @@ mod tests { use super::*; use crate::io::FileIOBuilder; - use crate::spec::NestedField; - use crate::spec::Struct; + use crate::spec::*; + use crate::spec::{PrimitiveLiteral, Struct}; use crate::writer::file_writer::location_generator::test::MockLocationGenerator; use crate::writer::file_writer::location_generator::DefaultFileNameGenerator; use crate::writer::tests::check_parquet_data_file;