From 8477d342b3d38edefdeb3837a972e0e3b754f425 Mon Sep 17 00:00:00 2001 From: Lee Penkman Date: Sun, 8 Jun 2025 18:48:25 +1200 Subject: [PATCH] Allow chunked streaming control --- README.md | 1 + apiexamples/streaming_speech.py | 16 + questions/blog_fixtures.py | 5 + .../inference_server/inference_server.py | 631 ++++++++++++------ questions/models.py | 2 + static/openapi.json | 55 ++ static/openapi2-audio.json | 55 ++ .../shared/streaming-speech-api.jinja2 | 38 ++ tests/integ/test_audio_length_limit.py | 28 + tests/integ/test_inference_server_speech.py | 33 +- 10 files changed, 659 insertions(+), 205 deletions(-) create mode 100644 apiexamples/streaming_speech.py create mode 100644 static/templates/shared/streaming-speech-api.jinja2 create mode 100644 tests/integ/test_audio_length_limit.py diff --git a/README.md b/README.md index 73116ab..4a71885 100755 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ There's also more control of text generation via the Text-generator API, this in * max_sentences (generate only a set number of sentences at most) Text generator also has routes for speech to text and speech generation. +You can now stream generated speech in real time via `/api/v1/generate_speech_stream`. See https://text-generator.io/docs diff --git a/apiexamples/streaming_speech.py b/apiexamples/streaming_speech.py new file mode 100644 index 0000000..b960e0d --- /dev/null +++ b/apiexamples/streaming_speech.py @@ -0,0 +1,16 @@ +import requests +import os + +API_KEY = os.getenv("TEXT_GENERATOR_API_KEY") +if API_KEY is None: + raise Exception( + "Please set TEXT_GENERATOR_API_KEY environment variable, login to https://text-generator.io to get your API key") +headers = {"secret": API_KEY} + +params = { + "text": "Hello streaming world", + "speaker": "Male fast" +} +with requests.post("https://api.text-generator.io/api/v1/generate_speech_stream", json=params, headers=headers, stream=True) as r: + for chunk in r.iter_content(chunk_size=None): + print("Received", len(chunk), "bytes") diff --git a/questions/blog_fixtures.py b/questions/blog_fixtures.py index 59475d2..801af0b 100644 --- a/questions/blog_fixtures.py +++ b/questions/blog_fixtures.py @@ -54,4 +54,9 @@ "description": "Control summary length precisely by specifying maximum characters in the summarization API", "keywords": "text summarization, API, natural language processing, text generation, Machine Learning" }, + "streaming-speech-api": { + "title": "Stream Speech Output in Real Time", + "description": "Generate speech in chunks and stream it directly without buffering the whole audio.", + "keywords": "text to speech, streaming, API" + }, } diff --git a/questions/inference_server/inference_server.py b/questions/inference_server/inference_server.py index 1811d1d..9369c93 100644 --- a/questions/inference_server/inference_server.py +++ b/questions/inference_server/inference_server.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import os import random +import asyncio from io import BytesIO from tempfile import NamedTemporaryFile from typing import Union, List, Iterator @@ -16,7 +17,14 @@ setup_logging() logger = logging.getLogger(__name__) -from starlette.responses import JSONResponse, RedirectResponse, Response, StreamingResponse, HTMLResponse +CHUNK_WORDS_DEFAULT = int(os.getenv("STREAM_CHUNK_WORDS", "100")) +from starlette.responses import ( + JSONResponse, + RedirectResponse, + Response, + StreamingResponse, + HTMLResponse, +) from questions.audio_server.audio_dl import request_get from questions.constants import weights_path_tgz @@ -27,10 +35,17 @@ FeatureExtractParams, OpenaiParams, map_to_generate_params, - map_to_openai_response, AudioParams, GenerateSpeechParams, AudioParamsOrAudioFile, SummarizationParams, + map_to_openai_response, + AudioParams, + GenerateSpeechParams, + AudioParamsOrAudioFile, + SummarizationParams, +) +from questions.payments.payments import ( + get_subscription_item_id_for_user, + create_subscription_for_user, + get_subscription_item_id_for_user_email, ) -from questions.payments.payments import get_subscription_item_id_for_user, \ - create_subscription_for_user, get_subscription_item_id_for_user_email from questions.perplexity import DEVICE from questions.summarization import get_extractive_summary from questions.text_gen_pipeline import TextGenPipeline @@ -42,7 +57,7 @@ from questions.utils import log_time from sellerinfo import session_secret from .models import build_model -from .kokoro import generate +from .kokoro import generate, generate_full import librosa @@ -82,23 +97,90 @@ # async def create_files(files: bytes = File()): # return {"file_sizes": [len(file) for file in files]} -languages = {"af_za": "Afrikaans", "am_et": "Amharic", "ar_eg": "Arabic", "as_in": "Assamese", "az_az": "Azerbaijani", - "be_by": "Belarusian", "bg_bg": "Bulgarian", "bn_in": "Bengali", "bs_ba": "Bosnian", "ca_es": "Catalan", - "cmn_hans_cn": "Chinese", "cs_cz": "Czech", "cy_gb": "Welsh", "da_dk": "Danish", "de_de": "German", - "el_gr": "Greek", "en_us": "English", "es_419": "Spanish", "et_ee": "Estonian", "fa_ir": "Persian", - "fi_fi": "Finnish", "fil_ph": "Tagalog", "fr_fr": "French", "gl_es": "Galician", "gu_in": "Gujarati", - "ha_ng": "Hausa", "he_il": "Hebrew", "hi_in": "Hindi", "hr_hr": "Croatian", "hu_hu": "Hungarian", - "hy_am": "Armenian", "id_id": "Indonesian", "is_is": "Icelandic", "it_it": "Italian", "ja_jp": "Japanese", - "jv_id": "Javanese", "ka_ge": "Georgian", "kk_kz": "Kazakh", "km_kh": "Khmer", "kn_in": "Kannada", - "ko_kr": "Korean", "lb_lu": "Luxembourgish", "ln_cd": "Lingala", "lo_la": "Lao", "lt_lt": "Lithuanian", - "lv_lv": "Latvian", "mi_nz": "Maori", "mk_mk": "Macedonian", "ml_in": "Malayalam", "mn_mn": "Mongolian", - "mr_in": "Marathi", "ms_my": "Malay", "mt_mt": "Maltese", "my_mm": "Myanmar", "nb_no": "Norwegian", - "ne_np": "Nepali", "nl_nl": "Dutch", "oc_fr": "Occitan", "pa_in": "Punjabi", "pl_pl": "Polish", - "ps_af": "Pashto", "pt_br": "Portuguese", "ro_ro": "Romanian", "ru_ru": "Russian", "sd_in": "Sindhi", - "sk_sk": "Slovak", "sl_si": "Slovenian", "sn_zw": "Shona", "so_so": "Somali", "sr_rs": "Serbian", - "sv_se": "Swedish", "sw_ke": "Swahili", "ta_in": "Tamil", "te_in": "Telugu", "tg_tj": "Tajik", - "th_th": "Thai", "tr_tr": "Turkish", "uk_ua": "Ukrainian", "ur_pk": "Urdu", "uz_uz": "Uzbek", - "vi_vn": "Vietnamese", "yo_ng": "Yoruba"} +languages = { + "af_za": "Afrikaans", + "am_et": "Amharic", + "ar_eg": "Arabic", + "as_in": "Assamese", + "az_az": "Azerbaijani", + "be_by": "Belarusian", + "bg_bg": "Bulgarian", + "bn_in": "Bengali", + "bs_ba": "Bosnian", + "ca_es": "Catalan", + "cmn_hans_cn": "Chinese", + "cs_cz": "Czech", + "cy_gb": "Welsh", + "da_dk": "Danish", + "de_de": "German", + "el_gr": "Greek", + "en_us": "English", + "es_419": "Spanish", + "et_ee": "Estonian", + "fa_ir": "Persian", + "fi_fi": "Finnish", + "fil_ph": "Tagalog", + "fr_fr": "French", + "gl_es": "Galician", + "gu_in": "Gujarati", + "ha_ng": "Hausa", + "he_il": "Hebrew", + "hi_in": "Hindi", + "hr_hr": "Croatian", + "hu_hu": "Hungarian", + "hy_am": "Armenian", + "id_id": "Indonesian", + "is_is": "Icelandic", + "it_it": "Italian", + "ja_jp": "Japanese", + "jv_id": "Javanese", + "ka_ge": "Georgian", + "kk_kz": "Kazakh", + "km_kh": "Khmer", + "kn_in": "Kannada", + "ko_kr": "Korean", + "lb_lu": "Luxembourgish", + "ln_cd": "Lingala", + "lo_la": "Lao", + "lt_lt": "Lithuanian", + "lv_lv": "Latvian", + "mi_nz": "Maori", + "mk_mk": "Macedonian", + "ml_in": "Malayalam", + "mn_mn": "Mongolian", + "mr_in": "Marathi", + "ms_my": "Malay", + "mt_mt": "Maltese", + "my_mm": "Myanmar", + "nb_no": "Norwegian", + "ne_np": "Nepali", + "nl_nl": "Dutch", + "oc_fr": "Occitan", + "pa_in": "Punjabi", + "pl_pl": "Polish", + "ps_af": "Pashto", + "pt_br": "Portuguese", + "ro_ro": "Romanian", + "ru_ru": "Russian", + "sd_in": "Sindhi", + "sk_sk": "Slovak", + "sl_si": "Slovenian", + "sn_zw": "Shona", + "so_so": "Somali", + "sr_rs": "Serbian", + "sv_se": "Swedish", + "sw_ke": "Swahili", + "ta_in": "Tamil", + "te_in": "Telugu", + "tg_tj": "Tajik", + "th_th": "Thai", + "tr_tr": "Turkish", + "uk_ua": "Ukrainian", + "ur_pk": "Urdu", + "uz_uz": "Uzbek", + "vi_vn": "Vietnamese", + "yo_ng": "Yoruba", +} # expose best of? options = {} # dict(beam_size=5, best_of=5) (makes it slower) @@ -116,7 +198,9 @@ # result = audio_model.transcribe("/media/lee/78ca132e-d181-4406-aea5-3c9665f486cc/Videos/intro-bitbanknz.mp3") 40/6s app.mount("/static", StaticFiles(directory="static"), name="static") -app.mount("/gameon/static", StaticFiles(directory="gameon/static"), name="gameon/static") +app.mount( + "/gameon/static", StaticFiles(directory="gameon/static"), name="gameon/static" +) app.add_middleware( CORSMiddleware, @@ -237,9 +321,7 @@ def track_stripe_request_usage(secret, quantity: int): f"no subscription item id for user: {existing_user.email} {existing_user.stripe_id}" ) # Create a new subscription item - subscription = create_subscription_for_user( - existing_user.stripe_id - ) + subscription = create_subscription_for_user(existing_user.stripe_id) logger.info(f"created subscription: {subscription}") try: subscription_item_id = subscription["items"].data[0]["id"] @@ -263,13 +345,17 @@ def validate_generate_params(generate_params): audio_model = None + + def load_audio_model(): global audio_model # about 10s with log_time("load whisper model"): # audio_model = whisper.load_model("large") # todo specify download_root to a fast ssd if not audio_model: - audio_model = whisper.load_model("medium", download_root="models") # todo specify download_root to a fast ssd + audio_model = whisper.load_model( + "medium", download_root="models" + ) # todo specify download_root to a fast ssd audio_model.eval() audio_model = audio_model.to("cuda") logger.info( @@ -291,26 +377,29 @@ def fast_audio_extract_inference(audio_params: AudioParamsOrAudioFile): # todo download youtube video ydl_opts = { - 'format': 'bestaudio/best', - 'postprocessors': [{ - 'key': 'FFmpegExtractAudio', - 'preferredcodec': 'wav', - 'preferredquality': '192' - }], - 'postprocessor_args': [ - '-ar', '16000' + "format": "bestaudio/best", + "postprocessors": [ + { + "key": "FFmpegExtractAudio", + "preferredcodec": "wav", + "preferredquality": "192", + } ], - 'prefer_ffmpeg': True, - 'keepvideo': False, + "postprocessor_args": ["-ar", "16000"], + "prefer_ffmpeg": True, + "keepvideo": False, # download to temp file - 'outtmpl': '/tmp/audio.wav' + "outtmpl": "/tmp/audio.wav", # download to memory - } with log_time("download youtube"): with youtube_dl.YoutubeDL(ydl_opts) as ydl: - ydl.download([audio_params.audio_url], ) - audio_bytes = ydl.prepare_filename(ydl.extract_info(audio_params.audio_url)) + ydl.download( + [audio_params.audio_url], + ) + audio_bytes = ydl.prepare_filename( + ydl.extract_info(audio_params.audio_url) + ) with open(audio_bytes, "rb") as f: audio_bytes = f.read() else: @@ -319,7 +408,8 @@ def fast_audio_extract_inference(audio_params: AudioParamsOrAudioFile): response = audio_request.result() if response.status_code != 200: raise HTTPException( - status_code=500, detail=f"Failed to download audio file at {audio_params.audio_url}" + status_code=500, + detail=f"Failed to download audio file at {audio_params.audio_url}", ) response.raw.decode_content = True audio_bytes = response.content @@ -335,9 +425,9 @@ def fast_audio_extract_inference(audio_params: AudioParamsOrAudioFile): # clean data tmp_file.close() - for segment in result['segments']: - del segment['tokens'] - result['text'] = result['text'].strip() + for segment in result["segments"]: + del segment["tokens"] + result["text"] = result["text"].strip() return result @@ -372,35 +462,43 @@ def write_srt(transcript: Iterator[dict]): @app.post("/api/v1/audio-file-extraction") async def audio_file_extraction( - background_tasks: BackgroundTasks, - request: Request, - response: Response, - # files: List[UploadFile] = File(...), - # files: list[bytes] = File(None, description="Multiple files as bytes"), - audio_file: UploadFile = File(None, description="Audio file"), - translate_to_english: bool = Form(False), - output_filetype: str = Form("txt"), - secret: Union[str, None] = Header(default=None), + background_tasks: BackgroundTasks, + request: Request, + response: Response, + # files: List[UploadFile] = File(...), + # files: list[bytes] = File(None, description="Multiple files as bytes"), + audio_file: UploadFile = File(None, description="Audio file"), + translate_to_english: bool = Form(False), + output_filetype: str = Form("txt"), + secret: Union[str, None] = Header(default=None), ): - audio_params = AudioParamsOrAudioFile(audio_file=audio_file, audio_url=None, - translate_to_english=translate_to_english, output_filetype=output_filetype) - return await audio_extract_shared(background_tasks, audio_params, request, response, secret) + audio_params = AudioParamsOrAudioFile( + audio_file=audio_file, + audio_url=None, + translate_to_english=translate_to_english, + output_filetype=output_filetype, + ) + return await audio_extract_shared( + background_tasks, audio_params, request, response, secret + ) @app.post("/api/v1/audio-extraction") async def audio_extraction( - feature_extract_params: AudioParams, # wav files or mp3 supported - # audio_file: UploadFile, - background_tasks: BackgroundTasks, - request: Request, - response: Response, - secret: Union[str, None] = Header(default=None), + feature_extract_params: AudioParams, # wav files or mp3 supported + # audio_file: UploadFile, + background_tasks: BackgroundTasks, + request: Request, + response: Response, + secret: Union[str, None] = Header(default=None), ): # if not request_authorized(request, secret): # return HTTPException( # status_code=401, detail="Please subscribe at https://text-generator.io/subscribe first" # ) - return await audio_extract_shared(background_tasks, feature_extract_params, request, response, secret) + return await audio_extract_shared( + background_tasks, feature_extract_params, request, response, secret + ) # @app.get("/restart-server") @@ -410,7 +508,10 @@ async def audio_extraction( # os.system("kill -SIGHUP `pgrep gunicorn`") # return "restarting server" -async def audio_extract_shared(background_tasks, feature_extract_params, request, response, secret): + +async def audio_extract_shared( + background_tasks, feature_extract_params, request, response, secret +): # if not request_authorized(request, secret): # return HTTPException( # status_code=401, @@ -419,41 +520,49 @@ async def audio_extract_shared(background_tasks, feature_extract_params, request inference_result = fast_audio_extract_inference(feature_extract_params) if "X-Rapid-API-Key" not in request.headers: # todo fix - seconds_taken = inference_result['segments'][-1]['end'] + seconds_taken = inference_result["segments"][-1]["end"] # price of quantity is 1 for .01 price = seconds_taken * 0.00005 - quantity = price // .01 - remainder = price % .01 + quantity = price // 0.01 + remainder = price % 0.01 if random.random() < remainder * 100: quantity += 1 if quantity: - if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: # skip logging for our secret - background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=int(quantity)) + if ( + not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET + ): # skip logging for our secret + background_tasks.add_task( + track_stripe_request_usage, secret=secret, quantity=int(quantity) + ) if feature_extract_params.output_filetype == "srt": # response = StreamingResponse( # non streaming response response.headers["Content-Disposition"] = "attachment; filename=audio.srt" - return write_srt(inference_result['segments']) + return write_srt(inference_result["segments"]) return inference_result @app.post("/api/v1/feature-extraction") async def feature_extraction( - feature_extract_params: FeatureExtractParams, - background_tasks: BackgroundTasks, - request: Request, - secret: Union[str, None] = Header(default=None), + feature_extract_params: FeatureExtractParams, + background_tasks: BackgroundTasks, + request: Request, + secret: Union[str, None] = Header(default=None), ): # global daemon # slow warmup on new servers # model = MODEL_CACHE.add_or_get("text_model", load_pipelines_and_model) # daemon.join() - inference_result = fast_feature_extract_inference(feature_extract_params, MODEL_CACHE) + inference_result = fast_feature_extract_inference( + feature_extract_params, MODEL_CACHE + ) if "X-Rapid-API-Key" not in request.headers: # todo fix if random.randint(1, 10) == 10: if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: - background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1) + background_tasks.add_task( + track_stripe_request_usage, secret=secret, quantity=1 + ) return inference_result[: feature_extract_params.num_features] @@ -473,17 +582,19 @@ async def feature_extraction( # todo fix if random.randint(1, 10) == 10: if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: - background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1) - return JSONResponse({ - "generated_text": text - }) + background_tasks.add_task( + track_stripe_request_usage, secret=secret, quantity=1 + ) + return JSONResponse({"generated_text": text}) @app.get("/liveness_check") async def liveness_check(request: Request): # global daemon inference_result = fast_inference( - generate_params=GenerateParams(text="hi my friend", min_probability=0.9, max_length=1, model='any'), + generate_params=GenerateParams( + text="hi my friend", min_probability=0.9, max_length=1, model="any" + ), model_cache=MODEL_CACHE, ) return JSONResponse(inference_result) @@ -522,6 +633,8 @@ def request_authorized(request: Request, secret): speech_processor = None speechgen_model = None speech_vocoder = None + + def load_speechgen_model(): global speech_processor global speechgen_model @@ -529,18 +642,30 @@ def load_speechgen_model(): if not speechgen_model: # Load Kokoro model - device = 'cuda' if torch.cuda.is_available() else 'cpu' - speechgen_model = build_model('models/kokoro-v0_19.pth', device) - + device = "cuda" if torch.cuda.is_available() else "cpu" + speechgen_model = build_model("models/kokoro-v0_19.pth", device) + # Load voice packs voicepacks = {} - voice_names = ['af', 'af_bella', 'af_sarah', 'am_adam', 'am_michael', - 'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis', - 'af_nicole', 'af_sky'] - + voice_names = [ + "af", + "af_bella", + "af_sarah", + "am_adam", + "am_michael", + "bf_emma", + "bf_isabella", + "bm_george", + "bm_lewis", + "af_nicole", + "af_sky", + ] + for voice in voice_names: - voicepacks[voice] = torch.load(f'models/voices/{voice}.pt', weights_only=True).to(device) - + voicepacks[voice] = torch.load( + f"models/voices/{voice}.pt", weights_only=True + ).to(device) + return speechgen_model, voicepacks @@ -557,23 +682,23 @@ def write_wav(processed_np_speech, rate): # todo fix to use io.BytesIO bytes = BytesIO() bytes.name = "audio.wav" - sf.write(bytes, processed_np_speech, rate, subtype='PCM_24') + sf.write(bytes, processed_np_speech, rate, subtype="PCM_24") # bytesio to bytes return bytes.getvalue() @app.post("/api/v1/generate_speech") async def generate_speech( - generate_speech_params: GenerateSpeechParams, - background_tasks: BackgroundTasks, - request: Request, - response: Response, - secret: Union[str, None] = Header(default=None), + generate_speech_params: GenerateSpeechParams, + background_tasks: BackgroundTasks, + request: Request, + response: Response, + secret: Union[str, None] = Header(default=None), ): if not request_authorized(request, secret): return HTTPException( status_code=401, - detail="Invalid Secret, please use the secret found in /account also subscribe at https://text-generator.io/subscribe first, also make sure there is an up to date credit card saved in your account" + detail="Invalid Secret, please use the secret found in /account also subscribe at https://text-generator.io/subscribe first, also make sure there is an up to date credit card saved in your account", ) text = generate_speech_params.text @@ -586,7 +711,9 @@ async def generate_speech( if "X-Rapid-API-Key" not in request.headers: # todo fix if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: - background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1) + background_tasks.add_task( + track_stripe_request_usage, secret=secret, quantity=1 + ) # write np array to wav wav = write_wav(processed_np_speech, rate) file = Response(wav, media_type="audio/wav") @@ -594,19 +721,48 @@ async def generate_speech( # return Response(wav, media_type="audio/wav") +@app.post("/api/v1/generate_speech_stream") +async def generate_speech_stream( + generate_speech_params: GenerateSpeechParams, + background_tasks: BackgroundTasks, + request: Request, + secret: Union[str, None] = Header(default=None), +): + if not request_authorized(request, secret): + return HTTPException( + status_code=401, + detail="Invalid Secret, please use the secret found in /account also subscribe at https://text-generator.io/subscribe first, also make sure there is an up to date credit card saved in your account", + ) + + if "X-Rapid-API-Key" not in request.headers: + if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: + background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1) + + generator = stream_audio_chunks( + generate_speech_params.text, + generate_speech_params.voice, + generate_speech_params.speed, + generate_speech_params.chunk_words, + ) + response = StreamingResponse(generator, media_type="audio/wav") + response.headers["Content-Disposition"] = "attachment; filename=audio.wav" + return response + + def gradio_audio_process(text, voice, speed=1.0): """Simplified function that only takes the required parameters""" if len(text.strip()) == 0: return (24000, np.zeros(0).astype(np.int16)) model, voicepacks = MODEL_CACHE.add_or_get("speech_model", load_speechgen_model) - + # Get the voicepack - voicepack = voicepacks.get(voice, voicepacks['af_nicole']) - + voicepack = voicepacks.get(voice, voicepacks["af_nicole"]) + # Generate audio using Kokoro - audio, phonemes = generate(model, text, voicepack, lang=voice[0], speed=speed) - + # generate_full handles inputs longer than the default token limit + audio, phonemes = generate_full(model, text, voicepack, lang=voice[0], speed=speed) + return (24000, audio) @@ -617,9 +773,7 @@ def gradio_audio_process(text, voice, speed=1.0): "Male slower": "RMS (male)", "Female 2": "SLT (female)", } -speaker_embeddings_loaded = { - -} +speaker_embeddings_loaded = {} def load_speaker_embedding(speaker): @@ -635,16 +789,17 @@ def load_speaker_embedding(speaker): def audio_process(text, voice="af_nicole", speed=1.0): model, voicepacks = MODEL_CACHE.add_or_get("speech_model", load_speechgen_model) - + if len(text.strip()) == 0: return (24000, np.zeros(0).astype(np.int16)) # Get the voicepack - voicepack = voicepacks.get(voice, voicepacks['af_nicole']) - + voicepack = voicepacks.get(voice, voicepacks["af_nicole"]) + # Generate audio using Kokoro - audio, phonemes = generate(model, text, voicepack, lang=voice[0], speed=speed) - + # generate_full handles inputs longer than the default token limit + audio, phonemes = generate_full(model, text, voicepack, lang=voice[0], speed=speed) + # we could do this but use speed instead # Convert to float32 for time-stretch # audio_float = audio.astype(np.float32) / 32767.0 @@ -658,7 +813,48 @@ def audio_process(text, voice="af_nicole", speed=1.0): return (24000, audio) +def synthesize_full_text( + text: str, voice: str = "af_nicole", speed: float = 1.0, chunk_words: int = CHUNK_WORDS_DEFAULT +): + """Generate speech for arbitrarily long text by chunking.""" + if not text.strip(): + return (24000, np.zeros(0, dtype=np.int16)) + + words = text.split() + segments = [] + for start in range(0, len(words), chunk_words): + chunk = " ".join(words[start : start + chunk_words]) + _, audio = audio_process(chunk, voice=voice, speed=speed) + segments.append(audio) + + if segments: + full_audio = np.concatenate(segments) + else: + full_audio = np.zeros(0, dtype=np.int16) + + return (24000, full_audio) + +def stream_audio_chunks( + text: str, voice: str = "af_nicole", speed: float = 1.0, chunk_words: int = CHUNK_WORDS_DEFAULT +): + """Yield audio wav bytes for each chunk of text.""" + words = text.split() + for start in range(0, len(words), chunk_words): + chunk = " ".join(words[start : start + chunk_words]) + _, audio = audio_process(chunk, voice=voice, speed=speed) + yield write_wav(audio, 24000) + + +async def stream_audio_chunks_async( + text: str, voice: str = "af_nicole", speed: float = 1.0, chunk_words: int = CHUNK_WORDS_DEFAULT +): + """Asynchronously yield audio wav bytes for each chunk of text.""" + words = text.split() + for start in range(0, len(words), chunk_words): + chunk = " ".join(words[start : start + chunk_words]) + _, audio = await asyncio.to_thread(audio_process, chunk, voice, speed) + yield write_wav(audio, 24000) # gradio web app at https://text-generator.io/gradio_tts @@ -666,16 +862,24 @@ def audio_process(text, voice="af_nicole", speed=1.0): examples = [ ["It is not in the stars to hold our destiny but in ourselves.", "Male default"], ["Oliver went to the opera in October.", "Female 2"], - ["She sells seashells by the seashore. I saw a kitten eating chicken in the kitchen.", "Male slower"], - ["Brisk brave brigadiers brandished broad bright blades, blunderbusses, and bludgeons—balancing them badly.", - "Female 1"], + [ + "She sells seashells by the seashore. I saw a kitten eating chicken in the kitchen.", + "Male slower", + ], + [ + "Brisk brave brigadiers brandished broad bright blades, blunderbusses, and bludgeons—balancing them badly.", + "Female 1", + ], ["A synonym for cinnamon is a cinnamon synonym.", "Male fast"], [ "How much wood would a woodchuck chuck if a woodchuck could chuck wood? He would chuck, he would, as much as he could, and chuck as much wood as a woodchuck would if a woodchuck could chuck wood.", - "Female 2"], + "Female 2", + ], ] title = "Text To Speech" -description = "How to use: Enter some English text and choose a speaker. Click Submit " +description = ( + "How to use: Enter some English text and choose a speaker. Click Submit " +) article = """

