diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java index e77b89f9a4ec..48cff6df9302 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactory.java @@ -16,6 +16,8 @@ package com.google.cloud.dataflow.sdk.options; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.cloud.dataflow.sdk.options.Validation.Required; import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.runners.PipelineRunnerRegistrar; @@ -1391,7 +1393,10 @@ private static ListMultimap parseCommandLine( * split up each string on ','. * *

We special case the "runner" option. It is mapped to the class of the {@link PipelineRunner} - * based off of the {@link PipelineRunner}s simple class name or fully qualified class name. + * based off of the {@link PipelineRunner PipelineRunners} simple class name. If the provided + * runner name is not registered via a {@link PipelineRunnerRegistrar}, we attempt to obtain the + * class that the name represents using {@link Class#forName(String)} and use the result class if + * it subclasses {@link PipelineRunner}. * *

If strict parsing is enabled, unknown options or options that cannot be converted to * the expected java type using an {@link ObjectMapper} will be ignored. @@ -1442,10 +1447,26 @@ public boolean apply(@Nullable String input) { JavaType type = MAPPER.getTypeFactory().constructType(method.getGenericReturnType()); if ("runner".equals(entry.getKey())) { String runner = Iterables.getOnlyElement(entry.getValue()); - Preconditions.checkArgument(SUPPORTED_PIPELINE_RUNNERS.containsKey(runner), - "Unknown 'runner' specified '%s', supported pipeline runners %s", - runner, Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); - convertedOptions.put("runner", SUPPORTED_PIPELINE_RUNNERS.get(runner)); + if (SUPPORTED_PIPELINE_RUNNERS.containsKey(runner)) { + convertedOptions.put("runner", SUPPORTED_PIPELINE_RUNNERS.get(runner)); + } else { + try { + Class runnerClass = Class.forName(runner); + checkArgument( + PipelineRunner.class.isAssignableFrom(runnerClass), + "Class '%s' does not implement PipelineRunner. Supported pipeline runners %s", + runner, + Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); + convertedOptions.put("runner", runnerClass); + } catch (ClassNotFoundException e) { + String msg = + String.format( + "Unknown 'runner' specified '%s', supported pipeline runners %s", + runner, + Sets.newTreeSet(SUPPORTED_PIPELINE_RUNNERS.keySet())); + throw new IllegalArgumentException(msg, e); + } + } } else if ((returnType.isArray() && (SIMPLE_TYPES.contains(returnType.getComponentType()) || returnType.getComponentType().isEnum())) || Collection.class.isAssignableFrom(returnType)) { diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java index e687f2798946..045a8ad0f257 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/options/PipelineOptionsFactoryTest.java @@ -25,8 +25,12 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.runners.BlockingDataflowPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.DataflowPipelineRunner; import com.google.cloud.dataflow.sdk.runners.DirectPipelineRunner; +import com.google.cloud.dataflow.sdk.runners.PipelineRunner; import com.google.cloud.dataflow.sdk.testing.ExpectedLogs; import com.google.cloud.dataflow.sdk.testing.RestoreSystemProperties; import com.google.common.collect.ArrayListMultimap; @@ -824,6 +828,14 @@ public void testSettingRunner() { assertEquals(BlockingDataflowPipelineRunner.class, options.getRunner()); } + @Test + public void testSettingRunnerFullName() { + String[] args = + new String[] {String.format("--runner=%s", DataflowPipelineRunner.class.getName())}; + PipelineOptions opts = PipelineOptionsFactory.fromArgs(args).create(); + assertEquals(opts.getRunner(), DataflowPipelineRunner.class); + } + @Test public void testSettingUnknownRunner() { String[] args = new String[] {"--runner=UnknownRunner"}; @@ -834,6 +846,30 @@ public void testSettingUnknownRunner() { PipelineOptionsFactory.fromArgs(args).create(); } + private static class ExampleTestRunner extends PipelineRunner { + @Override + public PipelineResult run(Pipeline pipeline) { + return null; + } + } + + @Test + public void testSettingRunnerCanonicalClassNameNotInSupportedExists() { + String[] args = new String[] {String.format("--runner=%s", ExampleTestRunner.class.getName())}; + PipelineOptions opts = PipelineOptionsFactory.fromArgs(args).create(); + assertEquals(opts.getRunner(), ExampleTestRunner.class); + } + + @Test + public void testSettingRunnerCanonicalClassNameNotInSupportedNotPipelineRunner() { + String[] args = new String[] {"--runner=java.lang.String"}; + expectedException.expect(IllegalArgumentException.class); + expectedException.expectMessage("does not implement PipelineRunner"); + expectedException.expectMessage("java.lang.String"); + + PipelineOptionsFactory.fromArgs(args).create(); + } + @Test public void testUsingArgumentWithUnknownPropertyIsNotAllowed() { String[] args = new String[] {"--unknownProperty=value"};