Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdks/java/container/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ ADD target/slf4j-api.jar /opt/apache/beam/jars/
ADD target/slf4j-jdk14.jar /opt/apache/beam/jars/
ADD target/beam-sdks-java-harness.jar /opt/apache/beam/jars/

# Required to run cross-language pipelines with KafkaIO
# TODO May be removed once custom environments are supported
ADD target/beam-sdks-java-io-kafka.jar /opt/apache/beam/jars/
ADD target/kafka-clients.jar /opt/apache/beam/jars/

ADD target/linux_amd64/boot /opt/apache/beam/

ENTRYPOINT ["/opt/apache/beam/boot"]
5 changes: 5 additions & 0 deletions sdks/java/container/Dockerfile-java11
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ ADD target/slf4j-api.jar /opt/apache/beam/jars/
ADD target/slf4j-jdk14.jar /opt/apache/beam/jars/
ADD target/beam-sdks-java-harness.jar /opt/apache/beam/jars/

# Required to run cross-language pipelines with KafkaIO
# TODO May be removed once custom environments are supported
ADD target/beam-sdks-java-io-kafka.jar /opt/apache/beam/jars/
ADD target/kafka-clients.jar /opt/apache/beam/jars/

ADD target/linux_amd64/boot /opt/apache/beam/

ENTRYPOINT ["/opt/apache/beam/boot"]
2 changes: 2 additions & 0 deletions sdks/java/container/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ func main() {
filepath.Join(jarsDir, "slf4j-api.jar"),
filepath.Join(jarsDir, "slf4j-jdk14.jar"),
filepath.Join(jarsDir, "beam-sdks-java-harness.jar"),
filepath.Join(jarsDir, "beam-sdks-java-io-kafka.jar"),
filepath.Join(jarsDir, "kafka-clients.jar"),
}

var hasWorkerExperiment = strings.Contains(options, "use_staged_dataflow_worker_jar")
Expand Down
4 changes: 4 additions & 0 deletions sdks/java/container/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ dependencies {
dockerDependency library.java.slf4j_api
dockerDependency library.java.slf4j_jdk14
dockerDependency project(path: ":beam-sdks-java-harness", configuration: "shadow")
// For executing KafkaIO, e.g. as an external transform
dockerDependency project(path: ":beam-sdks-java-io-kafka", configuration: "shadow")
}

def dockerfileName = project.findProperty('dockerfile') ?: 'Dockerfile'
Expand All @@ -50,6 +52,8 @@ task copyDockerfileDependencies(type: Copy) {
rename "slf4j-api.*", "slf4j-api.jar"
rename "slf4j-jdk14.*", "slf4j-jdk14.jar"
rename 'beam-sdks-java-harness-.*.jar', 'beam-sdks-java-harness.jar'
rename 'beam-sdks-java-io-kafka.*.jar', 'beam-sdks-java-io-kafka.jar'
rename 'kafka-clients.*.jar', 'kafka-clients.jar'
into "build/target"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,18 +458,6 @@ private static Coder resolveCoder(Class deserializer) {
}
throw new RuntimeException("Couldn't resolve coder for Deserializer: " + deserializer);
}

private static Class resolveClass(String className) {
try {
return Class.forName(className);
} catch (ClassNotFoundException e) {
throw new RuntimeException("Could not find deserializer class: " + className);
}
}

private static String utf8String(byte[] bytes) {
return new String(bytes, Charsets.UTF_8);
}
}

