Skip to content

Commit

Permalink
feat(core): introduce data control attributes for wren MDL base (#1014)
Browse files Browse the repository at this point in the history
  • Loading branch information
goldmedal authored Dec 26, 2024
1 parent f79314f commit cd89fd1
Show file tree
Hide file tree
Showing 6 changed files with 640 additions and 4 deletions.
145 changes: 145 additions & 0 deletions wren-core-base/manifest-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ pub fn column(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStrea
pub expression: Option<String>,
#[serde(default, with = "bool_from_int")]
pub is_hidden: bool,
pub rls: Option<RowLevelSecurity>,
pub cls: Option<ColumnLevelSecurity>,
}
};
proc_macro::TokenStream::from(expanded)
Expand Down Expand Up @@ -342,3 +344,146 @@ pub fn view(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream
};
proc_macro::TokenStream::from(expanded)
}

#[proc_macro]
pub fn row_level_security(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(python_binding as LitBool);
let python_binding = if input.value {
quote! {
#[pyclass]
}
} else {
quote! {}
};
let expanded = quote! {
#python_binding
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
pub struct RowLevelSecurity {
pub name: String,
pub operator: RowLevelOperator,
}
};
proc_macro::TokenStream::from(expanded)
}

#[proc_macro]
pub fn row_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(python_binding as LitBool);
let python_binding = if input.value {
quote! {
#[pyclass(eq, eq_int)]
}
} else {
quote! {}
};
let expanded = quote! {
#python_binding
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum RowLevelOperator {
Equals,
NotEquals,
GreaterThan,
LessThan,
GreaterThanOrEquals,
LessThanOrEquals,
IN,
NotIn,
LIKE,
NotLike,
}
};
proc_macro::TokenStream::from(expanded)
}

#[proc_macro]
pub fn column_level_security(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(python_binding as LitBool);
let python_binding = if input.value {
quote! {
#[pyclass]
}
} else {
quote! {}
};
let expanded = quote! {
#python_binding
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash)]
pub struct ColumnLevelSecurity {
pub name: String,
pub operator: ColumnLevelOperator,
pub threshold: NormalizedExpr,
}
};
proc_macro::TokenStream::from(expanded)
}

#[proc_macro]
pub fn column_level_operator(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(python_binding as LitBool);
let python_binding = if input.value {
quote! {
#[pyclass(eq, eq_int)]
}
} else {
quote! {}
};
let expanded = quote! {
#python_binding
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ColumnLevelOperator {
Equals,
NotEquals,
GreaterThan,
LessThan,
GreaterThanOrEquals,
LessThanOrEquals,
}
};
proc_macro::TokenStream::from(expanded)
}

#[proc_macro]
pub fn normalized_expr(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(python_binding as LitBool);
let python_binding = if input.value {
quote! {
#[pyclass]
}
} else {
quote! {}
};
let expanded = quote! {
#python_binding
#[derive(SerializeDisplay, DeserializeFromStr, Debug, PartialEq, Eq, Hash)]
pub struct NormalizedExpr {
pub value: String,
#[serde_with(alias = "type")]
pub data_type: NormalizedExprType,
}
};
proc_macro::TokenStream::from(expanded)
}

#[proc_macro]
pub fn normalized_expr_type(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(python_binding as LitBool);
let python_binding = if input.value {
quote! {
#[pyclass(eq, eq_int)]
}
} else {
quote! {}
};
let expanded = quote! {
#python_binding
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum NormalizedExprType {
Numeric,
String,
}
};
proc_macro::TokenStream::from(expanded)
}
45 changes: 45 additions & 0 deletions wren-core-base/src/mdl/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
use crate::mdl::manifest::{
Column, DataSource, JoinType, Manifest, Metric, Model, Relationship, TimeGrain, TimeUnit, View,
};
use crate::mdl::{
ColumnLevelOperator, ColumnLevelSecurity, NormalizedExpr, RowLevelOperator, RowLevelSecurity,
};
use std::sync::Arc;

/// A builder for creating a Manifest
Expand Down Expand Up @@ -165,6 +168,8 @@ impl ColumnBuilder {
is_hidden: false,
not_null: false,
expression: None,
rls: None,
cls: None,
},
}
}
Expand Down Expand Up @@ -202,6 +207,27 @@ impl ColumnBuilder {
self
}

pub fn row_level_security(mut self, name: &str, operator: RowLevelOperator) -> Self {
self.column.rls = Some(RowLevelSecurity {
name: name.to_string(),
operator,
});
self
}

pub fn column_level_security(
mut self,
name: &str,
operator: ColumnLevelOperator,
threshold: &str,
) -> Self {
self.column.cls = Some(ColumnLevelSecurity {
name: name.to_string(),
operator,
threshold: NormalizedExpr::new(threshold),
});
self
}
pub fn build(self) -> Arc<Column> {
Arc::new(self.column)
}
Expand Down Expand Up @@ -356,6 +382,7 @@ mod test {
use crate::mdl::manifest::{
Column, DataSource, JoinType, Manifest, Metric, Model, Relationship, TimeUnit, View,
};
use crate::mdl::{ColumnLevelOperator, RowLevelOperator};
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;
Expand All @@ -368,6 +395,8 @@ mod test {
.not_null(true)
.hidden(true)
.expression("test")
.row_level_security("SESSION_STATUS", RowLevelOperator::Equals)
.column_level_security("SESSION_LEVEL", ColumnLevelOperator::Equals, "'NORMAL'")
.build();

let json_str = serde_json::to_string(&expected).unwrap();
Expand Down Expand Up @@ -661,6 +690,22 @@ mod test {
.calculated(true)
.build(),
)
.column(
ColumnBuilder::new("rls_orderkey", "integer")
.row_level_security("SESSION_STATUS", RowLevelOperator::Equals)
.expression("o_orderkey")
.build(),
)
.column(
ColumnBuilder::new("cls_orderkey", "integer")
.column_level_security(
"SESSION_LEVEL",
ColumnLevelOperator::Equals,
"'NORMAL'",
)
.expression("o_orderkey")
.build(),
)
.primary_key("o_orderkey")
.build(),
)
Expand Down
Loading

0 comments on commit cd89fd1

Please sign in to comment.