Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdks/java/io/jms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies {
testCompile library.java.activemq_kahadb_store
testCompile library.java.activemq_client
testCompile library.java.junit
testCompile library.java.mockito_core
testRuntimeOnly library.java.slf4j_jdk14
testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.jms;

import java.io.Serializable;
import org.apache.beam.sdk.io.UnboundedSource;

/**
* Enables users to specify their own `JMS` backlog reporters enabling {@link JmsIO} to report
* {@link UnboundedSource.UnboundedReader#getTotalBacklogBytes()}.
*/
public interface AutoScaler extends Serializable {

/** The {@link AutoScaler} is started when the {@link JmsIO.UnboundedJmsReader} is started. */
void start();

/**
* Returns the size of the backlog of unread data in the underlying data source represented by all
* splits of this source.
*/
long getTotalBacklogBytes();

/** The {@link AutoScaler} is stopped when the {@link JmsIO.UnboundedJmsReader} is closed. */
void stop();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.jms;

import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;

/**
* Default implementation of {@link AutoScaler}. Returns {@link
* org.apache.beam.sdk.io.UnboundedSource.UnboundedReader#BACKLOG_UNKNOWN} as the default value.
*/
public class DefaultAutoscaler implements AutoScaler {
@Override
public void start() {}

@Override
public long getTotalBacklogBytes() {
return BACKLOG_UNKNOWN;
}

@Override
public void stop() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>>

abstract @Nullable Coder<T> getCoder();

abstract @Nullable AutoScaler getAutoScaler();

abstract Builder<T> builder();

@AutoValue.Builder
Expand All @@ -218,6 +220,8 @@ abstract static class Builder<T> {

abstract Builder<T> setCoder(Coder<T> coder);

abstract Builder<T> setAutoScaler(AutoScaler autoScaler);

abstract Read<T> build();
}

Expand Down Expand Up @@ -344,6 +348,14 @@ public Read<T> withCoder(Coder<T> coder) {
return builder().setCoder(coder).build();
}

/**
* Sets the {@link AutoScaler} to use for reporting backlog during the execution of this source.
*/
public Read<T> withAutoScaler(AutoScaler autoScaler) {
checkArgument(autoScaler != null, "autoScaler can not be null");
return builder().setAutoScaler(autoScaler).build();
}

@Override
public PCollection<T> expand(PBegin input) {
checkArgument(getConnectionFactory() != null, "withConnectionFactory() is required");
Expand Down Expand Up @@ -447,6 +459,7 @@ static class UnboundedJmsReader<T> extends UnboundedReader<T> {
private Connection connection;
private Session session;
private MessageConsumer consumer;
private AutoScaler autoScaler;

private T currentMessage;
private Instant currentTimestamp;
Expand Down Expand Up @@ -474,6 +487,12 @@ public boolean start() throws IOException {
}
connection.start();
this.connection = connection;
if (spec.getAutoScaler() == null) {
this.autoScaler = new DefaultAutoscaler();
} else {
this.autoScaler = spec.getAutoScaler();
}
this.autoScaler.start();
} catch (Exception e) {
throw new IOException("Error connecting to JMS", e);
}
Expand Down Expand Up @@ -544,6 +563,11 @@ public CheckpointMark getCheckpointMark() {
return checkpointMark;
}

@Override
public long getTotalBacklogBytes() {
return this.autoScaler.getTotalBacklogBytes();
}

@Override
public UnboundedSource<T, ?> getCurrentSource() {
return source;
Expand All @@ -565,6 +589,10 @@ public void close() throws IOException {
connection.close();
connection = null;
}
if (autoScaler != null) {
autoScaler.stop();
autoScaler = null;
}
} catch (Exception e) {
throw new IOException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
*/
package org.apache.beam.sdk.io.jms;

import static org.apache.beam.sdk.io.UnboundedSource.UnboundedReader.BACKLOG_UNKNOWN;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import java.io.IOException;
import java.lang.reflect.Proxy;
Expand Down Expand Up @@ -421,6 +426,50 @@ public void testCheckpointMarkDefaultCoder() throws Exception {
CoderProperties.coderDecodeEncodeEqual(coder, jmsCheckpointMark);
}

@Test
public void testDefaultAutoscaler() throws IOException {
JmsIO.Read spec =
JmsIO.read()
.withConnectionFactory(connectionFactory)
.withUsername(USERNAME)
.withPassword(PASSWORD)
.withQueue(QUEUE);
JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
JmsIO.UnboundedJmsReader reader = source.createReader(null, null);

// start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
reader.start();
assertEquals(BACKLOG_UNKNOWN, reader.getSplitBacklogBytes());
assertEquals(BACKLOG_UNKNOWN, reader.getTotalBacklogBytes());
reader.close();
}

@Test
public void testCustomAutoscaler() throws IOException {
long excpectedTotalBacklogBytes = 1111L;

AutoScaler autoScaler = mock(DefaultAutoscaler.class);
when(autoScaler.getTotalBacklogBytes()).thenReturn(excpectedTotalBacklogBytes);
JmsIO.Read spec =
JmsIO.read()
.withConnectionFactory(connectionFactory)
.withUsername(USERNAME)
.withPassword(PASSWORD)
.withQueue(QUEUE)
.withAutoScaler(autoScaler);

JmsIO.UnboundedJmsSource source = new JmsIO.UnboundedJmsSource(spec);
JmsIO.UnboundedJmsReader reader = source.createReader(null, null);

// start the reader and check getSplitBacklogBytes and getTotalBacklogBytes values
reader.start();
verify(autoScaler, times(1)).start();
assertEquals(excpectedTotalBacklogBytes, reader.getTotalBacklogBytes());
verify(autoScaler, times(1)).getTotalBacklogBytes();
reader.close();
verify(autoScaler, times(1)).stop();
}

private int count(String queue) throws Exception {
Connection connection = connectionFactory.createConnection(USERNAME, PASSWORD);
connection.start();
Expand Down