Checkout the API docs at /docs

""" @@ -828,22 +1032,26 @@ def audio_process(text, voice="af_nicole", speed=1.0): # async def config_route(): # return RedirectResponse("/gradio_tts/info") + @app.get("/setcookie") -async def setcookie(secret:str, request: Request, response: Response): - response.set_cookie('secret', value=secret)#, httponly=True) +async def setcookie(secret: str, request: Request, response: Response): + response.set_cookie("secret", value=secret) # , httponly=True) return {"message": "Cookie set"} + @app.get("/gradio_frame") -async def gradio_frame_route(secret:str, request: Request): - return templates.TemplateResponse("templates/gradio-frame.jinja2", {"request": request, "secret": secret}) +async def gradio_frame_route(secret: str, request: Request): + return templates.TemplateResponse( + "templates/gradio-frame.jinja2", {"request": request, "secret": secret} + ) @app.post("/api/v1/generate") async def generate_route( - generate_params: GenerateParams, - background_tasks: BackgroundTasks = None, - request: Request = None, - secret: Union[str, None] = Header(default=None), + generate_params: GenerateParams, + background_tasks: BackgroundTasks = None, + request: Request = None, + secret: Union[str, None] = Header(default=None), ): validation_result = validate_generate_params(generate_params) if validation_result: @@ -863,50 +1071,59 @@ async def generate_route( inference_result = fast_inference(generate_params, MODEL_CACHE) # todo vuln if request and background_tasks: - if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers: + if ( + "X-Rapid-API-Key" not in request.headers + and "x-rapid-api-key" not in request.headers + ): if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: - background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1) + background_tasks.add_task( + track_stripe_request_usage, secret=secret, quantity=1 + ) return inference_result @app.post("/api/discord") async def discord_route( - type: str = Form(...), - id: str = Form(...), - data: str = Form(...), - - # text: str, - # background_tasks: BackgroundTasks = None, - # request: Request = None, - # secret: Union[str, None] = Header(default=None), - # response: Response = None, + type: str = Form(...), + id: str = Form(...), + data: str = Form(...), + # text: str, + # background_tasks: BackgroundTasks = None, + # request: Request = None, + # secret: Union[str, None] = Header(default=None), + # response: Response = None, ): - generate_params = GenerateParams(**{ - "text": "in 2022 the stock market has been expected to reach a record high.", - "number_of_results": 1, - "max_length": 100, - "max_sentences": 1, - "min_probability": 0, - "stop_sequences": [], - "top_p": 0.9, - "top_k": 40, - "temperature": 0.7, - "repetition_penalty": 1.17, - "seed": 0 - }) + generate_params = GenerateParams( + **{ + "text": "in 2022 the stock market has been expected to reach a record high.", + "number_of_results": 1, + "max_length": 100, + "max_sentences": 1, + "min_probability": 0, + "stop_sequences": [], + "top_p": 0.9, + "top_k": 40, + "temperature": 0.7, + "repetition_penalty": 1.17, + "seed": 0, + } + ) + @app.post("/api/v1/generate-bulk") async def generate_route_bulk( - bulk_params: List[GenerateParams], - background_tasks: BackgroundTasks = None, - request: Request = None, - secret: Union[str, None] = Header(default=None), + bulk_params: List[GenerateParams], + background_tasks: BackgroundTasks = None, + request: Request = None, + secret: Union[str, None] = Header(default=None), ): # global daemon # print(model.config.max_length) # print(tokenizer.model_max_length) # model.config.max_length = tokenizer.model_max_length - model = MODEL_CACHE.add_or_get("text_model", lambda: load_pipelines_and_model(weights_path_tgz)) + model = MODEL_CACHE.add_or_get( + "text_model", lambda: load_pipelines_and_model(weights_path_tgz) + ) # daemon.join() inference_results = [] for generate_params in bulk_params: @@ -965,11 +1182,11 @@ async def generate_route_bulk( @app.post("/v1/engines/{engine_name}/completions") async def openai_route_named( - engine_name: str, - openai_params: OpenaiParams, - background_tasks: BackgroundTasks = None, - request: Request = None, - secret: Union[str, None] = Header(default=None), + engine_name: str, + openai_params: OpenaiParams, + background_tasks: BackgroundTasks = None, + request: Request = None, + secret: Union[str, None] = Header(default=None), ): # global daemon # print(model.config.max_length) @@ -982,7 +1199,10 @@ async def openai_route_named( if validation_result: # return a 400 bad request from fast api return HTTPException(status_code=400, detail=validation_result) - if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers: + if ( + "X-Rapid-API-Key" not in request.headers + and "x-rapid-api-key" not in request.headers + ): header_auth = request.headers.get("Authorization", " ") authorization_split = header_auth.split(" ") if len(authorization_split) == 2: @@ -991,32 +1211,39 @@ async def openai_route_named( if not request_authorized(request, secret): return HTTPException( status_code=401, - detail="Please subscribe at https://text-generator.io/subscribe first, also ensure you have a credit card on file" + detail="Please subscribe at https://text-generator.io/subscribe first, also ensure you have a credit card on file", ) inference_result = fast_inference(generate_params, MODEL_CACHE) if not openai_params.echo: ## remove all the inputs from the generated texts for i in range(len(inference_result)): - inference_result[i]["generated_text"] = inference_result[i]["generated_text"][ - len(openai_params.prompt): - ] + inference_result[i]["generated_text"] = inference_result[i][ + "generated_text" + ][len(openai_params.prompt) :] # todo vuln if request and background_tasks: - if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers: + if ( + "X-Rapid-API-Key" not in request.headers + and "x-rapid-api-key" not in request.headers + ): if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: - background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1) + background_tasks.add_task( + track_stripe_request_usage, secret=secret, quantity=1 + ) # map to openai response return map_to_openai_response(inference_result, generate_params) @app.post("/v1/completions") async def openai_route( - openai_params: OpenaiParams, - background_tasks: BackgroundTasks = None, - request: Request = None, - secret: Union[str, None] = Header(default=None), + openai_params: OpenaiParams, + background_tasks: BackgroundTasks = None, + request: Request = None, + secret: Union[str, None] = Header(default=None), ): - return await openai_route_named("engine", openai_params, background_tasks, request, secret) + return await openai_route_named( + "engine", openai_params, background_tasks, request, secret + ) # redirect / to /docs @@ -1032,14 +1259,17 @@ async def root(): # return RedirectResponse(url="/gradio_tts") -logger.info(""" +logger.info( + """ Welcome to the Text Generator API! Thanks for using our service. We hope you enjoy it and it means a lot to us that you are here. Let me know how it goes and if you have any questions or feedback, please reach out at lee.penkman@wordsmashing.com Downloading models may take a long time on the first run. -""") +""" +) + @app.get("/text-to-speech-demo") def tts_demo(request: Request): @@ -1051,6 +1281,7 @@ def tts_demo(request: Request): }, ) + # @app.post("/api/v1/generate-long") # async def generate_long_text( # generate_params: GenerateParams, @@ -1065,29 +1296,29 @@ def tts_demo(request: Request): # validation_result = validate_generate_params(generate_params) # if validation_result: # return HTTPException(status_code=400, detail=validation_result) -# +# # # Authorize the request # if request and "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers: # if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: # if not request_authorized(request, secret): # return HTTPException( -# status_code=401, +# status_code=401, # detail="Please subscribe at https://text-generator.io/subscribe first" # ) -# +# # try: # # Prepare the prompt for Claude # prompt = generate_params.text -# +# # # Set up system message to control generation parameters # system_message = f""" # You are a creative text generation assistant. Generate text that continues from the given prompt. -# +# # Parameters to follow: # - Temperature: {generate_params.temperature} # - Creativity level: {"high" if generate_params.temperature > 0.7 else "medium" if generate_params.temperature > 0.3 else "low"} # - Maximum length: {generate_params.max_length} tokens -# +# # Important instructions: # - Continue the text naturally from where the prompt ends # - Do not repeat the prompt in your response @@ -1095,43 +1326,43 @@ def tts_demo(request: Request): # - Do not use phrases like "Here's a continuation" or "Continuing from the prompt" # - Just generate the continuation text directly # """ -# +# # # Set up stop sequences # stop_sequences = None # if generate_params.stop_sequences and len(generate_params.stop_sequences) > 0: # stop_sequences = frozenset(generate_params.stop_sequences) -# +# # # Call Claude API # generated_text = await query_to_claude_async( # prompt=prompt, # stop_sequences=stop_sequences, # system_message=system_message, # ) -# +# # # Handle the response # if generated_text is None: # return HTTPException(status_code=500, detail="Failed to generate text with Claude") -# +# # # Format the response to match the standard API format # result = [{ # "generated_text": prompt + generated_text, # "finished_reason": "length", # "model": "claude-3-sonnet-20240229" # }] -# +# # # Track usage if needed # if request and background_tasks: # if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers: # if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: # # Claude is more expensive, so we charge 3 units # background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=3) -# +# # return result -# +# # except Exception as e: # logger.error(f"Error generating text with Claude: {e}") # return HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") -# +# # @app.post("/api/v1/generate-large") # async def generate_large_text( # generate_params: GenerateParams, @@ -1146,30 +1377,30 @@ def tts_demo(request: Request): # validation_result = validate_generate_params(generate_params) # if validation_result: # return HTTPException(status_code=400, detail=validation_result) -# +# # # Authorize the request # if request and "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers: # if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: # if not request_authorized(request, secret): # return HTTPException( -# status_code=401, +# status_code=401, # detail="Please subscribe at https://text-generator.io/subscribe first" # ) -# +# # try: # # Prepare the prompt for Claude # prompt = generate_params.text # model_name = "claude-3-7-sonnet-20250219" -# +# # # Set up system message to control generation parameters # system_message = f""" # You are a creative text generation assistant. Generate text that continues from the given prompt. -# +# # Parameters to follow: # - Temperature: {generate_params.temperature} # - Creativity level: {"high" if generate_params.temperature > 0.7 else "medium" if generate_params.temperature > 0.3 else "low"} # - Maximum length: {generate_params.max_length} tokens -# +# # Important instructions: # - Continue the text naturally from where the prompt ends # - Do not repeat the prompt in your response @@ -1177,12 +1408,12 @@ def tts_demo(request: Request): # - Do not use phrases like "Here's a continuation" or "Continuing from the prompt" # - Just generate the continuation text directly # """ -# +# # # Set up stop sequences # stop_sequences = None # if generate_params.stop_sequences and len(generate_params.stop_sequences) > 0: # stop_sequences = frozenset(generate_params.stop_sequences) -# +# # # Call Claude API with the specified model # generated_text = await query_to_claude_async( # prompt=prompt, @@ -1190,18 +1421,18 @@ def tts_demo(request: Request): # system_message=system_message, # model=model_name, # Pass the model name to the Claude API function # ) -# +# # # Handle the response # if generated_text is None: # return HTTPException(status_code=500, detail="Failed to generate text with Claude") -# +# # # Format the response to match the standard API format # result = [{ # "generated_text": prompt + generated_text, # "finished_reason": "length", # "model": model_name # }] -# +# # # Track usage if needed # if request and background_tasks: # if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers: @@ -1209,9 +1440,9 @@ def tts_demo(request: Request): # # Claude is more expensive, so we charge more units based on the model # quantity = 5 if "opus" in model_name else 3 if "sonnet" in model_name else 2 # background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=quantity) -# +# # return result -# +# # except Exception as e: # logger.error(f"Error generating text with Claude: {e}") # return HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") diff --git a/questions/models.py b/questions/models.py index 2268c6b..b171572 100755 --- a/questions/models.py +++ b/questions/models.py @@ -1,5 +1,6 @@ from time import time from typing import Optional, List +import os from fastapi import UploadFile from pydantic import BaseModel @@ -73,6 +74,7 @@ class GenerateSpeechParams(BaseModel): speed: float = 1.0 volume: float = 1.0 sample_rate: int = 24000 # Kokoro outputs 24kHz audio + chunk_words: int = int(os.getenv("STREAM_CHUNK_WORDS", "100")) class OpenaiParams(BaseModel): diff --git a/static/openapi.json b/static/openapi.json index a2a4735..271db17 100644 --- a/static/openapi.json +++ b/static/openapi.json @@ -150,6 +150,56 @@ } } }, + "/api/v1/generate_speech_stream": { + "post": { + "summary": "Stream Speech", + "operationId": "generate_speech_stream_api_v1_generate_speech_stream_post", + "parameters": [ + { + "required": false, + "schema": { + "title": "Secret", + "type": "string" + }, + "name": "secret", + "in": "header" + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenerateSpeechParams" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/octet-stream": { + "schema": { + "type": "string", + "format": "binary" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/api/v1/audio-file-extraction": { "post": { "summary": "Audio File Extraction", @@ -574,6 +624,11 @@ "title": "Speaker", "type": "string", "default": "Female 1" + }, + "chunk_words": { + "title": "Chunk Words", + "type": "integer", + "default": 100 } } }, diff --git a/static/openapi2-audio.json b/static/openapi2-audio.json index e762a87..c3894a5 100644 --- a/static/openapi2-audio.json +++ b/static/openapi2-audio.json @@ -209,6 +209,56 @@ } } }, + "/api/v1/generate_speech_stream": { + "post": { + "summary": "Stream Speech", + "operationId": "generate_speech_stream_api_v1_generate_speech_stream_post", + "parameters": [ + { + "required": false, + "schema": { + "title": "Secret", + "type": "string" + }, + "name": "secret", + "in": "header" + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/GenerateSpeechParams" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/octet-stream": { + "schema": { + "type": "string", + "format": "binary" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/api/v1/generate": { "post": { "summary": "Generate Route", @@ -572,6 +622,11 @@ "title": "Speaker", "type": "string", "default": "Female 1" + }, + "chunk_words": { + "title": "Chunk Words", + "type": "integer", + "default": 100 } } }, diff --git a/static/templates/shared/streaming-speech-api.jinja2 b/static/templates/shared/streaming-speech-api.jinja2 new file mode 100644 index 0000000..fa82b20 --- /dev/null +++ b/static/templates/shared/streaming-speech-api.jinja2 @@ -0,0 +1,38 @@ + + + +
+
+
+
+
+ +
+ Text Generator > Blog > Streaming Speech API +
+

