diff --git a/src/main/java/org/openrewrite/java/migrate/lang/ExplicitRecordImport.java b/src/main/java/org/openrewrite/java/migrate/lang/ExplicitRecordImport.java index e5576cb5ef..f7adf22cde 100644 --- a/src/main/java/org/openrewrite/java/migrate/lang/ExplicitRecordImport.java +++ b/src/main/java/org/openrewrite/java/migrate/lang/ExplicitRecordImport.java @@ -20,10 +20,9 @@ import org.openrewrite.Recipe; import org.openrewrite.TreeVisitor; import org.openrewrite.java.JavaIsoVisitor; +import org.openrewrite.java.search.FindTypes; import org.openrewrite.java.search.UsesType; -import org.openrewrite.java.tree.J; -import org.openrewrite.java.tree.JavaSourceFile; -import org.openrewrite.java.tree.JavaType; +import org.openrewrite.java.tree.*; public class ExplicitRecordImport extends Recipe { @Override @@ -45,10 +44,12 @@ public TreeVisitor getVisitor() { public J.CompilationUnit visitCompilationUnit(J.CompilationUnit cu, ExecutionContext ctx) { JavaSourceFile javaSourceFile = getCursor().firstEnclosing(JavaSourceFile.class); if (javaSourceFile != null) { - for (JavaType type : cu.getTypesInUse().getTypesInUse()) { - if (type instanceof JavaType.FullyQualified) { - JavaType.FullyQualified ref = (JavaType.FullyQualified) type; - if ("Record".equals(ref.getClassName()) && !ref.getPackageName().startsWith("java.lang")) { + for (NameTree nameTree : FindTypes.findAssignable(cu, "*..Record")) { + if (nameTree.getType() instanceof JavaType.FullyQualified) { + JavaType.FullyQualified ref = (JavaType.FullyQualified) nameTree.getType(); + if ("Record".equals(ref.getClassName()) && + !ref.getPackageName().startsWith("java.lang") && + !nameTree.getMarkers().findFirst(JavaVarKeyword.class).isPresent()) { maybeAddImport(ref.getFullyQualifiedName()); } } diff --git a/src/test/java/org/openrewrite/java/migrate/lang/ExplicitRecordImportTest.java b/src/test/java/org/openrewrite/java/migrate/lang/ExplicitRecordImportTest.java index f7da06d7f0..803bb8dc0f 100644 --- a/src/test/java/org/openrewrite/java/migrate/lang/ExplicitRecordImportTest.java +++ b/src/test/java/org/openrewrite/java/migrate/lang/ExplicitRecordImportTest.java @@ -15,7 +15,6 @@ */ package org.openrewrite.java.migrate.lang; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.Issue; @@ -32,12 +31,19 @@ public void defaults(RecipeSpec spec) { spec.recipe(new ExplicitRecordImport()) //language=java .parser(JavaParser.fromJavaVersion().dependsOn(""" - package com.acme.music; - public class Record { - String name; - } - """ - ) + package com.acme.music; + public class Record { + public String name; + } + """, + """ + package com.acme.music; + import java.util.List; + public class RecordList { + public List records; + } + """ + ) ); } @@ -51,7 +57,7 @@ void addImportFromSamePackage() { """ package com.acme.music; - public class Test { + class Test { Record record; } """, @@ -60,7 +66,7 @@ public class Test { import com.acme.music.Record; - public class Test { + class Test { Record record; } """ @@ -68,19 +74,78 @@ public class Test { ); } + @Test + void noChangeIfUsingVar() { + rewriteRun( + //language=java + java( + """ + package com.acme.music; + + import com.acme.music.RecordList; + + class Test { + void test() { + for (var record : new RecordList().records) { + } + } + } + """ + ) + ); + } @Test - @Disabled("Not handled yet; deemed unlikely") - void noChangeIfAlreadyFullyQualified() { + void genericUseOfRecordClass() { rewriteRun( //language=java java( """ package com.acme.music; - public class Test { + import java.util.List; + + class Test { + List records; + } + """, + """ + package com.acme.music; + + import com.acme.music.Record; + + import java.util.List; + + class Test { + List records; + } + """ + ) + ); + } + + + @Test + void documentChangeWhenFullyQualified() { + rewriteRun( + //language=java + java( + """ + package com.acme.music; + + class Test { com.acme.music.Record record; } + """, + // Perhaps undesired, but also unlikely, so not worth changing + """ + package com.acme.music; + + import com.acme.music.Record; + + class Test { + Record record; + } """ ) ); @@ -97,7 +162,7 @@ void noChangeIfAlreadyImported() { import com.acme.music.Record; - public class Test { + class Test { Record record; } """ @@ -113,7 +178,7 @@ void noImportAddedForJavaLangRecord() { """ package foo.bar; - public class Test { + class Test { Record record; } """