Skip to content

Commit

Permalink
feat(python): primitive kwargs in plugins (pola-rs#11268)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 16, 2023
1 parent 6e886f9 commit 5d48cc8
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 19 deletions.
4 changes: 4 additions & 0 deletions crates/polars-arrow/src/ffi/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ impl ArrowSchema {
}
}

pub fn is_null(&self) -> bool {
self.private_data.is_null()
}

/// returns the format of this schema.
pub(crate) fn format(&self) -> &str {
assert!(!self.format.is_null());
Expand Down
16 changes: 16 additions & 0 deletions crates/polars-ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ pub struct SeriesExport {
private_data: *mut std::os::raw::c_void,
}

impl SeriesExport {
pub fn empty() -> Self {
Self {
field: std::ptr::null_mut(),
arrays: std::ptr::null_mut(),
len: 0,
release: None,
private_data: std::ptr::null_mut(),
}
}

pub fn is_null(&self) -> bool {
self.private_data.is_null()
}
}

impl Drop for SeriesExport {
fn drop(&mut self) {
if let Some(release) = self.release {
Expand Down
26 changes: 23 additions & 3 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,15 @@ pub enum FunctionExpr {
},
SetSortedFlag(IsSorted),
#[cfg(feature = "ffi_plugin")]
/// Creating this node is unsafe
/// This will lead to calls over FFI>
FfiPlugin {
/// Shared library.
lib: Arc<str>,
/// Identifier in the shared lib.
symbol: Arc<str>,
/// Pickle serialized keyword arguments.
kwargs: Arc<[u8]>,
},
BackwardFill {
limit: FillNullLimit,
Expand Down Expand Up @@ -309,7 +315,12 @@ impl Hash for FunctionExpr {
#[cfg(feature = "dtype-categorical")]
FunctionExpr::Categorical(f) => f.hash(state),
#[cfg(feature = "ffi_plugin")]
FunctionExpr::FfiPlugin { lib, symbol } => {
FunctionExpr::FfiPlugin {
lib,
symbol,
kwargs,
} => {
kwargs.hash(state);
lib.hash(state);
symbol.hash(state);
},
Expand Down Expand Up @@ -767,8 +778,17 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
},
SetSortedFlag(sorted) => map!(dispatch::set_sorted_flag, sorted),
#[cfg(feature = "ffi_plugin")]
FfiPlugin { lib, symbol, .. } => unsafe {
map_as_slice!(plugin::call_plugin, lib.as_ref(), symbol.as_ref())
FfiPlugin {
lib,
symbol,
kwargs,
} => unsafe {
map_as_slice!(
plugin::call_plugin,
lib.as_ref(),
symbol.as_ref(),
kwargs.as_ref()
)
},
BackwardFill { limit } => map!(dispatch::backward_fill, limit),
ForwardFill { limit } => map!(dispatch::forward_fill, limit),
Expand Down
72 changes: 61 additions & 11 deletions crates/polars-plan/src/dsl/function_expr/plugin.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::ffi::CString;
use std::sync::RwLock;

use arrow::ffi::{import_field_from_c, ArrowSchema};
Expand Down Expand Up @@ -30,24 +31,59 @@ fn get_lib(lib: &str) -> PolarsResult<&'static Library> {
}
}

pub(super) unsafe fn call_plugin(s: &[Series], lib: &str, symbol: &str) -> PolarsResult<Series> {
unsafe fn retrieve_error_msg(lib: &Library) -> CString {
let symbol: libloading::Symbol<unsafe extern "C" fn() -> *mut std::os::raw::c_char> =
lib.get(b"get_last_error_message\0").unwrap();
let msg_ptr = symbol();
CString::from_raw(msg_ptr)
}

pub(super) unsafe fn call_plugin(
s: &[Series],
lib: &str,
symbol: &str,
kwargs: &[u8],
) -> PolarsResult<Series> {
let lib = get_lib(lib)?;

// *const SeriesExport: pointer to Box<SeriesExport>
// * usize: length of that pointer
// *const u8: pointer to &[u8]
// usize: length of the u8 slice
// *mut SeriesExport: pointer where return value should be written.
let symbol: libloading::Symbol<
unsafe extern "C" fn(*const SeriesExport, usize) -> SeriesExport,
unsafe extern "C" fn(*const SeriesExport, usize, *const u8, usize, *mut SeriesExport),
> = lib.get(symbol.as_bytes()).unwrap();

let n_args = s.len();

let input = s.iter().map(export_series).collect::<Vec<_>>();
let input_len = s.len();
let slice_ptr = input.as_ptr();
let out = symbol(slice_ptr, n_args);

let kwargs_ptr = kwargs.as_ptr();
let kwargs_len = kwargs.len();

let mut return_value = SeriesExport::empty();
let return_value_ptr = &mut return_value as *mut SeriesExport;
symbol(
slice_ptr,
input_len,
kwargs_ptr,
kwargs_len,
return_value_ptr,
);

// The inputs get dropped when the ffi side calls the drop callback.
for e in input {
std::mem::forget(e);
}

import_series(out)
if !return_value.is_null() {
import_series(return_value)
} else {
let msg = retrieve_error_msg(lib);
let msg = msg.to_string_lossy();
polars_bail!(ComputeError: "the plugin failed with message: {}", msg)
}
}

pub(super) unsafe fn plugin_field(
Expand All @@ -57,8 +93,12 @@ pub(super) unsafe fn plugin_field(
) -> PolarsResult<Field> {
let lib = get_lib(lib)?;

let symbol: libloading::Symbol<unsafe extern "C" fn(*const ArrowSchema, usize) -> ArrowSchema> =
lib.get(symbol.as_bytes()).unwrap();
// *const ArrowSchema: pointer to heap Box<ArrowSchema>
// usize: length of the boxed slice
// *mut ArrowSchema: pointer where the return value can be written
let symbol: libloading::Symbol<
unsafe extern "C" fn(*const ArrowSchema, usize, *mut ArrowSchema),
> = lib.get(symbol.as_bytes()).unwrap();

// we deallocate the fields buffer
let fields = fields
Expand All @@ -68,8 +108,18 @@ pub(super) unsafe fn plugin_field(
.into_boxed_slice();
let n_args = fields.len();
let slice_ptr = fields.as_ptr();
let out = symbol(slice_ptr, n_args);

let arrow_field = import_field_from_c(&out)?;
Ok(Field::from(&arrow_field))
let mut return_value = ArrowSchema::empty();
let return_value_ptr = &mut return_value as *mut ArrowSchema;
symbol(slice_ptr, n_args, return_value_ptr);

if !return_value.is_null() {
let arrow_field = import_field_from_c(&return_value)?;
let out = Field::from(&arrow_field);
Ok(out)
} else {
let msg = retrieve_error_msg(lib);
let msg = msg.to_string_lossy();
polars_bail!(ComputeError: "the plugin failed with message: {}", msg)
}
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ impl FunctionExpr {
Random { .. } => mapper.with_same_dtype(),
SetSortedFlag(_) => mapper.with_same_dtype(),
#[cfg(feature = "ffi_plugin")]
FfiPlugin { lib, symbol } => unsafe {
FfiPlugin { lib, symbol, .. } => unsafe {
plugin::plugin_field(fields, lib, &format!("__polars_field_{}", symbol.as_ref()))
},
BackwardFill { .. } => mapper.with_same_dtype(),
Expand Down
14 changes: 13 additions & 1 deletion py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9508,10 +9508,11 @@ def is_last(self) -> Self:

def _register_plugin(
self,
*,
lib: str,
symbol: str,
args: list[IntoExpr] | None = None,
*,
kwargs: dict[Any, Any] | None = None,
is_elementwise: bool = False,
input_wildcard_expansion: bool = False,
auto_explode: bool = False,
Expand All @@ -9536,6 +9537,9 @@ def _register_plugin(
Function to load.
args
Arguments (other than self) passed to this function.
These arguments have to be of type Expression.
kwargs
Non-expression arguments. They must be JSON serializable.
is_elementwise
If the function only operates on scalars
this will trigger fast paths.
Expand All @@ -9552,11 +9556,19 @@ def _register_plugin(
args = []
else:
args = [parse_as_expression(a) for a in args]
if kwargs is None:
serialized_kwargs = b""
else:
import pickle

serialized_kwargs = pickle.dumps(kwargs, protocol=2)

return self._from_pyexpr(
self._pyexpr.register_plugin(
lib,
symbol,
args,
serialized_kwargs,
is_elementwise,
input_wildcard_expansion,
auto_explode,
Expand Down
8 changes: 5 additions & 3 deletions py-polars/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -892,11 +892,12 @@ impl PyExpr {
lib: &str,
symbol: &str,
args: Vec<PyExpr>,
kwargs: Vec<u8>,
is_elementwise: bool,
input_wildcard_expansion: bool,
auto_explode: bool,
cast_to_supertypes: bool,
) -> Self {
) -> PyResult<Self> {
use polars_plan::prelude::*;
let inner = self.inner.clone();

Expand All @@ -911,11 +912,12 @@ impl PyExpr {
input.push(a.inner)
}

Expr::Function {
Ok(Expr::Function {
input,
function: FunctionExpr::FfiPlugin {
lib: Arc::from(lib),
symbol: Arc::from(symbol),
kwargs: Arc::from(kwargs),
},
options: FunctionOptions {
collect_groups,
Expand All @@ -925,6 +927,6 @@ impl PyExpr {
..Default::default()
},
}
.into()
.into())
}
}

0 comments on commit 5d48cc8

Please sign in to comment.