Skip to content

Commit

Permalink
[FEAT]: sql count(*) (#2832)
Browse files Browse the repository at this point in the history
adds support for sql `count(*)`

closes #2742
  • Loading branch information
universalmind303 authored Sep 12, 2024
1 parent 805fbce commit 6594d87
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 13 deletions.
52 changes: 41 additions & 11 deletions src/daft-sql/src/modules/aggs.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use daft_dsl::{AggExpr, Expr, ExprRef, LiteralValue};
use sqlparser::ast::FunctionArg;
use daft_dsl::{col, AggExpr, Expr, ExprRef, LiteralValue};
use sqlparser::ast::{FunctionArg, FunctionArgExpr};

use crate::{
ensure,
Expand Down Expand Up @@ -34,20 +34,50 @@ impl SQLModule for SQLModuleAggs {

impl SQLFunction for AggExpr {
fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult<ExprRef> {
let inputs = self.args_to_expr_unnamed(inputs, planner)?;
to_expr(self, inputs.as_slice())
// COUNT(*) needs a bit of extra handling, so we process that outside of `to_expr`
if let AggExpr::Count(_, _) = self {
handle_count(inputs, planner)
} else {
let inputs = self.args_to_expr_unnamed(inputs, planner)?;
to_expr(self, inputs.as_slice())
}
}
}

fn handle_count(inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult<ExprRef> {
Ok(match inputs {
[FunctionArg::Unnamed(FunctionArgExpr::Wildcard)] => match planner.relation_opt() {
Some(rel) => {
let schema = rel.schema();
col(schema.fields[0].name.clone())
.count(daft_core::count_mode::CountMode::All)
.alias("count")
}
None => unsupported_sql_err!("Wildcard is not supported in this context"),
},
[FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(name))] => {
match planner.relation_opt() {
Some(rel) if name.to_string() == rel.name => {
let schema = rel.schema();
col(schema.fields[0].name.clone())
.count(daft_core::count_mode::CountMode::All)
.alias("count")
}
_ => unsupported_sql_err!("Wildcard is not supported in this context"),
}
}
[expr] => {
// SQL default COUNT ignores nulls
let input = planner.plan_function_arg(expr)?;
input.count(daft_core::count_mode::CountMode::Valid)
}
_ => unsupported_sql_err!("COUNT takes exactly one argument"),
})
}

pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult<ExprRef> {
match expr {
AggExpr::Count(_, _) => {
// SQL default COUNT ignores nulls.
ensure!(args.len() == 1, "count takes exactly one argument");
Ok(args[0]
.clone()
.count(daft_core::count_mode::CountMode::Valid))
}
AggExpr::Count(_, _) => unreachable!("count should be handled by by this point"),
AggExpr::Sum(_) => {
ensure!(args.len() == 1, "sum takes exactly one argument");
Ok(args[0].clone().sum())
Expand Down
11 changes: 9 additions & 2 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,17 @@ use sqlparser::{
/// This is used to keep track of the table name associated with a logical plan while planning a SQL query
#[derive(Debug, Clone)]
pub(crate) struct Relation {
inner: LogicalPlanBuilder,
name: String,
pub(crate) inner: LogicalPlanBuilder,
pub(crate) name: String,
}

impl Relation {
pub fn new(inner: LogicalPlanBuilder, name: String) -> Self {
Relation { inner, name }
}
pub(crate) fn schema(&self) -> SchemaRef {
self.inner.schema()
}
}

pub struct SQLPlanner {
Expand Down Expand Up @@ -70,6 +73,10 @@ impl SQLPlanner {
self.current_relation.as_mut().expect("relation not set")
}

pub(crate) fn relation_opt(&self) -> Option<&Relation> {
self.current_relation.as_ref()
}

pub fn plan_sql(&mut self, sql: &str) -> SQLPlannerResult<LogicalPlanRef> {
let tokens = Tokenizer::new(&GenericDialect {}, sql).tokenize()?;

Expand Down
18 changes: 18 additions & 0 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,21 @@ def test_sql_groupby_agg():
catalog = SQLCatalog({"test": df})
df = daft.sql("SELECT sum(v) FROM test GROUP BY n ORDER BY n", catalog=catalog)
assert df.collect().to_pydict() == {"n": [1, 2], "v": [3, 7]}


def test_sql_count_star():
df = daft.from_pydict(
{
"a": ["a", "b", None, "c"],
"b": [4, 3, 2, None],
}
)
catalog = SQLCatalog({"df": df})
df2 = daft.sql("SELECT count(*) FROM df", catalog)
actual = df2.collect().to_pydict()
expected = df.count().collect().to_pydict()
assert actual == expected
df2 = daft.sql("SELECT count(b) FROM df", catalog)
actual = df2.collect().to_pydict()
expected = df.agg(daft.col("b").count()).collect().to_pydict()
assert actual == expected

0 comments on commit 6594d87

Please sign in to comment.