Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions riva/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@
from riva.client.proto.riva_asr_pb2 import RecognitionConfig, StreamingRecognitionConfig
from riva.client.proto.riva_audio_pb2 import AudioEncoding
from riva.client.proto.riva_nlp_pb2 import AnalyzeIntentOptions
from riva.client.proto.riva_nmt_pb2 import StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig, StreamingTranslateSpeechToTextConfig
from riva.client.tts import SpeechSynthesisService
from riva.client.nmt import NeuralMachineTranslationClient
109 changes: 107 additions & 2 deletions riva/client/nmt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: MIT

from typing import Generator, Optional, Union, List

from typing import Callable, Dict, Generator, Iterable, List, Optional, TextIO, Union
from grpc._channel import _MultiThreadedRendezvous

import riva.client.proto.riva_nmt_pb2 as riva_nmt
import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_srv
from riva.client import Auth

def streaming_s2s_request_generator(
audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig
) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechRequest, None, None]:
yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(config=streaming_config)
for chunk in audio_chunks:
yield riva_nmt.StreamingTranslateSpeechToSpeechRequest(audio_content=chunk)

def streaming_s2t_request_generator(
audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToTextConfig
) -> Generator[riva_nmt.StreamingTranslateSpeechToTextRequest, None, None]:
yield riva_nmt.StreamingTranslateSpeechToTextRequest(config=streaming_config)
for chunk in audio_chunks:
yield riva_nmt.StreamingTranslateSpeechToTextRequest(audio_content=chunk)

class NeuralMachineTranslationClient:
"""
Expand All @@ -25,6 +37,99 @@ def __init__(self, auth: Auth) -> None:
self.auth = auth
self.stub = riva_nmt_srv.RivaTranslationStub(self.auth.channel)

def streaming_s2s_response_generator(
self, audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToSpeechConfig
) -> Generator[riva_nmt.StreamingTranslateSpeechToSpeechResponse, None, None]:
"""
Generates speech to speech translation responses for fragments of speech audio in :param:`audio_chunks`.
The purpose of the method is to perform speech to speech translation "online" - as soon as
audio is acquired on small chunks of audio.

All available audio chunks will be sent to a server on first ``next()`` call.

Args:
audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments
of speech. For example, such raw audio can be obtained with

.. code-block:: python

import wave
with wave.open(file_name, 'rb') as wav_f:
raw_audio = wav_f.readframes(n_frames)

streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechConfig`): a config for streaming.
You may find description of config fields in message ``StreamingTranslateSpeechToSpeechConfig`` in
`common repo
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
An example of creation of streaming config:

.. code-style:: python

from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig
config = RecognitionConfig(enable_automatic_punctuation=True)
asr_config = StreamingRecognitionConfig(config, interim_results=True)
translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US")
tts_config = SynthesizeSpeechConfig(sample_rate_hz=44100, voice_name="English-US.Female-1")
streaming_config = StreamingTranslateSpeechToSpeechConfig(asr_config, translation_config, tts_config)

Yields:
:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToSpeechResponse`: responses for audio chunks in
:param:`audio_chunks`. You may find description of response fields in declaration of
``StreamingTranslateSpeechToSpeechResponse``
message `here
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
"""
generator = streaming_s2s_request_generator(audio_chunks, streaming_config)
for response in self.stub.StreamingTranslateSpeechToSpeech(generator, metadata=self.auth.get_auth_metadata()):
yield response


def streaming_s2t_response_generator(
self, audio_chunks: Iterable[bytes], streaming_config: riva_nmt.StreamingTranslateSpeechToTextConfig
) -> Generator[riva_nmt.StreamingTranslateSpeechToTextResponse, None, None]:
"""
Generates speech to text translation responses for fragments of speech audio in :param:`audio_chunks`.
The purpose of the method is to perform speech to text translation "online" - as soon as
audio is acquired on small chunks of audio.

All available audio chunks will be sent to a server on first ``next()`` call.

Args:
audio_chunks (:obj:`Iterable[bytes]`): an iterable object which contains raw audio fragments
of speech. For example, such raw audio can be obtained with

.. code-block:: python

import wave
with wave.open(file_name, 'rb') as wav_f:
raw_audio = wav_f.readframes(n_frames)

streaming_config (:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextConfig`): a config for streaming.
You may find description of config fields in message ``StreamingTranslateSpeechToTextConfig`` in
`common repo
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
An example of creation of streaming config:

.. code-style:: python

from riva.client import RecognitionConfig, StreamingRecognitionConfig, StreamingTranslateSpeechToTextConfig, TranslationConfig
config = RecognitionConfig(enable_automatic_punctuation=True)
asr_config = StreamingRecognitionConfig(config, interim_results=True)
translation_config = TranslationConfig(source_language_code="es-US", target_language_code="en-US")
streaming_config = StreamingTranslateSpeechToTextConfig(asr_config, translation_config)

Yields:
:obj:`riva.client.proto.riva_nmt_pb2.StreamingTranslateSpeechToTextResponse`: responses for audio chunks in
:param:`audio_chunks`. You may find description of response fields in declaration of
``StreamingTranslateSpeechToTextResponse``
message `here
<https://docs.nvidia.com/deeplearning/riva/user-guide/docs/reference/protos/protos.html#riva-proto-riva-nmt-proto>`_.
"""
generator = streaming_s2t_request_generator(audio_chunks, streaming_config)
for response in self.stub.StreamingTranslateSpeechToText(generator, metadata=self.auth.get_auth_metadata()):
yield response


def translate(
self,
texts: List[str],
Expand Down