diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5fc8dbcfdfb3..380723bd72a5 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -875,9 +875,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.103" +version = "1.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2755ff20a1d93490d26ba33a6f092a38a508398a5320df5d4b3014fcccce9410" +checksum = "74b6a57f98764a267ff415d50a25e6e166f3831a5071af4995296ea97d210490" dependencies = [ "jobserver", "libc", @@ -900,7 +900,7 @@ dependencies = [ "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -1333,6 +1333,7 @@ dependencies = [ "chrono", "datafusion-common", "datafusion-expr", + "datafusion-functions", "datafusion-physical-expr", "hashbrown 0.14.5", "indexmap 2.2.6", @@ -1964,9 +1965,9 @@ dependencies = [ [[package]] name = "hyper" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe575dd17d0862a9a33781c8c4696a55c320909004a67a00fb286ba8b1bc496d" +checksum = "c4fe55fb7a772d59a5ff1dfbff4fe0258d19b89fec4b233e75d35d5d2316badc" dependencies = [ "bytes", "futures-channel", @@ -2005,10 +2006,10 @@ checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" dependencies = [ "futures-util", "http 1.1.0", - "hyper 1.3.1", + "hyper 1.4.0", "hyper-util", "rustls 0.23.10", - "rustls-native-certs 0.7.0", + "rustls-native-certs 0.7.1", "rustls-pki-types", "tokio", "tokio-rustls 0.26.0", @@ -2017,16 +2018,16 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b875924a60b96e5d7b9ae7b066540b1dd1cbd90d1828f54c92e02a283351c56" +checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" dependencies = [ "bytes", "futures-channel", "futures-util", "http 1.1.0", "http-body 1.0.0", - "hyper 1.3.1", + "hyper 1.4.0", "pin-project-lite", "socket2", "tokio", @@ -2501,7 +2502,7 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.3.1", + "hyper 1.4.0", "itertools", "md-5", "parking_lot", @@ -2573,7 +2574,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -2975,7 +2976,7 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", - "hyper 1.3.1", + "hyper 1.4.0", "hyper-rustls 0.27.2", "hyper-util", "ipnet", @@ -2987,7 +2988,7 @@ dependencies = [ "pin-project-lite", "quinn", "rustls 0.23.10", - "rustls-native-certs 0.7.0", + "rustls-native-certs 0.7.1", "rustls-pemfile 2.1.2", "rustls-pki-types", "serde", @@ -3142,9 +3143,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" dependencies = [ "openssl-probe", "rustls-pemfile 2.1.2", @@ -3180,9 +3181,9 @@ checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.4" +version = "0.102.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" +checksum = "f9a6fccd794a42c2c105b513a2f62bc3fd8f3ba57a4593677ceb0bd035164d78" dependencies = [ "ring 0.17.8", "rustls-pki-types", @@ -3315,9 +3316,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.119" +version = "1.0.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8eddb61f0697cc3989c5d64b452f5488e2b8a60fd7d5076a3045076ffef8cb0" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" dependencies = [ "itoa", "ryu", @@ -4097,7 +4098,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4115,7 +4116,7 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", ] [[package]] @@ -4135,18 +4136,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -4157,9 +4158,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -4169,9 +4170,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -4181,15 +4182,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -4199,9 +4200,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -4211,9 +4212,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -4223,9 +4224,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -4235,9 +4236,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winreg" @@ -4266,18 +4267,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae87e3fcd617500e5d106f0380cf7b77f3c6092aae37191433159dda23cfb087" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.34" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15e934569e47891f7d9411f1a451d947a60e000ab3bd24fbb970f000387d1b3b" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 579f5fed578f..fd0d1e41612b 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1413,12 +1413,16 @@ impl Expr { .unwrap() } + /// Returns true if the expression node is volatile, i.e. whether it can return + /// different results when evaluated multiple times with the same input. + pub fn is_volatile_node(&self) -> bool { + matches!(self, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile) + } + /// Returns true if the expression is volatile, i.e. whether it can return different /// results when evaluated multiple times with the same input. pub fn is_volatile(&self) -> Result { - self.exists(|expr| { - Ok(matches!(expr, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile )) - }) + self.exists(|expr| Ok(expr.is_volatile_node())) } /// Recursively find all [`Expr::Placeholder`] expressions, and diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 1a9e9630c076..54d8c472f13f 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,6 +45,7 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } datafusion-physical-expr = { workspace = true } hashbrown = { workspace = true } indexmap = { workspace = true } diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index cebae410f309..2bdcfc5cbe2a 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -191,24 +191,19 @@ impl CommonSubexprEliminate { id_array: &mut IdArray<'n>, expr_mask: ExprMask, ) -> Result { - // Don't consider volatile expressions for CSE. - Ok(if expr.is_volatile()? { - false - } else { - let mut visitor = ExprIdentifierVisitor { - expr_stats, - id_array, - visit_stack: vec![], - down_index: 0, - up_index: 0, - expr_mask, - random_state: &self.random_state, - found_common: false, - }; - expr.visit(&mut visitor)?; + let mut visitor = ExprIdentifierVisitor { + expr_stats, + id_array, + visit_stack: vec![], + down_index: 0, + up_index: 0, + expr_mask, + random_state: &self.random_state, + found_common: false, + }; + expr.visit(&mut visitor)?; - visitor.found_common - }) + Ok(visitor.found_common) } /// Rewrites `exprs_list` with common sub-expressions replaced with a new @@ -917,27 +912,36 @@ struct ExprIdentifierVisitor<'a, 'n> { /// Record item that used when traversing an expression tree. enum VisitRecord<'n> { - /// Contains the post-order index assigned in during the first, visiting traversal and - /// a boolean flag to indicate if the record marks an expression subtree (not just a - /// single node). + /// Marks the beginning of expression. It contains: + /// - The post-order index assigned during the first, visiting traversal. + /// - A boolean flag if the record marks an expression subtree (not just a single + /// node). EnterMark(usize, bool), - /// Accumulated identifier of sub expression. - ExprItem(Identifier<'n>), + + /// Marks an accumulated subexpression tree. It contains: + /// - The accumulated identifier of a subexpression. + /// - A boolean flag if the expression is valid for subexpression elimination. + /// The flag is propagated up from children to parent. (E.g. volatile expressions + /// are not valid and can't be extracted, but non-volatile children of volatile + /// expressions can be extracted.) + ExprItem(Identifier<'n>, bool), } impl<'n> ExprIdentifierVisitor<'_, 'n> { /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> (usize, bool, Option>) { + fn pop_enter_mark(&mut self) -> (usize, bool, Option>, bool) { let mut expr_id = None; + let mut is_valid = true; while let Some(item) = self.visit_stack.pop() { match item { - VisitRecord::EnterMark(down_index, tree) => { - return (down_index, tree, expr_id); + VisitRecord::EnterMark(down_index, is_tree) => { + return (down_index, is_tree, expr_id, is_valid); } - VisitRecord::ExprItem(id) => { - expr_id = Some(id.combine(expr_id)); + VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => { + expr_id = Some(sub_expr_id.combine(expr_id)); + is_valid &= sub_expr_is_valid; } } } @@ -949,8 +953,6 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { type Node = Expr; fn f_down(&mut self, expr: &'n Expr) -> Result { - // TODO: consider non-volatile sub-expressions for CSE - // If an expression can short circuit its children then don't consider its // children for CSE (https://github.com/apache/arrow-datafusion/issues/8814). // This means that we don't recurse into its children, but handle the expression @@ -972,13 +974,14 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { } fn f_up(&mut self, expr: &'n Expr) -> Result { - let (down_index, is_tree, sub_expr_id) = self.pop_enter_mark(); + let (down_index, is_tree, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark(); let expr_id = Identifier::new(expr, is_tree, self.random_state).combine(sub_expr_id); + let is_valid = !expr.is_volatile_node() && sub_expr_is_valid; self.id_array[down_index].0 = self.up_index; - if !self.expr_mask.ignores(expr) { + if is_valid && !self.expr_mask.ignores(expr) { self.id_array[down_index].1 = Some(expr_id); let count = self.expr_stats.entry(expr_id).or_insert(0); *count += 1; @@ -986,7 +989,8 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> { self.found_common = true; } } - self.visit_stack.push(VisitRecord::ExprItem(expr_id)); + self.visit_stack + .push(VisitRecord::ExprItem(expr_id, is_valid)); self.up_index += 1; Ok(TreeNodeRecursion::Continue) @@ -1105,13 +1109,14 @@ mod test { use std::iter; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_expr::expr::AggregateFunction; + use datafusion_expr::expr::{AggregateFunction, ScalarFunction}; use datafusion_expr::logical_plan::{table_scan, JoinType}; use datafusion_expr::{ grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Signature, SimpleAggregateUDF, Volatility, }; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; + use datafusion_functions::math; use crate::optimizer::OptimizerContext; use crate::test::*; @@ -1838,4 +1843,27 @@ mod test { Ok(()) } + + #[test] + fn test_volatile() -> Result<()> { + let table_scan = test_table_scan()?; + + let extracted_child = col("a") + col("b"); + let rand = Expr::ScalarFunction(ScalarFunction::new_udf(math::random(), vec![])); + let not_extracted_volatile = extracted_child + rand; + let plan = LogicalPlanBuilder::from(table_scan.clone()) + .project(vec![ + not_extracted_volatile.clone().alias("c1"), + not_extracted_volatile.alias("c2"), + ])? + .build()?; + + let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\ + \n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\ + \n TableScan: test"; + + assert_optimized_plan_eq(expected, plan, None); + + Ok(()) + } }