From 7930e211abddc7eb69727fb1ac38532f00fa7628 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 8 May 2024 15:30:02 -0700 Subject: [PATCH] WIP --- python/python/lance/file.py | 21 +++++ python/python/lance/lance/__init__.pyi | 3 + python/python/tests/test_file.py | 17 ++++ python/src/file.rs | 20 ++++- rust/lance-encoding/src/decoder.rs | 29 +++---- .../src/encodings/logical/struct.rs | 33 ++++++-- rust/lance-encoding/src/testing.rs | 8 +- rust/lance-file/src/v2/reader.rs | 78 +++++++++++++++---- 8 files changed, 170 insertions(+), 39 deletions(-) diff --git a/python/python/lance/file.py b/python/python/lance/file.py index 6b394bf692..2ea0daa266 100644 --- a/python/python/lance/file.py +++ b/python/python/lance/file.py @@ -106,6 +106,27 @@ def read_range( self._reader.read_range(start, num_rows, batch_size, batch_readahead) ) + def take_rows( + self, indices, *, batch_size: int = 1024, batch_readahead=16 + ) -> ReaderResults: + """ + Read a specific set of rows from the file + + Parameters + ---------- + indices: List[int] + The indices of the rows to read from the file + batch_size: int, default 1024 + The file will be read in batches. This parameter controls + how many rows will be in each batch (except the final batch) + + Smaller batches will use less memory but might be slightly + slower because there is more per-batch overhead + """ + return ReaderResults( + self._reader.take_rows(indices, batch_size, batch_readahead) + ) + def metadata(self) -> LanceFileMetadata: """ Return metadata describing the file contents diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index b56496a01c..159926ffda 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -52,6 +52,9 @@ class LanceFileReader: def read_range( self, start: int, num_rows: int, batch_size: int, batch_readahead: int ) -> pa.RecordBatchReader: ... + def take_rows( + self, indices: List[int], batch_size: int, batch_readahead: int + ) -> pa.RecordBatchReader: ... class LanceBufferDescriptor: position: int diff --git a/python/python/tests/test_file.py b/python/python/tests/test_file.py index 3d1634400c..a8bbb38466 100644 --- a/python/python/tests/test_file.py +++ b/python/python/tests/test_file.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright The Lance Authors import pyarrow as pa +import pytest from lance.file import LanceFileReader, LanceFileWriter @@ -33,6 +34,22 @@ def test_multiple_close(tmp_path): writer.close() +def test_take(tmp_path): + path = tmp_path / "foo.lance" + schema = pa.schema([pa.field("a", pa.int64())]) + writer = LanceFileWriter(str(path), schema) + writer.write_batch(pa.table({"a": [i for i in range(100)]})) + writer.close() + + reader = LanceFileReader(str(path)) + # Can't read out of range + with pytest.raises(ValueError): + reader.take_rows([0, 100]).to_table() + + table = reader.take_rows([0, 77, 83]).to_table() + assert table == pa.table({"a": [0, 77, 83]}) + + def test_different_types(tmp_path): path = tmp_path / "foo.lance" schema = pa.schema( diff --git a/python/src/file.rs b/python/src/file.rs index bb361c3368..f862769752 100644 --- a/python/src/file.rs +++ b/python/src/file.rs @@ -15,7 +15,7 @@ use std::{pin::Pin, sync::Arc}; use arrow::pyarrow::PyArrowType; -use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow_array::{RecordBatch, RecordBatchReader, UInt32Array}; use arrow_schema::Schema as ArrowSchema; use futures::stream::StreamExt; use lance::io::{ObjectStore, RecordBatchStream}; @@ -344,6 +344,24 @@ impl LanceFileReader { ) } + pub fn take_rows( + &mut self, + row_indices: Vec, + batch_size: u32, + batch_readahead: u32, + ) -> PyResult>> { + let indices = row_indices + .into_iter() + .map(|idx| idx as u32) + .collect::>(); + let indices_arr = UInt32Array::from(indices); + self.read_stream( + lance_io::ReadBatchParams::Indices(indices_arr), + batch_size, + batch_readahead, + ) + } + pub fn metadata(&mut self, py: Python) -> LanceFileMetadata { let inner_meta = self.inner.metadata(); LanceFileMetadata::new(inner_meta, py) diff --git a/rust/lance-encoding/src/decoder.rs b/rust/lance-encoding/src/decoder.rs index f271158379..a0098a24c7 100644 --- a/rust/lance-encoding/src/decoder.rs +++ b/rust/lance-encoding/src/decoder.rs @@ -196,7 +196,6 @@ //! * The "batch overhead" is very small in Lance compared to other formats because it has no //! relation to the way the data is stored. -use std::future::Future; use std::{ops::Range, sync::Arc}; use arrow_array::cast::AsArray; @@ -557,11 +556,14 @@ impl DecodeBatchScheduler { /// * `scheduler` An I/O scheduler to issue I/O requests pub async fn schedule_take( &mut self, - indices: &[u32], + indices: &[u64], sink: mpsc::UnboundedSender>, scheduler: &Arc, ) -> Result<()> { debug_assert!(indices.windows(2).all(|w| w[0] < w[1])); + if indices.is_empty() { + return Ok(()); + } trace!( "Scheduling take of {} rows [{}]", indices.len(), @@ -574,15 +576,17 @@ impl DecodeBatchScheduler { if indices.is_empty() { return Ok(()); } + // TODO: Figure out how to handle u64 indices + let indices = indices.iter().map(|i| *i as u32).collect::>(); self.root_scheduler - .schedule_take(indices, scheduler, &sink, indices[0] as u64)?; + .schedule_take(&indices, scheduler, &sink, indices[0] as u64)?; trace!("Finished scheduling take of {} rows", indices.len()); Ok(()) } } -pub struct ReadBatchTask>> { - pub task: Fut, +pub struct ReadBatchTask { + pub task: BoxFuture<'static, Result>, pub num_rows: u32, } @@ -620,7 +624,10 @@ impl BatchDecodeStream { #[instrument(level = "debug", skip_all)] async fn next_batch_task(&mut self) -> Result> { - trace!("Draining batch task"); + trace!( + "Draining batch task (rows_remaining={})", + self.rows_remaining + ); if self.rows_remaining == 0 { return Ok(None); } @@ -651,9 +658,7 @@ impl BatchDecodeStream { Ok(RecordBatch::from(struct_arr.as_struct())) } - pub fn into_stream( - self, - ) -> BoxStream<'static, ReadBatchTask>>> { + pub fn into_stream(self) -> BoxStream<'static, ReadBatchTask> { let stream = futures::stream::unfold(self, |mut slf| async move { let next_task = slf.next_batch_task().await; let next_task = next_task.transpose().map(|next_task| { @@ -665,10 +670,8 @@ impl BatchDecodeStream { (task, num_rows) }); next_task.map(|(task, num_rows)| { - let next_task = ReadBatchTask { - task: task.map(|join_wrapper| join_wrapper.unwrap()), - num_rows, - }; + let task = task.map(|join_wrapper| join_wrapper.unwrap()).boxed(); + let next_task = ReadBatchTask { task, num_rows }; (next_task, slf) }) }); diff --git a/rust/lance-encoding/src/encodings/logical/struct.rs b/rust/lance-encoding/src/encodings/logical/struct.rs index ee1fd33bb8..1672cf619b 100644 --- a/rust/lance-encoding/src/encodings/logical/struct.rs +++ b/rust/lance-encoding/src/encodings/logical/struct.rs @@ -248,6 +248,7 @@ impl LogicalPageScheduler for SimpleStructScheduler { // NOTE: See schedule_range for a description of the scheduling algorithm let mut current_top_level_row = top_level_row; while rows_to_read > 0 { + trace!("Beginning scheduler scan of columns"); let mut min_rows_added = u32::MAX; for (col_idx, field_scheduler) in self.children.iter().enumerate() { let status = &mut field_status[col_idx]; @@ -263,12 +264,15 @@ impl LogicalPageScheduler for SimpleStructScheduler { "{}", if indices_in_page.is_empty() { format!( - "Skipping entire page of {} rows", - next_candidate_page.num_rows() + "Skipping entire page of {} rows for column {}", + next_candidate_page.num_rows(), + col_idx, ) } else { format!( - "Found page with {} overlapping indices", + "Found page for column {} with {} rows that had {} overlapping indices", + col_idx, + next_candidate_page.num_rows(), indices_in_page.len() ) } @@ -290,11 +294,18 @@ impl LogicalPageScheduler for SimpleStructScheduler { status.rows_queued += rows_scheduled; min_rows_added = min_rows_added.min(rows_scheduled); + } else { + // TODO: Unit tests are not covering this path right now + min_rows_added = min_rows_added.min(status.rows_queued); } } if min_rows_added == 0 { panic!("Error in scheduling logic, panic to avoid infinite loop"); } + trace!( + "One scheduling pass complete, {} rows added", + min_rows_added + ); rows_to_read -= min_rows_added; current_top_level_row += min_rows_added as u64; for field_status in &mut field_status { @@ -380,6 +391,11 @@ impl ChildState { ); let mut remaining = num_rows.saturating_sub(self.rows_available); if remaining > 0 { + trace!( + "Struct must await {} rows from {} unawaited rows", + remaining, + self.rows_unawaited + ); if let Some(back) = self.awaited.back_mut() { if back.unawaited() > 0 { let rows_to_wait = remaining.min(back.unawaited()); @@ -394,7 +410,12 @@ impl ChildState { let newly_avail = back.avail() - previously_avail; trace!("The await loaded {} rows", newly_avail); self.rows_available += newly_avail; - self.rows_unawaited -= newly_avail; + // Need to use saturating_sub here because we might have asked for range + // 0-1000 and this page we just loaded might cover 900-1100 and so newly_avail + // is 200 but rows_unawaited is only 100 + // + // TODO: Unit tests are not covering this branch right now + self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail); remaining -= rows_to_wait; if remaining == 0 { return Ok(true); @@ -425,9 +446,9 @@ impl ChildState { // newly_avail this way (we do this above too) let newly_avail = decoder.avail() - previously_avail; self.awaited.push_back(decoder); - self.rows_available += newly_avail; - self.rows_unawaited -= newly_avail; trace!("The new await loaded {} rows", newly_avail); + self.rows_available += newly_avail; + self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail); Ok(remaining == rows_to_wait) } else { Ok(true) diff --git a/rust/lance-encoding/src/testing.rs b/rust/lance-encoding/src/testing.rs index 58a2d4da47..d270f9d9aa 100644 --- a/rust/lance-encoding/src/testing.rs +++ b/rust/lance-encoding/src/testing.rs @@ -3,7 +3,7 @@ use std::{ops::Range, sync::Arc}; -use arrow_array::{Array, UInt32Array}; +use arrow_array::{Array, UInt64Array}; use arrow_schema::{DataType, Field, Schema}; use arrow_select::concat::concat; use bytes::{Bytes, BytesMut}; @@ -122,7 +122,7 @@ fn supports_nulls(data_type: &DataType) -> bool { #[derive(Clone, Default)] pub struct TestCases { ranges: Vec>, - indices: Vec>, + indices: Vec>, skip_validation: bool, } @@ -132,7 +132,7 @@ impl TestCases { self } - pub fn with_indices(mut self, indices: Vec) -> Self { + pub fn with_indices(mut self, indices: Vec) -> Self { self.indices.push(indices); self } @@ -285,7 +285,7 @@ async fn check_round_trip_encoding_inner( ); } let num_rows = indices.len() as u64; - let indices_arr = UInt32Array::from(indices.clone()); + let indices_arr = UInt64Array::from(indices.clone()); let expected = concat_data .as_ref() .map(|concat_data| arrow_select::take::take(&concat_data, &indices_arr, None).unwrap()); diff --git a/rust/lance-file/src/v2/reader.rs b/rust/lance-file/src/v2/reader.rs index f34115b02e..0e103b7639 100644 --- a/rust/lance-file/src/v2/reader.rs +++ b/rust/lance-file/src/v2/reader.rs @@ -3,11 +3,10 @@ use std::{collections::BTreeSet, io::Cursor, ops::Range, pin::Pin, sync::Arc}; -use arrow_array::RecordBatch; use arrow_schema::Schema as ArrowSchema; use byteorder::{ByteOrder, LittleEndian, ReadBytesExt}; use bytes::{Bytes, BytesMut}; -use futures::{stream::BoxStream, Future, Stream, StreamExt}; +use futures::{stream::BoxStream, Stream, StreamExt}; use lance_arrow::DataTypeExt; use lance_encoding::{ decoder::{BatchDecodeStream, ColumnInfo, DecodeBatchScheduler, PageInfo, ReadBatchTask}, @@ -583,7 +582,7 @@ impl FileReader { range: Range, batch_size: u32, projection: &ReaderProjection, - ) -> Result>>>> { + ) -> Result> { let column_infos = self.collect_columns_from_projection(projection)?; debug!( "Reading range {:?} with batch_size {} from columns {:?}", @@ -599,12 +598,49 @@ impl FileReader { let (tx, rx) = mpsc::unbounded_channel(); + let num_rows_to_read = range.end - range.start; + let scheduler = self.scheduler.clone() as Arc; tokio::task::spawn( async move { decode_scheduler.schedule_range(range, tx, &scheduler).await }, ); - Ok(BatchDecodeStream::new(rx, batch_size, self.num_rows).into_stream()) + Ok(BatchDecodeStream::new(rx, batch_size, num_rows_to_read).into_stream()) + } + + fn take_rows( + &self, + indices: Vec, + batch_size: u32, + projection: &ReaderProjection, + ) -> Result> { + let column_infos = self.collect_columns_from_projection(projection)?; + debug!( + "Taking {} rows spread across range {}..{} with batch_size {} from columns {:?}", + indices.len(), + indices[0], + indices[indices.len() - 1], + batch_size, + column_infos.iter().map(|ci| ci.index).collect::>() + ); + let mut decode_scheduler = DecodeBatchScheduler::new( + &projection.schema, + column_infos.iter().map(|ci| ci.as_ref()), + &vec![], + ); + + let (tx, rx) = mpsc::unbounded_channel(); + + let num_rows_to_read = indices.len() as u64; + + let scheduler = self.scheduler.clone() as Arc; + tokio::task::spawn(async move { + decode_scheduler + .schedule_take(&indices, tx, &scheduler) + .await + }); + + Ok(BatchDecodeStream::new(rx, batch_size, num_rows_to_read).into_stream()) } /// Creates a stream of "read tasks" to read the data from the file @@ -622,14 +658,10 @@ impl FileReader { params: ReadBatchParams, batch_size: u32, projection: &ReaderProjection, - ) -> Result< - Pin< - Box>>> + Send>, - >, - > { + ) -> Result + Send>>> { Self::validate_projection(projection, &self.metadata)?; - let verify_bound = |params: &ReadBatchParams, bound: usize| { - if bound > u32::MAX as usize { + let verify_bound = |params: &ReadBatchParams, bound: u64, inclusive: bool| { + if bound > self.num_rows || bound == self.num_rows && inclusive { Err(Error::invalid_input( format!( "cannot read {:?} from file with {} rows", @@ -642,17 +674,33 @@ impl FileReader { } }; match ¶ms { - ReadBatchParams::Indices(_) => todo!(), + ReadBatchParams::Indices(indices) => { + for idx in indices { + match idx { + None => { + return Err(Error::invalid_input( + "Null value in indices array", + location!(), + )) + } + Some(idx) => { + verify_bound(¶ms, idx as u64, true)?; + } + } + } + let indices = indices.iter().map(|idx| idx.unwrap() as u64).collect(); + self.take_rows(indices, batch_size, projection) + } ReadBatchParams::Range(range) => { - verify_bound(¶ms, range.end)?; + verify_bound(¶ms, range.end as u64, false)?; self.read_range(range.start as u64..range.end as u64, batch_size, projection) } ReadBatchParams::RangeFrom(range) => { - verify_bound(¶ms, range.start)?; + verify_bound(¶ms, range.start as u64, true)?; self.read_range(range.start as u64..self.num_rows, batch_size, projection) } ReadBatchParams::RangeTo(range) => { - verify_bound(¶ms, range.end)?; + verify_bound(¶ms, range.end as u64, false)?; self.read_range(0..range.end as u64, batch_size, projection) } ReadBatchParams::RangeFull => self.read_range(0..self.num_rows, batch_size, projection),