diff --git a/README.md b/README.md
index 73116ab..4a71885 100755
--- a/README.md
+++ b/README.md
@@ -42,6 +42,7 @@ There's also more control of text generation via the Text-generator API, this in
* max_sentences (generate only a set number of sentences at most)
Text generator also has routes for speech to text and speech generation.
+You can now stream generated speech in real time via `/api/v1/generate_speech_stream`.
See https://text-generator.io/docs
diff --git a/apiexamples/streaming_speech.py b/apiexamples/streaming_speech.py
new file mode 100644
index 0000000..b960e0d
--- /dev/null
+++ b/apiexamples/streaming_speech.py
@@ -0,0 +1,16 @@
+import requests
+import os
+
+API_KEY = os.getenv("TEXT_GENERATOR_API_KEY")
+if API_KEY is None:
+ raise Exception(
+ "Please set TEXT_GENERATOR_API_KEY environment variable, login to https://text-generator.io to get your API key")
+headers = {"secret": API_KEY}
+
+params = {
+ "text": "Hello streaming world",
+ "speaker": "Male fast"
+}
+with requests.post("https://api.text-generator.io/api/v1/generate_speech_stream", json=params, headers=headers, stream=True) as r:
+ for chunk in r.iter_content(chunk_size=None):
+ print("Received", len(chunk), "bytes")
diff --git a/questions/blog_fixtures.py b/questions/blog_fixtures.py
index 59475d2..801af0b 100644
--- a/questions/blog_fixtures.py
+++ b/questions/blog_fixtures.py
@@ -54,4 +54,9 @@
"description": "Control summary length precisely by specifying maximum characters in the summarization API",
"keywords": "text summarization, API, natural language processing, text generation, Machine Learning"
},
+ "streaming-speech-api": {
+ "title": "Stream Speech Output in Real Time",
+ "description": "Generate speech in chunks and stream it directly without buffering the whole audio.",
+ "keywords": "text to speech, streaming, API"
+ },
}
diff --git a/questions/inference_server/inference_server.py b/questions/inference_server/inference_server.py
index 1811d1d..9369c93 100644
--- a/questions/inference_server/inference_server.py
+++ b/questions/inference_server/inference_server.py
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import os
import random
+import asyncio
from io import BytesIO
from tempfile import NamedTemporaryFile
from typing import Union, List, Iterator
@@ -16,7 +17,14 @@
setup_logging()
logger = logging.getLogger(__name__)
-from starlette.responses import JSONResponse, RedirectResponse, Response, StreamingResponse, HTMLResponse
+CHUNK_WORDS_DEFAULT = int(os.getenv("STREAM_CHUNK_WORDS", "100"))
+from starlette.responses import (
+ JSONResponse,
+ RedirectResponse,
+ Response,
+ StreamingResponse,
+ HTMLResponse,
+)
from questions.audio_server.audio_dl import request_get
from questions.constants import weights_path_tgz
@@ -27,10 +35,17 @@
FeatureExtractParams,
OpenaiParams,
map_to_generate_params,
- map_to_openai_response, AudioParams, GenerateSpeechParams, AudioParamsOrAudioFile, SummarizationParams,
+ map_to_openai_response,
+ AudioParams,
+ GenerateSpeechParams,
+ AudioParamsOrAudioFile,
+ SummarizationParams,
+)
+from questions.payments.payments import (
+ get_subscription_item_id_for_user,
+ create_subscription_for_user,
+ get_subscription_item_id_for_user_email,
)
-from questions.payments.payments import get_subscription_item_id_for_user, \
- create_subscription_for_user, get_subscription_item_id_for_user_email
from questions.perplexity import DEVICE
from questions.summarization import get_extractive_summary
from questions.text_gen_pipeline import TextGenPipeline
@@ -42,7 +57,7 @@
from questions.utils import log_time
from sellerinfo import session_secret
from .models import build_model
-from .kokoro import generate
+from .kokoro import generate, generate_full
import librosa
@@ -82,23 +97,90 @@
# async def create_files(files: bytes = File()):
# return {"file_sizes": [len(file) for file in files]}
-languages = {"af_za": "Afrikaans", "am_et": "Amharic", "ar_eg": "Arabic", "as_in": "Assamese", "az_az": "Azerbaijani",
- "be_by": "Belarusian", "bg_bg": "Bulgarian", "bn_in": "Bengali", "bs_ba": "Bosnian", "ca_es": "Catalan",
- "cmn_hans_cn": "Chinese", "cs_cz": "Czech", "cy_gb": "Welsh", "da_dk": "Danish", "de_de": "German",
- "el_gr": "Greek", "en_us": "English", "es_419": "Spanish", "et_ee": "Estonian", "fa_ir": "Persian",
- "fi_fi": "Finnish", "fil_ph": "Tagalog", "fr_fr": "French", "gl_es": "Galician", "gu_in": "Gujarati",
- "ha_ng": "Hausa", "he_il": "Hebrew", "hi_in": "Hindi", "hr_hr": "Croatian", "hu_hu": "Hungarian",
- "hy_am": "Armenian", "id_id": "Indonesian", "is_is": "Icelandic", "it_it": "Italian", "ja_jp": "Japanese",
- "jv_id": "Javanese", "ka_ge": "Georgian", "kk_kz": "Kazakh", "km_kh": "Khmer", "kn_in": "Kannada",
- "ko_kr": "Korean", "lb_lu": "Luxembourgish", "ln_cd": "Lingala", "lo_la": "Lao", "lt_lt": "Lithuanian",
- "lv_lv": "Latvian", "mi_nz": "Maori", "mk_mk": "Macedonian", "ml_in": "Malayalam", "mn_mn": "Mongolian",
- "mr_in": "Marathi", "ms_my": "Malay", "mt_mt": "Maltese", "my_mm": "Myanmar", "nb_no": "Norwegian",
- "ne_np": "Nepali", "nl_nl": "Dutch", "oc_fr": "Occitan", "pa_in": "Punjabi", "pl_pl": "Polish",
- "ps_af": "Pashto", "pt_br": "Portuguese", "ro_ro": "Romanian", "ru_ru": "Russian", "sd_in": "Sindhi",
- "sk_sk": "Slovak", "sl_si": "Slovenian", "sn_zw": "Shona", "so_so": "Somali", "sr_rs": "Serbian",
- "sv_se": "Swedish", "sw_ke": "Swahili", "ta_in": "Tamil", "te_in": "Telugu", "tg_tj": "Tajik",
- "th_th": "Thai", "tr_tr": "Turkish", "uk_ua": "Ukrainian", "ur_pk": "Urdu", "uz_uz": "Uzbek",
- "vi_vn": "Vietnamese", "yo_ng": "Yoruba"}
+languages = {
+ "af_za": "Afrikaans",
+ "am_et": "Amharic",
+ "ar_eg": "Arabic",
+ "as_in": "Assamese",
+ "az_az": "Azerbaijani",
+ "be_by": "Belarusian",
+ "bg_bg": "Bulgarian",
+ "bn_in": "Bengali",
+ "bs_ba": "Bosnian",
+ "ca_es": "Catalan",
+ "cmn_hans_cn": "Chinese",
+ "cs_cz": "Czech",
+ "cy_gb": "Welsh",
+ "da_dk": "Danish",
+ "de_de": "German",
+ "el_gr": "Greek",
+ "en_us": "English",
+ "es_419": "Spanish",
+ "et_ee": "Estonian",
+ "fa_ir": "Persian",
+ "fi_fi": "Finnish",
+ "fil_ph": "Tagalog",
+ "fr_fr": "French",
+ "gl_es": "Galician",
+ "gu_in": "Gujarati",
+ "ha_ng": "Hausa",
+ "he_il": "Hebrew",
+ "hi_in": "Hindi",
+ "hr_hr": "Croatian",
+ "hu_hu": "Hungarian",
+ "hy_am": "Armenian",
+ "id_id": "Indonesian",
+ "is_is": "Icelandic",
+ "it_it": "Italian",
+ "ja_jp": "Japanese",
+ "jv_id": "Javanese",
+ "ka_ge": "Georgian",
+ "kk_kz": "Kazakh",
+ "km_kh": "Khmer",
+ "kn_in": "Kannada",
+ "ko_kr": "Korean",
+ "lb_lu": "Luxembourgish",
+ "ln_cd": "Lingala",
+ "lo_la": "Lao",
+ "lt_lt": "Lithuanian",
+ "lv_lv": "Latvian",
+ "mi_nz": "Maori",
+ "mk_mk": "Macedonian",
+ "ml_in": "Malayalam",
+ "mn_mn": "Mongolian",
+ "mr_in": "Marathi",
+ "ms_my": "Malay",
+ "mt_mt": "Maltese",
+ "my_mm": "Myanmar",
+ "nb_no": "Norwegian",
+ "ne_np": "Nepali",
+ "nl_nl": "Dutch",
+ "oc_fr": "Occitan",
+ "pa_in": "Punjabi",
+ "pl_pl": "Polish",
+ "ps_af": "Pashto",
+ "pt_br": "Portuguese",
+ "ro_ro": "Romanian",
+ "ru_ru": "Russian",
+ "sd_in": "Sindhi",
+ "sk_sk": "Slovak",
+ "sl_si": "Slovenian",
+ "sn_zw": "Shona",
+ "so_so": "Somali",
+ "sr_rs": "Serbian",
+ "sv_se": "Swedish",
+ "sw_ke": "Swahili",
+ "ta_in": "Tamil",
+ "te_in": "Telugu",
+ "tg_tj": "Tajik",
+ "th_th": "Thai",
+ "tr_tr": "Turkish",
+ "uk_ua": "Ukrainian",
+ "ur_pk": "Urdu",
+ "uz_uz": "Uzbek",
+ "vi_vn": "Vietnamese",
+ "yo_ng": "Yoruba",
+}
# expose best of?
options = {} # dict(beam_size=5, best_of=5) (makes it slower)
@@ -116,7 +198,9 @@
# result = audio_model.transcribe("/media/lee/78ca132e-d181-4406-aea5-3c9665f486cc/Videos/intro-bitbanknz.mp3") 40/6s
app.mount("/static", StaticFiles(directory="static"), name="static")
-app.mount("/gameon/static", StaticFiles(directory="gameon/static"), name="gameon/static")
+app.mount(
+ "/gameon/static", StaticFiles(directory="gameon/static"), name="gameon/static"
+)
app.add_middleware(
CORSMiddleware,
@@ -237,9 +321,7 @@ def track_stripe_request_usage(secret, quantity: int):
f"no subscription item id for user: {existing_user.email} {existing_user.stripe_id}"
)
# Create a new subscription item
- subscription = create_subscription_for_user(
- existing_user.stripe_id
- )
+ subscription = create_subscription_for_user(existing_user.stripe_id)
logger.info(f"created subscription: {subscription}")
try:
subscription_item_id = subscription["items"].data[0]["id"]
@@ -263,13 +345,17 @@ def validate_generate_params(generate_params):
audio_model = None
+
+
def load_audio_model():
global audio_model
# about 10s
with log_time("load whisper model"):
# audio_model = whisper.load_model("large") # todo specify download_root to a fast ssd
if not audio_model:
- audio_model = whisper.load_model("medium", download_root="models") # todo specify download_root to a fast ssd
+ audio_model = whisper.load_model(
+ "medium", download_root="models"
+ ) # todo specify download_root to a fast ssd
audio_model.eval()
audio_model = audio_model.to("cuda")
logger.info(
@@ -291,26 +377,29 @@ def fast_audio_extract_inference(audio_params: AudioParamsOrAudioFile):
# todo download youtube video
ydl_opts = {
- 'format': 'bestaudio/best',
- 'postprocessors': [{
- 'key': 'FFmpegExtractAudio',
- 'preferredcodec': 'wav',
- 'preferredquality': '192'
- }],
- 'postprocessor_args': [
- '-ar', '16000'
+ "format": "bestaudio/best",
+ "postprocessors": [
+ {
+ "key": "FFmpegExtractAudio",
+ "preferredcodec": "wav",
+ "preferredquality": "192",
+ }
],
- 'prefer_ffmpeg': True,
- 'keepvideo': False,
+ "postprocessor_args": ["-ar", "16000"],
+ "prefer_ffmpeg": True,
+ "keepvideo": False,
# download to temp file
- 'outtmpl': '/tmp/audio.wav'
+ "outtmpl": "/tmp/audio.wav",
# download to memory
-
}
with log_time("download youtube"):
with youtube_dl.YoutubeDL(ydl_opts) as ydl:
- ydl.download([audio_params.audio_url], )
- audio_bytes = ydl.prepare_filename(ydl.extract_info(audio_params.audio_url))
+ ydl.download(
+ [audio_params.audio_url],
+ )
+ audio_bytes = ydl.prepare_filename(
+ ydl.extract_info(audio_params.audio_url)
+ )
with open(audio_bytes, "rb") as f:
audio_bytes = f.read()
else:
@@ -319,7 +408,8 @@ def fast_audio_extract_inference(audio_params: AudioParamsOrAudioFile):
response = audio_request.result()
if response.status_code != 200:
raise HTTPException(
- status_code=500, detail=f"Failed to download audio file at {audio_params.audio_url}"
+ status_code=500,
+ detail=f"Failed to download audio file at {audio_params.audio_url}",
)
response.raw.decode_content = True
audio_bytes = response.content
@@ -335,9 +425,9 @@ def fast_audio_extract_inference(audio_params: AudioParamsOrAudioFile):
# clean data
tmp_file.close()
- for segment in result['segments']:
- del segment['tokens']
- result['text'] = result['text'].strip()
+ for segment in result["segments"]:
+ del segment["tokens"]
+ result["text"] = result["text"].strip()
return result
@@ -372,35 +462,43 @@ def write_srt(transcript: Iterator[dict]):
@app.post("/api/v1/audio-file-extraction")
async def audio_file_extraction(
- background_tasks: BackgroundTasks,
- request: Request,
- response: Response,
- # files: List[UploadFile] = File(...),
- # files: list[bytes] = File(None, description="Multiple files as bytes"),
- audio_file: UploadFile = File(None, description="Audio file"),
- translate_to_english: bool = Form(False),
- output_filetype: str = Form("txt"),
- secret: Union[str, None] = Header(default=None),
+ background_tasks: BackgroundTasks,
+ request: Request,
+ response: Response,
+ # files: List[UploadFile] = File(...),
+ # files: list[bytes] = File(None, description="Multiple files as bytes"),
+ audio_file: UploadFile = File(None, description="Audio file"),
+ translate_to_english: bool = Form(False),
+ output_filetype: str = Form("txt"),
+ secret: Union[str, None] = Header(default=None),
):
- audio_params = AudioParamsOrAudioFile(audio_file=audio_file, audio_url=None,
- translate_to_english=translate_to_english, output_filetype=output_filetype)
- return await audio_extract_shared(background_tasks, audio_params, request, response, secret)
+ audio_params = AudioParamsOrAudioFile(
+ audio_file=audio_file,
+ audio_url=None,
+ translate_to_english=translate_to_english,
+ output_filetype=output_filetype,
+ )
+ return await audio_extract_shared(
+ background_tasks, audio_params, request, response, secret
+ )
@app.post("/api/v1/audio-extraction")
async def audio_extraction(
- feature_extract_params: AudioParams, # wav files or mp3 supported
- # audio_file: UploadFile,
- background_tasks: BackgroundTasks,
- request: Request,
- response: Response,
- secret: Union[str, None] = Header(default=None),
+ feature_extract_params: AudioParams, # wav files or mp3 supported
+ # audio_file: UploadFile,
+ background_tasks: BackgroundTasks,
+ request: Request,
+ response: Response,
+ secret: Union[str, None] = Header(default=None),
):
# if not request_authorized(request, secret):
# return HTTPException(
# status_code=401, detail="Please subscribe at https://text-generator.io/subscribe first"
# )
- return await audio_extract_shared(background_tasks, feature_extract_params, request, response, secret)
+ return await audio_extract_shared(
+ background_tasks, feature_extract_params, request, response, secret
+ )
# @app.get("/restart-server")
@@ -410,7 +508,10 @@ async def audio_extraction(
# os.system("kill -SIGHUP `pgrep gunicorn`")
# return "restarting server"
-async def audio_extract_shared(background_tasks, feature_extract_params, request, response, secret):
+
+async def audio_extract_shared(
+ background_tasks, feature_extract_params, request, response, secret
+):
# if not request_authorized(request, secret):
# return HTTPException(
# status_code=401,
@@ -419,41 +520,49 @@ async def audio_extract_shared(background_tasks, feature_extract_params, request
inference_result = fast_audio_extract_inference(feature_extract_params)
if "X-Rapid-API-Key" not in request.headers:
# todo fix
- seconds_taken = inference_result['segments'][-1]['end']
+ seconds_taken = inference_result["segments"][-1]["end"]
# price of quantity is 1 for .01
price = seconds_taken * 0.00005
- quantity = price // .01
- remainder = price % .01
+ quantity = price // 0.01
+ remainder = price % 0.01
if random.random() < remainder * 100:
quantity += 1
if quantity:
- if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET: # skip logging for our secret
- background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=int(quantity))
+ if (
+ not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET
+ ): # skip logging for our secret
+ background_tasks.add_task(
+ track_stripe_request_usage, secret=secret, quantity=int(quantity)
+ )
if feature_extract_params.output_filetype == "srt":
# response = StreamingResponse(
# non streaming response
response.headers["Content-Disposition"] = "attachment; filename=audio.srt"
- return write_srt(inference_result['segments'])
+ return write_srt(inference_result["segments"])
return inference_result
@app.post("/api/v1/feature-extraction")
async def feature_extraction(
- feature_extract_params: FeatureExtractParams,
- background_tasks: BackgroundTasks,
- request: Request,
- secret: Union[str, None] = Header(default=None),
+ feature_extract_params: FeatureExtractParams,
+ background_tasks: BackgroundTasks,
+ request: Request,
+ secret: Union[str, None] = Header(default=None),
):
# global daemon
# slow warmup on new servers
# model = MODEL_CACHE.add_or_get("text_model", load_pipelines_and_model)
# daemon.join()
- inference_result = fast_feature_extract_inference(feature_extract_params, MODEL_CACHE)
+ inference_result = fast_feature_extract_inference(
+ feature_extract_params, MODEL_CACHE
+ )
if "X-Rapid-API-Key" not in request.headers:
# todo fix
if random.randint(1, 10) == 10:
if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET:
- background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1)
+ background_tasks.add_task(
+ track_stripe_request_usage, secret=secret, quantity=1
+ )
return inference_result[: feature_extract_params.num_features]
@@ -473,17 +582,19 @@ async def feature_extraction(
# todo fix
if random.randint(1, 10) == 10:
if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET:
- background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1)
- return JSONResponse({
- "generated_text": text
- })
+ background_tasks.add_task(
+ track_stripe_request_usage, secret=secret, quantity=1
+ )
+ return JSONResponse({"generated_text": text})
@app.get("/liveness_check")
async def liveness_check(request: Request):
# global daemon
inference_result = fast_inference(
- generate_params=GenerateParams(text="hi my friend", min_probability=0.9, max_length=1, model='any'),
+ generate_params=GenerateParams(
+ text="hi my friend", min_probability=0.9, max_length=1, model="any"
+ ),
model_cache=MODEL_CACHE,
)
return JSONResponse(inference_result)
@@ -522,6 +633,8 @@ def request_authorized(request: Request, secret):
speech_processor = None
speechgen_model = None
speech_vocoder = None
+
+
def load_speechgen_model():
global speech_processor
global speechgen_model
@@ -529,18 +642,30 @@ def load_speechgen_model():
if not speechgen_model:
# Load Kokoro model
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
- speechgen_model = build_model('models/kokoro-v0_19.pth', device)
-
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ speechgen_model = build_model("models/kokoro-v0_19.pth", device)
+
# Load voice packs
voicepacks = {}
- voice_names = ['af', 'af_bella', 'af_sarah', 'am_adam', 'am_michael',
- 'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis',
- 'af_nicole', 'af_sky']
-
+ voice_names = [
+ "af",
+ "af_bella",
+ "af_sarah",
+ "am_adam",
+ "am_michael",
+ "bf_emma",
+ "bf_isabella",
+ "bm_george",
+ "bm_lewis",
+ "af_nicole",
+ "af_sky",
+ ]
+
for voice in voice_names:
- voicepacks[voice] = torch.load(f'models/voices/{voice}.pt', weights_only=True).to(device)
-
+ voicepacks[voice] = torch.load(
+ f"models/voices/{voice}.pt", weights_only=True
+ ).to(device)
+
return speechgen_model, voicepacks
@@ -557,23 +682,23 @@ def write_wav(processed_np_speech, rate):
# todo fix to use io.BytesIO
bytes = BytesIO()
bytes.name = "audio.wav"
- sf.write(bytes, processed_np_speech, rate, subtype='PCM_24')
+ sf.write(bytes, processed_np_speech, rate, subtype="PCM_24")
# bytesio to bytes
return bytes.getvalue()
@app.post("/api/v1/generate_speech")
async def generate_speech(
- generate_speech_params: GenerateSpeechParams,
- background_tasks: BackgroundTasks,
- request: Request,
- response: Response,
- secret: Union[str, None] = Header(default=None),
+ generate_speech_params: GenerateSpeechParams,
+ background_tasks: BackgroundTasks,
+ request: Request,
+ response: Response,
+ secret: Union[str, None] = Header(default=None),
):
if not request_authorized(request, secret):
return HTTPException(
status_code=401,
- detail="Invalid Secret, please use the secret found in /account also subscribe at https://text-generator.io/subscribe first, also make sure there is an up to date credit card saved in your account"
+ detail="Invalid Secret, please use the secret found in /account also subscribe at https://text-generator.io/subscribe first, also make sure there is an up to date credit card saved in your account",
)
text = generate_speech_params.text
@@ -586,7 +711,9 @@ async def generate_speech(
if "X-Rapid-API-Key" not in request.headers:
# todo fix
if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET:
- background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1)
+ background_tasks.add_task(
+ track_stripe_request_usage, secret=secret, quantity=1
+ )
# write np array to wav
wav = write_wav(processed_np_speech, rate)
file = Response(wav, media_type="audio/wav")
@@ -594,19 +721,48 @@ async def generate_speech(
# return Response(wav, media_type="audio/wav")
+@app.post("/api/v1/generate_speech_stream")
+async def generate_speech_stream(
+ generate_speech_params: GenerateSpeechParams,
+ background_tasks: BackgroundTasks,
+ request: Request,
+ secret: Union[str, None] = Header(default=None),
+):
+ if not request_authorized(request, secret):
+ return HTTPException(
+ status_code=401,
+ detail="Invalid Secret, please use the secret found in /account also subscribe at https://text-generator.io/subscribe first, also make sure there is an up to date credit card saved in your account",
+ )
+
+ if "X-Rapid-API-Key" not in request.headers:
+ if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET:
+ background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1)
+
+ generator = stream_audio_chunks(
+ generate_speech_params.text,
+ generate_speech_params.voice,
+ generate_speech_params.speed,
+ generate_speech_params.chunk_words,
+ )
+ response = StreamingResponse(generator, media_type="audio/wav")
+ response.headers["Content-Disposition"] = "attachment; filename=audio.wav"
+ return response
+
+
def gradio_audio_process(text, voice, speed=1.0):
"""Simplified function that only takes the required parameters"""
if len(text.strip()) == 0:
return (24000, np.zeros(0).astype(np.int16))
model, voicepacks = MODEL_CACHE.add_or_get("speech_model", load_speechgen_model)
-
+
# Get the voicepack
- voicepack = voicepacks.get(voice, voicepacks['af_nicole'])
-
+ voicepack = voicepacks.get(voice, voicepacks["af_nicole"])
+
# Generate audio using Kokoro
- audio, phonemes = generate(model, text, voicepack, lang=voice[0], speed=speed)
-
+ # generate_full handles inputs longer than the default token limit
+ audio, phonemes = generate_full(model, text, voicepack, lang=voice[0], speed=speed)
+
return (24000, audio)
@@ -617,9 +773,7 @@ def gradio_audio_process(text, voice, speed=1.0):
"Male slower": "RMS (male)",
"Female 2": "SLT (female)",
}
-speaker_embeddings_loaded = {
-
-}
+speaker_embeddings_loaded = {}
def load_speaker_embedding(speaker):
@@ -635,16 +789,17 @@ def load_speaker_embedding(speaker):
def audio_process(text, voice="af_nicole", speed=1.0):
model, voicepacks = MODEL_CACHE.add_or_get("speech_model", load_speechgen_model)
-
+
if len(text.strip()) == 0:
return (24000, np.zeros(0).astype(np.int16))
# Get the voicepack
- voicepack = voicepacks.get(voice, voicepacks['af_nicole'])
-
+ voicepack = voicepacks.get(voice, voicepacks["af_nicole"])
+
# Generate audio using Kokoro
- audio, phonemes = generate(model, text, voicepack, lang=voice[0], speed=speed)
-
+ # generate_full handles inputs longer than the default token limit
+ audio, phonemes = generate_full(model, text, voicepack, lang=voice[0], speed=speed)
+
# we could do this but use speed instead
# Convert to float32 for time-stretch
# audio_float = audio.astype(np.float32) / 32767.0
@@ -658,7 +813,48 @@ def audio_process(text, voice="af_nicole", speed=1.0):
return (24000, audio)
+def synthesize_full_text(
+ text: str, voice: str = "af_nicole", speed: float = 1.0, chunk_words: int = CHUNK_WORDS_DEFAULT
+):
+ """Generate speech for arbitrarily long text by chunking."""
+ if not text.strip():
+ return (24000, np.zeros(0, dtype=np.int16))
+
+ words = text.split()
+ segments = []
+ for start in range(0, len(words), chunk_words):
+ chunk = " ".join(words[start : start + chunk_words])
+ _, audio = audio_process(chunk, voice=voice, speed=speed)
+ segments.append(audio)
+
+ if segments:
+ full_audio = np.concatenate(segments)
+ else:
+ full_audio = np.zeros(0, dtype=np.int16)
+
+ return (24000, full_audio)
+
+def stream_audio_chunks(
+ text: str, voice: str = "af_nicole", speed: float = 1.0, chunk_words: int = CHUNK_WORDS_DEFAULT
+):
+ """Yield audio wav bytes for each chunk of text."""
+ words = text.split()
+ for start in range(0, len(words), chunk_words):
+ chunk = " ".join(words[start : start + chunk_words])
+ _, audio = audio_process(chunk, voice=voice, speed=speed)
+ yield write_wav(audio, 24000)
+
+
+async def stream_audio_chunks_async(
+ text: str, voice: str = "af_nicole", speed: float = 1.0, chunk_words: int = CHUNK_WORDS_DEFAULT
+):
+ """Asynchronously yield audio wav bytes for each chunk of text."""
+ words = text.split()
+ for start in range(0, len(words), chunk_words):
+ chunk = " ".join(words[start : start + chunk_words])
+ _, audio = await asyncio.to_thread(audio_process, chunk, voice, speed)
+ yield write_wav(audio, 24000)
# gradio web app at https://text-generator.io/gradio_tts
@@ -666,16 +862,24 @@ def audio_process(text, voice="af_nicole", speed=1.0):
examples = [
["It is not in the stars to hold our destiny but in ourselves.", "Male default"],
["Oliver went to the opera in October.", "Female 2"],
- ["She sells seashells by the seashore. I saw a kitten eating chicken in the kitchen.", "Male slower"],
- ["Brisk brave brigadiers brandished broad bright blades, blunderbusses, and bludgeons—balancing them badly.",
- "Female 1"],
+ [
+ "She sells seashells by the seashore. I saw a kitten eating chicken in the kitchen.",
+ "Male slower",
+ ],
+ [
+ "Brisk brave brigadiers brandished broad bright blades, blunderbusses, and bludgeons—balancing them badly.",
+ "Female 1",
+ ],
["A synonym for cinnamon is a cinnamon synonym.", "Male fast"],
[
"How much wood would a woodchuck chuck if a woodchuck could chuck wood? He would chuck, he would, as much as he could, and chuck as much wood as a woodchuck would if a woodchuck could chuck wood.",
- "Female 2"],
+ "Female 2",
+ ],
]
title = "Text To Speech"
-description = "How to use: Enter some English text and choose a speaker. Click Submit "
+description = (
+ "How to use: Enter some English text and choose a speaker. Click Submit "
+)
article = """
Checkout the API docs at /docs
"""
@@ -828,22 +1032,26 @@ def audio_process(text, voice="af_nicole", speed=1.0):
# async def config_route():
# return RedirectResponse("/gradio_tts/info")
+
@app.get("/setcookie")
-async def setcookie(secret:str, request: Request, response: Response):
- response.set_cookie('secret', value=secret)#, httponly=True)
+async def setcookie(secret: str, request: Request, response: Response):
+ response.set_cookie("secret", value=secret) # , httponly=True)
return {"message": "Cookie set"}
+
@app.get("/gradio_frame")
-async def gradio_frame_route(secret:str, request: Request):
- return templates.TemplateResponse("templates/gradio-frame.jinja2", {"request": request, "secret": secret})
+async def gradio_frame_route(secret: str, request: Request):
+ return templates.TemplateResponse(
+ "templates/gradio-frame.jinja2", {"request": request, "secret": secret}
+ )
@app.post("/api/v1/generate")
async def generate_route(
- generate_params: GenerateParams,
- background_tasks: BackgroundTasks = None,
- request: Request = None,
- secret: Union[str, None] = Header(default=None),
+ generate_params: GenerateParams,
+ background_tasks: BackgroundTasks = None,
+ request: Request = None,
+ secret: Union[str, None] = Header(default=None),
):
validation_result = validate_generate_params(generate_params)
if validation_result:
@@ -863,50 +1071,59 @@ async def generate_route(
inference_result = fast_inference(generate_params, MODEL_CACHE)
# todo vuln
if request and background_tasks:
- if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers:
+ if (
+ "X-Rapid-API-Key" not in request.headers
+ and "x-rapid-api-key" not in request.headers
+ ):
if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET:
- background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1)
+ background_tasks.add_task(
+ track_stripe_request_usage, secret=secret, quantity=1
+ )
return inference_result
@app.post("/api/discord")
async def discord_route(
- type: str = Form(...),
- id: str = Form(...),
- data: str = Form(...),
-
- # text: str,
- # background_tasks: BackgroundTasks = None,
- # request: Request = None,
- # secret: Union[str, None] = Header(default=None),
- # response: Response = None,
+ type: str = Form(...),
+ id: str = Form(...),
+ data: str = Form(...),
+ # text: str,
+ # background_tasks: BackgroundTasks = None,
+ # request: Request = None,
+ # secret: Union[str, None] = Header(default=None),
+ # response: Response = None,
):
- generate_params = GenerateParams(**{
- "text": "in 2022 the stock market has been expected to reach a record high.",
- "number_of_results": 1,
- "max_length": 100,
- "max_sentences": 1,
- "min_probability": 0,
- "stop_sequences": [],
- "top_p": 0.9,
- "top_k": 40,
- "temperature": 0.7,
- "repetition_penalty": 1.17,
- "seed": 0
- })
+ generate_params = GenerateParams(
+ **{
+ "text": "in 2022 the stock market has been expected to reach a record high.",
+ "number_of_results": 1,
+ "max_length": 100,
+ "max_sentences": 1,
+ "min_probability": 0,
+ "stop_sequences": [],
+ "top_p": 0.9,
+ "top_k": 40,
+ "temperature": 0.7,
+ "repetition_penalty": 1.17,
+ "seed": 0,
+ }
+ )
+
@app.post("/api/v1/generate-bulk")
async def generate_route_bulk(
- bulk_params: List[GenerateParams],
- background_tasks: BackgroundTasks = None,
- request: Request = None,
- secret: Union[str, None] = Header(default=None),
+ bulk_params: List[GenerateParams],
+ background_tasks: BackgroundTasks = None,
+ request: Request = None,
+ secret: Union[str, None] = Header(default=None),
):
# global daemon
# print(model.config.max_length)
# print(tokenizer.model_max_length)
# model.config.max_length = tokenizer.model_max_length
- model = MODEL_CACHE.add_or_get("text_model", lambda: load_pipelines_and_model(weights_path_tgz))
+ model = MODEL_CACHE.add_or_get(
+ "text_model", lambda: load_pipelines_and_model(weights_path_tgz)
+ )
# daemon.join()
inference_results = []
for generate_params in bulk_params:
@@ -965,11 +1182,11 @@ async def generate_route_bulk(
@app.post("/v1/engines/{engine_name}/completions")
async def openai_route_named(
- engine_name: str,
- openai_params: OpenaiParams,
- background_tasks: BackgroundTasks = None,
- request: Request = None,
- secret: Union[str, None] = Header(default=None),
+ engine_name: str,
+ openai_params: OpenaiParams,
+ background_tasks: BackgroundTasks = None,
+ request: Request = None,
+ secret: Union[str, None] = Header(default=None),
):
# global daemon
# print(model.config.max_length)
@@ -982,7 +1199,10 @@ async def openai_route_named(
if validation_result:
# return a 400 bad request from fast api
return HTTPException(status_code=400, detail=validation_result)
- if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers:
+ if (
+ "X-Rapid-API-Key" not in request.headers
+ and "x-rapid-api-key" not in request.headers
+ ):
header_auth = request.headers.get("Authorization", " ")
authorization_split = header_auth.split(" ")
if len(authorization_split) == 2:
@@ -991,32 +1211,39 @@ async def openai_route_named(
if not request_authorized(request, secret):
return HTTPException(
status_code=401,
- detail="Please subscribe at https://text-generator.io/subscribe first, also ensure you have a credit card on file"
+ detail="Please subscribe at https://text-generator.io/subscribe first, also ensure you have a credit card on file",
)
inference_result = fast_inference(generate_params, MODEL_CACHE)
if not openai_params.echo:
## remove all the inputs from the generated texts
for i in range(len(inference_result)):
- inference_result[i]["generated_text"] = inference_result[i]["generated_text"][
- len(openai_params.prompt):
- ]
+ inference_result[i]["generated_text"] = inference_result[i][
+ "generated_text"
+ ][len(openai_params.prompt) :]
# todo vuln
if request and background_tasks:
- if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers:
+ if (
+ "X-Rapid-API-Key" not in request.headers
+ and "x-rapid-api-key" not in request.headers
+ ):
if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET:
- background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=1)
+ background_tasks.add_task(
+ track_stripe_request_usage, secret=secret, quantity=1
+ )
# map to openai response
return map_to_openai_response(inference_result, generate_params)
@app.post("/v1/completions")
async def openai_route(
- openai_params: OpenaiParams,
- background_tasks: BackgroundTasks = None,
- request: Request = None,
- secret: Union[str, None] = Header(default=None),
+ openai_params: OpenaiParams,
+ background_tasks: BackgroundTasks = None,
+ request: Request = None,
+ secret: Union[str, None] = Header(default=None),
):
- return await openai_route_named("engine", openai_params, background_tasks, request, secret)
+ return await openai_route_named(
+ "engine", openai_params, background_tasks, request, secret
+ )
# redirect / to /docs
@@ -1032,14 +1259,17 @@ async def root():
# return RedirectResponse(url="/gradio_tts")
-logger.info("""
+logger.info(
+ """
Welcome to the Text Generator API!
Thanks for using our service. We hope you enjoy it and it means a lot to us that you are here.
Let me know how it goes and if you have any questions or feedback, please reach out at lee.penkman@wordsmashing.com
Downloading models may take a long time on the first run.
-""")
+"""
+)
+
@app.get("/text-to-speech-demo")
def tts_demo(request: Request):
@@ -1051,6 +1281,7 @@ def tts_demo(request: Request):
},
)
+
# @app.post("/api/v1/generate-long")
# async def generate_long_text(
# generate_params: GenerateParams,
@@ -1065,29 +1296,29 @@ def tts_demo(request: Request):
# validation_result = validate_generate_params(generate_params)
# if validation_result:
# return HTTPException(status_code=400, detail=validation_result)
-#
+#
# # Authorize the request
# if request and "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers:
# if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET:
# if not request_authorized(request, secret):
# return HTTPException(
-# status_code=401,
+# status_code=401,
# detail="Please subscribe at https://text-generator.io/subscribe first"
# )
-#
+#
# try:
# # Prepare the prompt for Claude
# prompt = generate_params.text
-#
+#
# # Set up system message to control generation parameters
# system_message = f"""
# You are a creative text generation assistant. Generate text that continues from the given prompt.
-#
+#
# Parameters to follow:
# - Temperature: {generate_params.temperature}
# - Creativity level: {"high" if generate_params.temperature > 0.7 else "medium" if generate_params.temperature > 0.3 else "low"}
# - Maximum length: {generate_params.max_length} tokens
-#
+#
# Important instructions:
# - Continue the text naturally from where the prompt ends
# - Do not repeat the prompt in your response
@@ -1095,43 +1326,43 @@ def tts_demo(request: Request):
# - Do not use phrases like "Here's a continuation" or "Continuing from the prompt"
# - Just generate the continuation text directly
# """
-#
+#
# # Set up stop sequences
# stop_sequences = None
# if generate_params.stop_sequences and len(generate_params.stop_sequences) > 0:
# stop_sequences = frozenset(generate_params.stop_sequences)
-#
+#
# # Call Claude API
# generated_text = await query_to_claude_async(
# prompt=prompt,
# stop_sequences=stop_sequences,
# system_message=system_message,
# )
-#
+#
# # Handle the response
# if generated_text is None:
# return HTTPException(status_code=500, detail="Failed to generate text with Claude")
-#
+#
# # Format the response to match the standard API format
# result = [{
# "generated_text": prompt + generated_text,
# "finished_reason": "length",
# "model": "claude-3-sonnet-20240229"
# }]
-#
+#
# # Track usage if needed
# if request and background_tasks:
# if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers:
# if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET:
# # Claude is more expensive, so we charge 3 units
# background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=3)
-#
+#
# return result
-#
+#
# except Exception as e:
# logger.error(f"Error generating text with Claude: {e}")
# return HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
-#
+#
# @app.post("/api/v1/generate-large")
# async def generate_large_text(
# generate_params: GenerateParams,
@@ -1146,30 +1377,30 @@ def tts_demo(request: Request):
# validation_result = validate_generate_params(generate_params)
# if validation_result:
# return HTTPException(status_code=400, detail=validation_result)
-#
+#
# # Authorize the request
# if request and "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers:
# if not API_KEY and secret != sellerinfo.TEXT_GENERATOR_SECRET:
# if not request_authorized(request, secret):
# return HTTPException(
-# status_code=401,
+# status_code=401,
# detail="Please subscribe at https://text-generator.io/subscribe first"
# )
-#
+#
# try:
# # Prepare the prompt for Claude
# prompt = generate_params.text
# model_name = "claude-3-7-sonnet-20250219"
-#
+#
# # Set up system message to control generation parameters
# system_message = f"""
# You are a creative text generation assistant. Generate text that continues from the given prompt.
-#
+#
# Parameters to follow:
# - Temperature: {generate_params.temperature}
# - Creativity level: {"high" if generate_params.temperature > 0.7 else "medium" if generate_params.temperature > 0.3 else "low"}
# - Maximum length: {generate_params.max_length} tokens
-#
+#
# Important instructions:
# - Continue the text naturally from where the prompt ends
# - Do not repeat the prompt in your response
@@ -1177,12 +1408,12 @@ def tts_demo(request: Request):
# - Do not use phrases like "Here's a continuation" or "Continuing from the prompt"
# - Just generate the continuation text directly
# """
-#
+#
# # Set up stop sequences
# stop_sequences = None
# if generate_params.stop_sequences and len(generate_params.stop_sequences) > 0:
# stop_sequences = frozenset(generate_params.stop_sequences)
-#
+#
# # Call Claude API with the specified model
# generated_text = await query_to_claude_async(
# prompt=prompt,
@@ -1190,18 +1421,18 @@ def tts_demo(request: Request):
# system_message=system_message,
# model=model_name, # Pass the model name to the Claude API function
# )
-#
+#
# # Handle the response
# if generated_text is None:
# return HTTPException(status_code=500, detail="Failed to generate text with Claude")
-#
+#
# # Format the response to match the standard API format
# result = [{
# "generated_text": prompt + generated_text,
# "finished_reason": "length",
# "model": model_name
# }]
-#
+#
# # Track usage if needed
# if request and background_tasks:
# if "X-Rapid-API-Key" not in request.headers and "x-rapid-api-key" not in request.headers:
@@ -1209,9 +1440,9 @@ def tts_demo(request: Request):
# # Claude is more expensive, so we charge more units based on the model
# quantity = 5 if "opus" in model_name else 3 if "sonnet" in model_name else 2
# background_tasks.add_task(track_stripe_request_usage, secret=secret, quantity=quantity)
-#
+#
# return result
-#
+#
# except Exception as e:
# logger.error(f"Error generating text with Claude: {e}")
# return HTTPException(status_code=500, detail=f"Error generating text: {str(e)}")
diff --git a/questions/models.py b/questions/models.py
index 2268c6b..b171572 100755
--- a/questions/models.py
+++ b/questions/models.py
@@ -1,5 +1,6 @@
from time import time
from typing import Optional, List
+import os
from fastapi import UploadFile
from pydantic import BaseModel
@@ -73,6 +74,7 @@ class GenerateSpeechParams(BaseModel):
speed: float = 1.0
volume: float = 1.0
sample_rate: int = 24000 # Kokoro outputs 24kHz audio
+ chunk_words: int = int(os.getenv("STREAM_CHUNK_WORDS", "100"))
class OpenaiParams(BaseModel):
diff --git a/static/openapi.json b/static/openapi.json
index a2a4735..271db17 100644
--- a/static/openapi.json
+++ b/static/openapi.json
@@ -150,6 +150,56 @@
}
}
},
+ "/api/v1/generate_speech_stream": {
+ "post": {
+ "summary": "Stream Speech",
+ "operationId": "generate_speech_stream_api_v1_generate_speech_stream_post",
+ "parameters": [
+ {
+ "required": false,
+ "schema": {
+ "title": "Secret",
+ "type": "string"
+ },
+ "name": "secret",
+ "in": "header"
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/GenerateSpeechParams"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {
+ "application/octet-stream": {
+ "schema": {
+ "type": "string",
+ "format": "binary"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
"/api/v1/audio-file-extraction": {
"post": {
"summary": "Audio File Extraction",
@@ -574,6 +624,11 @@
"title": "Speaker",
"type": "string",
"default": "Female 1"
+ },
+ "chunk_words": {
+ "title": "Chunk Words",
+ "type": "integer",
+ "default": 100
}
}
},
diff --git a/static/openapi2-audio.json b/static/openapi2-audio.json
index e762a87..c3894a5 100644
--- a/static/openapi2-audio.json
+++ b/static/openapi2-audio.json
@@ -209,6 +209,56 @@
}
}
},
+ "/api/v1/generate_speech_stream": {
+ "post": {
+ "summary": "Stream Speech",
+ "operationId": "generate_speech_stream_api_v1_generate_speech_stream_post",
+ "parameters": [
+ {
+ "required": false,
+ "schema": {
+ "title": "Secret",
+ "type": "string"
+ },
+ "name": "secret",
+ "in": "header"
+ }
+ ],
+ "requestBody": {
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/GenerateSpeechParams"
+ }
+ }
+ },
+ "required": true
+ },
+ "responses": {
+ "200": {
+ "description": "Successful Response",
+ "content": {
+ "application/octet-stream": {
+ "schema": {
+ "type": "string",
+ "format": "binary"
+ }
+ }
+ }
+ },
+ "422": {
+ "description": "Validation Error",
+ "content": {
+ "application/json": {
+ "schema": {
+ "$ref": "#/components/schemas/HTTPValidationError"
+ }
+ }
+ }
+ }
+ }
+ }
+ },
"/api/v1/generate": {
"post": {
"summary": "Generate Route",
@@ -572,6 +622,11 @@
"title": "Speaker",
"type": "string",
"default": "Female 1"
+ },
+ "chunk_words": {
+ "title": "Chunk Words",
+ "type": "integer",
+ "default": 100
}
}
},
diff --git a/static/templates/shared/streaming-speech-api.jinja2 b/static/templates/shared/streaming-speech-api.jinja2
new file mode 100644
index 0000000..fa82b20
--- /dev/null
+++ b/static/templates/shared/streaming-speech-api.jinja2
@@ -0,0 +1,38 @@
+
+
+
+
+
+
+
+
+
+
+
Text Generator >
Blog > Streaming Speech API
+
+
Stream text to speech output
+
No more waiting for long texts
+
Our new endpoint /api/v1/generate_speech_stream lets you begin playback as soon as audio is generated. The backend automatically chunks long text so memory use stays low.
+
Example usage
+
import requests
+import os
+
+API_KEY = os.getenv("TEXT_GENERATOR_API_KEY")
+headers = {"secret": API_KEY}
+
+params = {"text": "Hello streaming world", "speaker": "Male fast"}
+with requests.post("https://api.text-generator.io/api/v1/generate_speech_stream", json=params, headers=headers, stream=True) as r:
+ for chunk in r.iter_content(chunk_size=None):
+ # Play or save chunk
+ print("Received", len(chunk), "bytes")
+
+
Chunks are standard WAV fragments with a 24 kHz sample rate.
+
See the API docs for more details.
+
+ Sign up
+
+
+
+ {% import "templates/macros.jinja2" as macros with context %}
+ {{ macros.svgstyled() }}
+
diff --git a/tests/integ/test_audio_length_limit.py b/tests/integ/test_audio_length_limit.py
new file mode 100644
index 0000000..9d600f3
--- /dev/null
+++ b/tests/integ/test_audio_length_limit.py
@@ -0,0 +1,28 @@
+import os
+import pytest
+
+from starlette.testclient import TestClient
+
+from questions.inference_server.inference_server import app, audio_process
+
+client = TestClient(app)
+
+API_KEY = os.environ.get("TEXT_GENERATOR_API_KEY", "test")
+headers = {"secret": API_KEY}
+
+
+@pytest.mark.skipif(os.getenv("CI") == "true", reason="skip heavy audio test in CI")
+def test_audio_word_limit():
+ text = "hello"
+ prev_len = 0
+ max_words = None
+ for i in range(1, 30):
+ rate, audio = audio_process(text, "af_nicole")
+ assert rate == 24000
+ cur_len = len(audio)
+ if cur_len <= prev_len:
+ max_words = len(text.split()) - 1
+ break
+ prev_len = cur_len
+ text += " hello"
+ assert max_words is not None
diff --git a/tests/integ/test_inference_server_speech.py b/tests/integ/test_inference_server_speech.py
index 08c1298..3e8ba55 100644
--- a/tests/integ/test_inference_server_speech.py
+++ b/tests/integ/test_inference_server_speech.py
@@ -1,4 +1,5 @@
import os
+import pytest
from questions.utils import log_time
@@ -19,17 +20,22 @@
def test_speech_creation():
with log_time("speech creation"):
- result = audio_process("It is not in the stars to hold our destiny but in ourselves." * 10, "Male fast")
+ result = audio_process(
+ "It is not in the stars to hold our destiny but in ourselves." * 10,
+ "Male fast",
+ )
assert result is not None
assert result[0] is not None
-def test_speech_creation_route():
+def test_generate_speech_route_single_voice():
audio_params = GenerateSpeechParams(
text="Text-Generator.io is bringing the cost of intelligence toward zero.",
speaker="Male fast",
)
- response = client.post("/api/v1/generate_speech", json=audio_params.__dict__, headers=headers)
+ response = client.post(
+ "/api/v1/generate_speech", json=audio_params.__dict__, headers=headers
+ )
assert response.status_code == 200, response.text
binary_file_response = response.content
assert binary_file_response is not None
@@ -37,6 +43,21 @@ def test_speech_creation_route():
f.write(binary_file_response)
+@pytest.mark.skipif(os.getenv("CI") == "true", reason="skip heavy audio test in CI")
+def test_generate_speech_stream_route():
+ audio_params = GenerateSpeechParams(
+ text="Streaming audio makes responses faster.",
+ speaker="Male fast",
+ chunk_words=5,
+ )
+ response = client.post(
+ "/api/v1/generate_speech_stream", json=audio_params.__dict__, headers=headers, stream=True
+ )
+ assert response.status_code == 200, response.text
+ data = b"".join(response.iter_content(chunk_size=None))
+ assert len(data) > 0
+
+
speakers = [
"Male fast",
"Female 1",
@@ -46,13 +67,15 @@ def test_speech_creation_route():
]
-def test_speech_creation_route():
+def test_generate_speech_route_all_voices():
for speaker in speakers:
audio_params = GenerateSpeechParams(
text="Text-Generator.io is bringing the cost of intelligence toward zero.",
speaker=speaker,
)
- response = client.post("/api/v1/generate_speech", json=audio_params.__dict__, headers=headers)
+ response = client.post(
+ "/api/v1/generate_speech", json=audio_params.__dict__, headers=headers
+ )
assert response.status_code == 200, response.text
binary_file_response = response.content
assert binary_file_response is not None