From 3c2f65ff733cdc1d7bb40412738ce091ff5839ef Mon Sep 17 00:00:00 2001 From: AlexTMjugador Date: Sat, 23 Apr 2022 23:39:51 +0200 Subject: [PATCH] Add async PacketWriter Given that this crate already supports async reads, it makes sense to complete its async support for writes for API parity. To achieve this addition the packet writing logic was extracted to a new private struct, `BasePacketWriter`, that is largely the same as before, with the exception of some refactors to delete data from ended streams of the map, which was a TODO comment, and getting rid of byteorder and fallible I/O operations. This new struct is not responsible for doing any I/O: it just invokes a callback, which can return `false` to indicate a failure and stop the writing process. The blocking `PacketWriter` now uses `BasePacketWriter` with a callback that writes everything to the sink as before. Any error that might happen is stored and returned, with the same user-visible behavior as before. The new async `PacketWriter` is pretty trivial thanks to Tokio encoders and the usage of callbacks by `BasePacketWriter`, as it is possible to simply copy any written pages to a Tokio-provided buffer, which always succeeds. Tokio takes care of all the dirty details of actual I/O. --- Cargo.toml | 3 +- src/writing.rs | 422 +++++++++++++++++++++++++++++++++---------------- 2 files changed, 290 insertions(+), 135 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f290ee0..67f9ca1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ name = "ogg" [dependencies] byteorder = "1.0" futures-core = { version = "0.3", optional = true } +futures-sink = { version = "0.3", optional = true } futures-io = { version = "0.3", optional = true } tokio = { version = "1", optional = true } tokio-util = { version = "0.6", features = ["codec", "compat"], optional = true } @@ -29,7 +30,7 @@ tokio = { version = "1", features = ["full"] } futures-util = "0.3" [features] -async = ["futures-core", "futures-io", "tokio", "tokio-util", "bytes", "pin-project"] +async = ["futures-core", "futures-sink", "futures-io", "tokio", "tokio-util", "bytes", "pin-project"] [package.metadata.docs.rs] all-features = true diff --git a/src/writing.rs b/src/writing.rs index e4c919b..c822283 100644 --- a/src/writing.rs +++ b/src/writing.rs @@ -11,18 +11,18 @@ Writing logic */ use std::borrow::Cow; -use std::result; -use std::io::{self, Cursor, Write, Seek, SeekFrom}; -use byteorder::{WriteBytesExt, LittleEndian}; +use std::io::{Write, Seek, SeekFrom, Result}; use std::collections::HashMap; use crate::crc::vorbis_crc32_update; - -/// Ogg version of the `std::io::Result` type. -/// -/// We need `std::result::Result` at other points -/// too, so we can't use `Result` as the name. -type IoResult = result::Result; +/// Returns `false` if the specified expression is false. +macro_rules! bail_out_on_fail { + ($success:expr) => { + if !$success { + return false; + } + } +} /** Writer for packets into an Ogg stream. @@ -30,9 +30,14 @@ Writer for packets into an Ogg stream. Note that the functionality of this struct isn't as well tested as for the `PacketReader` struct. */ -pub struct PacketWriter<'writer, T :io::Write> { +pub struct PacketWriter<'writer, T :Write> { wtr :T, + base_pck_wtr :BasePacketWriter<'writer>, +} + +/// Internal base packet writer that contains common packet writing logic. +struct BasePacketWriter<'writer> { page_vals :HashMap>, } @@ -90,46 +95,17 @@ pub enum PacketWriteEndInfo { EndStream, } -impl <'writer, T :io::Write> PacketWriter<'writer, T> { - pub fn new(wtr :T) -> Self { - return PacketWriter { - wtr, +impl<'writer> BasePacketWriter<'writer> { + fn new() -> Self { + Self { page_vals : HashMap::new(), - }; - } - pub fn into_inner(self) -> T { - self.wtr - } - /// Access the interior writer - /// - /// This allows access of the writer contained inside. - /// No guarantees are given onto the pattern of the writes. - /// They may change in the future. - pub fn inner(&self) -> &T { - &self.wtr - } - /// Access the interior writer mutably - /// - /// This allows access of the writer contained inside. - /// No guarantees are given onto the pattern of the writes. - /// They may change in the future. - pub fn inner_mut(&mut self) -> &mut T { - &mut self.wtr + } } - /// Write a packet - /// - /// - pub fn write_packet>>(&mut self, pck_cont :P, - serial :u32, - inf :PacketWriteEndInfo, - /* TODO find a better way to design the API around - passing the absgp to the underlying implementation. - e.g. the caller passes a closure on init which gets - called when we encounter a new page... with the param - the index inside the current page, or something. - */ - absgp :u64) -> IoResult<()> { - let is_end_stream :bool = inf == PacketWriteEndInfo::EndStream; + + fn write_packet(&mut self, pck_cont :Cow<'writer, [u8]>, serial :u32, + inf :PacketWriteEndInfo, absgp :u64, + mut sink_func :impl FnMut(&[u8]) -> bool) -> bool { + let is_end_stream = inf == PacketWriteEndInfo::EndStream; let pg = self.page_vals.entry(serial).or_insert( CurrentPageValues { first_page : true, @@ -142,14 +118,13 @@ impl <'writer, T :io::Write> PacketWriter<'writer, T> { } ); - let pck_cont = pck_cont.into(); let cont_len = pck_cont.len(); pg.cur_pg_data.push((pck_cont, absgp)); let last_data_segment_size = (cont_len % 255) as u8; - let needed_segments :usize = (cont_len / 255) + 1; - let mut segment_in_page_i :u8 = pg.segment_cnt; - let mut at_page_end :bool = false; + let needed_segments = (cont_len / 255) + 1; + let mut segment_in_page_i = pg.segment_cnt; + let mut at_page_end = false; for segment_i in 0 .. needed_segments { at_page_end = false; if segment_i + 1 < needed_segments { @@ -166,13 +141,11 @@ impl <'writer, T :io::Write> PacketWriter<'writer, T> { if segment_i + 1 < needed_segments { // We have to flush a page, but we know there are more to come... pg.pck_this_overflow_idx = Some((segment_i + 1) * 255); - tri!(PacketWriter::write_page(&mut self.wtr, serial, pg, - false)); + bail_out_on_fail!(Self::write_page(serial, pg, false, &mut sink_func)); } else { // We have to write a page end, and it's the very last // we need to write - tri!(PacketWriter::write_page(&mut self.wtr, - serial, pg, is_end_stream)); + bail_out_on_fail!(Self::write_page(serial, pg, is_end_stream, &mut sink_func)); // Not actually required // (it is always None except if we set it to Some directly // before we call write_page) @@ -183,98 +156,102 @@ impl <'writer, T :io::Write> PacketWriter<'writer, T> { at_page_end = true; } } - if (inf != PacketWriteEndInfo::NormalPacket) && !at_page_end { + if inf != PacketWriteEndInfo::NormalPacket && !at_page_end { // Write a page end - tri!(PacketWriter::write_page(&mut self.wtr, serial, pg, - is_end_stream)); - - pg.pck_last_overflow_idx = None; - - // TODO if inf was PacketWriteEndInfo::EndStream, we have to - // somehow erase pg from the hashmap... - // any ideas? perhaps needs external scope... + bail_out_on_fail!(Self::write_page(serial, pg, is_end_stream, &mut sink_func)); + } + // When ending the logical bitstream there is no point in keeping + // around page data. + if is_end_stream { + self.page_vals.remove(&serial); } + // All went fine. - Ok(()) + true } - fn write_page(wtr :&mut T, serial :u32, pg :&mut CurrentPageValues, - last_page :bool) -> IoResult<()> { - { - // The page header with everything but the lacing values: - let mut hdr_cur = Cursor::new(Vec::with_capacity(27)); - tri!(hdr_cur.write_all(&[0x4f, 0x67, 0x67, 0x53, 0x00])); - let mut flags :u8 = 0; - if pg.pck_last_overflow_idx.is_some() { flags |= 0x01; } - if pg.first_page { flags |= 0x02; } - if last_page { flags |= 0x04; } - - tri!(hdr_cur.write_u8(flags)); - - let pck_data = &pg.cur_pg_data; - - let mut last_finishing_pck_absgp = (-1i64) as u64; - for (idx, &(_, absgp)) in pck_data.iter().enumerate() { - if !(idx + 1 == pck_data.len() && - pg.pck_this_overflow_idx.is_some()) { - last_finishing_pck_absgp = absgp; - } + fn write_page(serial :u32, pg :&mut CurrentPageValues, last_page :bool, + mut sink_func :impl FnMut(&[u8]) -> bool) -> bool { + // The page header with everything but the lacing values: + let mut hdr = Vec::with_capacity(27); + + // Capture pattern. + hdr.extend_from_slice(b"OggS"); + + // Ogg format version, always zero. + hdr.push(0); + + let mut flags = 0; + if pg.pck_last_overflow_idx.is_some() { flags |= 0x01; } + if pg.first_page { flags |= 0x02; } + if last_page { flags |= 0x04; } + hdr.push(flags); + + let pck_data = &pg.cur_pg_data; + + let mut last_finishing_pck_absgp = (-1i64) as u64; + for (idx, (_, absgp)) in pck_data.iter().enumerate() { + if !(idx + 1 == pck_data.len() && + pg.pck_this_overflow_idx.is_some()) { + last_finishing_pck_absgp = *absgp; } + } - tri!(hdr_cur.write_u64::(last_finishing_pck_absgp)); - tri!(hdr_cur.write_u32::(serial)); - tri!(hdr_cur.write_u32::(pg.sequence_num)); + macro_rules! write_le { + ($sink:expr, $number:expr) => { + $sink.extend_from_slice(&$number.to_le_bytes()[..]) + } + } - // checksum, calculated later on :) - tri!(hdr_cur.write_u32::(0)); + write_le!(hdr, last_finishing_pck_absgp); + write_le!(hdr, serial); + write_le!(hdr, pg.sequence_num); - tri!(hdr_cur.write_u8(pg.segment_cnt)); + // checksum, calculated later on :) + write_le!(hdr, 0_u32); - let mut hash_calculated :u32; + write_le!(hdr, pg.segment_cnt); - let pg_lacing = &pg.cur_pg_lacing[0 .. pg.segment_cnt as usize]; + let pg_lacing = &pg.cur_pg_lacing[0 .. pg.segment_cnt as usize]; - hash_calculated = vorbis_crc32_update(0, hdr_cur.get_ref()); - hash_calculated = vorbis_crc32_update(hash_calculated, pg_lacing); + let mut hash_calculated = vorbis_crc32_update(0, &hdr); + hash_calculated = vorbis_crc32_update(hash_calculated, pg_lacing); - for (idx, &(ref pck, _)) in pck_data.iter().enumerate() { - let mut start :usize = 0; - if idx == 0 { if let Some(idx) = pg.pck_last_overflow_idx { - start = idx; - }} - let mut end :usize = pck.len(); - if idx + 1 == pck_data.len() { - if let Some(idx) = pg.pck_this_overflow_idx { - end = idx; - } + for (idx, (pck, _)) in pck_data.iter().enumerate() { + let mut start = 0; + if idx == 0 { if let Some(idx) = pg.pck_last_overflow_idx { + start = idx; + }} + let mut end = pck.len(); + if idx + 1 == pck_data.len() { + if let Some(idx) = pg.pck_this_overflow_idx { + end = idx; } - hash_calculated = vorbis_crc32_update(hash_calculated, - &pck[start .. end]); } + hash_calculated = vorbis_crc32_update(hash_calculated, + &pck[start .. end]); + } - // Go back to enter the checksum - // Don't do excessive checking here (that the seek - // succeeded & we are at the right pos now). - // It's hopefully not required. - tri!(hdr_cur.seek(SeekFrom::Start(22))); - tri!(hdr_cur.write_u32::(hash_calculated)); - - // Now all is done, write the stuff! - tri!(wtr.write_all(hdr_cur.get_ref())); - tri!(wtr.write_all(pg_lacing)); - for (idx, &(ref pck, _)) in pck_data.iter().enumerate() { - let mut start :usize = 0; - if idx == 0 { if let Some(idx) = pg.pck_last_overflow_idx { - start = idx; - }} - let mut end :usize = pck.len(); - if idx + 1 == pck_data.len() { - if let Some(idx) = pg.pck_this_overflow_idx { - end = idx; - } + // Go back to enter the checksum + hdr[22..26].copy_from_slice(&hash_calculated.to_le_bytes()[..]); + + // Now all is done, write the stuff! + // Bail out the function with a failure if some write error happens. + // This way the calling code could retry writing the page. + bail_out_on_fail!(sink_func(&hdr)); + bail_out_on_fail!(sink_func(pg_lacing)); + for (idx, (pck, _)) in pck_data.iter().enumerate() { + let mut start = 0; + if idx == 0 { if let Some(idx) = pg.pck_last_overflow_idx { + start = idx; + }} + let mut end = pck.len(); + if idx + 1 == pck_data.len() { + if let Some(idx) = pg.pck_this_overflow_idx { + end = idx; } - tri!(wtr.write_all(&pck[start .. end])); } + bail_out_on_fail!(sink_func(&pck[start .. end])); } // Reset the page. @@ -296,12 +273,65 @@ impl <'writer, T :io::Write> PacketWriter<'writer, T> { pg.pck_last_overflow_idx = pg.pck_this_overflow_idx; pg.pck_this_overflow_idx = None; - return Ok(()); + // All went fine. + true + } +} + +impl <'writer, T :Write> PacketWriter<'writer, T> { + /// Constructs a new `PacketWriter` with a given `Write`. + pub fn new(wtr :T) -> Self { + Self { + wtr, + base_pck_wtr : BasePacketWriter::new() + } + } + /// Returns the wrapped writer, consuming the PacketWriter. + pub fn into_inner(self) -> T { + self.wtr + } + /// Access the interior writer. + /// + /// This allows access of the writer contained inside. + /// No guarantees are given onto the pattern of the writes. + /// They may change in the future. + pub fn inner(&self) -> &T { + &self.wtr + } + /// Access the interior writer mutably. + /// + /// This allows access of the writer contained inside. + /// No guarantees are given onto the pattern of the writes. + /// They may change in the future. + pub fn inner_mut(&mut self) -> &mut T { + &mut self.wtr + } + /// Write a packet. + pub fn write_packet>>(&mut self, pck_cont :P, + serial :u32, + inf :PacketWriteEndInfo, + /* TODO find a better way to design the API around + passing the absgp to the underlying implementation. + e.g. the caller passes a closure on init which gets + called when we encounter a new page... with the param + the index inside the current page, or something. + */ + absgp :u64) -> Result<()> { + let pck_cont = pck_cont.into(); + + let mut io_err = None; + self.base_pck_wtr.write_packet(pck_cont, serial, inf, absgp, + |ogg_data| match self.wtr.write_all(ogg_data) { + Ok(()) => true, + Err(err) => { io_err = Some(err); false } + }); + + io_err.map_or(Ok(()), Err) } } -impl PacketWriter<'_, T> { - pub fn get_current_offs(&mut self) -> Result { +impl PacketWriter<'_, T> { + pub fn get_current_offs(&mut self) -> Result { self.wtr.seek(SeekFrom::Current(0)) } } @@ -313,8 +343,9 @@ fn test_recapture() { // Test that we can deal with recapture // at varying distances. // This is a regression test - use std::io::Write; + use std::io::{Cursor, Write}; use super::PacketReader; + let mut c = Cursor::new(Vec::new()); let test_arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; let test_arr_2 = [2, 4, 8, 16, 32, 64, 128, 127, 126, 125, 124]; @@ -345,3 +376,126 @@ fn test_recapture() { assert_eq!(test_arr_3, *p3.data); } } + +/// Asynchronous Ogg encoding. +#[cfg(feature = "async")] +pub mod async_api { + use std::io; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use super::*; + use futures_sink::Sink; + use futures_io::AsyncWrite as FuturesAsyncWrite; + use tokio::io::AsyncWrite as TokioAsyncWrite; + use bytes::BytesMut; + use pin_project::pin_project; + use tokio_util::codec::{Encoder, FramedWrite}; + use tokio_util::compat::{Compat, FuturesAsyncWriteCompatExt}; + + struct PacketEncoder<'writer> { + base_pck_wtr :BasePacketWriter<'writer>, + } + + impl<'writer, 'packet :'writer> Encoder> for PacketEncoder<'writer> { + type Error = io::Error; + + fn encode(&mut self, item :Packet<'packet>, dst :&mut BytesMut) -> Result<()> { + // An encoder only cares about encapsulating data in the proper format, + // which in this case is Ogg packets in Ogg pages, to a memory buffer. + // Memory operations are assumed to be infallible, so the base packet + // writer sink function can be a simple passthrough. Flushing and + // actual I/O is taken care of by FramedWrite. Observe that writing + // a packet is not guaranteed to generate any page, and thus any data + // to encode, unless it ends a bitstream or forces the end the page it + // belongs to. + self.base_pck_wtr.write_packet(item.data, item.serial, item.inf, item.absgp, + |ogg_data| { dst.extend_from_slice(ogg_data); true } + ); + + Ok(()) + } + } + + /// Asynchronous writer for packets into an Ogg stream. + /// + /// Please read the documentation of the [`Packet::inf`] field for more + /// information about the not-so-obvious semantics of flushing this sink. + #[pin_project] + pub struct PacketWriter<'writer, W :TokioAsyncWrite> { + #[pin] + sink :FramedWrite>, + } + + /// A Ogg packet that may be fed to a [`PacketWriter`]. + pub struct Packet<'packet> { + /// The data the packet contains. + pub data :Cow<'packet, [u8]>, + /// The serial of the stream this packet belongs to. + pub serial :u32, + /// Specifies whether to end something with the write of the packet. + /// + /// Note that flushing a [`PacketWriter`] alone does not guarantee that + /// every Ogg packet so far has made it to the destination: normally, + /// packets are stuffed into pages as possible, and then those pages are + /// written. Flushing will only write pending pages, thus packets that + /// belong to yet incomplete pages will not immediately generate anything + /// to flush. + /// + /// A packet has to forcibly end the page or stream it belongs to in order + /// to ensure that it is written on a flush. That can be done by setting + /// this value accordingly. + pub inf :PacketWriteEndInfo, + /// The granule position of the packet. + pub absgp :u64, + } + + impl PacketWriter<'_, W> { + /// Wraps the specified Tokio runtime `AsyncWrite` into an Ogg packet + /// writer. + /// + /// This is the recommended constructor when using the Tokio runtime + /// types. + pub fn new(inner :W) -> Self { + Self { + sink : FramedWrite::new(inner, PacketEncoder { + base_pck_wtr : BasePacketWriter::new() + }), + } + } + } + + impl PacketWriter<'_, Compat> { + /// Wraps the specified futures_io `AsyncWrite` into an Ogg packet + /// writer. + /// + /// This crate uses Tokio internally, so a wrapper that may have + /// some performance cost will be used. Therefore, this constructor + /// is to be used only when dealing with `AsyncWrite` implementations + /// from other runtimes, and implementing a Tokio `AsyncWrite` + /// compatibility layer oneself is not desired. + pub fn new_compat(inner :W) -> Self { + Self::new(inner.compat_write()) + } + } + + impl<'writer, 'packet :'writer, W :TokioAsyncWrite> Sink> for PacketWriter<'writer, W> { + type Error = io::Error; + + fn poll_ready(self :Pin<&mut Self>, cx :&mut Context<'_>) -> Poll> { + self.project().sink.poll_ready(cx) + } + + fn start_send(self :Pin<&mut Self>, item :Packet<'packet>) -> Result<()> { + self.project().sink.start_send(item) + } + + fn poll_flush(self :Pin<&mut Self>, cx :&mut Context<'_>) -> Poll> { + self.project().sink.poll_flush(cx) + } + + fn poll_close(self :Pin<&mut Self>, cx :&mut Context<'_>) -> Poll> { + self.project().sink.poll_close(cx) + } + } +}