diff --git a/requirements.txt b/requirements.txt index 0db03ce..2e4b2d2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ setuptools==78.1.1 grpcio==1.67.1 grpcio-tools==1.67.1 +websockets==15.0.1 diff --git a/riva/client/auth.py b/riva/client/auth.py index 4ee7085..8a4688d 100644 --- a/riva/client/auth.py +++ b/riva/client/auth.py @@ -15,6 +15,7 @@ def create_channel( uri: str = "localhost:50051", metadata: Optional[List[Tuple[str, str]]] = None, options: Optional[List[Tuple[str, str]]] = [], + use_aio: Optional[bool] = False, ) -> grpc.Channel: def metadata_callback(context, callback): callback(metadata, None) @@ -39,9 +40,15 @@ def metadata_callback(context, callback): if metadata: auth_creds = grpc.metadata_call_credentials(metadata_callback) creds = grpc.composite_channel_credentials(creds, auth_creds) - channel = grpc.secure_channel(uri, creds, options=options) + if use_aio: + channel = grpc.aio.secure_channel(uri, creds, options=options) + else: + channel = grpc.secure_channel(uri, creds, options=options) else: - channel = grpc.insecure_channel(uri, options=options) + if use_aio: + channel = grpc.aio.insecure_channel(uri, options=options) + else: + channel = grpc.insecure_channel(uri, options=options) return channel @@ -55,6 +62,7 @@ def __init__( ssl_client_cert: Optional[Union[str, os.PathLike]] = None, ssl_client_key: Optional[Union[str, os.PathLike]] = None, options: Optional[List[Tuple[str, str]]] = [], + use_aio: bool = False, ) -> None: """ Initialize the Auth class for establishing secure connections with a server. @@ -76,6 +84,7 @@ def __init__( Used for mutual TLS authentication. Defaults to None. options (Optional[List[Tuple[str, str]]], optional): Additional gRPC channel options. Each tuple should contain (option_name, option_value). Defaults to []. + use_aio (bool, optional): Whether to use asyncio for the channel. Defaults to False. Raises: ValueError: If any metadata argument doesn't contain exactly 2 elements (key-value pair). @@ -111,7 +120,14 @@ def __init__( ) self.metadata.append(tuple(meta)) self.channel: grpc.Channel = create_channel( - self.ssl_root_cert, self.ssl_client_cert, self.ssl_client_key, self.use_ssl, self.uri, self.metadata, options=options + self.ssl_root_cert, + self.ssl_client_cert, + self.ssl_client_key, + self.use_ssl, + self.uri, + self.metadata, + options=options, + use_aio=use_aio, ) def get_auth_metadata(self) -> List[Tuple[str, str]]: diff --git a/riva/client/realtime.py b/riva/client/realtime.py index 5de310b..7ff2ad6 100644 --- a/riva/client/realtime.py +++ b/riva/client/realtime.py @@ -10,10 +10,11 @@ import requests import websockets +import ssl from websockets.exceptions import WebSocketException logging.basicConfig( - level=logging.INFO, + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -21,40 +22,41 @@ class RealtimeClient: """Client for real-time transcription via WebSocket connection.""" - + def __init__(self, args: argparse.Namespace): """Initialize the RealtimeClient. - + Args: args: Command line arguments containing configuration """ self.args = args self.websocket = None self.session_config = None - + # Input audio playback self.input_audio_queue = queue.Queue() self.input_playback_thread = None self.is_input_playing = False self.input_buffer_size = 1024 # Buffer size for input audio playback - + # Transcription results self.delta_transcripts: List[str] = [] self.interim_final_transcripts: List[str] = [] self.final_transcript: str = "" self.is_config_updated = False + async def connect(self): """Establish connection to the ASR server.""" try: # Initialize session via HTTP POST session_data = await self._initialize_http_session() self.session_config = session_data - + # Connect to WebSocket await self._connect_websocket() await self._initialize_session() - + except requests.exceptions.RequestException as e: logger.error(f"HTTP request failed: {e}") raise @@ -68,27 +70,48 @@ async def connect(self): async def _initialize_http_session(self) -> Dict[str, Any]: """Initialize session via HTTP POST request.""" headers = {"Content-Type": "application/json"} + uri = f"http://{self.args.server}/v1/realtime/transcription_sessions" + if self.args.use_ssl: + uri = f"https://{self.args.server}/v1/realtime/transcription_sessions" + logger.info(f"Initializing session via HTTP POST request to: {uri}") response = requests.post( - f"http://{self.args.server}/v1/realtime/transcription_sessions", + uri, headers=headers, - json={} + json={}, + cert=(self.args.ssl_client_cert, self.args.ssl_client_key) if self.args.ssl_client_cert and self.args.ssl_client_key else None, + verify=self.args.ssl_root_cert if self.args.ssl_root_cert else True ) - + if response.status_code != 200: raise Exception( f"Failed to initialize session. Status: {response.status_code}, " f"Error: {response.text}" ) - + session_data = response.json() logger.info(f"Session initialized: {session_data}") return session_data async def _connect_websocket(self): """Connect to WebSocket endpoint.""" + ssl_context = None ws_url = f"ws://{self.args.server}{self.args.endpoint}?{self.args.query_params}" + if self.args.use_ssl: + ws_url = f"wss://{self.args.server}{self.args.endpoint}?{self.args.query_params}" + + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + # Load a custom CA certificate bundle + if self.args.ssl_root_cert: + ssl_context.load_verify_locations(self.args.ssl_root_cert) + # Load a client certificate and key + if self.args.ssl_client_cert and self.args.ssl_client_key: + ssl_context.load_cert_chain(self.args.ssl_client_cert, self.args.ssl_client_key) + # Disable hostname verification + ssl_context.check_hostname = False + # ssl_context.verify_mode = ssl.CERT_REQUIRED + logger.info(f"Connecting to WebSocket: {ws_url}") - self.websocket = await websockets.connect(ws_url) + self.websocket = await websockets.connect(ws_url, ssl=ssl_context) async def _initialize_session(self): """Initialize the WebSocket session.""" @@ -97,7 +120,7 @@ async def _initialize_session(self): response = await self.websocket.recv() response_data = json.loads(response) logger.info("Session created: %s", response_data) - + event_type = response_data.get("type", "") if event_type == "conversation.created": logger.info("Conversation created successfully") @@ -111,9 +134,9 @@ async def _initialize_session(self): if not self.is_config_updated: logger.error("Failed to update session") raise Exception("Failed to update session") - + logger.info("Session initialization complete") - + except json.JSONDecodeError as e: logger.error(f"Failed to parse JSON response: {e}") raise @@ -126,7 +149,7 @@ async def _initialize_session(self): def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, section: str = None): """Safely update a configuration value, creating the section if it doesn't exist. - + Args: config: The configuration dictionary to update key: The key to update @@ -144,62 +167,62 @@ def _safe_update_config(self, config: Dict[str, Any], key: str, value: Any, sect async def _update_session(self) -> bool: """Update session configuration by selectively overriding server defaults. - + Returns: True if session was updated successfully, False otherwise """ logger.info("Updating session configuration...") logger.info(f"Server default config: {self.session_config}") - + # Create a copy of the session config from server defaults session_config = self.session_config.copy() - + # Track what we're overriding overrides = [] - + # Update input audio transcription - only override if args are provided if hasattr(self.args, 'language_code') and self.args.language_code: self._safe_update_config(session_config, "language", self.args.language_code, "input_audio_transcription") overrides.append("language") - + if hasattr(self.args, 'model_name') and self.args.model_name: self._safe_update_config(session_config, "model", self.args.model_name, "input_audio_transcription") overrides.append("model") - + if hasattr(self.args, 'prompt') and self.args.prompt: self._safe_update_config(session_config, "prompt", self.args.prompt, "input_audio_transcription") overrides.append("prompt") - + # Update input audio parameters - only override if args are provided if hasattr(self.args, 'sample_rate_hz') and self.args.sample_rate_hz: self._safe_update_config(session_config, "sample_rate_hz", self.args.sample_rate_hz, "input_audio_params") overrides.append("sample_rate_hz") - + if hasattr(self.args, 'num_channels') and self.args.num_channels: self._safe_update_config(session_config, "num_channels", self.args.num_channels, "input_audio_params") overrides.append("num_channels") - + # Update recognition settings - only override if args are provided if hasattr(self.args, 'max_alternatives') and self.args.max_alternatives is not None: self._safe_update_config(session_config, "max_alternatives", self.args.max_alternatives, "recognition_config") overrides.append("max_alternatives") - + if hasattr(self.args, 'automatic_punctuation') and self.args.automatic_punctuation is not None: self._safe_update_config(session_config, "enable_automatic_punctuation", self.args.automatic_punctuation, "recognition_config") overrides.append("automatic_punctuation") - + if hasattr(self.args, 'word_time_offsets') and self.args.word_time_offsets is not None: self._safe_update_config(session_config, "enable_word_time_offsets", self.args.word_time_offsets, "recognition_config") overrides.append("word_time_offsets") - + if hasattr(self.args, 'profanity_filter') and self.args.profanity_filter is not None: self._safe_update_config(session_config, "enable_profanity_filter", self.args.profanity_filter, "recognition_config") overrides.append("profanity_filter") - + if hasattr(self.args, 'no_verbatim_transcripts') and self.args.no_verbatim_transcripts is not None: self._safe_update_config(session_config, "enable_verbatim_transcripts", self.args.no_verbatim_transcripts, "recognition_config") overrides.append("verbatim_transcripts") - + # Configure speaker diarization if enabled if hasattr(self.args, 'speaker_diarization') and self.args.speaker_diarization: session_config["speaker_diarization"] = { @@ -207,10 +230,10 @@ async def _update_session(self) -> bool: "max_speaker_count": getattr(self.args, 'diarization_max_speakers', 2) } overrides.append("speaker_diarization") - + # Configure word boosting if enabled - if (hasattr(self.args, 'boosted_lm_words') and - self.args.boosted_lm_words and + if (hasattr(self.args, 'boosted_lm_words') and + self.args.boosted_lm_words and len(self.args.boosted_lm_words)): word_boosting_list = [ { @@ -223,44 +246,44 @@ async def _update_session(self) -> bool: "word_boosting_list": word_boosting_list } overrides.append("word_boosting") - + # Configure endpointing if any parameters are set if self._has_endpointing_config(): session_config["endpointing_config"] = self._build_endpointing_config() overrides.append("endpointing_config") - + # Configure custom configuration if provided if hasattr(self.args, 'custom_configuration') and self.args.custom_configuration: custom_config = self._parse_custom_configuration(self.args.custom_configuration) if custom_config: session_config["custom_configuration"] = custom_config overrides.append("custom_configuration") - + if overrides: logger.info(f"Overriding server defaults for: {', '.join(overrides)}") else: logger.info("Using server default configuration (no overrides)") - + logger.info(f"Final session config: {session_config}") - + # Send update request update_session_request = { "type": "transcription_session.update", "session": session_config } await self._send_message(update_session_request) - + # Handle response return await self._handle_session_update_response() def _has_endpointing_config(self) -> bool: """Check if any endpointing configuration parameters are set.""" return ( - self.args.start_history > 0 or - self.args.start_threshold > 0 or - self.args.stop_history > 0 or - self.args.stop_history_eou > 0 or - self.args.stop_threshold > 0 or + self.args.start_history > 0 or + self.args.start_threshold > 0 or + self.args.stop_history > 0 or + self.args.stop_history_eou > 0 or + self.args.stop_threshold > 0 or self.args.stop_threshold_eou > 0 ) @@ -277,41 +300,41 @@ def _build_endpointing_config(self) -> Dict[str, Any]: def _parse_custom_configuration(self, custom_configuration: str) -> Dict[str, str]: """Parse custom configuration string into a dictionary. - + Args: custom_configuration: String in format "key1:value1,key2:value2" - + Returns: Dictionary of custom configuration key-value pairs - + Raises: ValueError: If the custom configuration format is invalid """ custom_config = {} custom_configuration = custom_configuration.strip().replace(" ", "") - + if not custom_configuration: return custom_config - + for pair in custom_configuration.split(","): key_value = pair.split(":") if len(key_value) == 2: custom_config[key_value[0]] = key_value[1] else: raise ValueError(f"Invalid key:value pair {key_value}") - + return custom_config async def _handle_session_update_response(self) -> bool: """Handle session update response. - + Returns: True if session was updated successfully, False otherwise """ response = await self.websocket.recv() response_data = json.loads(response) logger.info("Session updated: %s", response_data) - + event_type = response_data.get("type", "") if event_type == "transcription_session.updated": logger.info("Transcription session updated successfully") @@ -330,23 +353,23 @@ async def _send_message(self, message: Dict[str, Any]): async def send_audio_chunks(self, audio_chunks): """Send audio chunks to the server for transcription.""" logger.info("Sending audio chunks...") - + for chunk in audio_chunks: chunk_base64 = base64.b64encode(chunk).decode("utf-8") - + # Send chunk to the server await self._send_message({ "type": "input_audio_buffer.append", "audio": chunk_base64, }) - + # Commit the chunk await self._send_message({ "type": "input_audio_buffer.commit", }) - + logger.info("All chunks sent") - + # Tell the server that we are done sending chunks await self._send_message({ "type": "input_audio_buffer.done", @@ -356,7 +379,7 @@ async def receive_responses(self): """Receive and process transcription responses from the server.""" logger.info("Listening for responses...") received_final_response = False - + while not received_final_response: try: response = await asyncio.wait_for(self.websocket.recv(), 10.0) @@ -367,13 +390,13 @@ async def receive_responses(self): delta = event.get("delta", "") logger.info("Transcript: %s", delta) self.delta_transcripts.append(delta) - + elif event_type == "conversation.item.input_audio_transcription.completed": is_last_result = event.get("is_last_result", False) interim_final_transcript = event.get("transcript", "") self.interim_final_transcripts.append(interim_final_transcript) self.final_transcript = interim_final_transcript - + if is_last_result: logger.info("Final Transcript: %s", self.final_transcript) logger.info("Transcription completed") @@ -381,9 +404,9 @@ async def receive_responses(self): break else: logger.info("Interim Transcript: %s", interim_final_transcript) - + logger.info("Words Info: %s", event.get("words_info", "")) - + elif "error" in event_type.lower(): logger.error( f"Error: {event.get('error', {}).get('message', 'Unknown error')}" @@ -399,7 +422,7 @@ async def receive_responses(self): def save_responses(self, output_text_file: str): """Save collected transcription text to a file. - + Args: output_text_file: Path to the output text file """ diff --git a/scripts/asr/realtime_asr_client.py b/scripts/asr/realtime_asr_client.py index ea1e6f8..eb13c7b 100644 --- a/scripts/asr/realtime_asr_client.py +++ b/scripts/asr/realtime_asr_client.py @@ -11,6 +11,7 @@ from riva.client.argparse_utils import ( add_asr_config_argparse_parameters, add_realtime_config_argparse_parameters, + add_connection_argparse_parameters, ) @@ -27,94 +28,91 @@ def parse_args() -> argparse.Namespace: ), formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - + # Input configuration parser.add_argument( - "--input-file", - required=False, + "--input-file", + required=False, help="Input audio file (required when not using --mic)" ) parser.add_argument( - "--mic", - action="store_true", - help="Use microphone input instead of file input", + "--mic", + action="store_true", + help="Use microphone input instead of file input", default=False ) parser.add_argument( - "--duration", - type=int, - help="Duration in seconds to record from microphone (only used with --mic)", + "--duration", + type=int, + help="Duration in seconds to record from microphone (only used with --mic)", default=None ) - + parser.add_argument( - "--input-device", - type=int, - default=None, + "--input-device", + type=int, + default=None, help="Input audio device index to use (only used with --mic). If not specified, will use default device." ) parser.add_argument( - "--list-devices", - action="store_true", + "--list-devices", + action="store_true", help="List available input audio device indices" ) - + # Audio parameters parser.add_argument( - "--sample-rate-hz", - type=int, - help="Number of frames per second in audio streamed from a microphone.", + "--sample-rate-hz", + type=int, + help="Number of frames per second in audio streamed from a microphone.", default=16000 ) parser.add_argument( - "--num-channels", - type=int, - help="Number of audio channels.", + "--num-channels", + type=int, + help="Number of audio channels.", default=1 ) parser.add_argument( - "--file-streaming-chunk", - type=int, - default=1600, + "--file-streaming-chunk", + type=int, + default=1600, help="Maximum number of frames in one chunk sent to server." ) - + # Output configuration parser.add_argument( - "--output-text", - type=str, + "--output-text", + type=str, help="Output text file" ) parser.add_argument( - "--prompt", - default="", + "--prompt", + default="", help="Prompt to be used for transcription." ) - - parser.add_argument( - "--server", - default="localhost:9090", - help="URI to WebSocket server endpoint." - ) - + + # Add connection parameters + parser = add_connection_argparse_parameters(parser) + # Add ASR and realtime configuration parameters parser = add_asr_config_argparse_parameters( - parser, - max_alternatives=True, - profanity_filter=True, + parser, + max_alternatives=True, + profanity_filter=True, word_time_offsets=True ) parser = add_realtime_config_argparse_parameters(parser) - + args = parser.parse_args() - + # Validate input configuration if not args.mic and not args.input_file: parser.error("Either --input-file or --mic must be specified") - + if args.mic and args.input_file: parser.error("Cannot specify both --input-file and --mic") - + return args @@ -138,25 +136,25 @@ def signal_handler(sig, frame): async def create_audio_iterator(args): """Create appropriate audio iterator based on input type. - + Args: args: Command line arguments containing input configuration - + Returns: Audio iterator for streaming audio data """ if args.mic: # Only import when using microphone from riva.client.audio_io import MicrophoneStream - + # Get default device index if not specified device_index = args.input_device if device_index is None: device_index = get_default_device_index() - + audio_chunk_iterator = MicrophoneStream( - args.sample_rate_hz, - args.file_streaming_chunk, + args.sample_rate_hz, + args.file_streaming_chunk, device=device_index ) args.num_channels = 1 @@ -166,29 +164,29 @@ async def create_audio_iterator(args): args.sample_rate_hz = wav_parameters['framerate'] args.num_channels = wav_parameters['nchannels'] audio_chunk_iterator = AudioChunkFileIterator( - args.input_file, - args.file_streaming_chunk, + args.input_file, + args.file_streaming_chunk, delay_callback=None ) - + return audio_chunk_iterator async def run_transcription(args): """Run the transcription process. - + Args: args: Command line arguments containing all configuration """ client = RealtimeClient(args=args) - + try: # Create audio iterator audio_chunk_iterator = await create_audio_iterator(args) - + # Connect and start transcription await client.connect() - + # Run send and receive tasks concurrently send_task = asyncio.create_task( client.send_audio_chunks(audio_chunk_iterator) @@ -196,13 +194,13 @@ async def run_transcription(args): receive_task = asyncio.create_task( client.receive_responses() ) - + await asyncio.gather(send_task, receive_task) - + # Save results if output file specified if args.output_text: client.save_responses(args.output_text) - + except Exception as e: print(f"Error: {e}") raise @@ -213,7 +211,7 @@ async def run_transcription(args): async def main() -> None: """Main entry point for the realtime ASR client.""" args = parse_args() - + # Handle list devices option if args.list_devices: try: @@ -222,7 +220,7 @@ async def main() -> None: except ModuleNotFoundError: print("PyAudio not available. Please install PyAudio to list audio devices.") return - + setup_signal_handler() try: