Skip to content

Commit

Permalink
[FEAT]: SQL add hash and minhash (#2948)
Browse files Browse the repository at this point in the history
add a few more missing scalars (hash and minhash)
  • Loading branch information
universalmind303 authored Sep 30, 2024
1 parent 768f422 commit f1194b5
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/daft-functions/src/minhash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use daft_dsl::{
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub(super) struct MinHashFunction {
num_hashes: usize,
ngram_size: usize,
seed: u32,
pub struct MinHashFunction {
pub num_hashes: usize,
pub ngram_size: usize,
pub seed: u32,
}

#[typetag::serde]
Expand Down
7 changes: 6 additions & 1 deletion src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{collections::HashMap, sync::Arc};

use daft_dsl::ExprRef;
use hashing::SQLModuleHashing;
use once_cell::sync::Lazy;
use sqlparser::ast::{
Function, FunctionArg, FunctionArgExpr, FunctionArgOperator, FunctionArguments,
Expand All @@ -18,6 +19,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy<SQLFunctions> = Lazy::new(|| {
let mut functions = SQLFunctions::new();
functions.register::<SQLModuleAggs>();
functions.register::<SQLModuleFloat>();
functions.register::<SQLModuleHashing>();
functions.register::<SQLModuleImage>();
functions.register::<SQLModuleJson>();
functions.register::<SQLModuleList>();
Expand Down Expand Up @@ -235,7 +237,10 @@ impl SQLPlanner {
}
}

fn try_unwrap_function_arg_expr(&self, expr: &FunctionArgExpr) -> SQLPlannerResult<ExprRef> {
pub(crate) fn try_unwrap_function_arg_expr(
&self,
expr: &FunctionArgExpr,
) -> SQLPlannerResult<ExprRef> {
match expr {
FunctionArgExpr::Expr(expr) => self.plan_expr(expr),
_ => unsupported_sql_err!("Wildcard function args not yet supported"),
Expand Down
111 changes: 111 additions & 0 deletions src/daft-sql/src/modules/hashing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
use daft_dsl::ExprRef;
use daft_functions::{
hash::hash,
minhash::{minhash, MinHashFunction},
};
use sqlparser::ast::FunctionArg;

use super::SQLModule;
use crate::{
error::{PlannerError, SQLPlannerResult},
functions::{SQLFunction, SQLFunctionArguments, SQLFunctions},
unsupported_sql_err,
};

pub struct SQLModuleHashing;

impl SQLModule for SQLModuleHashing {
fn register(parent: &mut SQLFunctions) {
parent.add_fn("hash", SQLHash);
parent.add_fn("minhash", SQLMinhash);
}
}

pub struct SQLHash;

impl SQLFunction for SQLHash {
fn to_expr(
&self,
inputs: &[FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> SQLPlannerResult<ExprRef> {
match inputs {
[input] => {
let input = planner.plan_function_arg(input)?;
Ok(hash(input, None))
}
[input, seed] => {
let input = planner.plan_function_arg(input)?;
match seed {
FunctionArg::Named { name, arg, .. } if name.value == "seed" => {
let seed = planner.try_unwrap_function_arg_expr(arg)?;
Ok(hash(input, Some(seed)))
}
arg @ FunctionArg::Unnamed(_) => {
let seed = planner.plan_function_arg(arg)?;
Ok(hash(input, Some(seed)))
}
_ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"),
}
}
_ => unsupported_sql_err!("Invalid arguments for hash: '{inputs:?}'"),
}
}
}

pub struct SQLMinhash;

impl TryFrom<SQLFunctionArguments> for MinHashFunction {
type Error = PlannerError;

fn try_from(args: SQLFunctionArguments) -> Result<Self, Self::Error> {
let num_hashes = args
.get_named("num_hashes")
.ok_or_else(|| PlannerError::invalid_operation("num_hashes is required"))?
.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer"))?
as usize;

let ngram_size = args
.get_named("ngram_size")
.ok_or_else(|| PlannerError::invalid_operation("ngram_size is required"))?
.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("ngram_size must be an integer"))?
as usize;
let seed = args
.get_named("seed")
.map(|arg| {
arg.as_literal()
.and_then(|lit| lit.as_i64())
.ok_or_else(|| PlannerError::invalid_operation("num_hashes must be an integer"))
})
.transpose()?
.unwrap_or(1) as u32;
Ok(Self {
num_hashes,
ngram_size,
seed,
})
}
}

impl SQLFunction for SQLMinhash {
fn to_expr(
&self,
inputs: &[FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> SQLPlannerResult<ExprRef> {
match inputs {
[input, args @ ..] => {
let input = planner.plan_function_arg(input)?;
let args: MinHashFunction =
planner.plan_function_args(args, &["num_hashes", "ngram_size", "seed"], 0)?;

Ok(minhash(input, args.num_hashes, args.ngram_size, args.seed))
}
_ => unsupported_sql_err!("Invalid arguments for minhash: '{inputs:?}'"),
}
}
}
1 change: 1 addition & 0 deletions src/daft-sql/src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::functions::SQLFunctions;

pub mod aggs;
pub mod float;
pub mod hashing;
pub mod image;
pub mod json;
pub mod list;
Expand Down
47 changes: 47 additions & 0 deletions tests/sql/test_exprs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytest

import daft
from daft import col


def test_nested():
Expand All @@ -20,3 +23,47 @@ def test_nested():
expected = df.with_column("try_this", df["A"] + 1).collect()

assert actual.to_pydict() == expected.to_pydict()


def test_hash_exprs():
df = daft.from_pydict(
{
"a": ["foo", "bar", "baz", "qux"],
"ints": [1, 2, 3, 4],
"floats": [1.5, 2.5, 3.5, 4.5],
}
)

actual = (
daft.sql("""
SELECT
hash(a) as hash_a,
hash(a, 0) as hash_a_0,
hash(a, seed:=0) as hash_a_seed_0,
minhash(a, num_hashes:=10, ngram_size:= 100, seed:=10) as minhash_a,
minhash(a, num_hashes:=10, ngram_size:= 100) as minhash_a_no_seed,
FROM df
""")
.collect()
.to_pydict()
)

expected = (
df.select(
col("a").hash().alias("hash_a"),
col("a").hash(0).alias("hash_a_0"),
col("a").hash(seed=0).alias("hash_a_seed_0"),
col("a").minhash(num_hashes=10, ngram_size=100, seed=10).alias("minhash_a"),
col("a").minhash(num_hashes=10, ngram_size=100).alias("minhash_a_no_seed"),
)
.collect()
.to_pydict()
)

assert actual == expected

with pytest.raises(Exception, match="Invalid arguments for minhash"):
daft.sql("SELECT minhash() as hash_a FROM df").collect()

with pytest.raises(Exception, match="num_hashes is required"):
daft.sql("SELECT minhash(a) as hash_a FROM df").collect()

0 comments on commit f1194b5

Please sign in to comment.