Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions riva/client/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,20 @@

import argparse

def validate_grpc_message_size(value):
"""Validate that the GRPC message size is within acceptable limits."""
min_size = 4 * 1024 * 1024 # 4MB
max_size = 1024 * 1024 * 1024 # 1GB

try:
size = int(value)
if size < min_size:
raise argparse.ArgumentTypeError(f"GRPC message size must be at least {min_size} bytes (4MB)")
if size > max_size:
raise argparse.ArgumentTypeError(f"GRPC message size must be at most {max_size} bytes (1GB)")
return size
except ValueError:
raise argparse.ArgumentTypeError(f"'{value}' is not a valid integer")

def add_asr_config_argparse_parameters(
parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False
Expand Down Expand Up @@ -102,13 +116,16 @@ def add_asr_config_argparse_parameters(

def add_connection_argparse_parameters(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--server", default="localhost:50051", help="URI to GRPC server endpoint.")
parser.add_argument("--ssl-cert", help="Path to SSL client certificates file.")
parser.add_argument("--ssl-root-cert", help="Path to SSL root certificates file.")
parser.add_argument("--ssl-client-cert", help="Path to SSL client certificates file.")
parser.add_argument("--ssl-client-key", help="Path to SSL client key file.")
parser.add_argument(
"--use-ssl", action='store_true', help="Boolean to control if SSL/TLS encryption should be used."
)
parser.add_argument("--metadata", action='append', nargs='+', help="Send HTTP Header(s) to server")
parser.add_argument("--options", action='append', nargs='+', help="Send GRPC options to server")
parser.add_argument(
"--max-message-length", type=int, default=64 * 1024 * 1024, help="Maximum message length for GRPC server."
"--max-message-length", type=validate_grpc_message_size, default=64 * 1024 * 1024, help="Maximum message length for GRPC server."
)
return parser

Expand Down
82 changes: 64 additions & 18 deletions riva/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,34 @@


def create_channel(
ssl_cert: Optional[Union[str, os.PathLike]] = None,
ssl_root_cert: Optional[Union[str, os.PathLike]] = None,
ssl_client_cert: Optional[Union[str, os.PathLike]] = None,
ssl_client_key: Optional[Union[str, os.PathLike]] = None,
use_ssl: bool = False,
uri: str = "localhost:50051",
metadata: Optional[List[Tuple[str, str]]] = None,
max_message_length: int = 64 * 1024 * 1024,
options: Optional[List[Tuple[str, str]]] = [],
) -> grpc.Channel:
def metadata_callback(context, callback):
callback(metadata, None)

options = [('grpc.max_receive_message_length', max_message_length), ('grpc.max_send_message_length', max_message_length)]
if ssl_cert is not None or use_ssl:
if ssl_root_cert is not None or ssl_client_cert is not None or ssl_client_key is not None or use_ssl:
root_certificates = None
if ssl_cert is not None:
ssl_cert = Path(ssl_cert).expanduser()
with open(ssl_cert, 'rb') as f:
client_certificates = None
client_key = None
if ssl_root_cert is not None:
ssl_root_cert = Path(ssl_root_cert).expanduser()
with open(ssl_root_cert, 'rb') as f:
root_certificates = f.read()
creds = grpc.ssl_channel_credentials(root_certificates)
if ssl_client_cert is not None:
ssl_client_cert = Path(ssl_client_cert).expanduser()
with open(ssl_client_cert, 'rb') as f:
client_certificates = f.read()
if ssl_client_key is not None:
ssl_client_key = Path(ssl_client_key).expanduser()
with open(ssl_client_key, 'rb') as f:
client_key = f.read()
creds = grpc.ssl_channel_credentials(root_certificates=root_certificates, private_key=client_key, certificate_chain=client_certificates)
if metadata:
auth_creds = grpc.metadata_call_credentials(metadata_callback)
creds = grpc.composite_channel_credentials(creds, auth_creds)
Expand All @@ -37,23 +48,58 @@ def metadata_callback(context, callback):
class Auth:
def __init__(
self,
ssl_cert: Optional[Union[str, os.PathLike]] = None,
ssl_root_cert: Optional[Union[str, os.PathLike]] = None,
use_ssl: bool = False,
uri: str = "localhost:50051",
metadata_args: List[List[str]] = None,
max_message_length: int = 64 * 1024 * 1024,
ssl_client_cert: Optional[Union[str, os.PathLike]] = None,
ssl_client_key: Optional[Union[str, os.PathLike]] = None,
options: Optional[List[Tuple[str, str]]] = [],
) -> None:
"""
A class responsible for establishing connection with a server and providing security metadata.
Initialize the Auth class for establishing secure connections with a server.

This class handles SSL/TLS configuration, authentication metadata, and gRPC channel creation
for secure communication with Riva services.

Args:
ssl_cert (:obj:`Union[str, os.PathLike]`, `optional`): a path to SSL certificate file. If :param:`use_ssl`
is :obj:`False` and :param:`ssl_cert` is not :obj:`None`, then SSL is used.
use_ssl (:obj:`bool`, defaults to :obj:`False`): whether to use SSL. If :param:`ssl_cert` is :obj:`None`,
then SSL is still used but with default credentials.
uri (:obj:`str`, defaults to :obj:`"localhost:50051"`): a Riva URI.
ssl_root_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL root certificate file.
If provided and use_ssl is False, SSL will still be enabled. Defaults to None.
use_ssl (bool, optional): Whether to use SSL/TLS encryption. If True and ssl_root_cert is None,
SSL will be used with default credentials. Defaults to False.
uri (str, optional): The Riva server URI in format "host:port". Defaults to "localhost:50051".
metadata_args (List[List[str]], optional): List of metadata key-value pairs for authentication.
Each inner list should contain exactly 2 elements: [key, value]. Defaults to None.
ssl_client_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL client certificate file.
Used for mutual TLS authentication. Defaults to None.
ssl_client_key (Optional[Union[str, os.PathLike]], optional): Path to the SSL client private key file.
Used for mutual TLS authentication. Defaults to None.
options (Optional[List[Tuple[str, str]]], optional): Additional gRPC channel options.
Each tuple should contain (option_name, option_value). Defaults to [].

Raises:
ValueError: If any metadata argument doesn't contain exactly 2 elements (key-value pair).

Example:
>>> # Basic connection without SSL
>>> auth = Auth(uri="localhost:50051")

>>> # SSL connection with custom certificate
>>> auth = Auth(
... use_ssl=True,
... ssl_root_cert="/path/to/cert.pem",
... uri="secure-server:50051"
... )

>>> # Connection with authentication metadata
>>> auth = Auth(
... metadata_args=[["api-key", "your-api-key"], ["user-id", "12345"]],
... uri="auth-server:50051"
... )
"""
self.ssl_cert: Optional[Path] = None if ssl_cert is None else Path(ssl_cert).expanduser()
self.ssl_root_cert: Optional[Path] = None if ssl_root_cert is None else Path(ssl_root_cert).expanduser()
self.ssl_client_cert: Optional[Path] = None if ssl_client_cert is None else Path(ssl_client_cert).expanduser()
self.ssl_client_key: Optional[Path] = None if ssl_client_key is None else Path(ssl_client_key).expanduser()
self.uri: str = uri
self.use_ssl: bool = use_ssl
self.metadata = []
Expand All @@ -65,7 +111,7 @@ def __init__(
)
self.metadata.append(tuple(meta))
self.channel: grpc.Channel = create_channel(
self.ssl_cert, self.use_ssl, self.uri, self.metadata, max_message_length=max_message_length
self.ssl_root_cert, self.ssl_client_cert, self.ssl_client_key, self.use_ssl, self.uri, self.metadata, options=options
)

def get_auth_metadata(self) -> List[Tuple[str, str]]:
Expand Down
10 changes: 9 additions & 1 deletion scripts/asr/riva_streaming_asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,15 @@ def streaming_transcription_worker(
) -> None:
output_file = Path(output_file).expanduser()
try:
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
Expand Down
10 changes: 9 additions & 1 deletion scripts/asr/transcribe_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,15 @@ def main() -> None:
if args.list_devices:
riva.client.audio_io.list_output_devices()
return
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
asr_service = riva.client.ASRService(auth)

if args.list_models:
Expand Down
10 changes: 9 additions & 1 deletion scripts/asr/transcribe_file_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ def parse_args() -> argparse.Namespace:
def main() -> None:
args = parse_args()

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)]
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=options)
asr_service = riva.client.ASRService(auth)

