diff --git a/examples/Cargo.lock b/examples/Cargo.lock index 21ec559f1..94d6b303f 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -2108,7 +2108,7 @@ dependencies = [ [[package]] name = "libwebrtc" -version = "0.3.2" +version = "0.3.7" dependencies = [ "cxx", "jni", @@ -2122,7 +2122,6 @@ dependencies = [ "serde_json", "thiserror", "tokio", - "tokio-stream", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", @@ -2152,7 +2151,7 @@ checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "livekit" -version = "0.3.2" +version = "0.6.0" dependencies = [ "futures-util", "lazy_static", @@ -2171,7 +2170,7 @@ dependencies = [ [[package]] name = "livekit-api" -version = "0.3.2" +version = "0.4.0" dependencies = [ "async-tungstenite", "base64", @@ -2196,7 +2195,7 @@ dependencies = [ [[package]] name = "livekit-protocol" -version = "0.3.2" +version = "0.3.5" dependencies = [ "futures-util", "livekit-runtime", @@ -2215,6 +2214,7 @@ name = "livekit-runtime" version = "0.3.0" dependencies = [ "tokio", + "tokio-stream", ] [[package]] @@ -4455,7 +4455,7 @@ checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "webrtc-sys" -version = "0.3.2" +version = "0.3.5" dependencies = [ "cc", "cxx", @@ -4467,7 +4467,7 @@ dependencies = [ [[package]] name = "webrtc-sys-build" -version = "0.3.2" +version = "0.3.5" dependencies = [ "fs2", "regex", diff --git a/examples/wgpu_room/src/sine_track.rs b/examples/wgpu_room/src/sine_track.rs index becaf8217..1825bb8a7 100644 --- a/examples/wgpu_room/src/sine_track.rs +++ b/examples/wgpu_room/src/sine_track.rs @@ -17,12 +17,7 @@ pub struct SineParameters { impl Default for SineParameters { fn default() -> Self { - Self { - sample_rate: 48000, - freq: 440.0, - amplitude: 1.0, - num_channels: 2, - } + Self { sample_rate: 48000, freq: 440.0, amplitude: 1.0, num_channels: 2 } } } @@ -46,7 +41,7 @@ impl SineTrack { AudioSourceOptions::default(), params.sample_rate, params.num_channels, - None, + 1000, ), params, room, @@ -65,28 +60,18 @@ impl SineTrack { RtcAudioSource::Native(self.rtc_source.clone()), ); - let task = tokio::spawn(Self::track_task( - close_rx, - self.rtc_source.clone(), - self.params.clone(), - )); + let task = + tokio::spawn(Self::track_task(close_rx, self.rtc_source.clone(), self.params.clone())); self.room .local_participant() .publish_track( LocalTrack::Audio(track.clone()), - TrackPublishOptions { - source: TrackSource::Microphone, - ..Default::default() - }, + TrackPublishOptions { source: TrackSource::Microphone, ..Default::default() }, ) .await?; - let handle = TrackHandle { - close_tx, - track, - task, - }; + let handle = TrackHandle { close_tx, track, task }; self.handle = Some(handle); Ok(()) @@ -96,10 +81,7 @@ impl SineTrack { if let Some(handle) = self.handle.take() { handle.close_tx.send(()).ok(); handle.task.await.ok(); - self.room - .local_participant() - .unpublish_track(&handle.track.sid()) - .await?; + self.room.local_participant().unpublish_track(&handle.track.sid()).await?; } Ok(()) diff --git a/libwebrtc/src/audio_source.rs b/libwebrtc/src/audio_source.rs index b430f4fb0..1df5c4878 100644 --- a/libwebrtc/src/audio_source.rs +++ b/libwebrtc/src/audio_source.rs @@ -63,18 +63,22 @@ pub mod native { options: AudioSourceOptions, sample_rate: u32, num_channels: u32, - enable_queue: Option, + queue_size_ms: u32, ) -> NativeAudioSource { Self { handle: imp_as::NativeAudioSource::new( options, sample_rate, num_channels, - enable_queue, + queue_size_ms, ), } } + pub fn clear_buffer(&self) { + self.handle.clear_buffer() + } + pub async fn capture_frame(&self, frame: &AudioFrame<'_>) -> Result<(), RtcError> { self.handle.capture_frame(frame).await } @@ -94,9 +98,5 @@ pub mod native { pub fn num_channels(&self) -> u32 { self.handle.num_channels() } - - pub fn enable_queue(&self) -> bool { - self.handle.enable_queue() - } } } diff --git a/libwebrtc/src/native/audio_source.rs b/libwebrtc/src/native/audio_source.rs index 273853b10..c7652c676 100644 --- a/libwebrtc/src/native/audio_source.rs +++ b/libwebrtc/src/native/audio_source.rs @@ -12,39 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{sync::Arc, time::Duration}; - use cxx::SharedPtr; -use livekit_runtime::interval; -use tokio::sync::{ - mpsc::{self, error::TryRecvError}, - Mutex as AsyncMutex, -}; +use tokio::sync::oneshot; use webrtc_sys::audio_track as sys_at; use crate::{audio_frame::AudioFrame, audio_source::AudioSourceOptions, RtcError, RtcErrorType}; -const BUFFER_SIZE_MS: usize = 50; - #[derive(Clone)] pub struct NativeAudioSource { sys_handle: SharedPtr, - inner: Arc>, sample_rate: u32, num_channels: u32, - samples_10ms: usize, - // whether to queue audio frames or send them immediately - // defaults to true - enable_queue: bool, - po_tx: mpsc::Sender>, -} - -struct AudioSourceInner { - buf: Box<[i16]>, - - // Amount of data from the previous frame that hasn't been sent to the libwebrtc source - // (because it requires 10ms of data) - len: usize, + queue_size_samples: u32, } impl NativeAudioSource { @@ -52,66 +31,24 @@ impl NativeAudioSource { options: AudioSourceOptions, sample_rate: u32, num_channels: u32, - enable_queue: Option, + queue_size_ms: u32, ) -> NativeAudioSource { - let samples_10ms = (sample_rate / 100 * num_channels) as usize; - let (po_tx, mut po_rx) = mpsc::channel(BUFFER_SIZE_MS / 10); - - let source = Self { - sys_handle: sys_at::ffi::new_audio_track_source(options.into()), - inner: Arc::new(AsyncMutex::new(AudioSourceInner { - buf: vec![0; samples_10ms].into_boxed_slice(), - len: 0, - })), - sample_rate, - num_channels, - samples_10ms, - enable_queue: enable_queue.unwrap_or(true), - po_tx, - }; - - livekit_runtime::spawn({ - let source = source.clone(); - async move { - let mut interval = interval(Duration::from_millis(10)); - interval.set_missed_tick_behavior(livekit_runtime::MissedTickBehavior::Delay); - let blank_data = vec![0; samples_10ms]; - let enable_queue = source.enable_queue; - - loop { - if enable_queue { - interval.tick().await; - } - - let frame = po_rx.try_recv(); - if let Err(TryRecvError::Disconnected) = frame { - break; - } - - if let Err(TryRecvError::Empty) = frame { - if enable_queue { - source.sys_handle.on_captured_frame( - &blank_data, - sample_rate, - num_channels, - blank_data.len() / num_channels as usize, - ); - } - continue; - } - - let frame = frame.unwrap(); - source.sys_handle.on_captured_frame( - &frame, - sample_rate, - num_channels, - frame.len() / num_channels as usize, - ); - } - } - }); - - source + assert!(queue_size_ms % 10 == 0, "queue_size_ms must be a multiple of 10"); + + print!( + "new audio source {} {} {} {}", + sample_rate, num_channels, queue_size_ms, options.echo_cancellation + ); + + let sys_handle = sys_at::ffi::new_audio_track_source( + options.into(), + sample_rate.try_into().unwrap(), + num_channels.try_into().unwrap(), + queue_size_ms.try_into().unwrap(), + ); + + let queue_size_samples = (queue_size_ms * sample_rate / 1000) * num_channels; + Self { sys_handle, sample_rate, num_channels, queue_size_samples } } pub fn sys_handle(&self) -> SharedPtr { @@ -134,8 +71,8 @@ impl NativeAudioSource { self.num_channels } - pub fn enable_queue(&self) -> bool { - self.enable_queue + pub fn clear_buffer(&self) { + self.sys_handle.clear_buffer(); } pub async fn capture_frame(&self, frame: &AudioFrame<'_>) -> Result<(), RtcError> { @@ -146,38 +83,36 @@ impl NativeAudioSource { }); } - let mut inner = self.inner.lock().await; - let mut samples = 0; - // split frames into 10ms chunks - loop { - let remaining_samples = frame.data.len() - samples; - if remaining_samples == 0 { - break; - } + extern "C" fn lk_audio_source_complete(userdata: *const sys_at::SourceContext) { + let tx = unsafe { Box::from_raw(userdata as *mut oneshot::Sender<()>) }; + let _ = tx.send(()); + } - if (inner.len != 0 && remaining_samples > 0) || remaining_samples < self.samples_10ms { - let missing_len = self.samples_10ms - inner.len; - let to_add = missing_len.min(remaining_samples); - let start = inner.len; - inner.buf[start..start + to_add] - .copy_from_slice(&frame.data[samples..samples + to_add]); - inner.len += to_add; - samples += to_add; - - if inner.len == self.samples_10ms { - let data = inner.buf.clone().to_vec(); - let _ = self.po_tx.send(data).await; - inner.len = 0; + // iterate over chunks of self._queue_size_samples + for chunk in frame.data.chunks(self.queue_size_samples as usize) { + let nb_frames = chunk.len() / self.num_channels as usize; + let (tx, rx) = oneshot::channel::<()>(); + let ctx = Box::new(tx); + let ctx_ptr = Box::into_raw(ctx) as *const sys_at::SourceContext; + + unsafe { + if !self.sys_handle.capture_frame( + chunk, + self.sample_rate, + self.num_channels, + nb_frames, + ctx_ptr, + sys_at::CompleteCallback(lk_audio_source_complete), + ) { + return Err(RtcError { + error_type: RtcErrorType::InvalidState, + message: "failed to capture frame".to_owned(), + }); } - continue; } - if remaining_samples >= self.samples_10ms { - // TODO(theomonnom): avoid copying - let data = frame.data[samples..samples + self.samples_10ms].to_vec(); - let _ = self.po_tx.send(data).await; - samples += self.samples_10ms; - } + let _ = rx.await; + println!("captured frame"); } Ok(()) diff --git a/livekit-ffi/protocol/audio_frame.proto b/livekit-ffi/protocol/audio_frame.proto index abe304ff5..8307d265d 100644 --- a/livekit-ffi/protocol/audio_frame.proto +++ b/livekit-ffi/protocol/audio_frame.proto @@ -46,7 +46,7 @@ message NewAudioSourceRequest { optional AudioSourceOptions options = 2; uint32 sample_rate = 3; uint32 num_channels = 4; - optional bool enable_queue = 5; + uint32 queue_size_ms = 5; } message NewAudioSourceResponse { OwnedAudioSource source = 1; } @@ -64,6 +64,11 @@ message CaptureAudioFrameCallback { optional string error = 2; } +message ClearAudioBufferRequest { + uint64 source_handle = 1; +} +message ClearAudioBufferResponse {} + // Create a new AudioResampler message NewAudioResamplerRequest {} message NewAudioResamplerResponse { diff --git a/livekit-ffi/protocol/ffi.proto b/livekit-ffi/protocol/ffi.proto index ba87d8417..173d36286 100644 --- a/livekit-ffi/protocol/ffi.proto +++ b/livekit-ffi/protocol/ffi.proto @@ -88,10 +88,11 @@ message FfiRequest { NewAudioStreamRequest new_audio_stream = 25; NewAudioSourceRequest new_audio_source = 26; CaptureAudioFrameRequest capture_audio_frame = 27; - NewAudioResamplerRequest new_audio_resampler = 28; - RemixAndResampleRequest remix_and_resample = 29; - E2eeRequest e2ee = 30; - AudioStreamFromParticipantRequest audio_stream_from_participant = 31; + ClearAudioBufferRequest clear_audio_buffer = 28; + NewAudioResamplerRequest new_audio_resampler = 29; + RemixAndResampleRequest remix_and_resample = 30; + E2eeRequest e2ee = 31; + AudioStreamFromParticipantRequest audio_stream_from_participant = 32; } } @@ -132,10 +133,11 @@ message FfiResponse { NewAudioStreamResponse new_audio_stream = 25; NewAudioSourceResponse new_audio_source = 26; CaptureAudioFrameResponse capture_audio_frame = 27; - NewAudioResamplerResponse new_audio_resampler = 28; - RemixAndResampleResponse remix_and_resample = 29; - AudioStreamFromParticipantResponse audio_stream_from_participant = 30; - E2eeResponse e2ee = 31; + ClearAudioBufferResponse clear_audio_buffer = 28; + NewAudioResamplerResponse new_audio_resampler = 29; + RemixAndResampleResponse remix_and_resample = 30; + AudioStreamFromParticipantResponse audio_stream_from_participant = 31; + E2eeResponse e2ee = 32; } } diff --git a/livekit-ffi/src/livekit.proto.rs b/livekit-ffi/src/livekit.proto.rs index 18d3eb385..1cfce9572 100644 --- a/livekit-ffi/src/livekit.proto.rs +++ b/livekit-ffi/src/livekit.proto.rs @@ -1,5 +1,4 @@ // @generated -// This file is @generated by prost-build. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FrameCryptor { @@ -2982,8 +2981,8 @@ pub struct NewAudioSourceRequest { pub sample_rate: u32, #[prost(uint32, tag="4")] pub num_channels: u32, - #[prost(bool, optional, tag="5")] - pub enable_queue: ::core::option::Option, + #[prost(uint32, tag="5")] + pub queue_size_ms: u32, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -3015,6 +3014,16 @@ pub struct CaptureAudioFrameCallback { #[prost(string, optional, tag="2")] pub error: ::core::option::Option<::prost::alloc::string::String>, } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearAudioBufferRequest { + #[prost(uint64, tag="1")] + pub source_handle: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearAudioBufferResponse { +} /// Create a new AudioResampler #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -3241,7 +3250,7 @@ impl AudioSourceType { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FfiRequest { - #[prost(oneof="ffi_request::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31")] + #[prost(oneof="ffi_request::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32")] pub message: ::core::option::Option, } /// Nested message and enum types in `FfiRequest`. @@ -3306,12 +3315,14 @@ pub mod ffi_request { #[prost(message, tag="27")] CaptureAudioFrame(super::CaptureAudioFrameRequest), #[prost(message, tag="28")] - NewAudioResampler(super::NewAudioResamplerRequest), + ClearAudioBuffer(super::ClearAudioBufferRequest), #[prost(message, tag="29")] - RemixAndResample(super::RemixAndResampleRequest), + NewAudioResampler(super::NewAudioResamplerRequest), #[prost(message, tag="30")] - E2ee(super::E2eeRequest), + RemixAndResample(super::RemixAndResampleRequest), #[prost(message, tag="31")] + E2ee(super::E2eeRequest), + #[prost(message, tag="32")] AudioStreamFromParticipant(super::AudioStreamFromParticipantRequest), } } @@ -3319,7 +3330,7 @@ pub mod ffi_request { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FfiResponse { - #[prost(oneof="ffi_response::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31")] + #[prost(oneof="ffi_response::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32")] pub message: ::core::option::Option, } /// Nested message and enum types in `FfiResponse`. @@ -3384,12 +3395,14 @@ pub mod ffi_response { #[prost(message, tag="27")] CaptureAudioFrame(super::CaptureAudioFrameResponse), #[prost(message, tag="28")] - NewAudioResampler(super::NewAudioResamplerResponse), + ClearAudioBuffer(super::ClearAudioBufferResponse), #[prost(message, tag="29")] - RemixAndResample(super::RemixAndResampleResponse), + NewAudioResampler(super::NewAudioResamplerResponse), #[prost(message, tag="30")] - AudioStreamFromParticipant(super::AudioStreamFromParticipantResponse), + RemixAndResample(super::RemixAndResampleResponse), #[prost(message, tag="31")] + AudioStreamFromParticipant(super::AudioStreamFromParticipantResponse), + #[prost(message, tag="32")] E2ee(super::E2eeResponse), } } diff --git a/livekit-ffi/src/server/audio_source.rs b/livekit-ffi/src/server/audio_source.rs index ba0f21f7a..e07acee8e 100644 --- a/livekit-ffi/src/server/audio_source.rs +++ b/livekit-ffi/src/server/audio_source.rs @@ -43,7 +43,7 @@ impl FfiAudioSource { new_source.options.map(Into::into).unwrap_or_default(), new_source.sample_rate, new_source.num_channels, - new_source.enable_queue, + new_source.queue_size_ms, ); RtcAudioSource::Native(audio_source) } @@ -62,6 +62,14 @@ impl FfiAudioSource { }) } + pub fn clear_buffer(&self) { + match self.source { + #[cfg(not(target_arch = "wasm32"))] + RtcAudioSource::Native(ref source) => source.clear_buffer(), + _ => {} + } + } + pub fn capture_frame( &self, server: &'static server::FfiServer, diff --git a/livekit-ffi/src/server/requests.rs b/livekit-ffi/src/server/requests.rs index 705e86a1d..89e395769 100644 --- a/livekit-ffi/src/server/requests.rs +++ b/livekit-ffi/src/server/requests.rs @@ -426,6 +426,16 @@ fn on_capture_audio_frame( source.capture_frame(server, push) } +// Clear the internal audio buffer (cancel all pending frames from being played) +fn on_clear_audio_buffer( + server: &'static FfiServer, + clear: proto::ClearAudioBufferRequest, +) -> FfiResult { + let source = server.retrieve_handle::(clear.source_handle)?; + source.clear_buffer(); + Ok(proto::ClearAudioBufferResponse {}) +} + /// Create a new audio resampler fn new_audio_resampler( server: &'static FfiServer, @@ -744,6 +754,9 @@ pub fn handle_request( proto::ffi_request::Message::CaptureAudioFrame(push) => { proto::ffi_response::Message::CaptureAudioFrame(on_capture_audio_frame(server, push)?) } + proto::ffi_request::Message::ClearAudioBuffer(clear) => { + proto::ffi_response::Message::ClearAudioBuffer(on_clear_audio_buffer(server, clear)?) + } proto::ffi_request::Message::NewAudioResampler(new_res) => { proto::ffi_response::Message::NewAudioResampler(new_audio_resampler(server, new_res)?) } diff --git a/webrtc-sys/build.rs b/webrtc-sys/build.rs index beabe57c7..5e87b0c9c 100644 --- a/webrtc-sys/build.rs +++ b/webrtc-sys/build.rs @@ -72,6 +72,7 @@ fn main() { "src/audio_device.cpp", "src/audio_resampler.cpp", "src/frame_cryptor.cpp", + "src/global_task_queue.cpp", ]); let webrtc_dir = webrtc_sys_build::webrtc_dir(); diff --git a/webrtc-sys/include/livekit/audio_track.h b/webrtc-sys/include/livekit/audio_track.h index 3e074785b..029158fd5 100644 --- a/webrtc-sys/include/livekit/audio_track.h +++ b/webrtc-sys/include/livekit/audio_track.h @@ -20,18 +20,25 @@ #include "api/audio/audio_frame.h" #include "api/audio_options.h" +#include "api/task_queue/task_queue_factory.h" #include "common_audio/resampler/include/push_resampler.h" #include "livekit/helper.h" #include "livekit/media_stream_track.h" #include "livekit/webrtc.h" #include "pc/local_audio_source.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/task_queue.h" +#include "rtc_base/task_utils/repeating_task.h" +#include "rtc_base/thread_annotations.h" #include "rust/cxx.h" namespace livekit { class AudioTrack; class NativeAudioSink; class AudioTrackSource; +class SourceContext; + +using CompleteCallback = void (*)(const livekit::SourceContext*); } // namespace livekit #include "webrtc-sys/src/audio_track.rs.h" @@ -91,7 +98,11 @@ std::shared_ptr new_native_audio_sink( class AudioTrackSource { class InternalSource : public webrtc::LocalAudioSource { public: - InternalSource(const cricket::AudioOptions& options); + InternalSource(const cricket::AudioOptions& options, + int sample_rate, + int num_channels, + int buffer_size_ms, + webrtc::TaskQueueFactory* task_queue_factory); SourceState state() const override; bool remote() const override; @@ -103,30 +114,53 @@ class AudioTrackSource { void set_options(const cricket::AudioOptions& options); - // AudioFrame should always contain 10 ms worth of data (see index.md of - // acm) - void on_captured_frame(rust::Slice audio_data, - uint32_t sample_rate, - uint32_t number_of_channels, - size_t number_of_frames); + bool capture_frame(rust::Slice audio_data, + uint32_t sample_rate, + uint32_t number_of_channels, + size_t number_of_frames, + const SourceContext* ctx, + void (*on_complete)(const SourceContext*)); + + void clear_buffer(); private: mutable webrtc::Mutex mutex_; - std::vector sinks_; + std::unique_ptr audio_queue_; + webrtc::RepeatingTaskHandle audio_task_; + + std::vector sinks_ RTC_GUARDED_BY(mutex_); + std::vector buffer_ RTC_GUARDED_BY(mutex_); + + const SourceContext* capture_userdata_ RTC_GUARDED_BY(mutex_); + void (*on_complete_)(const SourceContext*) RTC_GUARDED_BY(mutex_); + + int sample_rate_; + int num_channels_; + int queue_size_samples_; + int notify_threshold_samples_; + cricket::AudioOptions options_{}; }; public: - AudioTrackSource(AudioSourceOptions options); + AudioTrackSource(AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms, + webrtc::TaskQueueFactory* task_queue_factory); AudioSourceOptions audio_options() const; void set_audio_options(const AudioSourceOptions& options) const; - void on_captured_frame(rust::Slice audio_data, - uint32_t sample_rate, - uint32_t number_of_channels, - size_t number_of_frames) const; + bool capture_frame(rust::Slice audio_data, + uint32_t sample_rate, + uint32_t number_of_channels, + size_t number_of_frames, + const SourceContext* ctx, + CompleteCallback on_complete) const; + + void clear_buffer() const; rtc::scoped_refptr get() const; @@ -135,7 +169,10 @@ class AudioTrackSource { }; std::shared_ptr new_audio_track_source( - AudioSourceOptions options); + AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms); static std::shared_ptr audio_to_media( std::shared_ptr track) { @@ -151,4 +188,8 @@ static std::shared_ptr _shared_audio_track() { return nullptr; // Ignore } +static std::shared_ptr _shared_audio_track_source() { + return nullptr; // Ignore +} + } // namespace livekit diff --git a/webrtc-sys/include/livekit/global_task_queue.h b/webrtc-sys/include/livekit/global_task_queue.h new file mode 100644 index 000000000..1cf9a7ac1 --- /dev/null +++ b/webrtc-sys/include/livekit/global_task_queue.h @@ -0,0 +1,9 @@ +#pragma once + +#include "api/task_queue/task_queue_factory.h" + +namespace livekit { + +webrtc::TaskQueueFactory* GetGlobalTaskQueueFactory(); + +} // namespace livekit diff --git a/webrtc-sys/include/livekit/peer_connection_factory.h b/webrtc-sys/include/livekit/peer_connection_factory.h index 7d22c4e81..ae49842ba 100644 --- a/webrtc-sys/include/livekit/peer_connection_factory.h +++ b/webrtc-sys/include/livekit/peer_connection_factory.h @@ -18,6 +18,7 @@ #include "api/peer_connection_interface.h" #include "api/scoped_refptr.h" +#include "api/task_queue/task_queue_factory.h" #include "livekit/audio_device.h" #include "media_stream.h" #include "rtp_parameters.h" @@ -65,6 +66,7 @@ class PeerConnectionFactory { std::shared_ptr rtc_runtime_; rtc::scoped_refptr audio_device_; rtc::scoped_refptr peer_factory_; + webrtc::TaskQueueFactory* task_queue_factory_; }; std::shared_ptr create_peer_connection_factory(); diff --git a/webrtc-sys/src/audio_track.cpp b/webrtc-sys/src/audio_track.cpp index 165f8e9c7..980c22efb 100644 --- a/webrtc-sys/src/audio_track.cpp +++ b/webrtc-sys/src/audio_track.cpp @@ -18,12 +18,16 @@ #include #include +#include #include #include "api/audio_options.h" #include "api/media_stream_interface.h" +#include "api/task_queue/task_queue_base.h" #include "audio/remix_resample.h" #include "common_audio/include/audio_util.h" +#include "livekit/global_task_queue.h" +#include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/ref_counted_object.h" #include "rtc_base/synchronization/mutex.h" @@ -123,7 +127,93 @@ std::shared_ptr new_native_audio_sink( } AudioTrackSource::InternalSource::InternalSource( - const cricket::AudioOptions& options) {} + const cricket::AudioOptions& options, + int sample_rate, + int num_channels, + int queue_size_ms, // must be a multiple of 10ms + webrtc::TaskQueueFactory* task_queue_factory) + : sample_rate_(sample_rate), + num_channels_(num_channels), + capture_userdata_(nullptr), + on_complete_(nullptr) { + if (!queue_size_ms) + return; // no audio queue + + int samples10ms = sample_rate / 100 * num_channels; + queue_size_samples_ = queue_size_ms / 10 * samples10ms; + notify_threshold_samples_ = queue_size_samples_; // TODO: this is currently + // using x2 the queue size + buffer_.resize(queue_size_samples_ + notify_threshold_samples_); + + audio_queue_ = + std::make_unique(task_queue_factory->CreateTaskQueue( + "AudioSourceCapture", webrtc::TaskQueueFactory::Priority::NORMAL)); + + audio_task_ = webrtc::RepeatingTaskHandle::Start( + audio_queue_->Get(), + [this, samples10ms]() { + webrtc::MutexLock lock(&mutex_); + + if (buffer_.size() >= samples10ms) { + for (auto sink : sinks_) + sink->OnData(buffer_.data(), sizeof(int16_t), sample_rate_, + num_channels_, samples10ms / num_channels_); + + buffer_.erase(buffer_.begin(), buffer_.begin() + samples10ms); + } + + if (on_complete_ && buffer_.size() <= notify_threshold_samples_) { + on_complete_(capture_userdata_); + on_complete_ = nullptr; + capture_userdata_ = nullptr; + } + + return webrtc::TimeDelta::Millis(10); + }, + webrtc::TaskQueueBase::DelayPrecision::kHigh); +} + +bool AudioTrackSource::InternalSource::capture_frame( + rust::Slice data, + uint32_t sample_rate, + uint32_t number_of_channels, + size_t number_of_frames, + const SourceContext* ctx, + void (*on_complete)(const SourceContext*)) { + webrtc::MutexLock lock(&mutex_); + + if (queue_size_samples_) { + int available = + (queue_size_samples_ + notify_threshold_samples_) - buffer_.size(); + if (available < data.size()) + return false; + + if (on_complete_ || capture_userdata_) + return false; + + buffer_.insert(buffer_.end(), data.begin(), data.end()); + + if (buffer_.size() <= notify_threshold_samples_) { + on_complete(ctx); // complete directly + } else { + on_complete_ = on_complete; + capture_userdata_ = ctx; + } + + } else { + // capture directly when the queue buffer is 0 (frame size must be 10ms) + for (auto sink : sinks_) + sink->OnData(data.data(), sizeof(int16_t), sample_rate, + number_of_channels, number_of_frames); + } + + return true; +} + +void AudioTrackSource::InternalSource::clear_buffer() { + webrtc::MutexLock lock(&mutex_); + buffer_.clear(); +} webrtc::MediaSourceInterface::SourceState AudioTrackSource::InternalSource::state() const { @@ -157,22 +247,17 @@ void AudioTrackSource::InternalSource::RemoveSink( sinks_.erase(std::remove(sinks_.begin(), sinks_.end(), sink), sinks_.end()); } -void AudioTrackSource::InternalSource::on_captured_frame( - rust::Slice data, - uint32_t sample_rate, - uint32_t number_of_channels, - size_t number_of_frames) { - webrtc::MutexLock lock(&mutex_); - for (auto sink : sinks_) { - sink->OnData(data.data(), 16, sample_rate, number_of_channels, - number_of_frames); - } -} - -AudioTrackSource::AudioTrackSource(AudioSourceOptions options) { - source_ = - rtc::make_ref_counted(to_native_audio_options(options)); -} +AudioTrackSource::AudioTrackSource(AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms, + webrtc::TaskQueueFactory* task_queue_factory) + : source_(rtc::make_ref_counted( + to_native_audio_options(options), + sample_rate, + num_channels, + queue_size_ms, + task_queue_factory)) {} AudioSourceOptions AudioTrackSource::audio_options() const { return to_rust_audio_options(source_->options()); @@ -183,22 +268,34 @@ void AudioTrackSource::set_audio_options( source_->set_options(to_native_audio_options(options)); } -void AudioTrackSource::on_captured_frame(rust::Slice audio_data, - uint32_t sample_rate, - uint32_t number_of_channels, - size_t number_of_frames) const { - source_->on_captured_frame(audio_data, sample_rate, number_of_channels, - number_of_frames); +bool AudioTrackSource::capture_frame( + rust::Slice audio_data, + uint32_t sample_rate, + uint32_t number_of_channels, + size_t number_of_frames, + const SourceContext* ctx, + void (*on_complete)(const SourceContext*)) const { + return source_->capture_frame(audio_data, sample_rate, number_of_channels, + number_of_frames, ctx, on_complete); } -rtc::scoped_refptr AudioTrackSource::get() - const { - return source_; +void AudioTrackSource::clear_buffer() const { + source_->clear_buffer(); } std::shared_ptr new_audio_track_source( - AudioSourceOptions options) { - return std::make_shared(options); + AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms) { + return std::make_shared(options, sample_rate, num_channels, + queue_size_ms, + GetGlobalTaskQueueFactory()); +} + +rtc::scoped_refptr AudioTrackSource::get() + const { + return source_; } } // namespace livekit diff --git a/webrtc-sys/src/audio_track.rs b/webrtc-sys/src/audio_track.rs index 26b9f8308..d24f87520 100644 --- a/webrtc-sys/src/audio_track.rs +++ b/webrtc-sys/src/audio_track.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use cxx::type_id; +use cxx::ExternType; +use std::any::Any; use std::sync::Arc; use crate::impl_thread_safety; @@ -29,6 +32,7 @@ pub mod ffi { include!("livekit/media_stream_track.h"); type MediaStreamTrack = crate::media_stream_track::ffi::MediaStreamTrack; + type CompleteCallback = crate::audio_track::CompleteCallback; } unsafe extern "C++" { @@ -46,24 +50,35 @@ pub mod ffi { num_channels: i32, ) -> SharedPtr; - fn on_captured_frame( + unsafe fn capture_frame( self: &AudioTrackSource, data: &[i16], sample_rate: u32, nb_channels: u32, nb_frames: usize, - ); + userdata: *const SourceContext, + on_complete: CompleteCallback, + ) -> bool; + fn clear_buffer(self: &AudioTrackSource); fn audio_options(self: &AudioTrackSource) -> AudioSourceOptions; fn set_audio_options(self: &AudioTrackSource, options: &AudioSourceOptions); - fn new_audio_track_source(options: AudioSourceOptions) -> SharedPtr; + + fn new_audio_track_source( + options: AudioSourceOptions, + sample_rate: i32, + num_channels: i32, + queue_size_ms: i32, + ) -> SharedPtr; fn audio_to_media(track: SharedPtr) -> SharedPtr; unsafe fn media_to_audio(track: SharedPtr) -> SharedPtr; fn _shared_audio_track() -> SharedPtr; + fn _shared_audio_track_source() -> SharedPtr; } extern "Rust" { type AudioSinkWrapper; + type SourceContext; fn on_data( self: &AudioSinkWrapper, @@ -79,6 +94,17 @@ impl_thread_safety!(ffi::AudioTrack, Send + Sync); impl_thread_safety!(ffi::NativeAudioSink, Send + Sync); impl_thread_safety!(ffi::AudioTrackSource, Send + Sync); +#[repr(transparent)] +pub struct SourceContext(pub Box); + +#[repr(transparent)] +pub struct CompleteCallback(pub extern "C" fn(ctx: *const SourceContext)); + +unsafe impl ExternType for CompleteCallback { + type Id = type_id!("livekit::CompleteCallback"); + type Kind = cxx::kind::Trivial; +} + pub trait AudioSink: Send { fn on_data(&self, data: &[i16], sample_rate: i32, nb_channels: usize, nb_frames: usize); } diff --git a/webrtc-sys/src/global_task_queue.cpp b/webrtc-sys/src/global_task_queue.cpp new file mode 100644 index 000000000..0f0cfc2b3 --- /dev/null +++ b/webrtc-sys/src/global_task_queue.cpp @@ -0,0 +1,14 @@ +#include "livekit/global_task_queue.h" + +#include "api/task_queue/default_task_queue_factory.h" +#include "api/task_queue/task_queue_factory.h" + +namespace livekit { + +webrtc::TaskQueueFactory* GetGlobalTaskQueueFactory() { + static std::unique_ptr global_task_queue_factory = + webrtc::CreateDefaultTaskQueueFactory(); + return global_task_queue_factory.get(); +} + +} // namespace livekit diff --git a/webrtc-sys/src/peer_connection_factory.cpp b/webrtc-sys/src/peer_connection_factory.cpp index 28e400f63..bbc73bfdc 100644 --- a/webrtc-sys/src/peer_connection_factory.cpp +++ b/webrtc-sys/src/peer_connection_factory.cpp @@ -28,6 +28,7 @@ #include "api/video_codecs/builtin_video_decoder_factory.h" #include "api/video_codecs/builtin_video_encoder_factory.h" #include "livekit/audio_device.h" +#include "livekit/audio_track.h" #include "livekit/peer_connection.h" #include "livekit/rtc_error.h" #include "livekit/rtp_parameters.h" @@ -83,6 +84,9 @@ PeerConnectionFactory::PeerConnectionFactory( peer_factory_ = webrtc::CreateModularPeerConnectionFactory(std::move(dependencies)); + + task_queue_factory_ = dependencies.task_queue_factory.get(); + if (peer_factory_.get() == nullptr) { RTC_LOG_ERR(LS_ERROR) << "Failed to create PeerConnectionFactory"; return; @@ -126,6 +130,7 @@ std::shared_ptr PeerConnectionFactory::create_audio_track( rtc_runtime_->get_or_create_media_stream_track( peer_factory_->CreateAudioTrack(label.c_str(), source->get().get()))); } + RtpCapabilities PeerConnectionFactory::rtp_sender_capabilities( MediaType type) const { return to_rust_rtp_capabilities(peer_factory_->GetRtpSenderCapabilities( diff --git a/webrtc-sys/src/peer_connection_factory.rs b/webrtc-sys/src/peer_connection_factory.rs index 51efc2f46..1ae912848 100644 --- a/webrtc-sys/src/peer_connection_factory.rs +++ b/webrtc-sys/src/peer_connection_factory.rs @@ -49,6 +49,7 @@ pub mod ffi { include!("livekit/jsep.h"); include!("livekit/webrtc.h"); include!("livekit/peer_connection.h"); + include!("livekit/audio_track.h"); type RtcConfiguration = crate::peer_connection::ffi::RtcConfiguration; type PeerConnectionState = crate::peer_connection::ffi::PeerConnectionState;