diff --git a/src/main/java/org/openrewrite/java/migrate/guava/NoGuavaPredicatesAndOr.java b/src/main/java/org/openrewrite/java/migrate/guava/NoGuavaPredicatesAndOr.java index 61b9b9b11a..f105129558 100644 --- a/src/main/java/org/openrewrite/java/migrate/guava/NoGuavaPredicatesAndOr.java +++ b/src/main/java/org/openrewrite/java/migrate/guava/NoGuavaPredicatesAndOr.java @@ -83,8 +83,8 @@ private J handlePredicatesMethod(J.MethodInvocation method, String operation) { return method; } - // If the first argument is a method reference, wrap it with a cast - if (result instanceof J.MemberReference && result.getType() != null) { + // If the first argument is a method reference or a lambda, wrap it with a cast + if ((result instanceof J.MemberReference || result instanceof J.Lambda) && result.getType() != null) { String typeString = result.getType().toString().replace("com.google.common.base.", ""); result = JavaTemplate.apply("((" + typeString + ") #{any()})", getCursor(), method.getCoordinates().replace(), result); } diff --git a/src/test/java/org/openrewrite/java/migrate/guava/NoGuavaPredicatesAndOrTest.java b/src/test/java/org/openrewrite/java/migrate/guava/NoGuavaPredicatesAndOrTest.java index 7bdcac353a..799a8b94ff 100644 --- a/src/test/java/org/openrewrite/java/migrate/guava/NoGuavaPredicatesAndOrTest.java +++ b/src/test/java/org/openrewrite/java/migrate/guava/NoGuavaPredicatesAndOrTest.java @@ -101,18 +101,38 @@ void replacePredicatesAndWithLambdas() { import com.google.common.base.Predicates; class Test { - Predicate isNotNull = s -> s != null; - Predicate isLong = s -> s.length() > 5; - Predicate combined = Predicates.and(isNotNull, isLong); + Predicate combined = Predicates.and(s -> s != null, s -> s.length() > 5); } """, """ import com.google.common.base.Predicate; class Test { - Predicate isNotNull = s -> s != null; - Predicate isLong = s -> s.length() > 5; - Predicate combined = isNotNull.and(isLong); + Predicate combined = ((Predicate) s -> s != null).and(s -> s.length() > 5); + } + """ + ) + ); + } + + @Test + void replacePredicatesOrWithLambdas() { + //language=java + rewriteRun( + java( + """ + import com.google.common.base.Predicate; + import com.google.common.base.Predicates; + + class Test { + Predicate combined = Predicates.or(s -> s != null, s -> s.length() > 5); + } + """, + """ + import com.google.common.base.Predicate; + + class Test { + Predicate combined = ((Predicate) s -> s != null).or(s -> s.length() > 5); } """ )