From ea621d1be7a896d78337f5563aa1dcdca2930371 Mon Sep 17 00:00:00 2001 From: Rahul Mittal Date: Mon, 24 Apr 2023 17:34:44 +0530 Subject: [PATCH] add s2s and s2t client utility functions --- riva/client/__init__.py | 1 + riva/client/nmt.py | 109 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 108 insertions(+), 2 deletions(-) diff --git a/riva/client/__init__.py b/riva/client/__init__.py index d08430bd..d4a22785 100644 --- a/riva/client/__init__.py +++ b/riva/client/__init__.py @@ -36,5 +36,6 @@ from riva.client.proto.riva_asr_pb2 import RecognitionConfig, StreamingRecognitionConfig from riva.client.proto.riva_audio_pb2 import AudioEncoding from riva.client.proto.riva_nlp_pb2 import AnalyzeIntentOptions +from riva.client.proto.riva_nmt_pb2 import StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig, StreamingTranslateSpeechToTextConfig from riva.client.tts import SpeechSynthesisService from riva.client.nmt import NeuralMachineTranslationClient diff --git a/riva/client/nmt.py b/riva/client/nmt.py index 0a261522..29cb4f8d 100644 --- a/riva/client/nmt.py +++ b/riva/client/nmt.py @@ -1,14 +1,26 @@ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: MIT -from typing import Generator, Optional, Union, List - +from typing import Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union from grpc._channel import _MultiThreadedRendezvous import riva.client.proto.riva_nmt_pb2 as riva_nmt import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv from riva.client import Auth +def streaming_s2s_request_generator( + audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig +) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechRequest, None, None]: + yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(config=streaming_config) + for chunk in audio_chunks: + yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(audio_content=chunk) + +def streaming_s2t_request_generator( + audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToTextConfig +) -> Generator[riva_nmt.StreamingTranslateSpeechToTextRequest, None, None]: + yield riva_nmt.StreamingTranslateSpeechToTextRequest(config=streaming_config) + for chunk in audio_chunks: + yield riva_nmt.StreamingTranslateSpeechToTextRequest(audio_content=chunk) class NeuralMachineTranslationClient: """ @@ -25,6 +37,99 @@ def __init__(self, auth: Auth) -> None: self.auth = auth self.stub = riva_nmt_srv.RivaTranslationStub(self.auth.channel) + def streaming_s2s_response_generator( + self, audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig + ) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechResponse, None, None]: + """ + Generates speech to speech translation responses for fragments of speech audio in :param:`audio_chunks`. + The purpose of the method is to perform speech to speech translation "online" - as soon as + audio is acquired on small chunks of audio. + + All available audio chunks will be sent to a server on first ``next()`` call. + + Args: + audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments + of speech. For example, such raw audio can be obtained with + + .. code-block:: python + + import wave + with wave.open(file_name, 'rb') as wav_f: + raw_audio = wav_f.readframes(n_frames) + + streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechConfig`): a config for streaming. + You may find description of config fields in message ``StreamingTranslateSpeechToSpeechConfig`` in + `common repo + `_. + An example of creation of streaming config: + + .. code-style:: python + + from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig + config = RecognitionConfig(enable_automatic_punctuation=True) + asr_config = StreamingRecognitionConfig(config, interim_results=True) + translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US") + tts_config = SynthesizeSpeechConfig(sample_rate_hz=44100, voice_name="English-US.Female-1") + streaming_config = StreamingTranslateSpeechToSpeechConfig(asr_config, translation_config, tts_config) + + Yields: + :obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechResponse`: responses for audio chunks in + :param:`audio_chunks`. You may find description of response fields in declaration of + ``StreamingTranslateSpeechToSpeechResponse`` + message `here + `_. + """ + generator = streaming_s2s_request_generator(audio_chunks, streaming_config) + for response in self.stub.StreamingTranslateSpeechToSpeech(generator, metadata=self.auth.get_auth_metadata()): + yield response + + + def streaming_s2t_response_generator( + self, audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToTextConfig + ) -> Generator[riva_nmt.StreamingTranslateSpeechToTextResponse, None, None]: + """ + Generates speech to text translation responses for fragments of speech audio in :param:`audio_chunks`. + The purpose of the method is to perform speech to text translation "online" - as soon as + audio is acquired on small chunks of audio. + + All available audio chunks will be sent to a server on first ``next()`` call. + + Args: + audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments + of speech. For example, such raw audio can be obtained with + + .. code-block:: python + + import wave + with wave.open(file_name, 'rb') as wav_f: + raw_audio = wav_f.readframes(n_frames) + + streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextConfig`): a config for streaming. + You may find description of config fields in message ``StreamingTranslateSpeechToTextConfig`` in + `common repo + `_. + An example of creation of streaming config: + + .. code-style:: python + + from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToTextConfig, TranslationConfig + config = RecognitionConfig(enable_automatic_punctuation=True) + asr_config = StreamingRecognitionConfig(config, interim_results=True) + translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US") + streaming_config = StreamingTranslateSpeechToTextConfig(asr_config, translation_config) + + Yields: + :obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextResponse`: responses for audio chunks in + :param:`audio_chunks`. You may find description of response fields in declaration of + ``StreamingTranslateSpeechToTextResponse`` + message `here + `_. + """ + generator = streaming_s2t_request_generator(audio_chunks, streaming_config) + for response in self.stub.StreamingTranslateSpeechToText(generator, metadata=self.auth.get_auth_metadata()): + yield response + + def translate( self, texts: List[str],