diff --git a/AGENTS.md b/AGENTS.md index f57a537..add3d2a 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -21,6 +21,7 @@ Main flow: - `src/index.ts`: public exports. - `src/KittenTTS.ts`: main SDK class and lifecycle. - `src/KittenTTSConfig.ts`: user config and defaults. +- `src/*.web.ts`: React Native Web entrypoints and platform-specific browser implementations. - `src/KittenTTSError.ts`: SDK error codes and helpers. - `src/KittenModel.ts`: model names, download URLs, sizes, speed priors. - `src/KittenVoice.ts`: voice enum and display helpers. @@ -29,6 +30,7 @@ Main flow: - `src/engine/TTSEngine.ts`: text-to-token-to-ONNX inference. - `src/phonemizer/CEPhonemizer.ts`: JS/Emscripten phonemizer adapter. - `src/audio/AudioOutput.ts`: optional playback helpers. +- `src/storage/AssetStorage.ts`: web/Node asset cache abstraction used by the web platform files. - `vendor/cephonemizer/`: vendored C++ phonemizer source. - `scripts/build-cephonemizer.js`: builds generated phonemizer runtime. - `scripts/patch-onnxruntime-react-native.js`: postinstall ONNX Runtime compatibility patches. diff --git a/CHANGELOG.md b/CHANGELOG.md index 45fa833..d75fd3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ - Added Swift-parity word timing metadata via `KittenTTSResult.wordTimings`. - Added `KittenTTS.generateStreaming()` for sentence-by-sentence generation. - Added `tts.play(result)` so apps can inspect timings before playback. +- Added React Native Web support through browser-specific ONNX Runtime Web, + Cache API asset storage, CE phonemizer, and audio playback implementations. ## 0.8.0 diff --git a/README.md b/README.md index 2818c6c..6744dd5 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@

- On-device text-to-speech for React Native. + On-device text-to-speech for React Native and React Native Web.
- Generate speech on iOS and Android without sending text to a cloud TTS API. + Generate speech on iOS, Android, and web without sending text to a cloud TTS API.

@@ -21,9 +21,15 @@ > Developer preview. APIs may change between releases. -> Expo Go will not work. KittenTTS uses native modules -> (`onnxruntime-react-native` and `react-native-fs`), so Expo apps need a -> development build or a prebuilt native project. +> Expo Go will not work for native iOS/Android. KittenTTS uses native modules +> (`onnxruntime-react-native` and `react-native-fs`) on mobile, so Expo apps +> need a development build or a prebuilt native project. Web builds use +> `onnxruntime-web` and browser storage instead. + +> React Native Web loads a pinned ONNX Runtime Web script and WASM assets from +> jsDelivr by default. For production apps that need CDN independence or stricter +> supply-chain controls, self-host those ONNX Runtime assets and set +> `ortWasmPath`. ## See It In Action @@ -36,6 +42,14 @@ Device: iOS · Expo example     Device: Android · Word timings

+

+ KittenTTS React Native Web example running in a browser +

+ +

+ Web · Browser example +

