Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: update merge_insert to add statistics for inserted, updated, deleted rows #2357

Merged
merged 14 commits into from
May 22, 2024
Merged
8 changes: 3 additions & 5 deletions python/python/lance/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@

import pyarrow as pa

from ._arrow.bf16 import ( # noqa: F401
BFloat16,
from ._arrow.bf16 import (
BFloat16Array,
BFloat16Type,
PandasBFloat16Array,
)
BFloat16Type, # noqa: F401
)
from .dependencies import numpy as np
from .lance import bfloat16_array

Expand Down
4 changes: 3 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):
source is some kind of generator.
"""
reader = _coerce_reader(data_obj, schema)
super(MergeInsertBuilder, self).execute(reader)
inserted, updated, deleted = super(MergeInsertBuilder, self).execute(reader)

return inserted, updated, deleted
raunaks13 marked this conversation as resolved.
Show resolved Hide resolved

# These next three overrides exist only to document the methods

Expand Down
8 changes: 2 additions & 6 deletions python/python/lance/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,8 @@
LanceFileMetadata,
LancePageMetadata,
)
from .lance import (
LanceFileReader as _LanceFileReader,
)
from .lance import (
LanceFileWriter as _LanceFileWriter,
)
from .lance import LanceFileReader as _LanceFileReader
from .lance import LanceFileWriter as _LanceFileWriter


class ReaderResults:
Expand Down
2 changes: 0 additions & 2 deletions python/python/lance/optimize.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,11 @@ class RewriteResult:
class CompactionTask:
read_version: int
fragments: List["FragmentMetadata"]

def execute(self, dataset: "Dataset") -> RewriteResult: ...

class CompactionPlan:
read_version: int
tasks: List[CompactionTask]

def num_tasks(self) -> int: ...

class Compaction:
Expand Down
101 changes: 72 additions & 29 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,59 +939,81 @@ def test_merge_insert(tmp_path: Path):

is_new = pc.field("b") == 2

dataset.merge_insert("a").when_not_matched_insert_all().execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert("a").when_not_matched_insert_all().execute(new_table)
)
table = dataset.to_table()
assert table.num_rows == 1300
assert table.filter(is_new).num_rows == 300
assert (inserted, updated, deleted) == (300, 0, 0)

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
dataset.merge_insert("a").when_matched_update_all().execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert("a").when_matched_update_all().execute(new_table)
)
table = dataset.to_table()
assert table.num_rows == 1000
assert table.filter(is_new).num_rows == 700
assert (inserted, updated, deleted) == (0, 700, 0)

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
dataset.merge_insert(
"a"
).when_not_matched_insert_all().when_matched_update_all().execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert("a")
.when_not_matched_insert_all()
.when_matched_update_all()
.execute(new_table)
)
table = dataset.to_table()
assert table.num_rows == 1300
assert table.filter(is_new).num_rows == 1000
assert (inserted, updated, deleted) == (300, 700, 0)

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
dataset.merge_insert("a").when_not_matched_insert_all().when_matched_update_all(
"target.c == source.c"
).execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert("a")
.when_not_matched_insert_all()
.when_matched_update_all("target.c == source.c")
.execute(new_table)
)
table = dataset.to_table()
assert table.num_rows == 1300
assert table.filter(is_new).num_rows == 650
assert (inserted, updated, deleted) == (300, 350, 0)

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
dataset.merge_insert("a").when_not_matched_by_source_delete().execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert("a").when_not_matched_by_source_delete().execute(new_table)
)
table = dataset.to_table()
assert table.num_rows == 700
assert table.filter(is_new).num_rows == 0
assert (inserted, updated, deleted) == (0, 0, 300)

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
dataset.merge_insert("a").when_not_matched_by_source_delete(
"a < 100"
).when_not_matched_insert_all().execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert("a")
.when_not_matched_by_source_delete("a < 100")
.when_not_matched_insert_all()
.execute(new_table)
)

table = dataset.to_table()
assert table.num_rows == 1200
assert table.filter(is_new).num_rows == 300
assert (inserted, updated, deleted) == (300, 0, 100)

# If the user doesn't specify anything then the merge_insert is
# a no-op and the operation fails
dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()
with pytest.raises(ValueError):
dataset.merge_insert("a").execute(new_table)
inserted, updated, deleted = dataset.merge_insert("a").execute(new_table)
assert (inserted, updated, deleted) == (None, None, None)


def test_flat_vector_search_with_delete(tmp_path: Path):
Expand Down Expand Up @@ -1031,9 +1053,11 @@ def test_merge_insert_conditional_upsert_example(tmp_path: Path):
}
)

dataset.merge_insert("id").when_matched_update_all(
"target.txNumber < source.txNumber"
).execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert("id")
.when_matched_update_all("target.txNumber < source.txNumber")
.execute(new_table)
)

table = dataset.to_table()

Expand All @@ -1049,6 +1073,7 @@ def test_merge_insert_conditional_upsert_example(tmp_path: Path):
)

assert table.sort_by("id") == expected
assert (inserted, updated, deleted) == (0, 2, 0)

# No matches

Expand All @@ -1060,9 +1085,12 @@ def test_merge_insert_conditional_upsert_example(tmp_path: Path):
}
)

dataset.merge_insert("id").when_matched_update_all(
"target.txNumber < source.txNumber"
).execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert("id")
.when_matched_update_all("target.txNumber < source.txNumber")
.execute(new_table)
)
assert (inserted, updated, deleted) == (0, 0, 0)


def test_merge_insert_source_is_dataset(tmp_path: Path):
Expand All @@ -1085,22 +1113,28 @@ def test_merge_insert_source_is_dataset(tmp_path: Path):

is_new = pc.field("b") == 2

dataset.merge_insert("a").when_not_matched_insert_all().execute(new_dataset)
inserted, updated, deleted = (
dataset.merge_insert("a").when_not_matched_insert_all().execute(new_dataset)
)
table = dataset.to_table()
assert table.num_rows == 1300
assert table.filter(is_new).num_rows == 300
assert (inserted, updated, deleted) == (300, 0, 0)

dataset = lance.dataset(tmp_path / "dataset", version=version)
dataset.restore()

reader = new_dataset.to_batches()

dataset.merge_insert("a").when_not_matched_insert_all().execute(
reader, schema=new_dataset.schema
inserted, updated, deleted = (
dataset.merge_insert("a")
.when_not_matched_insert_all()
.execute(reader, schema=new_dataset.schema)
)
table = dataset.to_table()
assert table.num_rows == 1300
assert table.filter(is_new).num_rows == 300
assert (inserted, updated, deleted) == (300, 0, 0)


def test_merge_insert_multiple_keys(tmp_path: Path):
Expand Down Expand Up @@ -1132,10 +1166,13 @@ def test_merge_insert_multiple_keys(tmp_path: Path):

is_new = pc.field("b") == 2

dataset.merge_insert(["a", "c"]).when_matched_update_all().execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert(["a", "c"]).when_matched_update_all().execute(new_table)
)
table = dataset.to_table()
assert table.num_rows == 1000
assert table.filter(is_new).num_rows == 350
assert (inserted, updated, deleted) == (0, 350, 0)


def test_merge_insert_incompatible_schema(tmp_path: Path):
Expand All @@ -1157,7 +1194,10 @@ def test_merge_insert_incompatible_schema(tmp_path: Path):
)

with pytest.raises(OSError):
dataset.merge_insert("a").when_matched_update_all().execute(new_table)
inserted, updated, deleted = (
dataset.merge_insert("a").when_matched_update_all().execute(new_table)
)
assert (inserted, updated, deleted) == (None, None, None)


def test_merge_insert_vector_column(tmp_path: Path):
Expand All @@ -1179,10 +1219,12 @@ def test_merge_insert_vector_column(tmp_path: Path):
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
)

dataset.merge_insert(
["key"]
).when_not_matched_insert_all().when_matched_update_all().execute(new_table)

inserted, updated, deleted = (
dataset.merge_insert(["key"])
.when_not_matched_insert_all()
.when_matched_update_all()
.execute(new_table)
)
expected = pa.Table.from_pydict(
{
"vec": pa.array(
Expand All @@ -1193,6 +1235,7 @@ def test_merge_insert_vector_column(tmp_path: Path):
)

assert dataset.to_table().sort_by("key") == expected
assert (inserted, updated, deleted) == (1, 1, 0)


def test_update_dataset(tmp_path: Path):
Expand Down Expand Up @@ -1512,7 +1555,7 @@ def test_scan_with_row_ids(tmp_path: Path):
tbl = ds.scanner(filter="a % 10 == 0 AND a < 500", with_row_id=True).to_table()
assert "_rowid" in tbl.column_names
row_ids = tbl["_rowid"].to_pylist()
assert row_ids == list(range(0, 250, 10)) + list(range(2**32, 2**32 + 250, 10))
assert row_ids == list(range(0, 250, 10)) + list(range(2 ** 32, 2 ** 32 + 250, 10))

tbl2 = ds._take_rows(row_ids)
assert tbl2["a"] == tbl["a"]
Expand Down
4 changes: 2 additions & 2 deletions python/python/tests/test_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_nearest(tmp_path):

schema = pa.schema([pa.field("emb", pa.list_(pa.float32(), 32), False)])
npvals = np.random.rand(100, 32)
npvals /= np.sqrt((npvals**2).sum(axis=1))[:, None]
npvals /= np.sqrt((npvals ** 2).sum(axis=1))[:, None]
values = pa.array(npvals.ravel(), type=pa.float32())
arr = pa.FixedSizeListArray.from_arrays(values, 32)
tbl = pa.Table.from_arrays([arr], schema=schema)
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_create_index_shuffle_params(tmp_path):
def _create_dataset(uri, num_batches=1):
schema = pa.schema([pa.field("emb", pa.list_(pa.float32(), 32), False)])
npvals = np.random.rand(1000, 32)
npvals /= np.sqrt((npvals**2).sum(axis=1))[:, None]
npvals /= np.sqrt((npvals ** 2).sum(axis=1))[:, None]
values = pa.array(npvals.ravel(), type=pa.float32())
arr = pa.FixedSizeListArray.from_arrays(values, 32)
tbl = pa.Table.from_arrays([arr], schema=schema)
Expand Down
4 changes: 2 additions & 2 deletions python/python/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
ray = pytest.importorskip("ray")


from lance.ray.sink import ( # noqa: E402
from lance.ray.sink import (
LanceCommitter,
LanceDatasink,
LanceDatasink, # noqa: E402
LanceFragmentWriter,
_register_hooks,
)
Expand Down
9 changes: 6 additions & 3 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl MergeInsertBuilder {
Ok(slf)
}

pub fn execute(&mut self, new_data: &PyAny) -> PyResult<()> {
pub fn execute(&mut self, new_data: &PyAny) -> PyResult<(u64, u64, u64)> {
let py = new_data.py();

let new_data: Box<dyn RecordBatchReader + Send> = if new_data.is_instance_of::<Scanner>() {
Expand All @@ -199,9 +199,12 @@ impl MergeInsertBuilder {

let dataset = self.dataset.as_ref(py);

dataset.borrow_mut().ds = new_self;
dataset.borrow_mut().ds = new_self.0;
let inserted = new_self.1;
let updated = new_self.2;
let deleted = new_self.3;

Ok(())
Ok((inserted, updated, deleted))
}
}

Expand Down
Loading
Loading