Skip to content

Commit

Permalink
ARROW-11188: [Rust] Support crypto functions from PostgreSQL dialect
Browse files Browse the repository at this point in the history
Implemented functions:

- [x] MD5 (return type string)
- [x] SHA224 (return type binary)
- [x] SHA256 (return type binary)
- [x] SHA384 (return type binary)
- [x] SHA512 (return type binary)

Closes #9139 from ovr/crypto-functions

Authored-by: Dmitry Patsura <zaets28rus@gmail.com>
Signed-off-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
ovr authored and alamb committed Jan 11, 2021
1 parent 01dec7e commit 6da7718
Show file tree
Hide file tree
Showing 9 changed files with 297 additions and 19 deletions.
24 changes: 24 additions & 0 deletions rust/arrow/src/util/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,27 @@ macro_rules! make_string {
}};
}

// It's not possible to do array.value($row).to_string() for &[u8], let's format it as hex
macro_rules! make_string_hex {
($array_type:ty, $column: ident, $row: ident) => {{
let array = $column.as_any().downcast_ref::<$array_type>().unwrap();

let s = if array.is_null($row) {
"".to_string()
} else {
let mut tmp = "".to_string();

for character in array.value($row) {
tmp += &format!("{:02x}", character);
}

tmp
};

Ok(s)
}};
}