+ --- ## What Is KittenTTS React Native? @@ -60,6 +74,7 @@ No cloud. No API key. No text leaving the device for speech generation. | --- | --- | --- | | React Native iOS | Developer preview | [Getting started](docs/getting-started.md) | | React Native Android | Developer preview | [Getting started](docs/getting-started.md) | +| React Native Web | Developer preview | [Getting started](docs/getting-started.md#web) | | Expo development build | Supported | [Expo setup](docs/getting-started.md#expo-development-build) | | Expo Go | Not supported | [Why not?](docs/troubleshooting.md#expo-go-fails) | @@ -109,6 +124,21 @@ const tts = await KittenTTS.create({ await tts.speak('This voice is generated on the device.'); ``` +Play audio in a web build: + +```tsx +import { + KittenTTS, + createBrowserAudioPlayer, +} from '@kittentts/react-native'; + +const tts = await KittenTTS.create({ + player: createBrowserAudioPlayer(), +}); + +await tts.speak('This voice is generated in the browser.'); +``` + [Full getting started guide →](docs/getting-started.md) --- @@ -153,7 +183,7 @@ If the app opens in Expo Go, stop it and run `npx expo run:ios` or ## Features -- [On-device TTS inference](docs/getting-started.md) on iOS and Android. +- [On-device TTS inference](docs/getting-started.md) on iOS, Android, and web. - [Model download and cache](docs/reference/api.md#cache-methods) with progress callbacks. - [Bundled offline assets](docs/guides/offline-assets.md) for apps that cannot depend on a first-run download. - [Expo development builds](docs/getting-started.md#expo-development-build); Expo Go is [not supported](docs/troubleshooting.md#expo-go-fails). diff --git a/assets/web-example.gif b/assets/web-example.gif new file mode 100644 index 0000000..d8244b0 Binary files /dev/null and b/assets/web-example.gif differ diff --git a/docs/getting-started.md b/docs/getting-started.md index 84ce3d8..ed4ae80 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -10,6 +10,7 @@ instance, and generate speech. | React Native | `>= 0.72` | | iOS | `15.1+` | | Android | API `24+` | +| Web | modern browser with WebAssembly support | | Node.js | `20+` recommended for examples | Expo Go will not work. KittenTTS depends on native modules: @@ -18,6 +19,8 @@ Expo Go will not work. KittenTTS depends on native modules: - `react-native-fs` Use a bare React Native app, an Expo development build, or a prebuilt Expo app. +React Native Web builds use `onnxruntime-web` and do not require those native +modules at runtime. ## Install @@ -57,6 +60,35 @@ npm install react-native-sound cd ios && pod install && cd .. ``` +## Web + +React Native Web builds resolve the package's browser entrypoint. The web +runtime uses `onnxruntime-web`, Cache API storage for downloaded model files, +and the same JavaScript CE phonemizer. + +```tsx +import { + KittenTTS, + createBrowserAudioPlayer, +} from '@kittentts/react-native'; + +const tts = await KittenTTS.create({ + player: createBrowserAudioPlayer(), +}); + +await tts.speak('Hello from KittenTTS on web.'); +await tts.dispose(); +``` + +The browser path also supports `generate()`, `wordTimings`, `wavData()`, and +`wavBase64()`. Pass `ortWasmPath` if your app needs to self-host ONNX Runtime +WASM assets instead of using the SDK defaults. + +By default, browser builds load the pinned ONNX Runtime Web script and WASM +assets from jsDelivr. That keeps the SDK simple to drop into React Native Web, +but production apps that require tighter supply-chain control or CDN outage +isolation should self-host those files and set `ortWasmPath` to that directory. + ## Generate Audio Use `generate()` when you want audio data back without playing it immediately. @@ -117,6 +149,8 @@ await tts.dispose(); The first `KittenTTS.create()` downloads the selected model, `voices.npz`, and phonemizer files. Later calls reuse the device cache. +On web, the cache is stored through the browser Cache API when available and +falls back to memory storage. Default model cache: diff --git a/docs/guides/playback.md b/docs/guides/playback.md index efffa23..da50256 100644 --- a/docs/guides/playback.md +++ b/docs/guides/playback.md @@ -52,6 +52,23 @@ const tts = await KittenTTS.create({ await tts.speak('This plays through react-native-sound.'); ``` +## Browser Audio + +React Native Web builds can use the browser audio helper: + +```tsx +import { + KittenTTS, + createBrowserAudioPlayer, +} from '@kittentts/react-native'; + +const tts = await KittenTTS.create({ + player: createBrowserAudioPlayer(), +}); + +await tts.speak('This plays through an HTML audio element.'); +``` + ## Generate First, Then Play This is useful when the UI needs metadata from the generated result before diff --git a/package-lock.json b/package-lock.json index ad15167..ee5a43c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,22 +1,23 @@ { "name": "@kittentts/react-native", - "version": "1.1.0", + "version": "1.2.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@kittentts/react-native", - "version": "1.1.0", + "version": "1.2.0", "hasInstallScript": true, "license": "Apache-2.0", - "bin": { - "kittentts-react-native": "bin/kittentts-react-native.js" - }, "dependencies": { "onnxruntime-react-native": "^1.24.3", + "onnxruntime-web": "^1.24.3", "pako": "^2.1.0", "react-native-fs": "npm:@dr.pogodin/react-native-fs@^2.38.2" }, + "bin": { + "kittentts-react-native": "bin/kittentts-react-native.js" + }, "devDependencies": { "@types/pako": "^2.0.3", "@types/react": "^18.2.0", @@ -2157,6 +2158,70 @@ "@jridgewell/sourcemap-codec": "^1.4.14" } }, + "node_modules/@protobufjs/aspromise": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@protobufjs/aspromise/-/aspromise-1.1.2.tgz", + "integrity": "sha512-j+gKExEuLmKwvz3OgROXtrJ2UG2x8Ch2YZUxahh+s1F2HZ+wAceUNLkvy6zKCPVRkU++ZWQrdxsUeQXmcg4uoQ==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/base64": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@protobufjs/base64/-/base64-1.1.2.tgz", + "integrity": "sha512-AZkcAA5vnN/v4PDqKyMR5lx7hZttPDgClv83E//FMNhR2TMcLUhfRUBHCmSl0oi9zMgDDqRUJkSxO3wm85+XLg==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/codegen": { + "version": "2.0.5", + "resolved": "https://registry.npmjs.org/@protobufjs/codegen/-/codegen-2.0.5.tgz", + "integrity": "sha512-zgXFLzW3Ap33e6d0Wlj4MGIm6Ce8O89n/apUaGNB/jx+hw+ruWEp7EwGUshdLKVRCxZW12fp9r40E1mQrf/34g==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/eventemitter": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/eventemitter/-/eventemitter-1.1.0.tgz", + "integrity": "sha512-j9ednRT81vYJ9OfVuXG6ERSTdEL1xVsNgqpkxMsbIabzSo3goCjDIveeGv5d03om39ML71RdmrGNjG5SReBP/Q==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/fetch": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/fetch/-/fetch-1.1.0.tgz", + "integrity": "sha512-lljVXpqXebpsijW71PZaCYeIcE5on1w5DlQy5WH6GLbFryLUrBD4932W/E2BSpfRJWseIL4v/KPgBFxDOIdKpQ==", + "license": "BSD-3-Clause", + "dependencies": { + "@protobufjs/aspromise": "^1.1.1", + "@protobufjs/inquire": "^1.1.0" + } + }, + "node_modules/@protobufjs/float": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@protobufjs/float/-/float-1.0.2.tgz", + "integrity": "sha512-Ddb+kVXlXst9d+R9PfTIxh1EdNkgoRe5tOX6t01f1lYWOvJnSPDBlG241QLzcyPdoNTsblLUdujGSE4RzrTZGQ==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/inquire": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@protobufjs/inquire/-/inquire-1.1.1.tgz", + "integrity": "sha512-mnzgDV26ueAvk7rsbt9L7bE0SuAoqyuys/sMMrmVcN5x9VsxpcG3rqAUSgDyLp0UZlmNfIbQ4fHfCtreVBk8Ew==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/path": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/@protobufjs/path/-/path-1.1.2.tgz", + "integrity": "sha512-6JOcJ5Tm08dOHAbdR3GrvP+yUUfkjG5ePsHYczMFLq3ZmMkAD98cDgcT2iA1lJ9NVwFd4tH/iSSoe44YWkltEA==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/pool": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/@protobufjs/pool/-/pool-1.1.0.tgz", + "integrity": "sha512-0kELaGSIDBKvcgS4zkjz1PeddatrjYcmMWOlAuAPwAeccUrPHdUqo/J6LiymHHEiJT5NrF1UVwxY14f+fy4WQw==", + "license": "BSD-3-Clause" + }, + "node_modules/@protobufjs/utf8": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/@protobufjs/utf8/-/utf8-1.1.1.tgz", + "integrity": "sha512-oOAWABowe8EAbMyWKM0tYDKi8Yaox52D+HWZhAIJqQXbqe0xI/GV7FhLWqlEKreMkfDjshR5FKgi3mnle0h6Eg==", + "license": "BSD-3-Clause" + }, "node_modules/@react-native-community/cli": { "version": "12.3.7", "resolved": "https://registry.npmjs.org/@react-native-community/cli/-/cli-12.3.7.tgz", @@ -3948,6 +4013,12 @@ "node": ">=8" } }, + "node_modules/flatbuffers": { + "version": "25.9.23", + "resolved": "https://registry.npmjs.org/flatbuffers/-/flatbuffers-25.9.23.tgz", + "integrity": "sha512-MI1qs7Lo4Syw0EOzUl0xjs2lsoeqFku44KpngfIduHBYvzm8h2+7K8YMQh1JtVVVrUvhLpNwqVi4DERegUJhPQ==", + "license": "Apache-2.0" + }, "node_modules/flow-enums-runtime": { "version": "0.0.6", "resolved": "https://registry.npmjs.org/flow-enums-runtime/-/flow-enums-runtime-0.0.6.tgz", @@ -4072,6 +4143,12 @@ "integrity": "sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==", "license": "ISC" }, + "node_modules/guid-typescript": { + "version": "1.0.9", + "resolved": "https://registry.npmjs.org/guid-typescript/-/guid-typescript-1.0.9.tgz", + "integrity": "sha512-Y8T4vYhEfwJOTbouREvG+3XDsjr8E3kIr7uf+JZ0BYloFsttiHU0WfvANVsR7TxNUJa/WpCnw/Ino/p+DeBhBQ==", + "license": "ISC" + }, "node_modules/has-flag": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", @@ -4885,6 +4962,12 @@ "node": ">=6" } }, + "node_modules/long": { + "version": "5.3.2", + "resolved": "https://registry.npmjs.org/long/-/long-5.3.2.tgz", + "integrity": "sha512-mNAgZ1GmyNhD7AuqnTG3/VQ26o760+ZYBPKjPvugO8+nLbYfX6TVpJPseBvopbdY+qpZ/lKUnmEc1LeZYS3QAA==", + "license": "Apache-2.0" + }, "node_modules/loose-envify": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/loose-envify/-/loose-envify-1.4.0.tgz", @@ -5661,6 +5744,26 @@ "react-native": "*" } }, + "node_modules/onnxruntime-web": { + "version": "1.26.0", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.26.0.tgz", + "integrity": "sha512-LbRr/8zZt2xilI2smrVQGGKINo0U46i8qJp+UXyMBGfqN7KjnH1BiwCwLwyNIVV4i9CKFv7Sf4PwLKWnT8/bEA==", + "license": "MIT", + "dependencies": { + "flatbuffers": "^25.1.24", + "guid-typescript": "^1.0.9", + "long": "^5.2.3", + "onnxruntime-common": "1.26.0", + "platform": "^1.3.6", + "protobufjs": "^7.2.4" + } + }, + "node_modules/onnxruntime-web/node_modules/onnxruntime-common": { + "version": "1.26.0", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.26.0.tgz", + "integrity": "sha512-qVyMR4lcWgbkc4getFV+GQijsTnbg/siteoqcDwa3sI/LxbrMSNw4ePyvCq/ymdQaRomCA7YuWmhzsswxvymdw==", + "license": "MIT" + }, "node_modules/open": { "version": "6.4.0", "resolved": "https://registry.npmjs.org/open/-/open-6.4.0.tgz", @@ -5908,6 +6011,12 @@ "node": ">=4" } }, + "node_modules/platform": { + "version": "1.3.6", + "resolved": "https://registry.npmjs.org/platform/-/platform-1.3.6.tgz", + "integrity": "sha512-fnWVljUchTro6RiCFvCXBbNhJc2NijN7oIQxbwsyL0buWJPG85v81ehlHI9fXrJsMNgTofEoWIQeClKpgxFLrg==", + "license": "MIT" + }, "node_modules/pretty-format": { "version": "26.6.2", "resolved": "https://registry.npmjs.org/pretty-format/-/pretty-format-26.6.2.tgz", @@ -5993,6 +6102,30 @@ "integrity": "sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==", "license": "MIT" }, + "node_modules/protobufjs": { + "version": "7.5.8", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.5.8.tgz", + "integrity": "sha512-dvpCIeLPbXZS/Ete7yLaO7RenOdken2NHKykBXbsaGxZT0UTltcarBciw+A78SRQs9iMAAVpsYA+l8b1hTePIA==", + "hasInstallScript": true, + "license": "BSD-3-Clause", + "dependencies": { + "@protobufjs/aspromise": "^1.1.2", + "@protobufjs/base64": "^1.1.2", + "@protobufjs/codegen": "^2.0.5", + "@protobufjs/eventemitter": "^1.1.0", + "@protobufjs/fetch": "^1.1.0", + "@protobufjs/float": "^1.0.2", + "@protobufjs/inquire": "^1.1.1", + "@protobufjs/path": "^1.1.2", + "@protobufjs/pool": "^1.1.0", + "@protobufjs/utf8": "^1.1.1", + "@types/node": ">=13.7.0", + "long": "^5.0.0" + }, + "engines": { + "node": ">=12.0.0" + } + }, "node_modules/queue": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/queue/-/queue-6.0.2.tgz", diff --git a/package.json b/package.json index e8139d3..e7ff31a 100644 --- a/package.json +++ b/package.json @@ -4,9 +4,24 @@ "description": "On-device text-to-speech for React Native, powered by KittenTTS + ONNX Runtime", "main": "lib/index.js", "types": "lib/index.d.ts", + "browser": { + "./lib/index.js": "./lib/index.web.js", + "./lib/KittenTTS.js": "./lib/KittenTTS.web.js", + "./lib/KittenTTSBundledAssets.js": "./lib/KittenTTSBundledAssets.web.js", + "./lib/KittenTTSConfig.js": "./lib/KittenTTSConfig.web.js", + "./lib/audio/AudioOutput.js": "./lib/audio/AudioOutput.web.js", + "./lib/engine/TTSEngine.js": "./lib/engine/TTSEngine.web.js", + "./lib/loader/ModelDownloader.js": "./lib/loader/ModelDownloader.web.js", + "./lib/loader/NPZLoader.js": "./lib/loader/NPZLoader.web.js", + "./lib/phonemizer/CEPhonemizer.js": "./lib/phonemizer/CEPhonemizer.web.js" + }, "exports": { ".": { "types": "./lib/index.d.ts", + "browser": { + "types": "./lib/index.web.d.ts", + "default": "./lib/index.web.js" + }, "react-native": "./lib/index.js", "require": "./lib/index.js", "default": "./lib/index.js" @@ -92,6 +107,7 @@ }, "dependencies": { "onnxruntime-react-native": "^1.24.3", + "onnxruntime-web": "^1.24.3", "react-native-fs": "npm:@dr.pogodin/react-native-fs@^2.38.2", "pako": "^2.1.0" } diff --git a/src/KittenTTS.web.ts b/src/KittenTTS.web.ts new file mode 100644 index 0000000..1780868 --- /dev/null +++ b/src/KittenTTS.web.ts @@ -0,0 +1,519 @@ +import { + KittenTTSConfig, + OUTPUT_SAMPLE_RATE, + type ResolvedKittenTTSConfig, + resolveConfig, +} from './KittenTTSConfig.web'; +import { + KittenTTSError, + KittenTTSErrorCode, + isKittenTTSError, +} from './KittenTTSError'; +import { KittenTTSResult } from './KittenTTSResult'; +import { KittenModel, speedPrior } from './KittenModel'; +import { KittenVoice } from './KittenVoice'; +import type { KittenWordTiming } from './KittenWordTiming'; +import { TTSEngine } from './engine/TTSEngine.web'; +import { splitSentences } from './engine/SentenceSplitter'; +import { joinTimestamps } from './engine/TimestampJoiner'; +import { loadNPZ, loadNPZData } from './loader/NPZLoader.web'; +import { + clearModelCache as deleteCachedModel, + getModelCacheInfo, + getProvidedModelCacheInfo, + isModelCached as checkModelCached, + type ModelCacheInfo, + type ModelPaths, + type ProgressHandler, + resolveModelPaths, +} from './loader/ModelDownloader.web'; +import { + AudioOutput, + type AudioPlayer, + type AudioPlayOptions, +} from './audio/AudioOutput.web'; + +/** Options for {@link KittenTTS.create}. */ +export interface KittenTTSCreateOptions extends KittenTTSConfig { + /** + * Delete cached model files and download fresh copies before initialising. + * Useful after a failed/interrupted first-run setup. + */ + forceRedownload?: boolean; + + /** + * Audio player for the `speak()` and `play()` methods. + * + * Use {@link createBrowserAudioPlayer} in browser apps, or provide your own + * implementation for frameworks, workers, or Node.js. + * + * @example + * ```typescript + * import { KittenTTS, createBrowserAudioPlayer } from '@kittentts/react-native'; + * + * const tts = await KittenTTS.create({ + * player: createBrowserAudioPlayer(), + * }); + * await tts.speak('Hello!'); + * ``` + */ + player?: AudioPlayer; +} + +/** + * The KittenTTS speech-synthesis engine for React Native Web runtimes. + * + * Downloads the model on first use, initialises ONNX Runtime inference, + * and exposes an async API for generating and playing speech. + * + * @example + * ```typescript + * import { KittenTTS, createBrowserAudioPlayer } from '@kittentts/react-native'; + * + * const tts = await KittenTTS.create({ + * player: createBrowserAudioPlayer(), + * }); + * + * // Generate audio + * const result = await tts.generate('Hello from KittenTTS!'); + * + * // Play through speakers + * await tts.speak('Good morning!'); + * ``` + */ +export class KittenTTS { + /** The configuration this instance was created with. */ + readonly config: ResolvedKittenTTSConfig; + + private engine: TTSEngine; + private audioOutput: AudioOutput; + private disposed = false; + private disposePromise: Promise | null = null; + + private constructor( + engine: TTSEngine, + config: ResolvedKittenTTSConfig, + player?: AudioPlayer, + ) { + this.engine = engine; + this.config = config; + this.audioOutput = new AudioOutput(player); + } + + /** + * Create and initialise a KittenTTS instance. + * + * Downloads all required files if not cached, loads the ONNX model, and + * prepares the engine for inference. + * + * @param options - Configuration and player for this session. + * @param onProgress - Optional callback for download progress [0, 1]. + * @returns A ready-to-use KittenTTS instance. + */ + static async create( + options?: KittenTTSCreateOptions, + onProgress?: ProgressHandler, + ): Promise { + const resolved = resolveConfig(options); + const hasPhonemizerDownload = + typeof resolved.phonemizer.downloadIfNeeded === 'function'; + const setupProgress = createAggregateProgress(onProgress); + + const phonemizerDownload = hasPhonemizerDownload + ? resolved.phonemizer.downloadIfNeeded?.( + resolved.storageDirectory, + setupProgress, + ) + : Promise.resolve(); + + const modelDownload = resolveModelPaths( + resolved.model, + resolved.storageDirectory, + setupProgress, + { + modelFiles: resolved.modelFiles, + force: options?.forceRedownload ?? false, + retries: resolved.downloadRetries, + baseURL: resolved.modelBaseURL || undefined, + storage: resolved.storage, + fetch: resolved.fetch, + }, + ); + + const [, downloadedPaths] = await Promise.all([ + phonemizerDownload, + modelDownload, + ]); + setupProgress(1, { stage: 'complete' }); + + let paths = downloadedPaths; + const repairCache = async (): Promise => { + await deleteCachedModel( + resolved.model, + resolved.storageDirectory, + resolved.storage, + ); + return resolveModelPaths( + resolved.model, + resolved.storageDirectory, + setupProgress, + { + force: true, + retries: resolved.downloadRetries, + baseURL: resolved.modelBaseURL || undefined, + storage: resolved.storage, + fetch: resolved.fetch, + }, + ); + }; + + let voices = await loadVoicesWithCacheRepair(paths, repairCache); + let engine: TTSEngine; + try { + engine = await TTSEngine.create(resolveOnnxModelSource(paths), voices, resolved); + } catch (error) { + if (resolved.modelFiles || !isRepairableModelCacheError(error)) throw error; + paths = await repairCache(); + voices = await loadVoicesFromModelPaths(paths); + engine = await TTSEngine.create(resolveOnnxModelSource(paths), voices, resolved); + } + + return new KittenTTS(engine, resolved, options?.player); + } + + /** + * Synthesise speech for the given text. + * + * @param text - The English text to synthesise. Must not be empty. + * @param voice - The voice to use. Defaults to the config's `defaultVoice`. + * @param speed - Speed multiplier (0.5--2.0). Defaults to the config's `speed`. + * @returns A {@link KittenTTSResult} containing PCM samples and metadata. + */ + async generate( + text: string, + voice?: KittenVoice, + speed?: number, + ): Promise { + if (this.disposed) throw KittenTTSError.engineNotReady(); + + const trimmed = text.trim(); + if (!trimmed) throw KittenTTSError.emptyInput(); + + const selectedVoice = voice ?? this.config.defaultVoice; + const selectedSpeed = Math.min(Math.max(speed ?? this.config.speed, 0.5), 2.0); + + const output = await this.engine.generate( + trimmed, + selectedVoice, + selectedSpeed, + ); + const effectiveSpeed = selectedSpeed * speedPrior(this.config.model, selectedVoice); + const wordTimings = normalizeWordTimingsToDuration( + joinTimestamps(trimmed, output.phonemes, output.durations), + output.samples.length / OUTPUT_SAMPLE_RATE, + ); + + return new KittenTTSResult( + output.samples, + OUTPUT_SAMPLE_RATE, + selectedVoice, + effectiveSpeed, + trimmed, + wordTimings, + ); + } + + /** + * Synthesise speech sentence by sentence. + * + * This is the streaming counterpart to {@link generate}. It yields each + * {@link KittenTTSResult} as soon as that sentence is ready, which lets apps + * start playback before a long text has fully generated. + */ + async *generateStreaming( + text: string, + voice?: KittenVoice, + speed?: number, + ): AsyncGenerator { + if (this.disposed) throw KittenTTSError.engineNotReady(); + + const trimmed = text.trim(); + if (!trimmed) throw KittenTTSError.emptyInput(); + + const selectedVoice = voice ?? this.config.defaultVoice; + const selectedSpeed = Math.min(Math.max(speed ?? this.config.speed, 0.5), 2.0); + for (const sentence of splitSentences(trimmed)) { + yield await this.generate(sentence, selectedVoice, selectedSpeed); + } + } + + /** + * Synthesise and play speech through the device speakers. + * + * Requires an {@link AudioPlayer} to be passed via `KittenTTS.create({ player })`. + * + * @param text - The English text to synthesise. + * @param voice - The voice to use. + * @param speed - Speed multiplier (0.5--2.0). + * @returns The generated {@link KittenTTSResult}. + */ + async speak( + text: string, + voice?: KittenVoice, + speed?: number, + ): Promise { + const result = await this.generate(text, voice, speed); + await this.play(result); + return result; + } + + /** + * Play a previously generated result. + * + * Use this when an app needs to inspect metadata such as `wordTimings` before + * playback starts. + */ + async play( + result: KittenTTSResult, + options: AudioPlayOptions = {}, + ): Promise { + if (this.disposed) throw KittenTTSError.engineNotReady(); + await this.audioOutput.play(result.samples, result.sampleRate, options); + } + + /** Stop any currently active audio playback. */ + async stopSpeaking(): Promise { + await this.audioOutput.stop(); + } + + /** Check if the model files are already cached on disk. */ + static async isModelCached(config?: KittenTTSConfig): Promise { + const resolved = resolveConfig(config); + if (resolved.modelFiles) { + return (await getProvidedModelCacheInfo( + resolved.model, + resolved.modelFiles, + )).isCached; + } + return checkModelCached( + resolved.model, + resolved.storageDirectory, + resolved.storage, + ); + } + + /** Detailed cache state for first-run UI. */ + static async getModelCacheInfo( + config?: KittenTTSConfig, + ): Promise { + const resolved = resolveConfig(config); + if (resolved.modelFiles) { + return getProvidedModelCacheInfo(resolved.model, resolved.modelFiles); + } + return getModelCacheInfo( + resolved.model, + resolved.storageDirectory, + resolved.storage, + ); + } + + /** Alias for `isModelCached()` with clearer app-facing wording. */ + static async isModelDownloaded(config?: KittenTTSConfig): Promise { + return KittenTTS.isModelCached(config); + } + + /** Delete cached files for the selected model. */ + static async clearModelCache(config?: KittenTTSConfig): Promise { + const resolved = resolveConfig(config); + if (resolved.modelFiles) return; + await deleteCachedModel( + resolved.model, + resolved.storageDirectory, + resolved.storage, + ); + } + + /** Delete and download the selected model again. */ + static async redownloadModel( + config?: KittenTTSConfig, + onProgress?: ProgressHandler, + ): Promise { + const resolved = resolveConfig(config); + if (resolved.modelFiles) { + await resolveModelPaths( + resolved.model, + resolved.storageDirectory, + onProgress, + { + modelFiles: resolved.modelFiles, + storage: resolved.storage, + fetch: resolved.fetch, + }, + ); + return; + } + await deleteCachedModel( + resolved.model, + resolved.storageDirectory, + resolved.storage, + ); + await resolveModelPaths( + resolved.model, + resolved.storageDirectory, + onProgress, + { + force: true, + retries: resolved.downloadRetries, + baseURL: resolved.modelBaseURL || undefined, + storage: resolved.storage, + fetch: resolved.fetch, + }, + ); + } + + /** Download model and phonemizer assets without creating a long-lived engine. */ + static async predownload( + config?: KittenTTSConfig, + onProgress?: ProgressHandler, + ): Promise { + const resolved = resolveConfig(config); + const hasPhonemizerDownload = + typeof resolved.phonemizer.downloadIfNeeded === 'function'; + const setupProgress = createAggregateProgress(onProgress); + + const phonemizerDownload = hasPhonemizerDownload + ? resolved.phonemizer.downloadIfNeeded?.( + resolved.storageDirectory, + setupProgress, + ) + : Promise.resolve(); + + const modelDownload = resolveModelPaths( + resolved.model, + resolved.storageDirectory, + setupProgress, + { + modelFiles: resolved.modelFiles, + retries: resolved.downloadRetries, + baseURL: resolved.modelBaseURL || undefined, + storage: resolved.storage, + fetch: resolved.fetch, + }, + ); + + await Promise.all([phonemizerDownload, modelDownload]); + setupProgress(1, { stage: 'complete' }); + } + + /** @deprecated Use `predownload()`. This method does not keep an engine warm. */ + static async prewarm( + config?: KittenTTSConfig, + onProgress?: ProgressHandler, + ): Promise { + await KittenTTS.predownload(config, onProgress); + } + + /** Release the ONNX session and free resources. */ + async dispose(): Promise { + if (this.disposePromise) return this.disposePromise; + this.disposed = true; + this.disposePromise = (async () => { + await this.audioOutput.stop().catch(() => {}); + await this.engine.dispose(); + this.config.phonemizer.dispose?.(); + })(); + return this.disposePromise; + } +} + +function normalizeWordTimingsToDuration( + wordTimings: readonly KittenWordTiming[], + audioDuration: number, +): KittenWordTiming[] { + if (wordTimings.length === 0 || audioDuration <= 0) return [...wordTimings]; + + const lastEndTime = wordTimings[wordTimings.length - 1].endTime; + if (lastEndTime <= 0) return [...wordTimings]; + + const scale = audioDuration / lastEndTime; + return wordTimings.map(timing => ({ + ...timing, + startTime: clampTime(timing.startTime * scale, audioDuration), + endTime: clampTime(timing.endTime * scale, audioDuration), + })); +} + +function clampTime(value: number, audioDuration: number): number { + return Math.max(0, Math.min(audioDuration, value)); +} + +function resolveOnnxModelSource(paths: ModelPaths): string | Uint8Array { + if (paths.onnxData) return paths.onnxData; + if (paths.onnxPath) return paths.onnxPath; + throw KittenTTSError.modelFileNotFound(''); +} + +async function loadVoicesFromModelPaths( + paths: ModelPaths, +): Promise>> { + if (paths.voicesData) return loadNPZData(paths.voicesData); + if (paths.voicesPath) return loadNPZ(paths.voicesPath); + throw KittenTTSError.voicesFileNotFound(''); +} + +async function loadVoicesWithCacheRepair( + paths: ModelPaths, + repairCache: () => Promise, +): Promise>> { + try { + return await loadVoicesFromModelPaths(paths); + } catch (error) { + if (!isRepairableModelCacheError(error)) throw error; + const repairedPaths = await repairCache(); + return loadVoicesFromModelPaths(repairedPaths); + } +} + +function isRepairableModelCacheError(error: unknown): boolean { + return ( + isKittenTTSError(error) && + (error.code === KittenTTSErrorCode.InvalidModelData || + error.code === KittenTTSErrorCode.VoicesFileNotFound || + error.code === KittenTTSErrorCode.ModelFileNotFound || + error.code === KittenTTSErrorCode.InferenceFailed) + ); +} + +function createAggregateProgress( + progressHandler?: ProgressHandler, +): ProgressHandler { + const files = new Map(); + + return (progress, info) => { + if (info?.asset && info.contentLength && info.contentLength > 0) { + files.set(info.asset, { + bytesWritten: Math.max(0, Math.min(info.bytesWritten ?? 0, info.contentLength)), + contentLength: info.contentLength, + }); + } + + const totalBytes = Array.from(files.values()).reduce( + (sum, file) => sum + file.contentLength, + 0, + ); + const writtenBytes = Array.from(files.values()).reduce( + (sum, file) => sum + file.bytesWritten, + 0, + ); + + if (totalBytes > 0) { + progressHandler?.( + Math.max(0, Math.min(1, writtenBytes / totalBytes)), + info, + ); + return; + } + + progressHandler?.(progress, info); + }; +} diff --git a/src/KittenTTSBundledAssets.web.ts b/src/KittenTTSBundledAssets.web.ts new file mode 100644 index 0000000..f9963a6 --- /dev/null +++ b/src/KittenTTSBundledAssets.web.ts @@ -0,0 +1,160 @@ +import { KittenModel } from './KittenModel'; +import { CEPhonemizer } from './phonemizer/CEPhonemizer.web'; +import type { KittenTTSConfig } from './KittenTTSConfig.web'; +import type { KittenPhonemizerProtocol } from './phonemizer/types'; + +export interface KittenTTSBundledModelFiles { + onnx: string; + voices: string; +} + +export interface KittenTTSBundledAssetsManifestV1 { + version: 1; + model: KittenModel | string; + files: KittenTTSBundledModelFiles & { + phonemizerRules: string; + phonemizerList: string; + }; +} + +export interface KittenTTSBundledAssetsManifestV2 { + version: 2; + defaultModel: KittenModel | string; + models: Record; + files: { + phonemizerRules: string; + phonemizerList: string; + }; +} + +export type KittenTTSBundledAssetsManifest = + | KittenTTSBundledAssetsManifestV1 + | KittenTTSBundledAssetsManifestV2; + +export interface CreateBundledAssetConfigOptions + extends Omit { + /** Directory or URL prefix containing the files listed in the manifest. */ + basePath?: string; + /** Browser asset directory. Defaults to `kittentts`. */ + assetRoot?: string; + /** Model to load from a multi-model manifest. Defaults to manifest.defaultModel. */ + model?: KittenModel | string; + /** Override the default CEPhonemizer constructed from manifest text. */ + phonemizer?: KittenPhonemizerProtocol; +} + +/** + * Create a KittenTTS web config from the manifest generated by the CLI. + * + * Browser builds fetch asset bytes from `basePath` when provided, otherwise + * from `assetRoot`. The returned config uses in-memory model bytes so bundlers + * do not need filesystem access. + */ +export async function createBundledAssetConfig( + manifest: KittenTTSBundledAssetsManifest, + options: CreateBundledAssetConfigOptions = {}, +): Promise { + const { + basePath, + assetRoot = 'kittentts', + model: selectedModel, + phonemizer, + ...rest + } = options; + const root = stripTrailingSlash(basePath ?? assetRoot); + const model = parseModel(selectedModel ?? defaultManifestModel(manifest)); + const modelFiles = manifestModelFiles(manifest, model); + + const [onnxData, voicesData, rulesText, listText] = await Promise.all([ + readBundledAssetBinary(root, modelFiles.onnx), + readBundledAssetBinary(root, modelFiles.voices), + readBundledAssetText(root, manifest.files.phonemizerRules), + readBundledAssetText(root, manifest.files.phonemizerList), + ]); + + return { + ...rest, + model, + modelFiles: { + onnxData, + voicesData, + }, + phonemizer: phonemizer ?? new CEPhonemizer({ + rulesText, + listText, + }), + }; +} + +async function readBundledAssetBinary( + root: string, + filePath: string, +): Promise { + const response = await fetch(joinPath(root, filePath)); + if (!response.ok) { + throw new Error(`HTTP ${response.status} loading bundled asset ${filePath}`); + } + return new Uint8Array(await response.arrayBuffer()); +} + +async function readBundledAssetText( + root: string, + filePath: string, +): Promise { + const response = await fetch(joinPath(root, filePath)); + if (!response.ok) { + throw new Error(`HTTP ${response.status} loading bundled asset ${filePath}`); + } + return response.text(); +} + +export function bundledAssetModels( + manifest: KittenTTSBundledAssetsManifest, +): KittenModel[] { + if (manifest.version === 1) return [parseModel(manifest.model)]; + return Object.keys(manifest.models).map(parseModel); +} + +function defaultManifestModel( + manifest: KittenTTSBundledAssetsManifest, +): KittenModel | string { + return manifest.version === 1 ? manifest.model : manifest.defaultModel; +} + +function manifestModelFiles( + manifest: KittenTTSBundledAssetsManifest, + model: KittenModel, +): KittenTTSBundledModelFiles { + if (manifest.version === 1) { + if (parseModel(manifest.model) !== model) { + throw new Error(`Model ${model} is not present in bundled assets manifest.`); + } + return manifest.files; + } + + const files = manifest.models[model]; + if (!files) { + throw new Error(`Model ${model} is not present in bundled assets manifest.`); + } + return files; +} + +function parseModel(model: KittenModel | string): KittenModel { + if (Object.values(KittenModel).includes(model as KittenModel)) { + return model as KittenModel; + } + throw new Error(`Unknown KittenTTS model in bundled assets manifest: ${model}`); +} + +function joinPath(basePath: string, filePath: string): string { + if (!basePath) return filePath; + return `${basePath}/${stripLeadingSlash(filePath)}`; +} + +function stripLeadingSlash(filePath: string): string { + return filePath.replace(/^\/+/, ''); +} + +function stripTrailingSlash(filePath: string): string { + return filePath.replace(/\/+$/, ''); +} diff --git a/src/KittenTTSConfig.web.ts b/src/KittenTTSConfig.web.ts new file mode 100644 index 0000000..e827b70 --- /dev/null +++ b/src/KittenTTSConfig.web.ts @@ -0,0 +1,121 @@ +import { KittenModel } from './KittenModel'; +import { KittenVoice } from './KittenVoice'; +import { CEPhonemizer } from './phonemizer/CEPhonemizer.web'; +import type { KittenPhonemizerProtocol } from './phonemizer/types'; +import type { ModelPaths } from './loader/ModelDownloader.web'; +import { defaultAssetStorage, type AssetStorage } from './storage/AssetStorage'; + +export type KittenTTSModelFiles = ModelPaths; + +export type ResolvedKittenTTSConfig = + Required> & + Pick; + +/** + * Configuration for a {@link KittenTTS} session. + * + * @example + * ```typescript + * const config: KittenTTSConfig = { + * model: KittenModel.Nano, + * defaultVoice: KittenVoice.Luna, + * speed: 1.1, + * }; + * const tts = await KittenTTS.create(config); + * ``` + */ +export interface KittenTTSConfig { + /** The model variant to use. Defaults to {@link KittenModel.Nano}. */ + model?: KittenModel; + + /** Default voice when `voice` is omitted from generate/speak calls. Defaults to {@link KittenVoice.Bella}. */ + defaultVoice?: KittenVoice; + + /** Default speed multiplier (0.5--2.0). Defaults to 1.0 (natural speed). */ + speed?: number; + + /** + * Root directory where downloaded SDK assets are cached. + * Model files live under `//`. + */ + storageDirectory?: string; + + /** + * Override the model file host. The URL must point at a directory containing + * the ONNX file and voices.npz for the selected model. + */ + modelBaseURL?: string; + + /** + * Local ONNX model and voices.npz paths. When provided, KittenTTS uses these + * files directly and skips model downloads/cache lookup. + */ + modelFiles?: KittenTTSModelFiles; + + /** Total download attempts per model file before failing. Defaults to 4. */ + downloadRetries?: number; + + /** Number of ONNX Runtime intra-op threads. Defaults to 4. */ + ortNumThreads?: number; + + /** Maximum tokens per inference chunk. Long texts are split to prevent OOM. Defaults to 400. */ + maxTokensPerChunk?: number; + + /** Trim trailing near-silence from generated chunks. Defaults to true. */ + trimTrailingSilence?: boolean; + + /** Amplitude threshold used for trailing silence trimming. Defaults to 0.005. */ + silenceThreshold?: number; + + /** Maximum trailing silence to trim from each chunk, in milliseconds. Defaults to 250. */ + maxSilenceTrimMs?: number; + + /** Text-to-IPA phonemizer. Defaults to the JS-compiled CEPhonemizer. */ + phonemizer?: KittenPhonemizerProtocol; + + /** Asset cache implementation. Defaults to Cache API in browsers and filesystem cache in Node. */ + storage?: AssetStorage; + + /** Fetch implementation. Defaults to globalThis.fetch. */ + fetch?: typeof fetch; + + /** + * Browser ONNX Runtime wasm asset location. + * + * Defaults to the matching onnxruntime-web CDN asset in browsers. Pass + * `false` when your app configures `ort.env.wasm` itself. + */ + ortWasmPath?: string | { wasm?: string | URL; mjs?: string | URL } | false; +} + +/** The fixed output sample rate for all KittenTTS audio (24 kHz). */ +export const OUTPUT_SAMPLE_RATE = 24_000; + +function defaultPhonemizer(config?: KittenTTSConfig): KittenPhonemizerProtocol { + return new CEPhonemizer({ + storage: config?.storage ?? defaultAssetStorage(config?.storageDirectory ?? 'KittenTTS'), + fetch: config?.fetch ?? globalThis.fetch?.bind(globalThis), + }); +} + +/** Resolve config with defaults applied. */ +export function resolveConfig(config?: KittenTTSConfig): ResolvedKittenTTSConfig { + return { + model: config?.model ?? KittenModel.Nano, + defaultVoice: config?.defaultVoice ?? KittenVoice.Bella, + speed: Math.min(Math.max(config?.speed ?? 1.0, 0.5), 2.0), + storageDirectory: config?.storageDirectory ?? 'KittenTTS', + modelBaseURL: config?.modelBaseURL ?? '', + modelFiles: config?.modelFiles, + downloadRetries: Math.max(1, Math.floor(config?.downloadRetries ?? 4)), + ortNumThreads: Math.max(1, config?.ortNumThreads ?? 4), + maxTokensPerChunk: Math.max(50, config?.maxTokensPerChunk ?? 400), + trimTrailingSilence: config?.trimTrailingSilence ?? true, + silenceThreshold: Math.max(0, config?.silenceThreshold ?? 0.005), + maxSilenceTrimMs: Math.max(0, config?.maxSilenceTrimMs ?? 250), + phonemizer: config?.phonemizer ?? defaultPhonemizer(config), + storage: config?.storage ?? defaultAssetStorage(config?.storageDirectory ?? 'KittenTTS'), + fetch: config?.fetch ?? globalThis.fetch?.bind(globalThis), + ortWasmPath: config?.ortWasmPath, + }; +} diff --git a/src/audio/AudioOutput.ts b/src/audio/AudioOutput.ts index 017e354..7fc54de 100644 --- a/src/audio/AudioOutput.ts +++ b/src/audio/AudioOutput.ts @@ -298,3 +298,17 @@ export function createRNSoundPlayer(Sound: RNSoundConstructor): AudioPlayer { }, }; } + +/** + * Create a browser audio player for React Native Web builds. + * + * Native iOS and Android builds should use `createExpoAudioPlayer()` or + * `createRNSoundPlayer()`. The actual browser implementation is provided by + * the package's web entrypoint. + */ +export function createBrowserAudioPlayer(): AudioPlayer { + // Browser builds use AudioOutput.web.ts; this stub should never be reached. + throw KittenTTSError.playbackFailed( + 'createBrowserAudioPlayer() is only available in React Native Web builds.', + ); +} diff --git a/src/audio/AudioOutput.web.ts b/src/audio/AudioOutput.web.ts new file mode 100644 index 0000000..80bf203 --- /dev/null +++ b/src/audio/AudioOutput.web.ts @@ -0,0 +1,152 @@ +import { + KittenTTSError, + errorMessage, + isKittenTTSError, +} from '../KittenTTSError'; +import { WAVEncoder } from './WAVEncoder'; + +export interface AudioPlayOptions { + /** Called after the configured player has started playback. */ + onPlaybackStart?: () => void; +} + +/** Audio player interface that users can provide. */ +export interface AudioPlayer { + /** Play generated PCM samples. Resolves when playback finishes. */ + play( + samples: Float32Array, + sampleRate: number, + options?: AudioPlayOptions, + ): Promise; + /** Stop current playback. */ + stop(): Promise; +} + +export class AudioOutput { + private player: AudioPlayer | null; + private playing = false; + + constructor(player?: AudioPlayer) { + this.player = player ?? null; + } + + async play( + samples: Float32Array, + sampleRate: number, + options: AudioPlayOptions = {}, + ): Promise { + if (!this.player) { + throw KittenTTSError.playbackFailed( + 'No audio player configured. Pass an AudioPlayer to KittenTTS.create(), ' + + 'or use createBrowserAudioPlayer() in browser apps.', + ); + } + + await this.stop(); + this.playing = true; + try { + await this.player.play(samples, sampleRate, options); + } catch (error) { + if (isKittenTTSError(error)) throw error; + throw KittenTTSError.playbackFailed(errorMessage(error), error); + } finally { + this.playing = false; + } + } + + async stop(): Promise { + if (this.player && this.playing) { + try { + await this.player.stop(); + } catch (error) { + throw KittenTTSError.playbackFailed(errorMessage(error), error); + } + } + this.playing = false; + } +} + +export function createBrowserAudioPlayer(): AudioPlayer { + let current: HTMLAudioElement | null = null; + let currentUrl: string | null = null; + + const cleanup = () => { + if (currentUrl) URL.revokeObjectURL(currentUrl); + currentUrl = null; + current = null; + }; + + return { + async play( + samples: Float32Array, + sampleRate: number, + options: AudioPlayOptions = {}, + ): Promise { + await this.stop(); + const wav = WAVEncoder.encode(samples, sampleRate); + const blob = new Blob([toArrayBuffer(wav) as any], { type: 'audio/wav' } as any); + const url = URL.createObjectURL(blob); + const audio = new Audio(url); + current = audio; + currentUrl = url; + + return new Promise((resolve, reject) => { + let started = false; + audio.onplaying = () => { + if (started) return; + started = true; + options.onPlaybackStart?.(); + }; + audio.onended = () => { + cleanup(); + resolve(); + }; + audio.onerror = () => { + const error = new Error('Browser audio playback failed.'); + cleanup(); + reject(error); + }; + audio.play().catch((error: unknown) => { + cleanup(); + reject(error); + }); + }); + }, + + async stop(): Promise { + const audio = current; + if (audio) { + audio.pause(); + audio.currentTime = 0; + } + cleanup(); + }, + }; +} + +/** + * Compatibility helper for Expo web builds. + * + * Native builds use the `expo-audio` implementation. Web builds play the + * generated WAV through an HTML audio element, so the Expo module argument is + * accepted for shared app code but is not needed. + */ +export function createExpoAudioPlayer(_Audio?: unknown): AudioPlayer { + return createBrowserAudioPlayer(); +} + +/** + * Compatibility helper for shared imports in web builds. + * + * React Native Sound is native-only; web builds use browser audio playback. + */ +export function createRNSoundPlayer(_Sound?: unknown): AudioPlayer { + return createBrowserAudioPlayer(); +} + +function toArrayBuffer(bytes: Uint8Array): ArrayBuffer { + return bytes.buffer.slice( + bytes.byteOffset, + bytes.byteOffset + bytes.byteLength, + ) as ArrayBuffer; +} diff --git a/src/engine/TTSEngine.web.ts b/src/engine/TTSEngine.web.ts new file mode 100644 index 0000000..dc1439a --- /dev/null +++ b/src/engine/TTSEngine.web.ts @@ -0,0 +1,439 @@ +import type * as Ort from 'onnxruntime-web'; +import { + KittenTTSError, + errorMessage, + isKittenTTSError, +} from '../KittenTTSError'; +import { KittenVoice } from '../KittenVoice'; +import { speedPrior } from '../KittenModel'; +import { OUTPUT_SAMPLE_RATE, type ResolvedKittenTTSConfig } from '../KittenTTSConfig.web'; +import { preprocess } from './TextPreprocessor'; +import * as TextCleaner from './TextCleaner'; +import type { VoiceEmbeddings } from '../loader/NPZLoader.web'; + +export interface TTSEngineOutput { + /** Raw Float32 PCM samples at 24 kHz. */ + samples: Float32Array; + + /** Predicted frame count per input token, including wrapper tokens. */ + durations: number[]; + + /** IPA phoneme string returned by the phonemizer. */ + phonemes: string; +} + +type OrtRuntime = typeof Ort; + +type BrowserDocument = { + createElement(tagName: 'script'): { + async: boolean; + src: string; + onload: (() => void) | null; + onerror: (() => void) | null; + }; + head: { + appendChild(element: unknown): void; + }; +}; + +const DEFAULT_ORT_WEB_VERSION = '1.26.0'; + +let browserOrtPromise: Promise | undefined; + +/** + * Internal ONNX inference engine. + * + * Orchestrates: text -> TextPreprocessor -> Phonemizer -> TextCleaner -> ONNX -> Float32 PCM + */ +export class TTSEngine { + private ort: OrtRuntime; + private session: Ort.InferenceSession; + private voices: VoiceEmbeddings; + private config: ResolvedKittenTTSConfig; + private waveformOutputName: string | undefined; + private durationOutputName: string | undefined; + private disposed = false; + + private constructor( + ortRuntime: OrtRuntime, + session: Ort.InferenceSession, + voices: VoiceEmbeddings, + config: ResolvedKittenTTSConfig, + waveformOutputName: string | undefined, + durationOutputName: string | undefined, + ) { + this.ort = ortRuntime; + this.session = session; + this.voices = voices; + this.config = config; + this.waveformOutputName = waveformOutputName; + this.durationOutputName = durationOutputName; + } + + /** + * Create a new TTSEngine by loading the ONNX model and voice embeddings. + */ + static async create( + model: string | Uint8Array, + voices: VoiceEmbeddings, + config: ResolvedKittenTTSConfig, + ): Promise { + try { + const ort = await loadOnnxRuntime(config); + await configureOnnxRuntime(ort, config); + const sessionOptions = { + executionProviders: ['wasm'], + graphOptimizationLevel: 'all', + intraOpNumThreads: config.ortNumThreads, + } as const; + const session = await ort.InferenceSession.create( + model as Uint8Array, + sessionOptions, + ); + const outputNames = session.outputNames ?? []; + const waveformOutputName = outputNames.includes('waveform') + ? 'waveform' + : outputNames[0]; + const durationOutputName = outputNames.includes('duration') + ? 'duration' + : undefined; + return new TTSEngine( + ort, + session, + voices, + config, + waveformOutputName, + durationOutputName, + ); + } catch (error) { + throw KittenTTSError.inferenceFailed( + `Could not initialise ONNX Runtime: ${errorMessage(error)}`, + error, + ); + } + } + + /** + * Synthesise speech and return PCM samples plus optional timing metadata. + */ + async generate( + text: string, + voice: KittenVoice, + speed: number, + ): Promise { + if (this.disposed) throw KittenTTSError.engineNotReady(); + + const embedding = this.voices[voice]; + if (!embedding) { + throw KittenTTSError.noVoiceEmbedding(voice); + } + + const normalised = preprocess(text); + if (!normalised) throw KittenTTSError.emptyInput(); + + let phonemes: string; + try { + phonemes = await this.config.phonemizer.phonemize(normalised); + } catch (error) { + if (isKittenTTSError(error)) throw error; + throw KittenTTSError.phonemizerFailed(errorMessage(error), error); + } + + try { + const tokens = TextCleaner.encode(phonemes); + const chunks = this.splitIntoChunks(tokens); + const effectiveSpeed = speed * speedPrior(this.config.model, voice); + const singleChunk = chunks.length === 1; + + const allChunks: Float32Array[] = []; + let durations: number[] = []; + for (const chunk of chunks) { + const chunkTextLength = Math.max(0, chunk.length - 3); + const output = await this.runChunk( + chunk, + embedding, + chunkTextLength, + effectiveSpeed, + ); + allChunks.push(output.samples); + if (singleChunk) { + durations = output.durations; + } + } + + // Concatenate all chunks + const totalLength = allChunks.reduce((sum, c) => sum + c.length, 0); + if (totalLength === 0) throw KittenTTSError.emptyOutput(); + + const result = new Float32Array(totalLength); + let offset = 0; + for (const chunk of allChunks) { + result.set(chunk, offset); + offset += chunk.length; + } + return { samples: result, durations, phonemes }; + } catch (error) { + if (isKittenTTSError(error)) throw error; + throw KittenTTSError.inferenceFailed(errorMessage(error), error); + } + } + + private async runChunk( + tokens: number[], + embedding: { rows: number; cols: number; data: Float32Array }, + phonemeLength: number, + speed: number, + ): Promise<{ samples: Float32Array; durations: number[] }> { + // Get style vector for this text length + const rowIdx = Math.min(phonemeLength, embedding.rows - 1); + const styleVec = embedding.data.slice( + rowIdx * embedding.cols, + (rowIdx + 1) * embedding.cols, + ); + + // Create tensors + const inputIds = new this.ort.Tensor( + 'int64', + BigInt64Array.from(tokens.map(t => BigInt(t))), + [1, tokens.length], + ); + const styleTensor = new this.ort.Tensor('float32', styleVec, [1, styleVec.length]); + const speedTensor = new this.ort.Tensor('float32', Float32Array.of(speed), [1]); + + const feeds = { + input_ids: inputIds, + style: styleTensor, + speed: speedTensor, + }; + const fetches = this.createOutputFetches(); + const results = fetches + ? await this.session.run(feeds, fetches) + : await this.session.run(feeds); + + const outputKey = this.resolveWaveformOutputKey(results); + if (!outputKey) throw KittenTTSError.emptyOutput(); + + const outputTensor = results[outputKey]; + const samples = outputTensor.data as Float32Array; + if (samples.length === 0) throw KittenTTSError.emptyOutput(); + + return { + samples: this.trimTrailingSilence(samples), + durations: this.readDurations(results, outputKey), + }; + } + + private createOutputFetches(): Record | undefined { + const outputNames = [ + this.waveformOutputName, + this.durationOutputName, + ].filter((name): name is string => Boolean(name)); + + if (outputNames.length === 0) return undefined; + return Object.fromEntries(outputNames.map(name => [name, null])); + } + + private resolveWaveformOutputKey( + results: Awaited>, + ): string | undefined { + if (this.waveformOutputName && results[this.waveformOutputName]) { + return this.waveformOutputName; + } + + const keys = Object.keys(results); + return ( + keys.find(key => results[key].data instanceof Float32Array) ?? + keys.find(key => key !== this.durationOutputName) ?? + keys[0] + ); + } + + private readDurations( + results: Awaited>, + waveformOutputKey: string, + ): number[] { + const durationKey = + this.durationOutputName && results[this.durationOutputName] + ? this.durationOutputName + : Object.keys(results).find(key => { + if (key === waveformOutputKey) return false; + const data = results[key].data; + return ( + data instanceof BigInt64Array || + data instanceof BigUint64Array || + data instanceof Int32Array || + data instanceof Uint32Array + ); + }); + + if (!durationKey) return []; + + const durationTensor = results[durationKey]; + if (!durationTensor) return []; + + return Array.from(durationTensor.data as ArrayLike, value => + typeof value === 'bigint' ? Number(value) : value, + ); + } + + private trimTrailingSilence(samples: Float32Array): Float32Array { + if (!this.config.trimTrailingSilence || samples.length === 0) { + return samples; + } + + const maxTrimSamples = Math.min( + samples.length, + Math.round((this.config.maxSilenceTrimMs / 1000) * OUTPUT_SAMPLE_RATE), + ); + const threshold = this.config.silenceThreshold; + let trimCount = 0; + + while ( + trimCount < maxTrimSamples && + Math.abs(samples[samples.length - 1 - trimCount]) <= threshold + ) { + trimCount += 1; + } + + if (trimCount === 0 || trimCount >= samples.length) { + return samples; + } + return samples.slice(0, samples.length - trimCount); + } + + private splitIntoChunks(tokens: number[]): number[][] { + // Strip the start/end/pad wrapper tokens to get the body + const body = tokens.slice(1, tokens.length - 2); + const maxBody = this.config.maxTokensPerChunk - 3; + + if (body.length <= maxBody) return [tokens]; + + const chunks: number[][] = []; + for (let i = 0; i < body.length; i += maxBody) { + const slice = body.slice(i, Math.min(i + maxBody, body.length)); + chunks.push([ + TextCleaner.START_TOKEN_ID, + ...slice, + TextCleaner.END_TOKEN_ID, + TextCleaner.PAD_TOKEN_ID, + ]); + } + return chunks; + } + + /** Release the ONNX session. */ + async dispose(): Promise { + if (this.disposed) return; + this.disposed = true; + await this.session.release().catch(() => {}); + } +} + +async function loadOnnxRuntime(config: ResolvedKittenTTSConfig): Promise { + if (!isBrowserRuntime()) { + // Keep this import opaque to web bundlers. A plain dynamic import causes + // Metro web to parse ONNX Runtime's generated wasm loader, which fails on + // its dynamic import pattern before the SDK can configure the Node path. + const importModule = new Function( + 'specifier', + 'return import(specifier)', + ) as (specifier: string) => Promise; + return importModule('onnxruntime-web/wasm'); + } + + const scope = globalThis as { + document?: BrowserDocument; + ort?: OrtRuntime; + }; + + if (scope.ort?.InferenceSession) return scope.ort; + if (!scope.document) { + throw new Error('Browser ONNX Runtime requires a document to load its script.'); + } + + if (!browserOrtPromise) { + browserOrtPromise = new Promise((resolve, reject) => { + const script = scope.document!.createElement('script'); + script.async = true; + script.src = defaultOrtScriptURL(config); + script.onload = () => { + if (scope.ort?.InferenceSession) { + resolve(scope.ort); + } else { + reject(new Error('ONNX Runtime script loaded without exposing globalThis.ort.')); + } + }; + script.onerror = () => { + reject(new Error(`Failed to load ONNX Runtime script: ${script.src}`)); + }; + scope.document!.head.appendChild(script); + }); + } + + return browserOrtPromise; +} + +async function configureOnnxRuntime( + ort: OrtRuntime, + config: ResolvedKittenTTSConfig, +): Promise { + if (config.ortWasmPath === false) return; + if (ort.env.wasm.wasmBinary || ort.env.wasm.wasmPaths) return; + + if (typeof config.ortWasmPath === 'string') { + ort.env.wasm.wasmPaths = normalizeWasmDirectory(config.ortWasmPath); + return; + } + + if (config.ortWasmPath) { + ort.env.wasm.wasmPaths = config.ortWasmPath; + return; + } + + if (!isBrowserRuntime()) { + await configureNodeOnnxRuntime(ort); + return; + } + + ort.env.wasm.wasmPaths = { + wasm: `${defaultOrtWasmBaseURL()}ort-wasm-simd-threaded.wasm`, + }; +} + +async function configureNodeOnnxRuntime(ort: OrtRuntime): Promise { + const [{ readFile }, { createRequire }] = await Promise.all([ + import('node:fs/promises'), + import('node:module'), + ]); + const require = createRequire(__filename); + const wasmPath = require.resolve('onnxruntime-web/ort-wasm-simd-threaded.wasm'); + ort.env.wasm.wasmBinary = await readFile(wasmPath); + ort.env.wasm.numThreads = 1; +} + +function defaultOrtWasmBaseURL(): string { + return `https://cdn.jsdelivr.net/npm/onnxruntime-web@${DEFAULT_ORT_WEB_VERSION}/dist/`; +} + +function defaultOrtScriptURL(config: ResolvedKittenTTSConfig): string { + if (typeof config.ortWasmPath === 'string') { + return `${normalizeWasmDirectory(config.ortWasmPath)}ort.wasm.min.js`; + } + return `${defaultOrtWasmBaseURL()}ort.wasm.min.js`; +} + +function normalizeWasmDirectory(path: string): string { + return path.endsWith('/') ? path : `${path}/`; +} + +function isBrowserRuntime(): boolean { + const scope = globalThis as { + window?: unknown; + self?: unknown; + process?: { versions?: { node?: string } }; + }; + return ( + typeof scope.window !== 'undefined' || + (typeof scope.self !== 'undefined' && typeof scope.process?.versions?.node === 'undefined') + ); +} diff --git a/src/index.ts b/src/index.ts index c1824fe..4a7b53e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -25,5 +25,9 @@ export type { export type { KittenPhonemizerProtocol } from './phonemizer/types'; export { CEPhonemizer } from './phonemizer/CEPhonemizer'; export { WAVEncoder } from './audio/WAVEncoder'; -export { createExpoAudioPlayer, createRNSoundPlayer } from './audio/AudioOutput'; +export { + createBrowserAudioPlayer, + createExpoAudioPlayer, + createRNSoundPlayer, +} from './audio/AudioOutput'; export type { AudioPlayer, AudioPlayOptions } from './audio/AudioOutput'; diff --git a/src/index.web.ts b/src/index.web.ts new file mode 100644 index 0000000..51467e2 --- /dev/null +++ b/src/index.web.ts @@ -0,0 +1,40 @@ +export { KittenTTS } from './KittenTTS.web'; +export type { KittenTTSCreateOptions } from './KittenTTS.web'; +export { KittenTTSResult } from './KittenTTSResult'; +export type { KittenWordTiming } from './KittenWordTiming'; +export { + KittenTTSError, + KittenTTSErrorCode, + errorMessage, + isKittenTTSError, +} from './KittenTTSError'; +export { KittenModel, modelDisplayName, approximateDownloadBytes } from './KittenModel'; +export { KittenVoice, ALL_VOICES, voiceDisplayName, isFemaleVoice } from './KittenVoice'; +export { OUTPUT_SAMPLE_RATE } from './KittenTTSConfig.web'; +export type { KittenTTSConfig, KittenTTSModelFiles } from './KittenTTSConfig.web'; +export { bundledAssetModels, createBundledAssetConfig } from './KittenTTSBundledAssets.web'; +export type { + CreateBundledAssetConfigOptions, + KittenTTSBundledAssetsManifest, +} from './KittenTTSBundledAssets.web'; +export type { + DownloadProgressInfo, + ModelCacheInfo, + ProgressHandler, +} from './loader/ModelDownloader.web'; +export type { AssetStorage } from './storage/AssetStorage'; +export { + BrowserCacheAssetStorage, + MemoryAssetStorage, + NodeFileAssetStorage, + defaultAssetStorage, +} from './storage/AssetStorage'; +export type { KittenPhonemizerProtocol } from './phonemizer/types'; +export { CEPhonemizer } from './phonemizer/CEPhonemizer.web'; +export { WAVEncoder } from './audio/WAVEncoder'; +export { + createBrowserAudioPlayer, + createExpoAudioPlayer, + createRNSoundPlayer, +} from './audio/AudioOutput.web'; +export type { AudioPlayer, AudioPlayOptions } from './audio/AudioOutput.web'; diff --git a/src/loader/ModelDownloader.web.ts b/src/loader/ModelDownloader.web.ts new file mode 100644 index 0000000..8197f6b --- /dev/null +++ b/src/loader/ModelDownloader.web.ts @@ -0,0 +1,511 @@ +import { + KittenModel, + huggingFaceBaseURL, + onnxFileName, + voicesFileName, +} from '../KittenModel'; +import { + KittenTTSError, + errorMessage, + isKittenTTSError, +} from '../KittenTTSError'; +import { + type AssetStorage, + defaultAssetStorage, + isNodeRuntime, +} from '../storage/AssetStorage'; + +export type DownloadProgressStage = + | 'checking-cache' + | 'cached' + | 'downloading' + | 'retrying' + | 'complete'; + +export type DownloadProgressAsset = + | 'model' + | 'voices' + | 'phonemizer-rules' + | 'phonemizer-list'; + +export interface DownloadProgressInfo { + stage: DownloadProgressStage; + asset?: DownloadProgressAsset; + cached?: boolean; + attempt?: number; + totalAttempts?: number; + bytesWritten?: number; + contentLength?: number; + message?: string; +} + +export type ProgressHandler = ( + progress: number, + info?: DownloadProgressInfo, +) => void; + +export interface ModelPaths { + onnxPath?: string; + voicesPath?: string; + onnxData?: Uint8Array; + voicesData?: Uint8Array; +} + +export interface FileModelPaths { + onnxPath: string; + voicesPath: string; +} + +export interface ModelCacheInfo extends FileModelPaths { + model: KittenModel; + directory: string; + onnxExists: boolean; + voicesExists: boolean; + isCached: boolean; +} + +export interface ModelDownloadOptions { + force?: boolean; + retries?: number; + baseURL?: string; + storage?: AssetStorage; + fetch?: typeof fetch; +} + +export interface ModelResolveOptions extends ModelDownloadOptions { + modelFiles?: ModelPaths; +} + +const activeDownloads = new Map>(); +const DEFAULT_DOWNLOAD_RETRIES = 4; +const RETRY_DELAY_MS = 750; + +export async function isModelCached( + model: KittenModel, + storageDir: string, + storage?: AssetStorage, +): Promise { + return (await getModelCacheInfo(model, storageDir, storage)).isCached; +} + +export async function getModelCacheInfo( + model: KittenModel, + storageDir: string, + storage = defaultAssetStorage(storageDir), +): Promise { + const dir = resolveDir(model, storageDir); + const onnxPath = `${dir}/${onnxFileName(model)}`; + const voicesPath = `${dir}/${voicesFileName(model)}`; + const [onnxExists, voicesExists] = await Promise.all([ + hasStorageKey(storage, onnxPath), + hasStorageKey(storage, voicesPath), + ]); + return { + model, + directory: dir, + onnxPath, + voicesPath, + onnxExists, + voicesExists, + isCached: onnxExists && voicesExists, + }; +} + +export async function getProvidedModelCacheInfo( + model: KittenModel, + files: ModelPaths, +): Promise { + if (files.onnxData || files.voicesData) { + return { + model, + directory: '', + onnxPath: files.onnxPath ?? '', + voicesPath: files.voicesPath ?? '', + onnxExists: Boolean(files.onnxData || files.onnxPath), + voicesExists: Boolean(files.voicesData || files.voicesPath), + isCached: Boolean((files.onnxData || files.onnxPath) && (files.voicesData || files.voicesPath)), + }; + } + + const paths = normalizeModelPaths(files); + if (!paths.onnxPath) throw KittenTTSError.modelFileNotFound(''); + if (!paths.voicesPath) throw KittenTTSError.voicesFileNotFound(''); + + if (!isNodeRuntime()) { + return { + model, + directory: commonDirectory(paths.onnxPath, paths.voicesPath), + onnxPath: paths.onnxPath, + voicesPath: paths.voicesPath, + onnxExists: true, + voicesExists: true, + isCached: true, + }; + } + + const [onnxExists, voicesExists] = await Promise.all([ + nodeFileExists(paths.onnxPath), + nodeFileExists(paths.voicesPath), + ]); + return { + model, + directory: commonDirectory(paths.onnxPath, paths.voicesPath), + onnxPath: paths.onnxPath, + voicesPath: paths.voicesPath, + onnxExists, + voicesExists, + isCached: onnxExists && voicesExists, + }; +} + +export async function resolveModelPaths( + model: KittenModel, + storageDir: string, + progressHandler?: ProgressHandler, + options: ModelResolveOptions = {}, +): Promise { + if (options.modelFiles) { + progressHandler?.(0, { stage: 'checking-cache', cached: false }); + const info = await getProvidedModelCacheInfo(model, options.modelFiles); + if (!info.onnxExists) throw KittenTTSError.modelFileNotFound(info.onnxPath); + if (!info.voicesExists) throw KittenTTSError.voicesFileNotFound(info.voicesPath); + progressHandler?.(1, { stage: 'cached', cached: true }); + return normalizeModelPaths(options.modelFiles); + } + + return downloadModelIfNeeded(model, storageDir, progressHandler, options); +} + +export async function downloadModelIfNeeded( + model: KittenModel, + storageDir: string, + progressHandler?: ProgressHandler, + options: ModelDownloadOptions = {}, +): Promise { + const storage = options.storage ?? defaultAssetStorage(storageDir); + const retryCount = normalizeRetryCount(options.retries); + const baseURL = options.baseURL ?? huggingFaceBaseURL(model); + const dir = resolveDir(model, storageDir); + const cacheKey = `${model}:${dir}:${baseURL}:${options.force ? 'force' : 'cached'}:${retryCount}`; + const activeDownload = activeDownloads.get(cacheKey); + if (activeDownload) { + const paths = await activeDownload; + progressHandler?.(1, { stage: 'complete' }); + return paths; + } + + const download = downloadModelFilesIfNeeded(model, dir, progressHandler, { + force: options.force ?? false, + retries: retryCount, + baseURL, + storage, + fetch: options.fetch, + }); + activeDownloads.set(cacheKey, download); + try { + return await download; + } finally { + activeDownloads.delete(cacheKey); + } +} + +export async function clearModelCache( + model: KittenModel, + storageDir: string, + storage = defaultAssetStorage(storageDir), +): Promise { + const dir = resolveDir(model, storageDir); + await Promise.all([ + storage.delete(`${dir}/${onnxFileName(model)}`), + storage.delete(`${dir}/${voicesFileName(model)}`), + ]); +} + +async function downloadModelFilesIfNeeded( + model: KittenModel, + dir: string, + progressHandler: ProgressHandler | undefined, + options: Required> & { + storage: AssetStorage; + fetch?: typeof fetch; + }, +): Promise { + const onnxPath = `${dir}/${onnxFileName(model)}`; + const voicesPath = `${dir}/${voicesFileName(model)}`; + + if (options.force) { + await Promise.all([ + options.storage.delete(onnxPath), + options.storage.delete(voicesPath), + ]); + } + + progressHandler?.(0, { stage: 'checking-cache', cached: false }); + + const [onnxExists, voicesExists] = await Promise.all([ + hasStorageKey(options.storage, onnxPath), + hasStorageKey(options.storage, voicesPath), + ]); + + if (onnxExists && voicesExists) { + progressHandler?.(1, { stage: 'cached', cached: true }); + return { + onnxPath, + voicesPath, + onnxData: await requireStorageData(options.storage, onnxPath), + voicesData: await requireStorageData(options.storage, voicesPath), + }; + } + + const aggregateProgress = createAggregateProgress(progressHandler); + const downloads: Promise[] = []; + + if (!onnxExists) { + downloads.push( + downloadFile( + `${options.baseURL}/${onnxFileName(model)}`, + onnxPath, + 'model', + options.retries, + options.storage, + options.fetch, + aggregateProgress, + ), + ); + } + + if (!voicesExists) { + downloads.push( + downloadFile( + `${options.baseURL}/${voicesFileName(model)}`, + voicesPath, + 'voices', + options.retries, + options.storage, + options.fetch, + aggregateProgress, + ), + ); + } + + await Promise.all(downloads); + progressHandler?.(1, { stage: 'complete', cached: false }); + return { + onnxPath, + voicesPath, + onnxData: await requireStorageData(options.storage, onnxPath), + voicesData: await requireStorageData(options.storage, voicesPath), + }; +} + +async function downloadFile( + fromURL: string, + toKey: string, + asset: DownloadProgressAsset, + retries: number, + storage: AssetStorage, + fetchImpl: typeof fetch | undefined, + progressHandler?: ProgressHandler, +): Promise { + let lastError: unknown; + + for (let attempt = 1; attempt <= retries; attempt += 1) { + try { + await downloadFileOnce(fromURL, toKey, asset, attempt, retries, storage, fetchImpl, progressHandler); + return; + } catch (error) { + lastError = error; + if (attempt === retries) break; + progressHandler?.(0, { + stage: 'retrying', + asset, + attempt: attempt + 1, + totalAttempts: retries, + message: errorMessage(error), + }); + await sleep(RETRY_DELAY_MS * attempt); + } + } + + if (isKittenTTSError(lastError)) { + throw KittenTTSError.downloadFailed( + `Failed after ${retries} attempts: ${lastError.message}`, + lastError, + ); + } + + throw KittenTTSError.downloadFailed( + `Failed after ${retries} attempts: ${errorMessage(lastError)}`, + lastError, + ); +} + +async function downloadFileOnce( + fromURL: string, + toKey: string, + asset: DownloadProgressAsset, + attempt: number, + totalAttempts: number, + storage: AssetStorage, + fetchImpl: typeof fetch | undefined, + progressHandler?: ProgressHandler, +): Promise { + const runFetch = fetchImpl ?? globalThis.fetch?.bind(globalThis); + if (!runFetch) { + throw KittenTTSError.downloadFailed('No fetch implementation is available.'); + } + + progressHandler?.(0, { stage: 'downloading', asset, attempt, totalAttempts }); + + try { + const response = await runFetch(fromURL); + if (!response.ok) { + throw KittenTTSError.downloadFailed(`HTTP ${response.status} downloading ${fromURL}`); + } + + const contentLength = Number(response.headers.get('content-length') || 0); + const data = await readResponseBytes(response, contentLength, (bytesWritten) => { + if (contentLength > 0) { + progressHandler?.( + Math.max(0, Math.min(1, bytesWritten / contentLength)), + { + stage: 'downloading', + asset, + attempt, + totalAttempts, + bytesWritten, + contentLength, + }, + ); + } + }); + + await storage.set(toKey, data); + progressHandler?.(1, { stage: 'complete', asset, attempt, totalAttempts }); + } catch (error) { + await storage.delete(toKey).catch(() => {}); + if (isKittenTTSError(error)) throw error; + throw KittenTTSError.downloadFailed(errorMessage(error), error); + } +} + +async function readResponseBytes( + response: Response, + contentLength: number, + onProgress: (bytesWritten: number) => void, +): Promise { + if (!response.body || !response.body.getReader) { + const data = new Uint8Array(await response.arrayBuffer()); + onProgress(data.byteLength); + return data; + } + + const reader = response.body.getReader(); + const chunks: Uint8Array[] = []; + let total = 0; + + for (;;) { + const { done, value } = await reader.read(); + if (done) break; + if (!value) continue; + chunks.push(value); + total += value.byteLength; + onProgress(total); + } + + const result = new Uint8Array(contentLength > 0 ? contentLength : total); + let offset = 0; + for (const chunk of chunks) { + result.set(chunk, offset); + offset += chunk.byteLength; + } + return result; +} + +function resolveDir(model: KittenModel, storageDir: string): string { + const base = storageDir || 'KittenTTS'; + return `${base}/${model}`; +} + +function normalizeModelPaths(files: ModelPaths): ModelPaths { + return { + onnxPath: files.onnxPath ? stripFileScheme(files.onnxPath) : undefined, + voicesPath: files.voicesPath ? stripFileScheme(files.voicesPath) : undefined, + onnxData: files.onnxData, + voicesData: files.voicesData, + }; +} + +function stripFileScheme(filePath: string): string { + return filePath.startsWith('file://') ? filePath.slice('file://'.length) : filePath; +} + +function commonDirectory(firstPath: string, secondPath: string): string { + const firstDir = dirname(firstPath); + return firstDir === dirname(secondPath) ? firstDir : ''; +} + +function dirname(filePath: string): string { + const index = filePath.lastIndexOf('/'); + return index > 0 ? filePath.slice(0, index) : ''; +} + +async function hasStorageKey(storage: AssetStorage, key: string): Promise { + if (storage.has) return storage.has(key); + return (await storage.get(key)) !== null; +} + +async function requireStorageData(storage: AssetStorage, key: string): Promise { + const data = await storage.get(key); + if (!data) throw KittenTTSError.modelFileNotFound(key); + return data; +} + +async function nodeFileExists(filePath: string): Promise { + const fs = await import('node:fs/promises'); + return fs.access(stripFileScheme(filePath)).then(() => true, () => false); +} + +function normalizeRetryCount(retries: number | undefined): number { + return Math.max(1, Math.floor(retries ?? DEFAULT_DOWNLOAD_RETRIES)); +} + +function createAggregateProgress( + progressHandler?: ProgressHandler, +): ProgressHandler { + const files = new Map< + DownloadProgressAsset, + { bytesWritten: number; contentLength: number } + >(); + + return (progress, info) => { + if (info?.asset && info.contentLength && info.contentLength > 0) { + files.set(info.asset, { + bytesWritten: Math.max(0, Math.min(info.bytesWritten ?? 0, info.contentLength)), + contentLength: info.contentLength, + }); + } else if (info?.asset && info.stage === 'complete' && !files.has(info.asset)) { + files.set(info.asset, { bytesWritten: 1, contentLength: 1 }); + } + + const totalBytes = Array.from(files.values()).reduce( + (sum, file) => sum + file.contentLength, + 0, + ); + const writtenBytes = Array.from(files.values()).reduce( + (sum, file) => sum + file.bytesWritten, + 0, + ); + + const aggregateProgress = + totalBytes > 0 + ? Math.max(0, Math.min(1, writtenBytes / totalBytes)) + : progress; + + progressHandler?.(aggregateProgress, info); + }; +} + +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} diff --git a/src/loader/NPZLoader.web.ts b/src/loader/NPZLoader.web.ts new file mode 100644 index 0000000..84ced24 --- /dev/null +++ b/src/loader/NPZLoader.web.ts @@ -0,0 +1,311 @@ +import pako from 'pako'; +import { KittenTTSError, errorMessage, isKittenTTSError } from '../KittenTTSError'; + +/** + * A voice embedding loaded from a `.npz` file. + * + * The rows dimension is indexed by `min(text_length, rows - 1)` following + * the KittenTTS Python implementation. + */ +export interface VoiceEmbedding { + rows: number; + cols: number; + data: Float32Array; +} + +/** Map of voice name to embedding data. */ +export type VoiceEmbeddings = Record; + +/** + * Load all float arrays from a `.npz` file on disk. + * + * Supports ZIP stored (method 0) and DEFLATE-compressed (method 8) entries, + * float32 and float16 NPY arrays, little-endian. + * + * @param filePath - Absolute path to the `.npz` file on device. + * @returns Dictionary mapping array names to VoiceEmbedding values. + */ +export async function loadNPZ(filePath: string): Promise { + try { + const data = await readFileBytes(filePath); + return loadNPZData(data, filePath); + } catch (error) { + if (isKittenTTSError(error)) { + throw error; + } + throw KittenTTSError.invalidModelData( + `Could not load voice embeddings from ${filePath}: ${errorMessage(error)}`, + error, + ); + } +} + +export function loadNPZData( + data: Uint8Array, + source = 'provided voice data', +): VoiceEmbeddings { + const embeddings = parseZIP(data); + if (Object.keys(embeddings).length === 0) { + throw KittenTTSError.invalidModelData( + `No voice embeddings were found in ${source}`, + ); + } + return embeddings; +} + +// --------------------------------------------------------------------------- +// ZIP parsing +// --------------------------------------------------------------------------- + +function parseZIP(data: Uint8Array): VoiceEmbeddings { + const result: VoiceEmbeddings = {}; + const view = new DataView(data.buffer, data.byteOffset, data.byteLength); + let offset = 0; + + while (offset + 30 <= data.length) { + // Check local file header signature + if (view.getUint32(offset, true) !== 0x04034b50) break; + + const method = view.getUint16(offset + 8, true); + let compressedSize = view.getUint32(offset + 18, true); + let uncompressedSize = view.getUint32(offset + 22, true); + const nameLen = view.getUint16(offset + 26, true); + const extraLen = view.getUint16(offset + 28, true); + + const nameStart = offset + 30; + const extraStart = nameStart + nameLen; + const dataStart = extraStart + extraLen; + + // ZIP64: sizes are 0xFFFFFFFF -> read from ZIP64 extra field (tag 0x0001) + if (compressedSize === 0xFFFFFFFF || uncompressedSize === 0xFFFFFFFF) { + let exOff = extraStart; + while (exOff + 4 <= extraStart + extraLen) { + const tag = view.getUint16(exOff, true); + const size = view.getUint16(exOff + 2, true); + if (tag === 0x0001 && exOff + 4 + size >= exOff + 20) { + const uncompressedHigh = view.getUint32(exOff + 8, true); + const compressedHigh = view.getUint32(exOff + 16, true); + if (uncompressedHigh !== 0 || compressedHigh !== 0) { + throw KittenTTSError.invalidModelData( + 'ZIP64 voice embedding entries larger than 4 GB are not supported.', + ); + } + uncompressedSize = view.getUint32(exOff + 4, true); + compressedSize = view.getUint32(exOff + 12, true); + break; + } + exOff += 4 + size; + } + } + + const dataEnd = dataStart + compressedSize; + if (dataEnd > data.length) break; + + const entryName = TEXT_DECODER.decode(data.slice(nameStart, nameStart + nameLen)); + + if (entryName.endsWith('.npy')) { + const compressed = data.slice(dataStart, dataEnd); + let fileData: Uint8Array; + if (method === 0) { + fileData = compressed; + } else if (method === 8) { + fileData = pako.inflateRaw(compressed); + } else { + offset = dataEnd; + continue; + } + + const arrayName = entryName.slice(0, -4); + const embedding = parseNPY(fileData); + if (embedding) { + result[arrayName] = embedding; + } + } + + offset = dataEnd; + } + + return result; +} + +// --------------------------------------------------------------------------- +// NPY parsing +// --------------------------------------------------------------------------- + +function parseNPY(data: Uint8Array): VoiceEmbedding | null { + if (data.length < 10) return null; + + // Magic bytes: 0x93 NUMPY + if ( + data[0] !== 0x93 || + data[1] !== 0x4e || + data[2] !== 0x55 || + data[3] !== 0x4d || + data[4] !== 0x50 || + data[5] !== 0x59 + ) { + return null; + } + + const major = data[6]; + const view = new DataView(data.buffer, data.byteOffset, data.byteLength); + + let headerLen: number; + let headerBase: number; + if (major >= 2) { + if (data.length < 12) return null; + headerLen = view.getUint32(8, true); + headerBase = 12; + } else { + headerLen = view.getUint16(8, true); + headerBase = 10; + } + + const dataStartOffset = headerBase + headerLen; + if (dataStartOffset > data.length) return null; + + const header = TEXT_DECODER.decode(data.slice(headerBase, headerBase + headerLen)); + const shape = parseShape(header); + if (!shape || shape.length < 1) return null; + + const rawData = data.slice(dataStartOffset); + + if (header.includes("'f4'") || header.includes('f4')) { + return makeFloat32Embedding(rawData, shape, header.includes('>f4')); + } + if (header.includes("'f2'") || header.includes('f2')) { + return makeFloat16Embedding(rawData, shape, header.includes('>f2')); + } + + return null; +} + +function parseShape(header: string): number[] | null { + const openIdx = header.indexOf('('); + const closeIdx = header.indexOf(')', openIdx); + if (openIdx === -1 || closeIdx === -1) return null; + + const inside = header.substring(openIdx + 1, closeIdx).trim(); + if (!inside) return [1]; + + const parts = inside + .split(',') + .map((s) => s.trim()) + .filter((s) => s.length > 0) + .map(Number); + + if (parts.some(isNaN)) return null; + return parts; +} + +function makeFloat32Embedding( + rawData: Uint8Array, + shape: number[], + bigEndian: boolean, +): VoiceEmbedding { + const count = Math.floor(rawData.length / 4); + const floats = new Float32Array(count); + const view = new DataView(rawData.buffer, rawData.byteOffset, rawData.byteLength); + + for (let i = 0; i < count; i++) { + floats[i] = view.getFloat32(i * 4, !bigEndian); + } + + const rows = shape.length >= 2 ? shape[0] : 1; + const cols = shape.length >= 2 ? shape[1] : shape[0]; + return { rows, cols, data: floats }; +} + +function makeFloat16Embedding( + rawData: Uint8Array, + shape: number[], + bigEndian: boolean, +): VoiceEmbedding { + const count = Math.floor(rawData.length / 2); + const floats = new Float32Array(count); + const view = new DataView(rawData.buffer, rawData.byteOffset, rawData.byteLength); + + for (let i = 0; i < count; i++) { + const bits = view.getUint16(i * 2, !bigEndian); + floats[i] = float16ToFloat(bits); + } + + const rows = shape.length >= 2 ? shape[0] : 1; + const cols = shape.length >= 2 ? shape[1] : shape[0]; + return { rows, cols, data: floats }; +} + +// --------------------------------------------------------------------------- +// Float16 conversion +// --------------------------------------------------------------------------- + +const FLOAT32_BUFFER = new ArrayBuffer(4); +const FLOAT32_VIEW = new DataView(FLOAT32_BUFFER); + +function float16ToFloat(bits: number): number { + const sign = (bits >>> 15) << 31; + const exp16 = (bits >>> 10) & 0x1f; + const mant16 = bits & 0x3ff; + + if (exp16 === 0) { + if (mant16 === 0) { + // Signed zero + return bitsToFloat32(sign >>> 0); + } + // Subnormal + let m = mant16; + let e = -14; + while ((m & 0x400) === 0) { + m <<= 1; + e -= 1; + } + m &= 0x3ff; + const exp32 = ((e + 127) << 23) >>> 0; + const bits32 = (sign | exp32 | (m << 13)) >>> 0; + return bitsToFloat32(bits32); + } + if (exp16 === 31) { + // Inf or NaN + const bits32 = (sign | 0x7f800000 | (mant16 << 13)) >>> 0; + return bitsToFloat32(bits32); + } + + const exp32 = ((exp16 - 15 + 127) << 23) >>> 0; + const bits32 = (sign | exp32 | (mant16 << 13)) >>> 0; + return bitsToFloat32(bits32); +} + +function bitsToFloat32(bits: number): number { + FLOAT32_VIEW.setUint32(0, bits, false); + return FLOAT32_VIEW.getFloat32(0, false); +} + +// --------------------------------------------------------------------------- +// TextDecoder polyfill for Hermes +// --------------------------------------------------------------------------- + +const TEXT_DECODER: { decode(input: Uint8Array): string } = + typeof TextDecoder !== 'undefined' + ? new TextDecoder() + : { + decode(input: Uint8Array): string { + let result = ''; + for (let i = 0; i < input.length; i++) { + result += String.fromCharCode(input[i]); + } + return result; + }, + }; + +async function readFileBytes(filePath: string): Promise { + if (typeof process === 'undefined' || !process.versions?.node) { + throw KittenTTSError.voicesFileNotFound(filePath); + } + const fs = await import('node:fs/promises'); + const data = await fs.readFile(stripFileScheme(filePath)); + return new Uint8Array(data.buffer, data.byteOffset, data.byteLength); +} + +function stripFileScheme(filePath: string): string { + return filePath.startsWith('file://') ? filePath.slice('file://'.length) : filePath; +} diff --git a/src/phonemizer/CEPhonemizer.web.ts b/src/phonemizer/CEPhonemizer.web.ts new file mode 100644 index 0000000..a05b80c --- /dev/null +++ b/src/phonemizer/CEPhonemizer.web.ts @@ -0,0 +1,431 @@ +import createCEPhonemizerModule from './generated/cephonemizer'; +import type { CEPhonemizerModule } from './generated/cephonemizer'; +import type { KittenPhonemizerProtocol } from './types'; +import { + KittenTTSError, + errorMessage, + isKittenTTSError, +} from '../KittenTTSError'; +import type { + DownloadProgressAsset, + ProgressHandler, +} from '../loader/ModelDownloader.web'; +import { + type AssetStorage, + defaultAssetStorage, + isNodeRuntime, +} from '../storage/AssetStorage'; + +const DEFAULT_RULES_URL = + 'https://raw.githubusercontent.com/espeak-ng/espeak-ng/59eb19938f12e30881c81d86ce4a7de25414c9f4/dictsource/en_rules'; + +const DEFAULT_LIST_URL = + 'https://raw.githubusercontent.com/espeak-ng/espeak-ng/59eb19938f12e30881c81d86ce4a7de25414c9f4/dictsource/en_list'; + +const VIRTUAL_RULES_PATH = '/cephonemizer/en_rules'; +const VIRTUAL_LIST_PATH = '/cephonemizer/en_list'; +const DEFAULT_DOWNLOAD_RETRIES = 4; +const RETRY_DELAY_MS = 750; + +export interface CEPhonemizerOptions { + /** Override the English pronunciation rules URL. Useful for tests or mirrors. */ + rulesURL?: string; + /** Override the English dictionary list URL. Useful for tests or mirrors. */ + listURL?: string; + /** Local English pronunciation rules file. Node.js only. Skips the rules download. */ + rulesPath?: string; + /** Local English dictionary list file. Node.js only. Skips the list download. */ + listPath?: string; + /** English pronunciation rules text. Skips the rules download and file read. */ + rulesText?: string; + /** English dictionary list text. Skips the list download and file read. */ + listText?: string; + /** Dialect passed through to the C++ engine, for example `en-us`. */ + dialect?: string; + /** Asset cache implementation. */ + storage?: AssetStorage; + /** Fetch implementation. Defaults to globalThis.fetch. */ + fetch?: typeof fetch; +} + +type CreateHandle = (rulesPath: string, listPath: string, dialect: string) => number; +type DestroyHandle = (handle: number) => void; +type PhonemizeHandle = (handle: number, text: string) => number; +type FreeString = (ptr: number) => void; +type PhonemizerAsset = Extract< + DownloadProgressAsset, + 'phonemizer-rules' | 'phonemizer-list' +>; + +/** + * Web/Node adapter for the original KittenTTS CEPhonemizer C++ engine. + * + * The C++ source is compiled to a JS-only Emscripten module, so browser and + * backend runtimes can use the same phonemizer logic without platform-native + * bindings. + */ +export class CEPhonemizer implements KittenPhonemizerProtocol { + static readonly defaultRulesURL = DEFAULT_RULES_URL; + static readonly defaultListURL = DEFAULT_LIST_URL; + + private readonly rulesURL: string; + private readonly listURL: string; + private readonly rulesPath?: string; + private readonly listPath?: string; + private readonly rulesText?: string; + private readonly listText?: string; + private readonly dialect: string; + private readonly storage?: AssetStorage; + private readonly fetch?: typeof fetch; + + private module: CEPhonemizerModule | null = null; + private handle = 0; + private createHandle: CreateHandle | null = null; + private destroyHandle: DestroyHandle | null = null; + private phonemizeHandle: PhonemizeHandle | null = null; + private freeString: FreeString | null = null; + + constructor(options: CEPhonemizerOptions = {}) { + this.rulesURL = options.rulesURL ?? DEFAULT_RULES_URL; + this.listURL = options.listURL ?? DEFAULT_LIST_URL; + this.rulesPath = options.rulesPath; + this.listPath = options.listPath; + this.rulesText = options.rulesText; + this.listText = options.listText; + this.dialect = options.dialect ?? 'en-us'; + this.storage = options.storage; + this.fetch = options.fetch; + } + + async downloadIfNeeded( + storageDirectory: string, + progressHandler?: ProgressHandler, + ): Promise { + if (this.hasBundledText() || this.hasBundledPaths()) { + await this.loadBundled(progressHandler); + return; + } + + this.assertNoPartialBundledData(); + + const base = storageDirectory || 'KittenTTS'; + const rulesKey = `${base}/CEPhonemizer/en_rules`; + const listKey = `${base}/CEPhonemizer/en_list`; + const storage = this.storage ?? defaultAssetStorage(storageDirectory); + + try { + const [rulesCached, listCached] = await Promise.all([ + hasStorageKey(storage, rulesKey), + hasStorageKey(storage, listKey), + ]); + + if (rulesCached && listCached) { + progressHandler?.(1, { stage: 'cached', cached: true }); + } else { + progressHandler?.(0, { stage: 'checking-cache', cached: false }); + } + + const aggregateProgress = createAggregateProgress(progressHandler); + const downloads: Promise[] = []; + if (!rulesCached) { + downloads.push( + downloadTextFile(this.rulesURL, rulesKey, 'phonemizer-rules', storage, this.fetch, aggregateProgress), + ); + } + if (!listCached) { + downloads.push( + downloadTextFile(this.listURL, listKey, 'phonemizer-list', storage, this.fetch, aggregateProgress), + ); + } + + await Promise.all(downloads); + + const [rulesData, listData] = await Promise.all([ + requireStorageData(storage, rulesKey), + requireStorageData(storage, listKey), + ]); + + await this.load(TEXT_DECODER.decode(rulesData), TEXT_DECODER.decode(listData)); + progressHandler?.(1, { stage: 'complete', cached: rulesCached && listCached }); + } catch (error) { + if (isKittenTTSError(error)) throw error; + throw KittenTTSError.phonemizerFailed(errorMessage(error), error); + } + } + + async phonemize(text: string): Promise { + if (!this.module || !this.handle || !this.phonemizeHandle || !this.freeString) { + throw KittenTTSError.phonemizerFailed( + 'CEPhonemizer data is not ready. Call downloadIfNeeded() before phonemize().', + ); + } + + const resultPtr = this.phonemizeHandle(this.handle, text); + if (!resultPtr) { + throw KittenTTSError.phonemizerFailed('CEPhonemizer failed to phonemize text.'); + } + + try { + return this.module.UTF8ToString(resultPtr); + } finally { + this.freeString(resultPtr); + } + } + + dispose(): void { + if (this.handle && this.destroyHandle) { + this.destroyHandle(this.handle); + } + this.handle = 0; + this.module = null; + this.createHandle = null; + this.destroyHandle = null; + this.phonemizeHandle = null; + this.freeString = null; + } + + private async load(rules: string, list: string): Promise { + this.dispose(); + + const module = await createCEPhonemizerModule(); + ensureDir(module, '/cephonemizer'); + + module.FS.writeFile(VIRTUAL_RULES_PATH, rules); + module.FS.writeFile(VIRTUAL_LIST_PATH, list); + + const createHandle = module.cwrap( + 'phonemizer_create', + 'number', + ['string', 'string', 'string'], + ) as CreateHandle; + const destroyHandle = module.cwrap('phonemizer_destroy', null, ['number']) as DestroyHandle; + const phonemizeHandle = module.cwrap( + 'phonemizer_phonemize', + 'number', + ['number', 'string'], + ) as PhonemizeHandle; + const freeString = module.cwrap('phonemizer_free_string', null, ['number']) as FreeString; + + const handle = createHandle(VIRTUAL_RULES_PATH, VIRTUAL_LIST_PATH, this.dialect); + if (!handle) { + throw KittenTTSError.phonemizerFailed('CEPhonemizer failed to load en_rules/en_list.'); + } + + this.module = module; + this.handle = handle; + this.createHandle = createHandle; + this.destroyHandle = destroyHandle; + this.phonemizeHandle = phonemizeHandle; + this.freeString = freeString; + } + + private hasBundledText(): boolean { + return this.rulesText !== undefined || this.listText !== undefined; + } + + private hasBundledPaths(): boolean { + return this.rulesPath !== undefined || this.listPath !== undefined; + } + + private assertNoPartialBundledData(): void { + if (this.rulesText !== undefined || this.listText !== undefined) { + if (this.rulesText === undefined || this.listText === undefined) { + throw KittenTTSError.phonemizerFailed( + 'Both rulesText and listText must be provided for bundled CEPhonemizer data.', + ); + } + } + + if (this.rulesPath !== undefined || this.listPath !== undefined) { + if (this.rulesPath === undefined || this.listPath === undefined) { + throw KittenTTSError.phonemizerFailed( + 'Both rulesPath and listPath must be provided for bundled CEPhonemizer data.', + ); + } + } + } + + private async loadBundled(progressHandler?: ProgressHandler): Promise { + this.assertNoPartialBundledData(); + + try { + progressHandler?.(0, { stage: 'checking-cache', cached: false }); + + if (this.rulesText !== undefined && this.listText !== undefined) { + await this.load(this.rulesText, this.listText); + progressHandler?.(1, { stage: 'complete', cached: true }); + return; + } + + if (!this.rulesPath || !this.listPath) { + throw KittenTTSError.phonemizerFailed( + 'Bundled CEPhonemizer data must provide text or Node.js file paths.', + ); + } + if (!isNodeRuntime()) { + throw KittenTTSError.phonemizerFailed( + 'rulesPath/listPath are only supported in Node.js. Use rulesText/listText in browsers.', + ); + } + + const [rules, list] = await Promise.all([ + readNodeTextFile(this.rulesPath), + readNodeTextFile(this.listPath), + ]); + await this.load(rules, list); + progressHandler?.(1, { stage: 'complete', cached: true }); + } catch (error) { + if (isKittenTTSError(error)) throw error; + throw KittenTTSError.phonemizerFailed(errorMessage(error), error); + } + } +} + +async function downloadTextFile( + fromUrl: string, + toKey: string, + asset: PhonemizerAsset, + storage: AssetStorage, + fetchImpl: typeof fetch | undefined, + progressHandler?: ProgressHandler, +): Promise { + let lastError: unknown; + + for (let attempt = 1; attempt <= DEFAULT_DOWNLOAD_RETRIES; attempt += 1) { + try { + await downloadTextFileOnce( + fromUrl, + toKey, + asset, + attempt, + DEFAULT_DOWNLOAD_RETRIES, + storage, + fetchImpl, + progressHandler, + ); + return; + } catch (error) { + lastError = error; + if (attempt === DEFAULT_DOWNLOAD_RETRIES) break; + progressHandler?.(0, { + stage: 'retrying', + asset, + attempt: attempt + 1, + totalAttempts: DEFAULT_DOWNLOAD_RETRIES, + message: errorMessage(error), + }); + await sleep(RETRY_DELAY_MS * attempt); + } + } + + throw KittenTTSError.phonemizerFailed( + `Failed after ${DEFAULT_DOWNLOAD_RETRIES} attempts: ${errorMessage(lastError)}`, + lastError, + ); +} + +async function downloadTextFileOnce( + fromUrl: string, + toKey: string, + asset: PhonemizerAsset, + attempt: number, + totalAttempts: number, + storage: AssetStorage, + fetchImpl: typeof fetch | undefined, + progressHandler?: ProgressHandler, +): Promise { + const runFetch = fetchImpl ?? globalThis.fetch?.bind(globalThis); + if (!runFetch) { + throw KittenTTSError.phonemizerFailed('No fetch implementation is available.'); + } + + progressHandler?.(0, { stage: 'downloading', asset, attempt, totalAttempts }); + + const response = await runFetch(fromUrl); + if (!response.ok) { + throw KittenTTSError.phonemizerFailed(`HTTP ${response.status} downloading ${fromUrl}`); + } + + const contentLength = Number(response.headers.get('content-length') || 0); + const data = new Uint8Array(await response.arrayBuffer()); + progressHandler?.(1, { + stage: 'downloading', + asset, + attempt, + totalAttempts, + bytesWritten: data.byteLength, + contentLength: contentLength || data.byteLength, + }); + await storage.set(toKey, data); + progressHandler?.(1, { stage: 'complete', asset, attempt, totalAttempts }); +} + +async function hasStorageKey(storage: AssetStorage, key: string): Promise { + if (storage.has) return storage.has(key); + return (await storage.get(key)) !== null; +} + +async function requireStorageData(storage: AssetStorage, key: string): Promise { + const data = await storage.get(key); + if (!data) throw KittenTTSError.phonemizerFailed(`Cached phonemizer file not found: ${key}`); + return data; +} + +async function readNodeTextFile(filePath: string): Promise { + const fs = await import('node:fs/promises'); + return fs.readFile(stripFileScheme(filePath), 'utf8'); +} + +function sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +function stripFileScheme(filePath: string): string { + return filePath.startsWith('file://') ? filePath.slice('file://'.length) : filePath; +} + +function createAggregateProgress( + progressHandler?: ProgressHandler, +): ProgressHandler { + const files = new Map< + DownloadProgressAsset, + { bytesWritten: number; contentLength: number } + >(); + + return (progress, info) => { + if (info?.asset && info.contentLength && info.contentLength > 0) { + files.set(info.asset, { + bytesWritten: Math.max(0, Math.min(info.bytesWritten ?? 0, info.contentLength)), + contentLength: info.contentLength, + }); + } else if (info?.asset && info.stage === 'complete' && !files.has(info.asset)) { + files.set(info.asset, { bytesWritten: 1, contentLength: 1 }); + } + + const totalBytes = Array.from(files.values()).reduce( + (sum, file) => sum + file.contentLength, + 0, + ); + const writtenBytes = Array.from(files.values()).reduce( + (sum, file) => sum + file.bytesWritten, + 0, + ); + + const aggregateProgress = + totalBytes > 0 + ? Math.max(0, Math.min(1, writtenBytes / totalBytes)) + : progress; + + progressHandler?.(aggregateProgress, info); + }; +} + +function ensureDir(module: CEPhonemizerModule, path: string): void { + try { + module.FS.mkdir(path); + } catch { + // Emscripten throws if the directory already exists. + } +} + +const TEXT_DECODER = new TextDecoder(); diff --git a/src/storage/AssetStorage.ts b/src/storage/AssetStorage.ts new file mode 100644 index 0000000..6ba4c74 --- /dev/null +++ b/src/storage/AssetStorage.ts @@ -0,0 +1,138 @@ +export interface AssetStorage { + get(key: string): Promise; + set(key: string, data: Uint8Array): Promise; + delete(key: string): Promise; + has?(key: string): Promise; + pathForKey?(key: string): Promise; +} + +export class MemoryAssetStorage implements AssetStorage { + private readonly entries = new Map(); + + async get(key: string): Promise { + const data = this.entries.get(key); + return data ? new Uint8Array(data) : null; + } + + async set(key: string, data: Uint8Array): Promise { + this.entries.set(key, new Uint8Array(data)); + } + + async delete(key: string): Promise { + this.entries.delete(key); + } + + async has(key: string): Promise { + return this.entries.has(key); + } +} + +export class BrowserCacheAssetStorage implements AssetStorage { + constructor(private readonly cacheName = 'kittentts-web') {} + + async get(key: string): Promise { + if (!hasCacheStorage()) return null; + const cache = await caches.open(this.cacheName); + const response = await cache.match(cacheRequest(key)); + if (!response || !response.ok) return null; + return new Uint8Array(await response.arrayBuffer()); + } + + async set(key: string, data: Uint8Array): Promise { + if (!hasCacheStorage()) return; + const cache = await caches.open(this.cacheName); + await cache.put(cacheRequest(key), new Response(toArrayBuffer(data))); + } + + async delete(key: string): Promise { + if (!hasCacheStorage()) return; + const cache = await caches.open(this.cacheName); + await cache.delete(cacheRequest(key)); + } + + async has(key: string): Promise { + return (await this.get(key)) !== null; + } +} + +export class NodeFileAssetStorage implements AssetStorage { + constructor(private readonly rootDirectory?: string) {} + + async get(key: string): Promise { + const filePath = await this.pathForKey(key); + try { + const fs = await import('node:fs/promises'); + const data = await fs.readFile(filePath); + return new Uint8Array(data.buffer, data.byteOffset, data.byteLength); + } catch { + return null; + } + } + + async set(key: string, data: Uint8Array): Promise { + const filePath = await this.pathForKey(key); + const fs = await import('node:fs/promises'); + const path = await import('node:path'); + await fs.mkdir(path.dirname(filePath), { recursive: true }); + const tempPath = `${filePath}.download`; + await fs.writeFile(tempPath, data); + await fs.rename(tempPath, filePath); + } + + async delete(key: string): Promise { + const filePath = await this.pathForKey(key); + const fs = await import('node:fs/promises'); + await fs.unlink(filePath).catch(() => {}); + await fs.unlink(`${filePath}.download`).catch(() => {}); + } + + async has(key: string): Promise { + const filePath = await this.pathForKey(key); + const fs = await import('node:fs/promises'); + return fs.access(filePath).then(() => true, () => false); + } + + async pathForKey(key: string): Promise { + const path = await import('node:path'); + const root = this.rootDirectory ?? await defaultNodeCacheDirectory(); + return path.join(root, ...key.split('/').map(safeSegment)); + } +} + +let memoryFallback: MemoryAssetStorage | null = null; + +export function defaultAssetStorage(storageDirectory?: string): AssetStorage { + if (isNodeRuntime()) return new NodeFileAssetStorage(storageDirectory); + if (hasCacheStorage()) return new BrowserCacheAssetStorage(storageDirectory || 'kittentts-web'); + memoryFallback ??= new MemoryAssetStorage(); + return memoryFallback; +} + +export function isNodeRuntime(): boolean { + return typeof process !== 'undefined' && Boolean(process.versions?.node); +} + +function hasCacheStorage(): boolean { + return typeof caches !== 'undefined' && typeof Response !== 'undefined'; +} + +function cacheRequest(key: string): Request { + return new Request(`https://kittentts.local/cache/${encodeURIComponent(key)}`); +} + +async function defaultNodeCacheDirectory(): Promise { + const os = await import('node:os'); + const path = await import('node:path'); + return path.join(os.homedir(), '.cache', 'kittentts-web'); +} + +function safeSegment(segment: string): string { + return segment.replace(/[^a-zA-Z0-9._-]/g, '_') || '_'; +} + +function toArrayBuffer(bytes: Uint8Array): ArrayBuffer { + return bytes.buffer.slice( + bytes.byteOffset, + bytes.byteOffset + bytes.byteLength, + ) as ArrayBuffer; +} diff --git a/src/web-globals.d.ts b/src/web-globals.d.ts new file mode 100644 index 0000000..f92474d --- /dev/null +++ b/src/web-globals.d.ts @@ -0,0 +1,21 @@ +interface HTMLAudioElement { + src: string; + currentTime: number; + onended: (() => void) | null; + onerror: (() => void) | null; + onplaying: (() => void) | null; + pause(): void; + play(): Promise; +} + +declare const Audio: { + new(src?: string): HTMLAudioElement; +}; + +declare const caches: { + open(cacheName: string): Promise<{ + match(request: Request): Promise; + put(request: Request, response: Response): Promise; + delete(request: Request): Promise; + }>; +};