Skip to content

Commit

Permalink
feat(case when): short-circuit optimization for case when (#6581)
Browse files Browse the repository at this point in the history
* feat(case when): short-circuit optimization for case when (close #6580)

* add test

* fix rebase

* rerun

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
soundOfDestiny and mergify[bot] authored Nov 28, 2022
1 parent 28878d8 commit b804479
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 61 deletions.
6 changes: 6 additions & 0 deletions e2e_test/batch/basic/func.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ select false and (1 / (v1-v1) > 2) from generate_series(1, 3, 1) as t(v1) where
f
f

query I
select case when 1=1 then 9 else 1 / (v1-v1) end from generate_series(1, 3, 1) as t(v1) where v1 != 2;
----
9
9

statement ok
create table t1 (v1 int, v2 int, v3 int);

Expand Down
104 changes: 43 additions & 61 deletions src/expr/src/expr/expr_case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use itertools::Itertools;
use risingwave_common::array::{ArrayRef, DataChunk};
use std::sync::Arc;

use risingwave_common::array::{ArrayRef, DataChunk, Vis};
use risingwave_common::row::Row;
use risingwave_common::types::{DataType, Datum, ScalarImpl, ScalarRefImpl, ToOwnedDatum};
use risingwave_common::types::{DataType, Datum};
use risingwave_common::{bail, ensure};
use risingwave_pb::expr::expr_node::{RexNode, Type};
use risingwave_pb::expr::ExprNode;
Expand Down Expand Up @@ -62,71 +63,52 @@ impl Expression for CaseExpression {
}

fn eval(&self, input: &DataChunk) -> Result<ArrayRef> {
let vis = input.vis();
let mut els = self
.else_clause
.as_deref()
.map(|else_clause| else_clause.eval_checked(input).unwrap());
let when_thens = self
.when_clauses
.iter()
.map(|when_clause| {
(
when_clause.when.eval_checked(input).unwrap(),
when_clause.then.eval_checked(input).unwrap(),
)
})
.collect_vec();
let mut output_array = self.return_type().create_array_builder(input.capacity());
for idx in 0..input.capacity() {
if vis.is_set(idx) {
if let Some((_, t)) = when_thens
.iter()
.map(|(w, t)| (w.value_at(idx), t.value_at(idx)))
.find(|(w, _)| {
*w.unwrap_or(ScalarRefImpl::Bool(false))
.into_scalar_impl()
.as_bool()
})
{
output_array.append_datum(&t.to_owned_datum());
} else if let Some(els) = els.as_mut() {
let t = els.datum_at(idx);
output_array.append_datum(&t);
} else {
output_array.append_null();
};
let mut input = input.clone();
let input_len = input.capacity();
let mut selection = vec![None; input_len];
let when_len = self.when_clauses.len();
let mut result_array = Vec::with_capacity(when_len + 1);
for (when_idx, WhenClause { when, then }) in self.when_clauses.iter().enumerate() {
let calc_then_vis: Vis = when.eval_checked(&input)?.as_bool().to_bitmap().into();
let input_vis = input.vis().clone();
input.set_vis(calc_then_vis.clone());
let then_res = then.eval_checked(&input)?;
calc_then_vis
.ones()
.for_each(|pos| selection[pos] = Some(when_idx));
input.set_vis(&input_vis & (!&calc_then_vis));
result_array.push(then_res);
}
if let Some(ref else_expr) = self.else_clause {
let else_res = else_expr.eval_checked(&input)?;
input
.vis()
.ones()
.for_each(|pos| selection[pos] = Some(when_len));
result_array.push(else_res);
}
let mut builder = self.return_type().create_array_builder(input.capacity());
for (i, sel) in selection.into_iter().enumerate() {
if let Some(when_idx) = sel {
builder.append_datum(result_array[when_idx].value_at(i));
} else {
output_array.append_null();
builder.append_null();
}
}
let output_array = output_array.finish().into();
Ok(output_array)
Ok(Arc::new(builder.finish()))
}

fn eval_row(&self, input: &Row) -> Result<Datum> {
let els = self
.else_clause
.as_deref()
.map(|else_clause| else_clause.eval_row(input).unwrap());
let when_then_first = self
.when_clauses
.iter()
.map(|when_clause| {
(
when_clause.when.eval_row(input).unwrap(),
when_clause.then.eval_row(input).unwrap(),
)
})
.find(|(w, _)| *(w.as_ref().unwrap_or(&ScalarImpl::Bool(false)).as_bool()));

let ret = if let Some((_, t)) = when_then_first {
t
for WhenClause { when, then } in &self.when_clauses {
if when.eval_row(input)?.map_or(false, |w| w.into_bool()) {
return then.eval_row(input);
}
}
if let Some(ref else_expr) = self.else_clause {
else_expr.eval_row(input)
} else {
els.unwrap_or(None)
};

Ok(ret)
Ok(None)
}
}
}

Expand Down

0 comments on commit b804479

Please sign in to comment.