@@ -17,9 +17,6 @@ import {
1717 subsetEquality ,
1818} from './jest-utils'
1919
20- const isAsyncFunction = ( fn : unknown ) =>
21- typeof fn === 'function' && ( fn as any ) [ Symbol . toStringTag ] === 'AsyncFunction'
22-
2320const getMatcherState = ( assertion : Chai . AssertionStatic & Chai . Assertion , expect : Vi . ExpectStatic ) => {
2421 const obj = assertion . _obj
2522 const isNot = util . flag ( assertion , 'negate' ) as boolean
@@ -56,30 +53,27 @@ class JestExtendError extends Error {
5653function JestExtendPlugin ( expect : Vi . ExpectStatic , matchers : MatchersObject ) : ChaiPlugin {
5754 return ( c , utils ) => {
5855 Object . entries ( matchers ) . forEach ( ( [ expectAssertionName , expectAssertion ] ) => {
59- function expectSyncWrapper ( this : Chai . AssertionStatic & Chai . Assertion , ...args : any [ ] ) {
56+ function expectWrapper ( this : Chai . AssertionStatic & Chai . Assertion , ...args : any [ ] ) {
6057 const { state, isNot, obj } = getMatcherState ( this , expect )
6158
6259 // @ts -expect-error args wanting tuple
63- const { pass, message, actual, expected } = expectAssertion . call ( state , obj , ...args ) as SyncExpectationResult
64-
65- if ( ( pass && isNot ) || ( ! pass && ! isNot ) )
66- throw new JestExtendError ( message ( ) , actual , expected )
67- }
60+ const result = expectAssertion . call ( state , obj , ...args )
6861
69- async function expectAsyncWrapper ( this : Chai . AssertionStatic & Chai . Assertion , ...args : any [ ] ) {
70- const { state, isNot, obj } = getMatcherState ( this , expect )
62+ if ( result && typeof result === 'object' && result instanceof Promise ) {
63+ return result . then ( ( { pass, message, actual, expected } ) => {
64+ if ( ( pass && isNot ) || ( ! pass && ! isNot ) )
65+ throw new JestExtendError ( message ( ) , actual , expected )
66+ } )
67+ }
7168
72- // @ts -expect-error args wanting tuple
73- const { pass, message, actual, expected } = await expectAssertion . call ( state , obj , ...args ) as SyncExpectationResult
69+ const { pass, message, actual, expected } = result
7470
7571 if ( ( pass && isNot ) || ( ! pass && ! isNot ) )
7672 throw new JestExtendError ( message ( ) , actual , expected )
7773 }
7874
79- const expectAssertionWrapper = isAsyncFunction ( expectAssertion ) ? expectAsyncWrapper : expectSyncWrapper
80-
81- utils . addMethod ( ( globalThis as any ) [ JEST_MATCHERS_OBJECT ] . matchers , expectAssertionName , expectAssertionWrapper )
82- utils . addMethod ( c . Assertion . prototype , expectAssertionName , expectAssertionWrapper )
75+ utils . addMethod ( ( globalThis as any ) [ JEST_MATCHERS_OBJECT ] . matchers , expectAssertionName , expectWrapper )
76+ utils . addMethod ( c . Assertion . prototype , expectAssertionName , expectWrapper )
8377
8478 class CustomMatcher extends AsymmetricMatcher < [ unknown , ...unknown [ ] ] > {
8579 constructor ( inverse = false , ...sample : [ unknown , ...unknown [ ] ] ) {
0 commit comments