Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
e407016
wip
chmjkb Jan 16, 2025
1f4b875
chore: remove unused function
chmjkb Jan 16, 2025
b5970ea
refactor: remove unintended for loop and data structures
chmjkb Jan 16, 2025
f24d1bb
fix: add error handling & lint
chmjkb Jan 16, 2025
bfeea95
fix: (native) update ETModule's forward to accept multiple inputs
chmjkb Jan 17, 2025
19f4fb5
refactor: get rid of InputType
chmjkb Jan 20, 2025
60828c9
replace single input forward with multiple inputs
chmjkb Jan 22, 2025
8408656
lint
chmjkb Jan 22, 2025
4d904ca
fix: make use of existing functions, return actual output
chmjkb Jan 22, 2025
4ab655f
fix: update rnexecutorchmodules.ts
chmjkb Jan 22, 2025
51c3e71
fix: update BaseModel to match new native implementation, remove Inpu…
chmjkb Jan 22, 2025
ec4d322
fix: int8_T -> char
chmjkb Jan 22, 2025
27af316
feat: make Android accept multiple inputs
chmjkb Jan 23, 2025
b8d71ae
refactor: remove unused function
chmjkb Jan 23, 2025
4daa790
chore: remove useless continue
chmjkb Jan 24, 2025
2719ca4
feat: make use of the new interface in BaseModel
chmjkb Jan 27, 2025
c55bbe8
chore: remove log import
chmjkb Jan 28, 2025
4acdd1b
fix: unsqueeze shapes if a single number is passed
chmjkb Jan 28, 2025
6974d38
chore: add a comment for a hack in forward()
chmjkb Jan 28, 2025
5934228
fix: fix types
chmjkb Jan 28, 2025
55f4b1c
chore: add comment for nsarraytovoidptr
chmjkb Jan 28, 2025
b73c8dc
fix: unsqueeze input before forwarding it to cpp
chmjkb Feb 21, 2025
5e748b3
feat: add multiple inputs to hookless api
chmjkb Feb 24, 2025
3217786
fix: minor fix
chmjkb Feb 24, 2025
2729838
fix: add --noEmit flag to lefthook
chmjkb Feb 24, 2025
c9a478a
fix: (native) update ETModule's forward to accept multiple inputs
chmjkb Jan 17, 2025
a0f2a1d
lint
chmjkb Jan 22, 2025
5686f48
feat: make Android accept multiple inputs
chmjkb Jan 23, 2025
daed011
wip
chmjkb Jan 27, 2025
490d576
fix: fix types
chmjkb Jan 28, 2025
ec176dd
wip
chmjkb Feb 12, 2025
5ef2484
wip
chmjkb Feb 12, 2025
fd93f53
wip
chmjkb Feb 14, 2025
cad049b
wip
chmjkb Feb 20, 2025
4f5faf9
feat: add fft lib to build.gradle
chmjkb Feb 20, 2025
a389ee1
feat: add array utils, finish android impl
chmjkb Feb 20, 2025
e02ce71
chore: move magic numbers to varibales
chmjkb Feb 20, 2025
6858c5d
feat: add multiple inputs forward to BaseModel
chmjkb Feb 20, 2025
cdfb6a3
chore: remove unused import from ClassificationModel
chmjkb Feb 20, 2025
28e0d5c
feat: add ScalarType to objc
chmjkb Feb 20, 2025
467f5ee
fix
chmjkb Feb 20, 2025
4d2864d
fix: fix scalartype enum
chmjkb Feb 20, 2025
76830a8
chore: apply ScalarType to stt.mm
chmjkb Feb 20, 2025
a96682c
chore: remove redundant params from native stt spec
chmjkb Feb 20, 2025
4488fd6
feat: add stft module to kotlin
chmjkb Feb 20, 2025
175ec56
fix: align rn executorch modules with native spec
chmjkb Feb 20, 2025
f8c6d2d
fix: minor stt changes
chmjkb Feb 20, 2025
b39f57b
chore: remove unused comments
chmjkb Feb 20, 2025
1e2fc64
feat: Make stt controller use a promise
chmjkb Feb 20, 2025
72f9f8e
fix: remove prevTokens arg from native generate() in Android
chmjkb Feb 20, 2025
e443378
feat: add whisper decoding map
chmjkb Feb 21, 2025
f4423c6
feat: use the decoder in stt controller
chmjkb Feb 21, 2025
f33d0bb
Merge branch 'main' into @chmjkb/whisper
chmjkb Feb 24, 2025
628ace1
yarn lock
chmjkb Feb 24, 2025
64e6f19
fix: make eslint ignore decoders dir
chmjkb Feb 24, 2025
f120af9
chore: move back getTypeIdentifier to common/types to avoid exposing …
chmjkb Feb 24, 2025
37aaa85
lint
chmjkb Feb 24, 2025
4e4bff4
fix: update third party
chmjkb Feb 24, 2025
bfe7ae5
fix: remove fetcher import
chmjkb Feb 24, 2025
1431131
chore: add JTransforms mention to license
chmjkb Feb 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,28 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

