diff --git a/riva/client/__init__.py b/riva/client/__init__.py index 7656bd69..108263aa 100644 --- a/riva/client/__init__.py +++ b/riva/client/__init__.py @@ -41,3 +41,6 @@ from riva.client.proto.riva_nmt_pb2 import StreamingTranslateSpeechToSpeechConfig, TranslationConfig, SynthesizeSpeechConfig, StreamingTranslateSpeechToTextConfig from riva.client.tts import SpeechSynthesisService from riva.client.nmt import NeuralMachineTranslationClient + +# Async extensions (grpc.aio) +from riva.client.asr_async import ASRServiceAsync, AsyncAuth diff --git a/riva/client/asr_async.py b/riva/client/asr_async.py new file mode 100644 index 00000000..bd939665 --- /dev/null +++ b/riva/client/asr_async.py @@ -0,0 +1,332 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Async ASR client using grpc.aio. + +This module provides async/await support for Riva ASR streaming, +enabling efficient high-concurrency scenarios without thread overhead. + +Example: + async with AsyncAuth(uri="localhost:50051") as auth: + service = ASRServiceAsync(auth) + async for response in service.streaming_recognize(audio_gen, config): + print(response.results) +""" + +from __future__ import annotations + +import asyncio +from typing import AsyncIterator, Sequence + +import grpc +import grpc.aio + +from riva.client.proto import riva_asr_pb2 as rasr +from riva.client.proto import riva_asr_pb2_grpc as rasr_srv + +__all__ = ["AsyncAuth", "ASRServiceAsync"] + + +class AsyncAuth: + """Async-compatible authentication and channel management. + + Provides lazy channel creation with thread-safe initialization. + Supports both insecure and SSL connections. + + Args: + uri: Riva server address (host:port) + use_ssl: Enable SSL/TLS + ssl_root_cert: Path to root CA certificate (optional) + ssl_client_cert: Path to client certificate for mTLS (optional) + ssl_client_key: Path to client private key for mTLS (optional) + metadata: List of (key, value) tuples for request metadata + options: Additional gRPC channel options + + Example: + # Simple insecure connection + auth = AsyncAuth(uri="localhost:50051") + + # SSL with custom cert + auth = AsyncAuth(uri="riva.example.com:443", use_ssl=True) + + # With API key metadata + auth = AsyncAuth( + uri="riva.example.com:443", + use_ssl=True, + metadata=[("x-api-key", "your-key")] + ) + + # As context manager (recommended) + async with AsyncAuth(uri="localhost:50051") as auth: + service = ASRServiceAsync(auth) + # use service... + """ + + # Default channel options for real-time streaming + DEFAULT_OPTIONS: Sequence[tuple[str, int | bool]] = ( + ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB + ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB + ("grpc.keepalive_time_ms", 10_000), # 10 sec + ("grpc.keepalive_timeout_ms", 5_000), # 5 sec + ("grpc.keepalive_permit_without_calls", True), + ("grpc.http2.min_ping_interval_without_data_ms", 5_000), + ) + + def __init__( + self, + uri: str, + use_ssl: bool = False, + ssl_root_cert: str | None = None, + ssl_client_cert: str | None = None, + ssl_client_key: str | None = None, + metadata: Sequence[tuple[str, str]] | None = None, + options: Sequence[tuple[str, int | bool | str]] | None = None, + ) -> None: + self.uri = uri + self.use_ssl = use_ssl + self.ssl_root_cert = ssl_root_cert + self.ssl_client_cert = ssl_client_cert + self.ssl_client_key = ssl_client_key + self.metadata = list(metadata) if metadata else [] + self._options = list(options) if options else list(self.DEFAULT_OPTIONS) + + self._channel: grpc.aio.Channel | None = None + self._lock = asyncio.Lock() + + async def get_channel(self) -> grpc.aio.Channel: + """Get or create the async gRPC channel. + + Thread-safe: uses asyncio.Lock to ensure single channel creation + even under concurrent access. Uses double-checked locking for + fast-path optimization when channel already exists. + + Returns: + The async gRPC channel + """ + # Fast path: channel already exists + if self._channel is not None: + return self._channel + # Slow path: acquire lock and create channel + async with self._lock: + if self._channel is None: + self._channel = await self._create_channel() + return self._channel + + async def _create_channel(self) -> grpc.aio.Channel: + """Create the appropriate channel type based on SSL settings.""" + if self.use_ssl: + credentials = await self._create_ssl_credentials() + return grpc.aio.secure_channel( + self.uri, + credentials, + options=self._options, + ) + else: + return grpc.aio.insecure_channel( + self.uri, + options=self._options, + ) + + async def _create_ssl_credentials(self) -> grpc.ChannelCredentials: + """Create SSL credentials from certificate files. + + Uses asyncio.to_thread() for non-blocking file I/O. + """ + + def _read_file(path: str) -> bytes: + with open(path, "rb") as f: + return f.read() + + root_cert = None + client_cert = None + client_key = None + + if self.ssl_root_cert: + root_cert = await asyncio.to_thread(_read_file, self.ssl_root_cert) + + if self.ssl_client_cert: + client_cert = await asyncio.to_thread(_read_file, self.ssl_client_cert) + + if self.ssl_client_key: + client_key = await asyncio.to_thread(_read_file, self.ssl_client_key) + + return grpc.ssl_channel_credentials( + root_certificates=root_cert, + private_key=client_key, + certificate_chain=client_cert, + ) + + def get_auth_metadata(self) -> list[tuple[str, str]]: + """Get metadata to include with RPC calls. + + Returns: + List of (key, value) metadata tuples + """ + return self.metadata + + async def close(self) -> None: + """Close the channel and release resources.""" + async with self._lock: + if self._channel is not None: + await self._channel.close() + self._channel = None + + async def __aenter__(self) -> "AsyncAuth": + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + """Async context manager exit - ensures cleanup.""" + await self.close() + + +class ASRServiceAsync: + """Async ASR service using grpc.aio. + + Provides async streaming and batch recognition methods that can handle + many concurrent streams without thread overhead. + + Args: + auth: AsyncAuth instance for channel management + + Example: + auth = AsyncAuth(uri="localhost:50051") + service = ASRServiceAsync(auth) + + # Streaming recognition + async def audio_generator(): + while audio_available: + yield audio_chunk + + async for response in service.streaming_recognize( + audio_generator(), + streaming_config + ): + for result in response.results: + print(result.alternatives[0].transcript) + + await auth.close() + """ + + def __init__(self, auth: AsyncAuth) -> None: + self.auth = auth + self._stub: "rasr_srv.RivaSpeechRecognitionStub | None" = None + self._stub_lock = asyncio.Lock() + # Cache metadata reference to avoid repeated method calls + self._metadata = auth.get_auth_metadata() or None + + async def _get_stub(self) -> "rasr_srv.RivaSpeechRecognitionStub": + """Get or create the gRPC stub. + + Thread-safe stub creation with double-checked locking for + fast-path optimization when stub already exists. + """ + # Fast path: stub already exists + if self._stub is not None: + return self._stub + # Slow path: acquire lock and create stub + async with self._stub_lock: + if self._stub is None: + channel = await self.auth.get_channel() + self._stub = rasr_srv.RivaSpeechRecognitionStub(channel) + return self._stub + + async def streaming_recognize( + self, + audio_chunks: AsyncIterator[bytes], + streaming_config: "rasr.StreamingRecognitionConfig", + ) -> AsyncIterator["rasr.StreamingRecognizeResponse"]: + """Perform async streaming speech recognition. + + This is the primary method for real-time speech recognition. + Audio is streamed to the server and partial/final results are + yielded as they become available. + + Args: + audio_chunks: Async iterator yielding raw audio bytes + (LINEAR_PCM format recommended, 16-bit, mono) + streaming_config: Configuration including sample rate, + language, and interim_results setting + + Yields: + StreamingRecognizeResponse objects containing transcription + results. Check result.is_final to distinguish partial from + final results. + + Raises: + grpc.aio.AioRpcError: On gRPC communication errors + + Example: + config = StreamingRecognitionConfig( + config=RecognitionConfig( + encoding=AudioEncoding.LINEAR_PCM, + sample_rate_hertz=16000, + language_code="en-US", + ), + interim_results=True, + ) + + async for response in service.streaming_recognize( + audio_generator(), config + ): + for result in response.results: + transcript = result.alternatives[0].transcript + if result.is_final: + print(f"Final: {transcript}") + else: + print(f"Partial: {transcript}") + """ + stub = await self._get_stub() + metadata = self._metadata + + async def request_generator() -> AsyncIterator[rasr.StreamingRecognizeRequest]: + # First request: config only (no audio) + yield rasr.StreamingRecognizeRequest(streaming_config=streaming_config) + # Subsequent requests: audio only + async for chunk in audio_chunks: + yield rasr.StreamingRecognizeRequest(audio_content=chunk) + + call = stub.StreamingRecognize( + request_generator(), + metadata=metadata, + ) + + async for response in call: + yield response + + async def recognize( + self, + audio_bytes: bytes, + config: "rasr.RecognitionConfig", + ) -> "rasr.RecognizeResponse": + """Perform async batch (offline) speech recognition. + + Use this for complete audio files rather than streaming. + + Args: + audio_bytes: Complete audio data + config: Recognition configuration + + Returns: + RecognizeResponse with transcription results + + Raises: + grpc.aio.AioRpcError: On gRPC communication errors + """ + stub = await self._get_stub() + metadata = self._metadata + + request = rasr.RecognizeRequest(config=config, audio=audio_bytes) + return await stub.Recognize(request, metadata=metadata) + + async def get_config(self) -> "rasr.RivaSpeechRecognitionConfigResponse": + """Get the server's speech recognition configuration. + + Returns: + Configuration response with available models and settings + """ + stub = await self._get_stub() + metadata = self._metadata + + request = rasr.RivaSpeechRecognitionConfigRequest() + return await stub.GetRivaSpeechRecognitionConfig(request, metadata=metadata) diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 00000000..af04eba9 --- /dev/null +++ b/tests/benchmarks/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT diff --git a/tests/benchmarks/test_asr_async_benchmarks.py b/tests/benchmarks/test_asr_async_benchmarks.py new file mode 100644 index 00000000..3bc9e95a --- /dev/null +++ b/tests/benchmarks/test_asr_async_benchmarks.py @@ -0,0 +1,294 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Performance benchmarks for async ASR client. + +Run with a Riva server: + RIVA_URI=localhost:50051 pytest tests/benchmarks/ -v -s + +Run client overhead tests (no server required): + pytest tests/benchmarks/ -v -s -k "overhead" +""" + +from __future__ import annotations + +import asyncio +import os +import time +import wave +from pathlib import Path +from typing import AsyncIterator, List + +import pytest + +from riva.client.proto import riva_asr_pb2 as rasr +from riva.client.proto import riva_audio_pb2 as riva_audio + +# Audio sample directory +AUDIO_DIR = Path(__file__).parent.parent.parent / "data" / "examples" + +pytestmark = pytest.mark.benchmark + + +class TestClientOverhead: + """Measure client-side overhead without server.""" + + def test_protobuf_creation_overhead(self) -> None: + """Measure time to create StreamingRecognizeRequest messages. + + Target: < 100μs per request creation. + """ + iterations = 10000 + audio_chunk = b"x" * 3200 # 100ms of 16-bit 16kHz audio + + start = time.perf_counter() + for _ in range(iterations): + request = rasr.StreamingRecognizeRequest(audio_content=audio_chunk) + elapsed = time.perf_counter() - start + + avg_time_us = (elapsed / iterations) * 1_000_000 + print(f"\nProtobuf creation: {avg_time_us:.2f} μs/request") + print(f"Total time for {iterations} requests: {elapsed * 1000:.2f} ms") + + # Should be under 100μs per request + assert avg_time_us < 100, f"Protobuf creation too slow: {avg_time_us:.2f} μs" + + def test_config_creation_overhead(self) -> None: + """Measure time to create recognition config. + + Target: < 200μs per config creation. + """ + iterations = 5000 + + start = time.perf_counter() + for _ in range(iterations): + config = rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=16000, + language_code="en-US", + max_alternatives=1, + enable_automatic_punctuation=True, + ), + interim_results=True, + ) + elapsed = time.perf_counter() - start + + avg_time_us = (elapsed / iterations) * 1_000_000 + print(f"\nConfig creation: {avg_time_us:.2f} μs/config") + + assert avg_time_us < 200, f"Config creation too slow: {avg_time_us:.2f} μs" + + @pytest.mark.asyncio + async def test_async_generator_overhead(self) -> None: + """Measure overhead of async generator iteration. + + Target: < 50μs per async yield. + """ + iterations = 10000 + chunk = b"x" * 3200 + + async def audio_generator() -> AsyncIterator[bytes]: + for _ in range(iterations): + yield chunk + + start = time.perf_counter() + count = 0 + async for _ in audio_generator(): + count += 1 + elapsed = time.perf_counter() - start + + avg_time_us = (elapsed / count) * 1_000_000 + print(f"\nAsync generator: {avg_time_us:.2f} μs/yield") + + assert avg_time_us < 50, f"Async generator too slow: {avg_time_us:.2f} μs" + + +@pytest.mark.skipif( + not os.getenv("RIVA_URI"), + reason="RIVA_URI not set - skipping server benchmarks" +) +class TestEndToEndLatency: + """Benchmark end-to-end latency with a running server.""" + + @pytest.fixture + def en_us_sample(self) -> Path: + """Path to en-US sample audio file.""" + path = AUDIO_DIR / "en-US_sample.wav" + if not path.exists(): + pytest.skip(f"Audio sample not found: {path}") + return path + + @pytest.fixture + def riva_uri(self) -> str: + """Get Riva server URI from environment.""" + return os.environ["RIVA_URI"] + + @pytest.mark.asyncio + async def test_time_to_first_result( + self, riva_uri: str, en_us_sample: Path + ) -> None: + """Measure time from first audio chunk to first response.""" + from riva.client.asr_async import ASRServiceAsync, AsyncAuth + + with wave.open(str(en_us_sample), "rb") as wf: + sample_rate = wf.getframerate() + + async with AsyncAuth(uri=riva_uri) as auth: + service = ASRServiceAsync(auth) + + config = rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=sample_rate, + language_code="en-US", + max_alternatives=1, + ), + interim_results=True, + ) + + async def timed_audio_generator() -> AsyncIterator[bytes]: + with wave.open(str(en_us_sample), "rb") as wf: + while True: + data = wf.readframes(1600) # 100ms chunks + if not data: + break + yield data + + # Measure time to first response + start = time.perf_counter() + first_response_time = None + + async for response in service.streaming_recognize( + timed_audio_generator(), config + ): + if first_response_time is None: + first_response_time = time.perf_counter() - start + # Continue to consume all responses + + print(f"\nTime to first response: {first_response_time * 1000:.2f} ms") + + @pytest.mark.asyncio + async def test_concurrent_stream_throughput( + self, riva_uri: str, en_us_sample: Path + ) -> None: + """Compare 4 concurrent streams vs 4 sequential streams.""" + from riva.client.asr_async import ASRServiceAsync, AsyncAuth + + with wave.open(str(en_us_sample), "rb") as wf: + sample_rate = wf.getframerate() + + async def run_single_stream() -> float: + """Run one stream and return duration.""" + async with AsyncAuth(uri=riva_uri) as auth: + service = ASRServiceAsync(auth) + + config = rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=sample_rate, + language_code="en-US", + ), + interim_results=False, + ) + + async def audio_gen() -> AsyncIterator[bytes]: + with wave.open(str(en_us_sample), "rb") as wf: + while True: + data = wf.readframes(1600) + if not data: + break + yield data + + start = time.perf_counter() + async for _ in service.streaming_recognize(audio_gen(), config): + pass + return time.perf_counter() - start + + num_streams = 4 + + # Sequential execution + seq_start = time.perf_counter() + for _ in range(num_streams): + await run_single_stream() + sequential_time = time.perf_counter() - seq_start + + # Concurrent execution + conc_start = time.perf_counter() + await asyncio.gather(*[run_single_stream() for _ in range(num_streams)]) + concurrent_time = time.perf_counter() - conc_start + + speedup = sequential_time / concurrent_time + print(f"\n{num_streams} sequential streams: {sequential_time * 1000:.2f} ms") + print(f"{num_streams} concurrent streams: {concurrent_time * 1000:.2f} ms") + print(f"Speedup: {speedup:.2f}x") + + # Concurrent should be faster (expect at least 1.5x speedup) + assert speedup > 1.5, f"Concurrent speedup too low: {speedup:.2f}x" + + +@pytest.mark.skipif( + not os.getenv("RIVA_URI"), + reason="RIVA_URI not set - skipping server benchmarks" +) +class TestMemoryUsage: + """Benchmark memory usage patterns.""" + + @pytest.fixture + def en_us_sample(self) -> Path: + """Path to en-US sample audio file.""" + path = AUDIO_DIR / "en-US_sample.wav" + if not path.exists(): + pytest.skip(f"Audio sample not found: {path}") + return path + + @pytest.fixture + def riva_uri(self) -> str: + """Get Riva server URI from environment.""" + return os.environ["RIVA_URI"] + + @pytest.mark.asyncio + async def test_many_short_streams_no_leak( + self, riva_uri: str, en_us_sample: Path + ) -> None: + """Run many short streams to check for resource leaks.""" + import gc + + from riva.client.asr_async import ASRServiceAsync, AsyncAuth + + with wave.open(str(en_us_sample), "rb") as wf: + sample_rate = wf.getframerate() + # Read just first 0.5 seconds + short_audio = wf.readframes(sample_rate // 2) + + num_streams = 20 + + async def run_short_stream() -> None: + async with AsyncAuth(uri=riva_uri) as auth: + service = ASRServiceAsync(auth) + + config = rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=sample_rate, + language_code="en-US", + ), + interim_results=False, + ) + + async def audio_gen() -> AsyncIterator[bytes]: + # Single chunk + yield short_audio + + async for _ in service.streaming_recognize(audio_gen(), config): + pass + + # Run many streams + for i in range(num_streams): + await run_short_stream() + if i % 5 == 0: + gc.collect() + + # Force GC + gc.collect() + print(f"\nCompleted {num_streams} short streams") diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..fcf7fef7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Pytest configuration and shared fixtures.""" + +import pytest + + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line( + "markers", "integration: tests requiring a running Riva server (deselect with '-m \"not integration\"')" + ) + config.addinivalue_line( + "markers", "benchmark: performance benchmark tests (select with '-m benchmark')" + ) + config.addinivalue_line( + "markers", "asyncio: mark test as async (auto-handled by pytest-asyncio)" + ) + + +def pytest_collection_modifyitems(config, items): + """Auto-skip integration tests when RIVA_URI is not set.""" + import os + + if os.getenv("RIVA_URI"): + # RIVA_URI is set, don't skip integration tests + return + + skip_integration = pytest.mark.skip(reason="RIVA_URI not set") + for item in items: + if "integration" in item.keywords: + item.add_marker(skip_integration) diff --git a/tests/integration/python/__init__.py b/tests/integration/python/__init__.py new file mode 100644 index 00000000..af04eba9 --- /dev/null +++ b/tests/integration/python/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT diff --git a/tests/integration/python/conftest.py b/tests/integration/python/conftest.py new file mode 100644 index 00000000..92b1c591 --- /dev/null +++ b/tests/integration/python/conftest.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Integration test fixtures.""" + +import os +import wave +from pathlib import Path +from typing import AsyncIterator + +import pytest + +# Audio sample directory +AUDIO_DIR = Path(__file__).parent.parent.parent.parent / "data" / "examples" + + +@pytest.fixture +def riva_uri() -> str: + """Get Riva server URI from environment.""" + uri = os.getenv("RIVA_URI") + if not uri: + pytest.skip("RIVA_URI environment variable not set") + return uri + + +@pytest.fixture +def en_us_sample() -> Path: + """Path to en-US sample audio file.""" + path = AUDIO_DIR / "en-US_sample.wav" + if not path.exists(): + pytest.skip(f"Audio sample not found: {path}") + return path + + +@pytest.fixture +def de_de_sample() -> Path: + """Path to de-DE sample audio file.""" + path = AUDIO_DIR / "de-DE_sample.wav" + if not path.exists(): + pytest.skip(f"Audio sample not found: {path}") + return path + + +def get_wav_params(wav_path: Path) -> dict: + """Get WAV file parameters.""" + with wave.open(str(wav_path), "rb") as wf: + return { + "sample_rate": wf.getframerate(), + "sample_width": wf.getsampwidth(), + "channels": wf.getnchannels(), + "n_frames": wf.getnframes(), + } + + +async def audio_chunk_generator( + wav_path: Path, chunk_frames: int = 1600 +) -> AsyncIterator[bytes]: + """Yield audio chunks from WAV file. + + Args: + wav_path: Path to WAV file + chunk_frames: Number of frames per chunk (default 1600 = 100ms at 16kHz) + + Yields: + Audio data chunks + """ + with wave.open(str(wav_path), "rb") as wf: + while True: + data = wf.readframes(chunk_frames) + if not data: + break + yield data + + +def read_wav_audio(wav_path: Path) -> bytes: + """Read entire audio content from WAV file.""" + with wave.open(str(wav_path), "rb") as wf: + return wf.readframes(wf.getnframes()) diff --git a/tests/integration/python/test_asr_async_integration.py b/tests/integration/python/test_asr_async_integration.py new file mode 100644 index 00000000..20d0a473 --- /dev/null +++ b/tests/integration/python/test_asr_async_integration.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Integration tests for async ASR client. + +These tests require a running Riva server. Set RIVA_URI environment variable +to point to your server (e.g., RIVA_URI=localhost:50051). + +Run: RIVA_URI=localhost:50051 pytest -m integration -v +Skip: pytest -m "not integration" +""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import pytest + +from riva.client.asr_async import ASRServiceAsync, AsyncAuth +from riva.client.proto import riva_asr_pb2 as rasr +from riva.client.proto import riva_audio_pb2 as riva_audio + +from .conftest import audio_chunk_generator, get_wav_params, read_wav_audio + +pytestmark = pytest.mark.integration + + +class TestStreamingRecognitionIntegration: + """Integration tests for streaming recognition.""" + + @pytest.mark.asyncio + async def test_streaming_recognize_returns_transcript( + self, riva_uri: str, en_us_sample: Path + ) -> None: + """Stream en-US_sample.wav, verify non-empty transcript.""" + wav_params = get_wav_params(en_us_sample) + + async with AsyncAuth(uri=riva_uri) as auth: + service = ASRServiceAsync(auth) + + config = rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=wav_params["sample_rate"], + language_code="en-US", + max_alternatives=1, + enable_automatic_punctuation=True, + ), + interim_results=False, + ) + + transcripts = [] + async for response in service.streaming_recognize( + audio_chunk_generator(en_us_sample), config + ): + for result in response.results: + if result.is_final and result.alternatives: + transcripts.append(result.alternatives[0].transcript) + + # Should have at least one transcript + assert len(transcripts) > 0 + full_transcript = " ".join(transcripts) + assert len(full_transcript) > 0 + print(f"Transcript: {full_transcript}") + + @pytest.mark.asyncio + async def test_interim_results_received( + self, riva_uri: str, en_us_sample: Path + ) -> None: + """With interim_results=True, partial results are yielded.""" + wav_params = get_wav_params(en_us_sample) + + async with AsyncAuth(uri=riva_uri) as auth: + service = ASRServiceAsync(auth) + + config = rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=wav_params["sample_rate"], + language_code="en-US", + max_alternatives=1, + ), + interim_results=True, # Enable interim results + ) + + interim_count = 0 + final_count = 0 + + async for response in service.streaming_recognize( + audio_chunk_generator(en_us_sample), config + ): + for result in response.results: + if result.is_final: + final_count += 1 + else: + interim_count += 1 + + # Should have received both interim and final results + assert final_count > 0 + # Note: interim results depend on audio length and server config + print(f"Interim: {interim_count}, Final: {final_count}") + + @pytest.mark.asyncio + async def test_concurrent_streams( + self, riva_uri: str, en_us_sample: Path + ) -> None: + """Multiple concurrent streams all succeed.""" + wav_params = get_wav_params(en_us_sample) + num_streams = 3 + + async def run_stream(stream_id: int) -> str: + """Run a single streaming recognition.""" + async with AsyncAuth(uri=riva_uri) as auth: + service = ASRServiceAsync(auth) + + config = rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=wav_params["sample_rate"], + language_code="en-US", + max_alternatives=1, + ), + interim_results=False, + ) + + transcripts = [] + async for response in service.streaming_recognize( + audio_chunk_generator(en_us_sample), config + ): + for result in response.results: + if result.is_final and result.alternatives: + transcripts.append(result.alternatives[0].transcript) + + return " ".join(transcripts) + + # Run streams concurrently + results = await asyncio.gather(*[ + run_stream(i) for i in range(num_streams) + ]) + + # All streams should succeed with non-empty transcripts + assert len(results) == num_streams + for i, transcript in enumerate(results): + assert len(transcript) > 0, f"Stream {i} returned empty transcript" + print(f"Stream {i}: {transcript[:50]}...") + + +class TestBatchRecognitionIntegration: + """Integration tests for batch recognition. + + Note: These tests require a Riva server with offline/batch recognition support. + The parakeet model only supports streaming, so these tests may be skipped. + """ + + @pytest.mark.asyncio + async def test_batch_recognize_returns_transcript( + self, riva_uri: str, en_us_sample: Path + ) -> None: + """Batch recognize en-US_sample.wav. + + This test requires offline recognition support. + Skip if server only supports streaming. + """ + import grpc + + wav_params = get_wav_params(en_us_sample) + audio_data = read_wav_audio(en_us_sample) + + async with AsyncAuth(uri=riva_uri) as auth: + service = ASRServiceAsync(auth) + + config = rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=wav_params["sample_rate"], + language_code="en-US", + max_alternatives=1, + enable_automatic_punctuation=True, + ) + + try: + response = await service.recognize(audio_data, config) + + # Should have results + assert len(response.results) > 0 + transcript = response.results[0].alternatives[0].transcript + assert len(transcript) > 0 + print(f"Batch transcript: {transcript}") + except grpc.aio.AioRpcError as e: + if "offline" in str(e.details()).lower() or "unavailable model" in str(e.details()).lower(): + pytest.skip("Server does not support offline/batch recognition") + + +class TestConnectionIntegration: + """Integration tests for connection handling.""" + + @pytest.mark.asyncio + async def test_reconnect_after_close( + self, riva_uri: str, en_us_sample: Path + ) -> None: + """Connection reestablishes after close().""" + wav_params = get_wav_params(en_us_sample) + + def make_config(): + return rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=wav_params["sample_rate"], + language_code="en-US", + max_alternatives=1, + ), + interim_results=True, # Enable interim for faster response + ) + + async def run_recognition(): + async with AsyncAuth(uri=riva_uri) as auth: + service = ASRServiceAsync(auth) + transcripts = [] + async for response in service.streaming_recognize( + audio_chunk_generator(en_us_sample), make_config() + ): + for result in response.results: + if result.is_final and result.alternatives: + transcripts.append(result.alternatives[0].transcript) + return " ".join(transcripts) + + # First recognition + transcript1 = await run_recognition() + print(f"Transcript 1: {transcript1}") + + # Connection is closed after first context manager exits + + # Second recognition with new connection + transcript2 = await run_recognition() + print(f"Transcript 2: {transcript2}") + + # Both should have succeeded + assert len(transcript1) > 0, "First recognition returned no results" + assert len(transcript2) > 0, "Second recognition returned no results" + + @pytest.mark.asyncio + async def test_get_config_returns_available_models( + self, riva_uri: str + ) -> None: + """get_config returns server configuration.""" + async with AsyncAuth(uri=riva_uri) as auth: + service = ASRServiceAsync(auth) + config = await service.get_config() + + # Should have some configuration data + assert config is not None + # The exact fields depend on server configuration + print(f"Server config: {config}") diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index fe0ea3e9..1d788d97 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -1,12 +1,72 @@ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: MIT -from typing import Tuple -from unittest.mock import Mock +from typing import List, Tuple +from unittest.mock import AsyncMock, Mock, MagicMock def set_auth_mock() -> Tuple[Mock, str]: + """Create mock Auth for synchronous testing.""" auth = Mock() return_value_of_get_auth_metadata = 'return_value_of_get_auth_metadata' auth.get_auth_metadata = Mock(return_value=return_value_of_get_auth_metadata) return auth, return_value_of_get_auth_metadata + + +def set_async_auth_mock() -> Tuple[AsyncMock, List[Tuple[str, str]]]: + """Create mock AsyncAuth for async testing. + + Returns: + Tuple of (auth mock, metadata list) + """ + auth = AsyncMock() + metadata = [("x-api-key", "test")] + auth.get_auth_metadata.return_value = metadata + auth.get_channel = AsyncMock(return_value=MagicMock()) + auth._channel = MagicMock() + return auth, metadata + + +def create_mock_streaming_response(transcript: str, is_final: bool) -> MagicMock: + """Create a mock StreamingRecognizeResponse. + + Args: + transcript: The transcript text + is_final: Whether this is a final result + + Returns: + MagicMock configured like StreamingRecognizeResponse + """ + response = MagicMock() + + # Create result structure + result = MagicMock() + result.is_final = is_final + + # Create alternative with transcript + alternative = MagicMock() + alternative.transcript = transcript + result.alternatives = [alternative] + + response.results = [result] + return response + + +def create_mock_recognize_response(transcript: str) -> MagicMock: + """Create a mock RecognizeResponse for batch recognition. + + Args: + transcript: The transcript text + + Returns: + MagicMock configured like RecognizeResponse + """ + response = MagicMock() + + result = MagicMock() + alternative = MagicMock() + alternative.transcript = transcript + result.alternatives = [alternative] + + response.results = [result] + return response diff --git a/tests/unit/test_asr_async.py b/tests/unit/test_asr_async.py new file mode 100644 index 00000000..79d122be --- /dev/null +++ b/tests/unit/test_asr_async.py @@ -0,0 +1,598 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: MIT + +"""Tests for async ASR client. + +Unit tests focus on observable behavior and contracts, not internal implementation. +Integration tests require RIVA_URI environment variable. +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from riva.client.asr_async import ASRServiceAsync, AsyncAuth + +from .helpers import create_mock_streaming_response + + +# Async iterator helper for tests +async def aiter(items): + for item in items: + yield item + + +class TestAsyncAuthChannel: + """Tests for AsyncAuth channel creation behavior.""" + + @pytest.mark.asyncio + async def test_insecure_channel_created(self) -> None: + """Insecure channel created when use_ssl=False.""" + with patch("grpc.aio.insecure_channel") as mock_channel: + mock_ch = MagicMock() + mock_ch.close = AsyncMock() + mock_channel.return_value = mock_ch + + auth = AsyncAuth(uri="localhost:50051") + channel = await auth.get_channel() + + mock_channel.assert_called_once() + assert channel is not None + await auth.close() + + @pytest.mark.asyncio + async def test_secure_channel_created_with_ssl(self) -> None: + """Secure channel created when use_ssl=True.""" + with patch("grpc.aio.secure_channel") as mock_channel, \ + patch("grpc.ssl_channel_credentials") as mock_creds: + mock_ch = MagicMock() + mock_ch.close = AsyncMock() + mock_channel.return_value = mock_ch + mock_creds.return_value = MagicMock() + + auth = AsyncAuth(uri="localhost:50051", use_ssl=True) + channel = await auth.get_channel() + + mock_creds.assert_called_once() + mock_channel.assert_called_once() + await auth.close() + + @pytest.mark.asyncio + async def test_close_allows_reconnection(self) -> None: + """After close(), get_channel() creates new channel.""" + with patch("grpc.aio.insecure_channel") as mock_channel: + mock_ch1 = MagicMock() + mock_ch1.close = AsyncMock() + mock_ch2 = MagicMock() + mock_ch2.close = AsyncMock() + mock_channel.side_effect = [mock_ch1, mock_ch2] + + auth = AsyncAuth(uri="localhost:50051") + channel1 = await auth.get_channel() + await auth.close() + + # Should be able to get a new channel after close + channel2 = await auth.get_channel() + + assert mock_channel.call_count == 2 + await auth.close() + + @pytest.mark.asyncio + async def test_context_manager_closes_channel(self) -> None: + """Async context manager properly closes channel.""" + with patch("grpc.aio.insecure_channel") as mock_channel: + mock_ch = MagicMock() + mock_ch.close = AsyncMock() + mock_channel.return_value = mock_ch + + async with AsyncAuth(uri="localhost:50051") as auth: + await auth.get_channel() + + mock_ch.close.assert_awaited_once() + + @pytest.mark.asyncio + async def test_same_channel_returned_on_multiple_calls(self) -> None: + """Multiple get_channel() calls return the same channel instance.""" + with patch("grpc.aio.insecure_channel") as mock_channel: + mock_ch = MagicMock() + mock_ch.close = AsyncMock() + mock_channel.return_value = mock_ch + + auth = AsyncAuth(uri="localhost:50051") + channel1 = await auth.get_channel() + channel2 = await auth.get_channel() + channel3 = await auth.get_channel() + + # Behavioral assertion: same channel returned + assert channel1 is channel2 is channel3 + await auth.close() + + +class TestAsyncAuthMetadata: + """Tests for AsyncAuth metadata handling.""" + + @pytest.mark.asyncio + async def test_metadata_preserved(self) -> None: + """Metadata is stored and retrievable.""" + metadata = [("x-api-key", "test-key"), ("x-custom", "value")] + auth = AsyncAuth(uri="localhost:50051", metadata=metadata) + + assert auth.get_auth_metadata() == metadata + await auth.close() + + @pytest.mark.asyncio + async def test_empty_metadata_returns_empty_list(self) -> None: + """No metadata returns empty list.""" + auth = AsyncAuth(uri="localhost:50051") + assert auth.get_auth_metadata() == [] + await auth.close() + + +class TestAsyncAuthChannelOptions: + """Tests for channel options behavior.""" + + @pytest.mark.asyncio + async def test_default_options_applied(self) -> None: + """Default channel options are applied.""" + with patch("grpc.aio.insecure_channel") as mock_channel: + mock_ch = MagicMock() + mock_ch.close = AsyncMock() + mock_channel.return_value = mock_ch + + auth = AsyncAuth(uri="localhost:50051") + await auth.get_channel() + + call_kwargs = mock_channel.call_args + options = call_kwargs[1]["options"] + option_names = [o[0] for o in options] + assert "grpc.keepalive_time_ms" in option_names + await auth.close() + + @pytest.mark.asyncio + async def test_custom_options_override(self) -> None: + """Custom options can be provided.""" + with patch("grpc.aio.insecure_channel") as mock_channel: + mock_ch = MagicMock() + mock_ch.close = AsyncMock() + mock_channel.return_value = mock_ch + + custom_options = [("grpc.max_send_message_length", 100)] + auth = AsyncAuth(uri="localhost:50051", options=custom_options) + await auth.get_channel() + + call_kwargs = mock_channel.call_args + options = call_kwargs[1]["options"] + assert ("grpc.max_send_message_length", 100) in options + await auth.close() + + +class TestAsyncAuthSSL: + """Tests for SSL credential handling.""" + + @pytest.mark.asyncio + async def test_ssl_credentials_reads_cert_files(self, tmp_path) -> None: + """SSL credential loading from actual files.""" + root_cert = tmp_path / "root.pem" + client_cert = tmp_path / "client.pem" + client_key = tmp_path / "client.key" + + root_cert.write_bytes(b"root-cert-content") + client_cert.write_bytes(b"client-cert-content") + client_key.write_bytes(b"client-key-content") + + with patch("grpc.aio.secure_channel") as mock_channel, \ + patch("grpc.ssl_channel_credentials") as mock_creds: + mock_ch = MagicMock() + mock_ch.close = AsyncMock() + mock_channel.return_value = mock_ch + mock_creds.return_value = MagicMock() + + auth = AsyncAuth( + uri="localhost:50051", + use_ssl=True, + ssl_root_cert=str(root_cert), + ssl_client_cert=str(client_cert), + ssl_client_key=str(client_key), + ) + await auth.get_channel() + + mock_creds.assert_called_once_with( + root_certificates=b"root-cert-content", + private_key=b"client-key-content", + certificate_chain=b"client-cert-content", + ) + await auth.close() + + @pytest.mark.asyncio + async def test_ssl_credentials_file_not_found(self) -> None: + """Error handling for missing cert files.""" + auth = AsyncAuth( + uri="localhost:50051", + use_ssl=True, + ssl_root_cert="/nonexistent/path/root.pem", + ) + + with pytest.raises(FileNotFoundError): + await auth.get_channel() + + +class TestStreamingRecognizeContract: + """Test the public contract of streaming_recognize.""" + + @pytest.fixture + def mock_auth(self) -> AsyncAuth: + """Create mock auth with mocked channel.""" + auth = AsyncAuth(uri="localhost:50051") + auth._channel = MagicMock() + return auth + + @pytest.mark.asyncio + async def test_streaming_recognize_calls_stub_with_generator( + self, mock_auth: AsyncAuth + ) -> None: + """streaming_recognize passes request generator to gRPC stub.""" + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_call = MagicMock() + mock_call.__aiter__ = lambda self: aiter([]) + mock_stub.StreamingRecognize.return_value = mock_call + mock_stub_cls.return_value = mock_stub + + # Import actual proto to create real config + from riva.client.proto import riva_asr_pb2 as rasr + from riva.client.proto import riva_audio_pb2 as riva_audio + + service = ASRServiceAsync(mock_auth) + config = rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + encoding=riva_audio.AudioEncoding.LINEAR_PCM, + sample_rate_hertz=16000, + language_code="en-US", + ) + ) + + async def audio_gen(): + yield b"audio1" + yield b"audio2" + + async for _ in service.streaming_recognize(audio_gen(), config): + pass + + # Verify StreamingRecognize was called with a generator + mock_stub.StreamingRecognize.assert_called_once() + call_args = mock_stub.StreamingRecognize.call_args + # First positional arg should be the request generator (async generator) + assert call_args[0][0] is not None + + @pytest.mark.asyncio + async def test_streaming_recognize_sends_requests(self, mock_auth: AsyncAuth) -> None: + """StreamingRecognize sends config first, then audio chunks.""" + # Import actual proto + from riva.client.proto import riva_asr_pb2 as rasr + + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_call = MagicMock() + mock_call.__aiter__ = lambda self: aiter([]) + mock_stub.StreamingRecognize.return_value = mock_call + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + config = rasr.StreamingRecognitionConfig( + config=rasr.RecognitionConfig( + sample_rate_hertz=16000, + language_code="en-US", + ) + ) + + async def audio_gen(): + yield b"chunk1" + yield b"chunk2" + + async for _ in service.streaming_recognize(audio_gen(), config): + pass + + mock_stub.StreamingRecognize.assert_called_once() + + @pytest.mark.asyncio + async def test_responses_yielded_as_received(self, mock_auth: AsyncAuth) -> None: + """Responses from server are yielded to caller.""" + mock_responses = [ + create_mock_streaming_response("partial", is_final=False), + create_mock_streaming_response("final transcript", is_final=True), + ] + + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_call = MagicMock() + mock_call.__aiter__ = lambda self: aiter(mock_responses) + mock_stub.StreamingRecognize.return_value = mock_call + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + mock_config = MagicMock() + + async def audio_gen(): + yield b"audio" + + responses = [] + async for response in service.streaming_recognize(audio_gen(), mock_config): + responses.append(response) + + assert len(responses) == 2 + assert responses == mock_responses + + @pytest.mark.asyncio + async def test_metadata_passed_to_streaming_call(self, mock_auth: AsyncAuth) -> None: + """Auth metadata is passed to streaming calls.""" + mock_auth.metadata = [("x-api-key", "test-key")] + + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_call = MagicMock() + mock_call.__aiter__ = lambda self: aiter([]) + mock_stub.StreamingRecognize.return_value = mock_call + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + mock_config = MagicMock() + + async def audio_gen(): + yield b"audio" + + async for _ in service.streaming_recognize(audio_gen(), mock_config): + pass + + call_kwargs = mock_stub.StreamingRecognize.call_args + assert call_kwargs[1]["metadata"] == [("x-api-key", "test-key")] + + +class TestRecognizeContract: + """Test batch recognition contract.""" + + @pytest.fixture + def mock_auth(self) -> AsyncAuth: + """Create mock auth with mocked channel.""" + auth = AsyncAuth(uri="localhost:50051") + auth._channel = MagicMock() + return auth + + @pytest.mark.asyncio + async def test_recognize_calls_stub_with_request(self, mock_auth: AsyncAuth) -> None: + """recognize() creates RecognizeRequest with config and audio.""" + from riva.client.proto import riva_asr_pb2 as rasr + + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_response = MagicMock() + captured_request = None + + async def capture_recognize(request, **kwargs): + nonlocal captured_request + captured_request = request + return mock_response + + mock_stub.Recognize = capture_recognize + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + config = rasr.RecognitionConfig( + sample_rate_hertz=16000, + language_code="en-US", + ) + + await service.recognize(b"audio_data", config) + + # Verify request was created with config and audio + assert captured_request is not None + assert captured_request.audio == b"audio_data" + assert captured_request.config.sample_rate_hertz == 16000 + assert captured_request.config.language_code == "en-US" + + @pytest.mark.asyncio + async def test_metadata_passed_to_recognize_call(self, mock_auth: AsyncAuth) -> None: + """Auth metadata is passed to batch recognition.""" + from riva.client.proto import riva_asr_pb2 as rasr + + mock_auth.metadata = [("x-api-key", "test-key"), ("x-custom", "value")] + + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_response = MagicMock() + mock_stub.Recognize = AsyncMock(return_value=mock_response) + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + config = rasr.RecognitionConfig( + sample_rate_hertz=16000, + language_code="en-US", + ) + + await service.recognize(b"audio_data", config) + + call_kwargs = mock_stub.Recognize.call_args + assert call_kwargs[1]["metadata"] == [("x-api-key", "test-key"), ("x-custom", "value")] + + +class TestGetConfigContract: + """Test get_config method contract.""" + + @pytest.fixture + def mock_auth(self) -> AsyncAuth: + """Create mock auth with mocked channel.""" + auth = AsyncAuth(uri="localhost:50051") + auth._channel = MagicMock() + return auth + + @pytest.mark.asyncio + async def test_returns_config_response(self, mock_auth: AsyncAuth) -> None: + """get_config returns the server response.""" + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_response = MagicMock() + mock_response.model_config = ["model1", "model2"] # Add some data + mock_stub.GetRivaSpeechRecognitionConfig = AsyncMock(return_value=mock_response) + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + result = await service.get_config() + + assert result is mock_response + mock_stub.GetRivaSpeechRecognitionConfig.assert_awaited_once() + + @pytest.mark.asyncio + async def test_metadata_passed_to_get_config_call(self, mock_auth: AsyncAuth) -> None: + """Auth metadata is passed to get_config.""" + mock_auth.metadata = [("x-api-key", "test-key")] + + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_stub.GetRivaSpeechRecognitionConfig = AsyncMock(return_value=MagicMock()) + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + await service.get_config() + + call_kwargs = mock_stub.GetRivaSpeechRecognitionConfig.call_args + assert call_kwargs[1]["metadata"] == [("x-api-key", "test-key")] + + +class TestStreamingRecognizeEdgeCases: + """Tests for edge cases in streaming recognition.""" + + @pytest.fixture + def mock_auth(self) -> AsyncAuth: + """Create mock auth with mocked channel.""" + auth = AsyncAuth(uri="localhost:50051") + auth._channel = MagicMock() + return auth + + @pytest.mark.asyncio + async def test_empty_audio_generator(self, mock_auth: AsyncAuth) -> None: + """Streaming with no audio chunks still sends config.""" + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_response = MagicMock() + mock_call = MagicMock() + mock_call.__aiter__ = lambda self: aiter([mock_response]) + mock_stub.StreamingRecognize.return_value = mock_call + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + mock_config = MagicMock() + + async def empty_audio_gen(): + return + yield # Make it a generator + + responses = [] + async for response in service.streaming_recognize(empty_audio_gen(), mock_config): + responses.append(response) + + # Should still process (config-only request) + assert len(responses) == 1 + mock_stub.StreamingRecognize.assert_called_once() + + @pytest.mark.asyncio + async def test_single_audio_chunk(self, mock_auth: AsyncAuth) -> None: + """Streaming with exactly one audio chunk.""" + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + mock_response = MagicMock() + mock_call = MagicMock() + mock_call.__aiter__ = lambda self: aiter([mock_response]) + mock_stub.StreamingRecognize.return_value = mock_call + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + mock_config = MagicMock() + + async def single_chunk_gen(): + yield b"single_audio_chunk" + + responses = [] + async for response in service.streaming_recognize(single_chunk_gen(), mock_config): + responses.append(response) + + assert len(responses) == 1 + mock_stub.StreamingRecognize.assert_called_once() + + +class TestErrorHandling: + """Tests for error handling behavior.""" + + @pytest.fixture + def mock_auth(self) -> AsyncAuth: + """Create mock auth with mocked channel.""" + auth = AsyncAuth(uri="localhost:50051") + auth._channel = MagicMock() + return auth + + @pytest.mark.asyncio + async def test_streaming_recognize_handles_grpc_error(self, mock_auth: AsyncAuth) -> None: + """gRPC errors propagate correctly.""" + import grpc + + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + + async def error_iter(): + raise grpc.aio.AioRpcError( + grpc.StatusCode.UNAVAILABLE, + initial_metadata=None, + trailing_metadata=None, + details="Server unavailable", + debug_error_string=None, + ) + yield # Make it a generator + + mock_call = MagicMock() + mock_call.__aiter__ = lambda self: error_iter() + mock_stub.StreamingRecognize.return_value = mock_call + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + mock_config = MagicMock() + + async def audio_gen(): + yield b"audio" + + with pytest.raises(grpc.aio.AioRpcError): + async for _ in service.streaming_recognize(audio_gen(), mock_config): + pass + + @pytest.mark.asyncio + async def test_streaming_recognize_cancellation(self, mock_auth: AsyncAuth) -> None: + """Cancellation is handled gracefully.""" + with patch("riva.client.asr_async.rasr_srv.RivaSpeechRecognitionStub") as mock_stub_cls: + mock_stub = MagicMock() + + async def slow_iter(): + yield MagicMock() + await asyncio.sleep(10) # Long delay + yield MagicMock() + + mock_call = MagicMock() + mock_call.__aiter__ = lambda self: slow_iter() + mock_stub.StreamingRecognize.return_value = mock_call + mock_stub_cls.return_value = mock_stub + + service = ASRServiceAsync(mock_auth) + mock_config = MagicMock() + + async def audio_gen(): + yield b"audio" + + async def consume_stream(): + async for _ in service.streaming_recognize(audio_gen(), mock_config): + pass + + task = asyncio.create_task(consume_stream()) + await asyncio.sleep(0.01) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task