Skip to content

Commit

Permalink
Add derive macro for SiNodeWeight
Browse files Browse the repository at this point in the history
  • Loading branch information
jhelwig committed Oct 9, 2024
1 parent c1da56e commit eadb5e6
Show file tree
Hide file tree
Showing 9 changed files with 420 additions and 86 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions lib/dal-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,9 @@ quote = { workspace = true }
syn = { workspace = true }

[dev-dependencies]
dal = { path = "../dal" }
proc-macro2 = { workspace = true }
quote = { workspace = true }
si-events = { path = "../si-events-rs" }
syn = { workspace = true }
trybuild = { version = "1.0.99", features = ["diff"] }
9 changes: 9 additions & 0 deletions lib/dal-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,12 @@ pub fn si_versioned_node_weight(
) -> manyhow::Result<proc_macro::TokenStream> {
node_weight::versioned::derive_si_versioned_node_weight(input, errors)
}

#[manyhow]
#[proc_macro_derive(SiNodeWeight, attributes(si_node_weight))]
pub fn si_node_weight(
input: proc_macro::TokenStream,
errors: &mut manyhow::Emitter,
) -> manyhow::Result<proc_macro::TokenStream> {
node_weight::derive_si_node_weight(input, errors)
}
292 changes: 292 additions & 0 deletions lib/dal-macros/src/node_weight.rs
Original file line number Diff line number Diff line change
@@ -1 +1,293 @@
use std::collections::HashSet;

use darling::{util::IdentString, FromAttributes, FromMeta};
use manyhow::{bail, emit};
use proc_macro2::TokenStream;
use quote::{quote, ToTokens};
use syn::{Data, DeriveInput, Expr, Path};

pub(crate) mod versioned;

#[derive(Debug, PartialEq, Eq, Hash)]
enum SkipOption {
ContentHash,
Id,
LineageId,
MerkleTreeHash,
NodeHash,
NodeWeightDiscriminants,
SetId,
SetLineageId,
SetMerkleTreeHash,
}

impl FromMeta for SkipOption {
fn from_string(value: &str) -> darling::Result<Self> {
match value {
"content_hash" => Ok(Self::ContentHash),
"id" => Ok(Self::Id),
"lineage_id" => Ok(Self::LineageId),
"merkle_tree_hash" => Ok(Self::MerkleTreeHash),
"node_hash" => Ok(Self::NodeHash),
"node_weight_discriminants" => Ok(Self::NodeWeightDiscriminants),
"set_id" => Ok(Self::SetId),
"set_lineage_id" => Ok(Self::SetLineageId),
"set_merkle_tree_hash" => Ok(Self::SetMerkleTreeHash),
v => Err(darling::Error::unknown_value(v)),
}
}
}

#[derive(Debug, Default)]
struct SkipOptionSet {
options: HashSet<SkipOption>,
}

impl SkipOptionSet {
fn should_skip(&self, option: SkipOption) -> bool {
self.options.contains(&option)
}

fn add_skip_option(&mut self, option: SkipOption) {
self.options.insert(option);
}
}

impl FromMeta for SkipOptionSet {
fn from_list(items: &[darling::ast::NestedMeta]) -> darling::Result<Self> {
let mut new_option_set = SkipOptionSet::default();
for item in items {
new_option_set.add_skip_option(SkipOption::from_nested_meta(item)?);
}

Ok(new_option_set)
}
}

#[derive(FromAttributes)]
#[darling(attributes(si_node_weight))]
struct SiNodeWeightOptions {
discriminant: Option<Path>,
#[darling(default)]
skip: SkipOptionSet,
}

#[derive(Debug, Default)]
struct NodeHashParticipation {
participant: bool,
custom_code: Option<String>,
}

impl FromMeta for NodeHashParticipation {
fn from_word() -> darling::Result<Self> {
Ok(Self {
participant: true,
custom_code: None,
})
}

fn from_value(value: &syn::Lit) -> darling::Result<Self> {
match value {
syn::Lit::Str(custom_str) => Ok(Self {
participant: true,
custom_code: Some(custom_str.value()),
}),
_ => Err(darling::Error::unexpected_lit_type(value)),
}
}
}

#[derive(Debug, Default, FromAttributes)]
#[darling(attributes(si_node_weight))]
struct NodeWeightFieldAttrs {
#[darling(default, rename = "node_hash")]
node_hash_participation: NodeHashParticipation,
}

