diff --git a/compiler/packages/babel-plugin-react-compiler/src/Entrypoint/Pipeline.ts b/compiler/packages/babel-plugin-react-compiler/src/Entrypoint/Pipeline.ts index 6d231919a6693..97ec36626b67b 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/Entrypoint/Pipeline.ts +++ b/compiler/packages/babel-plugin-react-compiler/src/Entrypoint/Pipeline.ts @@ -70,7 +70,9 @@ import { } from "../ReactiveScopes"; import { alignMethodCallScopes } from "../ReactiveScopes/AlignMethodCallScopes"; import { alignReactiveScopesToBlockScopesHIR } from "../ReactiveScopes/AlignReactiveScopesToBlockScopesHIR"; +import { flattenReactiveLoopsHIR } from "../ReactiveScopes/FlattenReactiveLoopsHIR"; import { pruneAlwaysInvalidatingScopes } from "../ReactiveScopes/PruneAlwaysInvalidatingScopes"; +import pruneInitializationDependencies from "../ReactiveScopes/PruneInitializationDependencies"; import { stabilizeBlockIds } from "../ReactiveScopes/StabilizeBlockIds"; import { eliminateRedundantPhi, enterSSA, leaveSSA } from "../SSA"; import { inferTypes } from "../TypeInference"; @@ -91,7 +93,6 @@ import { validatePreservedManualMemoization, validateUseMemo, } from "../Validation"; -import pruneInitializationDependencies from "../ReactiveScopes/PruneInitializationDependencies"; export type CompilerPipelineValue = | { kind: "ast"; name: string; value: CodegenFunction } @@ -281,6 +282,13 @@ function* runWithEnvironment( }); assertValidBlockNesting(hir); + + flattenReactiveLoopsHIR(hir); + yield log({ + kind: "hir", + name: "FlattenReactiveLoopsHIR", + value: hir, + }); } const reactiveFunction = buildReactiveFunction(hir); @@ -320,14 +328,14 @@ function* runWithEnvironment( name: "BuildReactiveBlocks", value: reactiveFunction, }); - } - flattenReactiveLoops(reactiveFunction); - yield log({ - kind: "reactive", - name: "FlattenReactiveLoops", - value: reactiveFunction, - }); + flattenReactiveLoops(reactiveFunction); + yield log({ + kind: "reactive", + name: "FlattenReactiveLoops", + value: reactiveFunction, + }); + } assertScopeInstructionsWithinScopes(reactiveFunction); diff --git a/compiler/packages/babel-plugin-react-compiler/src/ReactiveScopes/FlattenReactiveLoopsHIR.ts b/compiler/packages/babel-plugin-react-compiler/src/ReactiveScopes/FlattenReactiveLoopsHIR.ts new file mode 100644 index 0000000000000..6fa65ee2bf257 --- /dev/null +++ b/compiler/packages/babel-plugin-react-compiler/src/ReactiveScopes/FlattenReactiveLoopsHIR.ts @@ -0,0 +1,71 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import { BlockId, HIRFunction, PrunedScopeTerminal } from "../HIR"; +import { assertExhaustive, retainWhere } from "../Utils/utils"; + +/** + * Prunes any reactive scopes that are within a loop (for, while, etc). We don't yet + * support memoization within loops because this would require an extra layer of reconciliation + * (plus a way to identify values across runs, similar to how we use `key` in JSX for lists). + * Eventually we may integrate more deeply into the runtime so that we can do a single level + * of reconciliation, but for now we've found it's sufficient to memoize *around* the loop. + */ +export function flattenReactiveLoopsHIR(fn: HIRFunction): void { + const activeLoops = Array(); + for (const [, block] of fn.body.blocks) { + retainWhere(activeLoops, (id) => id !== block.id); + const { terminal } = block; + switch (terminal.kind) { + case "do-while": + case "for": + case "for-in": + case "for-of": + case "while": { + activeLoops.push(terminal.fallthrough); + break; + } + case "scope": { + if (activeLoops.length !== 0) { + block.terminal = { + kind: "pruned-scope", + block: terminal.block, + fallthrough: terminal.fallthrough, + id: terminal.id, + loc: terminal.loc, + scope: terminal.scope, + } as PrunedScopeTerminal; + } + break; + } + case "branch": + case "goto": + case "if": + case "label": + case "logical": + case "maybe-throw": + case "optional": + case "pruned-scope": + case "return": + case "sequence": + case "switch": + case "ternary": + case "throw": + case "try": + case "unreachable": + case "unsupported": { + break; + } + default: { + assertExhaustive( + terminal, + `Unexpected terminal kind \`${(terminal as any).kind}\`` + ); + } + } + } +}