if args.list_models:
Expand Down
10 changes: 9 additions & 1 deletion scripts/asr/transcribe_mic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ def main() -> None:
if args.list_devices:
riva.client.audio_io.list_input_devices()
return
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
asr_service = riva.client.ASRService(auth)
config = riva.client.StreamingRecognitionConfig(
config=riva.client.RecognitionConfig(
Expand Down
20 changes: 18 additions & 2 deletions scripts/nlp/punctuation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,15 @@ def parse_args() -> argparse.Namespace:


def run_punct_capit(args: argparse.Namespace) -> None:
auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nlp_service = riva.client.NLPService(auth)
if args.interactive:
while True:
Expand Down Expand Up @@ -134,7 +142,15 @@ def run_tests(args: argparse.Namespace) -> int:
],
}

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nlp_service = riva.client.NLPService(auth)

fail_count = 0
Expand Down
10 changes: 9 additions & 1 deletion scripts/nmt/nmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,15 @@ def request(inputs,args):

args = parse_args()

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nmt_client = riva.client.NeuralMachineTranslationClient(auth)

if args.list_models:
Expand Down
10 changes: 9 additions & 1 deletion scripts/nmt/nmt_speech_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@ def main():
if not os.path.exists(args.audio_file):
raise FileNotFoundError(f"Input audio file not found: {args.audio_file}")

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nmt_client = riva.client.NeuralMachineTranslationClient(auth)

if args.list_models:
Expand Down
10 changes: 9 additions & 1 deletion scripts/nmt/nmt_speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,15 @@ def main():
if not os.path.exists(args.audio_file):
raise FileNotFoundError(f"Input audio file not found: {args.audio_file}")

auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata)
auth = riva.client.Auth(
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
nmt_client = riva.client.NeuralMachineTranslationClient(auth)

if args.list_models:
Expand Down
13 changes: 12 additions & 1 deletion scripts/tts/talk.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,19 @@ def main() -> None:
riva.client.audio_io.list_output_devices()
return

if args.options is None:
args.options = []
args.options.append(('grpc.max_receive_message_length', args.max_message_length))
args.options.append(('grpc.max_send_message_length', args.max_message_length))

auth = riva.client.Auth(
args.ssl_cert, args.use_ssl, args.server, args.metadata, max_message_length=args.max_message_length
ssl_root_cert=args.ssl_root_cert,
ssl_client_cert=args.ssl_client_cert,
ssl_client_key=args.ssl_client_key,
use_ssl=args.use_ssl,
uri=args.server,
metadata_args=args.metadata,
options=args.options
)
service = riva.client.SpeechSynthesisService(auth)
nchannels = 1
Expand Down