This software includes components from the JTransforms library. The license and copyright notice for this library are as follows:
JTransforms
Copyright (c) 2007 onward, Piotr Wendykier
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
1 change: 1 addition & 0 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ dependencies {
// For < 0.71, this will be from the local maven repo
// For > 0.71, this will be replaced by `com.facebook.react:react-android:$version` by react gradle plugin
//noinspection GradleDynamicVersion
implementation 'com.github.wendykierp:JTransforms:3.1'
implementation "com.facebook.react:react-android:+"
implementation 'org.opencv:opencv:4.10.0'
implementation "org.jetbrains.kotlin:kotlin-stdlib:$kotlin_version"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class RnExecutorchPackage : TurboReactPackage() {
Classification(reactContext)
} else if (name == ObjectDetection.NAME) {
ObjectDetection(reactContext)
} else if (name == SpeechToText.NAME) {
SpeechToText(reactContext)
}
else {
null
Expand Down Expand Up @@ -74,6 +76,15 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true
)

moduleInfos[SpeechToText.NAME] = ReactModuleInfo(
SpeechToText.NAME,
SpeechToText.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Expand Down
55 changes: 55 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/SpeechToText.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package com.swmansion.rnexecutorch

import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.models.speechToText.WhisperDecoder
import com.swmansion.rnexecutorch.models.speechToText.WhisperEncoder
import com.swmansion.rnexecutorch.models.speechToText.WhisperPreprocessor
import com.swmansion.rnexecutorch.utils.ArrayUtils
import com.swmansion.rnexecutorch.utils.ETError

class SpeechToText(reactContext: ReactApplicationContext) :
NativeSpeechToTextSpec(reactContext) {
private var whisperPreprocessor = WhisperPreprocessor(reactContext)
private var whisperEncoder = WhisperEncoder(reactContext)
private var whisperDecoder = WhisperDecoder(reactContext)
private var START_TOKEN = 50257
private var EOS_TOKEN = 50256
Comment on lines +17 to +18
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we cshould move those inside Whisper model, as they can vary from tokenizer to tokenizer. Or move it to another config file tokenizer based, but I like the first one better right now


companion object {
const val NAME = "SpeechToText"
}

override fun loadModule(preprocessorSource: String, encoderSource: String, decoderSource: String, promise: Promise) {
try {
this.whisperPreprocessor.loadModel(preprocessorSource)
this.whisperEncoder.loadModel(encoderSource)
this.whisperDecoder.loadModel(decoderSource)
promise.resolve(0)
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

override fun generate(waveform: ReadableArray, promise: Promise) {
val logMel = this.whisperPreprocessor.runModel(waveform)
val encoding = this.whisperEncoder.runModel(logMel)
val generatedTokens = mutableListOf(this.START_TOKEN)
var lastToken = 0
Thread {
while (lastToken != this.EOS_TOKEN) {
this.whisperDecoder.setGeneratedTokens(generatedTokens)
lastToken = this.whisperDecoder.runModel(encoding)
emitOnToken(lastToken.toDouble())
generatedTokens.add(lastToken)
}
val generatedTokensReadableArray = ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray())
promise.resolve(generatedTokensReadableArray)
}.start()
}

override fun getName(): String {
return NAME
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,21 @@ abstract class BaseModel<Input, Output>(val context: Context) {
//The error is thrown when transformation to Tensor fails
throw Error(ETError.InvalidArgument.code.toString())
} catch (e: Exception) {
throw Error(e.message!!)
throw Error(e.message)
}
}

protected fun forward(inputs: Array<FloatArray>, shapes: Array<LongArray>) : Array<EValue> {
// We want to convert each input to EValue, a data structure accepted by ExecuTorch's
// Module. The array below keeps track of that values.
try {
val executorchInputs = inputs.mapIndexed { index, _ -> EValue.from(Tensor.fromBlob(inputs[index], shapes[index]))}
val forwardResult = module.forward(*executorchInputs.toTypedArray())
return forwardResult
} catch (e: IllegalArgumentException) {
throw Error(ETError.InvalidArgument.code.toString())
} catch (e: Exception) {
throw Error(e.message)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.opencv.imgproc.Imgproc
import org.pytorch.executorch.Tensor
import org.pytorch.executorch.EValue
import com.swmansion.rnexecutorch.models.BaseModel

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package com.swmansion.rnexecutorch.models.speechToText

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class WhisperDecoder(
reactApplicationContext: ReactApplicationContext,
) : BaseModel<EValue, Int>(reactApplicationContext) {
private var generatedTokens: MutableList<Int> = mutableListOf()

fun setGeneratedTokens(tokens: MutableList<Int>) {
this.generatedTokens = tokens
}

override fun runModel(input: EValue): Int {
val tokensEValue = EValue.from(Tensor.fromBlob(this.generatedTokens.toIntArray(), longArrayOf(1, generatedTokens.size.toLong())))
return this.module
.forward(tokensEValue, input)[0]
.toTensor()
.dataAsLongArray[0]
.toInt()
}

override fun preprocess(input: EValue): EValue {
TODO("Not yet implemented")
}

override fun postprocess(output: Array<EValue>): Int {
TODO("Not yet implemented")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package com.swmansion.rnexecutorch.models.speechToText

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class WhisperEncoder(reactApplicationContext: ReactApplicationContext) :
BaseModel<EValue, EValue>(reactApplicationContext) {
private val encoderInputShape = longArrayOf(1L, 80L, 3000L)

override fun runModel(input: EValue): EValue {
val inputEValue = this.preprocess(input)
val hiddenState = this.module.forward(inputEValue)
return hiddenState[0]
}

override fun preprocess(input: EValue): EValue {
val inputTensor = Tensor.fromBlob(input.toTensor().dataAsFloatArray, this.encoderInputShape)
return EValue.from(inputTensor)
}

override fun postprocess(output: Array<EValue>): EValue {
TODO("Not yet implemented")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.swmansion.rnexecutorch.models.speechToText

import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.utils.STFT
import org.pytorch.executorch.EValue
import org.pytorch.executorch.Tensor

class WhisperPreprocessor(reactApplicationContext: ReactApplicationContext) :
BaseModel<ReadableArray, EValue>(reactApplicationContext) {
private val fftSize = 512
private val hopLength = 160
private val stft = STFT(fftSize, hopLength)

override fun runModel(input: ReadableArray): EValue {
val size = input.size()
val inputFloatArray = FloatArray(size)
for (i in 0 until size) {
inputFloatArray[i] = input.getDouble(i).toFloat()
}
val stftResult = this.stft.fromWaveform(inputFloatArray)
val numStftFrames = stftResult.size / (this.fftSize / 2)
val preprocessorInputShape = longArrayOf(numStftFrames.toLong(), (this.fftSize / 2).toLong())
val melSpectrogram = this.module.forward(EValue.from(Tensor.fromBlob(stftResult, preprocessorInputShape)))
return melSpectrogram[0]
}

override fun preprocess(input: ReadableArray): EValue {
TODO("Not yet implemented")
}

override fun postprocess(output: Array<EValue>): EValue {
TODO("Not yet implemented")
}
}
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
package com.swmansion.rnexecutorch.utils

import android.util.Log
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReadableArray
import org.pytorch.executorch.DType
import org.pytorch.executorch.Tensor

class ArrayUtils {
companion object {
private inline fun <reified T> createTypedArrayFromReadableArray(input: ReadableArray, transform: (ReadableArray, Int) -> T): Array<T> {
inline fun <reified T> createTypedArrayFromReadableArray(input: ReadableArray, transform: (ReadableArray, Int) -> T): Array<T> {
return Array(input.size()) { index -> transform(input, index) }
}

fun createByteArray(input: ReadableArray): ByteArray {
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray()
}

fun createCharArray(input: ReadableArray): CharArray {
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toChar() }.toCharArray()
}
fun createByteArray(input: ReadableArray): ByteArray {
return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray()
}
Expand Down Expand Up @@ -62,5 +70,18 @@ class ArrayUtils {

return resultArray
}

fun createReadableArrayFromFloatArray(input: FloatArray): ReadableArray {
val resultArray = Arguments.createArray()
input.forEach { resultArray.pushDouble(it.toDouble()) }
return resultArray
}

fun createReadableArrayFromIntArray(input: IntArray): ReadableArray {
val resultArray = Arguments.createArray()
input.forEach { resultArray.pushInt(it) }
return resultArray
}

}
}
48 changes: 48 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/utils/STFT.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package com.swmansion.rnexecutorch.utils

import java.util.Vector
import kotlin.math.cos
import kotlin.math.PI
import org.jtransforms.fft.FloatFFT_1D
import kotlin.math.sqrt

class STFT public constructor(var fftSize: Int = 512, var hopLength: Int = 160) {
private val fftModule = FloatFFT_1D(this.fftSize.toLong())
private val magnitudeScale = 1.0 / this.fftSize
// https://www.mathworks.com/help/signal/ref/hann.html
private val hannWindow = FloatArray(this.fftSize) { i ->0.5f - 0.5f * cos(2f * PI.toFloat() * i / this.fftSize) }


fun fromWaveform(signal: FloatArray): FloatArray {
val numFftFrames = (signal.size - this.fftSize) / this.hopLength
// The output of FFT is always 2x smaller
val stft = FloatArray(numFftFrames * (this.fftSize / 2))

var windowStartIdx = 0
var outputIndex = 0
// TODO: i dont think the substraction at the end is needed
while (windowStartIdx + this.fftSize < signal.size - this.fftSize) {
val currentWindow = signal.copyOfRange(windowStartIdx, windowStartIdx + this.fftSize)
// Apply Hann window to the current slice
for (i in currentWindow.indices) currentWindow[i] *= this.hannWindow[i]

// Perform in-place FFT
this.fftModule.realForward(currentWindow)

stft[outputIndex++] = kotlin.math.abs(currentWindow[0])
for (i in 1 until this.fftSize / 2 - 1) {
val real = currentWindow[2 * i]
val imag = currentWindow[2 * i + 1]

val currentMagnitude = (sqrt(real * real + imag * imag) * this.magnitudeScale).toFloat()
// FIXME: we don't need that, but if we remove this we have to get rid of
// reversing this operation in the preprocessing part
stft[outputIndex++] = 20 * kotlin.math.log10(currentMagnitude)
}
// Nyquist frequency
stft[outputIndex++] = kotlin.math.abs(currentWindow[1])
windowStartIdx += this.hopLength
}
return stft
}
}
8 changes: 4 additions & 4 deletions ios/ExecutorchLib.xcframework/Info.plist
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
<key>BinaryPath</key>
<string>ExecutorchLib.framework/ExecutorchLib</string>
<key>LibraryIdentifier</key>
<string>ios-arm64</string>
<string>ios-arm64-simulator</string>
<key>LibraryPath</key>
<string>ExecutorchLib.framework</string>
<key>SupportedArchitectures</key>
Expand All @@ -17,12 +17,14 @@
</array>
<key>SupportedPlatform</key>
<string>ios</string>
<key>SupportedPlatformVariant</key>
<string>simulator</string>
</dict>
<dict>
<key>BinaryPath</key>
<string>ExecutorchLib.framework/ExecutorchLib</string>
<key>LibraryIdentifier</key>
<string>ios-arm64-simulator</string>
<string>ios-arm64</string>
<key>LibraryPath</key>
<string>ExecutorchLib.framework</string>
<key>SupportedArchitectures</key>
Expand All @@ -31,8 +33,6 @@
</array>
<key>SupportedPlatform</key>
<string>ios</string>
<key>SupportedPlatformVariant</key>
<string>simulator</string>
</dict>
</array>
<key>CFBundlePackageType</key>
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
- (NSNumber *)loadModel:(NSString *)filePath;
- (NSNumber *)loadMethod:(NSString *)methodName;
- (NSNumber *)loadForward;
- (NSArray *)forward:(NSArray *)input
shape:(NSArray *)shape
inputType:(NSNumber *)inputType;
- (NSArray *)forward:(NSArray *)inputs
shapes:(NSArray *)shapes
inputTypes: (NSArray *)inputTypes;
- (NSNumber *)getNumberOfInputs;
- (NSNumber *)getInputType:(NSNumber *)index;
- (NSArray *)getInputShape:(NSNumber *)index;
Expand Down
Binary file not shown.
Loading