/**
Expand All @@ -486,7 +474,7 @@ public Map<String, Class<? extends ExternalTransformBuilder>> knownBuilders() {
return ImmutableMap.of(URN, AutoValue_KafkaIO_Read.Builder.class);
}

/** Parameters class to expose the transform to an external SDK. */
/** Parameters class to expose the Read transform to an external SDK. */
public static class Configuration {

// All byte arrays are UTF-8 encoded strings
Expand Down Expand Up @@ -1325,12 +1313,77 @@ public abstract static class Write<K, V> extends PTransform<PCollection<KV<K, V>
abstract Builder<K, V> toBuilder();

@AutoValue.Builder
abstract static class Builder<K, V> {
abstract static class Builder<K, V>
implements ExternalTransformBuilder<External.Configuration, PCollection<KV<K, V>>, PDone> {
abstract Builder<K, V> setTopic(String topic);

abstract Builder<K, V> setWriteRecordsTransform(WriteRecords<K, V> transform);

abstract Write<K, V> build();

@Override
public PTransform<PCollection<KV<K, V>>, PDone> buildExternal(
External.Configuration configuration) {
String topic = utf8String(configuration.topic);
setTopic(topic);

Map<String, Object> producerConfig = new HashMap<>();
for (KV<byte[], byte[]> kv : configuration.producerConfig) {
String key = utf8String(kv.getKey());
String value = utf8String(kv.getValue());
producerConfig.put(key, value);
}
Class keySerializer = resolveClass(utf8String(configuration.keySerializer));
Class valSerializer = resolveClass(utf8String(configuration.valueSerializer));

WriteRecords<K, V> writeRecords =
KafkaIO.<K, V>writeRecords()
.updateProducerProperties(producerConfig)
.withKeySerializer(keySerializer)
.withValueSerializer(valSerializer)
.withTopic(topic);
setWriteRecordsTransform(writeRecords);

return build();
}
}

/** Exposes {@link KafkaIO.Write} as an external transform for cross-language usage. */
@AutoService(ExternalTransformRegistrar.class)
public static class External implements ExternalTransformRegistrar {

public static final String URN = "beam:external:java:kafka:write:v1";

@Override
public Map<String, Class<? extends ExternalTransformBuilder>> knownBuilders() {
return ImmutableMap.of(URN, AutoValue_KafkaIO_Write.Builder.class);
}

/** Parameters class to expose the Write transform to an external SDK. */
public static class Configuration {

// All byte arrays are UTF-8 encoded strings
private Iterable<KV<byte[], byte[]>> producerConfig;
private byte[] topic;
private byte[] keySerializer;
private byte[] valueSerializer;

public void setProducerConfig(Iterable<KV<byte[], byte[]>> producerConfig) {
this.producerConfig = producerConfig;
}

public void setTopic(byte[] topic) {
this.topic = topic;
}

public void setKeySerializer(byte[] keySerializer) {
this.keySerializer = keySerializer;
}

public void setValueSerializer(byte[] valueSerializer) {
this.valueSerializer = valueSerializer;
}
}
}

/** Used mostly to reduce using of boilerplate of wrapping {@link WriteRecords} methods. */
Expand Down Expand Up @@ -1580,4 +1633,16 @@ static <T> NullableCoder<T> inferCoder(
throw new RuntimeException(
String.format("Could not extract the Kafka Deserializer type from %s", deserializer));
}

private static String utf8String(byte[] bytes) {
return new String(bytes, Charsets.UTF_8);
}

private static Class resolveClass(String className) {
try {
return Class.forName(className);
} catch (ClassNotFoundException e) {
throw new RuntimeException("Could not find class: " + className);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,38 @@
import org.apache.beam.model.expansion.v1.ExpansionApi;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.ReadTranslation;
import org.apache.beam.runners.core.construction.expansion.ExpansionService;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.WithKeys;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p13p1.io.grpc.stub.StreamObserver;
import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Charsets;
import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.kafka.clients.consumer.ConsumerConfig;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.hamcrest.Matchers;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.internal.util.reflection.Whitebox;

/** Tests for building {@link KafkaIO} externally via the ExpansionService. */
@RunWith(JUnit4.class)
public class KafkaIOExternalTest {
@Test
public void testConstructKafkaIO() throws Exception {
public void testConstructKafkaRead() throws Exception {
List<String> topics = ImmutableList.of("topic1", "topic2");
String keyDeserializer = "org.apache.kafka.common.serialization.ByteArrayDeserializer";
String valueDeserializer = "org.apache.kafka.common.serialization.LongDeserializer";
Expand Down Expand Up @@ -136,10 +145,112 @@ public void testConstructKafkaIO() throws Exception {
assertThat(spec.getValueDeserializer().getName(), Matchers.is(valueDeserializer));
}

@Test
public void testConstructKafkaWrite() throws Exception {
String topic = "topic";
String keySerializer = "org.apache.kafka.common.serialization.ByteArraySerializer";
String valueSerializer = "org.apache.kafka.common.serialization.LongSerializer";
ImmutableMap<String, String> producerConfig =
ImmutableMap.<String, String>builder()
.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "server1:port,server2:port")
.put("retries", "3")
.build();

ExternalTransforms.ExternalConfigurationPayload payload =
ExternalTransforms.ExternalConfigurationPayload.newBuilder()
.putConfiguration(
"topic",
ExternalTransforms.ConfigValue.newBuilder()
.addCoderUrn("beam:coder:bytes:v1")
.setPayload(ByteString.copyFrom(encodeString(topic)))
.build())
.putConfiguration(
"producer_config",
ExternalTransforms.ConfigValue.newBuilder()
.addCoderUrn("beam:coder:iterable:v1")
.addCoderUrn("beam:coder:kv:v1")
.addCoderUrn("beam:coder:bytes:v1")
.addCoderUrn("beam:coder:bytes:v1")
.setPayload(ByteString.copyFrom(mapAsBytes(producerConfig)))
.build())
.putConfiguration(
"key_serializer",
ExternalTransforms.ConfigValue.newBuilder()
.addCoderUrn("beam:coder:bytes:v1")
.setPayload(ByteString.copyFrom(encodeString(keySerializer)))
.build())
.putConfiguration(
"value_serializer",
ExternalTransforms.ConfigValue.newBuilder()
.addCoderUrn("beam:coder:bytes:v1")
.setPayload(ByteString.copyFrom(encodeString(valueSerializer)))
.build())
.build();

Pipeline p = Pipeline.create();
p.apply(Impulse.create()).apply(WithKeys.of("key"));
RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p);
String inputPCollection =
Iterables.getOnlyElement(
Iterables.getLast(pipelineProto.getComponents().getTransformsMap().values())
.getOutputsMap()
.values());

ExpansionApi.ExpansionRequest request =
ExpansionApi.ExpansionRequest.newBuilder()
.setComponents(pipelineProto.getComponents())
.setTransform(
RunnerApi.PTransform.newBuilder()
.setUniqueName("test")
.putInputs("input", inputPCollection)
.setSpec(
RunnerApi.FunctionSpec.newBuilder()
.setUrn("beam:external:java:kafka:write:v1")
.setPayload(payload.toByteString())))
.setNamespace("test_namespace")
.build();

ExpansionService expansionService = new ExpansionService();
TestStreamObserver<ExpansionApi.ExpansionResponse> observer = new TestStreamObserver<>();
expansionService.expand(request, observer);

ExpansionApi.ExpansionResponse result = observer.result;
RunnerApi.PTransform transform = result.getTransform();
assertThat(
transform.getSubtransformsList(),
Matchers.contains(
"test_namespacetest/Kafka ProducerRecord", "test_namespacetest/KafkaIO.WriteRecords"));
assertThat(transform.getInputsCount(), Matchers.is(1));
assertThat(transform.getOutputsCount(), Matchers.is(0));

RunnerApi.PTransform writeComposite =
result.getComponents().getTransformsOrThrow(transform.getSubtransforms(1));
RunnerApi.PTransform writeParDo =
result
.getComponents()
.getTransformsOrThrow(
result
.getComponents()
.getTransformsOrThrow(writeComposite.getSubtransforms(0))
.getSubtransforms(0));

RunnerApi.ParDoPayload parDoPayload =
RunnerApi.ParDoPayload.parseFrom(writeParDo.getSpec().getPayload());
DoFn kafkaWriter = ParDoTranslation.getDoFn(parDoPayload);
assertThat(kafkaWriter, Matchers.instanceOf(KafkaWriter.class));
KafkaIO.WriteRecords spec =
(KafkaIO.WriteRecords) Whitebox.getInternalState(kafkaWriter, "spec");

assertThat(spec.getProducerConfig(), Matchers.is(producerConfig));
assertThat(spec.getTopic(), Matchers.is(topic));
assertThat(spec.getKeySerializer().getName(), Matchers.is(keySerializer));
assertThat(spec.getValueSerializer().getName(), Matchers.is(valueSerializer));
}

private static byte[] listAsBytes(List<String> stringList) throws IOException {
IterableCoder<byte[]> coder = IterableCoder.of(ByteArrayCoder.of());
List<byte[]> bytesList =
stringList.stream().map(KafkaIOExternalTest::rawBytes).collect(Collectors.toList());
stringList.stream().map(KafkaIOExternalTest::utf8Bytes).collect(Collectors.toList());
ByteArrayOutputStream baos = new ByteArrayOutputStream();
coder.encode(bytesList, baos);
return baos.toByteArray();
Expand All @@ -150,7 +261,7 @@ private static byte[] mapAsBytes(Map<String, String> stringMap) throws IOExcepti
IterableCoder.of(KvCoder.of(ByteArrayCoder.of(), ByteArrayCoder.of()));
List<KV<byte[], byte[]>> bytesList =
stringMap.entrySet().stream()
.map(kv -> KV.of(rawBytes(kv.getKey()), rawBytes(kv.getValue())))
.map(kv -> KV.of(utf8Bytes(kv.getKey()), utf8Bytes(kv.getValue())))
.collect(Collectors.toList());
ByteArrayOutputStream baos = new ByteArrayOutputStream();
coder.encode(bytesList, baos);
Expand All @@ -159,11 +270,11 @@ private static byte[] mapAsBytes(Map<String, String> stringMap) throws IOExcepti

private static byte[] encodeString(String str) throws IOException {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ByteArrayCoder.of().encode(rawBytes(str), baos);
ByteArrayCoder.of().encode(utf8Bytes(str), baos);
return baos.toByteArray();
}

private static byte[] rawBytes(String str) {
private static byte[] utf8Bytes(String str) {
Preconditions.checkNotNull(str, "String must not be null.");
return str.getBytes(Charsets.UTF_8);
}
Expand Down
Loading