diff --git a/compiler/packages/babel-plugin-react-compiler/src/Entrypoint/Pipeline.ts b/compiler/packages/babel-plugin-react-compiler/src/Entrypoint/Pipeline.ts index 0752894d94f..6d231919a66 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/Entrypoint/Pipeline.ts +++ b/compiler/packages/babel-plugin-react-compiler/src/Entrypoint/Pipeline.ts @@ -91,6 +91,7 @@ import { validatePreservedManualMemoization, validateUseMemo, } from "../Validation"; +import pruneInitializationDependencies from "../ReactiveScopes/PruneInitializationDependencies"; export type CompilerPipelineValue = | { kind: "ast"; name: string; value: CodegenFunction } @@ -379,6 +380,15 @@ function* runWithEnvironment( value: reactiveFunction, }); + if (env.config.enableChangeDetectionForDebugging != null) { + pruneInitializationDependencies(reactiveFunction); + yield log({ + kind: "reactive", + name: "PruneInitializationDependencies", + value: reactiveFunction, + }); + } + propagateEarlyReturns(reactiveFunction); yield log({ kind: "reactive", diff --git a/compiler/packages/babel-plugin-react-compiler/src/ReactiveScopes/PruneInitializationDependencies.ts b/compiler/packages/babel-plugin-react-compiler/src/ReactiveScopes/PruneInitializationDependencies.ts new file mode 100644 index 00000000000..b9939addcf1 --- /dev/null +++ b/compiler/packages/babel-plugin-react-compiler/src/ReactiveScopes/PruneInitializationDependencies.ts @@ -0,0 +1,290 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +import { CompilerError } from "../CompilerError"; +import { + Environment, + Identifier, + IdentifierId, + InstructionId, + Place, + ReactiveBlock, + ReactiveFunction, + ReactiveInstruction, + ReactiveScopeBlock, + ReactiveTerminalStatement, + getHookKind, + isUseRefType, + isUseStateType, +} from "../HIR"; +import { eachCallArgument, eachInstructionLValue } from "../HIR/visitors"; +import DisjointSet from "../Utils/DisjointSet"; +import { assertExhaustive } from "../Utils/utils"; +import { ReactiveFunctionVisitor, visitReactiveFunction } from "./visitors"; + +/** + * This pass is built based on the observation by @jbrown215 that arguments + * to useState and useRef are only used the first time a component is rendered. + * Any subsequent times, the arguments will be evaluated but ignored. In this pass, + * we use this fact to improve the output of the compiler by not recomputing values that + * are only used as arguments (or inputs to arguments to) useState and useRef. + * + * This pass isn't yet stress-tested so it's not enabled by default. It's only enabled + * to support certain debug modes that detect non-idempotent code, since non-idempotent + * code can "safely" be used if its only passed to useState and useRef. We plan to rewrite + * this pass in HIR and enable it as an optimization in the future. + * + * Algorithm: + * We take two passes over the reactive function AST. In the first pass, we gather + * aliases and build relationships between property accesses--the key thing we need + * to do here is to find that, e.g., $0.x and $1 refer to the same value if + * $1 = PropertyLoad $0.x. + * + * In the second pass, we traverse the AST in reverse order and track how each place + * is used. If a place is read from in any Terminal, we mark the place as "Update", meaning + * it is used whenever the component is updated/re-rendered. If a place is read from in + * a useState or useRef hook call, we mark it as "Create", since it is only used when the + * component is created. In other instructions, we propagate the inferred place for the + * instructions lvalues onto any other instructions that are read. + * + * Whenever we finish this reverse pass over a reactive block, we can look at the blocks + * dependencies and see whether the dependencies are used in an "Update" context or only + * in a "Create" context. If a dependency is create-only, then we can remove that dependency + * from the block. + */ + +type CreateUpdate = "Create" | "Update" | "Unknown"; + +type KindMap = Map; + +class Visitor extends ReactiveFunctionVisitor { + map: KindMap = new Map(); + aliases: DisjointSet; + paths: Map>; + env: Environment; + + constructor( + env: Environment, + aliases: DisjointSet, + paths: Map> + ) { + super(); + this.aliases = aliases; + this.paths = paths; + this.env = env; + } + + join(values: Array): CreateUpdate { + function join2(l: CreateUpdate, r: CreateUpdate): CreateUpdate { + if (l === "Update" || r === "Update") { + return "Update"; + } else if (l === "Create" || r === "Create") { + return "Create"; + } else if (l === "Unknown" || r === "Unknown") { + return "Unknown"; + } + assertExhaustive(r, `Unhandled variable kind ${r}`); + } + return values.reduce(join2, "Unknown"); + } + + isCreateOnlyHook(id: Identifier): boolean { + return isUseStateType(id) || isUseRefType(id); + } + + override visitPlace( + _: InstructionId, + place: Place, + state: CreateUpdate + ): void { + this.map.set( + place.identifier.id, + this.join([state, this.map.get(place.identifier.id) ?? "Unknown"]) + ); + } + + override visitBlock(block: ReactiveBlock, state: CreateUpdate): void { + super.visitBlock([...block].reverse(), state); + } + + override visitInstruction(instruction: ReactiveInstruction): void { + const state = this.join( + [...eachInstructionLValue(instruction)].map( + (operand) => this.map.get(operand.identifier.id) ?? "Unknown" + ) + ); + + const visitCallOrMethodNonArgs = (): void => { + switch (instruction.value.kind) { + case "CallExpression": { + this.visitPlace(instruction.id, instruction.value.callee, state); + break; + } + case "MethodCall": { + this.visitPlace(instruction.id, instruction.value.property, state); + this.visitPlace(instruction.id, instruction.value.receiver, state); + break; + } + } + }; + + const isHook = (): boolean => { + let callee = null; + switch (instruction.value.kind) { + case "CallExpression": { + callee = instruction.value.callee.identifier; + break; + } + case "MethodCall": { + callee = instruction.value.property.identifier; + break; + } + } + return callee != null && getHookKind(this.env, callee) != null; + }; + + switch (instruction.value.kind) { + case "CallExpression": + case "MethodCall": { + if ( + instruction.lvalue && + this.isCreateOnlyHook(instruction.lvalue.identifier) + ) { + [...eachCallArgument(instruction.value.args)].forEach((operand) => + this.visitPlace(instruction.id, operand, "Create") + ); + visitCallOrMethodNonArgs(); + } else { + this.traverseInstruction(instruction, isHook() ? "Update" : state); + } + break; + } + default: { + this.traverseInstruction(instruction, state); + } + } + } + + override visitScope(scope: ReactiveScopeBlock): void { + const state = this.join( + [ + ...scope.scope.declarations.keys(), + ...[...scope.scope.reassignments.values()].map((ident) => ident.id), + ].map((id) => this.map.get(id) ?? "Unknown") + ); + super.visitScope(scope, state); + [...scope.scope.dependencies].forEach((ident) => { + let target: undefined | IdentifierId = + this.aliases.find(ident.identifier.id) ?? ident.identifier.id; + ident.path.forEach((key) => { + target &&= this.paths.get(target)?.get(key); + }); + if (target && this.map.get(target) === "Create") { + scope.scope.dependencies.delete(ident); + } + }); + } + + override visitTerminal( + stmt: ReactiveTerminalStatement, + state: CreateUpdate + ): void { + CompilerError.invariant(state !== "Create", { + reason: "Visiting a terminal statement with state 'Create'", + loc: stmt.terminal.loc, + }); + super.visitTerminal(stmt, state); + } + + override visitReactiveFunctionValue( + _id: InstructionId, + _dependencies: Array, + fn: ReactiveFunction, + state: CreateUpdate + ): void { + visitReactiveFunction(fn, this, state); + } +} + +export default function pruneInitializationDependencies( + fn: ReactiveFunction +): void { + const [aliases, paths] = getAliases(fn); + visitReactiveFunction(fn, new Visitor(fn.env, aliases, paths), "Update"); +} + +function update( + map: Map>, + key: IdentifierId, + path: string, + value: IdentifierId +): void { + const inner = map.get(key) ?? new Map(); + inner.set(path, value); + map.set(key, inner); +} + +class AliasVisitor extends ReactiveFunctionVisitor { + scopeIdentifiers: DisjointSet = new DisjointSet(); + scopePaths: Map> = new Map(); + + override visitInstruction(instr: ReactiveInstruction): void { + if ( + instr.value.kind === "StoreLocal" || + instr.value.kind === "StoreContext" + ) { + this.scopeIdentifiers.union([ + instr.value.lvalue.place.identifier.id, + instr.value.value.identifier.id, + ]); + } else if ( + instr.value.kind === "LoadLocal" || + instr.value.kind === "LoadContext" + ) { + instr.lvalue && + this.scopeIdentifiers.union([ + instr.lvalue.identifier.id, + instr.value.place.identifier.id, + ]); + } else if (instr.value.kind === "PropertyLoad") { + instr.lvalue && + update( + this.scopePaths, + instr.value.object.identifier.id, + instr.value.property, + instr.lvalue.identifier.id + ); + } else if (instr.value.kind === "PropertyStore") { + update( + this.scopePaths, + instr.value.object.identifier.id, + instr.value.property, + instr.value.value.identifier.id + ); + } + } +} + +function getAliases( + fn: ReactiveFunction +): [DisjointSet, Map>] { + const visitor = new AliasVisitor(); + visitReactiveFunction(fn, visitor, null); + let disjoint = visitor.scopeIdentifiers; + let scopePaths = new Map>(); + for (const [key, value] of visitor.scopePaths) { + for (const [path, id] of value) { + update( + scopePaths, + disjoint.find(key) ?? key, + path, + disjoint.find(id) ?? id + ); + } + } + return [disjoint, scopePaths]; +} diff --git a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-and-other-hook-unpruned-dependency.expect.md b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-and-other-hook-unpruned-dependency.expect.md new file mode 100644 index 00000000000..414c9cd143b --- /dev/null +++ b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-and-other-hook-unpruned-dependency.expect.md @@ -0,0 +1,84 @@ + +## Input + +```javascript +import { useState } from "react"; // @enableChangeDetectionForDebugging + +function useOther(x) { + return x; +} + +function Component(props) { + const w = f(props.x); + const z = useOther(w); + const [x, _] = useState(z); + return
{x}
; +} + +function f(x) { + return x; +} + +export const FIXTURE_ENTRYPOINT = { + fn: Component, + params: [{ x: 42 }], + isComponent: true, +}; + +``` + +## Code + +```javascript +import { $structuralCheck } from "react-compiler-runtime"; +import { c as _c } from "react/compiler-runtime"; +import { useState } from "react"; // @enableChangeDetectionForDebugging + +function useOther(x) { + return x; +} + +function Component(props) { + const $ = _c(4); + let t0; + { + t0 = f(props.x); + if (!($[0] !== props.x)) { + let old$t0; + old$t0 = $[1]; + $structuralCheck(old$t0, t0, "t0", "Component"); + t0 = old$t0; + } + $[0] = props.x; + $[1] = t0; + } + const w = t0; + const z = useOther(w); + const [x] = useState(z); + let t1; + { + t1 =
{x}
; + if (!($[2] !== x)) { + let old$t1; + old$t1 = $[3]; + $structuralCheck(old$t1, t1, "t1", "Component"); + t1 = old$t1; + } + $[2] = x; + $[3] = t1; + } + return t1; +} + +function f(x) { + return x; +} + +export const FIXTURE_ENTRYPOINT = { + fn: Component, + params: [{ x: 42 }], + isComponent: true, +}; + +``` + \ No newline at end of file diff --git a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-and-other-hook-unpruned-dependency.js b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-and-other-hook-unpruned-dependency.js new file mode 100644 index 00000000000..4f57f785d90 --- /dev/null +++ b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-and-other-hook-unpruned-dependency.js @@ -0,0 +1,22 @@ +import { useState } from "react"; // @enableChangeDetectionForDebugging + +function useOther(x) { + return x; +} + +function Component(props) { + const w = f(props.x); + const z = useOther(w); + const [x, _] = useState(z); + return
{x}
; +} + +function f(x) { + return x; +} + +export const FIXTURE_ENTRYPOINT = { + fn: Component, + params: [{ x: 42 }], + isComponent: true, +}; diff --git a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-pruned-dependency-change-detect.expect.md b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-pruned-dependency-change-detect.expect.md index 482fb5cbbdc..f44b54f99df 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-pruned-dependency-change-detect.expect.md +++ b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-pruned-dependency-change-detect.expect.md @@ -20,31 +20,26 @@ import { c as _c } from "react/compiler-runtime"; // @enableChangeDetectionForDe import { useState } from "react"; function Component(props) { - const $ = _c(4); + const $ = _c(3); let t0; - { + if ($[0] === Symbol.for("react.memo_cache_sentinel")) { t0 = f(props.x); - if (!($[0] !== props.x)) { - let old$t0; - old$t0 = $[1]; - $structuralCheck(old$t0, t0, "t0", "Component"); - t0 = old$t0; - } - $[0] = props.x; - $[1] = t0; + $[0] = t0; + } else { + t0 = $[0]; } const [x] = useState(t0); let t1; { t1 =
{x}
; - if (!($[2] !== x)) { + if (!($[1] !== x)) { let old$t1; - old$t1 = $[3]; + old$t1 = $[2]; $structuralCheck(old$t1, t1, "t1", "Component"); t1 = old$t1; } - $[2] = x; - $[3] = t1; + $[1] = x; + $[2] = t1; } return t1; } diff --git a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-unpruned-dependency.expect.md b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-unpruned-dependency.expect.md new file mode 100644 index 00000000000..cb399a0bc56 --- /dev/null +++ b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-unpruned-dependency.expect.md @@ -0,0 +1,85 @@ + +## Input + +```javascript +import { useState } from "react"; // @enableChangeDetectionForDebugging + +function Component(props) { + const w = f(props.x); + const [x, _] = useState(w); + return ( +
+ {x} + {w} +
+ ); +} + +function f(x) { + return x; +} + +export const FIXTURE_ENTRYPOINT = { + fn: Component, + params: [{ x: 42 }], + isComponent: true, +}; + +``` + +## Code + +```javascript +import { $structuralCheck } from "react-compiler-runtime"; +import { c as _c } from "react/compiler-runtime"; +import { useState } from "react"; // @enableChangeDetectionForDebugging + +function Component(props) { + const $ = _c(5); + let t0; + { + t0 = f(props.x); + if (!($[0] !== props.x)) { + let old$t0; + old$t0 = $[1]; + $structuralCheck(old$t0, t0, "t0", "Component"); + t0 = old$t0; + } + $[0] = props.x; + $[1] = t0; + } + const w = t0; + const [x] = useState(w); + let t1; + { + t1 = ( +
+ {x} + {w} +
+ ); + if (!($[2] !== x || $[3] !== w)) { + let old$t1; + old$t1 = $[4]; + $structuralCheck(old$t1, t1, "t1", "Component"); + t1 = old$t1; + } + $[2] = x; + $[3] = w; + $[4] = t1; + } + return t1; +} + +function f(x) { + return x; +} + +export const FIXTURE_ENTRYPOINT = { + fn: Component, + params: [{ x: 42 }], + isComponent: true, +}; + +``` + \ No newline at end of file diff --git a/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-unpruned-dependency.js b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-unpruned-dependency.js new file mode 100644 index 00000000000..c63c16aebc4 --- /dev/null +++ b/compiler/packages/babel-plugin-react-compiler/src/__tests__/fixtures/compiler/useState-unpruned-dependency.js @@ -0,0 +1,22 @@ +import { useState } from "react"; // @enableChangeDetectionForDebugging + +function Component(props) { + const w = f(props.x); + const [x, _] = useState(w); + return ( +
+ {x} + {w} +
+ ); +} + +function f(x) { + return x; +} + +export const FIXTURE_ENTRYPOINT = { + fn: Component, + params: [{ x: 42 }], + isComponent: true, +}; diff --git a/compiler/packages/snap/src/SproutTodoFilter.ts b/compiler/packages/snap/src/SproutTodoFilter.ts index c6fdc12b727..1dcd1c2b55e 100644 --- a/compiler/packages/snap/src/SproutTodoFilter.ts +++ b/compiler/packages/snap/src/SproutTodoFilter.ts @@ -496,6 +496,8 @@ const skipFilter = new Set([ "fast-refresh-refresh-on-const-changes-dev", "useState-pruned-dependency-change-detect", + "useState-unpruned-dependency", + "useState-and-other-hook-unpruned-dependency", ]); export default skipFilter;