Skip to content

Commit

Permalink
[FEAT]: dyn function registry (#2466)
Browse files Browse the repository at this point in the history
this adds a new variant to `Expr`. `Expr::ScalarFunction` which is
pretty similar to `FunctionExpr`, but uses dynamic dispatch and a
registry instead of the enum variants.

The registry is inspired by datafusion's function registry. 

when an expr is serialized, it just serializes the name and the inputs.

for example
`col("text").hash(seed=42)`

```js
{
  "name": "hash",
  "inputs": [
    `col('text')`, // serialized repr of this
    `lit(42)` // serialized repr of this
   ]
}
```

then when deserializing, it just fetches the appropriate function from
the registry. _(errorring if no matches found)_.



Also just to make sure everything works, I refactored the `hash`
function to use the new paradigm.
  • Loading branch information
universalmind303 authored Jul 9, 2024
1 parent ecebb82 commit cf9a09b
Show file tree
Hide file tree
Showing 13 changed files with 309 additions and 36 deletions.
17 changes: 17 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src/daft-dsl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ common-treenode = {path = "../common/treenode", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-io = {path = "../daft-io", default-features = false}
daft-sketch = {path = "../daft-sketch", default-features = false}
dashmap = "6.0.1"
erased-serde = "0.4.5"
indexmap = {workspace = true}
itertools = {workspace = true}
lazy_static = {workspace = true}
pyo3 = {workspace = true, optional = true}
pyo3-log = {workspace = true, optional = true}
serde = {workspace = true}
Expand Down
17 changes: 14 additions & 3 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ use daft_core::{

use crate::{
functions::{
function_display, function_semantic_id,
function_display, function_semantic_id, scalar_function_semantic_id,
sketch::{HashableVecPercentiles, SketchExpr},
struct_::StructExpr,
FunctionEvaluator,
FunctionEvaluator, ScalarFunction,
},
lit,
optimization::{get_required_columns, requires_computation},
Expand Down Expand Up @@ -58,6 +58,7 @@ pub enum Expr {
if_false: ExprRef,
predicate: ExprRef,
},
ScalarFunction(ScalarFunction),
}

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Hash, Eq)]
Expand Down Expand Up @@ -576,6 +577,7 @@ impl Expr {

// Agg: Separate path.
Agg(agg_expr) => agg_expr.semantic_id(schema),
ScalarFunction(sf) => scalar_function_semantic_id(sf, schema),
}
}

Expand Down Expand Up @@ -607,6 +609,7 @@ impl Expr {
vec![if_true.clone(), if_false.clone(), predicate.clone()]
}
FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()],
ScalarFunction(sf) => sf.inputs.clone(),
}
}

Expand Down Expand Up @@ -658,6 +661,7 @@ impl Expr {
func: func.clone(),
inputs: children,
},
ScalarFunction(sf) => ScalarFunction(sf.clone()),
}
}

Expand Down Expand Up @@ -710,6 +714,8 @@ impl Expr {
}
Literal(value) => Ok(Field::new("literal", value.get_type())),
Function { func, inputs } => func.to_field(inputs.as_slice(), schema, func),
ScalarFunction(sf) => sf.to_field(schema),

