From 50cac4efe172f47b3b0a4d40a38d78598816c766 Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Tue, 25 Jun 2024 17:47:11 -0500 Subject: [PATCH] provides workaround for half-migrated UDAF `sum` Ref #730 --- examples/tpch/_tests.py | 1 - src/functions.rs | 20 +++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/tpch/_tests.py b/examples/tpch/_tests.py index aa9491bf..3f973d9f 100644 --- a/examples/tpch/_tests.py +++ b/examples/tpch/_tests.py @@ -74,7 +74,6 @@ def check_q17(df): ("q10_returned_item_reporting", "q10"), pytest.param( "q11_important_stock_identification", "q11", - marks=pytest.mark.xfail # https://github.com/apache/datafusion-python/issues/730 ), ("q12_ship_mode_order_priority", "q12"), ("q13_customer_distribution", "q13"), diff --git a/src/functions.rs b/src/functions.rs index 8e395ae4..cbe61aa5 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -320,14 +320,20 @@ fn window( window_frame: Option, ctx: Option, ) -> PyResult { - let fun = find_df_window_func(name).or_else(|| { - ctx.and_then(|ctx| { - ctx.ctx - .udaf(name) - .map(WindowFunctionDefinition::AggregateUDF) - .ok() + // workaround for https://github.com/apache/datafusion-python/issues/730 + let fun = if name == "sum" { + let sum_udf = functions_aggregate::sum::sum_udaf(); + Some(WindowFunctionDefinition::AggregateUDF(sum_udf)) + } else { + find_df_window_func(name).or_else(|| { + ctx.and_then(|ctx| { + ctx.ctx + .udaf(name) + .map(WindowFunctionDefinition::AggregateUDF) + .ok() + }) }) - }); + }; if fun.is_none() { return Err(DataFusionError::Common("window function not found".to_string()).into()); }