macro_rules! make_string_from_list {
($column: ident, $row: ident) => {{
let list = $column
Expand All @@ -67,6 +88,9 @@ macro_rules! make_string_from_list {
pub fn array_value_to_string(column: &array::ArrayRef, row: usize) -> Result<String> {
match column.data_type() {
DataType::Utf8 => make_string!(array::StringArray, column, row),
DataType::LargeUtf8 => make_string!(array::LargeStringArray, column, row),
DataType::Binary => make_string_hex!(array::BinaryArray, column, row),
DataType::LargeBinary => make_string_hex!(array::LargeBinaryArray, column, row),
DataType::Boolean => make_string!(array::BooleanArray, column, row),
DataType::Int8 => make_string!(array::Int8Array, column, row),
DataType::Int16 => make_string!(array::Int16Array, column, row),
Expand Down
2 changes: 2 additions & 0 deletions rust/datafusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ futures = "0.3"
pin-project-lite= "^0.2.0"
tokio = { version = "0.2", features = ["macros", "blocking", "rt-core", "rt-threaded", "sync"] }
log = "^0.4"
md-5 = "^0.9.1"
sha2 = "^0.9.1"

[dev-dependencies]
rand = "0.8"
Expand Down
5 changes: 5 additions & 0 deletions rust/datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,11 @@ unary_scalar_expr!(Trim, trim);
unary_scalar_expr!(Ltrim, ltrim);
unary_scalar_expr!(Rtrim, rtrim);
unary_scalar_expr!(Upper, upper);
unary_scalar_expr!(MD5, md5);
unary_scalar_expr!(SHA224, sha224);
unary_scalar_expr!(SHA256, sha256);
unary_scalar_expr!(SHA384, sha384);
unary_scalar_expr!(SHA512, sha512);

/// returns the length of a string in bytes
pub fn length(e: Expr) -> Expr {
Expand Down
5 changes: 3 additions & 2 deletions rust/datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ pub use display::display_schema;
pub use expr::{
abs, acos, and, array, asin, atan, avg, binary_expr, case, ceil, col, concat, cos,
count, count_distinct, create_udaf, create_udf, exp, exprlist_to_fields, floor,
in_list, length, lit, ln, log10, log2, lower, ltrim, max, min, or, round, rtrim,
signum, sin, sqrt, sum, tan, trim, trunc, upper, when, Expr, Literal,
in_list, length, lit, ln, log10, log2, lower, ltrim, max, md5, min, or, round, rtrim,
sha224, sha256, sha384, sha512, signum, sin, sqrt, sum, tan, trim, trunc, upper,
when, Expr, Literal,
};
pub use extension::UserDefinedLogicalNode;
pub use operators::Operator;
Expand Down
106 changes: 106 additions & 0 deletions rust/datafusion/src/physical_plan/crypto_expressions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Crypto expressions
use md5::Md5;
use sha2::{
digest::Output as SHA2DigestOutput, Digest as SHA2Digest, Sha224, Sha256, Sha384,
Sha512,
};

use crate::error::{DataFusionError, Result};
use arrow::array::{
ArrayRef, GenericBinaryArray, GenericStringArray, StringOffsetSizeTrait,
};

fn md5_process(input: &str) -> String {
let mut digest = Md5::default();
digest.update(&input);

let mut result = String::new();

for byte in &digest.finalize() {
result.push_str(&format!("{:02x}", byte));
}

result
}

// It's not possible to return &[u8], because trait in trait without short lifetime
fn sha_process<D: SHA2Digest + Default>(input: &str) -> SHA2DigestOutput<D> {
let mut digest = D::default();
digest.update(&input);

digest.finalize()
}

macro_rules! crypto_unary_string_function {
($NAME:ident, $FUNC:expr) => {
/// crypto function that accepts Utf8 or LargeUtf8 and returns Utf8 string
pub fn $NAME<T: StringOffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<GenericStringArray<i32>> {
if args.len() != 1 {
return Err(DataFusionError::Internal(format!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
String::from(stringify!($NAME)),
)));
}

let array = args[0]
.as_any()
.downcast_ref::<GenericStringArray<T>>()
.unwrap();

// first map is the iterator, second is for the `Option<_>`
Ok(array.iter().map(|x| x.map(|x| $FUNC(x))).collect())
}
};
}

macro_rules! crypto_unary_binary_function {
($NAME:ident, $FUNC:expr) => {
/// crypto function that accepts Utf8 or LargeUtf8 and returns Binary
pub fn $NAME<T: StringOffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<GenericBinaryArray<i32>> {
if args.len() != 1 {
return Err(DataFusionError::Internal(format!(
"{:?} args were supplied but {} takes exactly one argument",
args.len(),
String::from(stringify!($NAME)),
)));
}

let array = args[0]
.as_any()
.downcast_ref::<GenericStringArray<T>>()
.unwrap();

// first map is the iterator, second is for the `Option<_>`
Ok(array.iter().map(|x| x.map(|x| $FUNC(x))).collect())
}
};
}

crypto_unary_string_function!(md5, md5_process);
crypto_unary_binary_function!(sha224, sha_process::<Sha224>);
crypto_unary_binary_function!(sha256, sha_process::<Sha256>);
crypto_unary_binary_function!(sha384, sha_process::<Sha384>);
crypto_unary_binary_function!(sha512, sha_process::<Sha512>);
133 changes: 117 additions & 16 deletions rust/datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use super::{
};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::array_expressions;
use crate::physical_plan::crypto_expressions;
use crate::physical_plan::datetime_expressions;
use crate::physical_plan::expressions::{nullif_func, SUPPORTED_NULLIF_TYPES};
use crate::physical_plan::math_expressions;
Expand Down Expand Up @@ -136,6 +137,16 @@ pub enum BuiltinScalarFunction {
NullIf,
/// Date truncate
DateTrunc,
/// MD5
MD5,
/// SHA224
SHA224,
/// SHA256,
SHA256,
/// SHA384
SHA384,
/// SHA512,
SHA512,
}

impl fmt::Display for BuiltinScalarFunction {
Expand Down Expand Up @@ -179,6 +190,11 @@ impl FromStr for BuiltinScalarFunction {
"date_trunc" => BuiltinScalarFunction::DateTrunc,
"array" => BuiltinScalarFunction::Array,
"nullif" => BuiltinScalarFunction::NullIf,
"md5" => BuiltinScalarFunction::MD5,
"sha224" => BuiltinScalarFunction::SHA224,
"sha256" => BuiltinScalarFunction::SHA256,
"sha384" => BuiltinScalarFunction::SHA384,
"sha512" => BuiltinScalarFunction::SHA512,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -288,6 +304,56 @@ pub fn return_type(
let coerced_types = data_types(arg_types, &signature(fun));
coerced_types.map(|typs| typs[0].clone())
}
BuiltinScalarFunction::MD5 => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::LargeUtf8,
DataType::Utf8 => DataType::Utf8,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The md5 function can only accept strings.".to_string(),
));
}
}),
BuiltinScalarFunction::SHA224 => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::Binary,
DataType::Utf8 => DataType::Binary,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The sha224 function can only accept strings.".to_string(),
));
}
}),
BuiltinScalarFunction::SHA256 => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::Binary,
DataType::Utf8 => DataType::Binary,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The sha256 function can only accept strings.".to_string(),
));
}
}),
BuiltinScalarFunction::SHA384 => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::Binary,
DataType::Utf8 => DataType::Binary,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The sha384 function can only accept strings.".to_string(),
));
}
}),
BuiltinScalarFunction::SHA512 => Ok(match arg_types[0] {
DataType::LargeUtf8 => DataType::Binary,
DataType::Utf8 => DataType::Binary,
_ => {
// this error is internal as `data_types` should have captured this.
return Err(DataFusionError::Internal(
"The sha512 function can only accept strings.".to_string(),
));
}
}),
_ => Ok(DataType::Float64),
}
}
Expand Down Expand Up @@ -318,6 +384,46 @@ pub fn create_physical_expr(
BuiltinScalarFunction::Abs => math_expressions::abs,
BuiltinScalarFunction::Signum => math_expressions::signum,
BuiltinScalarFunction::NullIf => nullif_func,
BuiltinScalarFunction::MD5 => |args| match args[0].data_type() {
DataType::Utf8 => Ok(Arc::new(crypto_expressions::md5::<i32>(args)?)),
DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::md5::<i64>(args)?)),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function md5",
other,
))),
},
BuiltinScalarFunction::SHA224 => |args| match args[0].data_type() {
DataType::Utf8 => Ok(Arc::new(crypto_expressions::sha224::<i32>(args)?)),
DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::sha224::<i64>(args)?)),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function sha224",
other,
))),
},
BuiltinScalarFunction::SHA256 => |args| match args[0].data_type() {
DataType::Utf8 => Ok(Arc::new(crypto_expressions::sha256::<i32>(args)?)),
DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::sha256::<i64>(args)?)),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function sha256",
other,
))),
},
BuiltinScalarFunction::SHA384 => |args| match args[0].data_type() {
DataType::Utf8 => Ok(Arc::new(crypto_expressions::sha384::<i32>(args)?)),
DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::sha384::<i64>(args)?)),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function sha384",
other,
))),
},
BuiltinScalarFunction::SHA512 => |args| match args[0].data_type() {
DataType::Utf8 => Ok(Arc::new(crypto_expressions::sha512::<i32>(args)?)),
DataType::LargeUtf8 => Ok(Arc::new(crypto_expressions::sha512::<i64>(args)?)),
other => Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function sha512",
other,
))),
},
BuiltinScalarFunction::Length => |args| Ok(length(args[0].as_ref())?),
BuiltinScalarFunction::Concat => {
|args| Ok(Arc::new(string_expressions::concatenate(args)?))
Expand Down Expand Up @@ -392,23 +498,18 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature {

// for now, the list is small, as we do not have many built-in functions.
match fun {
BuiltinScalarFunction::Length => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]),
BuiltinScalarFunction::Lower => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::Upper => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::Trim => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::Ltrim => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::Rtrim => {
BuiltinScalarFunction::Upper
| BuiltinScalarFunction::Lower
| BuiltinScalarFunction::Length
| BuiltinScalarFunction::Trim
| BuiltinScalarFunction::Ltrim
| BuiltinScalarFunction::Rtrim
| BuiltinScalarFunction::MD5
| BuiltinScalarFunction::SHA224
| BuiltinScalarFunction::SHA256
| BuiltinScalarFunction::SHA384
| BuiltinScalarFunction::SHA512 => {
Signature::Uniform(1, vec![DataType::Utf8, DataType::LargeUtf8])
}
BuiltinScalarFunction::ToTimestamp => Signature::Uniform(1, vec![DataType::Utf8]),
Expand Down
1 change: 1 addition & 0 deletions rust/datafusion/src/physical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ pub mod aggregates;
pub mod array_expressions;
pub mod coalesce_batches;
pub mod common;
pub mod crypto_expressions;
pub mod csv;
pub mod datetime_expressions;
pub mod distinct_expressions;
Expand Down
3 changes: 2 additions & 1 deletion rust/datafusion/src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub use crate::dataframe::DataFrame;
pub use crate::execution::context::{ExecutionConfig, ExecutionContext};
pub use crate::logical_plan::{
array, avg, col, concat, count, create_udf, in_list, length, lit, lower, ltrim, max,
min, rtrim, sum, trim, upper, JoinType, Partitioning,
md5, min, rtrim, sha224, sha256, sha384, sha512, sum, trim, upper, JoinType,
Partitioning,
};
pub use crate::physical_plan::csv::CsvReadOptions;
Loading

0 comments on commit 6da7718

Please sign in to comment.