diff --git a/e2e_test/streaming/temporal_join.slt b/e2e_test/streaming/temporal_join/temporal_join.slt similarity index 100% rename from e2e_test/streaming/temporal_join.slt rename to e2e_test/streaming/temporal_join/temporal_join.slt diff --git a/e2e_test/streaming/temporal_join/temporal_join_with_index.slt b/e2e_test/streaming/temporal_join/temporal_join_with_index.slt new file mode 100644 index 000000000000..f714cefcc51b --- /dev/null +++ b/e2e_test/streaming/temporal_join/temporal_join_with_index.slt @@ -0,0 +1,84 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + +statement ok +create table version(id2 int, a2 int, b2 int, primary key (id2)); + +statement ok +create index idx on version (a2); + +statement ok +create materialized view v as select id1, a1, id2, a2 from stream left join idx FOR SYSTEM_TIME AS OF PROCTIME() on b1 = b2 and a1 = a2; + +statement ok +insert into stream values(1, 11, 111); + +statement ok +insert into version values(1, 11, 111); + +statement ok +insert into version values(9, 11, 111); + +statement ok +insert into stream values(1, 11, 111); + +statement ok +delete from version; + +query IIII rowsort +select * from v; +---- +1 11 1 11 +1 11 9 11 +1 11 NULL NULL + +statement ok +insert into version values(2, 22, 222); + +statement ok +insert into stream values(2, 22, 222); + +statement ok +insert into version values(8, 22, 222); + +statement ok +insert into stream values(2, 22, 222); + +query IIII rowsort +select * from v; +---- +1 11 1 11 +1 11 9 11 +1 11 NULL NULL +2 22 2 22 +2 22 2 22 +2 22 8 22 + +statement ok +update version set b2 = 333 where id2 = 2; + +statement ok +insert into stream values(2, 22, 222); + +query IIII rowsort +select * from v; +---- +1 11 1 11 +1 11 9 11 +1 11 NULL NULL +2 22 2 22 +2 22 2 22 +2 22 8 22 +2 22 8 22 + +statement ok +drop materialized view v; + +statement ok +drop table stream; + +statement ok +drop table version; diff --git a/src/frontend/planner_test/tests/testdata/input/temporal_join.yaml b/src/frontend/planner_test/tests/testdata/input/temporal_join.yaml index e42cec784158..eeb19aba2548 100644 --- a/src/frontend/planner_test/tests/testdata/input/temporal_join.yaml +++ b/src/frontend/planner_test/tests/testdata/input/temporal_join.yaml @@ -88,3 +88,35 @@ join version2 FOR SYSTEM_TIME AS OF PROCTIME() on stream.id2 = version2.id2 where a1 < 10; expected_outputs: - stream_plan +- name: temporal join with an index (distribution key size = 1) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (a2, b2) distributed by (a2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + expected_outputs: + - stream_plan +- name: temporal join with an index (distribution key size = 2) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (a2, b2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + expected_outputs: + - stream_plan +- name: temporal join with an index (index column size = 1) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (b2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + expected_outputs: + - stream_plan +- name: temporal join with singleton table + sql: | + create table t (a int) append only; + create materialized view v as select count(*) from t; + select * from t left join v FOR SYSTEM_TIME AS OF PROCTIME() on a = count; + expected_outputs: + - stream_plan + diff --git a/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml b/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml index c1c920e6f97c..2b93cce79725 100644 --- a/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml +++ b/src/frontend/planner_test/tests/testdata/output/temporal_join.yaml @@ -152,3 +152,55 @@ │ └─StreamTableScan { table: version1, columns: [version1.id1, version1.x1], pk: [version1.id1], dist: UpstreamHashShard(version1.id1) } └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(version2.id2) } └─StreamTableScan { table: version2, columns: [version2.id2, version2.x2], pk: [version2.id2], dist: UpstreamHashShard(version2.id2) } +- name: temporal join with an index (distribution key size = 1) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (a2, b2) distributed by (a2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + stream_plan: |- + StreamMaterialize { columns: [id1, a1, id2, a2, stream._row_id(hidden), stream.b1(hidden)], stream_key: [stream._row_id, id2, a1, stream.b1], pk_columns: [stream._row_id, id2, a1, stream.b1], pk_conflict: NoCheck } + └─StreamTemporalJoin { type: LeftOuter, predicate: stream.a1 = idx2.a2 AND stream.b1 = idx2.b2, output: [stream.id1, stream.a1, idx2.id2, idx2.a2, stream._row_id, stream.b1] } + ├─StreamExchange { dist: HashShard(stream.a1) } + │ └─StreamTableScan { table: stream, columns: [stream.id1, stream.a1, stream.b1, stream._row_id], pk: [stream._row_id], dist: UpstreamHashShard(stream._row_id) } + └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(idx2.a2) } + └─StreamTableScan { table: idx2, columns: [idx2.a2, idx2.b2, idx2.id2], pk: [idx2.id2], dist: UpstreamHashShard(idx2.a2) } +- name: temporal join with an index (distribution key size = 2) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (a2, b2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + stream_plan: |- + StreamMaterialize { columns: [id1, a1, id2, a2, stream._row_id(hidden), stream.b1(hidden)], stream_key: [stream._row_id, id2, a1, stream.b1], pk_columns: [stream._row_id, id2, a1, stream.b1], pk_conflict: NoCheck } + └─StreamTemporalJoin { type: LeftOuter, predicate: stream.a1 = idx2.a2 AND stream.b1 = idx2.b2, output: [stream.id1, stream.a1, idx2.id2, idx2.a2, stream._row_id, stream.b1] } + ├─StreamExchange { dist: HashShard(stream.a1, stream.b1) } + │ └─StreamTableScan { table: stream, columns: [stream.id1, stream.a1, stream.b1, stream._row_id], pk: [stream._row_id], dist: UpstreamHashShard(stream._row_id) } + └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(idx2.a2, idx2.b2) } + └─StreamTableScan { table: idx2, columns: [idx2.a2, idx2.b2, idx2.id2], pk: [idx2.id2], dist: UpstreamHashShard(idx2.a2, idx2.b2) } +- name: temporal join with an index (index column size = 1) + sql: | + create table stream(id1 int, a1 int, b1 int) APPEND ONLY; + create table version(id2 int, a2 int, b2 int, primary key (id2)); + create index idx2 on version (b2); + select id1, a1, id2, a2 from stream left join idx2 FOR SYSTEM_TIME AS OF PROCTIME() on a1 = a2 and b1 = b2; + stream_plan: |- + StreamMaterialize { columns: [id1, a1, id2, a2, stream._row_id(hidden), stream.b1(hidden)], stream_key: [stream._row_id, id2, stream.b1, a1], pk_columns: [stream._row_id, id2, stream.b1, a1], pk_conflict: NoCheck } + └─StreamTemporalJoin { type: LeftOuter, predicate: stream.b1 = idx2.b2 AND (stream.a1 = idx2.a2), output: [stream.id1, stream.a1, idx2.id2, idx2.a2, stream._row_id, stream.b1] } + ├─StreamExchange { dist: HashShard(stream.b1) } + │ └─StreamTableScan { table: stream, columns: [stream.id1, stream.a1, stream.b1, stream._row_id], pk: [stream._row_id], dist: UpstreamHashShard(stream._row_id) } + └─StreamExchange [no_shuffle] { dist: UpstreamHashShard(idx2.b2) } + └─StreamTableScan { table: idx2, columns: [idx2.b2, idx2.id2, idx2.a2], pk: [idx2.id2], dist: UpstreamHashShard(idx2.b2) } +- name: temporal join with singleton table + sql: | + create table t (a int) append only; + create materialized view v as select count(*) from t; + select * from t left join v FOR SYSTEM_TIME AS OF PROCTIME() on a = count; + stream_plan: |- + StreamMaterialize { columns: [a, count, t._row_id(hidden), $expr1(hidden)], stream_key: [t._row_id, $expr1], pk_columns: [t._row_id, $expr1], pk_conflict: NoCheck } + └─StreamTemporalJoin { type: LeftOuter, predicate: AND ($expr1 = v.count), output: [t.a, v.count, t._row_id, $expr1] } + ├─StreamExchange { dist: Single } + │ └─StreamProject { exprs: [t.a, t.a::Int64 as $expr1, t._row_id] } + │ └─StreamTableScan { table: t, columns: [t.a, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } + └─StreamExchange [no_shuffle] { dist: Single } + └─StreamTableScan { table: v, columns: [v.count], pk: [], dist: Single } diff --git a/src/frontend/src/optimizer/plan_node/eq_join_predicate.rs b/src/frontend/src/optimizer/plan_node/eq_join_predicate.rs index cd35ffd83e15..4539d17b2acd 100644 --- a/src/frontend/src/optimizer/plan_node/eq_join_predicate.rs +++ b/src/frontend/src/optimizer/plan_node/eq_join_predicate.rs @@ -14,6 +14,7 @@ use std::fmt; +use itertools::Itertools; use risingwave_common::catalog::Schema; use crate::expr::{ @@ -270,6 +271,43 @@ impl EqJoinPredicate { ) } + /// Retain the prefix of `eq_keys` based on the `prefix_len`. The other part is moved to the + /// other condition. + pub fn retain_prefix_eq_key(self, prefix_len: usize) -> Self { + assert!(prefix_len <= self.eq_keys.len()); + let (retain_eq_key, other_eq_key) = self.eq_keys.split_at(prefix_len); + let mut new_other_conjunctions = self.other_cond.conjunctions; + new_other_conjunctions.extend( + other_eq_key + .iter() + .cloned() + .map(|(l, r, null_safe)| { + FunctionCall::new( + if null_safe { + ExprType::IsNotDistinctFrom + } else { + ExprType::Equal + }, + vec![l.into(), r.into()], + ) + .unwrap() + .into() + }) + .collect_vec(), + ); + + let new_other_cond = Condition { + conjunctions: new_other_conjunctions, + }; + + Self::new( + new_other_cond, + retain_eq_key.to_owned(), + self.left_cols_num, + self.right_cols_num, + ) + } + pub fn rewrite_exprs(&self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self { let mut new = self.clone(); new.other_cond = new.other_cond.rewrite_expr(rewriter); diff --git a/src/frontend/src/optimizer/plan_node/logical_join.rs b/src/frontend/src/optimizer/plan_node/logical_join.rs index 6911bc3f5869..0d72c1d5fb1c 100644 --- a/src/frontend/src/optimizer/plan_node/logical_join.rs +++ b/src/frontend/src/optimizer/plan_node/logical_join.rs @@ -366,19 +366,19 @@ impl LogicalJoin { .expect("dist_key must in order_key"); dist_key_in_order_key_pos.push(pos); } - // The at least prefix of order key that contains distribution key. - let at_least_prefix_len = dist_key_in_order_key_pos + // The shortest prefix of order key that contains distribution key. + let shortest_prefix_len = dist_key_in_order_key_pos .iter() .max() .map_or(0, |pos| pos + 1); // Distributed lookup join can't support lookup table with a singleton distribution. - if at_least_prefix_len == 0 { + if shortest_prefix_len == 0 { return None; } // Reorder the join equal predicate to match the order key. - let mut reorder_idx = Vec::with_capacity(at_least_prefix_len); + let mut reorder_idx = Vec::with_capacity(shortest_prefix_len); for order_col_id in order_col_ids { let mut found = false; for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() { @@ -392,7 +392,7 @@ impl LogicalJoin { break; } } - if reorder_idx.len() < at_least_prefix_len { + if reorder_idx.len() < shortest_prefix_len { return None; } let lookup_prefix_len = reorder_idx.len(); @@ -966,18 +966,6 @@ impl LogicalJoin { ) -> Result { assert!(predicate.has_eq()); - let left = self.left().to_stream_with_dist_required( - &RequiredDist::shard_by_key(self.left().schema().len(), &predicate.left_eq_indexes()), - ctx, - )?; - - if !left.append_only() { - return Err(RwError::from(ErrorCode::NotSupported( - "Temporal join requires an append-only left input".into(), - "Please ensure your left input is append-only".into(), - ))); - } - let right = self.right(); let Some(logical_scan) = right.as_logical_scan() else { return Err(RwError::from(ErrorCode::NotSupported( @@ -994,30 +982,76 @@ impl LogicalJoin { } let table_desc = logical_scan.table_desc(); + let output_column_ids = logical_scan.output_column_ids(); - // Verify that right join key columns are the primary key of the lookup table. + // Verify that the right join key columns are the the prefix of the primary key and + // also contain the distribution key. let order_col_ids = table_desc.order_column_ids(); - let order_col_ids_len = order_col_ids.len(); - let output_column_ids = logical_scan.output_column_ids(); + let order_key = table_desc.order_column_indices(); + let dist_key = table_desc.distribution_key.clone(); + + let mut dist_key_in_order_key_pos = vec![]; + for d in dist_key { + let pos = order_key + .iter() + .position(|&x| x == d) + .expect("dist_key must in order_key"); + dist_key_in_order_key_pos.push(pos); + } + // The shortest prefix of order key that contains distribution key. + let shortest_prefix_len = dist_key_in_order_key_pos + .iter() + .max() + .map_or(0, |pos| pos + 1); // Reorder the join equal predicate to match the order key. - let mut reorder_idx = vec![]; + let mut reorder_idx = Vec::with_capacity(shortest_prefix_len); for order_col_id in order_col_ids { + let mut found = false; for (i, eq_idx) in predicate.right_eq_indexes().into_iter().enumerate() { if order_col_id == output_column_ids[eq_idx] { reorder_idx.push(i); + found = true; break; } } + if !found { + break; + } } - if order_col_ids_len != predicate.eq_keys().len() || reorder_idx.len() < order_col_ids_len { + if reorder_idx.len() < shortest_prefix_len { + // TODO: support index selection for temporal join and refine this error message. return Err(RwError::from(ErrorCode::NotSupported( "Temporal join requires the lookup table's primary key contained exactly in the equivalence condition".into(), "Please add the primary key of the lookup table to the join condition and remove any other conditions".into(), ))); } + let lookup_prefix_len = reorder_idx.len(); let predicate = predicate.reorder(&reorder_idx); + let left = if dist_key_in_order_key_pos.is_empty() { + self.left() + .to_stream_with_dist_required(&RequiredDist::single(), ctx)? + } else { + let left_eq_indexes = predicate.left_eq_indexes(); + let left_dist_key = dist_key_in_order_key_pos + .iter() + .map(|pos| left_eq_indexes[*pos]) + .collect_vec(); + + self.left().to_stream_with_dist_required( + &RequiredDist::shard_by_key(self.left().schema().len(), &left_dist_key), + ctx, + )? + }; + + if !left.append_only() { + return Err(RwError::from(ErrorCode::NotSupported( + "Temporal join requires an append-only left input".into(), + "Please ensure your left input is append-only".into(), + ))); + } + // Extract the predicate from logical scan. Only pure scan is supported. let (new_scan, scan_predicate, project_expr) = logical_scan.predicate_pull_up(); // Construct output column to require column mapping @@ -1090,6 +1124,8 @@ impl LogicalJoin { new_join_output_indices, ); + let new_predicate = new_predicate.retain_prefix_eq_key(lookup_prefix_len); + Ok(StreamTemporalJoin::new(new_logical_join, new_predicate).into()) } diff --git a/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs b/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs index c3a8b4ab7b1d..f9fb325b8af8 100644 --- a/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_temporal_join.rs @@ -21,7 +21,6 @@ use risingwave_pb::stream_plan::TemporalJoinNode; use super::utils::{childless_record, watermark_pretty, Distill}; use super::{generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamNode}; use crate::expr::{Expr, ExprRewriter}; -use crate::optimizer::plan_node::generic::GenericPlanRef; use crate::optimizer::plan_node::plan_tree_node::PlanTreeNodeUnary; use crate::optimizer::plan_node::stream::StreamPlanRef; use crate::optimizer::plan_node::utils::IndicesDisplay; @@ -42,7 +41,6 @@ impl StreamTemporalJoin { pub fn new(logical: generic::Join, eq_join_predicate: EqJoinPredicate) -> Self { assert!(logical.join_type == JoinType::Inner || logical.join_type == JoinType::LeftOuter); assert!(logical.left.append_only()); - assert!(logical.right.logical_pk() == eq_join_predicate.right_eq_indexes()); let right = logical.right.clone(); let exchange: &StreamExchange = right .as_stream_exchange() diff --git a/src/stream/src/executor/temporal_join.rs b/src/stream/src/executor/temporal_join.rs index 5b96193614b4..dc12b6ccdef0 100644 --- a/src/stream/src/executor/temporal_join.rs +++ b/src/stream/src/executor/temporal_join.rs @@ -13,22 +13,29 @@ // limitations under the License. use std::alloc::Global; +use std::collections::HashMap; +use std::ops::{Deref, DerefMut}; use std::pin::pin; use std::sync::Arc; use either::Either; use futures::stream::{self, PollNext}; -use futures::{StreamExt, TryStreamExt}; +use futures::{pin_mut, StreamExt, TryStreamExt}; use futures_async_stream::try_stream; use local_stats_alloc::{SharedStatsAlloc, StatsAlloc}; use lru::DefaultHasher; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::catalog::Schema; +use risingwave_common::estimate_size::{EstimateSize, KvSize}; +use risingwave_common::hash::{HashKey, NullBitmap}; use risingwave_common::row::{OwnedRow, Row, RowExt}; -use risingwave_common::util::iter_util::ZipEqFast; +use risingwave_common::types::DataType; +use risingwave_common::util::iter_util::ZipEqDebug; use risingwave_expr::expr::BoxedExpression; use risingwave_hummock_sdk::{HummockEpoch, HummockReadEpoch}; +use risingwave_storage::store::PrefetchOptions; use risingwave_storage::table::batch_table::storage_table::StorageTable; +use risingwave_storage::table::TableIter; use risingwave_storage::StateStore; use super::{Barrier, Executor, Message, MessageStream, StreamExecutorError, StreamExecutorResult}; @@ -39,11 +46,11 @@ use crate::executor::monitor::StreamingMetrics; use crate::executor::{ActorContextRef, BoxedExecutor, JoinType, JoinTypePrimitive, PkIndices}; use crate::task::AtomicU64Ref; -pub struct TemporalJoinExecutor { +pub struct TemporalJoinExecutor { ctx: ActorContextRef, left: BoxedExecutor, right: BoxedExecutor, - right_table: TemporalSide, + right_table: TemporalSide, left_join_keys: Vec, right_join_keys: Vec, null_safe: Vec, @@ -58,20 +65,86 @@ pub struct TemporalJoinExecutor { metrics: Arc, } -struct TemporalSide { +#[derive(Default)] +pub struct JoinEntry { + /// pk -> row + cached: HashMap, + kv_heap_size: KvSize, +} + +impl EstimateSize for JoinEntry { + fn estimated_heap_size(&self) -> usize { + // TODO: Add internal size. + // https://github.com/risingwavelabs/risingwave/issues/9713 + self.kv_heap_size.size() + } +} + +impl JoinEntry { + /// Insert into the cache. + pub fn insert(&mut self, key: OwnedRow, value: OwnedRow) { + self.kv_heap_size.add(&key, &value); + self.cached.try_insert(key, value).unwrap(); + } + + /// Delete from the cache. + pub fn remove(&mut self, pk: &OwnedRow) { + if let Some(value) = self.cached.remove(pk) { + self.kv_heap_size.sub(pk, &value); + } else { + panic!("pk {:?} should be in the cache", pk); + } + } + + pub fn is_empty(&self) -> bool { + self.cached.is_empty() + } +} + +struct JoinEntryWrapper(Option); + +impl EstimateSize for JoinEntryWrapper { + fn estimated_heap_size(&self) -> usize { + self.0.estimated_heap_size() + } +} + +impl JoinEntryWrapper { + const MESSAGE: &str = "the state should always be `Some`"; + + /// Take the value out of the wrapper. Panic if the value is `None`. + pub fn take(&mut self) -> JoinEntry { + self.0.take().expect(Self::MESSAGE) + } +} + +impl Deref for JoinEntryWrapper { + type Target = JoinEntry; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().expect(Self::MESSAGE) + } +} + +impl DerefMut for JoinEntryWrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0.as_mut().expect(Self::MESSAGE) + } +} + +struct TemporalSide { source: StorageTable, + pk: Vec, table_output_indices: Vec, - cache: ManagedLruCache, DefaultHasher, SharedStatsAlloc>, + cache: ManagedLruCache>, ctx: ActorContextRef, + join_key_data_types: Vec, } -impl TemporalSide { - async fn lookup( - &mut self, - key: impl Row, - epoch: HummockEpoch, - ) -> StreamExecutorResult> { - let key = key.into_owned_row(); +impl TemporalSide { + /// Lookup the temporal side table and return a `JoinEntry` which could be empty if there are no + /// matched records. + async fn lookup(&mut self, key: &K, epoch: HummockEpoch) -> StreamExecutorResult { let table_id_str = self.source.table_id().to_string(); let actor_id_str = self.ctx.id.to_string(); self.ctx @@ -79,36 +152,71 @@ impl TemporalSide { .temporal_join_total_query_cache_count .with_label_values(&[&table_id_str, &actor_id_str]) .inc(); - Ok(match self.cache.get(&key) { - Some(res) => res.clone(), - None => { - // cache miss - self.ctx - .streaming_metrics - .temporal_join_cache_miss_count - .with_label_values(&[&table_id_str, &actor_id_str]) - .inc(); - let res = self - .source - .get_row(key.clone(), HummockReadEpoch::NoWait(epoch)) - .await? - .map(|row| row.project(&self.table_output_indices).into_owned_row()); - self.cache.put(key, res.clone()); - res + + let res = if self.cache.contains(key) { + let mut state = self.cache.peek_mut(key).unwrap(); + state.take() + } else { + // cache miss + self.ctx + .streaming_metrics + .temporal_join_cache_miss_count + .with_label_values(&[&table_id_str, &actor_id_str]) + .inc(); + + let pk_prefix = key.deserialize(&self.join_key_data_types)?; + + let iter = self + .source + .batch_iter_with_pk_bounds( + HummockReadEpoch::NoWait(epoch), + &pk_prefix, + .., + false, + PrefetchOptions::new_for_exhaust_iter(), + ) + .await?; + + let mut entry = JoinEntry::default(); + + pin_mut!(iter); + while let Some(row) = iter.next_row().await? { + entry.insert( + row.as_ref().project(&self.pk).into_owned_row(), + row.project(&self.table_output_indices).into_owned_row(), + ); } - }) + + entry + }; + + Ok(res) } - fn update(&mut self, payload: Vec, join_keys: &[usize]) { - payload.iter().flat_map(|c| c.rows()).for_each(|(op, row)| { - let key = row.project(join_keys).into_owned_row(); - if let Some(mut value) = self.cache.get_mut(&key) { - match op { - Op::Insert | Op::UpdateInsert => *value = Some(row.into_owned_row()), - Op::Delete | Op::UpdateDelete => *value = None, - }; + fn update( + &mut self, + chunks: Vec, + join_keys: &[usize], + ) -> StreamExecutorResult<()> { + for chunk in chunks { + let keys = K::build(join_keys, chunk.data_chunk())?; + for ((op, row), key) in chunk.rows().zip_eq_debug(keys.into_iter()) { + if self.cache.contains(&key) { + // Update cache + let mut entry = self.cache.get_mut(&key).unwrap(); + let pk = row.project(&self.pk).into_owned_row(); + match op { + Op::Insert | Op::UpdateInsert => entry.insert(pk, row.into_owned_row()), + Op::Delete | Op::UpdateDelete => entry.remove(&pk), + }; + } } - }); + } + Ok(()) + } + + pub fn insert_back(&mut self, key: K, state: JoinEntry) { + self.cache.put(key, JoinEntryWrapper(Some(state))); } } @@ -184,7 +292,7 @@ async fn align_input(left: Box, right: Box) { } } -impl TemporalJoinExecutor { +impl TemporalJoinExecutor { #[allow(clippy::too_many_arguments)] pub fn new( ctx: ActorContextRef, @@ -202,6 +310,7 @@ impl TemporalJoinExecutor { watermark_epoch: AtomicU64Ref, metrics: Arc, chunk_size: usize, + join_key_data_types: Vec, ) -> Self { let schema_fields = [left.schema().fields.clone(), right.schema().fields.clone()].concat(); @@ -226,15 +335,19 @@ impl TemporalJoinExecutor { alloc, ); + let pk = table.pk_in_output_indices().unwrap(); + Self { ctx: ctx.clone(), left, right, right_table: TemporalSide { source: table, + pk, table_output_indices, cache, ctx, + join_key_data_types, }, left_join_keys, right_join_keys, @@ -257,6 +370,8 @@ impl TemporalJoinExecutor { self.right.schema().len(), ); + let null_matched = K::Bitmap::from_bool_vec(self.null_safe); + let mut prev_epoch = None; let table_id_str = self.right_table.source.table_id().to_string(); @@ -271,6 +386,8 @@ impl TemporalJoinExecutor { .set(self.right_table.cache.len() as i64); match msg? { InternalMessage::Chunk(chunk) => { + // Compact chunk, otherwise the following keys and chunk rows might fail to zip. + let chunk = chunk.compact(); let mut builder = StreamChunkBuilder::new( self.chunk_size, &self.schema.data_types(), @@ -278,33 +395,33 @@ impl TemporalJoinExecutor { right_map.clone(), ); let epoch = prev_epoch.expect("Chunk data should come after some barrier."); - for (op, left_row) in chunk.rows() { - let key = left_row.project(&self.left_join_keys); - if key - .iter() - .zip_eq_fast(self.null_safe.iter()) - .any(|(datum, can_null)| datum.is_none() && !*can_null) - { - continue; - } - if let Some(right_row) = self.right_table.lookup(key, epoch).await? { - // check join condition - let ok = if let Some(ref mut cond) = self.condition { - let concat_row = left_row.chain(&right_row).into_owned_row(); - cond.eval_row_infallible(&concat_row, |err| { - self.ctx.on_compute_error(err, self.identity.as_str()) - }) - .await - .map(|s| *s.as_bool()) - .unwrap_or(false) - } else { - true - }; - if ok { - if let Some(chunk) = builder.append_row(op, left_row, &right_row) { - yield Message::Chunk(chunk); + let keys = K::build(&self.left_join_keys, chunk.data_chunk())?; + for ((op, left_row), key) in chunk.rows().zip_eq_debug(keys.into_iter()) { + if key.null_bitmap().is_subset(&null_matched) + && let join_entry = self.right_table.lookup(&key, epoch).await? + && !join_entry.is_empty() { + for right_row in join_entry.cached.values() { + // check join condition + let ok = if let Some(ref mut cond) = self.condition { + let concat_row = left_row.chain(&right_row).into_owned_row(); + cond.eval_row_infallible(&concat_row, |err| { + self.ctx.on_compute_error(err, self.identity.as_str()) + }) + .await + .map(|s| *s.as_bool()) + .unwrap_or(false) + } else { + true + }; + + if ok { + if let Some(chunk) = builder.append_row(op, left_row, right_row) { + yield Message::Chunk(chunk); + } } } + // Insert back the state taken from ht. + self.right_table.insert_back(key.clone(), join_entry); } else if T == JoinType::LeftOuter { if let Some(chunk) = builder.append_row_update(op, left_row) { yield Message::Chunk(chunk); @@ -324,7 +441,7 @@ impl TemporalJoinExecutor { } } self.right_table.cache.update_epoch(barrier.epoch.curr); - self.right_table.update(updates, &self.right_join_keys); + self.right_table.update(updates, &self.right_join_keys)?; prev_epoch = Some(barrier.epoch.curr); yield Message::Barrier(barrier) } @@ -333,7 +450,9 @@ impl TemporalJoinExecutor { } } -impl Executor for TemporalJoinExecutor { +impl Executor + for TemporalJoinExecutor +{ fn execute(self: Box) -> super::BoxedMessageStream { self.into_stream().boxed() } diff --git a/src/stream/src/from_proto/temporal_join.rs b/src/stream/src/from_proto/temporal_join.rs index 4c7b8695066b..19df99d2d66e 100644 --- a/src/stream/src/from_proto/temporal_join.rs +++ b/src/stream/src/from_proto/temporal_join.rs @@ -15,6 +15,8 @@ use std::sync::Arc; use risingwave_common::catalog::{ColumnDesc, TableId, TableOption}; +use risingwave_common::hash::{HashKey, HashKeyDispatcher}; +use risingwave_common::types::DataType; use risingwave_common::util::sort_util::OrderType; use risingwave_expr::expr::{build_from_prost, BoxedExpression}; use risingwave_pb::plan_common::{JoinType as JoinTypeProto, StorageTableDesc}; @@ -141,6 +143,11 @@ impl ExecutorBuilder for TemporalJoinExecutorBuilder { .map(|&x| x as usize) .collect_vec(); + let join_key_data_types = left_join_keys + .iter() + .map(|idx| source_l.schema().fields[*idx].data_type()) + .collect_vec(); + let dispatcher_args = TemporalJoinExecutorDispatcherArgs { ctx: params.actor_context, left: source_l, @@ -158,6 +165,7 @@ impl ExecutorBuilder for TemporalJoinExecutorBuilder { chunk_size: params.env.config().developer.chunk_size, metrics: params.executor_stats, join_type_proto: node.get_join_type()?, + join_key_data_types, }; dispatcher_args.dispatch() @@ -181,31 +189,38 @@ struct TemporalJoinExecutorDispatcherArgs { chunk_size: usize, metrics: Arc, join_type_proto: JoinTypeProto, + join_key_data_types: Vec, } -impl TemporalJoinExecutorDispatcherArgs { - pub fn dispatch(self) -> StreamResult { +impl HashKeyDispatcher for TemporalJoinExecutorDispatcherArgs { + type Output = StreamResult; + + fn dispatch_impl(self) -> Self::Output { + /// This macro helps to fill the const generic type parameter. macro_rules! build { ($join_type:ident) => { - Ok(Box::new( - TemporalJoinExecutor::::new( - self.ctx, - self.left, - self.right, - self.right_table, - self.left_join_keys, - self.right_join_keys, - self.null_safe, - self.condition, - self.pk_indices, - self.output_indices, - self.table_output_indices, - self.executor_id, - self.watermark_epoch, - self.metrics, - self.chunk_size, - ), - )) + Ok(Box::new(TemporalJoinExecutor::< + K, + S, + { JoinType::$join_type }, + >::new( + self.ctx, + self.left, + self.right, + self.right_table, + self.left_join_keys, + self.right_join_keys, + self.null_safe, + self.condition, + self.pk_indices, + self.output_indices, + self.table_output_indices, + self.executor_id, + self.watermark_epoch, + self.metrics, + self.chunk_size, + self.join_key_data_types, + ))) }; } match self.join_type_proto { @@ -214,4 +229,8 @@ impl TemporalJoinExecutorDispatcherArgs { _ => unreachable!(), } } + + fn data_types(&self) -> &[DataType] { + &self.join_key_data_types + } }