diff --git a/src/services/refactors/convertExport.ts b/src/services/refactors/convertExport.ts
index 3d3a9de095b80..f2f361aa86b70 100644
--- a/src/services/refactors/convertExport.ts
+++ b/src/services/refactors/convertExport.ts
@@ -65,8 +65,8 @@ namespace ts.refactor {
return { error: getLocaleSpecificMessage(Diagnostics.Could_not_find_export_statement) };
}
- const exportingModuleSymbol = isSourceFile(exportNode.parent) ? exportNode.parent.symbol : exportNode.parent.parent.symbol;
-
+ const checker = program.getTypeChecker();
+ const exportingModuleSymbol = getExportingModuleSymbol(exportNode, checker);
const flags = getSyntacticModifierFlags(exportNode) || ((isExportAssignment(exportNode) && !exportNode.isExportEquals) ? ModifierFlags.ExportDefault : ModifierFlags.None);
const wasDefault = !!(flags & ModifierFlags.Default);
@@ -75,7 +75,6 @@ namespace ts.refactor {
return { error: getLocaleSpecificMessage(Diagnostics.This_file_already_has_a_default_export) };
}
- const checker = program.getTypeChecker();
const noSymbolError = (id: Node) =>
(isIdentifier(id) && checker.getSymbolAtLocation(id)) ? undefined
: { error: getLocaleSpecificMessage(Diagnostics.Can_only_convert_named_export) };
@@ -165,6 +164,7 @@ namespace ts.refactor {
const checker = program.getTypeChecker();
const exportSymbol = Debug.checkDefined(checker.getSymbolAtLocation(exportName), "Export name should resolve to a symbol");
FindAllReferences.Core.eachExportReference(program.getSourceFiles(), checker, cancellationToken, exportSymbol, exportingModuleSymbol, exportName.text, wasDefault, ref => {
+ if (exportName === ref) return;
const importingSourceFile = ref.getSourceFile();
if (wasDefault) {
changeDefaultToNamedImport(importingSourceFile, ref, changes, exportName.text);
@@ -258,4 +258,16 @@ namespace ts.refactor {
function makeExportSpecifier(propertyName: string, name: string): ExportSpecifier {
return factory.createExportSpecifier(/*isTypeOnly*/ false, propertyName === name ? undefined : factory.createIdentifier(propertyName), factory.createIdentifier(name));
}
+
+ function getExportingModuleSymbol(node: Node, checker: TypeChecker) {
+ const parent = node.parent;
+ if (isSourceFile(parent)) {
+ return parent.symbol;
+ }
+ const symbol = parent.parent.symbol;
+ if (symbol.valueDeclaration && isExternalModuleAugmentation(symbol.valueDeclaration)) {
+ return checker.getMergedSymbol(symbol);
+ }
+ return symbol;
+ }
}
diff --git a/tests/cases/fourslash/refactorConvertExport_namedToDefaultInModuleAugmentation1.ts b/tests/cases/fourslash/refactorConvertExport_namedToDefaultInModuleAugmentation1.ts
new file mode 100644
index 0000000000000..d1fc03349715e
--- /dev/null
+++ b/tests/cases/fourslash/refactorConvertExport_namedToDefaultInModuleAugmentation1.ts
@@ -0,0 +1,26 @@
+///
+
+// @Filename: /node_modules/@types/foo/index.d.ts
+////export {};
+////declare module "foo" {
+//// /*a*/export function foo(): void;/*b*/
+////}
+
+// @Filename: /b.ts
+////import { foo } from "foo";
+
+goTo.select("a", "b");
+edit.applyRefactor({
+ refactorName: "Convert export",
+ actionName: "Convert named export to default export",
+ actionDescription: "Convert named export to default export",
+ newContent: {
+ "/node_modules/@types/foo/index.d.ts":
+`export {};
+declare module "foo" {
+ export default function foo(): void;
+}`,
+ "/b.ts":
+`import foo from "foo";`
+ }
+});
diff --git a/tests/cases/fourslash/refactorConvertExport_namedToDefaultInModuleAugmentation2.ts b/tests/cases/fourslash/refactorConvertExport_namedToDefaultInModuleAugmentation2.ts
new file mode 100644
index 0000000000000..e278769f568df
--- /dev/null
+++ b/tests/cases/fourslash/refactorConvertExport_namedToDefaultInModuleAugmentation2.ts
@@ -0,0 +1,18 @@
+///
+
+////export {};
+////declare module "foo" {
+//// /*a*/export function func(): void;/*b*/
+////}
+
+goTo.select("a", "b");
+edit.applyRefactor({
+ refactorName: "Convert export",
+ actionName: "Convert named export to default export",
+ actionDescription: "Convert named export to default export",
+ newContent:
+`export {};
+declare module "foo" {
+ export default function func(): void;
+}`
+});