Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl Drop for ActiveConnection {
pub fn active_connection_count() -> u64 { ACTIVE_CONNECTIONS.load(Ordering::Relaxed) }

use crate::{
fairness::Fairness,
fairness::FairnessTracker,
hooks::{ConnectionContext, ProtocolHooks},
push::{FrameLike, PushHandle, PushQueues},
response::{FrameStream, WireframeError},
Expand Down Expand Up @@ -111,7 +111,7 @@ pub struct ConnectionActor<F, E> {
counter: Option<ActiveConnection>,
hooks: ProtocolHooks<F, E>,
ctx: ConnectionContext,
fairness: Fairness,
fairness: FairnessTracker,
connection_id: Option<ConnectionId>,
peer_addr: Option<SocketAddr>,
}
Expand Down Expand Up @@ -169,7 +169,7 @@ where
counter: Some(counter),
hooks,
ctx,
fairness: Fairness::new(FairnessConfig::default()),
fairness: FairnessTracker::new(FairnessConfig::default()),
connection_id: None,
peer_addr: None,
};
Expand Down Expand Up @@ -369,9 +369,9 @@ where

/// Update counters and opportunistically drain the low-priority queue.
fn after_high(&mut self, out: &mut Vec<F>, state: &mut ActorState) {
self.fairness.after_high();
self.fairness.record_high_priority();

if self.fairness.should_yield() {
if self.fairness.should_yield_to_low_priority() {
let res = self.low_rx.as_mut().map(mpsc::Receiver::try_recv);
if let Some(res) = res {
match res {
Expand All @@ -390,7 +390,7 @@ where
}

/// Reset counters after processing a low-priority frame.
fn after_low(&mut self) { self.fairness.after_low(); }
fn after_low(&mut self) { self.fairness.reset(); }

/// Push a frame from the response stream into `out` or handle completion.
///
Expand Down
120 changes: 85 additions & 35 deletions src/fairness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,44 @@
//!
//! This module encapsulates the logic for deciding when high-priority
//! processing should yield to low-priority traffic based on configured
//! thresholds and optional time slices.
//! thresholds and optional time slices. A pluggable [`Clock`] allows the
//! timing logic to be tested without depending on Tokio's global time.

use tokio::time::Instant;

use crate::connection::FairnessConfig;

/// Time source used by [`FairnessTracker`].
pub(crate) trait Clock: Clone {
/// Return the current instant.
fn now(&self) -> Instant;
}

/// Clock implementation backed by [`tokio::time`].
#[derive(Clone, Debug, Default)]
pub(crate) struct TokioClock;

impl Clock for TokioClock {
fn now(&self) -> Instant { Instant::now() }
}

#[derive(Debug)]
pub(crate) struct Fairness {
pub(crate) struct FairnessTracker<C: Clock = TokioClock> {
config: FairnessConfig,
clock: C,
high_counter: usize,
high_start: Option<Instant>,
}

impl Fairness {
pub(crate) fn new(config: FairnessConfig) -> Self {
impl FairnessTracker {
pub(crate) fn new(config: FairnessConfig) -> Self { Self::with_clock(config, TokioClock) }
}

impl<C: Clock> FairnessTracker<C> {
pub(crate) fn with_clock(config: FairnessConfig, clock: C) -> Self {
Self {
config,
clock,
high_counter: 0,
high_start: None,
}
Expand All @@ -29,36 +50,41 @@ impl Fairness {
self.reset();
}

pub(crate) fn after_high(&mut self) {
pub(crate) fn record_high_priority(&mut self) {
self.high_counter += 1;
if self.high_counter == 1 {
self.high_start = Some(Instant::now());
self.high_start = Some(self.clock.now());
}
}

pub(crate) fn should_yield(&self) -> bool {
let threshold_hit = self.config.max_high_before_low > 0
&& self.high_counter >= self.config.max_high_before_low;
let time_hit = self
.config
.time_slice
.zip(self.high_start)
.is_some_and(|(slice, start)| start.elapsed() >= slice);
threshold_hit || time_hit
pub(crate) fn should_yield_to_low_priority(&self) -> bool {
if self.config.max_high_before_low > 0
&& self.high_counter >= self.config.max_high_before_low
{
return true;
}

if let (Some(slice), Some(start)) = (self.config.time_slice, self.high_start) {
return self.clock.now().duration_since(start) >= slice;
}

false
}

pub(crate) fn after_low(&mut self) { self.reset(); }
pub(crate) fn reset(&mut self) { self.clear(); }

pub(crate) fn reset(&mut self) {
fn clear(&mut self) {
self.high_counter = 0;
self.high_start = None;
}
}

#[cfg(test)]
mod tests {
use std::sync::{Arc, Mutex};

use rstest::rstest;
use tokio::time::{self, Duration};
use tokio::time::{Duration, Instant};

use super::*;

Expand All @@ -69,11 +95,11 @@ mod tests {
max_high_before_low: 2,
time_slice: None,
};
let mut fairness = Fairness::new(cfg);
fairness.after_high();
assert!(!fairness.should_yield());
fairness.after_high();
assert!(fairness.should_yield());
let mut fairness = FairnessTracker::new(cfg);
fairness.record_high_priority();
assert!(!fairness.should_yield_to_low_priority());
fairness.record_high_priority();
assert!(fairness.should_yield_to_low_priority());
}

#[rstest]
Expand All @@ -83,24 +109,48 @@ mod tests {
max_high_before_low: 1,
time_slice: None,
};
let mut fairness = Fairness::new(cfg);
fairness.after_high();
assert!(fairness.should_yield());
fairness.after_low();
assert!(!fairness.should_yield());
let mut fairness = FairnessTracker::new(cfg);
fairness.record_high_priority();
assert!(fairness.should_yield_to_low_priority());
fairness.reset();
assert!(!fairness.should_yield_to_low_priority());
}

#[derive(Clone, Debug)]
struct MockClock {
now: Arc<Mutex<Instant>>,
}

impl MockClock {
fn new(start: Instant) -> Self {
Self {
now: Arc::new(Mutex::new(start)),
}
}

fn advance(&self, dur: Duration) {
let mut now = self.now.lock().expect("lock poisoned");
*now += dur;
}
}

impl Clock for MockClock {
fn now(&self) -> Instant { *self.now.lock().expect("lock poisoned") }
}

#[rstest]
#[tokio::test]
async fn time_slice_triggers_yield() {
time::pause();
#[test]
fn time_slice_triggers_yield() {
let start = Instant::now();
let clock = MockClock::new(start);
let cfg = FairnessConfig {
max_high_before_low: 0,
time_slice: Some(Duration::from_millis(5)),
};
let mut fairness = Fairness::new(cfg);
fairness.after_high();
time::advance(Duration::from_millis(6)).await;
assert!(fairness.should_yield());
let mut fairness = FairnessTracker::with_clock(cfg, clock.clone());
fairness.record_high_priority();
assert!(!fairness.should_yield_to_low_priority());
clock.advance(Duration::from_millis(5));
assert!(fairness.should_yield_to_low_priority());
}
}
77 changes: 31 additions & 46 deletions src/server/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,38 +315,37 @@ mod tests {
}

#[rstest]
#[case("success")]
#[case("failure")]
#[tokio::test]
async fn test_preamble_success_callback(
async fn test_preamble_callback_registration(
factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static,
#[case] callback_type: &str,
) {
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let server = server_with_preamble(factory).on_preamble_decode_success(
move |_p: &TestPreamble, _| {

let server = server_with_preamble(factory);
let server = match callback_type {
"success" => server.on_preamble_decode_success(move |_p: &TestPreamble, _| {
let c = c.clone();
Box::pin(async move {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
})
},
);
assert_eq!(counter.load(Ordering::SeqCst), 0);
assert!(server.on_preamble_success.is_some());
}

#[rstest]
#[tokio::test]
async fn test_preamble_failure_callback(
factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static,
) {
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let server =
server_with_preamble(factory).on_preamble_decode_failure(move |_err: &DecodeError| {
}),
"failure" => server.on_preamble_decode_failure(move |_err: &DecodeError| {
c.fetch_add(1, Ordering::SeqCst);
});
}),
_ => panic!("Invalid callback type"),
};

assert_eq!(counter.load(Ordering::SeqCst), 0);
assert!(server.on_preamble_failure.is_some());
match callback_type {
"success" => assert!(server.on_preamble_success.is_some()),
"failure" => assert!(server.on_preamble_failure.is_some()),
_ => unreachable!(),
}
}

#[rstest]
Expand All @@ -355,13 +354,13 @@ mod tests {
factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static,
free_port: SocketAddr,
) {
let counter = Arc::new(AtomicUsize::new(0));
let c = counter.clone();
let callback_invoked = Arc::new(AtomicUsize::new(0));
let counter = callback_invoked.clone();
let server = WireframeServer::new(factory)
.workers(2)
.with_preamble::<TestPreamble>()
.on_preamble_decode_success(move |_p: &TestPreamble, _| {
let c = c.clone();
let c = counter.clone();
Box::pin(async move {
c.fetch_add(1, Ordering::SeqCst);
Ok(())
Expand All @@ -372,7 +371,7 @@ mod tests {
.expect("Failed to bind");
assert_eq!(server.worker_count(), 2);
assert!(server.local_addr().is_some());
assert!(server.on_preamble_success.is_some() && server.on_preamble_failure.is_some());
assert_eq!(callback_invoked.load(Ordering::SeqCst), 0);
}

#[rstest]
Expand All @@ -389,18 +388,6 @@ mod tests {
assert!(server.local_addr().is_some());
}

#[rstest]
fn test_preamble_callbacks_reset_on_type_change(
factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static,
) {
let server = WireframeServer::new(factory)
.on_preamble_decode_success(|&(), _| Box::pin(async { Ok(()) }))
.on_preamble_decode_failure(|_: &DecodeError| {});
assert!(server.on_preamble_success.is_some() && server.on_preamble_failure.is_some());
let server = server.with_preamble::<TestPreamble>();
assert!(server.on_preamble_success.is_none() && server.on_preamble_failure.is_none());
}

#[rstest]
fn test_extreme_worker_counts(
factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static,
Expand All @@ -417,16 +404,14 @@ mod tests {
factory: impl Fn() -> WireframeApp + Send + Sync + Clone + 'static,
free_port: SocketAddr,
) {
let addr2 = {
let listener = std::net::TcpListener::bind(SocketAddr::new(
std::net::Ipv4Addr::LOCALHOST.into(),
0,
))
.expect("failed to bind second listener");
listener
.local_addr()
.expect("failed to get second listener address")
};
let listener2 =
std::net::TcpListener::bind(SocketAddr::new(std::net::Ipv4Addr::LOCALHOST.into(), 0))
.expect("failed to bind second listener");
let addr2 = listener2
.local_addr()
.expect("failed to get second listener address");
drop(listener2);

let server = WireframeServer::new(factory);
let server = server
.bind(free_port)
Expand Down
Loading