From 83ecc0156a82a28b2524163da3465dc56b2bcec5 Mon Sep 17 00:00:00 2001 From: Raunak Shah Date: Wed, 22 May 2024 14:57:59 -0700 Subject: [PATCH] feat: update merge_insert to add statistics for inserted, updated, deleted rows (#2357) Addresses https://github.com/lancedb/lance/issues/2019 --- benchmarks/dbpedia-openai/benchmarks.py | 3 +- benchmarks/flat/benchmark.py | 1 - benchmarks/full_report/_lib.py | 25 ++-- benchmarks/sift/index.py | 1 - benchmarks/tpch/benchmark.py | 17 ++- docs/conf.py | 1 - docs/examples/gcs_example.py | 16 ++- python/python/lance/dataset.py | 16 ++- python/python/tests/test_dataset.py | 106 ++++++++++---- python/src/dataset.rs | 11 +- rust/lance/src/dataset/write/merge_insert.rs | 138 +++++++++++++++---- test_data/v0.10.5/datagen.py | 18 +-- 12 files changed, 249 insertions(+), 104 deletions(-) diff --git a/benchmarks/dbpedia-openai/benchmarks.py b/benchmarks/dbpedia-openai/benchmarks.py index 382dfc2cd4..d3b783aef8 100755 --- a/benchmarks/dbpedia-openai/benchmarks.py +++ b/benchmarks/dbpedia-openai/benchmarks.py @@ -48,7 +48,8 @@ def ground_truth( def compute_recall(gt: np.ndarray, result: np.ndarray) -> float: recalls = [ - np.isin(rst, gt_vector).sum() / rst.shape[0] for (rst, gt_vector) in zip(result, gt) + np.isin(rst, gt_vector).sum() / rst.shape[0] + for (rst, gt_vector) in zip(result, gt) ] return np.mean(recalls) diff --git a/benchmarks/flat/benchmark.py b/benchmarks/flat/benchmark.py index 80c2b4ff93..a56fbf04ca 100755 --- a/benchmarks/flat/benchmark.py +++ b/benchmarks/flat/benchmark.py @@ -17,7 +17,6 @@ import time import lance -import matplotlib.pyplot as plt import numpy as np import pandas as pd import pyarrow as pa diff --git a/benchmarks/full_report/_lib.py b/benchmarks/full_report/_lib.py index bf4be0d973..302953f3b3 100644 --- a/benchmarks/full_report/_lib.py +++ b/benchmarks/full_report/_lib.py @@ -6,9 +6,6 @@ from typing import List import gzip -import lance -import numpy as np -import pyarrow as pa import requests @@ -33,15 +30,15 @@ def cosine(X, Y): def knn( query: np.ndarray, data: np.ndarray, - metric: Literal['L2', 'cosine'], + metric: Literal["L2", "cosine"], k: int, ) -> np.ndarray: - if metric == 'L2': + if metric == "L2": dist = l2 - elif metric == 'cosine': + elif metric == "cosine": dist = cosine else: - raise ValueError('Invalid metric') + raise ValueError("Invalid metric") return np.argpartition(dist(query, data), k, axis=1)[:, 0:k] @@ -51,10 +48,12 @@ def write_lance( ): dims = data.shape[1] - schema = pa.schema([ - pa.field("vec", pa.list_(pa.float32(), dims)), - pa.field("id", pa.uint32(), False), - ]) + schema = pa.schema( + [ + pa.field("vec", pa.list_(pa.float32(), dims)), + pa.field("id", pa.uint32(), False), + ] + ) fsl = pa.FixedSizeListArray.from_arrays( pa.array(data.reshape(-1).astype(np.float32), type=pa.float32()), @@ -65,6 +64,7 @@ def write_lance( lance.write_dataset(t, path) + # NYT _DATA_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/bag-of-words/docword.nytimes.txt.gz" @@ -112,7 +112,8 @@ def _get_nyt_vectors( tfidf = TfidfTransformer().fit_transform(freq) print("computing dense projection") dense_projection = random_projection.GaussianRandomProjection( - n_components=output_dims, random_state=42, + n_components=output_dims, + random_state=42, ).fit_transform(tfidf) dense_projection = dense_projection.astype(np.float32) np.save(_CACHE_PATH, dense_projection) diff --git a/benchmarks/sift/index.py b/benchmarks/sift/index.py index d3929e87c0..2a9c801852 100755 --- a/benchmarks/sift/index.py +++ b/benchmarks/sift/index.py @@ -20,7 +20,6 @@ from subprocess import check_output import lance -import pyarrow as pa def main(): diff --git a/benchmarks/tpch/benchmark.py b/benchmarks/tpch/benchmark.py index 602d029f21..0f56735d4c 100644 --- a/benchmarks/tpch/benchmark.py +++ b/benchmarks/tpch/benchmark.py @@ -1,7 +1,5 @@ # Benchmark performance Lance vs Parquet w/ Tpch Q1 and Q6 import lance -import pandas as pd -import pyarrow as pa import duckdb import sys @@ -46,10 +44,10 @@ num_args = len(sys.argv) assert num_args == 2 -query = '' -if sys.argv[1] == 'q1': +query = "" +if sys.argv[1] == "q1": query = Q1 -elif sys.argv[1] == 'q6': +elif sys.argv[1] == "q6": query = Q6 else: sys.exit("We only support Q1 and Q6 for now") @@ -62,17 +60,18 @@ res1 = duckdb.sql(query).df() end1 = time.time() -print("Lance Latency: ",str(round(end1 - start1, 3)) + 's') +print("Lance Latency: ", str(round(end1 - start1, 3)) + "s") print(res1) ##### Parquet ##### lineitem = None start2 = time.time() # read from parquet and create a view instead of table from it -duckdb.sql("CREATE VIEW lineitem AS SELECT * FROM read_parquet('./dataset/lineitem_sf1.parquet');") +duckdb.sql( + "CREATE VIEW lineitem AS SELECT * FROM read_parquet('./dataset/lineitem_sf1.parquet');" +) res2 = duckdb.sql(query).df() end2 = time.time() -print("Parquet Latency: ",str(round(end2 - start2, 3)) + 's') +print("Parquet Latency: ", str(round(end2 - start2, 3)) + "s") print(res2) - diff --git a/docs/conf.py b/docs/conf.py index d8d2671c99..b2cb7fb8dd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,7 +1,6 @@ # Configuration file for the Sphinx documentation builder. import shutil -import subprocess def run_apidoc(_): diff --git a/docs/examples/gcs_example.py b/docs/examples/gcs_example.py index 606072b2ec..dcb72901e6 100644 --- a/docs/examples/gcs_example.py +++ b/docs/examples/gcs_example.py @@ -1,25 +1,29 @@ -# +# # Lance example loading a dataset from Google Cloud Storage # # You need to set one of the following environment variables in order to authenticate with GS # - GOOGLE_SERVICE_ACCOUNT: location of service account file # - GOOGLE_SERVICE_ACCOUNT_KEY: JSON serialized service account key # -# Follow this doc in order to create an service key: https://cloud.google.com/iam/docs/keys-create-delete +# Follow this doc in order to create an service key: https://cloud.google.com/iam/docs/keys-create-delete # import lance +import pandas as pd ds = lance.dataset("gs://eto-public/datasets/oxford_pet/oxford_pet.lance") count = ds.count_rows() print(f"There are {count} pets") # You can also write to GCS -import pandas as pd + uri = "gs://eto-public/datasets/oxford_pet/example.lance" -lance.write_dataset(pd.DataFrame({"a": pd.array([10], dtype="Int32")}), uri, mode='create') +lance.write_dataset( + pd.DataFrame({"a": pd.array([10], dtype="Int32")}), uri, mode="create" +) assert lance.dataset(uri).version == 1 -lance.write_dataset(pd.DataFrame({"a": pd.array([5], dtype="Int32")}), uri, mode='append') +lance.write_dataset( + pd.DataFrame({"a": pd.array([5], dtype="Int32")}), uri, mode="append" +) assert lance.dataset(uri).version == 2 - diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index bf96f09307..04ee5906ef 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -83,7 +83,9 @@ class MergeInsertBuilder(_MergeInsertBuilder): def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None): """Executes the merge insert operation - There is no return value but the original dataset will be updated. + This function updates the original dataset and returns a dictionary with + information about merge statistics - i.e. the number of inserted, updated, + and deleted rows. Parameters ---------- @@ -97,7 +99,8 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None): source is some kind of generator. """ reader = _coerce_reader(data_obj, schema) - super(MergeInsertBuilder, self).execute(reader) + + return super(MergeInsertBuilder, self).execute(reader) # These next three overrides exist only to document the methods @@ -945,10 +948,11 @@ def merge_insert( >>> dataset = lance.write_dataset(table, "example") >>> new_table = pa.table({"a": [2, 3, 4], "b": ["x", "y", "z"]}) >>> # Perform a "upsert" operation - >>> dataset.merge_insert("a") \\ - ... .when_matched_update_all() \\ - ... .when_not_matched_insert_all() \\ - ... .execute(new_table) + >>> dataset.merge_insert("a") \\ + ... .when_matched_update_all() \\ + ... .when_not_matched_insert_all() \\ + ... .execute(new_table) + {'num_inserted_rows': 1, 'num_updated_rows': 2, 'num_deleted_rows': 0} >>> dataset.to_table().sort_by("a").to_pandas() a b 0 1 b diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 61070001bf..9b3254a86d 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -915,6 +915,14 @@ def test_delete_data(tmp_path: Path): assert dataset.count_rows() == 0 +def check_merge_stats(merge_dict, expected): + assert ( + merge_dict["num_inserted_rows"], + merge_dict["num_updated_rows"], + merge_dict["num_deleted_rows"], + ) == expected + + def test_merge_insert(tmp_path: Path): nrows = 1000 table = pa.Table.from_pydict( @@ -939,59 +947,79 @@ def test_merge_insert(tmp_path: Path): is_new = pc.field("b") == 2 - dataset.merge_insert("a").when_not_matched_insert_all().execute(new_table) + merge_dict = ( + dataset.merge_insert("a").when_not_matched_insert_all().execute(new_table) + ) table = dataset.to_table() assert table.num_rows == 1300 assert table.filter(is_new).num_rows == 300 + check_merge_stats(merge_dict, (300, 0, 0)) dataset = lance.dataset(tmp_path / "dataset", version=version) dataset.restore() - dataset.merge_insert("a").when_matched_update_all().execute(new_table) + merge_dict = dataset.merge_insert("a").when_matched_update_all().execute(new_table) table = dataset.to_table() assert table.num_rows == 1000 assert table.filter(is_new).num_rows == 700 + check_merge_stats(merge_dict, (0, 700, 0)) dataset = lance.dataset(tmp_path / "dataset", version=version) dataset.restore() - dataset.merge_insert( - "a" - ).when_not_matched_insert_all().when_matched_update_all().execute(new_table) + merge_dict = ( + dataset.merge_insert("a") + .when_not_matched_insert_all() + .when_matched_update_all() + .execute(new_table) + ) table = dataset.to_table() assert table.num_rows == 1300 assert table.filter(is_new).num_rows == 1000 + check_merge_stats(merge_dict, (300, 700, 0)) dataset = lance.dataset(tmp_path / "dataset", version=version) dataset.restore() - dataset.merge_insert("a").when_not_matched_insert_all().when_matched_update_all( - "target.c == source.c" - ).execute(new_table) + merge_dict = ( + dataset.merge_insert("a") + .when_not_matched_insert_all() + .when_matched_update_all("target.c == source.c") + .execute(new_table) + ) table = dataset.to_table() assert table.num_rows == 1300 assert table.filter(is_new).num_rows == 650 + check_merge_stats(merge_dict, (300, 350, 0)) dataset = lance.dataset(tmp_path / "dataset", version=version) dataset.restore() - dataset.merge_insert("a").when_not_matched_by_source_delete().execute(new_table) + merge_dict = ( + dataset.merge_insert("a").when_not_matched_by_source_delete().execute(new_table) + ) table = dataset.to_table() assert table.num_rows == 700 assert table.filter(is_new).num_rows == 0 + check_merge_stats(merge_dict, (0, 0, 300)) dataset = lance.dataset(tmp_path / "dataset", version=version) dataset.restore() - dataset.merge_insert("a").when_not_matched_by_source_delete( - "a < 100" - ).when_not_matched_insert_all().execute(new_table) + merge_dict = ( + dataset.merge_insert("a") + .when_not_matched_by_source_delete("a < 100") + .when_not_matched_insert_all() + .execute(new_table) + ) table = dataset.to_table() assert table.num_rows == 1200 assert table.filter(is_new).num_rows == 300 + check_merge_stats(merge_dict, (300, 0, 100)) # If the user doesn't specify anything then the merge_insert is # a no-op and the operation fails dataset = lance.dataset(tmp_path / "dataset", version=version) dataset.restore() with pytest.raises(ValueError): - dataset.merge_insert("a").execute(new_table) + merge_dict = dataset.merge_insert("a").execute(new_table) + check_merge_stats(merge_dict, (None, None, None)) def test_flat_vector_search_with_delete(tmp_path: Path): @@ -1031,9 +1059,11 @@ def test_merge_insert_conditional_upsert_example(tmp_path: Path): } ) - dataset.merge_insert("id").when_matched_update_all( - "target.txNumber < source.txNumber" - ).execute(new_table) + merge_dict = ( + dataset.merge_insert("id") + .when_matched_update_all("target.txNumber < source.txNumber") + .execute(new_table) + ) table = dataset.to_table() @@ -1049,6 +1079,7 @@ def test_merge_insert_conditional_upsert_example(tmp_path: Path): ) assert table.sort_by("id") == expected + check_merge_stats(merge_dict, (0, 2, 0)) # No matches @@ -1060,9 +1091,13 @@ def test_merge_insert_conditional_upsert_example(tmp_path: Path): } ) - dataset.merge_insert("id").when_matched_update_all( - "target.txNumber < source.txNumber" - ).execute(new_table) + merge_dict = ( + dataset.merge_insert("id") + .when_matched_update_all("target.txNumber < source.txNumber") + .execute(new_table) + ) + + check_merge_stats(merge_dict, (0, 0, 0)) def test_merge_insert_source_is_dataset(tmp_path: Path): @@ -1085,22 +1120,28 @@ def test_merge_insert_source_is_dataset(tmp_path: Path): is_new = pc.field("b") == 2 - dataset.merge_insert("a").when_not_matched_insert_all().execute(new_dataset) + merge_dict = ( + dataset.merge_insert("a").when_not_matched_insert_all().execute(new_dataset) + ) table = dataset.to_table() assert table.num_rows == 1300 assert table.filter(is_new).num_rows == 300 + check_merge_stats(merge_dict, (300, 0, 0)) dataset = lance.dataset(tmp_path / "dataset", version=version) dataset.restore() reader = new_dataset.to_batches() - dataset.merge_insert("a").when_not_matched_insert_all().execute( - reader, schema=new_dataset.schema + merge_dict = ( + dataset.merge_insert("a") + .when_not_matched_insert_all() + .execute(reader, schema=new_dataset.schema) ) table = dataset.to_table() assert table.num_rows == 1300 assert table.filter(is_new).num_rows == 300 + check_merge_stats(merge_dict, (300, 0, 0)) def test_merge_insert_multiple_keys(tmp_path: Path): @@ -1132,10 +1173,13 @@ def test_merge_insert_multiple_keys(tmp_path: Path): is_new = pc.field("b") == 2 - dataset.merge_insert(["a", "c"]).when_matched_update_all().execute(new_table) + merge_dict = ( + dataset.merge_insert(["a", "c"]).when_matched_update_all().execute(new_table) + ) table = dataset.to_table() assert table.num_rows == 1000 assert table.filter(is_new).num_rows == 350 + check_merge_stats(merge_dict, (0, 350, 0)) def test_merge_insert_incompatible_schema(tmp_path: Path): @@ -1157,7 +1201,10 @@ def test_merge_insert_incompatible_schema(tmp_path: Path): ) with pytest.raises(OSError): - dataset.merge_insert("a").when_matched_update_all().execute(new_table) + merge_dict = ( + dataset.merge_insert("a").when_matched_update_all().execute(new_table) + ) + check_merge_stats(merge_dict, (None, None, None)) def test_merge_insert_vector_column(tmp_path: Path): @@ -1179,10 +1226,12 @@ def test_merge_insert_vector_column(tmp_path: Path): table, tmp_path / "dataset", mode="create", max_rows_per_file=100 ) - dataset.merge_insert( - ["key"] - ).when_not_matched_insert_all().when_matched_update_all().execute(new_table) - + merge_dict = ( + dataset.merge_insert(["key"]) + .when_not_matched_insert_all() + .when_matched_update_all() + .execute(new_table) + ) expected = pa.Table.from_pydict( { "vec": pa.array( @@ -1193,6 +1242,7 @@ def test_merge_insert_vector_column(tmp_path: Path): ) assert dataset.to_table().sort_by("key") == expected + check_merge_stats(merge_dict, (1, 1, 0)) def test_update_dataset(tmp_path: Path): diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 2f7083b68c..e53b2a0e04 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -175,7 +175,7 @@ impl MergeInsertBuilder { Ok(slf) } - pub fn execute(&mut self, new_data: &PyAny) -> PyResult<()> { + pub fn execute(&mut self, new_data: &PyAny) -> PyResult { let py = new_data.py(); let new_data: Box = if new_data.is_instance_of::() { @@ -199,9 +199,14 @@ impl MergeInsertBuilder { let dataset = self.dataset.as_ref(py); - dataset.borrow_mut().ds = new_self; + dataset.borrow_mut().ds = new_self.0; + let merge_stats = new_self.1; + let merge_dict = PyDict::new(py); + merge_dict.set_item("num_inserted_rows", merge_stats.num_inserted_rows)?; + merge_dict.set_item("num_updated_rows", merge_stats.num_updated_rows)?; + merge_dict.set_item("num_deleted_rows", merge_stats.num_deleted_rows)?; - Ok(()) + Ok(merge_dict.into()) } } diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index d77f2fb041..c6a1db8c56 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -325,7 +325,7 @@ impl MergeInsertJob { pub async fn execute_reader( self, source: Box, - ) -> Result> { + ) -> Result<(Arc, MergeStats)> { let stream = reader_to_stream(source); self.execute(stream).await } @@ -343,6 +343,7 @@ impl MergeInsertJob { async fn join_key_as_scalar_index(&self) -> Result> { if self.params.on.len() != 1 { + // joining on more than one column Ok(None) } else { let col = &self.params.on[0]; @@ -371,7 +372,7 @@ impl MergeInsertJob { let schema = shared_input.schema(); let field = schema.field_with_name(&self.params.on[0])?; let key_only_schema = - lance_core::datatypes::Schema::try_from(&Schema::new(vec![field.clone()]))?; + lance_core::datatypes::Schema::try_from(&Schema::new(vec![field.clone()]))?; // schema for only the key join column let index_mapper_input = Arc::new(ProjectionExec::try_new( shared_input.clone(), Arc::new(key_only_schema), @@ -380,6 +381,7 @@ impl MergeInsertJob { // Then we pass the key column into the index mapper let index_column = self.params.on[0].clone(); let index_mapper = Arc::new(MapIndexExec::new( + // create index from original data and key column self.dataset.clone(), index_column.clone(), index_mapper_input, @@ -459,11 +461,11 @@ impl MergeInsertJob { let new_data = session_ctx.read_one_shot(source)?; let join_cols = self .params - .on + .on // columns to join on .iter() .map(|c| c.as_str()) - .collect::>(); - let joined = new_data.join(existing, JoinType::Full, &join_cols, &join_cols, None)?; + .collect::>(); // vector of strings of col names to join + let joined = new_data.join(existing, JoinType::Full, &join_cols, &join_cols, None)?; // full join Ok(joined.execute_stream().await?) } @@ -473,10 +475,11 @@ impl MergeInsertJob { ) -> Result { // We need to do a full index scan if we're deleting source data let can_use_scalar_index = matches!( - self.params.delete_not_matched_by_source, + self.params.delete_not_matched_by_source, // this value marks behavior for rows in target that are not matched by the source. Value assigned earlier. WhenNotMatchedBySource::Keep ); if can_use_scalar_index { + // keeping unmatched rows, no deletion if let Some(index) = self.join_key_as_scalar_index().await? { self.create_indexed_scan_joined_stream(source, index).await } else { @@ -492,11 +495,15 @@ impl MergeInsertJob { /// /// This will take in the source, merge it with the existing target data, and insert new /// rows, update existing rows, and delete existing rows - pub async fn execute(self, source: SendableRecordBatchStream) -> Result> { + pub async fn execute( + self, + source: SendableRecordBatchStream, + ) -> Result<(Arc, MergeStats)> { let schema = source.schema(); let joined = self.create_joined_stream(source).await?; let merger = Merger::try_new(self.params, schema.clone())?; + let merge_statistics = merger.merge_stats.clone(); let deleted_rows = merger.deleted_rows.clone(); let stream = joined .and_then(move |batch| merger.clone().execute_batch(batch)) @@ -515,17 +522,25 @@ impl MergeInsertJob { // Apply deletions let removed_row_ids = Arc::into_inner(deleted_rows).unwrap().into_inner().unwrap(); + let (old_fragments, removed_fragment_ids) = Self::apply_deletions(&self.dataset, &removed_row_ids).await?; // Commit updated and new fragments - Self::commit( + let committed_ds = Self::commit( self.dataset, removed_fragment_ids, old_fragments, new_fragments, ) - .await + .await?; + + let stats = Arc::into_inner(merge_statistics) + .unwrap() + .into_inner() + .unwrap(); + + Ok((committed_ds, stats)) } // Delete a batch of rows by id, returns the fragments modified and the fragments removed @@ -606,6 +621,19 @@ impl MergeInsertJob { } } +/// Merger will store these statistics as it runs (for each batch) +#[derive(Debug, Default, Clone)] +pub struct MergeStats { + /// Number of inserted rows (for user statistics) + pub num_inserted_rows: u64, + /// Number of updated rows (for user statistics) + pub num_updated_rows: u64, + /// Number of deleted rows (for user statistics) + /// Note: This is different from internal references to 'deleted_rows', since we technically "delete" updated rows during processing. + /// However those rows are not shared with the user. + pub num_deleted_rows: u64, +} + // A sync-safe structure that is shared by all of the "process batch" tasks. // // Note: we are not currently using parallelism but this still needs to be sync because it is @@ -616,6 +644,8 @@ struct Merger { deleted_rows: Arc>, // Physical delete expression, only set if params.delete_not_matched_by_source is DeleteIf delete_expr: Option>, + // User statistics for merging + merge_stats: Arc>, // Physical "when matched update if" expression, only set if params.when_matched is UpdateIf match_filter_expr: Option>, // The parameters controlling the merge @@ -657,6 +687,7 @@ impl Merger { Ok(Self { deleted_rows: Arc::new(Mutex::new(RoaringTreemap::new())), delete_expr, + merge_stats: Arc::new(Mutex::new(MergeStats::default())), match_filter_expr, params, schema, @@ -721,6 +752,7 @@ impl Merger { batch: RecordBatch, ) -> datafusion::common::Result>> { + let mut merge_statistics = self.merge_stats.lock().unwrap(); let num_fields = batch.schema().fields.len(); // The schema of the combined batches will be: // source_keys, source_payload, target_keys, target_payload, row_id @@ -743,6 +775,7 @@ impl Merger { if self.params.when_matched != WhenMatched::DoNothing { let mut matched = arrow::compute::filter_record_batch(&batch, &in_both)?; + if let Some(match_filter) = self.match_filter_expr { let unzipped = unzip_batch(&matched, &self.schema); let filtered = match_filter.evaluate(&unzipped)?; @@ -761,6 +794,9 @@ impl Merger { } } } + + merge_statistics.num_updated_rows = matched.num_rows() as u64; + // If the filter eliminated all rows then its important we don't try and write // the batch at all. Writing an empty batch currently panics if matched.num_rows() > 0 { @@ -787,11 +823,14 @@ impl Merger { self.schema.clone(), Vec::from_iter(not_matched.columns().iter().cloned()), )?; + + merge_statistics.num_inserted_rows = not_matched.num_rows() as u64; batches.push(Ok(not_matched)); } match self.params.delete_not_matched_by_source { WhenNotMatchedBySource::Delete => { let unmatched = arrow::compute::filter(batch.column(row_id_col), &right_only)?; + merge_statistics.num_deleted_rows = unmatched.len() as u64; let row_ids = unmatched.as_primitive::(); deleted_row_ids.extend(row_ids.values()); } @@ -800,6 +839,7 @@ impl Merger { let unmatched = arrow::compute::filter_record_batch(&target_data, &right_only)?; let row_id_col = unmatched.num_columns() - 1; let to_delete = self.delete_expr.unwrap().evaluate(&unmatched)?; + match to_delete { ColumnarValue::Array(mask) => { let row_ids = arrow::compute::filter( @@ -807,11 +847,13 @@ impl Merger { mask.as_boolean(), )?; let row_ids = row_ids.as_primitive::(); + merge_statistics.num_deleted_rows = row_ids.len() as u64; deleted_row_ids.extend(row_ids.values()); } ColumnarValue::Scalar(scalar) => { if let ScalarValue::Boolean(Some(true)) = scalar { let row_ids = unmatched.column(row_id_col).as_primitive::(); + merge_statistics.num_deleted_rows = row_ids.len() as u64; deleted_row_ids.extend(row_ids.values()); } } @@ -819,6 +861,7 @@ impl Merger { } WhenNotMatchedBySource::Keep => {} } + Ok(stream::iter(batches)) } } @@ -845,6 +888,7 @@ mod tests { mut job: MergeInsertJob, keys_from_left: &[u32], keys_from_right: &[u32], + stats: &[u64], ) { let mut dataset = (*job.dataset).clone(); dataset.restore().await.unwrap(); @@ -854,7 +898,7 @@ mod tests { let new_reader = Box::new(RecordBatchIterator::new([Ok(new_data)], schema.clone())); let new_stream = reader_to_stream(new_reader); - let merged_dataset = job.execute(new_stream).await.unwrap(); + let (merged_dataset, merge_stats) = job.execute(new_stream).await.unwrap(); let batches = merged_dataset .scan() @@ -895,6 +939,9 @@ mod tests { right_keys.sort(); assert_eq!(left_keys, keys_from_left); assert_eq!(right_keys, keys_from_right); + assert_eq!(merge_stats.num_inserted_rows, stats[0]); + assert_eq!(merge_stats.num_updated_rows, stats[1]); + assert_eq!(merge_stats.num_deleted_rows, stats[2]); } #[tokio::test] @@ -940,7 +987,14 @@ mod tests { .unwrap() .try_build() .unwrap(); - check(new_batch.clone(), job, &[1, 2, 3, 4, 5, 6], &[7, 8, 9]).await; + check( + new_batch.clone(), + job, + &[1, 2, 3, 4, 5, 6], + &[7, 8, 9], + &[3, 0, 0], + ) + .await; // upsert, no delete let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -948,7 +1002,14 @@ mod tests { .when_matched(WhenMatched::UpdateAll) .try_build() .unwrap(); - check(new_batch.clone(), job, &[1, 2, 3], &[4, 5, 6, 7, 8, 9]).await; + check( + new_batch.clone(), + job, + &[1, 2, 3], + &[4, 5, 6, 7, 8, 9], + &[3, 3, 0], + ) + .await; // conditional upsert, no delete let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -958,7 +1019,14 @@ mod tests { ) .try_build() .unwrap(); - check(new_batch.clone(), job, &[1, 2, 3, 4, 5], &[6, 7, 8, 9]).await; + check( + new_batch.clone(), + job, + &[1, 2, 3, 4, 5], + &[6, 7, 8, 9], + &[3, 1, 0], + ) + .await; // conditional update, no matches let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -967,7 +1035,7 @@ mod tests { .when_matched(WhenMatched::update_if(&ds, "target.filterme = 'z'").unwrap()) .try_build() .unwrap(); - check(new_batch.clone(), job, &[1, 2, 3, 4, 5, 6], &[]).await; + check(new_batch.clone(), job, &[1, 2, 3, 4, 5, 6], &[], &[0, 0, 0]).await; // update only, no delete (useful for bulk update) let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -976,7 +1044,7 @@ mod tests { .when_not_matched(WhenNotMatched::DoNothing) .try_build() .unwrap(); - check(new_batch.clone(), job, &[1, 2, 3], &[4, 5, 6]).await; + check(new_batch.clone(), job, &[1, 2, 3], &[4, 5, 6], &[0, 3, 0]).await; // Conditional update let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -987,7 +1055,7 @@ mod tests { .when_not_matched(WhenNotMatched::DoNothing) .try_build() .unwrap(); - check(new_batch.clone(), job, &[1, 2, 3, 6], &[4, 5]).await; + check(new_batch.clone(), job, &[1, 2, 3, 6], &[4, 5], &[0, 2, 0]).await; // No-op (will raise an error) assert!(MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -1002,7 +1070,7 @@ mod tests { .when_not_matched_by_source(WhenNotMatchedBySource::Delete) .try_build() .unwrap(); - check(new_batch.clone(), job, &[4, 5, 6], &[7, 8, 9]).await; + check(new_batch.clone(), job, &[4, 5, 6], &[7, 8, 9], &[3, 0, 3]).await; // upsert, with delete all let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -1011,7 +1079,7 @@ mod tests { .when_not_matched_by_source(WhenNotMatchedBySource::Delete) .try_build() .unwrap(); - check(new_batch.clone(), job, &[], &[4, 5, 6, 7, 8, 9]).await; + check(new_batch.clone(), job, &[], &[4, 5, 6, 7, 8, 9], &[3, 3, 3]).await; // update only, with delete all (unusual) let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -1021,7 +1089,7 @@ mod tests { .when_not_matched_by_source(WhenNotMatchedBySource::Delete) .try_build() .unwrap(); - check(new_batch.clone(), job, &[], &[4, 5, 6]).await; + check(new_batch.clone(), job, &[], &[4, 5, 6], &[0, 3, 3]).await; // just delete all (not real case, just use delete) let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -1030,7 +1098,7 @@ mod tests { .when_not_matched_by_source(WhenNotMatchedBySource::Delete) .try_build() .unwrap(); - check(new_batch.clone(), job, &[4, 5, 6], &[]).await; + check(new_batch.clone(), job, &[4, 5, 6], &[], &[0, 0, 3]).await; // For the "delete some" tests we use key > 1 let condition = Expr::gt( @@ -1043,7 +1111,14 @@ mod tests { .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone())) .try_build() .unwrap(); - check(new_batch.clone(), job, &[1, 4, 5, 6], &[7, 8, 9]).await; + check( + new_batch.clone(), + job, + &[1, 4, 5, 6], + &[7, 8, 9], + &[3, 0, 2], + ) + .await; // upsert, with delete some let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -1052,9 +1127,16 @@ mod tests { .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone())) .try_build() .unwrap(); - check(new_batch.clone(), job, &[1], &[4, 5, 6, 7, 8, 9]).await; + check( + new_batch.clone(), + job, + &[1], + &[4, 5, 6, 7, 8, 9], + &[3, 3, 2], + ) + .await; - // update only, with delete some (unusual) + // update only, witxh delete some (unusual) let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) .unwrap() .when_matched(WhenMatched::UpdateAll) @@ -1062,7 +1144,7 @@ mod tests { .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone())) .try_build() .unwrap(); - check(new_batch.clone(), job, &[1], &[4, 5, 6]).await; + check(new_batch.clone(), job, &[1], &[4, 5, 6], &[0, 3, 2]).await; // just delete some (not real case, just use delete) let job = MergeInsertBuilder::try_new(ds.clone(), keys.clone()) @@ -1071,7 +1153,7 @@ mod tests { .when_not_matched_by_source(WhenNotMatchedBySource::DeleteIf(condition.clone())) .try_build() .unwrap(); - check(new_batch.clone(), job, &[1, 4, 5, 6], &[]).await; + check(new_batch.clone(), job, &[1, 4, 5, 6], &[], &[0, 0, 2]).await; } #[tokio::test] @@ -1140,7 +1222,7 @@ mod tests { )); // Run merge_insert - let ds = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + let (ds, _) = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) .unwrap() .when_not_matched(WhenNotMatched::DoNothing) .when_matched(WhenMatched::UpdateAll) @@ -1166,7 +1248,7 @@ mod tests { schema.clone(), )); // Run merge_insert - let ds = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + let (ds, _) = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) .unwrap() .when_not_matched(WhenNotMatched::DoNothing) .when_matched(WhenMatched::UpdateAll) @@ -1186,7 +1268,7 @@ mod tests { )); // Run merge_insert one last time. The index is now completely out of date. Every // row it points to is a deleted row. Make sure that doesn't break. - let ds = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) + let (ds, _) = MergeInsertBuilder::try_new(ds.clone(), vec!["key".to_string()]) .unwrap() .when_not_matched(WhenNotMatched::DoNothing) .when_matched(WhenMatched::UpdateAll) diff --git a/test_data/v0.10.5/datagen.py b/test_data/v0.10.5/datagen.py index 2e77fd9784..5b4d5c2369 100644 --- a/test_data/v0.10.5/datagen.py +++ b/test_data/v0.10.5/datagen.py @@ -22,12 +22,14 @@ dataset.add_columns({"b": "x * 4", "c": "x * 5"}) # This is the bug: b and c will show data from z and a. -assert dataset.to_table() == pa.table({ - "x": range(4), - "y": [0, 2, 4, 6], - "b": [0, 3, 6, 9], - "c": [0, -1, -2, -3], -}) - -fragment_sizes = { len(frag.data_files()) for frag in dataset.get_fragments() } +assert dataset.to_table() == pa.table( + { + "x": range(4), + "y": [0, 2, 4, 6], + "b": [0, 3, 6, 9], + "c": [0, -1, -2, -3], + } +) + +fragment_sizes = {len(frag.data_files()) for frag in dataset.get_fragments()} assert fragment_sizes == {4, 2}