From 8f324d99a9819e3bd7707516a07cdf9bd2d38398 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Thu, 5 Sep 2024 13:59:17 -0700 Subject: [PATCH 1/7] wip --- Cargo.lock | 2 +- libwebrtc/src/native/audio_source.rs | 3 +- webrtc-sys/include/livekit/audio_track.h | 54 ++++++-- .../include/livekit/peer_connection_factory.h | 9 ++ webrtc-sys/src/audio_track.cpp | 119 ++++++++++++++---- webrtc-sys/src/audio_track.rs | 6 +- webrtc-sys/src/peer_connection_factory.cpp | 15 +++ webrtc-sys/src/peer_connection_factory.rs | 11 ++ 8 files changed, 177 insertions(+), 42 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cd647d60e..869a0da17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1591,7 +1591,7 @@ dependencies = [ [[package]] name = "livekit-ffi" -version = "0.8.2" +version = "0.8.3" dependencies = [ "console-subscriber", "dashmap", diff --git a/libwebrtc/src/native/audio_source.rs b/libwebrtc/src/native/audio_source.rs index 273853b10..e1ed81178 100644 --- a/libwebrtc/src/native/audio_source.rs +++ b/libwebrtc/src/native/audio_source.rs @@ -96,7 +96,8 @@ impl NativeAudioSource { num_channels, blank_data.len() / num_channels as usize, ); - } + + continue; } diff --git a/webrtc-sys/include/livekit/audio_track.h b/webrtc-sys/include/livekit/audio_track.h index 3e074785b..31ab38229 100644 --- a/webrtc-sys/include/livekit/audio_track.h +++ b/webrtc-sys/include/livekit/audio_track.h @@ -20,12 +20,15 @@ #include "api/audio/audio_frame.h" #include "api/audio_options.h" +#include "api/task_queue/task_queue_factory.h" #include "common_audio/resampler/include/push_resampler.h" #include "livekit/helper.h" #include "livekit/media_stream_track.h" #include "livekit/webrtc.h" #include "pc/local_audio_source.h" #include "rtc_base/synchronization/mutex.h" +#include "rtc_base/task_queue.h" +#include "rtc_base/task_utils/repeating_task.h" #include "rust/cxx.h" namespace livekit { @@ -91,7 +94,12 @@ std::shared_ptr new_native_audio_sink( class AudioTrackSource { class InternalSource : public webrtc::LocalAudioSource { public: - InternalSource(const cricket::AudioOptions& options); + InternalSource(const cricket::AudioOptions& options, + int sample_rate, + int num_channels, + int buffer_size_ms, + rust::Fn data_needed, + webrtc::TaskQueueFactory* task_queue_factory); SourceState state() const override; bool remote() const override; @@ -105,28 +113,48 @@ class AudioTrackSource { // AudioFrame should always contain 10 ms worth of data (see index.md of // acm) - void on_captured_frame(rust::Slice audio_data, - uint32_t sample_rate, - uint32_t number_of_channels, - size_t number_of_frames); + bool capture_frame(rust::Slice audio_data, + uint32_t sample_rate, + uint32_t number_of_channels, + size_t number_of_frames); + + void clear_buffer(); private: mutable webrtc::Mutex mutex_; + std::unique_ptr audio_queue_; + webrtc::RepeatingTaskHandle audio_task_; std::vector sinks_; + std::vector buffer_; + rust::Fn data_needed_; + + int sample_rate_; + int num_channels_; + int queue_size_samples_; + + bool data_requested_; + cricket::AudioOptions options_{}; }; public: - AudioTrackSource(AudioSourceOptions options); + AudioTrackSource(AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms, + rust::Fn data_needed, + webrtc::TaskQueueFactory* task_queue_factory); AudioSourceOptions audio_options() const; void set_audio_options(const AudioSourceOptions& options) const; - void on_captured_frame(rust::Slice audio_data, - uint32_t sample_rate, - uint32_t number_of_channels, - size_t number_of_frames) const; + bool capture_frame(rust::Slice audio_data, + uint32_t sample_rate, + uint32_t number_of_channels, + size_t number_of_frames) const; + + void clear_buffer(); rtc::scoped_refptr get() const; @@ -134,8 +162,6 @@ class AudioTrackSource { rtc::scoped_refptr source_; }; -std::shared_ptr new_audio_track_source( - AudioSourceOptions options); static std::shared_ptr audio_to_media( std::shared_ptr track) { @@ -151,4 +177,8 @@ static std::shared_ptr _shared_audio_track() { return nullptr; // Ignore } +static std::shared_ptr _shared_audio_track_source() { + return nullptr; // Ignore +} + } // namespace livekit diff --git a/webrtc-sys/include/livekit/peer_connection_factory.h b/webrtc-sys/include/livekit/peer_connection_factory.h index 7d22c4e81..1e25acca2 100644 --- a/webrtc-sys/include/livekit/peer_connection_factory.h +++ b/webrtc-sys/include/livekit/peer_connection_factory.h @@ -18,6 +18,7 @@ #include "api/peer_connection_interface.h" #include "api/scoped_refptr.h" +#include "api/task_queue/task_queue_factory.h" #include "livekit/audio_device.h" #include "media_stream.h" #include "rtp_parameters.h" @@ -55,6 +56,13 @@ class PeerConnectionFactory { rust::String label, std::shared_ptr source) const; + std::shared_ptr create_audio_source( + AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms, + rust::Fn data_needed) const; + RtpCapabilities rtp_sender_capabilities(MediaType type) const; RtpCapabilities rtp_receiver_capabilities(MediaType type) const; @@ -65,6 +73,7 @@ class PeerConnectionFactory { std::shared_ptr rtc_runtime_; rtc::scoped_refptr audio_device_; rtc::scoped_refptr peer_factory_; + webrtc::TaskQueueFactory* task_queue_factory_; }; std::shared_ptr create_peer_connection_factory(); diff --git a/webrtc-sys/src/audio_track.cpp b/webrtc-sys/src/audio_track.cpp index 165f8e9c7..62705d0fd 100644 --- a/webrtc-sys/src/audio_track.cpp +++ b/webrtc-sys/src/audio_track.cpp @@ -18,12 +18,14 @@ #include #include +#include #include #include "api/audio_options.h" #include "api/media_stream_interface.h" #include "audio/remix_resample.h" #include "common_audio/include/audio_util.h" +#include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/ref_counted_object.h" #include "rtc_base/synchronization/mutex.h" @@ -123,7 +125,73 @@ std::shared_ptr new_native_audio_sink( } AudioTrackSource::InternalSource::InternalSource( - const cricket::AudioOptions& options) {} + const cricket::AudioOptions& options, + int sample_rate, + int num_channels, + int queue_size_ms, // must be a multiple of 10ms + rust::Fn data_needed, + webrtc::TaskQueueFactory* task_queue_factory) + : sample_rate_(sample_rate), + num_channels_(num_channels), + data_needed_(std::move(data_needed)), + data_requested_(false) { + if (queue_size_ms > 0) { + return; // no audio queue + } + + int samples10ms = sample_rate / 100 * num_channels; + int notify_threshold = samples10ms * 2; + + buffer_.resize(queue_size_ms / 10 * samples10ms); + + audio_queue_ = + std::make_unique(task_queue_factory->CreateTaskQueue( + "AudioSourceCapture", webrtc::TaskQueueFactory::Priority::NORMAL)); + + audio_task_ = webrtc::RepeatingTaskHandle::Start( + audio_queue_->Get(), [this, samples10ms, notify_threshold]() { + webrtc::MutexLock lock(&mutex_); + + if (buffer_.size() >= samples10ms) { + for (auto sink : sinks_) + sink->OnData(buffer_.data(), sizeof(int16_t), sample_rate_, + num_channels_, samples10ms / num_channels_); + + buffer_.erase(buffer_.begin(), buffer_.begin() + samples10ms); + } + + if (!data_requested_ && buffer_.size() <= notify_threshold) { + data_requested_ = true; + this->data_needed_(); + } + + return webrtc::TimeDelta::Millis(10); + }); +} + +bool AudioTrackSource::InternalSource::capture_frame( + rust::Slice data, + uint32_t sample_rate, + uint32_t number_of_channels, + size_t number_of_frames) { + webrtc::MutexLock lock(&mutex_); + + if (queue_size_samples_) { + int available = queue_size_samples_ - buffer_.size(); + if (available < data.size()) + return false; + + buffer_.insert(buffer_.end(), data.begin(), data.end()); + data_requested_ = false; + } else { + // capture directly (realtime source) + for (auto sink : sinks_) + sink->OnData(data.data(), sizeof(int16_t), sample_rate, + number_of_channels, number_of_frames); + } + + return true; +} webrtc::MediaSourceInterface::SourceState AudioTrackSource::InternalSource::state() const { @@ -157,22 +225,24 @@ void AudioTrackSource::InternalSource::RemoveSink( sinks_.erase(std::remove(sinks_.begin(), sinks_.end(), sink), sinks_.end()); } -void AudioTrackSource::InternalSource::on_captured_frame( - rust::Slice data, - uint32_t sample_rate, - uint32_t number_of_channels, - size_t number_of_frames) { +void AudioTrackSource::InternalSource::clear_buffer() { webrtc::MutexLock lock(&mutex_); - for (auto sink : sinks_) { - sink->OnData(data.data(), 16, sample_rate, number_of_channels, - number_of_frames); - } + buffer_.clear(); } -AudioTrackSource::AudioTrackSource(AudioSourceOptions options) { - source_ = - rtc::make_ref_counted(to_native_audio_options(options)); -} +AudioTrackSource::AudioTrackSource(AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms, + rust::Fn data_needed, + webrtc::TaskQueueFactory* task_queue_factory) + : source_(rtc::make_ref_counted( + to_native_audio_options(options), + sample_rate, + num_channels, + queue_size_ms, + std::move(data_needed), + task_queue_factory)) {} AudioSourceOptions AudioTrackSource::audio_options() const { return to_rust_audio_options(source_->options()); @@ -183,12 +253,16 @@ void AudioTrackSource::set_audio_options( source_->set_options(to_native_audio_options(options)); } -void AudioTrackSource::on_captured_frame(rust::Slice audio_data, - uint32_t sample_rate, - uint32_t number_of_channels, - size_t number_of_frames) const { - source_->on_captured_frame(audio_data, sample_rate, number_of_channels, - number_of_frames); +bool AudioTrackSource::capture_frame(rust::Slice audio_data, + uint32_t sample_rate, + uint32_t number_of_channels, + size_t number_of_frames) const { + return source_->capture_frame(audio_data, sample_rate, number_of_channels, + number_of_frames); +} + +void AudioTrackSource::clear_buffer() { + source_->clear_buffer(); } rtc::scoped_refptr AudioTrackSource::get() @@ -196,9 +270,4 @@ rtc::scoped_refptr AudioTrackSource::get() return source_; } -std::shared_ptr new_audio_track_source( - AudioSourceOptions options) { - return std::make_shared(options); -} - } // namespace livekit diff --git a/webrtc-sys/src/audio_track.rs b/webrtc-sys/src/audio_track.rs index 26b9f8308..5cdd3fc6e 100644 --- a/webrtc-sys/src/audio_track.rs +++ b/webrtc-sys/src/audio_track.rs @@ -46,20 +46,20 @@ pub mod ffi { num_channels: i32, ) -> SharedPtr; - fn on_captured_frame( + fn capture_frame( self: &AudioTrackSource, data: &[i16], sample_rate: u32, nb_channels: u32, nb_frames: usize, - ); + ) -> bool; fn audio_options(self: &AudioTrackSource) -> AudioSourceOptions; fn set_audio_options(self: &AudioTrackSource, options: &AudioSourceOptions); - fn new_audio_track_source(options: AudioSourceOptions) -> SharedPtr; fn audio_to_media(track: SharedPtr) -> SharedPtr; unsafe fn media_to_audio(track: SharedPtr) -> SharedPtr; fn _shared_audio_track() -> SharedPtr; + fn _shared_audio_track_source() -> SharedPtr; } extern "Rust" { diff --git a/webrtc-sys/src/peer_connection_factory.cpp b/webrtc-sys/src/peer_connection_factory.cpp index 28e400f63..9a5a9c313 100644 --- a/webrtc-sys/src/peer_connection_factory.cpp +++ b/webrtc-sys/src/peer_connection_factory.cpp @@ -28,6 +28,7 @@ #include "api/video_codecs/builtin_video_decoder_factory.h" #include "api/video_codecs/builtin_video_encoder_factory.h" #include "livekit/audio_device.h" +#include "livekit/audio_track.h" #include "livekit/peer_connection.h" #include "livekit/rtc_error.h" #include "livekit/rtp_parameters.h" @@ -83,6 +84,8 @@ PeerConnectionFactory::PeerConnectionFactory( peer_factory_ = webrtc::CreateModularPeerConnectionFactory(std::move(dependencies)); + task_queue_factory_ = dependencies.task_queue_factory.get(); + if (peer_factory_.get() == nullptr) { RTC_LOG_ERR(LS_ERROR) << "Failed to create PeerConnectionFactory"; return; @@ -126,6 +129,18 @@ std::shared_ptr PeerConnectionFactory::create_audio_track( rtc_runtime_->get_or_create_media_stream_track( peer_factory_->CreateAudioTrack(label.c_str(), source->get().get()))); } + +std::shared_ptr PeerConnectionFactory::create_audio_source( + AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms, + rust::Fn data_needed) const { + return std::make_shared( + options, sample_rate, num_channels, queue_size_ms, std::move(data_needed), + task_queue_factory_); +} + RtpCapabilities PeerConnectionFactory::rtp_sender_capabilities( MediaType type) const { return to_rust_rtp_capabilities(peer_factory_->GetRtpSenderCapabilities( diff --git a/webrtc-sys/src/peer_connection_factory.rs b/webrtc-sys/src/peer_connection_factory.rs index 51efc2f46..b6a9cc0e2 100644 --- a/webrtc-sys/src/peer_connection_factory.rs +++ b/webrtc-sys/src/peer_connection_factory.rs @@ -49,6 +49,7 @@ pub mod ffi { include!("livekit/jsep.h"); include!("livekit/webrtc.h"); include!("livekit/peer_connection.h"); + include!("livekit/audio_track.h"); type RtcConfiguration = crate::peer_connection::ffi::RtcConfiguration; type PeerConnectionState = crate::peer_connection::ffi::PeerConnectionState; @@ -78,6 +79,7 @@ pub mod ffi { type MediaStreamTrack = crate::media_stream::ffi::MediaStreamTrack; type SessionDescription = crate::jsep::ffi::SessionDescription; type MediaType = crate::webrtc::ffi::MediaType; + type AudioSourceOptions = crate::audio_track::ffi::AudioSourceOptions; } unsafe extern "C++" { @@ -106,6 +108,15 @@ pub mod ffi { source: SharedPtr, ) -> SharedPtr; + fn create_audio_source( + self: &PeerConnectionFactory, + options: AudioSourceOptions, + sample_rate: i32, + num_channels: i32, + queue_size_ms: i32, + data_needed: fn(), + ) -> SharedPtr; + fn rtp_sender_capabilities( self: &PeerConnectionFactory, kind: MediaType, From bbf5f70d3f904dbf70296a2763dc8e742d438953 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 11 Sep 2024 17:29:31 -0700 Subject: [PATCH 2/7] wip test --- Cargo.lock | 10 +- libwebrtc/src/audio_source.rs | 8 +- libwebrtc/src/native/audio_source.rs | 152 +++++------------- livekit-ffi/protocol/audio_frame.proto | 2 +- livekit-ffi/src/livekit.proto.rs | 4 +- livekit-ffi/src/server/audio_source.rs | 2 +- webrtc-sys/build.rs | 1 + webrtc-sys/include/livekit/audio_track.h | 32 ++-- .../include/livekit/global_task_queue.h | 9 ++ .../include/livekit/peer_connection_factory.h | 7 - webrtc-sys/src/audio_track.cpp | 65 +++++--- webrtc-sys/src/audio_track.rs | 27 +++- webrtc-sys/src/global_task_queue.cpp | 14 ++ webrtc-sys/src/peer_connection_factory.cpp | 12 +- webrtc-sys/src/peer_connection_factory.rs | 10 -- 15 files changed, 166 insertions(+), 189 deletions(-) create mode 100644 webrtc-sys/include/livekit/global_task_queue.h create mode 100644 webrtc-sys/src/global_task_queue.cpp diff --git a/Cargo.lock b/Cargo.lock index 869a0da17..f24ff4777 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1490,7 +1490,7 @@ dependencies = [ [[package]] name = "libwebrtc" -version = "0.3.5" +version = "0.3.7" dependencies = [ "cxx", "env_logger", @@ -1546,7 +1546,7 @@ checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "livekit" -version = "0.5.1" +version = "0.5.3" dependencies = [ "futures-util", "lazy_static", @@ -1591,7 +1591,7 @@ dependencies = [ [[package]] name = "livekit-ffi" -version = "0.8.3" +version = "0.8.2" dependencies = [ "console-subscriber", "dashmap", @@ -3234,7 +3234,7 @@ checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "webrtc-sys" -version = "0.3.3" +version = "0.3.5" dependencies = [ "cc", "cxx", @@ -3247,7 +3247,7 @@ dependencies = [ [[package]] name = "webrtc-sys-build" -version = "0.3.3" +version = "0.3.5" dependencies = [ "fs2", "regex", diff --git a/libwebrtc/src/audio_source.rs b/libwebrtc/src/audio_source.rs index b430f4fb0..f2f0ffac1 100644 --- a/libwebrtc/src/audio_source.rs +++ b/libwebrtc/src/audio_source.rs @@ -63,14 +63,14 @@ pub mod native { options: AudioSourceOptions, sample_rate: u32, num_channels: u32, - enable_queue: Option, + queue_size_ms: u32, ) -> NativeAudioSource { Self { handle: imp_as::NativeAudioSource::new( options, sample_rate, num_channels, - enable_queue, + queue_size_ms, ), } } @@ -94,9 +94,5 @@ pub mod native { pub fn num_channels(&self) -> u32 { self.handle.num_channels() } - - pub fn enable_queue(&self) -> bool { - self.handle.enable_queue() - } } } diff --git a/libwebrtc/src/native/audio_source.rs b/libwebrtc/src/native/audio_source.rs index e1ed81178..e01c411db 100644 --- a/libwebrtc/src/native/audio_source.rs +++ b/libwebrtc/src/native/audio_source.rs @@ -12,39 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{sync::Arc, time::Duration}; - use cxx::SharedPtr; -use livekit_runtime::interval; -use tokio::sync::{ - mpsc::{self, error::TryRecvError}, - Mutex as AsyncMutex, -}; +use tokio::sync::oneshot; use webrtc_sys::audio_track as sys_at; use crate::{audio_frame::AudioFrame, audio_source::AudioSourceOptions, RtcError, RtcErrorType}; -const BUFFER_SIZE_MS: usize = 50; - #[derive(Clone)] pub struct NativeAudioSource { sys_handle: SharedPtr, - inner: Arc>, sample_rate: u32, num_channels: u32, - samples_10ms: usize, - // whether to queue audio frames or send them immediately - // defaults to true - enable_queue: bool, - po_tx: mpsc::Sender>, -} - -struct AudioSourceInner { - buf: Box<[i16]>, - - // Amount of data from the previous frame that hasn't been sent to the libwebrtc source - // (because it requires 10ms of data) - len: usize, + queue_size_samples: u32, } impl NativeAudioSource { @@ -52,67 +31,19 @@ impl NativeAudioSource { options: AudioSourceOptions, sample_rate: u32, num_channels: u32, - enable_queue: Option, + queue_size_ms: u32, ) -> NativeAudioSource { - let samples_10ms = (sample_rate / 100 * num_channels) as usize; - let (po_tx, mut po_rx) = mpsc::channel(BUFFER_SIZE_MS / 10); - - let source = Self { - sys_handle: sys_at::ffi::new_audio_track_source(options.into()), - inner: Arc::new(AsyncMutex::new(AudioSourceInner { - buf: vec![0; samples_10ms].into_boxed_slice(), - len: 0, - })), - sample_rate, - num_channels, - samples_10ms, - enable_queue: enable_queue.unwrap_or(true), - po_tx, - }; - - livekit_runtime::spawn({ - let source = source.clone(); - async move { - let mut interval = interval(Duration::from_millis(10)); - interval.set_missed_tick_behavior(livekit_runtime::MissedTickBehavior::Delay); - let blank_data = vec![0; samples_10ms]; - let enable_queue = source.enable_queue; - - loop { - if enable_queue { - interval.tick().await; - } - - let frame = po_rx.try_recv(); - if let Err(TryRecvError::Disconnected) = frame { - break; - } - - if let Err(TryRecvError::Empty) = frame { - if enable_queue { - source.sys_handle.on_captured_frame( - &blank_data, - sample_rate, - num_channels, - blank_data.len() / num_channels as usize, - ); - - - continue; - } - - let frame = frame.unwrap(); - source.sys_handle.on_captured_frame( - &frame, - sample_rate, - num_channels, - frame.len() / num_channels as usize, - ); - } - } - }); + assert!(queue_size_ms % 10 == 0, "queue_size_ms must be a multiple of 10"); - source + let sys_handle = sys_at::ffi::new_audio_track_source( + options.into(), + sample_rate.try_into().unwrap(), + num_channels.try_into().unwrap(), + queue_size_ms.try_into().unwrap(), + ); + + let queue_size_samples = (queue_size_ms * sample_rate / 1000) * num_channels; + Self { sys_handle, sample_rate, num_channels, queue_size_samples } } pub fn sys_handle(&self) -> SharedPtr { @@ -135,10 +66,6 @@ impl NativeAudioSource { self.num_channels } - pub fn enable_queue(&self) -> bool { - self.enable_queue - } - pub async fn capture_frame(&self, frame: &AudioFrame<'_>) -> Result<(), RtcError> { if self.sample_rate != frame.sample_rate || self.num_channels != frame.num_channels { return Err(RtcError { @@ -147,38 +74,35 @@ impl NativeAudioSource { }); } - let mut inner = self.inner.lock().await; - let mut samples = 0; - // split frames into 10ms chunks - loop { - let remaining_samples = frame.data.len() - samples; - if remaining_samples == 0 { - break; - } + extern "C" fn lk_audio_source_complete(userdata: *const sys_at::SourceContext) { + let tx = unsafe { Box::from_raw(userdata as *mut oneshot::Sender<()>) }; + let _ = tx.send(()); + } - if (inner.len != 0 && remaining_samples > 0) || remaining_samples < self.samples_10ms { - let missing_len = self.samples_10ms - inner.len; - let to_add = missing_len.min(remaining_samples); - let start = inner.len; - inner.buf[start..start + to_add] - .copy_from_slice(&frame.data[samples..samples + to_add]); - inner.len += to_add; - samples += to_add; - - if inner.len == self.samples_10ms { - let data = inner.buf.clone().to_vec(); - let _ = self.po_tx.send(data).await; - inner.len = 0; + // iterate over chunks of self._queue_size_samples + for chunk in frame.data.chunks(self.queue_size_samples as usize) { + let nb_frames = chunk.len() / self.num_channels as usize; + let (tx, rx) = oneshot::channel::<()>(); + let ctx = Box::new(tx); + let ctx_ptr = Box::into_raw(ctx) as *const sys_at::SourceContext; + + unsafe { + if !self.sys_handle.capture_frame( + chunk, + self.sample_rate, + self.num_channels, + nb_frames, + ctx_ptr, + sys_at::CompleteCallback(lk_audio_source_complete), + ) { + return Err(RtcError { + error_type: RtcErrorType::InvalidState, + message: "failed to capture frame".to_owned(), + }); } - continue; } - if remaining_samples >= self.samples_10ms { - // TODO(theomonnom): avoid copying - let data = frame.data[samples..samples + self.samples_10ms].to_vec(); - let _ = self.po_tx.send(data).await; - samples += self.samples_10ms; - } + let _ = rx.await; } Ok(()) diff --git a/livekit-ffi/protocol/audio_frame.proto b/livekit-ffi/protocol/audio_frame.proto index 35f6b8372..f12150c04 100644 --- a/livekit-ffi/protocol/audio_frame.proto +++ b/livekit-ffi/protocol/audio_frame.proto @@ -35,7 +35,7 @@ message NewAudioSourceRequest { optional AudioSourceOptions options = 2; uint32 sample_rate = 3; uint32 num_channels = 4; - optional bool enable_queue = 5; + uint32 queue_size_ms = 5; } message NewAudioSourceResponse { OwnedAudioSource source = 1; } diff --git a/livekit-ffi/src/livekit.proto.rs b/livekit-ffi/src/livekit.proto.rs index 141966bba..7dcb567aa 100644 --- a/livekit-ffi/src/livekit.proto.rs +++ b/livekit-ffi/src/livekit.proto.rs @@ -2903,8 +2903,8 @@ pub struct NewAudioSourceRequest { pub sample_rate: u32, #[prost(uint32, tag="4")] pub num_channels: u32, - #[prost(bool, optional, tag="5")] - pub enable_queue: ::core::option::Option, + #[prost(uint32, tag="5")] + pub queue_size_ms: u32, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/livekit-ffi/src/server/audio_source.rs b/livekit-ffi/src/server/audio_source.rs index ba0f21f7a..904cba1a4 100644 --- a/livekit-ffi/src/server/audio_source.rs +++ b/livekit-ffi/src/server/audio_source.rs @@ -43,7 +43,7 @@ impl FfiAudioSource { new_source.options.map(Into::into).unwrap_or_default(), new_source.sample_rate, new_source.num_channels, - new_source.enable_queue, + new_source.queue_size_ms, ); RtcAudioSource::Native(audio_source) } diff --git a/webrtc-sys/build.rs b/webrtc-sys/build.rs index beabe57c7..5e87b0c9c 100644 --- a/webrtc-sys/build.rs +++ b/webrtc-sys/build.rs @@ -72,6 +72,7 @@ fn main() { "src/audio_device.cpp", "src/audio_resampler.cpp", "src/frame_cryptor.cpp", + "src/global_task_queue.cpp", ]); let webrtc_dir = webrtc_sys_build::webrtc_dir(); diff --git a/webrtc-sys/include/livekit/audio_track.h b/webrtc-sys/include/livekit/audio_track.h index 31ab38229..1ed24664c 100644 --- a/webrtc-sys/include/livekit/audio_track.h +++ b/webrtc-sys/include/livekit/audio_track.h @@ -29,12 +29,16 @@ #include "rtc_base/synchronization/mutex.h" #include "rtc_base/task_queue.h" #include "rtc_base/task_utils/repeating_task.h" +#include "rtc_base/thread_annotations.h" #include "rust/cxx.h" namespace livekit { class AudioTrack; class NativeAudioSink; class AudioTrackSource; +class SourceContext; + +using CompleteCallback = void (*)(const livekit::SourceContext*); } // namespace livekit #include "webrtc-sys/src/audio_track.rs.h" @@ -98,7 +102,6 @@ class AudioTrackSource { int sample_rate, int num_channels, int buffer_size_ms, - rust::Fn data_needed, webrtc::TaskQueueFactory* task_queue_factory); SourceState state() const override; @@ -111,12 +114,12 @@ class AudioTrackSource { void set_options(const cricket::AudioOptions& options); - // AudioFrame should always contain 10 ms worth of data (see index.md of - // acm) bool capture_frame(rust::Slice audio_data, uint32_t sample_rate, uint32_t number_of_channels, - size_t number_of_frames); + size_t number_of_frames, + const SourceContext* ctx, + void (*on_complete)(const SourceContext*)); void clear_buffer(); @@ -124,16 +127,17 @@ class AudioTrackSource { mutable webrtc::Mutex mutex_; std::unique_ptr audio_queue_; webrtc::RepeatingTaskHandle audio_task_; - std::vector sinks_; - std::vector buffer_; - rust::Fn data_needed_; + + std::vector sinks_ RTC_GUARDED_BY(mutex_); + std::vector buffer_ RTC_GUARDED_BY(mutex_); + + const SourceContext* capture_userdata_ RTC_GUARDED_BY(mutex_); + void (*on_complete_)(const SourceContext*) RTC_GUARDED_BY(mutex_); int sample_rate_; int num_channels_; int queue_size_samples_; - bool data_requested_; - cricket::AudioOptions options_{}; }; @@ -142,7 +146,6 @@ class AudioTrackSource { int sample_rate, int num_channels, int queue_size_ms, - rust::Fn data_needed, webrtc::TaskQueueFactory* task_queue_factory); AudioSourceOptions audio_options() const; @@ -152,7 +155,9 @@ class AudioTrackSource { bool capture_frame(rust::Slice audio_data, uint32_t sample_rate, uint32_t number_of_channels, - size_t number_of_frames) const; + size_t number_of_frames, + const SourceContext* ctx, + CompleteCallback on_complete) const; void clear_buffer(); @@ -162,6 +167,11 @@ class AudioTrackSource { rtc::scoped_refptr source_; }; +std::shared_ptr new_audio_track_source( + AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms); static std::shared_ptr audio_to_media( std::shared_ptr track) { diff --git a/webrtc-sys/include/livekit/global_task_queue.h b/webrtc-sys/include/livekit/global_task_queue.h new file mode 100644 index 000000000..1cf9a7ac1 --- /dev/null +++ b/webrtc-sys/include/livekit/global_task_queue.h @@ -0,0 +1,9 @@ +#pragma once + +#include "api/task_queue/task_queue_factory.h" + +namespace livekit { + +webrtc::TaskQueueFactory* GetGlobalTaskQueueFactory(); + +} // namespace livekit diff --git a/webrtc-sys/include/livekit/peer_connection_factory.h b/webrtc-sys/include/livekit/peer_connection_factory.h index 1e25acca2..ae49842ba 100644 --- a/webrtc-sys/include/livekit/peer_connection_factory.h +++ b/webrtc-sys/include/livekit/peer_connection_factory.h @@ -56,13 +56,6 @@ class PeerConnectionFactory { rust::String label, std::shared_ptr source) const; - std::shared_ptr create_audio_source( - AudioSourceOptions options, - int sample_rate, - int num_channels, - int queue_size_ms, - rust::Fn data_needed) const; - RtpCapabilities rtp_sender_capabilities(MediaType type) const; RtpCapabilities rtp_receiver_capabilities(MediaType type) const; diff --git a/webrtc-sys/src/audio_track.cpp b/webrtc-sys/src/audio_track.cpp index 62705d0fd..2f5cacffe 100644 --- a/webrtc-sys/src/audio_track.cpp +++ b/webrtc-sys/src/audio_track.cpp @@ -23,8 +23,10 @@ #include "api/audio_options.h" #include "api/media_stream_interface.h" +#include "api/task_queue/task_queue_base.h" #include "audio/remix_resample.h" #include "common_audio/include/audio_util.h" +#include "livekit/global_task_queue.h" #include "rtc_base/checks.h" #include "rtc_base/logging.h" #include "rtc_base/ref_counted_object.h" @@ -129,27 +131,29 @@ AudioTrackSource::InternalSource::InternalSource( int sample_rate, int num_channels, int queue_size_ms, // must be a multiple of 10ms - rust::Fn data_needed, webrtc::TaskQueueFactory* task_queue_factory) : sample_rate_(sample_rate), num_channels_(num_channels), - data_needed_(std::move(data_needed)), - data_requested_(false) { + capture_userdata_(nullptr), + on_complete_(nullptr) { if (queue_size_ms > 0) { return; // no audio queue } int samples10ms = sample_rate / 100 * num_channels; - int notify_threshold = samples10ms * 2; + int notify_threshold = + sample_rate / 10; // notify when buffer is less than 100ms - buffer_.resize(queue_size_ms / 10 * samples10ms); + queue_size_samples_ = queue_size_ms / 10 * samples10ms; + buffer_.resize(queue_size_samples_ + notify_threshold); audio_queue_ = std::make_unique(task_queue_factory->CreateTaskQueue( "AudioSourceCapture", webrtc::TaskQueueFactory::Priority::NORMAL)); audio_task_ = webrtc::RepeatingTaskHandle::Start( - audio_queue_->Get(), [this, samples10ms, notify_threshold]() { + audio_queue_->Get(), + [this, samples10ms, notify_threshold]() { webrtc::MutexLock lock(&mutex_); if (buffer_.size() >= samples10ms) { @@ -160,20 +164,24 @@ AudioTrackSource::InternalSource::InternalSource( buffer_.erase(buffer_.begin(), buffer_.begin() + samples10ms); } - if (!data_requested_ && buffer_.size() <= notify_threshold) { - data_requested_ = true; - this->data_needed_(); + if (on_complete_ && buffer_.size() <= notify_threshold) { + on_complete_(capture_userdata_); + on_complete_ = nullptr; + capture_userdata_ = nullptr; } return webrtc::TimeDelta::Millis(10); - }); + }, + webrtc::TaskQueueBase::DelayPrecision::kHigh); } bool AudioTrackSource::InternalSource::capture_frame( rust::Slice data, uint32_t sample_rate, uint32_t number_of_channels, - size_t number_of_frames) { + size_t number_of_frames, + const SourceContext* ctx, + void (*on_complete)(const SourceContext*)) { webrtc::MutexLock lock(&mutex_); if (queue_size_samples_) { @@ -181,10 +189,16 @@ bool AudioTrackSource::InternalSource::capture_frame( if (available < data.size()) return false; + if (on_complete_ || capture_userdata_) + return false; + buffer_.insert(buffer_.end(), data.begin(), data.end()); - data_requested_ = false; + + on_complete_ = on_complete; + capture_userdata_ = ctx; + } else { - // capture directly (realtime source) + // capture directly when the queue buffer is 0 (frame size must be 10ms) for (auto sink : sinks_) sink->OnData(data.data(), sizeof(int16_t), sample_rate, number_of_channels, number_of_frames); @@ -234,14 +248,12 @@ AudioTrackSource::AudioTrackSource(AudioSourceOptions options, int sample_rate, int num_channels, int queue_size_ms, - rust::Fn data_needed, webrtc::TaskQueueFactory* task_queue_factory) : source_(rtc::make_ref_counted( to_native_audio_options(options), sample_rate, num_channels, queue_size_ms, - std::move(data_needed), task_queue_factory)) {} AudioSourceOptions AudioTrackSource::audio_options() const { @@ -253,18 +265,31 @@ void AudioTrackSource::set_audio_options( source_->set_options(to_native_audio_options(options)); } -bool AudioTrackSource::capture_frame(rust::Slice audio_data, - uint32_t sample_rate, - uint32_t number_of_channels, - size_t number_of_frames) const { +bool AudioTrackSource::capture_frame( + rust::Slice audio_data, + uint32_t sample_rate, + uint32_t number_of_channels, + size_t number_of_frames, + const SourceContext* ctx, + void (*on_complete)(const SourceContext*)) const { return source_->capture_frame(audio_data, sample_rate, number_of_channels, - number_of_frames); + number_of_frames, ctx, on_complete); } void AudioTrackSource::clear_buffer() { source_->clear_buffer(); } +std::shared_ptr new_audio_track_source( + AudioSourceOptions options, + int sample_rate, + int num_channels, + int queue_size_ms) { + return std::make_shared(options, sample_rate, num_channels, + queue_size_ms, + GetGlobalTaskQueueFactory()); +} + rtc::scoped_refptr AudioTrackSource::get() const { return source_; diff --git a/webrtc-sys/src/audio_track.rs b/webrtc-sys/src/audio_track.rs index 5cdd3fc6e..88f68f779 100644 --- a/webrtc-sys/src/audio_track.rs +++ b/webrtc-sys/src/audio_track.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +use cxx::type_id; +use cxx::ExternType; +use std::any::Any; use std::sync::Arc; use crate::impl_thread_safety; @@ -29,6 +32,7 @@ pub mod ffi { include!("livekit/media_stream_track.h"); type MediaStreamTrack = crate::media_stream_track::ffi::MediaStreamTrack; + type CompleteCallback = crate::audio_track::CompleteCallback; } unsafe extern "C++" { @@ -46,16 +50,25 @@ pub mod ffi { num_channels: i32, ) -> SharedPtr; - fn capture_frame( + unsafe fn capture_frame( self: &AudioTrackSource, data: &[i16], sample_rate: u32, nb_channels: u32, nb_frames: usize, + userdata: *const SourceContext, + on_complete: CompleteCallback, ) -> bool; fn audio_options(self: &AudioTrackSource) -> AudioSourceOptions; fn set_audio_options(self: &AudioTrackSource, options: &AudioSourceOptions); + fn new_audio_track_source( + options: AudioSourceOptions, + sample_rate: i32, + num_channels: i32, + queue_size_ms: i32, + ) -> SharedPtr; + fn audio_to_media(track: SharedPtr) -> SharedPtr; unsafe fn media_to_audio(track: SharedPtr) -> SharedPtr; fn _shared_audio_track() -> SharedPtr; @@ -64,6 +77,7 @@ pub mod ffi { extern "Rust" { type AudioSinkWrapper; + type SourceContext; fn on_data( self: &AudioSinkWrapper, @@ -79,6 +93,17 @@ impl_thread_safety!(ffi::AudioTrack, Send + Sync); impl_thread_safety!(ffi::NativeAudioSink, Send + Sync); impl_thread_safety!(ffi::AudioTrackSource, Send + Sync); +#[repr(transparent)] +pub struct SourceContext(pub Box); + +#[repr(transparent)] +pub struct CompleteCallback(pub extern "C" fn(ctx: *const SourceContext)); + +unsafe impl ExternType for CompleteCallback { + type Id = type_id!("livekit::CompleteCallback"); + type Kind = cxx::kind::Trivial; +} + pub trait AudioSink: Send { fn on_data(&self, data: &[i16], sample_rate: i32, nb_channels: usize, nb_frames: usize); } diff --git a/webrtc-sys/src/global_task_queue.cpp b/webrtc-sys/src/global_task_queue.cpp new file mode 100644 index 000000000..0f0cfc2b3 --- /dev/null +++ b/webrtc-sys/src/global_task_queue.cpp @@ -0,0 +1,14 @@ +#include "livekit/global_task_queue.h" + +#include "api/task_queue/default_task_queue_factory.h" +#include "api/task_queue/task_queue_factory.h" + +namespace livekit { + +webrtc::TaskQueueFactory* GetGlobalTaskQueueFactory() { + static std::unique_ptr global_task_queue_factory = + webrtc::CreateDefaultTaskQueueFactory(); + return global_task_queue_factory.get(); +} + +} // namespace livekit diff --git a/webrtc-sys/src/peer_connection_factory.cpp b/webrtc-sys/src/peer_connection_factory.cpp index 9a5a9c313..bbc73bfdc 100644 --- a/webrtc-sys/src/peer_connection_factory.cpp +++ b/webrtc-sys/src/peer_connection_factory.cpp @@ -84,6 +84,7 @@ PeerConnectionFactory::PeerConnectionFactory( peer_factory_ = webrtc::CreateModularPeerConnectionFactory(std::move(dependencies)); + task_queue_factory_ = dependencies.task_queue_factory.get(); if (peer_factory_.get() == nullptr) { @@ -130,17 +131,6 @@ std::shared_ptr PeerConnectionFactory::create_audio_track( peer_factory_->CreateAudioTrack(label.c_str(), source->get().get()))); } -std::shared_ptr PeerConnectionFactory::create_audio_source( - AudioSourceOptions options, - int sample_rate, - int num_channels, - int queue_size_ms, - rust::Fn data_needed) const { - return std::make_shared( - options, sample_rate, num_channels, queue_size_ms, std::move(data_needed), - task_queue_factory_); -} - RtpCapabilities PeerConnectionFactory::rtp_sender_capabilities( MediaType type) const { return to_rust_rtp_capabilities(peer_factory_->GetRtpSenderCapabilities( diff --git a/webrtc-sys/src/peer_connection_factory.rs b/webrtc-sys/src/peer_connection_factory.rs index b6a9cc0e2..1ae912848 100644 --- a/webrtc-sys/src/peer_connection_factory.rs +++ b/webrtc-sys/src/peer_connection_factory.rs @@ -79,7 +79,6 @@ pub mod ffi { type MediaStreamTrack = crate::media_stream::ffi::MediaStreamTrack; type SessionDescription = crate::jsep::ffi::SessionDescription; type MediaType = crate::webrtc::ffi::MediaType; - type AudioSourceOptions = crate::audio_track::ffi::AudioSourceOptions; } unsafe extern "C++" { @@ -108,15 +107,6 @@ pub mod ffi { source: SharedPtr, ) -> SharedPtr; - fn create_audio_source( - self: &PeerConnectionFactory, - options: AudioSourceOptions, - sample_rate: i32, - num_channels: i32, - queue_size_ms: i32, - data_needed: fn(), - ) -> SharedPtr; - fn rtp_sender_capabilities( self: &PeerConnectionFactory, kind: MediaType, From 21eef3cf60c80c0756b5a7d8660328002f72b4b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 11 Sep 2024 18:07:54 -0700 Subject: [PATCH 3/7] Update livekit.proto.rs --- livekit-ffi/src/livekit.proto.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/livekit-ffi/src/livekit.proto.rs b/livekit-ffi/src/livekit.proto.rs index af2c82e89..1635c130c 100644 --- a/livekit-ffi/src/livekit.proto.rs +++ b/livekit-ffi/src/livekit.proto.rs @@ -1,5 +1,4 @@ // @generated -// This file is @generated by prost-build. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FrameCryptor { From 04fd97ecc625270609346e6761047f7e4ad131e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 11 Sep 2024 20:21:27 -0700 Subject: [PATCH 4/7] also optimize for small frames --- examples/Cargo.lock | 14 +++++------ examples/wgpu_room/src/sine_track.rs | 32 ++++++------------------ libwebrtc/src/native/audio_source.rs | 9 +++++++ webrtc-sys/include/livekit/audio_track.h | 1 + webrtc-sys/src/audio_track.cpp | 23 +++++++++-------- 5 files changed, 37 insertions(+), 42 deletions(-) diff --git a/examples/Cargo.lock b/examples/Cargo.lock index 21ec559f1..94d6b303f 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -2108,7 +2108,7 @@ dependencies = [ [[package]] name = "libwebrtc" -version = "0.3.2" +version = "0.3.7" dependencies = [ "cxx", "jni", @@ -2122,7 +2122,6 @@ dependencies = [ "serde_json", "thiserror", "tokio", - "tokio-stream", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", @@ -2152,7 +2151,7 @@ checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "livekit" -version = "0.3.2" +version = "0.6.0" dependencies = [ "futures-util", "lazy_static", @@ -2171,7 +2170,7 @@ dependencies = [ [[package]] name = "livekit-api" -version = "0.3.2" +version = "0.4.0" dependencies = [ "async-tungstenite", "base64", @@ -2196,7 +2195,7 @@ dependencies = [ [[package]] name = "livekit-protocol" -version = "0.3.2" +version = "0.3.5" dependencies = [ "futures-util", "livekit-runtime", @@ -2215,6 +2214,7 @@ name = "livekit-runtime" version = "0.3.0" dependencies = [ "tokio", + "tokio-stream", ] [[package]] @@ -4455,7 +4455,7 @@ checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "webrtc-sys" -version = "0.3.2" +version = "0.3.5" dependencies = [ "cc", "cxx", @@ -4467,7 +4467,7 @@ dependencies = [ [[package]] name = "webrtc-sys-build" -version = "0.3.2" +version = "0.3.5" dependencies = [ "fs2", "regex", diff --git a/examples/wgpu_room/src/sine_track.rs b/examples/wgpu_room/src/sine_track.rs index becaf8217..1825bb8a7 100644 --- a/examples/wgpu_room/src/sine_track.rs +++ b/examples/wgpu_room/src/sine_track.rs @@ -17,12 +17,7 @@ pub struct SineParameters { impl Default for SineParameters { fn default() -> Self { - Self { - sample_rate: 48000, - freq: 440.0, - amplitude: 1.0, - num_channels: 2, - } + Self { sample_rate: 48000, freq: 440.0, amplitude: 1.0, num_channels: 2 } } } @@ -46,7 +41,7 @@ impl SineTrack { AudioSourceOptions::default(), params.sample_rate, params.num_channels, - None, + 1000, ), params, room, @@ -65,28 +60,18 @@ impl SineTrack { RtcAudioSource::Native(self.rtc_source.clone()), ); - let task = tokio::spawn(Self::track_task( - close_rx, - self.rtc_source.clone(), - self.params.clone(), - )); + let task = + tokio::spawn(Self::track_task(close_rx, self.rtc_source.clone(), self.params.clone())); self.room .local_participant() .publish_track( LocalTrack::Audio(track.clone()), - TrackPublishOptions { - source: TrackSource::Microphone, - ..Default::default() - }, + TrackPublishOptions { source: TrackSource::Microphone, ..Default::default() }, ) .await?; - let handle = TrackHandle { - close_tx, - track, - task, - }; + let handle = TrackHandle { close_tx, track, task }; self.handle = Some(handle); Ok(()) @@ -96,10 +81,7 @@ impl SineTrack { if let Some(handle) = self.handle.take() { handle.close_tx.send(()).ok(); handle.task.await.ok(); - self.room - .local_participant() - .unpublish_track(&handle.track.sid()) - .await?; + self.room.local_participant().unpublish_track(&handle.track.sid()).await?; } Ok(()) diff --git a/libwebrtc/src/native/audio_source.rs b/libwebrtc/src/native/audio_source.rs index e01c411db..249e55e11 100644 --- a/libwebrtc/src/native/audio_source.rs +++ b/libwebrtc/src/native/audio_source.rs @@ -35,6 +35,11 @@ impl NativeAudioSource { ) -> NativeAudioSource { assert!(queue_size_ms % 10 == 0, "queue_size_ms must be a multiple of 10"); + print!( + "new audio source {} {} {} {}", + sample_rate, num_channels, queue_size_ms, options.echo_cancellation + ); + let sys_handle = sys_at::ffi::new_audio_track_source( options.into(), sample_rate.try_into().unwrap(), @@ -75,12 +80,14 @@ impl NativeAudioSource { } extern "C" fn lk_audio_source_complete(userdata: *const sys_at::SourceContext) { + println!("lk_audio_source_complete"); let tx = unsafe { Box::from_raw(userdata as *mut oneshot::Sender<()>) }; let _ = tx.send(()); } // iterate over chunks of self._queue_size_samples for chunk in frame.data.chunks(self.queue_size_samples as usize) { + println!("capturing frame {}", chunk.len()); let nb_frames = chunk.len() / self.num_channels as usize; let (tx, rx) = oneshot::channel::<()>(); let ctx = Box::new(tx); @@ -95,6 +102,7 @@ impl NativeAudioSource { ctx_ptr, sys_at::CompleteCallback(lk_audio_source_complete), ) { + print!("failed to capture frame"); return Err(RtcError { error_type: RtcErrorType::InvalidState, message: "failed to capture frame".to_owned(), @@ -103,6 +111,7 @@ impl NativeAudioSource { } let _ = rx.await; + println!("captured frame"); } Ok(()) diff --git a/webrtc-sys/include/livekit/audio_track.h b/webrtc-sys/include/livekit/audio_track.h index 1ed24664c..b2ee28986 100644 --- a/webrtc-sys/include/livekit/audio_track.h +++ b/webrtc-sys/include/livekit/audio_track.h @@ -137,6 +137,7 @@ class AudioTrackSource { int sample_rate_; int num_channels_; int queue_size_samples_; + int notify_threshold_samples_; cricket::AudioOptions options_{}; }; diff --git a/webrtc-sys/src/audio_track.cpp b/webrtc-sys/src/audio_track.cpp index 2f5cacffe..e83b4bf34 100644 --- a/webrtc-sys/src/audio_track.cpp +++ b/webrtc-sys/src/audio_track.cpp @@ -136,16 +136,14 @@ AudioTrackSource::InternalSource::InternalSource( num_channels_(num_channels), capture_userdata_(nullptr), on_complete_(nullptr) { - if (queue_size_ms > 0) { + if (!queue_size_ms) { return; // no audio queue } int samples10ms = sample_rate / 100 * num_channels; - int notify_threshold = - sample_rate / 10; // notify when buffer is less than 100ms - queue_size_samples_ = queue_size_ms / 10 * samples10ms; - buffer_.resize(queue_size_samples_ + notify_threshold); + notify_threshold_samples_ = queue_size_samples_; + buffer_.resize(queue_size_samples_ + notify_threshold_samples_); audio_queue_ = std::make_unique(task_queue_factory->CreateTaskQueue( @@ -153,7 +151,7 @@ AudioTrackSource::InternalSource::InternalSource( audio_task_ = webrtc::RepeatingTaskHandle::Start( audio_queue_->Get(), - [this, samples10ms, notify_threshold]() { + [this, samples10ms]() { webrtc::MutexLock lock(&mutex_); if (buffer_.size() >= samples10ms) { @@ -164,7 +162,7 @@ AudioTrackSource::InternalSource::InternalSource( buffer_.erase(buffer_.begin(), buffer_.begin() + samples10ms); } - if (on_complete_ && buffer_.size() <= notify_threshold) { + if (on_complete_ && buffer_.size() <= notify_threshold_samples_) { on_complete_(capture_userdata_); on_complete_ = nullptr; capture_userdata_ = nullptr; @@ -185,7 +183,8 @@ bool AudioTrackSource::InternalSource::capture_frame( webrtc::MutexLock lock(&mutex_); if (queue_size_samples_) { - int available = queue_size_samples_ - buffer_.size(); + int available = + (queue_size_samples_ + notify_threshold_samples_) - buffer_.size(); if (available < data.size()) return false; @@ -194,8 +193,12 @@ bool AudioTrackSource::InternalSource::capture_frame( buffer_.insert(buffer_.end(), data.begin(), data.end()); - on_complete_ = on_complete; - capture_userdata_ = ctx; + if (buffer_.size() <= notify_threshold_samples_) { + on_complete(ctx); // complete directly + } else { + on_complete_ = on_complete; + capture_userdata_ = ctx; + } } else { // capture directly when the queue buffer is 0 (frame size must be 10ms) From c0723f0a037ace77063c5d06efb202f57e240a36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 11 Sep 2024 20:23:12 -0700 Subject: [PATCH 5/7] Update audio_track.cpp --- webrtc-sys/src/audio_track.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/webrtc-sys/src/audio_track.cpp b/webrtc-sys/src/audio_track.cpp index e83b4bf34..d466ecf3d 100644 --- a/webrtc-sys/src/audio_track.cpp +++ b/webrtc-sys/src/audio_track.cpp @@ -142,7 +142,8 @@ AudioTrackSource::InternalSource::InternalSource( int samples10ms = sample_rate / 100 * num_channels; queue_size_samples_ = queue_size_ms / 10 * samples10ms; - notify_threshold_samples_ = queue_size_samples_; + notify_threshold_samples_ = queue_size_samples_; // TODO: this is currently using x2 the queue + // size buffer_.resize(queue_size_samples_ + notify_threshold_samples_); audio_queue_ = From bc3bd8b71571a7b264676005e42779e3b3c0bb6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Wed, 11 Sep 2024 20:27:59 -0700 Subject: [PATCH 6/7] fmt --- webrtc-sys/src/audio_track.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/webrtc-sys/src/audio_track.cpp b/webrtc-sys/src/audio_track.cpp index d466ecf3d..7804f2304 100644 --- a/webrtc-sys/src/audio_track.cpp +++ b/webrtc-sys/src/audio_track.cpp @@ -136,14 +136,13 @@ AudioTrackSource::InternalSource::InternalSource( num_channels_(num_channels), capture_userdata_(nullptr), on_complete_(nullptr) { - if (!queue_size_ms) { + if (!queue_size_ms) return; // no audio queue - } int samples10ms = sample_rate / 100 * num_channels; queue_size_samples_ = queue_size_ms / 10 * samples10ms; - notify_threshold_samples_ = queue_size_samples_; // TODO: this is currently using x2 the queue - // size + notify_threshold_samples_ = queue_size_samples_; // TODO: this is currently + // using x2 the queue size buffer_.resize(queue_size_samples_ + notify_threshold_samples_); audio_queue_ = From d6948b1ca0da73a533cf5efc8c50d55ea7c9e85c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?The=CC=81o=20Monnom?= Date: Thu, 12 Sep 2024 11:01:21 -0700 Subject: [PATCH 7/7] clear_buffer --- libwebrtc/src/audio_source.rs | 4 ++++ libwebrtc/src/native/audio_source.rs | 7 +++--- livekit-ffi/protocol/audio_frame.proto | 5 ++++ livekit-ffi/protocol/ffi.proto | 18 +++++++------- livekit-ffi/src/livekit.proto.rs | 30 +++++++++++++++++------- livekit-ffi/src/server/audio_source.rs | 8 +++++++ livekit-ffi/src/server/requests.rs | 13 ++++++++++ webrtc-sys/include/livekit/audio_track.h | 2 +- webrtc-sys/src/audio_track.cpp | 12 +++++----- webrtc-sys/src/audio_track.rs | 1 + 10 files changed, 74 insertions(+), 26 deletions(-) diff --git a/libwebrtc/src/audio_source.rs b/libwebrtc/src/audio_source.rs index f2f0ffac1..1df5c4878 100644 --- a/libwebrtc/src/audio_source.rs +++ b/libwebrtc/src/audio_source.rs @@ -75,6 +75,10 @@ pub mod native { } } + pub fn clear_buffer(&self) { + self.handle.clear_buffer() + } + pub async fn capture_frame(&self, frame: &AudioFrame<'_>) -> Result<(), RtcError> { self.handle.capture_frame(frame).await } diff --git a/libwebrtc/src/native/audio_source.rs b/libwebrtc/src/native/audio_source.rs index 249e55e11..c7652c676 100644 --- a/libwebrtc/src/native/audio_source.rs +++ b/libwebrtc/src/native/audio_source.rs @@ -71,6 +71,10 @@ impl NativeAudioSource { self.num_channels } + pub fn clear_buffer(&self) { + self.sys_handle.clear_buffer(); + } + pub async fn capture_frame(&self, frame: &AudioFrame<'_>) -> Result<(), RtcError> { if self.sample_rate != frame.sample_rate || self.num_channels != frame.num_channels { return Err(RtcError { @@ -80,14 +84,12 @@ impl NativeAudioSource { } extern "C" fn lk_audio_source_complete(userdata: *const sys_at::SourceContext) { - println!("lk_audio_source_complete"); let tx = unsafe { Box::from_raw(userdata as *mut oneshot::Sender<()>) }; let _ = tx.send(()); } // iterate over chunks of self._queue_size_samples for chunk in frame.data.chunks(self.queue_size_samples as usize) { - println!("capturing frame {}", chunk.len()); let nb_frames = chunk.len() / self.num_channels as usize; let (tx, rx) = oneshot::channel::<()>(); let ctx = Box::new(tx); @@ -102,7 +104,6 @@ impl NativeAudioSource { ctx_ptr, sys_at::CompleteCallback(lk_audio_source_complete), ) { - print!("failed to capture frame"); return Err(RtcError { error_type: RtcErrorType::InvalidState, message: "failed to capture frame".to_owned(), diff --git a/livekit-ffi/protocol/audio_frame.proto b/livekit-ffi/protocol/audio_frame.proto index 14a225cb9..8307d265d 100644 --- a/livekit-ffi/protocol/audio_frame.proto +++ b/livekit-ffi/protocol/audio_frame.proto @@ -64,6 +64,11 @@ message CaptureAudioFrameCallback { optional string error = 2; } +message ClearAudioBufferRequest { + uint64 source_handle = 1; +} +message ClearAudioBufferResponse {} + // Create a new AudioResampler message NewAudioResamplerRequest {} message NewAudioResamplerResponse { diff --git a/livekit-ffi/protocol/ffi.proto b/livekit-ffi/protocol/ffi.proto index ba87d8417..173d36286 100644 --- a/livekit-ffi/protocol/ffi.proto +++ b/livekit-ffi/protocol/ffi.proto @@ -88,10 +88,11 @@ message FfiRequest { NewAudioStreamRequest new_audio_stream = 25; NewAudioSourceRequest new_audio_source = 26; CaptureAudioFrameRequest capture_audio_frame = 27; - NewAudioResamplerRequest new_audio_resampler = 28; - RemixAndResampleRequest remix_and_resample = 29; - E2eeRequest e2ee = 30; - AudioStreamFromParticipantRequest audio_stream_from_participant = 31; + ClearAudioBufferRequest clear_audio_buffer = 28; + NewAudioResamplerRequest new_audio_resampler = 29; + RemixAndResampleRequest remix_and_resample = 30; + E2eeRequest e2ee = 31; + AudioStreamFromParticipantRequest audio_stream_from_participant = 32; } } @@ -132,10 +133,11 @@ message FfiResponse { NewAudioStreamResponse new_audio_stream = 25; NewAudioSourceResponse new_audio_source = 26; CaptureAudioFrameResponse capture_audio_frame = 27; - NewAudioResamplerResponse new_audio_resampler = 28; - RemixAndResampleResponse remix_and_resample = 29; - AudioStreamFromParticipantResponse audio_stream_from_participant = 30; - E2eeResponse e2ee = 31; + ClearAudioBufferResponse clear_audio_buffer = 28; + NewAudioResamplerResponse new_audio_resampler = 29; + RemixAndResampleResponse remix_and_resample = 30; + AudioStreamFromParticipantResponse audio_stream_from_participant = 31; + E2eeResponse e2ee = 32; } } diff --git a/livekit-ffi/src/livekit.proto.rs b/livekit-ffi/src/livekit.proto.rs index 1635c130c..1cfce9572 100644 --- a/livekit-ffi/src/livekit.proto.rs +++ b/livekit-ffi/src/livekit.proto.rs @@ -3014,6 +3014,16 @@ pub struct CaptureAudioFrameCallback { #[prost(string, optional, tag="2")] pub error: ::core::option::Option<::prost::alloc::string::String>, } +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearAudioBufferRequest { + #[prost(uint64, tag="1")] + pub source_handle: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClearAudioBufferResponse { +} /// Create a new AudioResampler #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -3240,7 +3250,7 @@ impl AudioSourceType { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FfiRequest { - #[prost(oneof="ffi_request::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31")] + #[prost(oneof="ffi_request::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32")] pub message: ::core::option::Option, } /// Nested message and enum types in `FfiRequest`. @@ -3305,12 +3315,14 @@ pub mod ffi_request { #[prost(message, tag="27")] CaptureAudioFrame(super::CaptureAudioFrameRequest), #[prost(message, tag="28")] - NewAudioResampler(super::NewAudioResamplerRequest), + ClearAudioBuffer(super::ClearAudioBufferRequest), #[prost(message, tag="29")] - RemixAndResample(super::RemixAndResampleRequest), + NewAudioResampler(super::NewAudioResamplerRequest), #[prost(message, tag="30")] - E2ee(super::E2eeRequest), + RemixAndResample(super::RemixAndResampleRequest), #[prost(message, tag="31")] + E2ee(super::E2eeRequest), + #[prost(message, tag="32")] AudioStreamFromParticipant(super::AudioStreamFromParticipantRequest), } } @@ -3318,7 +3330,7 @@ pub mod ffi_request { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FfiResponse { - #[prost(oneof="ffi_response::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31")] + #[prost(oneof="ffi_response::Message", tags="2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32")] pub message: ::core::option::Option, } /// Nested message and enum types in `FfiResponse`. @@ -3383,12 +3395,14 @@ pub mod ffi_response { #[prost(message, tag="27")] CaptureAudioFrame(super::CaptureAudioFrameResponse), #[prost(message, tag="28")] - NewAudioResampler(super::NewAudioResamplerResponse), + ClearAudioBuffer(super::ClearAudioBufferResponse), #[prost(message, tag="29")] - RemixAndResample(super::RemixAndResampleResponse), + NewAudioResampler(super::NewAudioResamplerResponse), #[prost(message, tag="30")] - AudioStreamFromParticipant(super::AudioStreamFromParticipantResponse), + RemixAndResample(super::RemixAndResampleResponse), #[prost(message, tag="31")] + AudioStreamFromParticipant(super::AudioStreamFromParticipantResponse), + #[prost(message, tag="32")] E2ee(super::E2eeResponse), } } diff --git a/livekit-ffi/src/server/audio_source.rs b/livekit-ffi/src/server/audio_source.rs index 904cba1a4..e07acee8e 100644 --- a/livekit-ffi/src/server/audio_source.rs +++ b/livekit-ffi/src/server/audio_source.rs @@ -62,6 +62,14 @@ impl FfiAudioSource { }) } + pub fn clear_buffer(&self) { + match self.source { + #[cfg(not(target_arch = "wasm32"))] + RtcAudioSource::Native(ref source) => source.clear_buffer(), + _ => {} + } + } + pub fn capture_frame( &self, server: &'static server::FfiServer, diff --git a/livekit-ffi/src/server/requests.rs b/livekit-ffi/src/server/requests.rs index 705e86a1d..89e395769 100644 --- a/livekit-ffi/src/server/requests.rs +++ b/livekit-ffi/src/server/requests.rs @@ -426,6 +426,16 @@ fn on_capture_audio_frame( source.capture_frame(server, push) } +// Clear the internal audio buffer (cancel all pending frames from being played) +fn on_clear_audio_buffer( + server: &'static FfiServer, + clear: proto::ClearAudioBufferRequest, +) -> FfiResult { + let source = server.retrieve_handle::(clear.source_handle)?; + source.clear_buffer(); + Ok(proto::ClearAudioBufferResponse {}) +} + /// Create a new audio resampler fn new_audio_resampler( server: &'static FfiServer, @@ -744,6 +754,9 @@ pub fn handle_request( proto::ffi_request::Message::CaptureAudioFrame(push) => { proto::ffi_response::Message::CaptureAudioFrame(on_capture_audio_frame(server, push)?) } + proto::ffi_request::Message::ClearAudioBuffer(clear) => { + proto::ffi_response::Message::ClearAudioBuffer(on_clear_audio_buffer(server, clear)?) + } proto::ffi_request::Message::NewAudioResampler(new_res) => { proto::ffi_response::Message::NewAudioResampler(new_audio_resampler(server, new_res)?) } diff --git a/webrtc-sys/include/livekit/audio_track.h b/webrtc-sys/include/livekit/audio_track.h index b2ee28986..029158fd5 100644 --- a/webrtc-sys/include/livekit/audio_track.h +++ b/webrtc-sys/include/livekit/audio_track.h @@ -160,7 +160,7 @@ class AudioTrackSource { const SourceContext* ctx, CompleteCallback on_complete) const; - void clear_buffer(); + void clear_buffer() const; rtc::scoped_refptr get() const; diff --git a/webrtc-sys/src/audio_track.cpp b/webrtc-sys/src/audio_track.cpp index 7804f2304..980c22efb 100644 --- a/webrtc-sys/src/audio_track.cpp +++ b/webrtc-sys/src/audio_track.cpp @@ -210,6 +210,11 @@ bool AudioTrackSource::InternalSource::capture_frame( return true; } +void AudioTrackSource::InternalSource::clear_buffer() { + webrtc::MutexLock lock(&mutex_); + buffer_.clear(); +} + webrtc::MediaSourceInterface::SourceState AudioTrackSource::InternalSource::state() const { return webrtc::MediaSourceInterface::SourceState::kLive; @@ -242,11 +247,6 @@ void AudioTrackSource::InternalSource::RemoveSink( sinks_.erase(std::remove(sinks_.begin(), sinks_.end(), sink), sinks_.end()); } -void AudioTrackSource::InternalSource::clear_buffer() { - webrtc::MutexLock lock(&mutex_); - buffer_.clear(); -} - AudioTrackSource::AudioTrackSource(AudioSourceOptions options, int sample_rate, int num_channels, @@ -279,7 +279,7 @@ bool AudioTrackSource::capture_frame( number_of_frames, ctx, on_complete); } -void AudioTrackSource::clear_buffer() { +void AudioTrackSource::clear_buffer() const { source_->clear_buffer(); } diff --git a/webrtc-sys/src/audio_track.rs b/webrtc-sys/src/audio_track.rs index 88f68f779..d24f87520 100644 --- a/webrtc-sys/src/audio_track.rs +++ b/webrtc-sys/src/audio_track.rs @@ -59,6 +59,7 @@ pub mod ffi { userdata: *const SourceContext, on_complete: CompleteCallback, ) -> bool; + fn clear_buffer(self: &AudioTrackSource); fn audio_options(self: &AudioTrackSource) -> AudioSourceOptions; fn set_audio_options(self: &AudioTrackSource, options: &AudioSourceOptions);