-
Notifications
You must be signed in to change notification settings - Fork 74
feat: Whisper #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat: Whisper #101
Changes from all commits
Commits
Show all changes
60 commits
Select commit
Hold shift + click to select a range
e407016
wip
chmjkb 1f4b875
chore: remove unused function
chmjkb b5970ea
refactor: remove unintended for loop and data structures
chmjkb f24d1bb
fix: add error handling & lint
chmjkb bfeea95
fix: (native) update ETModule's forward to accept multiple inputs
chmjkb 19f4fb5
refactor: get rid of InputType
chmjkb 60828c9
replace single input forward with multiple inputs
chmjkb 8408656
lint
chmjkb 4d904ca
fix: make use of existing functions, return actual output
chmjkb 4ab655f
fix: update rnexecutorchmodules.ts
chmjkb 51c3e71
fix: update BaseModel to match new native implementation, remove Inpu…
chmjkb ec4d322
fix: int8_T -> char
chmjkb 27af316
feat: make Android accept multiple inputs
chmjkb b8d71ae
refactor: remove unused function
chmjkb 4daa790
chore: remove useless continue
chmjkb 2719ca4
feat: make use of the new interface in BaseModel
chmjkb c55bbe8
chore: remove log import
chmjkb 4acdd1b
fix: unsqueeze shapes if a single number is passed
chmjkb 6974d38
chore: add a comment for a hack in forward()
chmjkb 5934228
fix: fix types
chmjkb 55f4b1c
chore: add comment for nsarraytovoidptr
chmjkb b73c8dc
fix: unsqueeze input before forwarding it to cpp
chmjkb 5e748b3
feat: add multiple inputs to hookless api
chmjkb 3217786
fix: minor fix
chmjkb 2729838
fix: add --noEmit flag to lefthook
chmjkb c9a478a
fix: (native) update ETModule's forward to accept multiple inputs
chmjkb a0f2a1d
lint
chmjkb 5686f48
feat: make Android accept multiple inputs
chmjkb daed011
wip
chmjkb 490d576
fix: fix types
chmjkb ec176dd
wip
chmjkb 5ef2484
wip
chmjkb fd93f53
wip
chmjkb cad049b
wip
chmjkb 4f5faf9
feat: add fft lib to build.gradle
chmjkb a389ee1
feat: add array utils, finish android impl
chmjkb e02ce71
chore: move magic numbers to varibales
chmjkb 6858c5d
feat: add multiple inputs forward to BaseModel
chmjkb cdfb6a3
chore: remove unused import from ClassificationModel
chmjkb 28e0d5c
feat: add ScalarType to objc
chmjkb 467f5ee
fix
chmjkb 4d2864d
fix: fix scalartype enum
chmjkb 76830a8
chore: apply ScalarType to stt.mm
chmjkb a96682c
chore: remove redundant params from native stt spec
chmjkb 4488fd6
feat: add stft module to kotlin
chmjkb 175ec56
fix: align rn executorch modules with native spec
chmjkb f8c6d2d
fix: minor stt changes
chmjkb b39f57b
chore: remove unused comments
chmjkb 1e2fc64
feat: Make stt controller use a promise
chmjkb 72f9f8e
fix: remove prevTokens arg from native generate() in Android
chmjkb e443378
feat: add whisper decoding map
chmjkb f4423c6
feat: use the decoder in stt controller
chmjkb f33d0bb
Merge branch 'main' into @chmjkb/whisper
chmjkb 628ace1
yarn lock
chmjkb 64e6f19
fix: make eslint ignore decoders dir
chmjkb f120af9
chore: move back getTypeIdentifier to common/types to avoid exposing …
chmjkb 37aaa85
lint
chmjkb 4e4bff4
fix: update third party
chmjkb bfe7ae5
fix: remove fetcher import
chmjkb 1431131
chore: add JTransforms mention to license
chmjkb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
android/src/main/java/com/swmansion/rnexecutorch/SpeechToText.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| package com.swmansion.rnexecutorch | ||
|
|
||
| import com.facebook.react.bridge.Promise | ||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.facebook.react.bridge.ReadableArray | ||
| import com.swmansion.rnexecutorch.models.speechToText.WhisperDecoder | ||
| import com.swmansion.rnexecutorch.models.speechToText.WhisperEncoder | ||
| import com.swmansion.rnexecutorch.models.speechToText.WhisperPreprocessor | ||
| import com.swmansion.rnexecutorch.utils.ArrayUtils | ||
| import com.swmansion.rnexecutorch.utils.ETError | ||
|
|
||
| class SpeechToText(reactContext: ReactApplicationContext) : | ||
| NativeSpeechToTextSpec(reactContext) { | ||
| private var whisperPreprocessor = WhisperPreprocessor(reactContext) | ||
| private var whisperEncoder = WhisperEncoder(reactContext) | ||
| private var whisperDecoder = WhisperDecoder(reactContext) | ||
| private var START_TOKEN = 50257 | ||
| private var EOS_TOKEN = 50256 | ||
|
|
||
| companion object { | ||
| const val NAME = "SpeechToText" | ||
| } | ||
|
|
||
| override fun loadModule(preprocessorSource: String, encoderSource: String, decoderSource: String, promise: Promise) { | ||
| try { | ||
| this.whisperPreprocessor.loadModel(preprocessorSource) | ||
| this.whisperEncoder.loadModel(encoderSource) | ||
| this.whisperDecoder.loadModel(decoderSource) | ||
| promise.resolve(0) | ||
| } catch (e: Exception) { | ||
| promise.reject(e.message!!, ETError.InvalidModelSource.toString()) | ||
| } | ||
| } | ||
|
|
||
| override fun generate(waveform: ReadableArray, promise: Promise) { | ||
| val logMel = this.whisperPreprocessor.runModel(waveform) | ||
| val encoding = this.whisperEncoder.runModel(logMel) | ||
| val generatedTokens = mutableListOf(this.START_TOKEN) | ||
| var lastToken = 0 | ||
| Thread { | ||
| while (lastToken != this.EOS_TOKEN) { | ||
| this.whisperDecoder.setGeneratedTokens(generatedTokens) | ||
| lastToken = this.whisperDecoder.runModel(encoding) | ||
| emitOnToken(lastToken.toDouble()) | ||
| generatedTokens.add(lastToken) | ||
| } | ||
| val generatedTokensReadableArray = ArrayUtils.createReadableArrayFromIntArray(generatedTokens.toIntArray()) | ||
| promise.resolve(generatedTokensReadableArray) | ||
| }.start() | ||
| } | ||
|
|
||
| override fun getName(): String { | ||
| return NAME | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
33 changes: 33 additions & 0 deletions
33
android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/WhisperDecoder.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| package com.swmansion.rnexecutorch.models.speechToText | ||
|
|
||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.swmansion.rnexecutorch.models.BaseModel | ||
| import org.pytorch.executorch.EValue | ||
| import org.pytorch.executorch.Tensor | ||
|
|
||
| class WhisperDecoder( | ||
| reactApplicationContext: ReactApplicationContext, | ||
| ) : BaseModel<EValue, Int>(reactApplicationContext) { | ||
| private var generatedTokens: MutableList<Int> = mutableListOf() | ||
|
|
||
| fun setGeneratedTokens(tokens: MutableList<Int>) { | ||
| this.generatedTokens = tokens | ||
| } | ||
|
|
||
| override fun runModel(input: EValue): Int { | ||
| val tokensEValue = EValue.from(Tensor.fromBlob(this.generatedTokens.toIntArray(), longArrayOf(1, generatedTokens.size.toLong()))) | ||
| return this.module | ||
| .forward(tokensEValue, input)[0] | ||
| .toTensor() | ||
| .dataAsLongArray[0] | ||
| .toInt() | ||
| } | ||
|
|
||
| override fun preprocess(input: EValue): EValue { | ||
| TODO("Not yet implemented") | ||
| } | ||
|
|
||
| override fun postprocess(output: Array<EValue>): Int { | ||
| TODO("Not yet implemented") | ||
| } | ||
| } |
26 changes: 26 additions & 0 deletions
26
android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/WhisperEncoder.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| package com.swmansion.rnexecutorch.models.speechToText | ||
|
|
||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.swmansion.rnexecutorch.models.BaseModel | ||
| import org.pytorch.executorch.EValue | ||
| import org.pytorch.executorch.Tensor | ||
|
|
||
| class WhisperEncoder(reactApplicationContext: ReactApplicationContext) : | ||
| BaseModel<EValue, EValue>(reactApplicationContext) { | ||
| private val encoderInputShape = longArrayOf(1L, 80L, 3000L) | ||
|
|
||
| override fun runModel(input: EValue): EValue { | ||
| val inputEValue = this.preprocess(input) | ||
| val hiddenState = this.module.forward(inputEValue) | ||
| return hiddenState[0] | ||
| } | ||
|
|
||
| override fun preprocess(input: EValue): EValue { | ||
| val inputTensor = Tensor.fromBlob(input.toTensor().dataAsFloatArray, this.encoderInputShape) | ||
| return EValue.from(inputTensor) | ||
| } | ||
|
|
||
| override fun postprocess(output: Array<EValue>): EValue { | ||
| TODO("Not yet implemented") | ||
| } | ||
| } |
36 changes: 36 additions & 0 deletions
36
android/src/main/java/com/swmansion/rnexecutorch/models/speechToText/WhisperPreprocessor.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,36 @@ | ||
| package com.swmansion.rnexecutorch.models.speechToText | ||
|
|
||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.facebook.react.bridge.ReadableArray | ||
| import com.swmansion.rnexecutorch.models.BaseModel | ||
| import com.swmansion.rnexecutorch.utils.STFT | ||
| import org.pytorch.executorch.EValue | ||
| import org.pytorch.executorch.Tensor | ||
|
|
||
| class WhisperPreprocessor(reactApplicationContext: ReactApplicationContext) : | ||
| BaseModel<ReadableArray, EValue>(reactApplicationContext) { | ||
| private val fftSize = 512 | ||
| private val hopLength = 160 | ||
| private val stft = STFT(fftSize, hopLength) | ||
|
|
||
| override fun runModel(input: ReadableArray): EValue { | ||
| val size = input.size() | ||
| val inputFloatArray = FloatArray(size) | ||
| for (i in 0 until size) { | ||
| inputFloatArray[i] = input.getDouble(i).toFloat() | ||
| } | ||
| val stftResult = this.stft.fromWaveform(inputFloatArray) | ||
| val numStftFrames = stftResult.size / (this.fftSize / 2) | ||
| val preprocessorInputShape = longArrayOf(numStftFrames.toLong(), (this.fftSize / 2).toLong()) | ||
| val melSpectrogram = this.module.forward(EValue.from(Tensor.fromBlob(stftResult, preprocessorInputShape))) | ||
| return melSpectrogram[0] | ||
| } | ||
|
|
||
| override fun preprocess(input: ReadableArray): EValue { | ||
| TODO("Not yet implemented") | ||
| } | ||
|
|
||
| override fun postprocess(output: Array<EValue>): EValue { | ||
| TODO("Not yet implemented") | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
48 changes: 48 additions & 0 deletions
48
android/src/main/java/com/swmansion/rnexecutorch/utils/STFT.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| package com.swmansion.rnexecutorch.utils | ||
|
|
||
| import java.util.Vector | ||
| import kotlin.math.cos | ||
| import kotlin.math.PI | ||
| import org.jtransforms.fft.FloatFFT_1D | ||
| import kotlin.math.sqrt | ||
|
|
||
| class STFT public constructor(var fftSize: Int = 512, var hopLength: Int = 160) { | ||
| private val fftModule = FloatFFT_1D(this.fftSize.toLong()) | ||
| private val magnitudeScale = 1.0 / this.fftSize | ||
| // https://www.mathworks.com/help/signal/ref/hann.html | ||
| private val hannWindow = FloatArray(this.fftSize) { i ->0.5f - 0.5f * cos(2f * PI.toFloat() * i / this.fftSize) } | ||
|
|
||
|
|
||
| fun fromWaveform(signal: FloatArray): FloatArray { | ||
| val numFftFrames = (signal.size - this.fftSize) / this.hopLength | ||
| // The output of FFT is always 2x smaller | ||
| val stft = FloatArray(numFftFrames * (this.fftSize / 2)) | ||
|
|
||
| var windowStartIdx = 0 | ||
| var outputIndex = 0 | ||
| // TODO: i dont think the substraction at the end is needed | ||
| while (windowStartIdx + this.fftSize < signal.size - this.fftSize) { | ||
| val currentWindow = signal.copyOfRange(windowStartIdx, windowStartIdx + this.fftSize) | ||
| // Apply Hann window to the current slice | ||
| for (i in currentWindow.indices) currentWindow[i] *= this.hannWindow[i] | ||
|
|
||
| // Perform in-place FFT | ||
| this.fftModule.realForward(currentWindow) | ||
|
|
||
| stft[outputIndex++] = kotlin.math.abs(currentWindow[0]) | ||
| for (i in 1 until this.fftSize / 2 - 1) { | ||
| val real = currentWindow[2 * i] | ||
| val imag = currentWindow[2 * i + 1] | ||
|
|
||
| val currentMagnitude = (sqrt(real * real + imag * imag) * this.magnitudeScale).toFloat() | ||
| // FIXME: we don't need that, but if we remove this we have to get rid of | ||
| // reversing this operation in the preprocessing part | ||
| stft[outputIndex++] = 20 * kotlin.math.log10(currentMagnitude) | ||
| } | ||
| // Nyquist frequency | ||
| stft[outputIndex++] = kotlin.math.abs(currentWindow[1]) | ||
| windowStartIdx += this.hopLength | ||
| } | ||
| return stft | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file modified
BIN
-6.58 KB
(100%)
ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/ExecutorchLib
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file modified
BIN
+1 Byte
(100%)
ios/ExecutorchLib.xcframework/ios-arm64-simulator/ExecutorchLib.framework/Info.plist
Binary file not shown.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we cshould move those inside Whisper model, as they can vary from tokenizer to tokenizer. Or move it to another config file tokenizer based, but I like the first one better right now