Skip to content

Commit

Permalink
feat: add take_rows operation to the v2 file reader's python bindings (
Browse files Browse the repository at this point in the history
…#2331)

This also plumbs out basic support for take_rows and is a prerequisite
for adding a fragment-level take
  • Loading branch information
westonpace authored May 15, 2024
1 parent dfb531b commit 578afdd
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 39 deletions.
21 changes: 21 additions & 0 deletions python/python/lance/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ def read_range(
self._reader.read_range(start, num_rows, batch_size, batch_readahead)
)

def take_rows(
self, indices, *, batch_size: int = 1024, batch_readahead=16
) -> ReaderResults:
"""
Read a specific set of rows from the file
Parameters
----------
indices: List[int]
The indices of the rows to read from the file
batch_size: int, default 1024
The file will be read in batches. This parameter controls
how many rows will be in each batch (except the final batch)
Smaller batches will use less memory but might be slightly
slower because there is more per-batch overhead
"""
return ReaderResults(
self._reader.take_rows(indices, batch_size, batch_readahead)
)

def metadata(self) -> LanceFileMetadata:
"""
Return metadata describing the file contents
Expand Down
3 changes: 3 additions & 0 deletions python/python/lance/lance/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class LanceFileReader:
def read_range(
self, start: int, num_rows: int, batch_size: int, batch_readahead: int
) -> pa.RecordBatchReader: ...
def take_rows(
self, indices: List[int], batch_size: int, batch_readahead: int
) -> pa.RecordBatchReader: ...

class LanceBufferDescriptor:
position: int
Expand Down
17 changes: 17 additions & 0 deletions python/python/tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright The Lance Authors

import pyarrow as pa
import pytest
from lance.file import LanceFileReader, LanceFileWriter


Expand Down Expand Up @@ -33,6 +34,22 @@ def test_multiple_close(tmp_path):
writer.close()


def test_take(tmp_path):
path = tmp_path / "foo.lance"
schema = pa.schema([pa.field("a", pa.int64())])
writer = LanceFileWriter(str(path), schema)
writer.write_batch(pa.table({"a": [i for i in range(100)]}))
writer.close()

reader = LanceFileReader(str(path))
# Can't read out of range
with pytest.raises(ValueError):
reader.take_rows([0, 100]).to_table()

table = reader.take_rows([0, 77, 83]).to_table()
assert table == pa.table({"a": [0, 77, 83]})


def test_different_types(tmp_path):
path = tmp_path / "foo.lance"
schema = pa.schema(
Expand Down
20 changes: 19 additions & 1 deletion python/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::{pin::Pin, sync::Arc};

use arrow::pyarrow::PyArrowType;
use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_array::{RecordBatch, RecordBatchReader, UInt32Array};
use arrow_schema::Schema as ArrowSchema;
use futures::stream::StreamExt;
use lance::io::{ObjectStore, RecordBatchStream};
Expand Down Expand Up @@ -344,6 +344,24 @@ impl LanceFileReader {
)
}

pub fn take_rows(
&mut self,
row_indices: Vec<u64>,
batch_size: u32,
batch_readahead: u32,
) -> PyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
let indices = row_indices
.into_iter()
.map(|idx| idx as u32)
.collect::<Vec<_>>();
let indices_arr = UInt32Array::from(indices);
self.read_stream(
lance_io::ReadBatchParams::Indices(indices_arr),
batch_size,
batch_readahead,
)
}

pub fn metadata(&mut self, py: Python) -> LanceFileMetadata {
let inner_meta = self.inner.metadata();
LanceFileMetadata::new(inner_meta, py)
Expand Down
29 changes: 16 additions & 13 deletions rust/lance-encoding/src/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@
//! * The "batch overhead" is very small in Lance compared to other formats because it has no
//! relation to the way the data is stored.

use std::future::Future;
use std::{ops::Range, sync::Arc};

use arrow_array::cast::AsArray;
Expand Down Expand Up @@ -557,11 +556,14 @@ impl DecodeBatchScheduler {
/// * `scheduler` An I/O scheduler to issue I/O requests
pub async fn schedule_take(
&mut self,
indices: &[u32],
indices: &[u64],
sink: mpsc::UnboundedSender<Box<dyn LogicalPageDecoder>>,
scheduler: &Arc<dyn EncodingsIo>,
) -> Result<()> {
debug_assert!(indices.windows(2).all(|w| w[0] < w[1]));
if indices.is_empty() {
return Ok(());
}
trace!(
"Scheduling take of {} rows [{}]",
indices.len(),
Expand All @@ -574,15 +576,17 @@ impl DecodeBatchScheduler {
if indices.is_empty() {
return Ok(());
}
// TODO: Figure out how to handle u64 indices
let indices = indices.iter().map(|i| *i as u32).collect::<Vec<_>>();
self.root_scheduler
.schedule_take(indices, scheduler, &sink, indices[0] as u64)?;
.schedule_take(&indices, scheduler, &sink, indices[0] as u64)?;
trace!("Finished scheduling take of {} rows", indices.len());
Ok(())
}
}

pub struct ReadBatchTask<Fut: Future<Output = Result<RecordBatch>>> {
pub task: Fut,
pub struct ReadBatchTask {
pub task: BoxFuture<'static, Result<RecordBatch>>,
pub num_rows: u32,
}

Expand Down Expand Up @@ -620,7 +624,10 @@ impl BatchDecodeStream {

#[instrument(level = "debug", skip_all)]
async fn next_batch_task(&mut self) -> Result<Option<NextDecodeTask>> {
trace!("Draining batch task");
trace!(
"Draining batch task (rows_remaining={})",
self.rows_remaining
);
if self.rows_remaining == 0 {
return Ok(None);
}
Expand Down Expand Up @@ -651,9 +658,7 @@ impl BatchDecodeStream {
Ok(RecordBatch::from(struct_arr.as_struct()))
}

pub fn into_stream(
self,
) -> BoxStream<'static, ReadBatchTask<impl Future<Output = Result<RecordBatch>>>> {
pub fn into_stream(self) -> BoxStream<'static, ReadBatchTask> {
let stream = futures::stream::unfold(self, |mut slf| async move {
let next_task = slf.next_batch_task().await;
let next_task = next_task.transpose().map(|next_task| {
Expand All @@ -665,10 +670,8 @@ impl BatchDecodeStream {
(task, num_rows)
});
next_task.map(|(task, num_rows)| {
let next_task = ReadBatchTask {
task: task.map(|join_wrapper| join_wrapper.unwrap()),
num_rows,
};
let task = task.map(|join_wrapper| join_wrapper.unwrap()).boxed();
let next_task = ReadBatchTask { task, num_rows };
(next_task, slf)
})
});
Expand Down
33 changes: 27 additions & 6 deletions rust/lance-encoding/src/encodings/logical/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ impl LogicalPageScheduler for SimpleStructScheduler {
// NOTE: See schedule_range for a description of the scheduling algorithm
let mut current_top_level_row = top_level_row;
while rows_to_read > 0 {
trace!("Beginning scheduler scan of columns");
let mut min_rows_added = u32::MAX;
for (col_idx, field_scheduler) in self.children.iter().enumerate() {
let status = &mut field_status[col_idx];
Expand All @@ -263,12 +264,15 @@ impl LogicalPageScheduler for SimpleStructScheduler {
"{}",
if indices_in_page.is_empty() {
format!(
"Skipping entire page of {} rows",
next_candidate_page.num_rows()
"Skipping entire page of {} rows for column {}",
next_candidate_page.num_rows(),
col_idx,
)
} else {
format!(
"Found page with {} overlapping indices",
"Found page for column {} with {} rows that had {} overlapping indices",
col_idx,
next_candidate_page.num_rows(),
indices_in_page.len()
)
}
Expand All @@ -290,11 +294,18 @@ impl LogicalPageScheduler for SimpleStructScheduler {
status.rows_queued += rows_scheduled;

min_rows_added = min_rows_added.min(rows_scheduled);
} else {
// TODO: Unit tests are not covering this path right now
min_rows_added = min_rows_added.min(status.rows_queued);
}
}
if min_rows_added == 0 {
panic!("Error in scheduling logic, panic to avoid infinite loop");
}
trace!(
"One scheduling pass complete, {} rows added",
min_rows_added
);
rows_to_read -= min_rows_added;
current_top_level_row += min_rows_added as u64;
for field_status in &mut field_status {
Expand Down Expand Up @@ -380,6 +391,11 @@ impl ChildState {
);
let mut remaining = num_rows.saturating_sub(self.rows_available);
if remaining > 0 {
trace!(
"Struct must await {} rows from {} unawaited rows",
remaining,
self.rows_unawaited
);
if let Some(back) = self.awaited.back_mut() {
if back.unawaited() > 0 {
let rows_to_wait = remaining.min(back.unawaited());
Expand All @@ -394,7 +410,12 @@ impl ChildState {
let newly_avail = back.avail() - previously_avail;
trace!("The await loaded {} rows", newly_avail);
self.rows_available += newly_avail;
self.rows_unawaited -= newly_avail;
// Need to use saturating_sub here because we might have asked for range
// 0-1000 and this page we just loaded might cover 900-1100 and so newly_avail
// is 200 but rows_unawaited is only 100
//
// TODO: Unit tests are not covering this branch right now
self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail);
remaining -= rows_to_wait;
if remaining == 0 {
return Ok(true);
Expand Down Expand Up @@ -425,9 +446,9 @@ impl ChildState {
// newly_avail this way (we do this above too)
let newly_avail = decoder.avail() - previously_avail;
self.awaited.push_back(decoder);
self.rows_available += newly_avail;
self.rows_unawaited -= newly_avail;
trace!("The new await loaded {} rows", newly_avail);
self.rows_available += newly_avail;
self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail);
Ok(remaining == rows_to_wait)
} else {
Ok(true)
Expand Down
8 changes: 4 additions & 4 deletions rust/lance-encoding/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use std::{ops::Range, sync::Arc};

use arrow_array::{Array, UInt32Array};
use arrow_array::{Array, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use arrow_select::concat::concat;
use bytes::{Bytes, BytesMut};
Expand Down Expand Up @@ -122,7 +122,7 @@ fn supports_nulls(data_type: &DataType) -> bool {
#[derive(Clone, Default)]
pub struct TestCases {
ranges: Vec<Range<u64>>,
indices: Vec<Vec<u32>>,
indices: Vec<Vec<u64>>,
skip_validation: bool,
}

Expand All @@ -132,7 +132,7 @@ impl TestCases {
self
}

pub fn with_indices(mut self, indices: Vec<u32>) -> Self {
pub fn with_indices(mut self, indices: Vec<u64>) -> Self {
self.indices.push(indices);
self
}
Expand Down Expand Up @@ -285,7 +285,7 @@ async fn check_round_trip_encoding_inner(
);
}
let num_rows = indices.len() as u64;
let indices_arr = UInt32Array::from(indices.clone());
let indices_arr = UInt64Array::from(indices.clone());
let expected = concat_data
.as_ref()
.map(|concat_data| arrow_select::take::take(&concat_data, &indices_arr, None).unwrap());
Expand Down
Loading

0 comments on commit 578afdd

Please sign in to comment.