diff --git a/tasks/ast_tools/src/generators/visit.rs b/tasks/ast_tools/src/generators/visit.rs index 4488e44aa553e7..e7d75f5cb95384 100644 --- a/tasks/ast_tools/src/generators/visit.rs +++ b/tasks/ast_tools/src/generators/visit.rs @@ -464,6 +464,7 @@ impl<'a> VisitBuilder<'a> { }; let mut enter_scope_at = 0; + let mut exit_scope_at: Option = None; let mut enter_node_at = 0; let fields_visits: Vec = struct_ .fields @@ -481,6 +482,7 @@ impl<'a> VisitBuilder<'a> { let visit_args = markers.visit.visit_args.clone(); let have_enter_scope = markers.scope.enter_before; + let have_exit_scope = markers.scope.exit_before; let have_enter_node = markers.visit.enter_before; let (args_def, args) = visit_args @@ -525,6 +527,18 @@ impl<'a> VisitBuilder<'a> { }; enter_scope_at = ix; } + if have_exit_scope { + assert!( + exit_scope_at.is_none(), + "Scopes cannot be exited more than once. Remove the extra `#[scope(exit_before)]` attribute(s)." + ); + let scope_leave = &scope_events.1; + result = quote! { + #scope_leave + #result + }; + exit_scope_at = Some(ix); + } #[expect(unreachable_code)] if have_enter_node { @@ -563,17 +577,25 @@ impl<'a> VisitBuilder<'a> { }, }; - let with_scope_events = |body: TokenStream| match (scope_events, enter_scope_at) { - ((enter, leave), 0) => quote! { - #enter - #body - #leave - }, - ((_, leave), _) => quote! { - #body - #leave - }, - }; + let with_scope_events = + |body: TokenStream| match (scope_events, enter_scope_at, exit_scope_at) { + ((enter, leave), 0, None) => quote! { + #enter + #body + #leave + }, + ((_, leave), _, None) => quote! { + #body + #leave + }, + ((enter, _), 0, Some(_)) => quote! { + #enter + #body + }, + ((_, _), _, Some(_)) => quote! { + #body + }, + }; let body = with_node_events(with_scope_events(quote!(#(#fields_visits)*))); diff --git a/tasks/ast_tools/src/markers.rs b/tasks/ast_tools/src/markers.rs index 4165e623b2e5ab..7d634c8bd8a009 100644 --- a/tasks/ast_tools/src/markers.rs +++ b/tasks/ast_tools/src/markers.rs @@ -64,7 +64,10 @@ pub struct VisitMarkers { /// A struct representing `#[scope(...)]` markers #[derive(Default, Debug)] pub struct ScopeMarkers { + /// `#[scope(enter_before)]` pub enter_before: bool, + /// `#[scope(exit_before)]` + pub exit_before: bool, } /// A struct representing all the helper attributes that might be used with `#[generate_derive(...)]` @@ -204,7 +207,10 @@ where || Ok(ScopeMarkers::default()), |attr| { attr.parse_args_with(Ident::parse) - .map(|id| ScopeMarkers { enter_before: id == "enter_before" }) + .map(|id| ScopeMarkers { + enter_before: id == "enter_before", + exit_before: id == "exit_before", + }) .normalize() }, )