From 37aad057ef7478c673036290dce170b16196eaf6 Mon Sep 17 00:00:00 2001 From: Daniel Tschinder <231804+danez@users.noreply.github.com> Date: Wed, 26 Oct 2022 12:03:25 +0200 Subject: [PATCH] feat: Add support for useImperativeHandle --- .../componentMethodsHandler-test.ts.snap | 150 +++++++ .../__tests__/componentMethodsHandler-test.ts | 106 +++++ .../src/handlers/componentMethodsHandler.ts | 156 ++++++-- .../__tests__/findFunctionReturn-test.ts | 266 ++++++++++++ .../__tests__/isStatelessComponent-test.ts | 377 ++++-------------- .../src/utils/findFunctionReturn.ts | 154 +++++++ .../src/utils/isStatelessComponent.ts | 200 +--------- 7 files changed, 895 insertions(+), 514 deletions(-) create mode 100644 packages/react-docgen/src/utils/__tests__/findFunctionReturn-test.ts create mode 100644 packages/react-docgen/src/utils/findFunctionReturn.ts diff --git a/packages/react-docgen/src/handlers/__tests__/__snapshots__/componentMethodsHandler-test.ts.snap b/packages/react-docgen/src/handlers/__tests__/__snapshots__/componentMethodsHandler-test.ts.snap index 2f034fbf3dc..e7a7a858ba2 100644 --- a/packages/react-docgen/src/handlers/__tests__/__snapshots__/componentMethodsHandler-test.ts.snap +++ b/packages/react-docgen/src/handlers/__tests__/__snapshots__/componentMethodsHandler-test.ts.snap @@ -158,3 +158,153 @@ exports[`componentMethodsHandler should handle and ignore computed methods 1`] = }, ] `; + +exports[`componentMethodsHandler useImperativeHandle AssignmentExpression and useImperativeHandle 1`] = ` +[ + { + "docblock": null, + "modifiers": [ + "static", + ], + "name": "other", + "params": [], + "returns": null, + }, + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle VariableDeclaration and useImperativeHandle 1`] = ` +[ + { + "docblock": null, + "modifiers": [ + "static", + ], + "name": "other", + "params": [], + "returns": null, + }, + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle assigned ReturnStatement ArrowFunctionExpression Component 1`] = ` +[ + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle assigned ReturnStatement FunctionDeclaration Component 1`] = ` +[ + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle assigned ReturnStatement FunctionExpression Component 1`] = ` +[ + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle direct ObjectExpression ArrowFunctionExpression Component 1`] = ` +[ + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle direct ObjectExpression FunctionDeclaration Component 1`] = ` +[ + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle direct ObjectExpression FunctionExpression Component 1`] = ` +[ + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle regular ReturnStatement ArrowFunctionExpression Component 1`] = ` +[ + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle regular ReturnStatement FunctionDeclaration Component 1`] = ` +[ + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; + +exports[`componentMethodsHandler useImperativeHandle regular ReturnStatement FunctionExpression Component 1`] = ` +[ + { + "docblock": null, + "modifiers": [], + "name": "method", + "params": [], + "returns": null, + }, +] +`; diff --git a/packages/react-docgen/src/handlers/__tests__/componentMethodsHandler-test.ts b/packages/react-docgen/src/handlers/__tests__/componentMethodsHandler-test.ts index 995e46e02c9..eb0bc0e7aa0 100644 --- a/packages/react-docgen/src/handlers/__tests__/componentMethodsHandler-test.ts +++ b/packages/react-docgen/src/handlers/__tests__/componentMethodsHandler-test.ts @@ -4,11 +4,15 @@ import Documentation from '../../Documentation'; import type DocumentationMock from '../../__mocks__/Documentation'; import type { ArrowFunctionExpression, + AssignmentExpression, ClassDeclaration, ExportDefaultDeclaration, FunctionDeclaration, + FunctionExpression, + VariableDeclaration, } from '@babel/types'; import type { NodePath } from '@babel/traverse'; +import type { ComponentNode } from '../../resolver'; jest.mock('../../Documentation'); @@ -81,6 +85,108 @@ describe('componentMethodsHandler', () => { ]); } + describe('useImperativeHandle', () => { + // Other cases BlockScopeBody with return, both assigned and useImperativeHandles + + const methodDefinitions = { + 'direct ObjectExpression': '({ method: () => {} })', + 'regular ReturnStatement': '{x; return { method: () => {} };}', + 'assigned ReturnStatement': '{const r = { method: () => {} }; return r;}', + }; + + Object.entries(methodDefinitions).forEach(([name, code]) => { + describe(name, () => { + it('FunctionExpression Component', () => { + const definition = parse.expressionLast( + `import { useImperativeHandle } from 'react'; + (function () { + useImperativeHandle(ref, () => ${code}); + return
; + });`, + ); + + componentMethodsHandler(documentation, definition); + + expect(documentation.methods).toHaveLength(1); + expect(documentation.methods).toMatchSnapshot(); + }); + + it('FunctionDeclaration Component', () => { + const definition = parse.statementLast( + `import { useImperativeHandle } from 'react'; + function Component() { + useImperativeHandle(ref, () => ${code}); + return
; + }`, + ); + + componentMethodsHandler(documentation, definition); + + expect(documentation.methods).toHaveLength(1); + expect(documentation.methods).toMatchSnapshot(); + }); + + it('ArrowFunctionExpression Component', () => { + const definition = parse.expressionLast( + `import { useImperativeHandle } from 'react'; + (() => { + useImperativeHandle(ref, () => ${code}); + return
; + });`, + ); + + componentMethodsHandler(documentation, definition); + + expect(documentation.methods).toHaveLength(1); + expect(documentation.methods).toMatchSnapshot(); + }); + }); + }); + + it('AssignmentExpression and useImperativeHandle', () => { + const definition = parse + .statement( + `import { useImperativeHandle } from 'react'; + let Component; + Component = function () { + test(); + useImperativeHandle(ref, () => ({ method: () => {} })); + + return
; + }; + Component.other = () => {}; + `, + -2, + ) + .get('expression.right') as NodePath; + + componentMethodsHandler(documentation, definition); + + expect(documentation.methods).toMatchSnapshot(); + }); + + it('VariableDeclaration and useImperativeHandle', () => { + const definition = parse + .statement( + `import { useImperativeHandle } from 'react'; + let Component = function () { + test(); + useImperativeHandle(ref, () => ({ method: () => {} })); + + return
; + }; + Component.other = () => {}; + `, + -2, + ) + .get('declarations.0.init') as NodePath; + + componentMethodsHandler(documentation, definition); + + expect(documentation.methods).toMatchSnapshot(); + }); + }); + it('extracts the documentation for an ObjectExpression', () => { const src = ` { diff --git a/packages/react-docgen/src/handlers/componentMethodsHandler.ts b/packages/react-docgen/src/handlers/componentMethodsHandler.ts index f21199deb99..945ef3022a0 100644 --- a/packages/react-docgen/src/handlers/componentMethodsHandler.ts +++ b/packages/react-docgen/src/handlers/componentMethodsHandler.ts @@ -8,9 +8,20 @@ import { shallowIgnoreVisitors } from '../utils/traverse'; import resolveToValue from '../utils/resolveToValue'; import type { NodePath, Scope } from '@babel/traverse'; import { visitors } from '@babel/traverse'; -import type { AssignmentExpression, Identifier } from '@babel/types'; +import type { + AssignmentExpression, + BlockStatement, + Identifier, + ObjectExpression, +} from '@babel/types'; import type { ComponentNode } from '../resolver'; import type { Handler } from '.'; +import { + isReactBuiltinCall, + isReactForwardRefCall, + isStatelessComponent, +} from '../utils'; +import findFunctionReturn from '../utils/findFunctionReturn'; /** * The following values/constructs are considered methods: @@ -20,7 +31,7 @@ import type { Handler } from '.'; * - Public class fields in classes whose value are a functions * - Object properties whose values are functions */ -function isMethod(path: NodePath): boolean { +function isMethod(path: NodePath): path is MethodNodePath { let isProbablyMethod = (path.isClassMethod() && path.node.kind !== 'constructor') || path.isObjectMethod(); @@ -53,10 +64,10 @@ const explodedVisitors = visitors.explode({ const binding = assignmentPath.scope.getBinding(name); if ( + binding && left.isMemberExpression() && left.get('object').isIdentifier() && (left.node.object as Identifier).name === name && - binding && binding.scope === scope && resolveToValue(assignmentPath.get('right')).isFunction() ) { @@ -67,10 +78,98 @@ const explodedVisitors = visitors.explode({ }, }); +interface MethodDefinition { + path: MethodNodePath; + isStatic?: boolean; +} + +interface TraverseImperativeHandleState { + results: MethodNodePath[]; +} + +function isObjectExpression(path: NodePath): boolean { + return path.isObjectExpression(); +} + +const explodedImperativeHandleVisitors = + visitors.explode({ + ...shallowIgnoreVisitors, + + CallExpression: { + enter: function (path, state) { + if (!isReactBuiltinCall(path, 'useImperativeHandle')) { + return path.skip(); + } + + // useImperativeHandle(ref, () => ({ name: () => {}, ...})) + const arg = path.get('arguments')[1]; + + if (arg && !arg.isFunction()) { + return path.skip(); + } + + const body = resolveToValue(arg.get('body') as NodePath); + + let definition: NodePath | undefined; + + if (body.isObjectExpression()) { + definition = body; + } else { + definition = findFunctionReturn(arg, isObjectExpression) as + | NodePath + | undefined; + } + + // We found the object body, now add all of the properties as methods. + definition?.get('properties').forEach(p => { + if (isMethod(p)) { + state.results.push(p); + } + }); + + path.skip(); + }, + }, + }); + +function findStatelessComponentBody( + componentDefinition: NodePath, +): NodePath | undefined { + if (isStatelessComponent(componentDefinition)) { + const body = componentDefinition.get('body'); + + if (body.isBlockStatement()) { + return body; + } + } else if (isReactForwardRefCall(componentDefinition)) { + const inner = resolveToValue(componentDefinition.get('arguments')[0]); + + return findStatelessComponentBody(inner); + } + + return undefined; +} + +function findImperativeHandleMethods( + componentDefinition: NodePath, +): MethodDefinition[] { + const body = findStatelessComponentBody(componentDefinition); + + if (!body) { + return []; + } + + const state: TraverseImperativeHandleState = { results: [] }; + + body.traverse(explodedImperativeHandleVisitors, state); + + return state.results.map(p => ({ path: p })); +} + function findAssignedMethods( path: NodePath, idPath: NodePath, -): Array> { +): MethodDefinition[] { if (!idPath.hasNode() || !idPath.isIdentifier()) { return []; } @@ -91,7 +190,7 @@ function findAssignedMethods( path.traverse(explodedVisitors, state); - return state.methods; + return state.methods.map(p => ({ path: p })); } /** @@ -103,20 +202,19 @@ const componentMethodsHandler: Handler = function ( componentDefinition: NodePath, ): void { // Extract all methods from the class or object. - let methodPaths: Array<{ path: MethodNodePath; isStatic?: boolean }> = []; + let methodPaths: MethodDefinition[] = []; + const parent = componentDefinition.parentPath; if (isReactComponentClass(componentDefinition)) { methodPaths = ( componentDefinition .get('body') .get('body') - .filter(body => isMethod(body)) as MethodNodePath[] + .filter(isMethod) as MethodNodePath[] ).map(p => ({ path: p })); } else if (componentDefinition.isObjectExpression()) { methodPaths = ( - componentDefinition - .get('properties') - .filter(props => isMethod(props)) as MethodNodePath[] + componentDefinition.get('properties').filter(isMethod) as MethodNodePath[] ).map(p => ({ path: p })); // Add the statics object properties. @@ -126,37 +224,41 @@ const componentMethodsHandler: Handler = function ( statics.get('properties').forEach(property => { if (isMethod(property)) { methodPaths.push({ - path: property as MethodNodePath, + path: property, isStatic: true, }); } }); } } else if ( - componentDefinition.parentPath && - componentDefinition.parentPath.isVariableDeclarator() && - componentDefinition.parentPath.node.init === componentDefinition.node && - componentDefinition.parentPath.get('id').isIdentifier() + parent.isVariableDeclarator() && + parent.node.init === componentDefinition.node && + parent.get('id').isIdentifier() ) { methodPaths = findAssignedMethods( - componentDefinition.parentPath.scope.path, - componentDefinition.parentPath.get('id') as NodePath, - ).map(p => ({ path: p })); + parent.scope.path, + parent.get('id') as NodePath, + ); } else if ( - componentDefinition.parentPath && - componentDefinition.parentPath.isAssignmentExpression() && - componentDefinition.parentPath.node.right === componentDefinition.node && - componentDefinition.parentPath.get('left').isIdentifier() + parent.isAssignmentExpression() && + parent.node.right === componentDefinition.node && + parent.get('left').isIdentifier() ) { methodPaths = findAssignedMethods( - componentDefinition.parentPath.scope.path, - componentDefinition.parentPath.get('left') as NodePath, - ).map(p => ({ path: p })); + parent.scope.path, + parent.get('left') as NodePath, + ); } else if (componentDefinition.isFunctionDeclaration()) { methodPaths = findAssignedMethods( - componentDefinition.parentPath.scope.path, + parent.scope.path, componentDefinition.get('id'), - ).map(p => ({ path: p })); + ); + } + + const imperativeHandles = findImperativeHandleMethods(componentDefinition); + + if (imperativeHandles) { + methodPaths = [...methodPaths, ...imperativeHandles]; } documentation.set( diff --git a/packages/react-docgen/src/utils/__tests__/findFunctionReturn-test.ts b/packages/react-docgen/src/utils/__tests__/findFunctionReturn-test.ts new file mode 100644 index 00000000000..9407b077e50 --- /dev/null +++ b/packages/react-docgen/src/utils/__tests__/findFunctionReturn-test.ts @@ -0,0 +1,266 @@ +import type { NodePath } from '@babel/traverse'; +import type { StringLiteral } from '@babel/types'; +import { parse, makeMockImporter, noopImporter } from '../../../tests/utils'; +import type { Importer } from '../../importer'; +import findFunctionReturn from '../findFunctionReturn'; + +const predicate = (path: NodePath): boolean => + path.isStringLiteral() && + (path.node.value === 'value' || path.node.value === 'wrong'); + +const value = '"value"'; + +function expectValue(received: NodePath | undefined) { + expect(received).not.toBeUndefined(); + expect(received?.node.type).toBe('StringLiteral'); + expect((received?.node as StringLiteral).value).toBe('value'); +} + +describe('findFunctionReturn', () => { + const wrongFunction = `const wrong = () => "wrong";`; + const functionStyle: Record< + string, + [(name: string, expr: string) => string, string] + > = { + ArrowExpression: [ + (name: string, expr: string): string => `var ${name} = () => (${expr});`, + 'declarations.0.init', + ], + ArrowBlock: [ + (name: string, expr: string): string => + `var ${name} = () => { ${wrongFunction}return (${expr}); };`, + 'declarations.0.init', + ], + FunctionExpression: [ + (name: string, expr: string): string => + `var ${name} = function () { ${wrongFunction}return (${expr}); }`, + 'declarations.0.init', + ], + FunctionDeclaration: [ + (name: string, expr: string): string => + `function ${name} () { ${wrongFunction}return (${expr}); }`, + '', + ], + }; + + const modifiers = { + default: (): string => value, + 'conditional consequent': (): string => `x ? ${value} : null`, + 'conditional alternate': (): string => `x ? null : ${value}`, + 'OR left': (): string => `${value} || null`, + 'AND right': (): string => `x && ${value}`, + }; + + type ComponentFactory = (name: string, expression: string) => string; + + const cases = { + 'no reference': [ + (expr: string, componentFactory: ComponentFactory): string => + `${componentFactory('Foo', expr)}`, + 'body.0', + ], + 'with Identifier reference': [ + (expr: string, componentFactory: ComponentFactory): string => ` + var variable = (${expr}); + ${componentFactory('Foo', 'variable')} + `, + 'body.1', + ], + }; + + Object.entries(functionStyle).forEach(([name, style]) => { + cases[`with ${name} reference`] = [ + (expr: string, componentFactory: ComponentFactory): string => ` + ${style[0]('subfunc', expr)} + ${componentFactory('Foo', 'subfunc()')} + `, + 'body.1', + ]; + }); + + const negativeModifiers = { + 'nested ArrowExpression': (expr: string): string => `() => ${expr}`, + 'nested ArrowBlock': (expr: string): string => `() => { return ${expr} }`, + 'nested FunctionExpression': (expr: string): string => + `function () { return ${expr} }`, + }; + + Object.keys(cases).forEach(name => { + const [caseFactory, caseSelector] = cases[name]; + + describe(name, () => { + Object.entries(functionStyle).forEach( + ([functionName, [functionFactory, functionSelector]]) => { + describe(functionName, () => { + Object.keys(modifiers).forEach(modifierName => { + const modifierFactory = modifiers[modifierName]; + + it(modifierName, () => { + const code = caseFactory(modifierFactory(), functionFactory); + const def: NodePath = parse(code).get( + `${caseSelector}.${functionSelector}`.replace(/\.$/, ''), + ) as NodePath; + + expectValue(findFunctionReturn(def, predicate)); + }); + }); + + Object.keys(negativeModifiers).forEach(modifierName => { + const modifierFactory = negativeModifiers[modifierName]; + + it(modifierName, () => { + const code = caseFactory(modifierFactory(), functionFactory); + + const def: NodePath = parse(code).get( + `${caseSelector}.${functionSelector}`.replace(/\.$/, ''), + ) as NodePath; + + expect(findFunctionReturn(def, predicate)).toBeUndefined(); + }); + }); + }); + }, + ); + }); + }); + + describe('resolving return values', () => { + function test( + desc: string, + src: string, + importer: Importer = noopImporter, + ) { + it(desc, () => { + const def = parse(src, importer).get('body')[0]; + + expectValue(findFunctionReturn(def, predicate)); + }); + } + + const mockImporter = makeMockImporter({ + bar: stmtLast => stmtLast(`export default "value";`).get('declaration'), + }); + + it('handles recursive function calls', () => { + const def = parse.statement(` + function Foo (props) { + return props && Foo(props); + } + `); + + expect(findFunctionReturn(def, predicate)).toBeUndefined(); + }); + + test( + 'does not see ifs as separate block', + ` + function Foo (props) { + if (x) { + return "value"; + } + } + `, + ); + + test( + 'handles simple resolves', + ` + function Foo (props) { + function bar() { + return "value"; + } + + return bar(); + } + `, + ); + + test( + 'handles reference resolves', + ` + function Foo (props) { + var result = bar(); + + return result; + + function bar() { + return "value"; + } + } + `, + ); + + test( + 'handles shallow member call expression resolves', + ` + function Foo (props) { + var shallow = { + shallowMember() { + return "value"; + } + }; + + return shallow.shallowMember(); + } + `, + ); + + test( + 'handles deep member call expression resolves', + ` + function Foo (props) { + var obj = { + deep: { + member() { + return "value"; + } + } + }; + + return obj.deep.member(); + } + `, + ); + + test( + 'handles external reference member call expression resolves', + ` + function Foo (props) { + var member = () => "value"; + var obj = { + deep: { + member: member + } + }; + + return obj.deep.member(); + } + `, + ); + + test( + 'handles all sorts of JavaScript things', + ` + function Foo (props) { + var external = { + member: () => "value" + }; + var obj = {external}; + + return obj.external.member(); + } + `, + ); + + test( + 'resolves imported values as return', + ` + function Foo (props) { + return bar; + } + import bar from 'bar'; + `, + mockImporter, + ); + }); +}); diff --git a/packages/react-docgen/src/utils/__tests__/isStatelessComponent-test.ts b/packages/react-docgen/src/utils/__tests__/isStatelessComponent-test.ts index 3c9a1e20736..7c7981a86ee 100644 --- a/packages/react-docgen/src/utils/__tests__/isStatelessComponent-test.ts +++ b/packages/react-docgen/src/utils/__tests__/isStatelessComponent-test.ts @@ -1,155 +1,93 @@ -import type { NodePath } from '@babel/traverse'; -import { parse, makeMockImporter, noopImporter } from '../../../tests/utils'; -import type { Importer } from '../../importer'; +import { parse } from '../../../tests/utils'; import isStatelessComponent from '../isStatelessComponent'; describe('isStatelessComponent', () => { - const componentIdentifiers = { - JSX: '
', - JSXFragment: '<>', - 'React.createElement': 'React.createElement("div", null)', - 'React.cloneElement': 'React.cloneElement(children, null)', - 'React.Children.only()': 'React.Children.only(children, null)', - }; - - const componentStyle = { - ArrowExpression: [ - (name: string, expr: string): string => `var ${name} = () => (${expr});`, - 'declarations.0.init', - ], - ArrowBlock: [ - (name: string, expr: string): string => - `var ${name} = () => { return (${expr}); };`, - 'declarations.0.init', - ], - FunctionExpression: [ - (name: string, expr: string): string => - `var ${name} = function () { return (${expr}); }`, - 'declarations.0.init', - ], - FunctionDeclaration: [ - (name: string, expr: string): string => - `function ${name} () { return (${expr}); }`, - '', - ], - }; - - const modifiers = { - default: (expr: string): string => expr, - 'conditional consequent': (expr: string): string => `x ? ${expr} : null`, - 'conditional alternate': (expr: string): string => `x ? null : ${expr}`, - 'OR left': (expr: string): string => `${expr} || null`, - 'AND right': (expr: string): string => `x && ${expr}`, - }; - - type ComponentFactory = (name: string, expression: string) => string; - - const cases = { - 'no reference': [ - (expr: string, componentFactory: ComponentFactory): string => ` - var React = require('react'); - ${componentFactory('Foo', expr)} - `, - 'body.1', - ], - 'with Identifier reference': [ - (expr: string, componentFactory: ComponentFactory): string => ` + it('accepts jsx', () => { + const def = parse(` var React = require('react'); - var jsx = (${expr}); - ${componentFactory('Foo', 'jsx')} - `, - 'body.2', - ], - }; + var Foo = () =>
; + `) + .get('body')[1] + .get('declarations')[0] + .get('init'); - Object.keys(componentStyle).forEach(name => { - cases[`with ${name} reference`] = [ - (expr: string, componentFactory: ComponentFactory): string => ` - var React = require('react'); - ${componentStyle[name][0]('jsx', expr)} - ${componentFactory('Foo', 'jsx()')} - `, - 'body.2', - ]; + expect(isStatelessComponent(def)).toBe(true); }); - const negativeModifiers = { - 'nested ArrowExpression': (expr: string): string => `() => ${expr}`, - 'nested ArrowBlock': (expr: string): string => `() => { return ${expr} }`, - 'nested FunctionExpression': (expr: string): string => - `function () { return ${expr} }`, - }; - - Object.keys(cases).forEach(name => { - const [caseFactory, caseSelector] = cases[name]; - - describe(name, () => { - Object.keys(componentIdentifiers).forEach(componentIdentifierName => { - const returnExpr = componentIdentifiers[componentIdentifierName]; + it('accepts jsx fragment', () => { + const def = parse(` + var React = require('react'); + var Foo = () => <>; + `) + .get('body')[1] + .get('declarations')[0] + .get('init'); - describe(componentIdentifierName, () => { - Object.keys(componentStyle).forEach(componentName => { - const [componentFactory, componentSelector] = - componentStyle[componentName]; + expect(isStatelessComponent(def)).toBe(true); + }); - describe(componentName, () => { - Object.keys(modifiers).forEach(modifierName => { - const modifierFactory = modifiers[modifierName]; + it('accepts createElement', () => { + const def = parse(` + var React = require('react'); + var Foo = () => React.createElement("div", null); + `) + .get('body')[1] + .get('declarations')[0] + .get('init'); - it(modifierName, () => { - const code = caseFactory( - modifierFactory(returnExpr), - componentFactory, - ); + expect(isStatelessComponent(def)).toBe(true); + }); - const def: NodePath = parse(code).get( - `${caseSelector}.${componentSelector}`.replace(/\.$/, ''), - ) as NodePath; + it('accepts cloneElement', () => { + const def = parse(` + var React = require('react'); + var Foo = () => React.cloneElement("div", null); + `) + .get('body')[1] + .get('declarations')[0] + .get('init'); - expect(isStatelessComponent(def)).toBe(true); - }); - }); + expect(isStatelessComponent(def)).toBe(true); + }); - Object.keys(negativeModifiers).forEach(modifierName => { - const modifierFactory = negativeModifiers[modifierName]; + it('accepts React.Children.only', () => { + const def = parse(` + var React = require('react'); + var Foo = () => React.Children.only(children); + `) + .get('body')[1] + .get('declarations')[0] + .get('init'); - it(modifierName, () => { - const code = caseFactory( - modifierFactory(returnExpr), - componentFactory, - ); + expect(isStatelessComponent(def)).toBe(true); + }); - const def: NodePath = parse(code).get( - `${caseSelector}.${componentSelector}`.replace(/\.$/, ''), - ) as NodePath; + it('accepts React.Children.map', () => { + const def = parse(` + var React = require('react'); + var Foo = () => React.Children.map(children, child => child); + `) + .get('body')[1] + .get('declarations')[0] + .get('init'); - expect(isStatelessComponent(def)).toBe(false); - }); - }); - }); - }); - }); - }); - }); + expect(isStatelessComponent(def)).toBe(true); }); - describe('Stateless Function Components with React.createElement', () => { - it('accepts different name than React', () => { - const def = parse(` + it('accepts different name than React', () => { + const def = parse(` var AlphaBetters = require('react'); var Foo = () => AlphaBetters.createElement("div", null); `) - .get('body')[1] - .get('declarations')[0] - .get('init'); + .get('body')[1] + .get('declarations')[0] + .get('init'); - expect(isStatelessComponent(def)).toBe(true); - }); + expect(isStatelessComponent(def)).toBe(true); }); - describe('Stateless Function Components inside module pattern', () => { - it('', () => { - const def = parse(` + it('Stateless Function Components inside module pattern', () => { + const def = parse(` var React = require('react'); var Foo = { Bar() { return
; }, @@ -159,22 +97,21 @@ describe('isStatelessComponent', () => { world: function({ children }) { return React.cloneElement(children, {}); }, } `) - .get('body')[1] - .get('declarations')[0] - .get('init'); - - const bar = def.get('properties')[0]; - const baz = def.get('properties')[1].get('value'); - const hello = def.get('properties')[2].get('value'); - const render = def.get('properties')[3]; - const world = def.get('properties')[4].get('value'); - - expect(isStatelessComponent(bar)).toBe(true); - expect(isStatelessComponent(baz)).toBe(true); - expect(isStatelessComponent(hello)).toBe(true); - expect(isStatelessComponent(render)).toBe(false); - expect(isStatelessComponent(world)).toBe(true); - }); + .get('body')[1] + .get('declarations')[0] + .get('init'); + + const bar = def.get('properties')[0]; + const baz = def.get('properties')[1].get('value'); + const hello = def.get('properties')[2].get('value'); + const render = def.get('properties')[3]; + const world = def.get('properties')[4].get('value'); + + expect(isStatelessComponent(bar)).toBe(true); + expect(isStatelessComponent(baz)).toBe(true); + expect(isStatelessComponent(hello)).toBe(true); + expect(isStatelessComponent(render)).toBe(false); + expect(isStatelessComponent(world)).toBe(true); }); describe('is not overzealous', () => { @@ -215,156 +152,4 @@ describe('isStatelessComponent', () => { expect(isStatelessComponent(def)).toBe(false); }); }); - - describe('resolving return values', () => { - function test( - desc: string, - src: string, - importer: Importer = noopImporter, - ) { - it(desc, () => { - const def = parse(src, importer).get('body')[1]; - - expect(isStatelessComponent(def)).toBe(true); - }); - } - - const mockImporter = makeMockImporter({ - bar: stmtLast => - stmtLast( - ` - export default
; - `, - ).get('declaration'), - }); - - it('does not see ifs as separate block', () => { - const def = parse.statement(` - function Foo (props) { - if (x) { - return
; - } - } - `); - - expect(isStatelessComponent(def)).toBe(true); - }); - - it('handles recursive function calls', () => { - const def = parse.statement(` - function Foo (props) { - return props && Foo(props); - } - `); - - expect(isStatelessComponent(def)).toBe(false); - }); - - test( - 'handles simple resolves', - ` - var React = require('react'); - function Foo (props) { - function bar() { - return React.createElement("div", props); - } - - return bar(); - } - `, - ); - - test( - 'handles reference resolves', - ` - var React = require('react'); - function Foo (props) { - var result = bar(); - - return result; - - function bar() { - return
; - } - } - `, - ); - - test( - 'handles shallow member call expression resolves', - ` - var React = require('react'); - function Foo (props) { - var shallow = { - shallowMember() { - return
; - } - }; - - return shallow.shallowMember(); - } - `, - ); - - test( - 'handles deep member call expression resolves', - ` - var React = require('react'); - function Foo (props) { - var obj = { - deep: { - member() { - return
; - } - } - }; - - return obj.deep.member(); - } - `, - ); - - test( - 'handles external reference member call expression resolves', - ` - var React = require('react'); - function Foo (props) { - var member = () =>
; - var obj = { - deep: { - member: member - } - }; - - return obj.deep.member(); - } - `, - ); - - test( - 'handles all sorts of JavaScript things', - ` - var React = require('react'); - function Foo (props) { - var external = { - member: () =>
- }; - var obj = {external}; - - return obj.external.member(); - } - `, - ); - - test( - 'resolves imported values as return', - ` - import bar from 'bar'; - function Foo (props) { - return bar; - } - `, - mockImporter, - ); - }); }); diff --git a/packages/react-docgen/src/utils/findFunctionReturn.ts b/packages/react-docgen/src/utils/findFunctionReturn.ts new file mode 100644 index 00000000000..8718dac57bb --- /dev/null +++ b/packages/react-docgen/src/utils/findFunctionReturn.ts @@ -0,0 +1,154 @@ +import type { NodePath } from '@babel/traverse'; +import { visitors } from '@babel/traverse'; +import resolveToValue from './resolveToValue'; +import { ignore } from './traverse'; + +type Predicate = (p: NodePath) => boolean; +interface TraverseState { + readonly predicate: Predicate; + resolvedReturnPath?: NodePath; + readonly seen: WeakSet; +} + +const explodedVisitors = visitors.explode({ + Function: { enter: ignore }, + Class: { enter: ignore }, + ObjectExpression: { enter: ignore }, + ReturnStatement: { + enter: function (path, state) { + const argument = path.get('argument'); + + if (argument.hasNode()) { + const resolvedPath = resolvesToFinalValue( + argument, + state.predicate, + state.seen, + ); + + if (resolvedPath) { + state.resolvedReturnPath = resolvedPath; + path.stop(); + } + } + }, + }, +}); + +function resolvesToFinalValue( + path: NodePath, + predicate: Predicate, + seen: WeakSet, +): NodePath | undefined { + // avoid returns with recursive function calls + if (seen.has(path)) { + return; + } + seen.add(path); + + // Is the path already passes then return it. + if (predicate(path)) { + return path; + } + + const resolvedPath = resolveToValue(path); + + // If the resolved path is already passing then no need to further check + // Only do this if the resolvedPath actually resolved something as otherwise we did this check already + if (resolvedPath.node !== path.node && predicate(resolvedPath)) { + return resolvedPath; + } + + // If the path points to a conditional expression, then we need to look only at + // the two possible paths + if (resolvedPath.isConditionalExpression()) { + return ( + resolvesToFinalValue(resolvedPath.get('consequent'), predicate, seen) || + resolvesToFinalValue(resolvedPath.get('alternate'), predicate, seen) + ); + } + + // If the path points to a logical expression (AND, OR, ...), then we need to look only at + // the two possible paths + if (resolvedPath.isLogicalExpression()) { + return ( + resolvesToFinalValue(resolvedPath.get('left'), predicate, seen) || + resolvesToFinalValue(resolvedPath.get('right'), predicate, seen) + ); + } + + // If we have a call expression, lets try to follow it + if (resolvedPath.isCallExpression()) { + const returnValue = findFunctionReturnWithCache( + resolveToValue(resolvedPath.get('callee')), + predicate, + seen, + ); + + if (returnValue) { + return returnValue; + } + } + + return; +} + +/** + * This can be used in two ways + * 1. Find the first return path that passes the predicate function + * (for example to check if a function is returning something) + * 2. Find all occurrences of return values + * For this the predicate acts more like a collector and always needs to return false + */ +function findFunctionReturnWithCache( + path: NodePath, + predicate: Predicate, + seen: WeakSet, +): NodePath | undefined { + let functionPath: NodePath = path; + + if (functionPath.isObjectProperty()) { + functionPath = functionPath.get('value'); + } else if (functionPath.isClassProperty()) { + const classPropertyValue = functionPath.get('value'); + + if (classPropertyValue.hasNode()) { + functionPath = classPropertyValue; + } + } + + if (!functionPath.isFunction()) { + return; + } + + // skip traversing for ArrowFunctionExpressions with no block + if (path.isArrowFunctionExpression()) { + const body = path.get('body'); + + if (!body.isBlockStatement()) { + return resolvesToFinalValue(body, predicate, seen); + } + } + + const state: TraverseState = { + predicate, + seen, + }; + + path.traverse(explodedVisitors, state); + + return state.resolvedReturnPath; +} + +/** + * This can be used in two ways + * 1. Find the first return path that passes the predicate function + * (for example to check if a function is returning something) + * 2. Find all occurrences of return values + * For this the predicate acts more like a collector and always needs to return false + */ +export default function findFunctionReturn( + path: NodePath, + predicate: Predicate, +): NodePath | undefined { + return findFunctionReturnWithCache(path, predicate, new WeakSet()); +} diff --git a/packages/react-docgen/src/utils/isStatelessComponent.ts b/packages/react-docgen/src/utils/isStatelessComponent.ts index bb4afb74107..842d3f62db0 100644 --- a/packages/react-docgen/src/utils/isStatelessComponent.ts +++ b/packages/react-docgen/src/utils/isStatelessComponent.ts @@ -1,12 +1,9 @@ -import getPropertyValuePath from './getPropertyValuePath'; import isReactCreateElementCall from './isReactCreateElementCall'; import isReactCloneElementCall from './isReactCloneElementCall'; import isReactChildrenElementCall from './isReactChildrenElementCall'; -import resolveToValue from './resolveToValue'; -import type { NodePath, Scope } from '@babel/traverse'; -import { visitors } from '@babel/traverse'; -import type { Expression } from '@babel/types'; -import { ignore } from './traverse'; +import type { NodePath } from '@babel/traverse'; +import type { StatelessComponentNode } from '../resolver'; +import findFunctionReturn from './findFunctionReturn'; const validPossibleStatelessComponentTypes = [ 'ArrowFunctionExpression', @@ -26,196 +23,17 @@ function isJSXElementOrReactCall(path: NodePath): boolean { ); } -function resolvesToJSXElementOrReactCall( - path: NodePath, - seen: WeakSet, -): boolean { - // avoid returns with recursive function calls - if (seen.has(path)) { - return false; - } - - seen.add(path); - - // Is the path is already a JSX element or a call to one of the React.* functions - if (isJSXElementOrReactCall(path)) { - return true; - } - - const resolvedPath = resolveToValue(path); - - // If the path points to a conditional expression, then we need to look only at - // the two possible paths - if (resolvedPath.isConditionalExpression()) { - return ( - resolvesToJSXElementOrReactCall( - resolvedPath.get('consequent'), - - seen, - ) || - resolvesToJSXElementOrReactCall( - resolvedPath.get('alternate'), - - seen, - ) - ); - } - - // If the path points to a logical expression (AND, OR, ...), then we need to look only at - // the two possible paths - if (resolvedPath.isLogicalExpression()) { - return ( - resolvesToJSXElementOrReactCall( - resolvedPath.get('left'), - - seen, - ) || resolvesToJSXElementOrReactCall(resolvedPath.get('right'), seen) - ); - } - - // Is the resolved path is already a JSX element or a call to one of the React.* functions - // Only do this if the resolvedPath actually resolved something as otherwise we did this check already - if (resolvedPath !== path && isJSXElementOrReactCall(resolvedPath)) { - return true; - } - - // If we have a call expression, lets try to follow it - if (resolvedPath.isCallExpression()) { - let calleeValue = resolveToValue(resolvedPath.get('callee')); - - if (returnsJSXElementOrReactCall(calleeValue, seen)) { - return true; - } - - if (calleeValue.isMemberExpression()) { - let resolvedValue: NodePath | undefined; - const namesToResolve: NodePath[] = []; - - const calleeObj = calleeValue.get('object'); - - if (calleeObj.isIdentifier()) { - namesToResolve.push(calleeValue.get('property')); - resolvedValue = resolveToValue(calleeObj); - } else { - do { - namesToResolve.unshift(calleeValue.get('property')); - calleeValue = calleeValue.get('object'); - } while (calleeValue.isMemberExpression()); - - resolvedValue = resolveToValue(calleeValue); - } - - if (resolvedValue && resolvedValue.isObjectExpression()) { - const resolvedMemberExpression = namesToResolve.reduce( - (result: NodePath | null, nodePath) => { - if (result) { - if ( - (!nodePath.isIdentifier() && !nodePath.isStringLiteral()) || - !result.isObjectExpression() - ) { - return null; - } - result = getPropertyValuePath( - result, - nodePath.isIdentifier() - ? nodePath.node.name - : nodePath.node.value, - ); - if (result && result.isIdentifier()) { - return resolveToValue(result); - } - } - - return result; - }, - resolvedValue, - ); - - if ( - !resolvedMemberExpression || - returnsJSXElementOrReactCall(resolvedMemberExpression, seen) - ) { - return true; - } - } - } - } - - return false; -} - -interface TraverseState { - readonly initialScope: Scope; - isStatelessComponent: boolean; - readonly seen: WeakSet; -} - -const explodedVisitors = visitors.explode({ - Function: { enter: ignore }, - Class: { enter: ignore }, - ObjectExpression: { enter: ignore }, - ReturnStatement: { - enter: function (path, state) { - // Only check return statements which are part of the checked function scope - if (path.scope.getFunctionParent() !== state.initialScope) { - path.skip(); - - return; - } - - if ( - path.node.argument && - resolvesToJSXElementOrReactCall( - path.get('argument') as NodePath, - state.seen, - ) - ) { - state.isStatelessComponent = true; - path.stop(); - } - }, - }, -}); - -function returnsJSXElementOrReactCall( - path: NodePath, - seen: WeakSet = new WeakSet(), -): boolean { - if (path.isObjectProperty()) { - path = path.get('value'); - } - - if (!path.isFunction()) { - return false; - } - - // early exit for ArrowFunctionExpressions - if ( - path.isArrowFunctionExpression() && - !path.get('body').isBlockStatement() && - resolvesToJSXElementOrReactCall(path.get('body'), seen) - ) { - return true; - } - - const state: TraverseState = { - initialScope: path.scope, - isStatelessComponent: false, - seen, - }; - - path.traverse(explodedVisitors, state); - - return state.isStatelessComponent; -} - /** * Returns `true` if the path represents a function which returns a JSXElement */ -export default function isStatelessComponent(path: NodePath): boolean { +export default function isStatelessComponent( + path: NodePath, +): path is NodePath { if (!path.inType(...validPossibleStatelessComponentTypes)) { return false; } - return returnsJSXElementOrReactCall(path); + const foundPath = findFunctionReturn(path, isJSXElementOrReactCall); + + return Boolean(foundPath); }