From 1e611aa6f478dfb1527b2f21344dcdf989200599 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Fri, 11 Jul 2025 15:15:52 +0530 Subject: [PATCH 1/6] allow passing grpc channel create options --- riva/client/argparse_utils.py | 16 +++++++++++++++- riva/client/auth.py | 7 +++---- scripts/asr/transcribe_file_offline.py | 3 ++- scripts/tts/talk.py | 3 ++- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index d662697..b4b04d0 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -3,6 +3,20 @@ import argparse +def validate_grpc_message_size(value): + """Validate that the GRPC message size is within acceptable limits.""" + min_size = 4 * 1024 * 1024 # 4MB + max_size = 1024 * 1024 * 1024 # 1GB + + try: + size = int(value) + if size < min_size: + raise argparse.ArgumentTypeError(f"GRPC message size must be at least {min_size} bytes (4MB)") + if size > max_size: + raise argparse.ArgumentTypeError(f"GRPC message size must be at most {max_size} bytes (1GB)") + return size + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not a valid integer") def add_asr_config_argparse_parameters( parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False @@ -108,7 +122,7 @@ def add_connection_argparse_parameters(parser: argparse.ArgumentParser) -> argpa ) parser.add_argument("--metadata", action='append', nargs='+', help="Send HTTP Header(s) to server") parser.add_argument( - "--max-message-length", type=int, default=64 * 1024 * 1024, help="Maximum message length for GRPC server." + "--max-message-length", type=validate_grpc_message_size, default=64 * 1024 * 1024, help="Maximum message length for GRPC server." ) return parser diff --git a/riva/client/auth.py b/riva/client/auth.py index 046f409..7db531d 100644 --- a/riva/client/auth.py +++ b/riva/client/auth.py @@ -12,12 +12,11 @@ def create_channel( use_ssl: bool = False, uri: str = "localhost:50051", metadata: Optional[List[Tuple[str, str]]] = None, - max_message_length: int = 64 * 1024 * 1024, + options: Optional[List[Tuple[str, str]]] = [], ) -> grpc.Channel: def metadata_callback(context, callback): callback(metadata, None) - options = [('grpc.max_receive_message_length', max_message_length), ('grpc.max_send_message_length', max_message_length)] if ssl_cert is not None or use_ssl: root_certificates = None if ssl_cert is not None: @@ -41,7 +40,7 @@ def __init__( use_ssl: bool = False, uri: str = "localhost:50051", metadata_args: List[List[str]] = None, - max_message_length: int = 64 * 1024 * 1024, + options: Optional[List[Tuple[str, str]]] = [], ) -> None: """ A class responsible for establishing connection with a server and providing security metadata. @@ -65,7 +64,7 @@ def __init__( ) self.metadata.append(tuple(meta)) self.channel: grpc.Channel = create_channel( - self.ssl_cert, self.use_ssl, self.uri, self.metadata, max_message_length=max_message_length + self.ssl_cert, self.use_ssl, self.uri, self.metadata, options=options ) def get_auth_metadata(self) -> List[Tuple[str, str]]: diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 8d1fd48..ea8a757 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -33,7 +33,8 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)] + auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata, options=options) asr_service = riva.client.ASRService(auth) if args.list_models: diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index ac3ed6d..c99a5ec 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -99,8 +99,9 @@ def main() -> None: riva.client.audio_io.list_output_devices() return + options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)] auth = riva.client.Auth( - args.ssl_cert, args.use_ssl, args.server, args.metadata, max_message_length=args.max_message_length + args.ssl_cert, args.use_ssl, args.server, args.metadata, options=options ) service = riva.client.SpeechSynthesisService(auth) nchannels = 1 From 9ca6fe217f0a3b575669f2e90dde51ce2b012c68 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Sat, 12 Jul 2025 03:00:40 +0530 Subject: [PATCH 2/6] add support for using client cert for MTLS --- riva/client/argparse_utils.py | 5 ++++- riva/client/auth.py | 22 +++++++++++++++++++--- scripts/asr/riva_streaming_asr_client.py | 10 +++++++++- scripts/asr/transcribe_file.py | 10 +++++++++- scripts/asr/transcribe_file_offline.py | 9 ++++++++- scripts/asr/transcribe_mic.py | 10 +++++++++- scripts/nlp/punctuation_client.py | 20 ++++++++++++++++++-- scripts/nmt/nmt_speech_to_speech.py | 10 +++++++++- scripts/nmt/nmt_speech_to_text.py | 10 +++++++++- scripts/tts/talk.py | 14 ++++++++++++-- 10 files changed, 106 insertions(+), 14 deletions(-) diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index b4b04d0..d76f91e 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -116,11 +116,14 @@ def add_asr_config_argparse_parameters( def add_connection_argparse_parameters(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--server", default="localhost:50051", help="URI to GRPC server endpoint.") - parser.add_argument("--ssl-cert", help="Path to SSL client certificates file.") + parser.add_argument("--ssl-cert", help="Path to SSL root certificates file.") + parser.add_argument("--ssl-client-cert", help="Path to SSL client certificates file.") + parser.add_argument("--ssl-client-key", help="Path to SSL client key file.") parser.add_argument( "--use-ssl", action='store_true', help="Boolean to control if SSL/TLS encryption should be used." ) parser.add_argument("--metadata", action='append', nargs='+', help="Send HTTP Header(s) to server") + parser.add_argument("--options", action='append', nargs='+', help="Send GRPC options to server") parser.add_argument( "--max-message-length", type=validate_grpc_message_size, default=64 * 1024 * 1024, help="Maximum message length for GRPC server." ) diff --git a/riva/client/auth.py b/riva/client/auth.py index 7db531d..55188b7 100644 --- a/riva/client/auth.py +++ b/riva/client/auth.py @@ -9,6 +9,8 @@ def create_channel( ssl_cert: Optional[Union[str, os.PathLike]] = None, + ssl_client_cert: Optional[Union[str, os.PathLike]] = None, + ssl_client_key: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051", metadata: Optional[List[Tuple[str, str]]] = None, @@ -17,13 +19,23 @@ def create_channel( def metadata_callback(context, callback): callback(metadata, None) - if ssl_cert is not None or use_ssl: + if ssl_cert is not None or ssl_client_cert is not None or ssl_client_key is not None or use_ssl: root_certificates = None + client_certificates = None + client_key = None if ssl_cert is not None: ssl_cert = Path(ssl_cert).expanduser() with open(ssl_cert, 'rb') as f: root_certificates = f.read() - creds = grpc.ssl_channel_credentials(root_certificates) + if ssl_client_cert is not None: + ssl_client_cert = Path(ssl_client_cert).expanduser() + with open(ssl_client_cert, 'rb') as f: + client_certificates = f.read() + if ssl_client_key is not None: + ssl_client_key = Path(ssl_client_key).expanduser() + with open(ssl_client_key, 'rb') as f: + client_key = f.read() + creds = grpc.ssl_channel_credentials(root_certificates=root_certificates, private_key=client_key, certificate_chain=client_certificates) if metadata: auth_creds = grpc.metadata_call_credentials(metadata_callback) creds = grpc.composite_channel_credentials(creds, auth_creds) @@ -40,6 +52,8 @@ def __init__( use_ssl: bool = False, uri: str = "localhost:50051", metadata_args: List[List[str]] = None, + ssl_client_cert: Optional[Union[str, os.PathLike]] = None, + ssl_client_key: Optional[Union[str, os.PathLike]] = None, options: Optional[List[Tuple[str, str]]] = [], ) -> None: """ @@ -53,6 +67,8 @@ def __init__( uri (:obj:`str`, defaults to :obj:`"localhost:50051"`): a Riva URI. """ self.ssl_cert: Optional[Path] = None if ssl_cert is None else Path(ssl_cert).expanduser() + self.ssl_client_cert: Optional[Path] = None if ssl_client_cert is None else Path(ssl_client_cert).expanduser() + self.ssl_client_key: Optional[Path] = None if ssl_client_key is None else Path(ssl_client_key).expanduser() self.uri: str = uri self.use_ssl: bool = use_ssl self.metadata = [] @@ -64,7 +80,7 @@ def __init__( ) self.metadata.append(tuple(meta)) self.channel: grpc.Channel = create_channel( - self.ssl_cert, self.use_ssl, self.uri, self.metadata, options=options + self.ssl_cert, self.ssl_client_cert, self.ssl_client_key, self.use_ssl, self.uri, self.metadata, options=options ) def get_auth_metadata(self) -> List[Tuple[str, str]]: diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 4387394..21ed237 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -50,7 +50,15 @@ def streaming_transcription_worker( ) -> None: output_file = Path(output_file).expanduser() try: - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_cert=args.ssl_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index f9cbd7c..ee9bde3 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -66,7 +66,15 @@ def main() -> None: if args.list_devices: riva.client.audio_io.list_output_devices() return - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_cert=args.ssl_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) asr_service = riva.client.ASRService(auth) if args.list_models: diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index ea8a757..38ebd11 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -34,7 +34,14 @@ def main() -> None: args = parse_args() options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)] - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata, options=options) + auth = riva.client.Auth( + ssl_cert=args.ssl_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=options) asr_service = riva.client.ASRService(auth) if args.list_models: diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index 77d38fb..fa29ae3 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -45,7 +45,15 @@ def main() -> None: if args.list_devices: riva.client.audio_io.list_input_devices() return - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_cert=args.ssl_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( diff --git a/scripts/nlp/punctuation_client.py b/scripts/nlp/punctuation_client.py index ab843a7..e63cced 100644 --- a/scripts/nlp/punctuation_client.py +++ b/scripts/nlp/punctuation_client.py @@ -39,7 +39,15 @@ def parse_args() -> argparse.Namespace: def run_punct_capit(args: argparse.Namespace) -> None: - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_cert=args.ssl_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nlp_service = riva.client.NLPService(auth) if args.interactive: while True: @@ -134,7 +142,15 @@ def run_tests(args: argparse.Namespace) -> int: ], } - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_cert=args.ssl_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nlp_service = riva.client.NLPService(auth) fail_count = 0 diff --git a/scripts/nmt/nmt_speech_to_speech.py b/scripts/nmt/nmt_speech_to_speech.py index 49426bf..d493207 100644 --- a/scripts/nmt/nmt_speech_to_speech.py +++ b/scripts/nmt/nmt_speech_to_speech.py @@ -32,7 +32,15 @@ def main(): if not os.path.exists(args.audio_file): raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_cert=args.ssl_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: diff --git a/scripts/nmt/nmt_speech_to_text.py b/scripts/nmt/nmt_speech_to_text.py index 7d75a6c..f83fb21 100644 --- a/scripts/nmt/nmt_speech_to_text.py +++ b/scripts/nmt/nmt_speech_to_text.py @@ -43,7 +43,15 @@ def main(): if not os.path.exists(args.audio_file): raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_cert=args.ssl_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index c99a5ec..cf32d9b 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -99,9 +99,19 @@ def main() -> None: riva.client.audio_io.list_output_devices() return - options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)] + if args.options is None: + args.options = [] + args.options.append(('grpc.max_receive_message_length', args.max_message_length)) + args.options.append(('grpc.max_send_message_length', args.max_message_length)) + auth = riva.client.Auth( - args.ssl_cert, args.use_ssl, args.server, args.metadata, options=options + ssl_cert=args.ssl_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options ) service = riva.client.SpeechSynthesisService(auth) nchannels = 1 From 7f85e65aa60ee244c5183a1c0108717b5eb404e9 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Thu, 7 Aug 2025 12:17:16 +0530 Subject: [PATCH 3/6] rename ssl-cert to ssl-root-cert --- riva/client/argparse_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index d76f91e..a8cc1a7 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -116,7 +116,7 @@ def add_asr_config_argparse_parameters( def add_connection_argparse_parameters(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--server", default="localhost:50051", help="URI to GRPC server endpoint.") - parser.add_argument("--ssl-cert", help="Path to SSL root certificates file.") + parser.add_argument("--ssl-root-cert", help="Path to SSL root certificates file.") parser.add_argument("--ssl-client-cert", help="Path to SSL client certificates file.") parser.add_argument("--ssl-client-key", help="Path to SSL client key file.") parser.add_argument( From 33202bb33b42c1c078a97fbf7ff47f5d4d2c93a8 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Thu, 7 Aug 2025 13:14:27 +0530 Subject: [PATCH 4/6] fix: renaming ssl_cert -> ssl_root_cert --- riva/client/auth.py | 10 +++++----- scripts/asr/riva_streaming_asr_client.py | 2 +- scripts/asr/transcribe_file.py | 2 +- scripts/asr/transcribe_file_offline.py | 2 +- scripts/asr/transcribe_mic.py | 2 +- scripts/nlp/punctuation_client.py | 2 +- scripts/nmt/nmt.py | 10 +++++++++- scripts/nmt/nmt_speech_to_speech.py | 2 +- scripts/nmt/nmt_speech_to_text.py | 2 +- scripts/tts/talk.py | 2 +- 10 files changed, 22 insertions(+), 14 deletions(-) diff --git a/riva/client/auth.py b/riva/client/auth.py index 55188b7..941029f 100644 --- a/riva/client/auth.py +++ b/riva/client/auth.py @@ -8,7 +8,7 @@ def create_channel( - ssl_cert: Optional[Union[str, os.PathLike]] = None, + ssl_root_cert: Optional[Union[str, os.PathLike]] = None, ssl_client_cert: Optional[Union[str, os.PathLike]] = None, ssl_client_key: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, @@ -19,13 +19,13 @@ def create_channel( def metadata_callback(context, callback): callback(metadata, None) - if ssl_cert is not None or ssl_client_cert is not None or ssl_client_key is not None or use_ssl: + if ssl_root_cert is not None or ssl_client_cert is not None or ssl_client_key is not None or use_ssl: root_certificates = None client_certificates = None client_key = None - if ssl_cert is not None: - ssl_cert = Path(ssl_cert).expanduser() - with open(ssl_cert, 'rb') as f: + if ssl_root_cert is not None: + ssl_root_cert = Path(ssl_root_cert).expanduser() + with open(ssl_root_cert, 'rb') as f: root_certificates = f.read() if ssl_client_cert is not None: ssl_client_cert = Path(ssl_client_cert).expanduser() diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 21ed237..f600af6 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -51,7 +51,7 @@ def streaming_transcription_worker( output_file = Path(output_file).expanduser() try: auth = riva.client.Auth( - ssl_cert=args.ssl_cert, + ssl_root_cert=args.ssl_root_cert, ssl_client_cert=args.ssl_client_cert, ssl_client_key=args.ssl_client_key, use_ssl=args.use_ssl, diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index ee9bde3..1849a67 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -67,7 +67,7 @@ def main() -> None: riva.client.audio_io.list_output_devices() return auth = riva.client.Auth( - ssl_cert=args.ssl_cert, + ssl_root_cert=args.ssl_root_cert, ssl_client_cert=args.ssl_client_cert, ssl_client_key=args.ssl_client_key, use_ssl=args.use_ssl, diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 38ebd11..afedb95 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -35,7 +35,7 @@ def main() -> None: options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)] auth = riva.client.Auth( - ssl_cert=args.ssl_cert, + ssl_root_cert=args.ssl_root_cert, ssl_client_cert=args.ssl_client_cert, ssl_client_key=args.ssl_client_key, use_ssl=args.use_ssl, diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index fa29ae3..3fd2b5a 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -46,7 +46,7 @@ def main() -> None: riva.client.audio_io.list_input_devices() return auth = riva.client.Auth( - ssl_cert=args.ssl_cert, + ssl_root_cert=args.ssl_root_cert, ssl_client_cert=args.ssl_client_cert, ssl_client_key=args.ssl_client_key, use_ssl=args.use_ssl, diff --git a/scripts/nlp/punctuation_client.py b/scripts/nlp/punctuation_client.py index e63cced..c0b57de 100644 --- a/scripts/nlp/punctuation_client.py +++ b/scripts/nlp/punctuation_client.py @@ -40,7 +40,7 @@ def parse_args() -> argparse.Namespace: def run_punct_capit(args: argparse.Namespace) -> None: auth = riva.client.Auth( - ssl_cert=args.ssl_cert, + ssl_root_cert=args.ssl_root_cert, ssl_client_cert=args.ssl_client_cert, ssl_client_key=args.ssl_client_key, use_ssl=args.use_ssl, diff --git a/scripts/nmt/nmt.py b/scripts/nmt/nmt.py index a978cc9..f3d13db 100644 --- a/scripts/nmt/nmt.py +++ b/scripts/nmt/nmt.py @@ -123,7 +123,15 @@ def request(inputs,args): args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: diff --git a/scripts/nmt/nmt_speech_to_speech.py b/scripts/nmt/nmt_speech_to_speech.py index d493207..7348525 100644 --- a/scripts/nmt/nmt_speech_to_speech.py +++ b/scripts/nmt/nmt_speech_to_speech.py @@ -33,7 +33,7 @@ def main(): raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") auth = riva.client.Auth( - ssl_cert=args.ssl_cert, + ssl_root_cert=args.ssl_root_cert, ssl_client_cert=args.ssl_client_cert, ssl_client_key=args.ssl_client_key, use_ssl=args.use_ssl, diff --git a/scripts/nmt/nmt_speech_to_text.py b/scripts/nmt/nmt_speech_to_text.py index f83fb21..bc8c0f2 100644 --- a/scripts/nmt/nmt_speech_to_text.py +++ b/scripts/nmt/nmt_speech_to_text.py @@ -44,7 +44,7 @@ def main(): raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") auth = riva.client.Auth( - ssl_cert=args.ssl_cert, + ssl_root_cert=args.ssl_root_cert, ssl_client_cert=args.ssl_client_cert, ssl_client_key=args.ssl_client_key, use_ssl=args.use_ssl, diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index cf32d9b..2df233b 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -105,7 +105,7 @@ def main() -> None: args.options.append(('grpc.max_send_message_length', args.max_message_length)) auth = riva.client.Auth( - ssl_cert=args.ssl_cert, + ssl_root_cert=args.ssl_root_cert, ssl_client_cert=args.ssl_client_cert, ssl_client_key=args.ssl_client_key, use_ssl=args.use_ssl, From 12e49f6274310ca79c51cb48c44c467359c4bd03 Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Thu, 7 Aug 2025 13:22:54 +0530 Subject: [PATCH 5/6] fix auth init and docstring --- riva/client/auth.py | 47 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/riva/client/auth.py b/riva/client/auth.py index 941029f..8a71be7 100644 --- a/riva/client/auth.py +++ b/riva/client/auth.py @@ -48,7 +48,7 @@ def metadata_callback(context, callback): class Auth: def __init__( self, - ssl_cert: Optional[Union[str, os.PathLike]] = None, + ssl_root_cert: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051", metadata_args: List[List[str]] = None, @@ -57,16 +57,47 @@ def __init__( options: Optional[List[Tuple[str, str]]] = [], ) -> None: """ - A class responsible for establishing connection with a server and providing security metadata. + Initialize the Auth class for establishing secure connections with a server. + + This class handles SSL/TLS configuration, authentication metadata, and gRPC channel creation + for secure communication with Riva services. Args: - ssl_cert (:obj:`Union[str, os.PathLike]`, `optional`): a path to SSL certificate file. If :param:`use_ssl` - is :obj:`False` and :param:`ssl_cert` is not :obj:`None`, then SSL is used. - use_ssl (:obj:`bool`, defaults to :obj:`False`): whether to use SSL. If :param:`ssl_cert` is :obj:`None`, - then SSL is still used but with default credentials. - uri (:obj:`str`, defaults to :obj:`"localhost:50051"`): a Riva URI. + ssl_root_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL root certificate file. + If provided and use_ssl is False, SSL will still be enabled. Defaults to None. + use_ssl (bool, optional): Whether to use SSL/TLS encryption. If True and ssl_root_cert is None, + SSL will be used with default credentials. Defaults to False. + uri (str, optional): The Riva server URI in format "host:port". Defaults to "localhost:50051". + metadata_args (List[List[str]], optional): List of metadata key-value pairs for authentication. + Each inner list should contain exactly 2 elements: [key, value]. Defaults to None. + ssl_client_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL client certificate file. + Used for mutual TLS authentication. Defaults to None. + ssl_client_key (Optional[Union[str, os.PathLike]], optional): Path to the SSL client private key file. + Used for mutual TLS authentication. Defaults to None. + options (Optional[List[Tuple[str, str]]], optional): Additional gRPC channel options. + Each tuple should contain (option_name, option_value). Defaults to []. + + Raises: + ValueError: If any metadata argument doesn't contain exactly 2 elements (key-value pair). + + Example: + >>> # Basic connection without SSL + >>> auth = Auth(uri="localhost:50051") + + >>> # SSL connection with custom certificate + >>> auth = Auth( + ... use_ssl=True, + ... ssl_root_cert="/path/to/cert.pem", + ... uri="secure-server:50051" + ... ) + + >>> # Connection with authentication metadata + >>> auth = Auth( + ... metadata_args=[["api-key", "your-api-key"], ["user-id", "12345"]], + ... uri="auth-server:50051" + ... ) """ - self.ssl_cert: Optional[Path] = None if ssl_cert is None else Path(ssl_cert).expanduser() + self.ssl_root_cert: Optional[Path] = None if ssl_root_cert is None else Path(ssl_root_cert).expanduser() self.ssl_client_cert: Optional[Path] = None if ssl_client_cert is None else Path(ssl_client_cert).expanduser() self.ssl_client_key: Optional[Path] = None if ssl_client_key is None else Path(ssl_client_key).expanduser() self.uri: str = uri From a1bf1843348cff3a008d3fdc25e18648232eac6f Mon Sep 17 00:00:00 2001 From: Viraj Karandikar Date: Thu, 7 Aug 2025 13:22:54 +0530 Subject: [PATCH 6/6] fix missing typos --- riva/client/auth.py | 2 +- scripts/nlp/punctuation_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/riva/client/auth.py b/riva/client/auth.py index 8a71be7..4ee7085 100644 --- a/riva/client/auth.py +++ b/riva/client/auth.py @@ -111,7 +111,7 @@ def __init__( ) self.metadata.append(tuple(meta)) self.channel: grpc.Channel = create_channel( - self.ssl_cert, self.ssl_client_cert, self.ssl_client_key, self.use_ssl, self.uri, self.metadata, options=options + self.ssl_root_cert, self.ssl_client_cert, self.ssl_client_key, self.use_ssl, self.uri, self.metadata, options=options ) def get_auth_metadata(self) -> List[Tuple[str, str]]: diff --git a/scripts/nlp/punctuation_client.py b/scripts/nlp/punctuation_client.py index c0b57de..437cce2 100644 --- a/scripts/nlp/punctuation_client.py +++ b/scripts/nlp/punctuation_client.py @@ -143,7 +143,7 @@ def run_tests(args: argparse.Namespace) -> int: } auth = riva.client.Auth( - ssl_cert=args.ssl_cert, + ssl_root_cert=args.ssl_root_cert, ssl_client_cert=args.ssl_client_cert, ssl_client_key=args.ssl_client_key, use_ssl=args.use_ssl,