diff --git a/src/main/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToList.java b/src/main/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToList.java index 42c6da6813..67c1bef233 100644 --- a/src/main/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToList.java +++ b/src/main/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToList.java @@ -27,6 +27,8 @@ import org.openrewrite.java.search.UsesMethod; import org.openrewrite.java.tree.Expression; import org.openrewrite.java.tree.J; +import org.openrewrite.java.tree.JavaType; +import org.openrewrite.java.tree.TypeUtils; import java.time.Duration; import java.util.Collections; @@ -95,6 +97,12 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu Expression command = method.getArguments().get(0); if (COLLECT_TO_UNMODIFIABLE_LIST.matches(command) || convertToList && COLLECT_TO_LIST.matches(command)) { + + // Check if the transformation would result in incompatible types + if (!areTypesCompatible(result)) { + return result; + } + maybeRemoveImport("java.util.stream.Collectors"); J.MethodInvocation toList = JavaTemplate.apply("#{any(java.util.stream.Stream)}.toList()", updateCursor(result), result.getCoordinates().replace(), result.getSelect()); @@ -102,5 +110,19 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu } return result; } + + private boolean areTypesCompatible(J.MethodInvocation method) { + if (method.getSelect() == null || + method.getSelect().getType() == null || + !(method.getSelect().getType() instanceof JavaType.Parameterized) || + !(method.getType() instanceof JavaType.Parameterized)) { + return false; + } + // Check if the stream element type and expected list element type are exactly the same + // If they differ (e.g., Stream but List), don't transform + return TypeUtils.isOfType( + ((JavaType.Parameterized) method.getSelect().getType()).getTypeParameters().get(0), + ((JavaType.Parameterized) method.getType()).getTypeParameters().get(0)); + } } } diff --git a/src/test/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToListTest.java b/src/test/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToListTest.java index 5e2f09a510..36d9f9db81 100644 --- a/src/test/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToListTest.java +++ b/src/test/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToListTest.java @@ -17,6 +17,7 @@ import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; +import org.openrewrite.Issue; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; @@ -180,4 +181,99 @@ List test(Stream stream) { ); } + @Issue("https://github.com/openrewrite/rewrite-migrate-java/issues/791") + @Test + void doesNotReplaceWhenReturnTypeIsIncompatible() { + rewriteRun( + //language=java + java( + """ + import java.util.stream.Collectors; + import java.util.stream.Stream; + import java.util.List; + + class Example { + List foo() { + return Stream.of(Integer.valueOf(1)).collect(Collectors.toUnmodifiableList()); + } + } + """ + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-migrate-java/issues/791") + @Test + void replacesWhenTypesAreCompatible() { + rewriteRun( + //language=java + java( + """ + import java.util.stream.Collectors; + import java.util.stream.Stream; + import java.util.List; + + class Example { + List foo() { + return Stream.of(Integer.valueOf(1)).collect(Collectors.toUnmodifiableList()); + } + } + """, + """ + import java.util.stream.Stream; + import java.util.List; + + class Example { + List foo() { + return Stream.of(Integer.valueOf(1)).toList(); + } + } + """ + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-migrate-java/issues/791") + @Test + void doesNotReplaceInVariableAssignmentWithIncompatibleTypes() { + rewriteRun( + //language=java + java( + """ + import java.util.stream.Collectors; + import java.util.stream.Stream; + import java.util.List; + + class Example { + void foo() { + List numbers = Stream.of(Integer.valueOf(1)).collect(Collectors.toUnmodifiableList()); + } + } + """ + ) + ); + } + + @Issue("https://github.com/openrewrite/rewrite-migrate-java/issues/791") + @Test + void doesNotReplaceWithToListWhenConvertToListFlagIsTrue() { + rewriteRun( + recipeSpec -> recipeSpec.recipe(new ReplaceStreamCollectWithToList(true)), + //language=java + java( + """ + import java.util.stream.Collectors; + import java.util.stream.Stream; + import java.util.List; + + class Example { + List foo() { + return Stream.of(Integer.valueOf(1)).collect(Collectors.toList()); + } + } + """ + ) + ); + } + }