diff --git a/sdks/java/io/jms/build.gradle b/sdks/java/io/jms/build.gradle index 69cd88f81d1f..3f9d0c719ec6 100644 --- a/sdks/java/io/jms/build.gradle +++ b/sdks/java/io/jms/build.gradle @@ -36,6 +36,7 @@ dependencies { testCompile library.java.activemq_kahadb_store testCompile library.java.activemq_client testCompile library.java.junit + testCompile library.java.mockito_core testRuntimeOnly library.java.slf4j_jdk14 testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") } diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java new file mode 100644 index 000000000000..0e023d1aae09 --- /dev/null +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/AutoScaler.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jms; + +import java.io.Serializable; +import org.apache.beam.sdk.io.UnboundedSource; + +/** + * Enables users to specify their own `JMS` backlog reporters enabling {@link JmsIO} to report + * {@link UnboundedSource.UnboundedReader#getTotalBacklogBytes()}. + */ +public interface AutoScaler extends Serializable { + + /** The {@link AutoScaler} is started when the {@link JmsIO.UnboundedJmsReader} is started. */ + void start(); + + /** + * Returns the size of the backlog of unread data in the underlying data source represented by all + * splits of this source. + */ + long getTotalBacklogBytes(); + + /** The {@link AutoScaler} is stopped when the {@link JmsIO.UnboundedJmsReader} is closed. */ + void stop(); +} diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java new file mode 100644 index 000000000000..2b05cf630bf8 --- /dev/null +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/DefaultAutoscaler.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.jms; + +import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN; + +/** + * Default implementation of {@link AutoScaler}. Returns {@link + * org.apache.beam.sdk.io.UnboundedSource.UnboundedReader#BACKLOG_UNKNOWN} as the default value. + */ +public class DefaultAutoscaler implements AutoScaler { + @Override + public void start() {} + + @Override + public long getTotalBacklogBytes() { + return BACKLOG_UNKNOWN; + } + + @Override + public void stop() {} +} diff --git a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java index 4999e10ee9e8..9fa4492cf235 100644 --- a/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java +++ b/sdks/java/io/jms/src/main/java/org/apache/beam/sdk/io/jms/JmsIO.java @@ -196,6 +196,8 @@ public abstract static class Read extends PTransform> abstract @Nullable Coder getCoder(); + abstract @Nullable AutoScaler getAutoScaler(); + abstract Builder builder(); @AutoValue.Builder @@ -218,6 +220,8 @@ abstract static class Builder { abstract Builder setCoder(Coder coder); + abstract Builder setAutoScaler(AutoScaler autoScaler); + abstract Read build(); } @@ -344,6 +348,14 @@ public Read withCoder(Coder coder) { return builder().setCoder(coder).build(); } + /** + * Sets the {@link AutoScaler} to use for reporting backlog during the execution of this source. + */ + public Read withAutoScaler(AutoScaler autoScaler) { + checkArgument(autoScaler != null, "autoScaler can not be null"); + return builder().setAutoScaler(autoScaler).build(); + } + @Override public PCollection expand(PBegin input) { checkArgument(getConnectionFactory() != null, "withConnectionFactory() is required"); @@ -447,6 +459,7 @@ static class UnboundedJmsReader extends UnboundedReader { private Connection connection; private Session session; private MessageConsumer consumer; + private AutoScaler autoScaler; private T currentMessage; private Instant currentTimestamp; @@ -474,6 +487,12 @@ public boolean start() throws IOException { } connection.start(); this.connection = connection; + if (spec.getAutoScaler() == null) { + this.autoScaler = new DefaultAutoscaler(); + } else { + this.autoScaler = spec.getAutoScaler(); + } + this.autoScaler.start(); } catch (Exception e) { throw new IOException("Error connecting to JMS", e); } @@ -544,6 +563,11 @@ public CheckpointMark getCheckpointMark() { return checkpointMark; } + @Override + public long getTotalBacklogBytes() { + return this.autoScaler.getTotalBacklogBytes(); + } + @Override public UnboundedSource getCurrentSource() { return source; @@ -565,6 +589,10 @@ public void close() throws IOException { connection.close(); connection = null; } + if (autoScaler != null) { + autoScaler.stop(); + autoScaler = null; + } } catch (Exception e) { throw new IOException(e); } diff --git a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java index c335f8ab11de..a9f3c3f004ef 100644 --- a/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java +++ b/sdks/java/io/jms/src/test/java/org/apache/beam/sdk/io/jms/JmsIOTest.java @@ -17,12 +17,17 @@ */ package org.apache.beam.sdk.io.jms; +import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.IOException; import java.lang.reflect.Proxy; @@ -421,6 +426,50 @@ public void testCheckpointMarkDefaultCoder() throws Exception { CoderProperties.coderDecodeEncodeEqual(coder, jmsCheckpointMark); } + @Test + public void testDefaultAutoscaler() throws IOException { + JmsIO.Read spec = + JmsIO.read() + .withConnectionFactory(connectionFactory) + .withUsername(USERNAME) + .withPassword(PASSWORD) + .withQueue(QUEUE); + JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec); + JmsIO.UnboundedJmsReader reader = source.createReader(null, null); + + // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values + reader.start(); + assertEquals(BACKLOG_UNKNOWN, reader.getSplitBacklogBytes()); + assertEquals(BACKLOG_UNKNOWN, reader.getTotalBacklogBytes()); + reader.close(); + } + + @Test + public void testCustomAutoscaler() throws IOException { + long excpectedTotalBacklogBytes = 1111L; + + AutoScaler autoScaler = mock(DefaultAutoscaler.class); + when(autoScaler.getTotalBacklogBytes()).thenReturn(excpectedTotalBacklogBytes); + JmsIO.Read spec = + JmsIO.read() + .withConnectionFactory(connectionFactory) + .withUsername(USERNAME) + .withPassword(PASSWORD) + .withQueue(QUEUE) + .withAutoScaler(autoScaler); + + JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec); + JmsIO.UnboundedJmsReader reader = source.createReader(null, null); + + // start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values + reader.start(); + verify(autoScaler, times(1)).start(); + assertEquals(excpectedTotalBacklogBytes, reader.getTotalBacklogBytes()); + verify(autoScaler, times(1)).getTotalBacklogBytes(); + reader.close(); + verify(autoScaler, times(1)).stop(); + } + private int count(String queue) throws Exception { Connection connection = connectionFactory.createConnection(USERNAME, PASSWORD); connection.start();