pub fn derive_si_node_weight(
input: proc_macro::TokenStream,
errors: &mut manyhow::Emitter,
) -> manyhow::Result<proc_macro::TokenStream> {
let input = syn::parse::<DeriveInput>(input)?;
let DeriveInput {
ident,
data: type_data,
attrs,
..
} = input.clone();
let struct_options = SiNodeWeightOptions::from_attributes(&attrs)?;

let node_weight_discriminant = if let Some(discriminant) = struct_options.discriminant {
discriminant
} else {
emit!(errors, input, "No NodeWeightDiscriminants was specified.");
syn::parse_str::<Path>("")?
};

let struct_data = match &type_data {
Data::Struct(data) => data,
_ => {
bail!(input, "SiNodeWeight must be derived on a struct.");
}
};
let mut struct_has_content_address = false;
let mut node_hash_parts = Vec::new();
for field in &struct_data.fields {
if let Some(field_ident) = &field.ident {
let ident_string: IdentString = field_ident.clone().into();
if ident_string == "content_address" {
struct_has_content_address = true;
}
} else {
emit!(errors, field, "No identifier found for field.");
continue;
}
let field_attrs = NodeWeightFieldAttrs::from_attributes(&field.attrs)?;
if field_attrs.node_hash_participation.participant {
let hasher_update =
if let Some(custom_code) = field_attrs.node_hash_participation.custom_code {
match syn::parse_str::<Expr>(&custom_code) {
Ok(expr) => expr.to_token_stream(),
Err(e) => {
emit!(errors, syn::Error::new_spanned(field, e.to_string()));
continue;
}
}
} else {
let field_ident = field.ident.clone();
quote! {
self.#field_ident.as_bytes()
}
};
node_hash_parts.push(hasher_update);
}
}
errors.into_result()?;

let id_fn = if struct_options.skip.should_skip(SkipOption::Id) {
quote! {}
} else {
quote! {
fn id(&self) -> Ulid {
self.id
}
}
};

let lineage_id_fn = if struct_options.skip.should_skip(SkipOption::LineageId) {
quote! {}
} else {
quote! {
fn lineage_id(&self) -> Ulid {
self.lineage_id
}
}
};

let content_hash_fn = if struct_options.skip.should_skip(SkipOption::ContentHash) {
quote! {}
} else {
let content_hash_location = if struct_has_content_address {
quote! { self.content_address.content_hash() }
} else {
quote! { self.node_hash() }
};

quote! {
fn content_hash(&self) -> ContentHash {
#content_hash_location
}
}
};

let merkle_tree_hash_fn = if struct_options.skip.should_skip(SkipOption::MerkleTreeHash) {
quote! {}
} else {
quote! {
fn merkle_tree_hash(&self) -> MerkleTreeHash {
self.merkle_tree_hash
}
}
};

let node_hash_fn = if struct_options.skip.should_skip(SkipOption::NodeHash) {
quote! {}
} else {
let mut hash_updates = TokenStream::new();
for update in node_hash_parts {
let full_update = quote! { content_hasher.update(#update); };
hash_updates.extend(full_update);
}

quote! {
fn node_hash(&self) -> ContentHash {
let mut content_hasher = ContentHash::hasher();
#hash_updates

content_hasher.finalize()
}
}
};

let node_weight_discriminant_fn = if struct_options
.skip
.should_skip(SkipOption::NodeWeightDiscriminants)
{
quote! {}
} else {
quote! {
fn node_weight_discriminant(&self) -> NodeWeightDiscriminants {
#node_weight_discriminant
}
}
};

let set_id_fn = if struct_options.skip.should_skip(SkipOption::SetId) {
quote! {}
} else {
quote! {
fn set_id(&mut self, new_id: Ulid) {
self.id = new_id;
}
}
};

let set_lineage_id_fn = if struct_options.skip.should_skip(SkipOption::SetLineageId) {
quote! {}
} else {
quote! {
fn set_lineage_id(&mut self, new_lineage_id: Ulid) {
self.lineage_id = new_lineage_id;
}
}
};

let set_merkle_tree_hash_fn = if struct_options
.skip
.should_skip(SkipOption::SetMerkleTreeHash)
{
quote! {}
} else {
quote! {
fn set_merkle_tree_hash(&mut self, new_merkle_tree_hash: MerkleTreeHash) {
self.merkle_tree_hash = new_merkle_tree_hash;
}
}
};

let output = quote! {
impl SiNodeWeight for #ident {
#id_fn
#lineage_id_fn
#content_hash_fn
#merkle_tree_hash_fn
#node_hash_fn
#node_weight_discriminant_fn
#set_id_fn
#set_lineage_id_fn
#set_merkle_tree_hash_fn
}
};

Ok(output.into())
}
25 changes: 25 additions & 0 deletions lib/dal-macros/tests/ui/06-node_weight-basic-tests-pass.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use dal::workspace_snapshot::{
edge_weight::EdgeWeightKindDiscriminants,
node_weight::{
traits::{CorrectExclusiveOutgoingEdge, CorrectTransforms, SiNodeWeight},
NodeWeightDiscriminants,
},
};
use si_events::{merkle_tree_hash::MerkleTreeHash, ulid::Ulid, ContentHash};

#[derive(dal_macros::SiNodeWeight)]
#[si_node_weight(discriminant = NodeWeightDiscriminants::InputSocket)]
pub struct TestingNodeWeight {
id: Ulid,
lineage_id: Ulid,
merkle_tree_hash: MerkleTreeHash,
}

impl CorrectTransforms for TestingNodeWeight {}
impl CorrectExclusiveOutgoingEdge for TestingNodeWeight {
fn exclusive_outgoing_edges(&self) -> &[EdgeWeightKindDiscriminants] {
todo!()
}
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use dal::workspace_snapshot::{
edge_weight::EdgeWeightKindDiscriminants,
node_weight::{
traits::{CorrectExclusiveOutgoingEdge, CorrectTransforms, SiNodeWeight},
NodeWeightDiscriminants,
},
};
use si_events::{merkle_tree_hash::MerkleTreeHash, ulid::Ulid, ContentHash};

#[derive(dal_macros::SiNodeWeight)]
#[si_node_weight]
pub struct TestingNodeWeight {
id: Ulid,
lineage_id: Ulid,
merkle_tree_hash: MerkleTreeHash,
}

impl CorrectTransforms for TestingNodeWeight {}
impl CorrectExclusiveOutgoingEdge for TestingNodeWeight {
fn exclusive_outgoing_edges(&self) -> &[EdgeWeightKindDiscriminants] {
&[]
}
}

fn main() {}
Loading

0 comments on commit eadb5e6

Please sign in to comment.