Skip to content

Commit

Permalink
[compiler] Add lowerContextAccess pass
Browse files Browse the repository at this point in the history
*This is only for internal profiling, not intended to ship.*

This pass is intended to be used with #30407.

This pass synthesizes selector functions by collecting immediately
destructured context acesses. We bailout for other types of context
access.

This pass lowers context access to use a selector function by passing
the synthesized selector function as the second argument.

ghstack-source-id: 61ef1a29e98e3b5a2e5d70e5870956ba5fa6bd8d
Pull Request resolved: #30548
  • Loading branch information
gsathya committed Aug 6, 2024
1 parent eb1d52b commit 89e2305
Show file tree
Hide file tree
Showing 17 changed files with 575 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ import {
} from '../Validation';
import {validateLocalsNotReassignedAfterRender} from '../Validation/ValidateLocalsNotReassignedAfterRender';
import {outlineFunctions} from '../Optimization/OutlineFunctions';
import {lowerContextAccess} from '../Optimization/LowerContextAccess';

export type CompilerPipelineValue =
| {kind: 'ast'; name: string; value: CodegenFunction}
Expand Down Expand Up @@ -199,6 +200,10 @@ function* runWithEnvironment(
validateNoCapitalizedCalls(hir);
}

if (env.config.enableLowerContextAccess) {
lowerContextAccess(hir);
}

analyseFunctions(hir);
yield log({kind: 'hir', name: 'AnalyseFunctions', value: hir});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

import {
ArrayExpression,
BasicBlock,
CallExpression,
Destructure,
Environment,
GeneratedSource,
HIRFunction,
IdentifierId,
Instruction,
LoadLocal,
Place,
PropertyLoad,
isUseContextHookType,
makeBlockId,
makeInstructionId,
markInstructionIds,
promoteTemporary,
reversePostorderBlocks,
} from '../HIR';
import {createTemporaryPlace} from '../HIR/HIRBuilder';
import {enterSSA} from '../SSA';
import {inferTypes} from '../TypeInference';

export function lowerContextAccess(fn: HIRFunction): void {
const contextAccess: Map<IdentifierId, CallExpression> = new Map();
const contextKeys: Map<IdentifierId, Array<string>> = new Map();

// collect context access and keys
for (const [, block] of fn.body.blocks) {
for (const instr of block.instructions) {
const {value, lvalue} = instr;

if (
value.kind === 'CallExpression' &&
isUseContextHookType(value.callee.identifier)
) {
contextAccess.set(lvalue.identifier.id, value);
continue;
}

if (value.kind !== 'Destructure') {
continue;
}

const destructureId = value.value.identifier.id;
if (!contextAccess.has(destructureId)) {
continue;
}

const keys = getContextKeys(value);
if (keys === null) {
return;
}

if (contextKeys.has(destructureId)) {
/*
* TODO(gsn): Add support for accessing context over multiple
* statements.
*/
return;
} else {
contextKeys.set(destructureId, keys);
}
}
}

if (contextAccess.size > 0) {
for (const [, block] of fn.body.blocks) {
let nextInstructions: Array<Instruction> | null = null;

for (let i = 0; i < block.instructions.length; i++) {
const instr = block.instructions[i];
const {lvalue, value} = instr;
if (
value.kind === 'CallExpression' &&
isUseContextHookType(value.callee.identifier) &&
contextKeys.has(lvalue.identifier.id)
) {
const keys = contextKeys.get(lvalue.identifier.id)!;
const selectorFnInstr = emitSelectorFn(fn.env, keys);
if (nextInstructions === null) {
nextInstructions = block.instructions.slice(0, i);
}
nextInstructions.push(selectorFnInstr);

const selectorFn = selectorFnInstr.lvalue;
value.args.push(selectorFn);
}

if (nextInstructions) {
nextInstructions.push(instr);
}
}
if (nextInstructions) {
block.instructions = nextInstructions;
}
}
markInstructionIds(fn.body);
}
}

function getContextKeys(value: Destructure): Array<string> | null {
const keys = [];
const pattern = value.lvalue.pattern;

switch (pattern.kind) {
case 'ArrayPattern': {
return null;
}

case 'ObjectPattern': {
for (const place of pattern.properties) {
if (
place.kind !== 'ObjectProperty' ||
place.type !== 'property' ||
place.key.kind !== 'identifier' ||
place.place.identifier.name === null ||
place.place.identifier.name.kind !== 'named'
) {
return null;
}
keys.push(place.key.name);
}
return keys;
}
}
}

