diff --git a/speech/grpc/pom.xml b/speech/grpc/pom.xml
index 71d4536f246..ca857bb323b 100644
--- a/speech/grpc/pom.xml
+++ b/speech/grpc/pom.xml
@@ -156,6 +156,12 @@ limitations under the License.
0.31
test
+
+ org.mockito
+ mockito-all
+ 1.10.19
+ test
+
io.grpc
grpc-auth
diff --git a/speech/grpc/src/main/java/com/examples/cloud/speech/StreamingRecognizeClient.java b/speech/grpc/src/main/java/com/examples/cloud/speech/StreamingRecognizeClient.java
index 03d029ec1fe..e7f17e1d29a 100644
--- a/speech/grpc/src/main/java/com/examples/cloud/speech/StreamingRecognizeClient.java
+++ b/speech/grpc/src/main/java/com/examples/cloud/speech/StreamingRecognizeClient.java
@@ -23,10 +23,10 @@
import com.google.cloud.speech.v1beta1.RecognitionConfig.AudioEncoding;
import com.google.cloud.speech.v1beta1.SpeechGrpc;
import com.google.cloud.speech.v1beta1.StreamingRecognitionConfig;
+import com.google.cloud.speech.v1beta1.StreamingRecognitionResult;
import com.google.cloud.speech.v1beta1.StreamingRecognizeRequest;
import com.google.cloud.speech.v1beta1.StreamingRecognizeResponse;
import com.google.protobuf.ByteString;
-import com.google.protobuf.TextFormat;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
@@ -44,14 +44,17 @@
import org.apache.log4j.Logger;
import org.apache.log4j.SimpleLayout;
-import java.io.File;
-import java.io.FileInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
+import javax.sound.sampled.AudioFormat;
+import javax.sound.sampled.AudioSystem;
+import javax.sound.sampled.DataLine;
+import javax.sound.sampled.LineUnavailableException;
+import javax.sound.sampled.TargetDataLine;
/**
@@ -59,37 +62,36 @@
*/
public class StreamingRecognizeClient {
- private final String file;
- private final int samplingRate;
-
private static final Logger logger = Logger.getLogger(StreamingRecognizeClient.class.getName());
private final ManagedChannel channel;
-
private final SpeechGrpc.SpeechStub speechClient;
-
- private static final int BYTES_PER_BUFFER = 3200; //buffer size in bytes
- private static final int BYTES_PER_SAMPLE = 2; //bytes per sample for LINEAR16
-
private static final List OAUTH2_SCOPES =
Arrays.asList("https://www.googleapis.com/auth/cloud-platform");
+ static final int BYTES_PER_SAMPLE = 2; // bytes per sample for LINEAR16
+
+ private final int samplingRate;
+ final int bytesPerBuffer; // buffer size in bytes
+
+ // Used for testing
+ protected TargetDataLine mockDataLine = null;
+
/**
* Construct client connecting to Cloud Speech server at {@code host:port}.
*/
- public StreamingRecognizeClient(ManagedChannel channel, String file, int samplingRate)
+ public StreamingRecognizeClient(ManagedChannel channel, int samplingRate)
throws IOException {
- this.file = file;
this.samplingRate = samplingRate;
this.channel = channel;
+ this.bytesPerBuffer = samplingRate * BYTES_PER_SAMPLE / 10; // 100 ms
speechClient = SpeechGrpc.newStub(channel);
// Send log4j logs to Console
// If you are going to run this on GCE, you might wish to integrate with
- // google-cloud-java logging. See:
+ // google-cloud-java logging. See:
// https://github.com/GoogleCloudPlatform/google-cloud-java/blob/master/README.md#stackdriver-logging-alpha
-
ConsoleAppender appender = new ConsoleAppender(new SimpleLayout(), SYSTEM_OUT);
logger.addAppender(appender);
}
@@ -109,19 +111,73 @@ static ManagedChannel createChannel(String host, int port) throws IOException {
return channel;
}
+ /**
+ * Return a Line to the audio input device.
+ */
+ private TargetDataLine getAudioInputLine() {
+ // For testing
+ if (null != mockDataLine) {
+ return mockDataLine;
+ }
+
+ AudioFormat format = new AudioFormat(samplingRate, BYTES_PER_SAMPLE * 8, 1, true, false);
+ DataLine.Info info = new DataLine.Info(TargetDataLine.class, format);
+ if (!AudioSystem.isLineSupported(info)) {
+ throw new RuntimeException(String.format(
+ "Device doesn't support LINEAR16 mono raw audio format at {}Hz", samplingRate));
+ }
+ try {
+ TargetDataLine line = (TargetDataLine) AudioSystem.getLine(info);
+ // Make sure the line buffer doesn't overflow while we're filling this thread's buffer.
+ line.open(format, bytesPerBuffer * 5);
+ return line;
+ } catch (LineUnavailableException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
/** Send streaming recognize requests to server. */
public void recognize() throws InterruptedException, IOException {
final CountDownLatch finishLatch = new CountDownLatch(1);
StreamObserver responseObserver =
new StreamObserver() {
+ private int sentenceLength = 1;
+ /**
+ * Prints the transcription results. Interim results are overwritten by subsequent
+ * results, until a final one is returned, at which point we start a new line.
+ *
+ * Flags the program to exit when it hears "exit".
+ */
@Override
public void onNext(StreamingRecognizeResponse response) {
- logger.info("Received response: " + TextFormat.printToString(response));
+ List results = response.getResultsList();
+ if (results.size() < 1) {
+ return;
+ }
+
+ StreamingRecognitionResult result = results.get(0);
+ String transcript = result.getAlternatives(0).getTranscript();
+
+ // Print interim results with a line feed, so subsequent transcriptions will overwrite
+ // it. Final result will print a newline.
+ String format = "%-" + this.sentenceLength + 's';
+ if (result.getIsFinal()) {
+ format += '\n';
+ this.sentenceLength = 1;
+
+ if (transcript.toLowerCase().indexOf("exit") >= 0) {
+ finishLatch.countDown();
+ }
+ } else {
+ format += '\r';
+ this.sentenceLength = transcript.length();
+ }
+ System.out.print(String.format(format, transcript));
}
@Override
public void onError(Throwable error) {
- logger.log(Level.WARN, "recognize failed: {0}", error);
+ logger.log(Level.ERROR, "recognize failed: {0}", error);
finishLatch.countDown();
}
@@ -146,33 +202,28 @@ public void onCompleted() {
StreamingRecognitionConfig.newBuilder()
.setConfig(config)
.setInterimResults(true)
- .setSingleUtterance(true)
+ .setSingleUtterance(false)
.build();
StreamingRecognizeRequest initial =
StreamingRecognizeRequest.newBuilder().setStreamingConfig(streamingConfig).build();
requestObserver.onNext(initial);
- // Open audio file. Read and send sequential buffers of audio as additional RecognizeRequests.
- FileInputStream in = new FileInputStream(new File(file));
- // For LINEAR16 at 16000 Hz sample rate, 3200 bytes corresponds to 100 milliseconds of audio.
- byte[] buffer = new byte[BYTES_PER_BUFFER];
+ // Get a Line to the audio input device.
+ TargetDataLine in = getAudioInputLine();
+ byte[] buffer = new byte[bytesPerBuffer];
int bytesRead;
- int totalBytes = 0;
- int samplesPerBuffer = BYTES_PER_BUFFER / BYTES_PER_SAMPLE;
- int samplesPerMillis = samplingRate / 1000;
- while ((bytesRead = in.read(buffer)) != -1) {
- totalBytes += bytesRead;
+ in.start();
+ // Read and send sequential buffers of audio as additional RecognizeRequests.
+ while (finishLatch.getCount() > 0
+ && (bytesRead = in.read(buffer, 0, buffer.length)) != -1) {
StreamingRecognizeRequest request =
StreamingRecognizeRequest.newBuilder()
.setAudioContent(ByteString.copyFrom(buffer, 0, bytesRead))
.build();
requestObserver.onNext(request);
- // To simulate real-time audio, sleep after sending each audio buffer.
- Thread.sleep(samplesPerBuffer / samplesPerMillis);
}
- logger.info("Sent " + totalBytes + " bytes from audio file: " + file);
} catch (RuntimeException e) {
// Cancel RPC.
requestObserver.onError(e);
@@ -187,21 +238,13 @@ public void onCompleted() {
public static void main(String[] args) throws Exception {
- String audioFile = "";
- String host = "speech.googleapis.com";
- Integer port = 443;
- Integer sampling = 16000;
+ String host = null;
+ Integer port = null;
+ Integer sampling = null;
CommandLineParser parser = new DefaultParser();
Options options = new Options();
- options.addOption(
- Option.builder()
- .longOpt("file")
- .desc("path to audio file")
- .hasArg()
- .argName("FILE_PATH")
- .build());
options.addOption(
Option.builder()
.longOpt("host")
@@ -226,31 +269,14 @@ public static void main(String[] args) throws Exception {
try {
CommandLine line = parser.parse(options, args);
- if (line.hasOption("file")) {
- audioFile = line.getOptionValue("file");
- } else {
- System.err.println("An Audio file must be specified (e.g. /foo/baz.raw).");
- System.exit(1);
- }
-
- if (line.hasOption("host")) {
- host = line.getOptionValue("host");
- } else {
- System.err.println("An API enpoint must be specified (typically speech.googleapis.com).");
- System.exit(1);
- }
- if (line.hasOption("port")) {
- port = Integer.parseInt(line.getOptionValue("port"));
- } else {
- System.err.println("An SSL port must be specified (typically 443).");
- System.exit(1);
- }
+ host = line.getOptionValue("host", "speech.googleapis.com");
+ port = Integer.parseInt(line.getOptionValue("port", "443"));
if (line.hasOption("sampling")) {
sampling = Integer.parseInt(line.getOptionValue("sampling"));
} else {
- System.err.println("An Audio sampling rate must be specified.");
+ System.err.println("An Audio sampling rate (--sampling) must be specified. (e.g. 16000)");
System.exit(1);
}
} catch (ParseException exp) {
@@ -259,7 +285,7 @@ public static void main(String[] args) throws Exception {
}
ManagedChannel channel = createChannel(host, port);
- StreamingRecognizeClient client = new StreamingRecognizeClient(channel, audioFile, sampling);
+ StreamingRecognizeClient client = new StreamingRecognizeClient(channel, sampling);
try {
client.recognize();
} finally {
diff --git a/speech/grpc/src/test/java/com/examples/cloud/speech/StreamingRecognizeClientTest.java b/speech/grpc/src/test/java/com/examples/cloud/speech/StreamingRecognizeClientTest.java
index 773af36c707..7ed15a0fe9e 100644
--- a/speech/grpc/src/test/java/com/examples/cloud/speech/StreamingRecognizeClientTest.java
+++ b/speech/grpc/src/test/java/com/examples/cloud/speech/StreamingRecognizeClientTest.java
@@ -17,24 +17,26 @@
package com.examples.cloud.speech;
import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.anyInt;
+import static org.mockito.Mockito.when;
import io.grpc.ManagedChannel;
-import org.apache.log4j.Logger;
-import org.apache.log4j.SimpleLayout;
-import org.apache.log4j.WriterAppender;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
-import java.io.File;
+import java.io.ByteArrayOutputStream;
+import java.io.FileInputStream;
import java.io.IOException;
-import java.io.StringWriter;
-import java.io.Writer;
-import java.net.URI;
-import java.nio.file.Path;
-import java.nio.file.Paths;
+import java.io.PrintStream;
+import javax.sound.sampled.TargetDataLine;
/**
@@ -42,46 +44,100 @@
*/
@RunWith(JUnit4.class)
public class StreamingRecognizeClientTest {
- private Writer writer;
- private WriterAppender appender;
+ private final ByteArrayOutputStream stdout = new ByteArrayOutputStream();
+ private static final PrintStream REAL_OUT = System.out;
+
+ @Mock private TargetDataLine mockDataLine;
@Before
public void setUp() {
- writer = new StringWriter();
- appender = new WriterAppender(new SimpleLayout(), writer);
- Logger.getRootLogger().addAppender(appender);
+ MockitoAnnotations.initMocks(this);
+ System.setOut(new PrintStream(stdout));
}
@After
public void tearDown() {
- Logger.getRootLogger().removeAppender(appender);
+ System.setOut(REAL_OUT);
}
@Test
public void test16KHzAudio() throws InterruptedException, IOException {
- URI uri = new File("resources/audio.raw").toURI();
- Path path = Paths.get(uri);
-
String host = "speech.googleapis.com";
int port = 443;
ManagedChannel channel = StreamingRecognizeClient.createChannel(host, port);
- StreamingRecognizeClient client = new StreamingRecognizeClient(channel, path.toString(), 16000);
+
+ final FileInputStream in = new FileInputStream("resources/audio.raw");
+
+ final int samplingRate = 16000;
+ final StreamingRecognizeClient client = new StreamingRecognizeClient(channel, samplingRate);
+
+ // When audio data is requested from the mock, get it from the file
+ when(mockDataLine.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer() {
+ public Object answer(InvocationOnMock invocation) {
+ Object[] args = invocation.getArguments();
+ byte[] buffer = (byte[])args[0];
+ int offset = (int)args[1];
+ int len = (int)args[2];
+ assertThat(buffer.length).isEqualTo(len);
+
+ try {
+ // Sleep, to simulate realtime
+ int samplesPerBuffer = client.bytesPerBuffer / StreamingRecognizeClient.BYTES_PER_SAMPLE;
+ int samplesPerMillis = samplingRate / 1000;
+ Thread.sleep(samplesPerBuffer / samplesPerMillis);
+
+ // Provide the audio bytes from the file
+ return in.read(buffer, offset, len);
+
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ });
+ client.mockDataLine = mockDataLine;
client.recognize();
- assertThat(writer.toString()).contains("transcript: \"how old is the Brooklyn Bridge\"");
+
+ assertThat(stdout.toString()).contains("how old is the Brooklyn Bridge");
}
@Test
public void test32KHzAudio() throws InterruptedException, IOException {
- URI uri = new File("resources/audio32KHz.raw").toURI();
- Path path = Paths.get(uri);
-
String host = "speech.googleapis.com";
int port = 443;
ManagedChannel channel = StreamingRecognizeClient.createChannel(host, port);
- StreamingRecognizeClient client = new StreamingRecognizeClient(channel, path.toString(), 32000);
+
+ final FileInputStream in = new FileInputStream("resources/audio32KHz.raw");
+
+ final int samplingRate = 32000;
+ final StreamingRecognizeClient client = new StreamingRecognizeClient(channel, samplingRate);
+
+ // When audio data is requested from the mock, get it from the file
+ when(mockDataLine.read(any(byte[].class), anyInt(), anyInt())).thenAnswer(new Answer() {
+ public Object answer(InvocationOnMock invocation) {
+ Object[] args = invocation.getArguments();
+ byte[] buffer = (byte[])args[0];
+ int offset = (int)args[1];
+ int len = (int)args[2];
+ assertThat(buffer.length).isEqualTo(len);
+
+ try {
+ // Sleep, to simulate realtime
+ int samplesPerBuffer = client.bytesPerBuffer / StreamingRecognizeClient.BYTES_PER_SAMPLE;
+ int samplesPerMillis = samplingRate / 1000;
+ Thread.sleep(samplesPerBuffer / samplesPerMillis);
+
+ // Provide the audio bytes from the file
+ return in.read(buffer, offset, len);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ });
+ client.mockDataLine = mockDataLine;
client.recognize();
- assertThat(writer.toString()).contains("transcript: \"how old is the Brooklyn Bridge\"");
+
+ assertThat(stdout.toString()).contains("how old is the Brooklyn Bridge");
}
}