Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@

import argparse

def validate_grpc_message_size(value):
"""Validate that the GRPC message size is within acceptable limits."""
min_size = 4 * 1024 * 1024 # 4MB
max_size = 1024 * 1024 * 1024 # 1GB

try:
size = int(value)
if size < min_size:
raise argparse.ArgumentTypeError(f"GRPC message size must be at least {min_size} bytes (4MB)")
if size > max_size:
raise argparse.ArgumentTypeError(f"GRPC message size must be at most {max_size} bytes (1GB)")
return size
except ValueError:
raise argparse.ArgumentTypeError(f"'{value}' is not a valid integer")

def add_asr_config_argparse_parameters(
parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False
Expand Down Expand Up @@ -102,13 +116,16 @@ def add_asr_config_argparse_parameters(

def add_connection_argparse_parameters(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--server", default="localhost:50051", help="URI to GRPC server endpoint.")
parser.add_argument("--ssl-cert", help="Path to SSL client certificates file.")
parser.add_argument("--ssl-root-cert", help="Path to SSL root certificates file.")
parser.add_argument("--ssl-client-cert", help="Path to SSL client certificates file.")
parser.add_argument("--ssl-client-key", help="Path to SSL client key file.")
parser.add_argument(
"--use-ssl", action='store_true', help="Boolean to control if SSL/TLS encryption should be used."
)
parser.add_argument("--metadata", action='append', nargs='+', help="Send HTTP Header(s) to server")
parser.add_argument("--options", action='append', nargs='+', help="Send GRPC options to server")
parser.add_argument(
"--max-message-length", type=int, default=64 * 1024 * 1024, help="Maximum message length for GRPC server."
"--max-message-length", type=validate_grpc_message_size, default=64 * 1024 * 1024, help="Maximum message length for GRPC server."
)
return parser

Expand Down
82 changes: 64 additions & 18 deletions riva/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,34 @@


def create_channel(
ssl_cert: Optional[Union[str, os.PathLike]] = None,
ssl_root_cert: Optional[Union[str, os.PathLike]] = None,
ssl_client_cert: Optional[Union[str, os.PathLike]] = None,
ssl_client_key: Optional[Union[str, os.PathLike]] = None,
use_ssl: bool = False,
uri: str = "localhost:50051",
metadata: Optional[List[Tuple[str, str]]] = None,
max_message_length: int = 64 * 1024 * 1024,
options: Optional[List[Tuple[str, str]]] = [],
) -> grpc.Channel:
def metadata_callback(context, callback):
callback(metadata, None)

options = [('grpc.max_receive_message_length', max_message_length), ('grpc.max_send_message_length', max_message_length)]
if ssl_cert is not None or use_ssl:
if ssl_root_cert is not None or ssl_client_cert is not None or ssl_client_key is not None or use_ssl:
root_certificates = None
if ssl_cert is not None:
ssl_cert = Path(ssl_cert).expanduser()
with open(ssl_cert, 'rb') as f:
client_certificates = None
client_key = None
if ssl_root_cert is not None:
ssl_root_cert = Path(ssl_root_cert).expanduser()
with open(ssl_root_cert, 'rb') as f:
root_certificates = f.read()
creds = grpc.ssl_channel_credentials(root_certificates)
if ssl_client_cert is not None:
ssl_client_cert = Path(ssl_client_cert).expanduser()
with open(ssl_client_cert, 'rb') as f:
client_certificates = f.read()
if ssl_client_key is not None:
ssl_client_key = Path(ssl_client_key).expanduser()
with open(ssl_client_key, 'rb') as f:
client_key = f.read()
creds = grpc.ssl_channel_credentials(root_certificates=root_certificates, private_key=client_key, certificate_chain=client_certificates)
if metadata:
auth_creds = grpc.metadata_call_credentials(metadata_callback)
creds = grpc.composite_channel_credentials(creds, auth_creds)
Expand All @@ -37,23 +48,58 @@ def metadata_callback(context, callback):
class Auth:
def __init__(
self,
ssl_cert: Optional[Union[str, os.PathLike]] = None,
ssl_root_cert: Optional[Union[str, os.PathLike]] = None,
use_ssl: bool = False,
uri: str = "localhost:50051",
metadata_args: List[List[str]] = None,
max_message_length: int = 64 * 1024 * 1024,
ssl_client_cert: Optional[Union[str, os.PathLike]] = None,
ssl_client_key: Optional[Union[str, os.PathLike]] = None,
options: Optional[List[Tuple[str, str]]] = [],
) -> None:
"""
A class responsible for establishing connection with a server and providing security metadata.
Initialize the Auth class for establishing secure connections with a server.

This class handles SSL/TLS configuration, authentication metadata, and gRPC channel creation
for secure communication with Riva services.

Args:
ssl_cert (:obj:`Union[str, os.PathLike]`, `optional`): a path to SSL certificate file. If :param:`use_ssl`
is :obj:`False` and :param:`ssl_cert` is not :obj:`None`, then SSL is used.
use_ssl (:obj:`bool`, defaults to :obj:`False`): whether to use SSL. If :param:`ssl_cert` is :obj:`None`,
then SSL is still used but with default credentials.
uri (:obj:`str`, defaults to :obj:`"localhost:50051"`): a Riva URI.
ssl_root_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL root certificate file.
If provided and use_ssl is False, SSL will still be enabled. Defaults to None.
use_ssl (bool, optional): Whether to use SSL/TLS encryption. If True and ssl_root_cert is None,
SSL will be used with default credentials. Defaults to False.
uri (str, optional): The Riva server URI in format "host:port". Defaults to "localhost:50051".
metadata_args (List[List[str]], optional): List of metadata key-value pairs for authentication.
Each inner list should contain exactly 2 elements: [key, value]. Defaults to None.
ssl_client_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL client certificate file.
Used for mutual TLS authentication. Defaults to None.
ssl_client_key (Optional[Union[str, os.PathLike]], optional): Path to the SSL client private key file.
Used for mutual TLS authentication. Defaults to None.
options (Optional[List[Tuple[str, str]]], optional): Additional gRPC channel options.
Each tuple should contain (option_name, option_value). Defaults to [].

Raises:
ValueError: If any metadata argument doesn't contain exactly 2 elements (key-value pair).

Example:
>>> # Basic connection without SSL
>>> auth = Auth(uri="localhost:50051")

>>> # SSL connection with custom certificate
>>> auth = Auth(
... use_ssl=True,
... ssl_root_cert="/path/to/cert.pem",
... uri="secure-server:50051"
... )

>>> # Connection with authentication metadata
>>> auth = Auth(
... metadata_args=[["api-key", "your-api-key"], ["user-id", "12345"]],
... uri="auth-server:50051"
... )
"""
self.ssl_cert: Optional[Path] = None if ssl_cert is None else Path(ssl_cert).expanduser()
self.ssl_root_cert: Optional[Path] = None if ssl_root_cert is None else Path(ssl_root_cert).expanduser()
self.ssl_client_cert: Optional[Path] = None if ssl_client_cert is None else Path(ssl_client_cert).expanduser()
self.ssl_client_key: Optional[Path] = None if ssl_client_key is None else Path(ssl_client_key).expanduser()
self.uri: str = uri
self.use_ssl: bool = use_ssl
self.metadata = []
Expand All @@ -65,7 +111,7 @@ def __init__(
)
self.metadata.append(tuple(meta))
self.channel: grpc.Channel = create_channel(
self.ssl_cert, self.use_ssl, self.uri, self.metadata, max_message_length=max_message_length
self.ssl_root_cert, self.ssl_client_cert, self.ssl_client_key, self.use_ssl, self.uri, self.metadata, options=options
)

def get_auth_metadata(self) -> List[Tuple[str, str]]:
Expand Down
30 changes: 19 additions & 11 deletions scripts/asr/realtime_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,12 @@ def parse_args() -> argparse.Namespace:
help="Duration in seconds to record from microphone (only used with --mic)",
default=None
)

# Audio device configuration
try:
import riva.client.audio_io
default_device_info = riva.client.audio_io.get_default_input_device_info()
default_device_index = None if default_device_info is None else default_device_info['index']
except ModuleNotFoundError:
default_device_index = None

parser.add_argument(
"--input-device",
type=int,
default=default_device_index,
help="Input audio device index to use (only used with --mic)"
default=None,
help="Input audio device index to use (only used with --mic). If not specified, will use default device."
)
parser.add_argument(
"--list-devices",
Expand Down Expand Up @@ -126,6 +118,15 @@ def parse_args() -> argparse.Namespace:
return args


def get_default_device_index():
"""Get default audio device index only when needed."""
try:
import riva.client.audio_io
default_device_info = riva.client.audio_io.get_default_input_device_info()
return None if default_device_info is None else default_device_info['index']
except ModuleNotFoundError:
return None

def setup_signal_handler():
"""Set up signal handler for graceful shutdown."""
def signal_handler(sig, frame):
Expand All @@ -145,11 +146,18 @@ async def create_audio_iterator(args):
Audio iterator for streaming audio data
"""
if args.mic:
# Only import when using microphone
from riva.client.audio_io import MicrophoneStream

# Get default device index if not specified
device_index = args.input_device
if device_index is None:
device_index = get_default_device_index()

audio_chunk_iterator = MicrophoneStream(
args.sample_rate_hz,
args.file_streaming_chunk,
device=args.input_device
device=device_index
)
args.num_channels = 1
else:
Expand Down
10 changes: 9 additions & 1 deletion scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ def streaming_transcription_worker(
) -> None:
output_file = Path(output_file).expanduser()
try:
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
Expand Down
10 changes: 9 additions & 1 deletion scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,15 @@ def main() -> None:
if args.list_devices:
riva.client.audio_io.list_output_devices()
return
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
asr_service = riva.client.ASRService(auth)

if args.list_models:
Expand Down
10 changes: 9 additions & 1 deletion scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ def parse_args() -> argparse.Namespace:
def main() -> None:
args = parse_args()

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)]
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=options)
asr_service = riva.client.ASRService(auth)

if args.list_models:
Expand Down
10 changes: 9 additions & 1 deletion scripts/asr/transcribe_mic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ def main() -> None:
if args.list_devices:
riva.client.audio_io.list_input_devices()
return
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
Expand Down
20 changes: 18 additions & 2 deletions scripts/nlp/punctuation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@ def parse_args() -> argparse.Namespace:


def run_punct_capit(args: argparse.Namespace) -> None:
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nlp_service = riva.client.NLPService(auth)
if args.interactive:
while True:
Expand Down Expand Up @@ -134,7 +142,15 @@ def run_tests(args: argparse.Namespace) -> int:
],
}

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nlp_service = riva.client.NLPService(auth)

fail_count = 0
Expand Down
10 changes: 9 additions & 1 deletion scripts/nmt/nmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,15 @@ def request(inputs,args):

args = parse_args()

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nmt_client = riva.client.NeuralMachineTranslationClient(auth)

if args.list_models:
Expand Down
10 changes: 9 additions & 1 deletion scripts/nmt/nmt_speech_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ def main():
if not os.path.exists(args.audio_file):
raise FileNotFoundError(f"Input audio file not found: {args.audio_file}")

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nmt_client = riva.client.NeuralMachineTranslationClient(auth)

if args.list_models:
Expand Down
10 changes: 9 additions & 1 deletion scripts/nmt/nmt_speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,15 @@ def main():
if not os.path.exists(args.audio_file):
raise FileNotFoundError(f"Input audio file not found: {args.audio_file}")

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nmt_client = riva.client.NeuralMachineTranslationClient(auth)

if args.list_models:
Expand Down
13 changes: 12 additions & 1 deletion scripts/tts/talk.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,19 @@ def main() -> None:
riva.client.audio_io.list_output_devices()
return

if args.options is None:
args.options = []
args.options.append(('grpc.max_receive_message_length', args.max_message_length))
args.options.append(('grpc.max_send_message_length', args.max_message_length))

auth = riva.client.Auth(
args.ssl_cert, args.use_ssl, args.server, args.metadata, max_message_length=args.max_message_length
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
service = riva.client.SpeechSynthesisService(auth)
nchannels = 1
Expand Down