function emitPropertyLoad(
env: Environment,
obj: Place,
property: string,
): {instructions: Array<Instruction>; element: Place} {
const loadObj: LoadLocal = {
kind: 'LoadLocal',
place: obj,
loc: GeneratedSource,
};
const object: Place = createTemporaryPlace(env, GeneratedSource);
const loadLocalInstr: Instruction = {
lvalue: object,
value: loadObj,
id: makeInstructionId(0),
loc: GeneratedSource,
};

const loadProp: PropertyLoad = {
kind: 'PropertyLoad',
object,
property,
loc: GeneratedSource,
};
const element: Place = createTemporaryPlace(env, GeneratedSource);
const loadPropInstr: Instruction = {
lvalue: element,
value: loadProp,
id: makeInstructionId(0),
loc: GeneratedSource,
};
return {
instructions: [loadLocalInstr, loadPropInstr],
element: element,
};
}

function emitSelectorFn(env: Environment, keys: Array<string>): Instruction {
const obj: Place = createTemporaryPlace(env, GeneratedSource);
promoteTemporary(obj.identifier);
const instr: Array<Instruction> = [];
const elements = [];
for (const key of keys) {
const {instructions, element: prop} = emitPropertyLoad(env, obj, key);
instr.push(...instructions);
elements.push(prop);
}

const arrayInstr = emitArrayInstr(elements, env);
instr.push(arrayInstr);

const block: BasicBlock = {
kind: 'block',
id: makeBlockId(0),
instructions: instr,
terminal: {
id: makeInstructionId(0),
kind: 'return',
loc: GeneratedSource,
value: arrayInstr.lvalue,
},
preds: new Set(),
phis: new Set(),
};

const fn: HIRFunction = {
loc: GeneratedSource,
id: null,
fnType: 'Other',
env,
params: [obj],
returnType: null,
context: [],
effects: null,
body: {
entry: block.id,
blocks: new Map([[block.id, block]]),
},
generator: false,
async: false,
directives: [],
};

reversePostorderBlocks(fn.body);
markInstructionIds(fn.body);
enterSSA(fn);
inferTypes(fn);

const fnInstr: Instruction = {
id: makeInstructionId(0),
value: {
kind: 'FunctionExpression',
name: null,
loweredFunc: {
func: fn,
dependencies: [],
},
type: 'ArrowFunctionExpression',
loc: GeneratedSource,
},
lvalue: createTemporaryPlace(env, GeneratedSource),
loc: GeneratedSource,
};
return fnInstr;
}

function emitArrayInstr(elements: Array<Place>, env: Environment): Instruction {
const array: ArrayExpression = {
kind: 'ArrayExpression',
elements,
loc: GeneratedSource,
};
const arrayLvalue: Place = createTemporaryPlace(env, GeneratedSource);
const arrayInstr: Instruction = {
id: makeInstructionId(0),
value: array,
lvalue: arrayLvalue,
loc: GeneratedSource,
};
return arrayInstr;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

## Input

```javascript
// @enableLowerContextAccess
function App() {
const {foo} = useContext(MyContext);
const {bar} = useContext(MyContext);
return <Bar foo={foo} bar={bar} />;
}

```

## Code

```javascript
import { c as _c } from "react/compiler-runtime"; // @enableLowerContextAccess
function App() {
const $ = _c(3);
const { foo } = useContext(MyContext, _temp);
const { bar } = useContext(MyContext, _temp2);
let t0;
if ($[0] !== foo || $[1] !== bar) {
t0 = <Bar foo={foo} bar={bar} />;
$[0] = foo;
$[1] = bar;
$[2] = t0;
} else {
t0 = $[2];
}
return t0;
}
function _temp2(t0) {
return [t0.bar];
}
function _temp(t0) {
return [t0.foo];
}

```
### Eval output
(kind: exception) Fixture not implemented
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// @enableLowerContextAccess
function App() {
const {foo} = useContext(MyContext);
const {bar} = useContext(MyContext);
return <Bar foo={foo} bar={bar} />;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

## Input

```javascript
// @enableLowerContextAccess
function App() {
const {foo, bar} = useContext(MyContext);
return <Bar foo={foo} bar={bar} />;
}

```

## Code

```javascript
import { c as _c } from "react/compiler-runtime"; // @enableLowerContextAccess
function App() {
const $ = _c(3);
const { foo, bar } = useContext(MyContext, _temp);
let t0;
if ($[0] !== foo || $[1] !== bar) {
t0 = <Bar foo={foo} bar={bar} />;
$[0] = foo;
$[1] = bar;
$[2] = t0;
} else {
t0 = $[2];
}
return t0;
}
function _temp(t0) {
return [t0.foo, t0.bar];
}

```
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// @enableLowerContextAccess
function App() {
const {foo, bar} = useContext(MyContext);
return <Bar foo={foo} bar={bar} />;
}
Loading

0 comments on commit 89e2305

Please sign in to comment.