diff --git a/src/services/codefixes/convertToAsyncFunction.ts b/src/services/codefixes/convertToAsyncFunction.ts index 6474c84df9b08..4a4c14655c3a4 100644 --- a/src/services/codefixes/convertToAsyncFunction.ts +++ b/src/services/codefixes/convertToAsyncFunction.ts @@ -7,11 +7,11 @@ namespace ts.codefix { errorCodes, getCodeActions(context: CodeFixContext) { codeActionSucceeded = true; - const changes = textChanges.ChangeTracker.with(context, (t) => convertToAsyncFunction(t, context.sourceFile, context.span.start, context.program.getTypeChecker(), context)); + const changes = textChanges.ChangeTracker.with(context, (t) => convertToAsyncFunction(t, context.sourceFile, context.span.start, context.program.getTypeChecker())); return codeActionSucceeded ? [createCodeFixAction(fixId, changes, Diagnostics.Convert_to_async_function, fixId, Diagnostics.Convert_all_to_async_functions)] : []; }, fixIds: [fixId], - getAllCodeActions: context => codeFixAll(context, errorCodes, (changes, err) => convertToAsyncFunction(changes, err.file, err.start, context.program.getTypeChecker(), context)), + getAllCodeActions: context => codeFixAll(context, errorCodes, (changes, err) => convertToAsyncFunction(changes, err.file, err.start, context.program.getTypeChecker())), }); const enum SynthBindingNameKind { @@ -43,7 +43,7 @@ namespace ts.codefix { readonly isInJSFile: boolean; } - function convertToAsyncFunction(changes: textChanges.ChangeTracker, sourceFile: SourceFile, position: number, checker: TypeChecker, context: CodeFixContextBase): void { + function convertToAsyncFunction(changes: textChanges.ChangeTracker, sourceFile: SourceFile, position: number, checker: TypeChecker): void { // get the function declaration - returns a promise const tokenAtPosition = getTokenAtPosition(sourceFile, position); let functionToConvert: FunctionLikeDeclaration | undefined; @@ -64,7 +64,7 @@ namespace ts.codefix { const synthNamesMap = new Map(); const isInJavascript = isInJSFile(functionToConvert); const setOfExpressionsToReturn = getAllPromiseExpressionsToReturn(functionToConvert, checker); - const functionToConvertRenamed = renameCollidingVarNames(functionToConvert, checker, synthNamesMap, context.sourceFile); + const functionToConvertRenamed = renameCollidingVarNames(functionToConvert, checker, synthNamesMap); const returnStatements = functionToConvertRenamed.body && isBlock(functionToConvertRenamed.body) ? getReturnStatementsWithPromiseHandlers(functionToConvertRenamed.body) : emptyArray; const transformer: Transformer = { checker, synthNamesMap, setOfExpressionsToReturn, isInJSFile: isInJavascript }; @@ -139,16 +139,12 @@ namespace ts.codefix { return !!checker.getPromisedTypeOfPromise(checker.getTypeAtLocation(node)); } - function declaredInFile(symbol: Symbol, sourceFile: SourceFile): boolean { - return symbol.valueDeclaration && symbol.valueDeclaration.getSourceFile() === sourceFile; - } - /* Renaming of identifiers may be neccesary as the refactor changes scopes - This function collects all existing identifier names and names of identifiers that will be created in the refactor. It then checks for any collisions and renames them through getSynthesizedDeepClone */ - function renameCollidingVarNames(nodeToRename: FunctionLikeDeclaration, checker: TypeChecker, synthNamesMap: ESMap, sourceFile: SourceFile): FunctionLikeDeclaration { + function renameCollidingVarNames(nodeToRename: FunctionLikeDeclaration, checker: TypeChecker, synthNamesMap: ESMap): FunctionLikeDeclaration { const identsToRenameMap = new Map(); // key is the symbol id const collidingSymbolMap = createMultiMap(); forEachChild(nodeToRename, function visit(node: Node) { @@ -156,11 +152,8 @@ namespace ts.codefix { forEachChild(node, visit); return; } - const symbol = checker.getSymbolAtLocation(node); - const isDefinedInFile = symbol && declaredInFile(symbol, sourceFile); - - if (symbol && isDefinedInFile) { + if (symbol) { const type = checker.getTypeAtLocation(node); // Note - the choice of the last call signature is arbitrary const lastCallSignature = getLastCallSignature(type, checker); diff --git a/src/testRunner/unittests/services/convertToAsyncFunction.ts b/src/testRunner/unittests/services/convertToAsyncFunction.ts index 952f6bf4d68e7..20c51045c47fe 100644 --- a/src/testRunner/unittests/services/convertToAsyncFunction.ts +++ b/src/testRunner/unittests/services/convertToAsyncFunction.ts @@ -255,6 +255,14 @@ interface String { charAt: any; } interface Array {}` }; + const moduleFile: TestFSWithWatch.File = { + path: "/module.ts", + content: +`export function fn(res: any): any { + return res; +}` + }; + type WithSkipAndOnly = ((...args: T) => void) & { skip: (...args: T) => void; only: (...args: T) => void; @@ -269,7 +277,7 @@ interface Array {}` } } - function testConvertToAsyncFunction(it: Mocha.PendingTestFunction, caption: string, text: string, baselineFolder: string, includeLib?: boolean, expectFailure = false, onlyProvideAction = false) { + function testConvertToAsyncFunction(it: Mocha.PendingTestFunction, caption: string, text: string, baselineFolder: string, includeLib?: boolean, includeModule?: boolean, expectFailure = false, onlyProvideAction = false) { const t = extractTest(text); const selectionRange = t.ranges.get("selection")!; if (!selectionRange) { @@ -283,7 +291,7 @@ interface Array {}` function runBaseline(extension: Extension) { const path = "/a" + extension; - const languageService = makeLanguageService({ path, content: t.source }, includeLib); + const languageService = makeLanguageService({ path, content: t.source }, includeLib, includeModule); const program = languageService.getProgram()!; if (hasSyntacticDiagnostics(program)) { @@ -338,17 +346,23 @@ interface Array {}` const newText = textChanges.applyChanges(sourceFile.text, changes[0].textChanges); data.push(newText); - const diagProgram = makeLanguageService({ path, content: newText }, includeLib).getProgram()!; + const diagProgram = makeLanguageService({ path, content: newText }, includeLib, includeModule).getProgram()!; assert.isFalse(hasSyntacticDiagnostics(diagProgram)); Harness.Baseline.runBaseline(`${baselineFolder}/${caption}${extension}`, data.join(newLineCharacter)); } - function makeLanguageService(f: { path: string, content: string }, includeLib?: boolean) { - - const host = projectSystem.createServerHost(includeLib ? [f, libFile] : [f]); // libFile is expensive to parse repeatedly - only test when required + function makeLanguageService(file: TestFSWithWatch.File, includeLib?: boolean, includeModule?: boolean) { + const files = [file]; + if (includeLib) { + files.push(libFile); // libFile is expensive to parse repeatedly - only test when required + } + if (includeModule) { + files.push(moduleFile); + } + const host = projectSystem.createServerHost(files); const projectService = projectSystem.createProjectService(host); - projectService.openClientFile(f.path); - return projectService.inferredProjects[0].getLanguageService(); + projectService.openClientFile(file.path); + return first(projectService.inferredProjects).getLanguageService(); } function hasSyntacticDiagnostics(program: Program) { @@ -362,11 +376,15 @@ interface Array {}` }); const _testConvertToAsyncFunctionFailed = createTestWrapper((it, caption: string, text: string) => { - testConvertToAsyncFunction(it, caption, text, "convertToAsyncFunction", /*includeLib*/ true, /*expectFailure*/ true); + testConvertToAsyncFunction(it, caption, text, "convertToAsyncFunction", /*includeLib*/ true, /*includeModule*/ false, /*expectFailure*/ true); }); const _testConvertToAsyncFunctionFailedSuggestion = createTestWrapper((it, caption: string, text: string) => { - testConvertToAsyncFunction(it, caption, text, "convertToAsyncFunction", /*includeLib*/ true, /*expectFailure*/ true, /*onlyProvideAction*/ true); + testConvertToAsyncFunction(it, caption, text, "convertToAsyncFunction", /*includeLib*/ true, /*includeModule*/ false, /*expectFailure*/ true, /*onlyProvideAction*/ true); + }); + + const _testConvertToAsyncFunctionWithModule = createTestWrapper((it, caption: string, text: string) => { + testConvertToAsyncFunction(it, caption, text, "convertToAsyncFunction", /*includeLib*/ true, /*includeModule*/ true); }); describe("unittests:: services:: convertToAsyncFunction", () => { @@ -1453,6 +1471,13 @@ const fn = (): Promise<(message: string) => void> => function [#|f|]() { return fn().then(res => res("test")); } +`); + + _testConvertToAsyncFunctionWithModule("convertToAsyncFunction_importedFunction", ` +import { fn } from "./module"; +function [#|f|]() { + return Promise.resolve(0).then(fn); +} `); }); diff --git a/tests/baselines/reference/convertToAsyncFunction/convertToAsyncFunction_importedFunction.js b/tests/baselines/reference/convertToAsyncFunction/convertToAsyncFunction_importedFunction.js new file mode 100644 index 0000000000000..7d81cc2bbb148 --- /dev/null +++ b/tests/baselines/reference/convertToAsyncFunction/convertToAsyncFunction_importedFunction.js @@ -0,0 +1,14 @@ +// ==ORIGINAL== + +import { fn } from "./module"; +function /*[#|*/f/*|]*/() { + return Promise.resolve(0).then(fn); +} + +// ==ASYNC FUNCTION::Convert to async function== + +import { fn } from "./module"; +async function f() { + const res = await Promise.resolve(0); + return fn(res); +} diff --git a/tests/baselines/reference/convertToAsyncFunction/convertToAsyncFunction_importedFunction.ts b/tests/baselines/reference/convertToAsyncFunction/convertToAsyncFunction_importedFunction.ts new file mode 100644 index 0000000000000..7d81cc2bbb148 --- /dev/null +++ b/tests/baselines/reference/convertToAsyncFunction/convertToAsyncFunction_importedFunction.ts @@ -0,0 +1,14 @@ +// ==ORIGINAL== + +import { fn } from "./module"; +function /*[#|*/f/*|]*/() { + return Promise.resolve(0).then(fn); +} + +// ==ASYNC FUNCTION::Convert to async function== + +import { fn } from "./module"; +async function f() { + const res = await Promise.resolve(0); + return fn(res); +}