Stream text to speech output

+

No more waiting for long texts

+

Our new endpoint /api/v1/generate_speech_stream lets you begin playback as soon as audio is generated. The backend automatically chunks long text so memory use stays low.

+

Example usage

+
import requests
+import os
+
+API_KEY = os.getenv("TEXT_GENERATOR_API_KEY")
+headers = {"secret": API_KEY}
+
+params = {"text": "Hello streaming world", "speaker": "Male fast"}
+with requests.post("https://api.text-generator.io/api/v1/generate_speech_stream", json=params, headers=headers, stream=True) as r:
+    for chunk in r.iter_content(chunk_size=None):
+        # Play or save chunk
+        print("Received", len(chunk), "bytes")
+
+

Chunks are standard WAV fragments with a 24 kHz sample rate.

+

See the API docs for more details.

+ +
+
+ {% import "templates/macros.jinja2" as macros with context %} + {{ macros.svgstyled() }} +
diff --git a/tests/integ/test_audio_length_limit.py b/tests/integ/test_audio_length_limit.py new file mode 100644 index 0000000..9d600f3 --- /dev/null +++ b/tests/integ/test_audio_length_limit.py @@ -0,0 +1,28 @@ +import os +import pytest + +from starlette.testclient import TestClient + +from questions.inference_server.inference_server import app, audio_process + +client = TestClient(app) + +API_KEY = os.environ.get("TEXT_GENERATOR_API_KEY", "test") +headers = {"secret": API_KEY} + + +@pytest.mark.skipif(os.getenv("CI") == "true", reason="skip heavy audio test in CI") +def test_audio_word_limit(): + text = "hello" + prev_len = 0 + max_words = None + for i in range(1, 30): + rate, audio = audio_process(text, "af_nicole") + assert rate == 24000 + cur_len = len(audio) + if cur_len <= prev_len: + max_words = len(text.split()) - 1 + break + prev_len = cur_len + text += " hello" + assert max_words is not None diff --git a/tests/integ/test_inference_server_speech.py b/tests/integ/test_inference_server_speech.py index 08c1298..3e8ba55 100644 --- a/tests/integ/test_inference_server_speech.py +++ b/tests/integ/test_inference_server_speech.py @@ -1,4 +1,5 @@ import os +import pytest from questions.utils import log_time @@ -19,17 +20,22 @@ def test_speech_creation(): with log_time("speech creation"): - result = audio_process("It is not in the stars to hold our destiny but in ourselves." * 10, "Male fast") + result = audio_process( + "It is not in the stars to hold our destiny but in ourselves." * 10, + "Male fast", + ) assert result is not None assert result[0] is not None -def test_speech_creation_route(): +def test_generate_speech_route_single_voice(): audio_params = GenerateSpeechParams( text="Text-Generator.io is bringing the cost of intelligence toward zero.", speaker="Male fast", ) - response = client.post("/api/v1/generate_speech", json=audio_params.__dict__, headers=headers) + response = client.post( + "/api/v1/generate_speech", json=audio_params.__dict__, headers=headers + ) assert response.status_code == 200, response.text binary_file_response = response.content assert binary_file_response is not None @@ -37,6 +43,21 @@ def test_speech_creation_route(): f.write(binary_file_response) +@pytest.mark.skipif(os.getenv("CI") == "true", reason="skip heavy audio test in CI") +def test_generate_speech_stream_route(): + audio_params = GenerateSpeechParams( + text="Streaming audio makes responses faster.", + speaker="Male fast", + chunk_words=5, + ) + response = client.post( + "/api/v1/generate_speech_stream", json=audio_params.__dict__, headers=headers, stream=True + ) + assert response.status_code == 200, response.text + data = b"".join(response.iter_content(chunk_size=None)) + assert len(data) > 0 + + speakers = [ "Male fast", "Female 1", @@ -46,13 +67,15 @@ def test_speech_creation_route(): ] -def test_speech_creation_route(): +def test_generate_speech_route_all_voices(): for speaker in speakers: audio_params = GenerateSpeechParams( text="Text-Generator.io is bringing the cost of intelligence toward zero.", speaker=speaker, ) - response = client.post("/api/v1/generate_speech", json=audio_params.__dict__, headers=headers) + response = client.post( + "/api/v1/generate_speech", json=audio_params.__dict__, headers=headers + ) assert response.status_code == 200, response.text binary_file_response = response.content assert binary_file_response is not None