diff --git a/riva/client/argparse_utils.py b/riva/client/argparse_utils.py index d662697f..a8cc1a72 100644 --- a/riva/client/argparse_utils.py +++ b/riva/client/argparse_utils.py @@ -3,6 +3,20 @@ import argparse +def validate_grpc_message_size(value): + """Validate that the GRPC message size is within acceptable limits.""" + min_size = 4 * 1024 * 1024 # 4MB + max_size = 1024 * 1024 * 1024 # 1GB + + try: + size = int(value) + if size < min_size: + raise argparse.ArgumentTypeError(f"GRPC message size must be at least {min_size} bytes (4MB)") + if size > max_size: + raise argparse.ArgumentTypeError(f"GRPC message size must be at most {max_size} bytes (1GB)") + return size + except ValueError: + raise argparse.ArgumentTypeError(f"'{value}' is not a valid integer") def add_asr_config_argparse_parameters( parser: argparse.ArgumentParser, max_alternatives: bool = False, profanity_filter: bool = False, word_time_offsets: bool = False @@ -102,13 +116,16 @@ def add_asr_config_argparse_parameters( def add_connection_argparse_parameters(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--server", default="localhost:50051", help="URI to GRPC server endpoint.") - parser.add_argument("--ssl-cert", help="Path to SSL client certificates file.") + parser.add_argument("--ssl-root-cert", help="Path to SSL root certificates file.") + parser.add_argument("--ssl-client-cert", help="Path to SSL client certificates file.") + parser.add_argument("--ssl-client-key", help="Path to SSL client key file.") parser.add_argument( "--use-ssl", action='store_true', help="Boolean to control if SSL/TLS encryption should be used." ) parser.add_argument("--metadata", action='append', nargs='+', help="Send HTTP Header(s) to server") + parser.add_argument("--options", action='append', nargs='+', help="Send GRPC options to server") parser.add_argument( - "--max-message-length", type=int, default=64 * 1024 * 1024, help="Maximum message length for GRPC server." + "--max-message-length", type=validate_grpc_message_size, default=64 * 1024 * 1024, help="Maximum message length for GRPC server." ) return parser diff --git a/riva/client/auth.py b/riva/client/auth.py index 046f409d..4ee70854 100644 --- a/riva/client/auth.py +++ b/riva/client/auth.py @@ -8,23 +8,34 @@ def create_channel( - ssl_cert: Optional[Union[str, os.PathLike]] = None, + ssl_root_cert: Optional[Union[str, os.PathLike]] = None, + ssl_client_cert: Optional[Union[str, os.PathLike]] = None, + ssl_client_key: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051", metadata: Optional[List[Tuple[str, str]]] = None, - max_message_length: int = 64 * 1024 * 1024, + options: Optional[List[Tuple[str, str]]] = [], ) -> grpc.Channel: def metadata_callback(context, callback): callback(metadata, None) - options = [('grpc.max_receive_message_length', max_message_length), ('grpc.max_send_message_length', max_message_length)] - if ssl_cert is not None or use_ssl: + if ssl_root_cert is not None or ssl_client_cert is not None or ssl_client_key is not None or use_ssl: root_certificates = None - if ssl_cert is not None: - ssl_cert = Path(ssl_cert).expanduser() - with open(ssl_cert, 'rb') as f: + client_certificates = None + client_key = None + if ssl_root_cert is not None: + ssl_root_cert = Path(ssl_root_cert).expanduser() + with open(ssl_root_cert, 'rb') as f: root_certificates = f.read() - creds = grpc.ssl_channel_credentials(root_certificates) + if ssl_client_cert is not None: + ssl_client_cert = Path(ssl_client_cert).expanduser() + with open(ssl_client_cert, 'rb') as f: + client_certificates = f.read() + if ssl_client_key is not None: + ssl_client_key = Path(ssl_client_key).expanduser() + with open(ssl_client_key, 'rb') as f: + client_key = f.read() + creds = grpc.ssl_channel_credentials(root_certificates=root_certificates, private_key=client_key, certificate_chain=client_certificates) if metadata: auth_creds = grpc.metadata_call_credentials(metadata_callback) creds = grpc.composite_channel_credentials(creds, auth_creds) @@ -37,23 +48,58 @@ def metadata_callback(context, callback): class Auth: def __init__( self, - ssl_cert: Optional[Union[str, os.PathLike]] = None, + ssl_root_cert: Optional[Union[str, os.PathLike]] = None, use_ssl: bool = False, uri: str = "localhost:50051", metadata_args: List[List[str]] = None, - max_message_length: int = 64 * 1024 * 1024, + ssl_client_cert: Optional[Union[str, os.PathLike]] = None, + ssl_client_key: Optional[Union[str, os.PathLike]] = None, + options: Optional[List[Tuple[str, str]]] = [], ) -> None: """ - A class responsible for establishing connection with a server and providing security metadata. + Initialize the Auth class for establishing secure connections with a server. + + This class handles SSL/TLS configuration, authentication metadata, and gRPC channel creation + for secure communication with Riva services. Args: - ssl_cert (:obj:`Union[str, os.PathLike]`, `optional`): a path to SSL certificate file. If :param:`use_ssl` - is :obj:`False` and :param:`ssl_cert` is not :obj:`None`, then SSL is used. - use_ssl (:obj:`bool`, defaults to :obj:`False`): whether to use SSL. If :param:`ssl_cert` is :obj:`None`, - then SSL is still used but with default credentials. - uri (:obj:`str`, defaults to :obj:`"localhost:50051"`): a Riva URI. + ssl_root_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL root certificate file. + If provided and use_ssl is False, SSL will still be enabled. Defaults to None. + use_ssl (bool, optional): Whether to use SSL/TLS encryption. If True and ssl_root_cert is None, + SSL will be used with default credentials. Defaults to False. + uri (str, optional): The Riva server URI in format "host:port". Defaults to "localhost:50051". + metadata_args (List[List[str]], optional): List of metadata key-value pairs for authentication. + Each inner list should contain exactly 2 elements: [key, value]. Defaults to None. + ssl_client_cert (Optional[Union[str, os.PathLike]], optional): Path to the SSL client certificate file. + Used for mutual TLS authentication. Defaults to None. + ssl_client_key (Optional[Union[str, os.PathLike]], optional): Path to the SSL client private key file. + Used for mutual TLS authentication. Defaults to None. + options (Optional[List[Tuple[str, str]]], optional): Additional gRPC channel options. + Each tuple should contain (option_name, option_value). Defaults to []. + + Raises: + ValueError: If any metadata argument doesn't contain exactly 2 elements (key-value pair). + + Example: + >>> # Basic connection without SSL + >>> auth = Auth(uri="localhost:50051") + + >>> # SSL connection with custom certificate + >>> auth = Auth( + ... use_ssl=True, + ... ssl_root_cert="/path/to/cert.pem", + ... uri="secure-server:50051" + ... ) + + >>> # Connection with authentication metadata + >>> auth = Auth( + ... metadata_args=[["api-key", "your-api-key"], ["user-id", "12345"]], + ... uri="auth-server:50051" + ... ) """ - self.ssl_cert: Optional[Path] = None if ssl_cert is None else Path(ssl_cert).expanduser() + self.ssl_root_cert: Optional[Path] = None if ssl_root_cert is None else Path(ssl_root_cert).expanduser() + self.ssl_client_cert: Optional[Path] = None if ssl_client_cert is None else Path(ssl_client_cert).expanduser() + self.ssl_client_key: Optional[Path] = None if ssl_client_key is None else Path(ssl_client_key).expanduser() self.uri: str = uri self.use_ssl: bool = use_ssl self.metadata = [] @@ -65,7 +111,7 @@ def __init__( ) self.metadata.append(tuple(meta)) self.channel: grpc.Channel = create_channel( - self.ssl_cert, self.use_ssl, self.uri, self.metadata, max_message_length=max_message_length + self.ssl_root_cert, self.ssl_client_cert, self.ssl_client_key, self.use_ssl, self.uri, self.metadata, options=options ) def get_auth_metadata(self) -> List[Tuple[str, str]]: diff --git a/scripts/asr/riva_streaming_asr_client.py b/scripts/asr/riva_streaming_asr_client.py index 43873942..f600af66 100644 --- a/scripts/asr/riva_streaming_asr_client.py +++ b/scripts/asr/riva_streaming_asr_client.py @@ -50,7 +50,15 @@ def streaming_transcription_worker( ) -> None: output_file = Path(output_file).expanduser() try: - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( diff --git a/scripts/asr/transcribe_file.py b/scripts/asr/transcribe_file.py index f9cbd7c9..1849a675 100644 --- a/scripts/asr/transcribe_file.py +++ b/scripts/asr/transcribe_file.py @@ -66,7 +66,15 @@ def main() -> None: if args.list_devices: riva.client.audio_io.list_output_devices() return - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) asr_service = riva.client.ASRService(auth) if args.list_models: diff --git a/scripts/asr/transcribe_file_offline.py b/scripts/asr/transcribe_file_offline.py index 8d1fd482..afedb957 100644 --- a/scripts/asr/transcribe_file_offline.py +++ b/scripts/asr/transcribe_file_offline.py @@ -33,7 +33,15 @@ def parse_args() -> argparse.Namespace: def main() -> None: args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + options = [('grpc.max_receive_message_length', args.max_message_length), ('grpc.max_send_message_length', args.max_message_length)] + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=options) asr_service = riva.client.ASRService(auth) if args.list_models: diff --git a/scripts/asr/transcribe_mic.py b/scripts/asr/transcribe_mic.py index 77d38fb3..3fd2b5a2 100644 --- a/scripts/asr/transcribe_mic.py +++ b/scripts/asr/transcribe_mic.py @@ -45,7 +45,15 @@ def main() -> None: if args.list_devices: riva.client.audio_io.list_input_devices() return - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) asr_service = riva.client.ASRService(auth) config = riva.client.StreamingRecognitionConfig( config=riva.client.RecognitionConfig( diff --git a/scripts/nlp/punctuation_client.py b/scripts/nlp/punctuation_client.py index ab843a77..437cce21 100644 --- a/scripts/nlp/punctuation_client.py +++ b/scripts/nlp/punctuation_client.py @@ -39,7 +39,15 @@ def parse_args() -> argparse.Namespace: def run_punct_capit(args: argparse.Namespace) -> None: - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nlp_service = riva.client.NLPService(auth) if args.interactive: while True: @@ -134,7 +142,15 @@ def run_tests(args: argparse.Namespace) -> int: ], } - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nlp_service = riva.client.NLPService(auth) fail_count = 0 diff --git a/scripts/nmt/nmt.py b/scripts/nmt/nmt.py index a978cc9b..f3d13dbe 100644 --- a/scripts/nmt/nmt.py +++ b/scripts/nmt/nmt.py @@ -123,7 +123,15 @@ def request(inputs,args): args = parse_args() - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: diff --git a/scripts/nmt/nmt_speech_to_speech.py b/scripts/nmt/nmt_speech_to_speech.py index 49426bfe..7348525b 100644 --- a/scripts/nmt/nmt_speech_to_speech.py +++ b/scripts/nmt/nmt_speech_to_speech.py @@ -32,7 +32,15 @@ def main(): if not os.path.exists(args.audio_file): raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: diff --git a/scripts/nmt/nmt_speech_to_text.py b/scripts/nmt/nmt_speech_to_text.py index 7d75a6c7..bc8c0f2a 100644 --- a/scripts/nmt/nmt_speech_to_text.py +++ b/scripts/nmt/nmt_speech_to_text.py @@ -43,7 +43,15 @@ def main(): if not os.path.exists(args.audio_file): raise FileNotFoundError(f"Input audio file not found: {args.audio_file}") - auth = riva.client.Auth(args.ssl_cert, args.use_ssl, args.server, args.metadata) + auth = riva.client.Auth( + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options + ) nmt_client = riva.client.NeuralMachineTranslationClient(auth) if args.list_models: diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index ac3ed6d0..2df233b0 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -99,8 +99,19 @@ def main() -> None: riva.client.audio_io.list_output_devices() return + if args.options is None: + args.options = [] + args.options.append(('grpc.max_receive_message_length', args.max_message_length)) + args.options.append(('grpc.max_send_message_length', args.max_message_length)) + auth = riva.client.Auth( - args.ssl_cert, args.use_ssl, args.server, args.metadata, max_message_length=args.max_message_length + ssl_root_cert=args.ssl_root_cert, + ssl_client_cert=args.ssl_client_cert, + ssl_client_key=args.ssl_client_key, + use_ssl=args.use_ssl, + uri=args.server, + metadata_args=args.metadata, + options=args.options ) service = riva.client.SpeechSynthesisService(auth) nchannels = 1