BinaryOp { op, left, right } => {
let left_field = left.to_field(schema)?;
let right_field = right.to_field(schema)?;
Expand Down Expand Up @@ -814,6 +820,7 @@ impl Expr {
FunctionExpr::Struct(StructExpr::Get(name)) => name,
_ => inputs.first().unwrap().name(),
},
ScalarFunction(func) => func.inputs.first().unwrap().name(),
BinaryOp {
op: _,
left,
Expand Down Expand Up @@ -903,7 +910,8 @@ impl Expr {
| Expr::IsIn(..)
| Expr::Between(..)
| Expr::Function { .. }
| Expr::FillNull(..) => Err(io::Error::new(
| Expr::FillNull(..)
| Expr::ScalarFunction { .. } => Err(io::Error::new(
io::ErrorKind::Other,
"Unsupported expression for SQL translation",
)),
Expand Down Expand Up @@ -946,6 +954,8 @@ impl Display for Expr {
Between(expr, lower, upper) => write!(f, "{expr} in [{lower},{upper}]"),
Literal(val) => write!(f, "lit({val})"),
Function { func, inputs } => function_display(f, func, inputs),
ScalarFunction(func) => write!(f, "{func}"),

IfElse {
if_true,
if_false,
Expand Down Expand Up @@ -1130,6 +1140,7 @@ fn expr_has_agg(expr: &ExprRef) -> bool {
Alias(e, _) | Cast(e, _) | Not(e) | IsNull(e) | NotNull(e) => expr_has_agg(e),
BinaryOp { left, right, .. } => expr_has_agg(left) || expr_has_agg(right),
Function { inputs, .. } => inputs.iter().any(expr_has_agg),
ScalarFunction(func) => func.inputs.iter().any(expr_has_agg),
IsIn(l, r) | FillNull(l, r) => expr_has_agg(l) || expr_has_agg(r),
Between(v, l, u) => expr_has_agg(v) || expr_has_agg(l) || expr_has_agg(u),
IfElse {
Expand Down
51 changes: 26 additions & 25 deletions src/daft-dsl/src/functions/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,25 @@ use daft_core::{
schema::Schema,
DataType, IntoSeries, Series,
};
use serde::{Deserialize, Serialize};

use crate::{
functions::{FunctionEvaluator, FunctionExpr},
Expr, ExprRef,
};
use crate::{Expr, ExprRef};

pub(super) struct HashEvaluator {}
use super::{ScalarFunction, ScalarUDF};

impl FunctionEvaluator for HashEvaluator {
fn fn_name(&self) -> &'static str {
"hash"
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub(super) struct HashFunction;

impl ScalarUDF for HashFunction {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[input] | [input, _] => match input.to_field(schema) {
Ok(field) => Ok(Field::new(field.name, DataType::UInt64)),
e => e,
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input arg, got {}",
inputs.len()
))),
}
fn name(&self) -> &'static str {
"hash"
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
fn evaluate(&self, inputs: &[Series]) -> DaftResult<Series> {
match inputs {
[input] => input.hash(None).map(|s| s.into_series()),
[input, seed] => {
Expand Down Expand Up @@ -60,6 +52,19 @@ impl FunctionEvaluator for HashEvaluator {
_ => Err(DaftError::ValueError("Expected 2 input arg".to_string())),
}
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema) -> DaftResult<Field> {
match inputs {
[input] | [input, _] => match input.to_field(schema) {
Ok(field) => Ok(Field::new(field.name, DataType::UInt64)),
e => e,
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input arg, got {}",
inputs.len()
))),
}
}
}

pub fn hash(input: ExprRef, seed: Option<ExprRef>) -> ExprRef {
Expand All @@ -68,9 +73,5 @@ pub fn hash(input: ExprRef, seed: Option<ExprRef>) -> ExprRef {
None => vec![input],
};

Expr::Function {
func: FunctionExpr::Hash,
inputs,
}
.into()
Expr::ScalarFunction(ScalarFunction::new(HashFunction {}, inputs)).into()
}
2 changes: 1 addition & 1 deletion src/daft-dsl/src/functions/json/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub enum JsonExpr {

impl JsonExpr {
#[inline]
pub fn query_evaluator(&self) -> &dyn FunctionEvaluator {
pub fn get_evaluator(&self) -> &dyn FunctionEvaluator {
use JsonExpr::*;
match self {
Query(_) => &JsonQueryEvaluator {},
Expand Down
11 changes: 7 additions & 4 deletions src/daft-dsl/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ pub mod map;
pub mod minhash;
pub mod numeric;
pub mod partitioning;
pub mod registry;
pub mod scalar;
pub mod sketch;
pub mod struct_;
pub mod temporal;
pub mod uri;
pub mod utf8;

use std::fmt::{Display, Formatter, Result};
use std::hash::Hash;

use crate::ExprRef;

Expand All @@ -28,10 +31,12 @@ use self::struct_::StructExpr;
use self::temporal::TemporalExpr;
use self::utf8::Utf8Expr;
use self::{float::FloatExpr, uri::UriExpr};
pub use scalar::*;

use common_error::DaftResult;
use daft_core::datatypes::FieldID;
use daft_core::{datatypes::Field, schema::Schema, series::Series};
use hash::HashEvaluator;

use minhash::{MinHashEvaluator, MinHashExpr};
use serde::{Deserialize, Serialize};

Expand All @@ -56,7 +61,6 @@ pub enum FunctionExpr {
Python(PythonUDF),
Partitioning(PartitioningExpr),
Uri(UriExpr),
Hash,
MinHash(MinHashExpr),
}

Expand Down Expand Up @@ -84,13 +88,12 @@ impl FunctionExpr {
Map(expr) => expr.get_evaluator(),
Sketch(expr) => expr.get_evaluator(),
Struct(expr) => expr.get_evaluator(),
Json(expr) => expr.query_evaluator(),
Json(expr) => expr.get_evaluator(),
Image(expr) => expr.get_evaluator(),
Uri(expr) => expr.get_evaluator(),
#[cfg(feature = "python")]
Python(expr) => expr,
Partitioning(expr) => expr.get_evaluator(),
Hash => &HashEvaluator {},
MinHash(_) => &MinHashEvaluator {},
}
}
Expand Down
43 changes: 43 additions & 0 deletions src/daft-dsl/src/functions/registry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use std::sync::Arc;

use common_error::{DaftError, DaftResult};
use dashmap::DashMap;

use super::{hash::HashFunction, ScalarUDF};

lazy_static::lazy_static! {
pub static ref REGISTRY: Registry = Registry::new();
}

pub struct Registry {
functions: DashMap<&'static str, Arc<dyn ScalarUDF>>,
}

impl Registry {
fn new() -> Self {
let iter: Vec<Arc<dyn ScalarUDF>> = vec![Arc::new(HashFunction {})];

let functions = iter.into_iter().map(|f| (f.name(), f)).collect();

Self { functions }
}
pub fn register(&mut self, function: Arc<dyn ScalarUDF>) -> DaftResult<()> {
if self.functions.contains_key(function.name()) {
Err(DaftError::ValueError(format!(
"function {} already exists",
function.name()
)))
} else {
self.functions.insert(function.name(), function);
Ok(())
}
}

pub fn get(&self, name: &str) -> Option<Arc<dyn ScalarUDF>> {
self.functions.get(name).map(|f| f.value().clone())
}

pub fn names(&self) -> Vec<&'static str> {
self.functions.iter().map(|pair| pair.name()).collect()
}
}
Loading

0 comments on commit cf9a09b

Please sign in to comment.