Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add take_rows operation to the v2 file reader's python bindings #2331

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions python/python/lance/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ def read_range(
self._reader.read_range(start, num_rows, batch_size, batch_readahead)
)

def take_rows(
self, indices, *, batch_size: int = 1024, batch_readahead=16
) -> ReaderResults:
"""
Read a specific set of rows from the file

Parameters
----------
indices: List[int]
The indices of the rows to read from the file
batch_size: int, default 1024
The file will be read in batches. This parameter controls
how many rows will be in each batch (except the final batch)

Smaller batches will use less memory but might be slightly
slower because there is more per-batch overhead
"""
return ReaderResults(
self._reader.take_rows(indices, batch_size, batch_readahead)
)

def metadata(self) -> LanceFileMetadata:
"""
Return metadata describing the file contents
Expand Down
3 changes: 3 additions & 0 deletions python/python/lance/lance/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class LanceFileReader:
def read_range(
self, start: int, num_rows: int, batch_size: int, batch_readahead: int
) -> pa.RecordBatchReader: ...
def take_rows(
self, indices: List[int], batch_size: int, batch_readahead: int
) -> pa.RecordBatchReader: ...

class LanceBufferDescriptor:
position: int
Expand Down
17 changes: 17 additions & 0 deletions python/python/tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright The Lance Authors

import pyarrow as pa
import pytest
from lance.file import LanceFileReader, LanceFileWriter


Expand Down Expand Up @@ -33,6 +34,22 @@ def test_multiple_close(tmp_path):
writer.close()


def test_take(tmp_path):
path = tmp_path / "foo.lance"
schema = pa.schema([pa.field("a", pa.int64())])
writer = LanceFileWriter(str(path), schema)
writer.write_batch(pa.table({"a": [i for i in range(100)]}))
writer.close()

reader = LanceFileReader(str(path))
# Can't read out of range
with pytest.raises(ValueError):
reader.take_rows([0, 100]).to_table()

table = reader.take_rows([0, 77, 83]).to_table()
assert table == pa.table({"a": [0, 77, 83]})


def test_different_types(tmp_path):
path = tmp_path / "foo.lance"
schema = pa.schema(
Expand Down
20 changes: 19 additions & 1 deletion python/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
use std::{pin::Pin, sync::Arc};

use arrow::pyarrow::PyArrowType;
use arrow_array::{RecordBatch, RecordBatchReader};
use arrow_array::{RecordBatch, RecordBatchReader, UInt32Array};
use arrow_schema::Schema as ArrowSchema;
use futures::stream::StreamExt;
use lance::io::{ObjectStore, RecordBatchStream};
Expand Down Expand Up @@ -344,6 +344,24 @@ impl LanceFileReader {
)
}

pub fn take_rows(
&mut self,
row_indices: Vec<u64>,
batch_size: u32,
batch_readahead: u32,
) -> PyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
let indices = row_indices
.into_iter()
.map(|idx| idx as u32)
.collect::<Vec<_>>();
let indices_arr = UInt32Array::from(indices);
self.read_stream(
lance_io::ReadBatchParams::Indices(indices_arr),
batch_size,
batch_readahead,
)
}

pub fn metadata(&mut self, py: Python) -> LanceFileMetadata {
let inner_meta = self.inner.metadata();
LanceFileMetadata::new(inner_meta, py)
Expand Down
29 changes: 16 additions & 13 deletions rust/lance-encoding/src/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@
//! * The "batch overhead" is very small in Lance compared to other formats because it has no
//! relation to the way the data is stored.

use std::future::Future;
use std::{ops::Range, sync::Arc};

use arrow_array::cast::AsArray;
Expand Down Expand Up @@ -557,11 +556,14 @@ impl DecodeBatchScheduler {
/// * `scheduler` An I/O scheduler to issue I/O requests
pub async fn schedule_take(
&mut self,
indices: &[u32],
indices: &[u64],
sink: mpsc::UnboundedSender<Box<dyn LogicalPageDecoder>>,
scheduler: &Arc<dyn EncodingsIo>,
) -> Result<()> {
debug_assert!(indices.windows(2).all(|w| w[0] < w[1]));
if indices.is_empty() {
return Ok(());
}
trace!(
"Scheduling take of {} rows [{}]",
indices.len(),
Expand All @@ -574,15 +576,17 @@ impl DecodeBatchScheduler {
if indices.is_empty() {
return Ok(());
}
// TODO: Figure out how to handle u64 indices
let indices = indices.iter().map(|i| *i as u32).collect::<Vec<_>>();
self.root_scheduler
.schedule_take(indices, scheduler, &sink, indices[0] as u64)?;
.schedule_take(&indices, scheduler, &sink, indices[0] as u64)?;
trace!("Finished scheduling take of {} rows", indices.len());
Ok(())
}
}

pub struct ReadBatchTask<Fut: Future<Output = Result<RecordBatch>>> {
pub task: Fut,
pub struct ReadBatchTask {
pub task: BoxFuture<'static, Result<RecordBatch>>,
pub num_rows: u32,
}

Expand Down Expand Up @@ -620,7 +624,10 @@ impl BatchDecodeStream {

#[instrument(level = "debug", skip_all)]
async fn next_batch_task(&mut self) -> Result<Option<NextDecodeTask>> {
trace!("Draining batch task");
trace!(
"Draining batch task (rows_remaining={})",
self.rows_remaining
);
if self.rows_remaining == 0 {
return Ok(None);
}
Expand Down Expand Up @@ -651,9 +658,7 @@ impl BatchDecodeStream {
Ok(RecordBatch::from(struct_arr.as_struct()))
}

pub fn into_stream(
self,
) -> BoxStream<'static, ReadBatchTask<impl Future<Output = Result<RecordBatch>>>> {
pub fn into_stream(self) -> BoxStream<'static, ReadBatchTask> {
let stream = futures::stream::unfold(self, |mut slf| async move {
let next_task = slf.next_batch_task().await;
let next_task = next_task.transpose().map(|next_task| {
Expand All @@ -665,10 +670,8 @@ impl BatchDecodeStream {
(task, num_rows)
});
next_task.map(|(task, num_rows)| {
let next_task = ReadBatchTask {
task: task.map(|join_wrapper| join_wrapper.unwrap()),
num_rows,
};
let task = task.map(|join_wrapper| join_wrapper.unwrap()).boxed();
let next_task = ReadBatchTask { task, num_rows };
(next_task, slf)
})
});
Expand Down
33 changes: 27 additions & 6 deletions rust/lance-encoding/src/encodings/logical/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ impl LogicalPageScheduler for SimpleStructScheduler {
// NOTE: See schedule_range for a description of the scheduling algorithm
let mut current_top_level_row = top_level_row;
while rows_to_read > 0 {
trace!("Beginning scheduler scan of columns");
let mut min_rows_added = u32::MAX;
for (col_idx, field_scheduler) in self.children.iter().enumerate() {
let status = &mut field_status[col_idx];
Expand All @@ -263,12 +264,15 @@ impl LogicalPageScheduler for SimpleStructScheduler {
"{}",
if indices_in_page.is_empty() {
format!(
"Skipping entire page of {} rows",
next_candidate_page.num_rows()
"Skipping entire page of {} rows for column {}",
next_candidate_page.num_rows(),
col_idx,
)
} else {
format!(
"Found page with {} overlapping indices",
"Found page for column {} with {} rows that had {} overlapping indices",
col_idx,
next_candidate_page.num_rows(),
indices_in_page.len()
)
}
Expand All @@ -290,11 +294,18 @@ impl LogicalPageScheduler for SimpleStructScheduler {
status.rows_queued += rows_scheduled;

min_rows_added = min_rows_added.min(rows_scheduled);
} else {
// TODO: Unit tests are not covering this path right now
min_rows_added = min_rows_added.min(status.rows_queued);
}
}
if min_rows_added == 0 {
panic!("Error in scheduling logic, panic to avoid infinite loop");
}
trace!(
"One scheduling pass complete, {} rows added",
min_rows_added
);
rows_to_read -= min_rows_added;
current_top_level_row += min_rows_added as u64;
for field_status in &mut field_status {
Expand Down Expand Up @@ -380,6 +391,11 @@ impl ChildState {
);
let mut remaining = num_rows.saturating_sub(self.rows_available);
if remaining > 0 {
trace!(
"Struct must await {} rows from {} unawaited rows",
remaining,
self.rows_unawaited
);
if let Some(back) = self.awaited.back_mut() {
if back.unawaited() > 0 {
let rows_to_wait = remaining.min(back.unawaited());
Expand All @@ -394,7 +410,12 @@ impl ChildState {
let newly_avail = back.avail() - previously_avail;
trace!("The await loaded {} rows", newly_avail);
self.rows_available += newly_avail;
self.rows_unawaited -= newly_avail;
// Need to use saturating_sub here because we might have asked for range
// 0-1000 and this page we just loaded might cover 900-1100 and so newly_avail
// is 200 but rows_unawaited is only 100
//
// TODO: Unit tests are not covering this branch right now
self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail);
remaining -= rows_to_wait;
if remaining == 0 {
return Ok(true);
Expand Down Expand Up @@ -425,9 +446,9 @@ impl ChildState {
// newly_avail this way (we do this above too)
let newly_avail = decoder.avail() - previously_avail;
self.awaited.push_back(decoder);
self.rows_available += newly_avail;
self.rows_unawaited -= newly_avail;
trace!("The new await loaded {} rows", newly_avail);
self.rows_available += newly_avail;
self.rows_unawaited = self.rows_unawaited.saturating_sub(newly_avail);
Ok(remaining == rows_to_wait)
} else {
Ok(true)
Expand Down
8 changes: 4 additions & 4 deletions rust/lance-encoding/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use std::{ops::Range, sync::Arc};

use arrow_array::{Array, UInt32Array};
use arrow_array::{Array, UInt64Array};
use arrow_schema::{DataType, Field, Schema};
use arrow_select::concat::concat;
use bytes::{Bytes, BytesMut};
Expand Down Expand Up @@ -122,7 +122,7 @@ fn supports_nulls(data_type: &DataType) -> bool {
#[derive(Clone, Default)]
pub struct TestCases {
ranges: Vec<Range<u64>>,
indices: Vec<Vec<u32>>,
indices: Vec<Vec<u64>>,
skip_validation: bool,
}

Expand All @@ -132,7 +132,7 @@ impl TestCases {
self
}

pub fn with_indices(mut self, indices: Vec<u32>) -> Self {
pub fn with_indices(mut self, indices: Vec<u64>) -> Self {
self.indices.push(indices);
self
}
Expand Down Expand Up @@ -285,7 +285,7 @@ async fn check_round_trip_encoding_inner(
);
}
let num_rows = indices.len() as u64;
let indices_arr = UInt32Array::from(indices.clone());
let indices_arr = UInt64Array::from(indices.clone());
let expected = concat_data
.as_ref()
.map(|concat_data| arrow_select::take::take(&concat_data, &indices_arr, None).unwrap());
Expand Down
Loading
Loading