diff --git a/crates/oxc_ast/src/ast_impl/js.rs b/crates/oxc_ast/src/ast_impl/js.rs index 759cb224a631b9..26006130775ed1 100644 --- a/crates/oxc_ast/src/ast_impl/js.rs +++ b/crates/oxc_ast/src/ast_impl/js.rs @@ -1007,6 +1007,13 @@ impl<'a> Function<'a> { } } +impl GetAddress for Function<'_> { + #[inline] + fn address(&self) -> Address { + Address::from_ptr(self) + } +} + impl<'a> FormalParameters<'a> { /// Number of parameters bound in this parameter list. pub fn parameters_count(&self) -> usize { diff --git a/crates/oxc_transformer/src/common/statement_injector.rs b/crates/oxc_transformer/src/common/statement_injector.rs index 4453d6cdc239c2..6b8a5b12313484 100644 --- a/crates/oxc_transformer/src/common/statement_injector.rs +++ b/crates/oxc_transformer/src/common/statement_injector.rs @@ -43,11 +43,13 @@ impl<'a, 'ctx> Traverse<'a> for StatementInjector<'a, 'ctx> { } } +#[derive(Debug)] enum Direction { Before, After, } +#[derive(Debug)] struct AdjacentStatement<'a> { stmt: Statement<'a>, direction: Direction, @@ -78,7 +80,7 @@ impl<'a> StatementInjectorStore<'a> { } /// Add a statement to be inserted immediately after the target statement. - pub fn insert_after(&self, target: &Statement<'a>, stmt: Statement<'a>) { + pub fn insert_after(&self, target: &A, stmt: Statement<'a>) { let mut insertions = self.insertions.borrow_mut(); let adjacent_stmts = insertions.entry(target.address()).or_default(); adjacent_stmts.push(AdjacentStatement { stmt, direction: Direction::After }); diff --git a/crates/oxc_transformer/src/jsx/refresh.rs b/crates/oxc_transformer/src/jsx/refresh.rs index 67547b609daba0..33934cdaad4e4e 100644 --- a/crates/oxc_transformer/src/jsx/refresh.rs +++ b/crates/oxc_transformer/src/jsx/refresh.rs @@ -1,5 +1,3 @@ -use std::iter::once; - use base64::prelude::{Engine, BASE64_STANDARD}; use rustc_hash::FxHashMap; use sha1::{Digest, Sha1}; @@ -107,7 +105,6 @@ pub struct ReactRefresh<'a, 'ctx> { /// Used to wrap call expression with signature. /// (eg: hoc(() => {}) -> _s1(hoc(_s1(() => {})))) last_signature: Option<(BindingIdentifier<'a>, oxc_allocator::Vec<'a, Argument<'a>>)>, - extra_statements: FxHashMap>>, // (function_scope_id, (hook_name, hook_key, custom_hook_callee) hook_calls: FxHashMap, Atom<'a>)>>, non_builtin_hooks_callee: FxHashMap>>>, @@ -127,7 +124,6 @@ impl<'a, 'ctx> ReactRefresh<'a, 'ctx> { registrations: Vec::default(), ctx, last_signature: None, - extra_statements: FxHashMap::default(), hook_calls: FxHashMap::default(), non_builtin_hooks_callee: FxHashMap::default(), } @@ -196,30 +192,18 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> { stmts: &mut oxc_allocator::Vec<'a, Statement<'a>>, ctx: &mut TraverseCtx<'a>, ) { - // TODO: check is there any function declaration - - let mut new_stmts = ctx.ast.vec_with_capacity(stmts.len() + 1); - let declarations = self.signature_declarator_items.pop().unwrap(); if !declarations.is_empty() { - new_stmts.push(Statement::from(ctx.ast.declaration_variable( - SPAN, - VariableDeclarationKind::Var, - declarations, - false, - ))); + stmts.insert( + 0, + Statement::from(ctx.ast.declaration_variable( + SPAN, + VariableDeclarationKind::Var, + declarations, + false, + )), + ); } - new_stmts.extend(stmts.drain(..).flat_map(move |stmt| { - let symbol_ids = get_symbol_id_from_function_and_declarator(&stmt); - let extra_stmts = symbol_ids - .into_iter() - .filter_map(|symbol_id| self.extra_statements.remove(&symbol_id)) - .flatten() - .collect::>(); - once(stmt).chain(extra_stmts) - })); - - *stmts = new_stmts; } fn exit_expression(&mut self, expr: &mut Expression<'a>, ctx: &mut TraverseCtx<'a>) { @@ -268,7 +252,6 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> { let first_argument = Argument::from(id_binding.create_read_expression(ctx)); arguments.insert(0, first_argument); - let statement = ctx.ast.statement_expression( SPAN, ctx.ast.expression_call( @@ -279,10 +262,19 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> { false, ), ); - self.extra_statements - .entry(id_binding.symbol_id) - .or_insert(ctx.ast.vec()) - .push(statement); + + let mut target_ancestor = ctx.ancestor(1); + for ancestor in ctx.ancestors().skip(2) { + if !matches!( + ancestor, + Ancestor::VariableDeclarationDeclarations(_) + | Ancestor::ExportNamedDeclarationDeclaration(_) + ) { + break; + } + target_ancestor = ancestor; + } + self.ctx.statement_injector.insert_after(&target_ancestor, statement); return; } } @@ -334,18 +326,27 @@ impl<'a, 'ctx> Traverse<'a> for ReactRefresh<'a, 'ctx> { arguments.insert(0, Argument::from(id_binding.create_read_expression(ctx))); let binding = BoundIdentifier::from_binding_ident(&binding_identifier); - self.extra_statements.entry(id_binding.symbol_id).or_insert(ctx.ast.vec()).push( - ctx.ast.statement_expression( - SPAN, - ctx.ast.expression_call( - SPAN, - binding.create_read_expression(ctx), - NONE, - arguments, - false, - ), - ), - ); + let callee = binding.create_read_expression(ctx); + let expr = ctx.ast.expression_call(SPAN, callee, NONE, arguments, false); + let statement = ctx.ast.statement_expression(SPAN, expr); + + let mut target_ancestor = Ancestor::None; + for ancestor in ctx.ancestors() { + if !matches!( + ancestor, + Ancestor::ExportNamedDeclarationDeclaration(_) + | Ancestor::ExportDefaultDeclarationDeclaration(_) + ) { + break; + } + target_ancestor = ancestor; + } + + if matches!(target_ancestor, Ancestor::None) { + self.ctx.statement_injector.insert_after(func, statement); + } else { + self.ctx.statement_injector.insert_after(&target_ancestor, statement); + } } fn enter_call_expression( @@ -898,42 +899,3 @@ fn is_builtin_hook(hook_name: &str) -> bool { "useOptimistic" ) } - -fn get_symbol_id_from_function_and_declarator(stmt: &Statement<'_>) -> Vec { - let mut symbol_ids = vec![]; - match stmt { - Statement::FunctionDeclaration(ref func) => { - if !func.is_typescript_syntax() { - symbol_ids.push(func.symbol_id().unwrap()); - } - } - Statement::VariableDeclaration(ref decl) => { - symbol_ids.extend(decl.declarations.iter().filter_map(|decl| { - decl.id.get_binding_identifier().and_then(|id| id.symbol_id.get()) - })); - } - Statement::ExportNamedDeclaration(ref export_decl) => { - if let Some(Declaration::FunctionDeclaration(func)) = &export_decl.declaration { - if !func.is_typescript_syntax() { - symbol_ids.push(func.symbol_id().unwrap()); - } - } else if let Some(Declaration::VariableDeclaration(decl)) = &export_decl.declaration { - symbol_ids.extend(decl.declarations.iter().filter_map(|decl| { - decl.id.get_binding_identifier().and_then(|id| id.symbol_id.get()) - })); - } - } - Statement::ExportDefaultDeclaration(ref export_decl) => { - if let ExportDefaultDeclarationKind::FunctionDeclaration(func) = - &export_decl.declaration - { - if let Some(id) = func.symbol_id() { - symbol_ids.push(id); - } - } - } - _ => {} - }; - - symbol_ids -}