Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FunctionRegistry::register_udaf and FunctionRegistry::register_udwf #9075

Merged
merged 4 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,15 +823,7 @@ impl SessionContext {
/// Any functions registered with the udf name or its aliases will be overwritten with this new function
pub fn register_udf(&self, f: ScalarUDF) {
let mut state = self.state.write();
let aliases = f.aliases();
for alias in aliases {
state
.scalar_functions
.insert(alias.to_string(), Arc::new(f.clone()));
}
state
.scalar_functions
.insert(f.name().to_string(), Arc::new(f));
state.register_udf(Arc::new(f)).ok();
Copy link
Contributor Author

@alamb alamb Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved into SessionState -- it is important that the alias resolution is done there so when new functions are registered via SessionState their aliases are as well. Otherwise aliases are only added when the function is defined via SessionContext

}

/// Registers an aggregate UDF within this context.
Expand All @@ -842,10 +834,7 @@ impl SessionContext {
/// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
/// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
pub fn register_udaf(&self, f: AggregateUDF) {
self.state
.write()
.aggregate_functions
.insert(f.name().to_string(), Arc::new(f));
self.state.write().register_udaf(Arc::new(f)).ok();
}

/// Registers a window UDF within this context.
Expand All @@ -856,10 +845,7 @@ impl SessionContext {
/// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_udwf"`
/// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"`
pub fn register_udwf(&self, f: WindowUDF) {
self.state
.write()
.window_functions
.insert(f.name().to_string(), Arc::new(f));
self.state.write().register_udwf(Arc::new(f)).ok();
}

/// Creates a [`DataFrame`] for reading a data source.
Expand Down Expand Up @@ -1984,8 +1970,24 @@ impl FunctionRegistry for SessionState {
}

fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
let aliases = udf.aliases();
alamb marked this conversation as resolved.
Show resolved Hide resolved
for alias in aliases {
alamb marked this conversation as resolved.
Show resolved Hide resolved
self.scalar_functions.insert(alias.to_string(), udf.clone());
alamb marked this conversation as resolved.
Show resolved Hide resolved
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}


Ok(self.scalar_functions.insert(udf.name().into(), udf))
}

fn register_udaf(
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
Ok(self.aggregate_functions.insert(udaf.name().into(), udaf))
}

fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
Ok(self.window_functions.insert(udwf.name().into(), udwf))
}
}

impl OptimizerConfig for SessionState {
Expand Down
30 changes: 26 additions & 4 deletions datafusion/execution/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ pub trait FunctionRegistry {
/// Set of all available udfs.
fn udfs(&self) -> HashSet<String>;

/// Returns a reference to the udf named `name`.
/// Returns a reference to the user defined scalar function (udf) named
/// `name`.
fn udf(&self, name: &str) -> Result<Arc<ScalarUDF>>;

/// Returns a reference to the udaf named `name`.
/// Returns a reference to the user defined table function (udaf) named
alamb marked this conversation as resolved.
Show resolved Hide resolved
/// `name`.
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>>;

/// Returns a reference to the udwf named `name`.
/// Returns a reference to the user defined window function (udwf) named
/// `name`.
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;

/// Registers a new [`ScalarUDF`], returning any previously registered
Expand All @@ -45,7 +48,26 @@ pub trait FunctionRegistry {
not_impl_err!("Registering ScalarUDF")
}

// TODO add register_udaf and register_udwf
/// Registers a new [`AggregateUDF`], returning any previously registered
/// implementation.
///
/// Returns an error (the default) if the function can not be registered,
/// for example if the registry is read only.
fn register_udaf(
Copy link
Contributor Author

@alamb alamb Jan 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the key (backwards compatible) API addition in this PR.

It sets us up for being able to pull out BuiltInAggregate and BuiltInWindowFunction

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering why this is a backwards compatible API? Does FunctionRegistry provide this API before?

Copy link
Contributor Author

@alamb alamb Feb 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is backwards compatible because it provides a default implementation (to return a Not Yet Implemented error)

Thus, any existing implementations of FunctionRegistry will continue to compile and work as it did before, without any required code changes

&mut self,
_udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
not_impl_err!("Registering AggregateUDF")
}

/// Registers a new [`WindowUDF`], returning any previously registered
/// implementation.
///
/// Returns an error (the default) if the function can not be registered,
/// for example if the registry is read only.
fn register_udwf(&mut self, _udaf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
not_impl_err!("Registering WindowUDF")
}
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
Expand Down
Loading