diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java index 720f1ee8b516..bf7b8d408eed 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/External.java @@ -98,8 +98,12 @@ SingleOutputExpandableTransform of( } @VisibleForTesting - static SingleOutputExpandableTransform of( - String urn, byte[] payload, String endpoint, ExpansionServiceClientFactory clientFactory) { + public static + SingleOutputExpandableTransform of( + String urn, + byte[] payload, + String endpoint, + ExpansionServiceClientFactory clientFactory) { Endpoints.ApiServiceDescriptor apiDesc = Endpoints.ApiServiceDescriptor.newBuilder().setUrl(endpoint).build(); return new SingleOutputExpandableTransform<>( diff --git a/runners/google-cloud-dataflow-java/build.gradle b/runners/google-cloud-dataflow-java/build.gradle index 8bbbfe4eceb7..b2a2be9cdfcc 100644 --- a/runners/google-cloud-dataflow-java/build.gradle +++ b/runners/google-cloud-dataflow-java/build.gradle @@ -114,6 +114,7 @@ dependencies { testImplementation project(path: ":sdks:java:core", configuration: "shadowTest") testImplementation project(path: ":sdks:java:extensions:google-cloud-platform-core", configuration: "testRuntimeMigration") testImplementation project(path: ":runners:core-construction-java", configuration: "testRuntimeMigration") + testImplementation project(path: ":sdks:java:extensions:python", configuration: "testRuntimeMigration") testImplementation library.java.google_cloud_dataflow_java_proto_library_all testImplementation library.java.jackson_dataformat_yaml testImplementation library.java.mockito_core @@ -417,7 +418,6 @@ createCrossLanguageValidatesRunnerTask( "--project=${dataflowProject}", "--region=${dataflowRegion}", "--sdk_harness_container_image_overrides=.*java.*,${dockerJavaImageContainer}:${dockerTag}", - "--experiments=use_runner_v2", // TODO(BEAM-11779) remove shuffle_mode=appliance with runner v2 once issue is resolved "--experiments=shuffle_mode=appliance", ], @@ -429,7 +429,7 @@ createCrossLanguageValidatesRunnerTask( "--sdkContainerImage=${dockerJavaImageContainer}:${dockerTag}", "--sdkHarnessContainerImageOverrides=.*python.*,${dockerPythonImageContainer}:${dockerTag}", // TODO(BEAM-11779) remove shuffle_mode=appliance with runner v2 once issue is resolved. - "--experiments=use_runner_v2,shuffle_mode=appliance", + "--experiments=shuffle_mode=appliance", ], pytestOptions: [ "--capture=no", diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index c0abbaded25e..e4de35e278e7 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -1038,10 +1038,51 @@ private List getDefaultArtifacts() { return Environments.getArtifacts(pathsToStageBuilder.build()); } + @VisibleForTesting + static boolean isMultiLanguagePipeline(Pipeline pipeline) { + class IsMultiLanguageVisitor extends PipelineVisitor.Defaults { + private boolean isMultiLanguage = false; + + private void performMultiLanguageTest(Node node) { + if (node.getTransform() instanceof External.ExpandableTransform) { + isMultiLanguage = true; + } + } + + @Override + public CompositeBehavior enterCompositeTransform(Node node) { + performMultiLanguageTest(node); + return super.enterCompositeTransform(node); + } + + @Override + public void visitPrimitiveTransform(Node node) { + performMultiLanguageTest(node); + super.visitPrimitiveTransform(node); + } + } + + IsMultiLanguageVisitor visitor = new IsMultiLanguageVisitor(); + pipeline.traverseTopologically(visitor); + + return visitor.isMultiLanguage; + } + @Override public DataflowPipelineJob run(Pipeline pipeline) { + if (DataflowRunner.isMultiLanguagePipeline(pipeline)) { + List experiments = firstNonNull(options.getExperiments(), Collections.emptyList()); + if (!experiments.contains("use_runner_v2")) { + LOG.info( + "Automatically enabling Dataflow Runner v2 since the pipeline used cross-language" + + " transforms"); + options.setExperiments( + ImmutableList.builder().addAll(experiments).add("use_runner_v2").build()); + } + } if (useUnifiedWorker(options)) { - List experiments = options.getExperiments(); // non-null if useUnifiedWorker is true + List experiments = + new ArrayList<>(options.getExperiments()); // non-null if useUnifiedWorker is true if (!experiments.contains("use_runner_v2")) { experiments.add("use_runner_v2"); } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index c192f239a61b..afd75bb843cc 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -84,9 +84,14 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Pattern; import java.util.stream.Collectors; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.Endpoints; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.core.construction.BeamUrns; import org.apache.beam.runners.core.construction.Environments; +import org.apache.beam.runners.core.construction.ExpansionServiceClient; +import org.apache.beam.runners.core.construction.ExpansionServiceClientFactory; +import org.apache.beam.runners.core.construction.External; import org.apache.beam.runners.core.construction.PipelineTranslation; import org.apache.beam.runners.core.construction.SdkComponents; import org.apache.beam.runners.dataflow.DataflowRunner.StreamingShardedWriteFactory; @@ -2225,6 +2230,65 @@ public void testStreamingGroupIntoBatchesWithShardedKeyOverrideBytes() throws IO verifyGroupIntoBatchesOverrideBytes(p, true, true); } + static class TestExpansionServiceClientFactory implements ExpansionServiceClientFactory { + ExpansionApi.ExpansionResponse response; + + @Override + public ExpansionServiceClient getExpansionServiceClient( + Endpoints.ApiServiceDescriptor endpoint) { + return new ExpansionServiceClient() { + @Override + public ExpansionApi.ExpansionResponse expand(ExpansionApi.ExpansionRequest request) { + Pipeline p = TestPipeline.create(); + p.apply(Create.of(1, 2, 3)); + SdkComponents sdkComponents = + SdkComponents.create(p.getOptions()).withNewIdPrefix(request.getNamespace()); + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p, sdkComponents); + String transformId = Iterables.getOnlyElement(pipelineProto.getRootTransformIdsList()); + RunnerApi.Components components = pipelineProto.getComponents(); + ImmutableList.Builder requirementsBuilder = ImmutableList.builder(); + requirementsBuilder.addAll(pipelineProto.getRequirementsList()); + requirementsBuilder.add("ExternalTranslationTest_Requirement_URN"); + response = + ExpansionApi.ExpansionResponse.newBuilder() + .setComponents(components) + .setTransform( + components + .getTransformsOrThrow(transformId) + .toBuilder() + .setUniqueName(transformId)) + .addAllRequirements(requirementsBuilder.build()) + .build(); + return response; + } + + @Override + public void close() throws Exception { + // do nothing + } + }; + } + + @Override + public void close() throws Exception { + // do nothing + } + } + + @Test + public void testIsMultiLanguage() throws IOException { + PipelineOptions options = buildPipelineOptions(); + Pipeline pipeline = Pipeline.create(options); + PCollection col = + pipeline + .apply(Create.of("1", "2", "3")) + .apply( + External.of( + "dummy_urn", new byte[] {}, "", new TestExpansionServiceClientFactory())); + + assertTrue(DataflowRunner.isMultiLanguagePipeline(pipeline)); + } + private void testStreamingWriteOverride(PipelineOptions options, int expectedNumShards) { TestPipeline p = TestPipeline.fromOptions(options);