diff --git a/packages/react-test-renderer/src/ReactShallowRenderer.js b/packages/react-test-renderer/src/ReactShallowRenderer.js index 9bf373d033c..988e5975074 100644 --- a/packages/react-test-renderer/src/ReactShallowRenderer.js +++ b/packages/react-test-renderer/src/ReactShallowRenderer.js @@ -143,7 +143,8 @@ class ReactShallowRenderer { this._rendering = true; this._element = element; - this._context = getMaskedContext(element.type.contextTypes, context); + + this._context = getContext(element, context); if (this._instance) { this._updateClassComponent(element, this._context); @@ -372,6 +373,14 @@ function shouldConstruct(Component) { return !!(Component.prototype && Component.prototype.isReactComponent); } +function getContextTypeContext(contextType, context) { + if (context !== undefined && context !== emptyObject) { + return context; + } + + return contextType._currentValue; +} + function getMaskedContext(contextTypes, unmaskedContext) { if (!contextTypes) { return emptyObject; @@ -383,4 +392,12 @@ function getMaskedContext(contextTypes, unmaskedContext) { return context; } +function getContext(element, contextOption) { + if (element.type.contextType) { + return getContextTypeContext(element.type.contextType, contextOption); + } + + return getMaskedContext(element.type.contextTypes, contextOption); +} + export default ReactShallowRenderer; diff --git a/packages/react-test-renderer/src/__tests__/ReactShallowRenderer-test.js b/packages/react-test-renderer/src/__tests__/ReactShallowRenderer-test.js index b6c4259af9f..5518ab2113a 100644 --- a/packages/react-test-renderer/src/__tests__/ReactShallowRenderer-test.js +++ b/packages/react-test-renderer/src/__tests__/ReactShallowRenderer-test.js @@ -594,6 +594,38 @@ describe('ReactShallowRenderer', () => { expect(result).toEqual(
); }); + it('can shallowly render components with contextType and default context', () => { + const SimpleContext = React.createContext('hello world'); + + class SimpleComponent extends React.Component { + static contextType = SimpleContext; + + render() { + return
{this.context}
; + } + } + + const shallowRenderer = createRenderer(); + const result = shallowRenderer.render(); + expect(result).toEqual(
hello world
); + }); + + it('can shallowly render components with contextType', () => { + const SimpleContext = React.createContext('hello world'); + + class SimpleComponent extends React.Component { + static contextType = SimpleContext; + + render() { + return
{this.context}
; + } + } + + const shallowRenderer = createRenderer(); + const result = shallowRenderer.render(, 'Foo bar'); + expect(result).toEqual(
Foo bar
); + }); + it('passes expected params to legacy component lifecycle methods', () => { const componentDidUpdateParams = []; const componentWillReceivePropsParams = [];