From 6b31af70a8873418bfce5319de7dd93a68f0db99 Mon Sep 17 00:00:00 2001 From: Luke Curley Date: Thu, 21 Mar 2024 09:37:36 -0700 Subject: [PATCH 1/7] WIP --- Cargo.lock | 12 +++-- moq-relay/src/connection.rs | 4 +- moq-transport/Cargo.toml | 4 +- moq-transport/src/coding/decode.rs | 6 +-- moq-transport/src/coding/encode.rs | 2 +- moq-transport/src/coding/varint.rs | 7 --- moq-transport/src/session/announce.rs | 10 ++-- moq-transport/src/session/announced.rs | 26 +++++----- moq-transport/src/session/error.rs | 17 ++----- moq-transport/src/session/mod.rs | 67 +++++++++++-------------- moq-transport/src/session/publisher.rs | 44 ++++++++-------- moq-transport/src/session/subscribe.rs | 36 ++++++------- moq-transport/src/session/subscribed.rs | 39 +++++++------- moq-transport/src/session/subscriber.rs | 42 ++++++++-------- 14 files changed, 146 insertions(+), 170 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e8361bbe..7b3e16c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -944,7 +944,6 @@ dependencies = [ "log", "mp4", "paste", - "quinn", "rfc6381-codec", "rustls", "rustls-native-certs", @@ -955,7 +954,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "webtransport-quinn", + "webtransport-generic 0.5.0", ] [[package]] @@ -2181,6 +2180,13 @@ version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" +[[package]] +name = "webtransport-generic" +version = "0.5.0" +dependencies = [ + "tokio", +] + [[package]] name = "webtransport-generic" version = "0.5.0" @@ -2218,7 +2224,7 @@ dependencies = [ "thiserror", "tokio", "url", - "webtransport-generic", + "webtransport-generic 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", "webtransport-proto", ] diff --git a/moq-relay/src/connection.rs b/moq-relay/src/connection.rs index ce99c322..708c25e8 100644 --- a/moq-relay/src/connection.rs +++ b/moq-relay/src/connection.rs @@ -58,7 +58,7 @@ impl Connection { Ok(()) } - async fn serve_publisher(mut publisher: Publisher, origin: Origin) -> Result<(), SessionError> { + async fn serve_publisher(mut publisher: Publisher, origin: Origin) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); loop { @@ -77,7 +77,7 @@ impl Connection { } } - async fn serve_subscriber(mut subscriber: Subscriber, origin: Origin) -> Result<(), SessionError> { + async fn serve_subscriber(mut subscriber: Subscriber, origin: Origin) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); loop { diff --git a/moq-transport/Cargo.toml b/moq-transport/Cargo.toml index d41af021..01cf9ec6 100644 --- a/moq-transport/Cargo.toml +++ b/moq-transport/Cargo.toml @@ -20,9 +20,7 @@ thiserror = "1" tokio = { version = "1", features = ["macros", "io-util", "sync"] } log = "0.4" -quinn = "0.10" -webtransport-quinn = "0.7" -#webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn" } +webtransport-generic = { path = "../../webtransport-rs/webtransport-generic" } async-trait = "0.1" paste = "1" diff --git a/moq-transport/src/coding/decode.rs b/moq-transport/src/coding/decode.rs index 01d98a6a..a4f5894e 100644 --- a/moq-transport/src/coding/decode.rs +++ b/moq-transport/src/coding/decode.rs @@ -6,9 +6,9 @@ use thiserror::Error; // I'm too lazy to add these trait bounds to every message type. // TODO Use trait aliases when they're stable, or add these bounds to every method. pub trait AsyncRead: tokio::io::AsyncRead + Unpin + Send {} -impl AsyncRead for webtransport_quinn::RecvStream {} -impl AsyncRead for tokio::io::Take<&mut T> where T: AsyncRead {} -impl + Unpin + Send> AsyncRead for io::Cursor {} +impl AsyncRead for T {} +//impl AsyncRead for tokio::io::Take<&mut T> where T: AsyncRead {} +//impl + Unpin + Send> AsyncRead for io::Cursor {} #[async_trait::async_trait] pub trait Decode: Sized { diff --git a/moq-transport/src/coding/encode.rs b/moq-transport/src/coding/encode.rs index 9510e05b..6455487d 100644 --- a/moq-transport/src/coding/encode.rs +++ b/moq-transport/src/coding/encode.rs @@ -5,7 +5,7 @@ use thiserror::Error; // I'm too lazy to add these trait bounds to every message type. // TODO Use trait aliases when they're stable, or add these bounds to every method. pub trait AsyncWrite: tokio::io::AsyncWrite + Unpin + Send {} -impl AsyncWrite for webtransport_quinn::SendStream {} +impl AsyncWrite for dyn webtransport_generic::SendStream {} impl AsyncWrite for Vec {} #[async_trait::async_trait] diff --git a/moq-transport/src/coding/varint.rs b/moq-transport/src/coding/varint.rs index 4f72f096..e0e927d6 100644 --- a/moq-transport/src/coding/varint.rs +++ b/moq-transport/src/coding/varint.rs @@ -232,13 +232,6 @@ impl Encode for VarInt { } } -// This is a fork of quinn::VarInt. -impl From for VarInt { - fn from(v: quinn::VarInt) -> Self { - Self(v.into_inner()) - } -} - #[async_trait::async_trait] impl Encode for u64 { /// Encode a varint to the given writer. diff --git a/moq-transport/src/session/announce.rs b/moq-transport/src/session/announce.rs index dba276f7..93a9f07d 100644 --- a/moq-transport/src/session/announce.rs +++ b/moq-transport/src/session/announce.rs @@ -2,14 +2,14 @@ use crate::{message, serve::ServeError, util::Watch}; use super::Publisher; -pub struct Announce { - session: Publisher, +pub struct Announce { + session: Publisher, msg: message::Announce, state: Watch, } -impl Announce { - pub(super) fn new(session: Publisher, msg: message::Announce) -> (Announce, AnnounceRecv) { +impl Announce { + pub(super) fn new(session: Publisher, msg: message::Announce) -> (Announce, AnnounceRecv) { let state = Watch::default(); let recv = AnnounceRecv { state: state.clone() }; @@ -49,7 +49,7 @@ impl Announce { } } -impl Drop for Announce { +impl Drop for Announce { fn drop(&mut self) { self.close().ok(); self.session.drop_announce(&self.msg.namespace); diff --git a/moq-transport/src/session/announced.rs b/moq-transport/src/session/announced.rs index 9ffacfee..40a38baa 100644 --- a/moq-transport/src/session/announced.rs +++ b/moq-transport/src/session/announced.rs @@ -2,14 +2,14 @@ use crate::{message, serve::ServeError, util::Watch}; use super::Subscriber; -pub struct Announced { - session: Subscriber, +pub struct Announced { + session: Subscriber, namespace: String, - state: Watch, + state: Watch>, } -impl Announced { - pub(super) fn new(session: Subscriber, namespace: String) -> (Announced, AnnouncedRecv) { +impl Announced { + pub(super) fn new(session: Subscriber, namespace: String) -> (Announced, AnnouncedRecv) { let state = Watch::new(State::new(session.clone(), namespace.clone())); let recv = AnnouncedRecv { state: state.clone() }; @@ -48,32 +48,32 @@ impl Announced { } } -impl Drop for Announced { +impl Drop for Announced { fn drop(&mut self) { self.close(ServeError::Done).ok(); self.session.drop_announce(&self.namespace); } } -pub(super) struct AnnouncedRecv { - state: Watch, +pub(super) struct AnnouncedRecv { + state: Watch>, } -impl AnnouncedRecv { +impl AnnouncedRecv { pub fn recv_unannounce(&mut self) -> Result<(), ServeError> { self.state.lock_mut().close(ServeError::Done) } } -struct State { +struct State { namespace: String, - session: Subscriber, + session: Subscriber, ok: bool, closed: Result<(), ServeError>, } -impl State { - fn new(session: Subscriber, namespace: String) -> Self { +impl State { + fn new(session: Subscriber, namespace: String) -> Self { Self { session, namespace, diff --git a/moq-transport/src/session/error.rs b/moq-transport/src/session/error.rs index 1f2320a3..34a83b90 100644 --- a/moq-transport/src/session/error.rs +++ b/moq-transport/src/session/error.rs @@ -1,9 +1,10 @@ use crate::{coding, serve, setup}; #[derive(thiserror::Error, Debug, Clone)] -pub enum SessionError { +pub enum SessionError { + // We can't use #[from] here because it would conflict with #[error("webtransport error: {0}")] - WebTransport(#[from] webtransport_quinn::SessionError), + WebTransport(S::Error), #[error("encode error: {0}")] Encode(#[from] coding::EncodeError), @@ -19,14 +20,6 @@ pub enum SessionError { #[error("incompatible roles: client={0:?} server={1:?}")] RoleIncompatible(setup::Role, setup::Role), - /// An error occured while reading from the QUIC stream. - #[error("failed to read from stream: {0}")] - Read(#[from] webtransport_quinn::ReadError), - - /// An error occured while writing to the QUIC stream. - #[error("failed to write to stream: {0}")] - Write(#[from] webtransport_quinn::WriteError), - /// The role negiotiated in the handshake was violated. For example, a publisher sent a SUBSCRIBE, or a subscriber sent an OBJECT. #[error("role violation")] RoleViolation, @@ -49,14 +42,12 @@ pub enum SessionError { WrongSize, } -impl SessionError { +impl SessionError { /// An integer code that is sent over the wire. pub fn code(&self) -> u64 { match self { Self::RoleIncompatible(..) => 406, Self::RoleViolation => 405, - Self::Write(_) => 501, - Self::Read(_) => 400, Self::WebTransport(_) => 503, Self::Version(..) => 406, Self::Decode(_) => 400, diff --git a/moq-transport/src/session/mod.rs b/moq-transport/src/session/mod.rs index 6c3599c8..a7a89c63 100644 --- a/moq-transport/src/session/mod.rs +++ b/moq-transport/src/session/mod.rs @@ -16,28 +16,25 @@ pub use subscriber::*; use futures::FutureExt; use futures::{stream::FuturesUnordered, StreamExt}; -use webtransport_quinn::{RecvStream, SendStream}; use crate::{message, setup, util::Queue}; -type Messages = Queue; +pub struct Session { + webtransport: S, + control: (S::SendStream, S::RecvStream), -pub struct Session { - webtransport: webtransport_quinn::Session, - control: (SendStream, RecvStream), - - publisher: Option, - subscriber: Option, - outgoing: Messages, + publisher: Option>, + subscriber: Option>, + outgoing: Queue>, } -impl Session { +impl Session { fn new( - webtransport: webtransport_quinn::Session, - control: (SendStream, RecvStream), + webtransport: S, + control: (S::SendStream, S::RecvStream), role: setup::Role, - ) -> (Self, Option, Option) { - let outgoing = Messages::::default(); + ) -> (Self, Option>, Option>) { + let outgoing = Default::default(); let publisher = role .is_publisher() @@ -56,15 +53,15 @@ impl Session { } pub async fn connect( - session: webtransport_quinn::Session, - ) -> Result<(Session, Option, Option), SessionError> { + session: S, + ) -> Result<(Session, Option>, Option>), SessionError> { Self::connect_role(session, setup::Role::Both).await } pub async fn connect_role( - session: webtransport_quinn::Session, + session: S, role: setup::Role, - ) -> Result<(Session, Option, Option), SessionError> { + ) -> Result<(Session, Option>, Option>), SessionError> { let mut control = session.open_bi().await?; let versions: setup::Versions = [setup::Version::DRAFT_03].into(); @@ -101,15 +98,15 @@ impl Session { } pub async fn accept( - session: webtransport_quinn::Session, - ) -> Result<(Session, Option, Option), SessionError> { + session: S, + ) -> Result<(Session, Option>, Option>), SessionError> { Self::accept_role(session, setup::Role::Both).await } pub async fn accept_role( - session: webtransport_quinn::Session, + session: S, role: setup::Role, - ) -> Result<(Session, Option, Option), SessionError> { + ) -> Result<(Session, Option>, Option>), SessionError> { let mut control = session.accept_bi().await?; let client = setup::Client::decode(&mut control.1).await?; @@ -151,7 +148,7 @@ impl Session { Ok(Session::new(session, control, role)) } - pub async fn run(self) -> Result<(), SessionError> { + pub async fn run(self) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); tasks.push(Self::run_send(self.outgoing, self.control.0).boxed()); tasks.push(Self::run_recv(self.control.1, self.publisher, self.subscriber.clone()).boxed()); @@ -166,9 +163,9 @@ impl Session { } async fn run_send( - outgoing: Queue, - mut stream: SendStream, - ) -> Result<(), SessionError> { + outgoing: Queue>, + mut stream: S::SendStream, + ) -> Result<(), SessionError> { loop { let msg = outgoing.pop().await?; msg.encode(&mut stream).await?; @@ -176,10 +173,10 @@ impl Session { } async fn run_recv( - mut stream: RecvStream, - mut publisher: Option, - mut subscriber: Option, - ) -> Result<(), SessionError> { + mut stream: S::RecvStream, + mut publisher: Option>, + mut subscriber: Option>, + ) -> Result<(), SessionError> { loop { let msg = message::Message::decode(&mut stream).await?; @@ -210,10 +207,7 @@ impl Session { } } - async fn run_streams( - webtransport: webtransport_quinn::Session, - subscriber: Subscriber, - ) -> Result<(), SessionError> { + async fn run_streams(webtransport: S, subscriber: Subscriber) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); loop { @@ -227,10 +221,7 @@ impl Session { } } - async fn run_datagrams( - webtransport: webtransport_quinn::Session, - mut subscriber: Subscriber, - ) -> Result<(), SessionError> { + async fn run_datagrams(webtransport: S, mut subscriber: Subscriber) -> Result<(), SessionError> { loop { let datagram = webtransport.read_datagram().await?; subscriber.recv_datagram(datagram).await?; diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 06ff8f60..154f55c3 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -16,18 +16,18 @@ use super::{Announce, AnnounceRecv, Session, SessionError, Subscribed, Subscribe // TODO remove Clone. #[derive(Clone)] -pub struct Publisher { - webtransport: webtransport_quinn::Session, +pub struct Publisher { + webtransport: S, announces: Arc>>, - subscribed: Arc>>, - subscribed_queue: Queue, + subscribed: Arc>>>, + subscribed_queue: Queue, SessionError>, - outgoing: Queue, + outgoing: Queue>, } -impl Publisher { - pub(crate) fn new(webtransport: webtransport_quinn::Session, outgoing: Queue) -> Self { +impl Publisher { + pub(crate) fn new(webtransport: S, outgoing: Queue>) -> Self { Self { webtransport, announces: Default::default(), @@ -37,17 +37,17 @@ impl Publisher { } } - pub async fn accept(session: webtransport_quinn::Session) -> Result<(Session, Self), SessionError> { + pub async fn accept(session: S) -> Result<(Session, Publisher), SessionError> { let (session, publisher, _) = Session::accept_role(session, setup::Role::Publisher).await?; Ok((session, publisher.unwrap())) } - pub async fn connect(session: webtransport_quinn::Session) -> Result<(Session, Self), SessionError> { + pub async fn connect(session: S) -> Result<(Session, Publisher), SessionError> { let (session, publisher, _) = Session::connect_role(session, setup::Role::Publisher).await?; Ok((session, publisher.unwrap())) } - pub fn announce(&mut self, namespace: &str) -> Result { + pub fn announce(&mut self, namespace: &str) -> Result, SessionError> { let mut announces = self.announces.lock().unwrap(); // Insert the abort handle into the lookup table. @@ -68,13 +68,13 @@ impl Publisher { Ok(announce) } - pub async fn subscribed(&mut self) -> Result { + pub async fn subscribed(&mut self) -> Result, SessionError> { self.subscribed_queue.pop().await } // Helper to announce and serve any matching subscribers. // TODO this currently takes over the connection; definitely remove Clone - pub async fn serve(mut self, broadcast: serve::BroadcastSubscriber) -> Result<(), SessionError> { + pub async fn serve(mut self, broadcast: serve::BroadcastSubscriber) -> Result<(), SessionError> { log::info!("serving broadcast: {}", broadcast.namespace); let announce = self.announce(&broadcast.namespace)?; @@ -107,7 +107,7 @@ impl Publisher { fn serve_track( &self, broadcast: &serve::BroadcastSubscriber, - subscribe: &Subscribed, + subscribe: &Subscribed, ) -> Result { if subscribe.namespace() != broadcast.namespace { return Err(ServeError::NotFound); @@ -116,7 +116,7 @@ impl Publisher { broadcast.get_track(subscribe.name())?.ok_or(ServeError::NotFound) } - pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { + pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { log::debug!("received message: {:?}", msg); match msg { @@ -128,13 +128,13 @@ impl Publisher { } } - fn recv_announce_ok(&mut self, _msg: message::AnnounceOk) -> Result<(), SessionError> { + fn recv_announce_ok(&mut self, _msg: message::AnnounceOk) -> Result<(), SessionError> { // Who cares // TODO make AnnouncePending so we're forced to care Ok(()) } - fn recv_announce_error(&mut self, msg: message::AnnounceError) -> Result<(), SessionError> { + fn recv_announce_error(&mut self, msg: message::AnnounceError) -> Result<(), SessionError> { if let Some(announce) = self.announces.lock().unwrap().get_mut(&msg.namespace) { announce.recv_error(ServeError::Closed(msg.code)).ok(); } @@ -142,11 +142,11 @@ impl Publisher { Ok(()) } - fn recv_announce_cancel(&mut self, _msg: message::AnnounceCancel) -> Result<(), SessionError> { + fn recv_announce_cancel(&mut self, _msg: message::AnnounceCancel) -> Result<(), SessionError> { unimplemented!("recv_announce_cancel") } - fn recv_subscribe(&mut self, msg: message::Subscribe) -> Result<(), SessionError> { + fn recv_subscribe(&mut self, msg: message::Subscribe) -> Result<(), SessionError> { let mut subscribes = self.subscribed.lock().unwrap(); // Insert the abort handle into the lookup table. @@ -160,7 +160,7 @@ impl Publisher { self.subscribed_queue.push(subscribe) } - fn recv_unsubscribe(&mut self, msg: message::Unsubscribe) -> Result<(), SessionError> { + fn recv_unsubscribe(&mut self, msg: message::Unsubscribe) -> Result<(), SessionError> { if let Some(subscribed) = self.subscribed.lock().unwrap().get_mut(&msg.id) { subscribed.recv_unsubscribe().ok(); } @@ -168,7 +168,7 @@ impl Publisher { Ok(()) } - pub fn send_message>(&self, msg: T) -> Result<(), SessionError> { + pub fn send_message>(&self, msg: T) -> Result<(), SessionError> { let msg = msg.into(); log::debug!("sending message: {:?}", msg); self.outgoing.push(msg.into()) @@ -182,11 +182,11 @@ impl Publisher { self.announces.lock().unwrap().remove(namespace); } - pub(super) fn webtransport(&mut self) -> &mut webtransport_quinn::Session { + pub(super) fn webtransport(&mut self) -> &mut S { &mut self.webtransport } - pub fn close(self, err: SessionError) { + pub fn close(self, err: SessionError) { self.outgoing.close(err.clone()).ok(); self.subscribed_queue.close(err).ok(); } diff --git a/moq-transport/src/session/subscribe.rs b/moq-transport/src/session/subscribe.rs index b0471017..f27bf8c9 100644 --- a/moq-transport/src/session/subscribe.rs +++ b/moq-transport/src/session/subscribe.rs @@ -9,15 +9,15 @@ use crate::{ use super::{SessionError, Subscriber}; -pub struct Subscribe { - session: Subscriber, +pub struct Subscribe { + session: Subscriber, id: u64, track: serve::TrackSubscriber, state: Watch, } -impl Subscribe { - pub(super) fn new(session: Subscriber, msg: message::Subscribe) -> (SubscribeRecv, Subscribe) { +impl Subscribe { + pub(super) fn new(session: Subscriber, msg: message::Subscribe) -> (Subscribe, SubscribeRecv) { let state = Watch::new(State::default()); let (publisher, subscriber) = serve::Track { @@ -37,7 +37,7 @@ impl Subscribe { let publisher = SubscribeRecv::new(state, publisher); - (publisher, subscriber) + (subscriber, publisher) } // Waits until an OK message is received. @@ -78,7 +78,7 @@ impl Subscribe { } } -impl Drop for Subscribe { +impl Drop for Subscribe { fn drop(&mut self) { let msg = message::Unsubscribe { id: self.id }; self.session.send_message(msg).ok(); @@ -88,12 +88,12 @@ impl Drop for Subscribe { } #[derive(Clone)] -pub(super) struct SubscribeRecv { +pub(super) struct SubscribeRecv { publisher: Arc>, state: Watch, } -impl SubscribeRecv { +impl SubscribeRecv { fn new(state: Watch, publisher: serve::TrackPublisher) -> Self { Self { publisher: Arc::new(Mutex::new(publisher)), @@ -117,11 +117,7 @@ impl SubscribeRecv { Ok(()) } - pub async fn recv_stream( - &mut self, - header: data::Header, - stream: webtransport_quinn::RecvStream, - ) -> Result<(), SessionError> { + pub async fn recv_stream(&mut self, header: data::Header, stream: S::RecvStream) -> Result<(), SessionError> { match header { data::Header::Track(track) => self.recv_track(track, stream).await, data::Header::Group(group) => self.recv_group(group, stream).await, @@ -132,8 +128,8 @@ impl SubscribeRecv { async fn recv_track( &mut self, header: data::TrackHeader, - mut stream: webtransport_quinn::RecvStream, - ) -> Result<(), SessionError> { + mut stream: S::RecvStream, + ) -> Result<(), SessionError> { log::trace!("received track: {:?}", header); let mut track = self.publisher.lock().unwrap().create_stream(header.send_order)?; @@ -165,8 +161,8 @@ impl SubscribeRecv { async fn recv_group( &mut self, header: data::GroupHeader, - mut stream: webtransport_quinn::RecvStream, - ) -> Result<(), SessionError> { + mut stream: S::RecvStream, + ) -> Result<(), SessionError> { log::trace!("received group: {:?}", header); let mut group = self.publisher.lock().unwrap().create_group(serve::Group { @@ -193,8 +189,8 @@ impl SubscribeRecv { async fn recv_object( &mut self, header: data::ObjectHeader, - mut stream: webtransport_quinn::RecvStream, - ) -> Result<(), SessionError> { + mut stream: S::RecvStream, + ) -> Result<(), SessionError> { log::trace!("received object: {:?}", header); // TODO avoid buffering the entire object to learn the size. @@ -220,7 +216,7 @@ impl SubscribeRecv { Ok(()) } - pub fn recv_datagram(&self, datagram: data::Datagram) -> Result<(), SessionError> { + pub fn recv_datagram(&self, datagram: data::Datagram) -> Result<(), SessionError> { log::trace!("received datagram: {:?}", datagram); self.publisher.lock().unwrap().write_datagram(serve::Datagram { diff --git a/moq-transport/src/session/subscribed.rs b/moq-transport/src/session/subscribed.rs index b716a227..48af20c5 100644 --- a/moq-transport/src/session/subscribed.rs +++ b/moq-transport/src/session/subscribed.rs @@ -1,5 +1,6 @@ use futures::stream::FuturesUnordered; use futures::{FutureExt, StreamExt}; +use tokio::io::AsyncWriteExt; use crate::serve::ServeError; use crate::util::{Watch, WatchWeak}; @@ -8,14 +9,14 @@ use crate::{data, message, serve}; use super::{Publisher, SessionError}; #[derive(Clone)] -pub struct Subscribed { - session: Publisher, - state: Watch, +pub struct Subscribed { + session: Publisher, + state: Watch>, msg: message::Subscribe, } -impl Subscribed { - pub(super) fn new(session: Publisher, msg: message::Subscribe) -> (Subscribed, SubscribedRecv) { +impl Subscribed { + pub(super) fn new(session: Publisher, msg: message::Subscribe) -> (Subscribed, SubscribedRecv) { let state = Watch::new(State::new(session.clone(), msg.id)); let recv = SubscribedRecv { state: state.downgrade(), @@ -33,7 +34,7 @@ impl Subscribed { self.msg.track_name.as_str() } - pub async fn serve(mut self, mut track: serve::TrackSubscriber) -> Result<(), SessionError> { + pub async fn serve(mut self, mut track: serve::TrackSubscriber) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); self.state.lock_mut().ok(track.latest())?; @@ -60,7 +61,7 @@ impl Subscribed { } } - async fn serve_track(mut self, mut track: serve::StreamSubscriber) -> Result<(), SessionError> { + async fn serve_track(mut self, mut track: serve::StreamSubscriber) -> Result<(), SessionError> { let mut stream = self.session.webtransport().open_uni().await?; let header: data::Header = data::TrackHeader { @@ -98,7 +99,7 @@ impl Subscribed { Ok(()) } - pub async fn serve_group(mut self, mut group: serve::GroupSubscriber) -> Result<(), SessionError> { + pub async fn serve_group(mut self, mut group: serve::GroupSubscriber) -> Result<(), SessionError> { let mut stream = self.session.webtransport().open_uni().await?; let header: data::Header = data::GroupHeader { @@ -136,7 +137,7 @@ impl Subscribed { Ok(()) } - pub async fn serve_object(mut self, mut object: serve::ObjectSubscriber) -> Result<(), SessionError> { + pub async fn serve_object(mut self, mut object: serve::ObjectSubscriber) -> Result<(), SessionError> { let mut stream = self.session.webtransport().open_uni().await?; let header: data::Header = data::ObjectHeader { @@ -163,7 +164,7 @@ impl Subscribed { Ok(()) } - pub async fn serve_datagram(&mut self, datagram: serve::Datagram) -> Result<(), SessionError> { + pub async fn serve_datagram(&mut self, datagram: serve::Datagram) -> Result<(), SessionError> { let datagram = data::Datagram { subscribe_id: self.msg.id, track_alias: self.msg.track_alias, @@ -205,11 +206,11 @@ impl Subscribed { } } -pub(super) struct SubscribedRecv { - state: WatchWeak, +pub(super) struct SubscribedRecv { + state: WatchWeak>, } -impl SubscribedRecv { +impl SubscribedRecv { pub fn recv_unsubscribe(&mut self) -> Result<(), ServeError> { if let Some(state) = self.state.upgrade() { state.lock_mut().close(ServeError::Done)?; @@ -218,8 +219,8 @@ impl SubscribedRecv { } } -struct State { - session: Publisher, +struct State { + session: Publisher, id: u64, ok: bool, @@ -227,8 +228,8 @@ struct State { closed: Result<(), ServeError>, } -impl State { - fn new(session: Publisher, id: u64) -> Self { +impl State { + fn new(session: Publisher, id: u64) -> Self { Self { session, id, @@ -239,7 +240,7 @@ impl State { } } -impl State { +impl State { fn ok(&mut self, latest: Option<(u64, u64)>) -> Result<(), ServeError> { self.ok = true; self.max = latest; @@ -295,7 +296,7 @@ impl State { } } -impl Drop for State { +impl Drop for State { fn drop(&mut self) { self.close(ServeError::Done).ok(); self.session.drop_subscribe(self.id); diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index c0f8aef8..bc370100 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -10,18 +10,18 @@ use super::{Announced, AnnouncedRecv, Session, SessionError, Subscribe, Subscrib // TODO remove Clone. #[derive(Clone)] -pub struct Subscriber { - announced: Arc>>, - announced_queue: Queue, +pub struct Subscriber { + announced: Arc>>>, + announced_queue: Queue, SessionError>, - subscribes: Arc>>, + subscribes: Arc>>>, subscribe_next: Arc, - outgoing: Queue, + outgoing: Queue>, } -impl Subscriber { - pub(super) fn new(outgoing: Queue) -> Self { +impl Subscriber { + pub(super) fn new(outgoing: Queue>) -> Self { Self { announced: Default::default(), announced_queue: Default::default(), @@ -31,17 +31,17 @@ impl Subscriber { } } - pub async fn accept(session: webtransport_quinn::Session) -> Result<(Session, Self), SessionError> { + pub async fn accept(session: S) -> Result<(Session, Self), SessionError> { let (session, _, subscriber) = Session::accept_role(session, setup::Role::Subscriber).await?; Ok((session, subscriber.unwrap())) } - pub async fn connect(session: webtransport_quinn::Session) -> Result<(Session, Self), SessionError> { + pub async fn connect(session: S) -> Result<(Session, Self), SessionError> { let (session, _, subscriber) = Session::connect_role(session, setup::Role::Subscriber).await?; Ok((session, subscriber.unwrap())) } - pub async fn announced(&mut self) -> Result { + pub async fn announced(&mut self) -> Result, SessionError> { self.announced_queue.pop().await } @@ -50,7 +50,7 @@ impl Subscriber { namespace: &str, name: &str, options: SubscribeOptions, - ) -> Result { + ) -> Result, SessionError> { let id = self.subscribe_next.fetch_add(1, atomic::Ordering::Relaxed); let msg = message::Subscribe { @@ -71,13 +71,13 @@ impl Subscriber { Ok(subscribe) } - pub(super) fn send_message>(&mut self, msg: M) -> Result<(), SessionError> { + pub(super) fn send_message>(&mut self, msg: M) -> Result<(), SessionError> { let msg = msg.into(); log::debug!("sending message: {:?}", msg); self.outgoing.push(msg.into()) } - pub(super) fn recv_message(&mut self, msg: message::Publisher) -> Result<(), SessionError> { + pub(super) fn recv_message(&mut self, msg: message::Publisher) -> Result<(), SessionError> { log::debug!("received message: {:?}", msg); match msg { @@ -89,7 +89,7 @@ impl Subscriber { } } - fn recv_announce(&mut self, msg: message::Announce) -> Result<(), SessionError> { + fn recv_announce(&mut self, msg: message::Announce) -> Result<(), SessionError> { let mut announces = self.announced.lock().unwrap(); let entry = match announces.entry(msg.namespace.clone()) { @@ -104,7 +104,7 @@ impl Subscriber { Ok(()) } - fn recv_unannounce(&mut self, msg: message::Unannounce) -> Result<(), SessionError> { + fn recv_unannounce(&mut self, msg: message::Unannounce) -> Result<(), SessionError> { if let Some(announce) = self.announced.lock().unwrap().get_mut(&msg.namespace) { announce.recv_unannounce().ok(); } @@ -112,7 +112,7 @@ impl Subscriber { Ok(()) } - fn recv_subscribe_ok(&mut self, msg: message::SubscribeOk) -> Result<(), SessionError> { + fn recv_subscribe_ok(&mut self, msg: message::SubscribeOk) -> Result<(), SessionError> { if let Some(sub) = self.subscribes.lock().unwrap().get_mut(&msg.id) { sub.recv_ok(msg).ok(); } @@ -120,7 +120,7 @@ impl Subscriber { Ok(()) } - fn recv_subscribe_error(&mut self, msg: message::SubscribeError) -> Result<(), SessionError> { + fn recv_subscribe_error(&mut self, msg: message::SubscribeError) -> Result<(), SessionError> { if let Some(subscriber) = self.subscribes.lock().unwrap().get_mut(&msg.id) { subscriber.recv_error(msg.code).ok(); } @@ -128,7 +128,7 @@ impl Subscriber { Ok(()) } - fn recv_subscribe_done(&mut self, msg: message::SubscribeDone) -> Result<(), SessionError> { + fn recv_subscribe_done(&mut self, msg: message::SubscribeDone) -> Result<(), SessionError> { if let Some(subscriber) = self.subscribes.lock().unwrap().get_mut(&msg.id) { subscriber.recv_done(msg.code).ok(); } @@ -144,7 +144,7 @@ impl Subscriber { self.announced.lock().unwrap().remove(namespace); } - pub(super) async fn recv_stream(self, mut stream: webtransport_quinn::RecvStream) -> Result<(), SessionError> { + pub(super) async fn recv_stream(self, mut stream: S::RecvStream) -> Result<(), SessionError> { let header = data::Header::decode(&mut stream).await?; let id = header.subscribe_id(); @@ -158,7 +158,7 @@ impl Subscriber { } // TODO should not be async - pub async fn recv_datagram(&mut self, datagram: bytes::Bytes) -> Result<(), SessionError> { + pub async fn recv_datagram(&mut self, datagram: bytes::Bytes) -> Result<(), SessionError> { let mut cursor = io::Cursor::new(datagram); let datagram = data::Datagram::decode(&mut cursor).await?; @@ -171,7 +171,7 @@ impl Subscriber { Ok(()) } - pub fn close(self, err: SessionError) { + pub fn close(self, err: SessionError) { self.outgoing.close(err.clone()).ok(); self.announced_queue.close(err).ok(); } From 7a95e88672a3266fc3efbf8b461c02e6989fd767 Mon Sep 17 00:00:00 2001 From: Luke Curley Date: Thu, 21 Mar 2024 13:16:15 -0700 Subject: [PATCH 2/7] Generic works..? --- Cargo.lock | 50 +++++++++-- moq-clock/Cargo.toml | 3 +- moq-pub/Cargo.toml | 4 +- moq-relay/Cargo.toml | 2 + moq-relay/src/connection.rs | 72 +++++++++++----- moq-relay/src/quic.rs | 4 +- moq-transport/Cargo.toml | 1 - moq-transport/src/coding/decode.rs | 29 +++---- moq-transport/src/coding/encode.rs | 28 +++--- moq-transport/src/coding/mod.rs | 4 + moq-transport/src/coding/params.rs | 44 ++++------ moq-transport/src/coding/reader.rs | 64 ++++++++++++++ moq-transport/src/coding/string.rs | 34 ++++---- moq-transport/src/coding/varint.rs | 73 +++++++--------- moq-transport/src/coding/writer.rs | 36 ++++++++ moq-transport/src/data/datagram.rs | 47 +++++----- moq-transport/src/data/group.rs | 49 ++++++----- moq-transport/src/data/header.rs | 20 +++-- moq-transport/src/data/object.rs | 29 ++++--- moq-transport/src/data/track.rs | 49 ++++++----- moq-transport/src/lib.rs | 3 +- moq-transport/src/message/announce.rs | 18 ++-- moq-transport/src/message/announce_cancel.rs | 22 ++--- moq-transport/src/message/announce_error.rs | 22 ++--- moq-transport/src/message/announce_ok.rs | 14 +-- moq-transport/src/message/go_away.rs | 14 +-- moq-transport/src/message/mod.rs | 20 +++-- moq-transport/src/message/subscribe.rs | 79 +++++++++-------- moq-transport/src/message/subscribe_done.rs | 44 ++++++---- moq-transport/src/message/subscribe_error.rs | 25 +++--- moq-transport/src/message/subscribe_ok.rs | 40 +++++---- moq-transport/src/message/unannounce.rs | 14 +-- moq-transport/src/message/unsubscribe.rs | 13 ++- moq-transport/src/session/error.rs | 28 +++++- moq-transport/src/session/mod.rs | 90 ++++++++++++-------- moq-transport/src/session/publisher.rs | 32 +++---- moq-transport/src/session/subscribe.rs | 71 ++++++++------- moq-transport/src/session/subscribed.rs | 60 ++++++++----- moq-transport/src/session/subscriber.rs | 53 ++++++------ moq-transport/src/setup/client.rs | 26 +++--- moq-transport/src/setup/mod.rs | 2 + moq-transport/src/setup/role.rs | 12 +-- moq-transport/src/setup/server.rs | 27 +++--- moq-transport/src/setup/version.rs | 29 +++---- 44 files changed, 828 insertions(+), 572 deletions(-) create mode 100644 moq-transport/src/coding/reader.rs create mode 100644 moq-transport/src/coding/writer.rs diff --git a/Cargo.lock b/Cargo.lock index 7b3e16c5..95b315da 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -875,7 +875,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "webtransport-quinn", + "webtransport-quinn 0.7.0", ] [[package]] @@ -899,7 +899,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "webtransport-quinn", + "webtransport-quinn 0.7.0", ] [[package]] @@ -916,6 +916,7 @@ dependencies = [ "log", "moq-api", "moq-transport", + "quictransport-quinn", "quinn", "ring 0.16.20", "rustls", @@ -928,7 +929,8 @@ dependencies = [ "tracing-subscriber", "url", "webpki", - "webtransport-quinn", + "webtransport-generic 0.5.0", + "webtransport-quinn 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] @@ -936,7 +938,6 @@ name = "moq-transport" version = "0.3.0" dependencies = [ "anyhow", - "async-trait", "bytes", "clap", "env_logger", @@ -1220,6 +1221,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quictransport-quinn" +version = "0.7.0" +dependencies = [ + "bytes", + "quinn", + "tokio", + "webtransport-generic 0.5.0", + "webtransport-proto 0.6.0", +] + [[package]] name = "quinn" version = "0.10.2" @@ -2184,6 +2196,7 @@ checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" name = "webtransport-generic" version = "0.5.0" dependencies = [ + "bytes", "tokio", ] @@ -2197,6 +2210,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "webtransport-proto" +version = "0.6.0" +dependencies = [ + "bytes", + "http", + "thiserror", + "url", +] + [[package]] name = "webtransport-proto" version = "0.6.0" @@ -2209,6 +2232,23 @@ dependencies = [ "url", ] +[[package]] +name = "webtransport-quinn" +version = "0.7.0" +dependencies = [ + "bytes", + "futures", + "http", + "log", + "quinn", + "quinn-proto", + "thiserror", + "tokio", + "url", + "webtransport-generic 0.5.0", + "webtransport-proto 0.6.0", +] + [[package]] name = "webtransport-quinn" version = "0.7.0" @@ -2225,7 +2265,7 @@ dependencies = [ "tokio", "url", "webtransport-generic 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", - "webtransport-proto", + "webtransport-proto 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", ] [[package]] diff --git a/moq-clock/Cargo.toml b/moq-clock/Cargo.toml index 483b7dbe..90dbebec 100644 --- a/moq-clock/Cargo.toml +++ b/moq-clock/Cargo.toml @@ -18,7 +18,8 @@ moq-transport = { path = "../moq-transport" } # QUIC quinn = "0.10" -webtransport-quinn = "0.7" +#webtransport-quinn = "0.7" +webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn" } url = "2" # Crypto diff --git a/moq-pub/Cargo.toml b/moq-pub/Cargo.toml index 648fbf63..95c90f68 100644 --- a/moq-pub/Cargo.toml +++ b/moq-pub/Cargo.toml @@ -18,8 +18,8 @@ moq-transport = { path = "../moq-transport" } # QUIC quinn = "0.10" -webtransport-quinn = "0.7" -#webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn" } +#webtransport-quinn = "0.7" +webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn" } url = "2" # Crypto diff --git a/moq-relay/Cargo.toml b/moq-relay/Cargo.toml index 869ad9f2..e405dd39 100644 --- a/moq-relay/Cargo.toml +++ b/moq-relay/Cargo.toml @@ -19,6 +19,8 @@ moq-api = { path = "../moq-api" } quinn = "0.10" webtransport-quinn = "0.7" #webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn" } +quictransport-quinn = { path = "../../webtransport-rs/quictransport-quinn" } +webtransport-generic = { path = "../../webtransport-rs/webtransport-generic" } url = "2" # Crypto diff --git a/moq-relay/src/connection.rs b/moq-relay/src/connection.rs index 708c25e8..a027ccfd 100644 --- a/moq-relay/src/connection.rs +++ b/moq-relay/src/connection.rs @@ -16,28 +16,54 @@ impl Connection { } pub async fn run(self, conn: quinn::Connecting) -> anyhow::Result<()> { - log::debug!("received QUIC handshake: ip={:?}", conn.remote_address()); + let handshake = conn + .handshake_data() + .await? + .downcast::()?; - // Wait for the QUIC connection to be established. - let conn = conn.await.context("failed to establish QUIC connection")?; - - log::debug!("established QUIC connection: ip={:?}", conn.remote_address(),); + let alpn = handshake.protocol.context("missing ALPN")?; - // Wait for the CONNECT request. - let request = webtransport_quinn::accept(conn) - .await - .context("failed to receive WebTransport request")?; + log::debug!( + "received QUIC handshake: ip={} alpn={} server={}", + conn.remote_address(), + alpn, + handshake.server_name + ); - // Strip any leading and trailing slashes to get the customer ID. - let path = request.url().path().trim_matches('/').to_string(); - - log::debug!("received WebTransport CONNECT: path={}", path); + // Wait for the QUIC connection to be established. + let conn = conn.await.context("failed to establish QUIC connection")?; - // Accept the CONNECT request. - let session = request - .ok() - .await - .context("failed to respond to WebTransport request")?; + log::debug!( + "established QUIC connection: id={} ip={} alpn={} server={}", + conn.stable_id(), + conn.remote_address(), + alpn, + handshake.server_name + ); + + let session = if alpn.as_slice() == webtransport_quinn::ALPN { + // Wait for the CONNECT request. + let request = webtransport_quinn::accept(conn) + .await + .context("failed to receive WebTransport request")?; + + // Accept the CONNECT request. + let session = request + .ok() + .await + .context("failed to respond to WebTransport request")?; + + let path = request.url().path().trim_matches('/').to_string(); + + log::debug!("received WebTransport CONNECT: path={}", path); + session + } else if alpn.as_slice() == moq_transport::setup::ALPN { + let session: quictransport_quinn::Session = conn.into(); + + session + } else { + anyhow::anyhow!("unsupported ALPN: alpn={:?}", alpn); + }; let (session, publisher, subscriber) = moq_transport::Session::accept(session).await?; @@ -58,7 +84,10 @@ impl Connection { Ok(()) } - async fn serve_publisher(mut publisher: Publisher, origin: Origin) -> Result<(), SessionError> { + async fn serve_publisher( + mut publisher: Publisher, + origin: Origin, + ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); loop { @@ -77,7 +106,10 @@ impl Connection { } } - async fn serve_subscriber(mut subscriber: Subscriber, origin: Origin) -> Result<(), SessionError> { + async fn serve_subscriber( + mut subscriber: Subscriber, + origin: Origin, + ) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); loop { diff --git a/moq-relay/src/quic.rs b/moq-relay/src/quic.rs index 0e76fb45..8727450e 100644 --- a/moq-relay/src/quic.rs +++ b/moq-relay/src/quic.rs @@ -21,8 +21,8 @@ impl Quic { pub async fn new(config: Config, tls: Tls) -> anyhow::Result { let mut client_config = tls.client.clone(); let mut server_config = tls.server.clone(); - client_config.alpn_protocols = vec![webtransport_quinn::ALPN.to_vec()]; - server_config.alpn_protocols = vec![webtransport_quinn::ALPN.to_vec()]; + client_config.alpn_protocols = vec![webtransport_quinn::ALPN.to_vec(), moq_transport::setup::ALPN.to_vec()]; + server_config.alpn_protocols = vec![webtransport_quinn::ALPN.to_vec(), moq_transport::setup::ALPN.to_vec()]; // Enable BBR congestion control // TODO validate the implementation diff --git a/moq-transport/Cargo.toml b/moq-transport/Cargo.toml index 01cf9ec6..a1d1554a 100644 --- a/moq-transport/Cargo.toml +++ b/moq-transport/Cargo.toml @@ -22,7 +22,6 @@ log = "0.4" webtransport-generic = { path = "../../webtransport-rs/webtransport-generic" } -async-trait = "0.1" paste = "1" futures = "0.3" diff --git a/moq-transport/src/coding/decode.rs b/moq-transport/src/coding/decode.rs index a4f5894e..a5f6e87e 100644 --- a/moq-transport/src/coding/decode.rs +++ b/moq-transport/src/coding/decode.rs @@ -1,28 +1,19 @@ use super::BoundsExceeded; -use std::{io, str}; - +use std::{io, string::FromUtf8Error, sync}; use thiserror::Error; -// I'm too lazy to add these trait bounds to every message type. -// TODO Use trait aliases when they're stable, or add these bounds to every method. -pub trait AsyncRead: tokio::io::AsyncRead + Unpin + Send {} -impl AsyncRead for T {} -//impl AsyncRead for tokio::io::Take<&mut T> where T: AsyncRead {} -//impl + Unpin + Send> AsyncRead for io::Cursor {} - -#[async_trait::async_trait] pub trait Decode: Sized { - async fn decode(r: &mut R) -> Result; + fn decode(buf: &mut B) -> Result; } /// A decode error. #[derive(Error, Debug, Clone)] pub enum DecodeError { - #[error("unexpected end of buffer")] - UnexpectedEnd, + #[error("fill buffer")] + More(usize), #[error("invalid string")] - InvalidString(#[from] str::Utf8Error), + InvalidString(#[from] FromUtf8Error), #[error("invalid message: {0:?}")] InvalidMessage(u64), @@ -49,6 +40,12 @@ pub enum DecodeError { #[error("invalid parameter")] InvalidParameter, - #[error("io error")] - IoError, + #[error("io error: {0}")] + Io(sync::Arc), +} + +impl From for DecodeError { + fn from(err: io::Error) -> Self { + Self::Io(sync::Arc::new(err)) + } } diff --git a/moq-transport/src/coding/encode.rs b/moq-transport/src/coding/encode.rs index 6455487d..e906c7d4 100644 --- a/moq-transport/src/coding/encode.rs +++ b/moq-transport/src/coding/encode.rs @@ -1,27 +1,29 @@ -use super::BoundsExceeded; - -use thiserror::Error; +use std::{io, sync}; -// I'm too lazy to add these trait bounds to every message type. -// TODO Use trait aliases when they're stable, or add these bounds to every method. -pub trait AsyncWrite: tokio::io::AsyncWrite + Unpin + Send {} -impl AsyncWrite for dyn webtransport_generic::SendStream {} -impl AsyncWrite for Vec {} +use super::BoundsExceeded; -#[async_trait::async_trait] pub trait Encode: Sized { - async fn encode(&self, w: &mut W) -> Result<(), EncodeError>; + fn encode(&self, w: &mut W) -> Result<(), EncodeError>; } /// An encode error. -#[derive(Error, Debug, Clone)] +#[derive(thiserror::Error, Debug, Clone)] pub enum EncodeError { + #[error("short buffer")] + More(usize), + #[error("varint too large")] BoundsExceeded(#[from] BoundsExceeded), #[error("invalid value")] InvalidValue, - #[error("i/o error")] - IoError, + #[error("i/o error: {0}")] + Io(sync::Arc), +} + +impl From for EncodeError { + fn from(err: io::Error) -> Self { + Self::Io(sync::Arc::new(err)) + } } diff --git a/moq-transport/src/coding/mod.rs b/moq-transport/src/coding/mod.rs index a3ff6f78..8753e205 100644 --- a/moq-transport/src/coding/mod.rs +++ b/moq-transport/src/coding/mod.rs @@ -1,10 +1,14 @@ mod decode; mod encode; mod params; +mod reader; mod string; mod varint; +mod writer; pub use decode::*; pub use encode::*; pub use params::*; +pub use reader::*; pub use varint::*; +pub use writer::*; diff --git a/moq-transport/src/coding/params.rs b/moq-transport/src/coding/params.rs index 0bcedbec..1768a758 100644 --- a/moq-transport/src/coding/params.rs +++ b/moq-transport/src/coding/params.rs @@ -1,53 +1,45 @@ +use std::collections::HashMap; use std::io::Cursor; -use std::{cmp::max, collections::HashMap}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -use crate::coding::{AsyncRead, AsyncWrite, Decode, Encode}; - -use crate::coding::{DecodeError, EncodeError}; +use crate::coding::{Decode, DecodeError, Encode, EncodeError}; #[derive(Default, Debug, Clone)] pub struct Params(pub HashMap>); -#[async_trait::async_trait] impl Decode for Params { - async fn decode(mut r: &mut R) -> Result { + fn decode(mut r: &mut R) -> Result { let mut params = HashMap::new(); // I hate this encoding so much; let me encode my role and get on with my life. - let count = u64::decode(r).await?; + let count = u64::decode(r)?; for _ in 0..count { - let kind = u64::decode(r).await?; + let kind = u64::decode(r)?; if params.contains_key(&kind) { return Err(DecodeError::DupliateParameter); } - let size = u64::decode(r).await?; + let size = usize::decode(&mut r)?; // Don't allocate the entire requested size to avoid a possible attack // Instead, we allocate up to 1024 and keep appending as we read further. - let mut pr = r.take(size); - let mut buf = Vec::with_capacity(max(1024, pr.limit() as usize)); - pr.read_to_end(&mut buf).await.map_err(|_| DecodeError::IoError)?; - params.insert(kind, buf); + let mut buf = vec![0; size]; + r.copy_to_slice(&mut buf); - r = pr.into_inner(); + params.insert(kind, buf); } Ok(Params(params)) } } -#[async_trait::async_trait] impl Encode for Params { - async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.0.len().encode(w).await?; + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.0.len().encode(w)?; for (kind, value) in self.0.iter() { - kind.encode(w).await?; - value.len().encode(w).await?; - w.write_all(value).await.map_err(|_| EncodeError::IoError)?; + kind.encode(w)?; + value.len().encode(w)?; + w.put_slice(value); } Ok(()) @@ -59,9 +51,9 @@ impl Params { Self::default() } - pub async fn set(&mut self, kind: u64, p: P) -> Result<(), EncodeError> { + pub fn set(&mut self, kind: u64, p: P) -> Result<(), EncodeError> { let mut value = Vec::new(); - p.encode(&mut value).await?; + p.encode(&mut value)?; self.0.insert(kind, value); Ok(()) @@ -71,10 +63,10 @@ impl Params { self.0.contains_key(&kind) } - pub async fn get(&mut self, kind: u64) -> Result, DecodeError> { + pub fn get(&mut self, kind: u64) -> Result, DecodeError> { if let Some(value) = self.0.remove(&kind) { let mut cursor = Cursor::new(value); - Ok(Some(P::decode(&mut cursor).await?)) + Ok(Some(P::decode(&mut cursor)?)) } else { Ok(None) } diff --git a/moq-transport/src/coding/reader.rs b/moq-transport/src/coding/reader.rs new file mode 100644 index 00000000..63d2df55 --- /dev/null +++ b/moq-transport/src/coding/reader.rs @@ -0,0 +1,64 @@ +use std::{cmp, io}; + +use bytes::Buf; +use tokio::io::{AsyncRead, AsyncReadExt}; + +use crate::coding::Decode; + +use super::DecodeError; + +pub struct Reader { + stream: S, + buffer: bytes::BytesMut, +} + +impl Reader { + pub fn new(stream: S) -> Self { + Self { + stream, + buffer: Default::default(), + } + } + + pub async fn decode(&mut self) -> Result { + loop { + let mut cursor = io::Cursor::new(&self.buffer); + + // Try to decode with the current buffer. + let mut remain = match T::decode(&mut cursor) { + Ok(msg) => { + self.buffer.advance(cursor.position() as usize); + return Ok(msg); + } + Err(DecodeError::More(remain)) => remain, // Try again with more data + Err(err) => return Err(err.into()), + }; + + // Append to the buffer + while remain > 0 { + remain -= self.stream.read_buf(&mut self.buffer).await?; + } + } + } + + pub async fn read(&mut self, max_size: usize) -> Result, io::Error> { + if self.buffer.is_empty() { + // TODO avoid making a copy by using Quinn's read_chunk + let size = self.stream.read_buf(&mut self.buffer).await?; + if size == 0 { + return Ok(None); + } + } + + let size = cmp::min(self.buffer.len(), max_size); + Ok(Some(self.buffer.split_to(size).freeze())) + } + + pub async fn done(&mut self) -> Result { + Ok(self.buffer.is_empty() && self.stream.read_buf(&mut self.buffer).await? == 0) + } + + pub fn into_inner(self) -> (bytes::BytesMut, S) { + (self.buffer, self.stream) + } +} diff --git a/moq-transport/src/coding/string.rs b/moq-transport/src/coding/string.rs index b24bdb56..7627bc5d 100644 --- a/moq-transport/src/coding/string.rs +++ b/moq-transport/src/coding/string.rs @@ -1,29 +1,29 @@ -use std::cmp::min; - -use crate::coding::{AsyncRead, AsyncWrite}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use super::{Decode, DecodeError, Encode, EncodeError}; -#[async_trait::async_trait] impl Encode for String { - async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.len().encode(w).await?; - w.write_all(self.as_ref()).await.map_err(|_| EncodeError::IoError)?; + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.len().encode(w)?; + if w.remaining_mut() < self.len() { + return Err(EncodeError::More(self.len())); + } + + w.put(self.as_ref()); Ok(()) } } -#[async_trait::async_trait] impl Decode for String { /// Decode a string with a varint length prefix. - async fn decode(r: &mut R) -> Result { - let size = usize::decode(r).await?; - let mut str = String::with_capacity(min(1024, size)); - r.take(size as u64) - .read_to_string(&mut str) - .await - .map_err(|_| DecodeError::IoError)?; + fn decode(r: &mut R) -> Result { + let size = usize::decode(r)?; + if r.remaining() < size { + return Err(DecodeError::More(size)); + } + + let mut buf = vec![0; size]; + r.copy_to_slice(&mut buf); + let str = String::from_utf8(buf)?; + Ok(str) } } diff --git a/moq-transport/src/coding/varint.rs b/moq-transport/src/coding/varint.rs index e0e927d6..763af016 100644 --- a/moq-transport/src/coding/varint.rs +++ b/moq-transport/src/coding/varint.rs @@ -5,9 +5,7 @@ use std::convert::{TryFrom, TryInto}; use std::fmt; -use crate::coding::{AsyncRead, AsyncWrite}; use thiserror::Error; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; use super::{Decode, DecodeError, Encode, EncodeError}; @@ -165,19 +163,14 @@ impl fmt::Display for VarInt { } } -#[async_trait::async_trait] impl Decode for VarInt { /// Decode a varint from the given reader. - async fn decode(r: &mut R) -> Result { - let b = r.read_u8().await.map_err(|_| DecodeError::IoError)?; - Self::decode_byte(b, r).await - } -} + fn decode(r: &mut R) -> Result { + if r.remaining() < 1 { + return Err(DecodeError::More(1)); + } -impl VarInt { - /// Decode a varint given the first byte, reading the rest as needed. - /// This is silly but useful for determining if the stream has ended. - pub async fn decode_byte(b: u8, r: &mut R) -> Result { + let b = r.get_u8(); let tag = b >> 6; let mut buf = [0u8; 8]; @@ -186,21 +179,27 @@ impl VarInt { let x = match tag { 0b00 => u64::from(buf[0]), 0b01 => { - r.read_exact(buf[1..2].as_mut()) - .await - .map_err(|_| DecodeError::IoError)?; + if r.remaining() < 1 { + return Err(DecodeError::More(1)); + } + + r.copy_to_slice(buf[1..2].as_mut()); u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap())) } 0b10 => { - r.read_exact(buf[1..4].as_mut()) - .await - .map_err(|_| DecodeError::IoError)?; + if r.remaining() < 3 { + return Err(DecodeError::More(3)); + } + + r.copy_to_slice(buf[1..4].as_mut()); u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap())) } 0b11 => { - r.read_exact(buf[1..8].as_mut()) - .await - .map_err(|_| DecodeError::IoError)?; + if r.remaining() < 7 { + return Err(DecodeError::More(7)); + } + + r.copy_to_slice(buf[1..8].as_mut()); u64::from_be_bytes(buf) } _ => unreachable!(), @@ -210,56 +209,50 @@ impl VarInt { } } -#[async_trait::async_trait] impl Encode for VarInt { /// Encode a varint to the given writer. - async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { let x = self.0; if x < 2u64.pow(6) { - w.write_u8(x as u8).await + w.put_u8(x as u8) } else if x < 2u64.pow(14) { - w.write_u16(0b01 << 14 | x as u16).await + w.put_u16(0b01 << 14 | x as u16) } else if x < 2u64.pow(30) { - w.write_u32(0b10 << 30 | x as u32).await + w.put_u32(0b10 << 30 | x as u32) } else if x < 2u64.pow(62) { - w.write_u64(0b11 << 62 | x).await + w.put_u64(0b11 << 62 | x) } else { return Err(BoundsExceeded.into()); } - .map_err(|_| EncodeError::IoError)?; Ok(()) } } -#[async_trait::async_trait] impl Encode for u64 { /// Encode a varint to the given writer. - async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { let var = VarInt::try_from(*self)?; - var.encode(w).await + var.encode(w) } } -#[async_trait::async_trait] impl Decode for u64 { - async fn decode(r: &mut R) -> Result { - VarInt::decode(r).await.map(|v| v.into_inner()) + fn decode(r: &mut R) -> Result { + VarInt::decode(r).map(|v| v.into_inner()) } } -#[async_trait::async_trait] impl Encode for usize { /// Encode a varint to the given writer. - async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { let var = VarInt::try_from(*self)?; - var.encode(w).await + var.encode(w) } } -#[async_trait::async_trait] impl Decode for usize { - async fn decode(r: &mut R) -> Result { - VarInt::decode(r).await.map(|v| v.into_inner() as usize) + fn decode(r: &mut R) -> Result { + VarInt::decode(r).map(|v| v.into_inner() as usize) } } diff --git a/moq-transport/src/coding/writer.rs b/moq-transport/src/coding/writer.rs new file mode 100644 index 00000000..25bb2a0e --- /dev/null +++ b/moq-transport/src/coding/writer.rs @@ -0,0 +1,36 @@ +use tokio::io::{AsyncWrite, AsyncWriteExt}; + +use crate::coding::Encode; + +use super::EncodeError; + +pub struct Writer { + stream: S, + buffer: bytes::BytesMut, +} + +impl Writer { + pub fn new(stream: S) -> Self { + Self { + stream, + buffer: Default::default(), + } + } + + pub async fn encode(&mut self, msg: &T) -> Result<(), EncodeError> { + self.buffer.clear(); + msg.encode(&mut self.buffer)?; + self.stream.write_all(&self.buffer).await?; + + Ok(()) + } + + pub async fn write(&mut self, buf: &[u8]) -> Result<(), EncodeError> { + self.stream.write_all(buf).await?; + Ok(()) + } + + pub fn into_inner(self) -> S { + self.stream + } +} diff --git a/moq-transport/src/data/datagram.rs b/moq-transport/src/data/datagram.rs index 5ebcc212..8bf62c63 100644 --- a/moq-transport/src/data/datagram.rs +++ b/moq-transport/src/data/datagram.rs @@ -1,7 +1,4 @@ -use crate::coding::{AsyncRead, AsyncWrite}; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - #[derive(Clone, Debug)] pub struct Datagram { // The subscribe ID. @@ -23,17 +20,14 @@ pub struct Datagram { pub payload: bytes::Bytes, } -impl Datagram { - pub async fn decode(r: &mut R) -> Result { - let subscribe_id = u64::decode(r).await?; - let track_alias = u64::decode(r).await?; - let group_id = u64::decode(r).await?; - let object_id = u64::decode(r).await?; - let send_order = u64::decode(r).await?; - - // TODO use with_capacity once we know the size of the datagram... - let mut payload = Vec::new(); - r.read_to_end(&mut payload).await.map_err(|_| DecodeError::IoError)?; +impl Decode for Datagram { + fn decode(r: &mut R) -> Result { + let subscribe_id = u64::decode(r)?; + let track_alias = u64::decode(r)?; + let group_id = u64::decode(r)?; + let object_id = u64::decode(r)?; + let send_order = u64::decode(r)?; + let payload = r.copy_to_bytes(r.remaining()); Ok(Self { subscribe_id, @@ -41,19 +35,24 @@ impl Datagram { group_id, object_id, send_order, - payload: payload.into(), + payload, }) } +} + +impl Encode for Datagram { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.subscribe_id.encode(w)?; + self.track_alias.encode(w)?; + self.group_id.encode(w)?; + self.object_id.encode(w)?; + self.send_order.encode(w)?; + + if w.remaining_mut() < self.payload.len() { + return Err(EncodeError::More(self.payload.len())); + } + w.put_slice(&self.payload); - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.subscribe_id.encode(w).await?; - self.track_alias.encode(w).await?; - self.group_id.encode(w).await?; - self.object_id.encode(w).await?; - self.send_order.encode(w).await?; - w.write_all(self.payload.as_ref()) - .await - .map_err(|_| EncodeError::IoError)?; Ok(()) } } diff --git a/moq-transport/src/data/group.rs b/moq-transport/src/data/group.rs index 8ac020e4..aee076a1 100644 --- a/moq-transport/src/data/group.rs +++ b/moq-transport/src/data/group.rs @@ -1,4 +1,4 @@ -use crate::coding::{AsyncRead, AsyncWrite, Decode, DecodeError, Encode, EncodeError}; +use crate::coding::{Decode, DecodeError, Encode, EncodeError}; #[derive(Clone, Debug)] pub struct GroupHeader { @@ -15,21 +15,23 @@ pub struct GroupHeader { pub send_order: u64, } -impl GroupHeader { - pub async fn decode(r: &mut R) -> Result { +impl Decode for GroupHeader { + fn decode(r: &mut R) -> Result { Ok(Self { - subscribe_id: u64::decode(r).await?, - track_alias: u64::decode(r).await?, - group_id: u64::decode(r).await?, - send_order: u64::decode(r).await?, + subscribe_id: u64::decode(r)?, + track_alias: u64::decode(r)?, + group_id: u64::decode(r)?, + send_order: u64::decode(r)?, }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.subscribe_id.encode(w).await?; - self.track_alias.encode(w).await?; - self.group_id.encode(w).await?; - self.send_order.encode(w).await?; +impl Encode for GroupHeader { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.subscribe_id.encode(w)?; + self.track_alias.encode(w)?; + self.group_id.encode(w)?; + self.send_order.encode(w)?; Ok(()) } @@ -41,22 +43,19 @@ pub struct GroupObject { pub size: usize, } -impl GroupObject { - pub async fn decode(r: &mut R) -> Result, DecodeError> { - let object_id = match u64::decode(r).await { - Ok(object_id) => object_id, - Err(DecodeError::UnexpectedEnd) => return Ok(None), - Err(err) => return Err(err), - }; - - let size = usize::decode(r).await?; +impl Decode for GroupObject { + fn decode(r: &mut R) -> Result { + let object_id = u64::decode(r)?; + let size = usize::decode(r)?; - Ok(Some(Self { object_id, size })) + Ok(Self { object_id, size }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.object_id.encode(w).await?; - self.size.encode(w).await?; +impl Encode for GroupObject { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.object_id.encode(w)?; + self.size.encode(w)?; Ok(()) } diff --git a/moq-transport/src/data/header.rs b/moq-transport/src/data/header.rs index 22c1ba2d..237bce7e 100644 --- a/moq-transport/src/data/header.rs +++ b/moq-transport/src/data/header.rs @@ -1,4 +1,4 @@ -use crate::coding::{AsyncRead, AsyncWrite, Decode, DecodeError, Encode, EncodeError}; +use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use paste::paste; use std::fmt; @@ -14,28 +14,32 @@ macro_rules! header_types { $($name(paste! { [<$name Header>] })),* } - impl Header { - pub async fn decode(r: &mut R) -> Result { - let t = u64::decode(r).await?; + impl Decode for Header { + fn decode(r: &mut R) -> Result { + let t = u64::decode(r)?; match t { $($val => { - let msg = ] }>::decode(r).await?; + let msg = ] }>::decode(r)?; Ok(Self::$name(msg)) })* _ => Err(DecodeError::InvalidMessage(t)), } } + } - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + impl Encode for Header { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { match self { $(Self::$name(ref m) => { - self.id().encode(w).await?; - m.encode(w).await + self.id().encode(w)?; + m.encode(w) },)* } } + } + impl Header { pub fn id(&self) -> u64 { match self { $(Self::$name(_) => { diff --git a/moq-transport/src/data/object.rs b/moq-transport/src/data/object.rs index bc241c62..6b601e28 100644 --- a/moq-transport/src/data/object.rs +++ b/moq-transport/src/data/object.rs @@ -1,4 +1,3 @@ -use crate::coding::{AsyncRead, AsyncWrite}; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; #[derive(Clone, Debug)] @@ -19,23 +18,25 @@ pub struct ObjectHeader { pub send_order: u64, } -impl ObjectHeader { - pub async fn decode(r: &mut R) -> Result { +impl Decode for ObjectHeader { + fn decode(r: &mut R) -> Result { Ok(Self { - subscribe_id: u64::decode(r).await?, - track_alias: u64::decode(r).await?, - group_id: u64::decode(r).await?, - object_id: u64::decode(r).await?, - send_order: u64::decode(r).await?, + subscribe_id: u64::decode(r)?, + track_alias: u64::decode(r)?, + group_id: u64::decode(r)?, + object_id: u64::decode(r)?, + send_order: u64::decode(r)?, }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.subscribe_id.encode(w).await?; - self.track_alias.encode(w).await?; - self.group_id.encode(w).await?; - self.object_id.encode(w).await?; - self.send_order.encode(w).await?; +impl Encode for ObjectHeader { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.subscribe_id.encode(w)?; + self.track_alias.encode(w)?; + self.group_id.encode(w)?; + self.object_id.encode(w)?; + self.send_order.encode(w)?; Ok(()) } diff --git a/moq-transport/src/data/track.rs b/moq-transport/src/data/track.rs index 17de5caf..70ce3a06 100644 --- a/moq-transport/src/data/track.rs +++ b/moq-transport/src/data/track.rs @@ -1,4 +1,3 @@ -use crate::coding::{AsyncRead, AsyncWrite}; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; #[derive(Clone, Debug)] @@ -13,19 +12,21 @@ pub struct TrackHeader { pub send_order: u64, } -impl TrackHeader { - pub async fn decode(r: &mut R) -> Result { +impl Decode for TrackHeader { + fn decode(r: &mut R) -> Result { Ok(Self { - subscribe_id: u64::decode(r).await?, - track_alias: u64::decode(r).await?, - send_order: u64::decode(r).await?, + subscribe_id: u64::decode(r)?, + track_alias: u64::decode(r)?, + send_order: u64::decode(r)?, }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.subscribe_id.encode(w).await?; - self.track_alias.encode(w).await?; - self.send_order.encode(w).await?; +impl Encode for TrackHeader { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.subscribe_id.encode(w)?; + self.track_alias.encode(w)?; + self.send_order.encode(w)?; Ok(()) } @@ -38,28 +39,26 @@ pub struct TrackObject { pub size: usize, } -impl TrackObject { - pub async fn decode(r: &mut R) -> Result, DecodeError> { - let group_id = match u64::decode(r).await { - Ok(group_id) => group_id, - Err(DecodeError::UnexpectedEnd) => return Ok(None), - Err(err) => return Err(err), - }; +impl Decode for TrackObject { + fn decode(r: &mut R) -> Result { + let group_id = u64::decode(r)?; - let object_id = u64::decode(r).await?; - let size = usize::decode(r).await?; + let object_id = u64::decode(r)?; + let size = usize::decode(r)?; - Ok(Some(Self { + Ok(Self { group_id, object_id, size, - })) + }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.group_id.encode(w).await?; - self.object_id.encode(w).await?; - self.size.encode(w).await?; +impl Encode for TrackObject { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.group_id.encode(w)?; + self.object_id.encode(w)?; + self.size.encode(w)?; Ok(()) } diff --git a/moq-transport/src/lib.rs b/moq-transport/src/lib.rs index 62c7759f..2456c30f 100644 --- a/moq-transport/src/lib.rs +++ b/moq-transport/src/lib.rs @@ -4,8 +4,7 @@ //! While originally designed for live media, MoQ Transport is generic and can be used for other live applications. //! The specification is a work in progress and will change. //! See the [specification](https://datatracker.ietf.org/doc/draft-ietf-moq-transport/) and [github](https://github.com/moq-wg/moq-transport) for any updates. -mod coding; - +pub mod coding; pub mod data; pub mod error; pub mod message; diff --git a/moq-transport/src/message/announce.rs b/moq-transport/src/message/announce.rs index 709339d9..114fb5c0 100644 --- a/moq-transport/src/message/announce.rs +++ b/moq-transport/src/message/announce.rs @@ -1,7 +1,5 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError, Params}; -use crate::coding::{AsyncRead, AsyncWrite}; - /// Sent by the publisher to announce the availability of a group of tracks. #[derive(Clone, Debug)] pub struct Announce { @@ -12,17 +10,19 @@ pub struct Announce { pub params: Params, } -impl Announce { - pub async fn decode(r: &mut R) -> Result { - let namespace = String::decode(r).await?; - let params = Params::decode(r).await?; +impl Decode for Announce { + fn decode(r: &mut R) -> Result { + let namespace = String::decode(r)?; + let params = Params::decode(r)?; Ok(Self { namespace, params }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.namespace.encode(w).await?; - self.params.encode(w).await?; +impl Encode for Announce { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.namespace.encode(w)?; + self.params.encode(w)?; Ok(()) } diff --git a/moq-transport/src/message/announce_cancel.rs b/moq-transport/src/message/announce_cancel.rs index 386d8a00..2a3379f2 100644 --- a/moq-transport/src/message/announce_cancel.rs +++ b/moq-transport/src/message/announce_cancel.rs @@ -1,7 +1,5 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError}; -use crate::coding::{AsyncRead, AsyncWrite}; - /// Sent by the subscriber to reject an Announce after ANNOUNCE_OK #[derive(Clone, Debug)] pub struct AnnounceCancel { @@ -14,11 +12,11 @@ pub struct AnnounceCancel { //pub reason: String, } -impl AnnounceCancel { - pub async fn decode(r: &mut R) -> Result { - let namespace = String::decode(r).await?; - //let code = u64::decode(r).await?; - //let reason = String::decode(r).await?; +impl Decode for AnnounceCancel { + fn decode(r: &mut R) -> Result { + let namespace = String::decode(r)?; + //let code = u64::decode(r)?; + //let reason = String::decode(r)?; Ok(Self { namespace, @@ -26,11 +24,13 @@ impl AnnounceCancel { //reason, }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.namespace.encode(w).await?; - //self.code.encode(w).await?; - //self.reason.encode(w).await?; +impl Encode for AnnounceCancel { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.namespace.encode(w)?; + //self.code.encode(w)?; + //self.reason.encode(w)?; Ok(()) } diff --git a/moq-transport/src/message/announce_error.rs b/moq-transport/src/message/announce_error.rs index cb0e82ac..4c468a32 100644 --- a/moq-transport/src/message/announce_error.rs +++ b/moq-transport/src/message/announce_error.rs @@ -1,7 +1,5 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError}; -use crate::coding::{AsyncRead, AsyncWrite}; - /// Sent by the subscriber to reject an Announce. #[derive(Clone, Debug)] pub struct AnnounceError { @@ -15,11 +13,11 @@ pub struct AnnounceError { pub reason: String, } -impl AnnounceError { - pub async fn decode(r: &mut R) -> Result { - let namespace = String::decode(r).await?; - let code = u64::decode(r).await?; - let reason = String::decode(r).await?; +impl Decode for AnnounceError { + fn decode(r: &mut R) -> Result { + let namespace = String::decode(r)?; + let code = u64::decode(r)?; + let reason = String::decode(r)?; Ok(Self { namespace, @@ -27,11 +25,13 @@ impl AnnounceError { reason, }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.namespace.encode(w).await?; - self.code.encode(w).await?; - self.reason.encode(w).await?; +impl Encode for AnnounceError { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.namespace.encode(w)?; + self.code.encode(w)?; + self.reason.encode(w)?; Ok(()) } diff --git a/moq-transport/src/message/announce_ok.rs b/moq-transport/src/message/announce_ok.rs index a5c47928..0178eb1c 100644 --- a/moq-transport/src/message/announce_ok.rs +++ b/moq-transport/src/message/announce_ok.rs @@ -1,4 +1,4 @@ -use crate::coding::{AsyncRead, AsyncWrite, Decode, DecodeError, Encode, EncodeError}; +use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Sent by the subscriber to accept an Announce. #[derive(Clone, Debug)] @@ -8,13 +8,15 @@ pub struct AnnounceOk { pub namespace: String, } -impl AnnounceOk { - pub async fn decode(r: &mut R) -> Result { - let namespace = String::decode(r).await?; +impl Decode for AnnounceOk { + fn decode(r: &mut R) -> Result { + let namespace = String::decode(r)?; Ok(Self { namespace }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.namespace.encode(w).await +impl Encode for AnnounceOk { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.namespace.encode(w) } } diff --git a/moq-transport/src/message/go_away.rs b/moq-transport/src/message/go_away.rs index c86152ae..376b9057 100644 --- a/moq-transport/src/message/go_away.rs +++ b/moq-transport/src/message/go_away.rs @@ -1,20 +1,20 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError}; -use crate::coding::{AsyncRead, AsyncWrite}; - /// Sent by the server to indicate that the client should connect to a different server. #[derive(Clone, Debug)] pub struct GoAway { pub url: String, } -impl GoAway { - pub async fn decode(r: &mut R) -> Result { - let url = String::decode(r).await?; +impl Decode for GoAway { + fn decode(r: &mut R) -> Result { + let url = String::decode(r)?; Ok(Self { url }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.url.encode(w).await +impl Encode for GoAway { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.url.encode(w) } } diff --git a/moq-transport/src/message/mod.rs b/moq-transport/src/message/mod.rs index 4ce54476..1812f769 100644 --- a/moq-transport/src/message/mod.rs +++ b/moq-transport/src/message/mod.rs @@ -59,7 +59,7 @@ pub use subscriber::*; pub use unannounce::*; pub use unsubscribe::*; -use crate::coding::{AsyncRead, AsyncWrite, Decode, DecodeError, Encode, EncodeError}; +use crate::coding::{Decode, DecodeError, Encode, EncodeError}; use std::fmt; // Use a macro to generate the message types rather than copy-paste. @@ -72,28 +72,32 @@ macro_rules! message_types { $($name($name)),* } - impl Message { - pub async fn decode(r: &mut R) -> Result { - let t = u64::decode(r).await?; + impl Decode for Message { + fn decode(r: &mut R) -> Result { + let t = u64::decode(r)?; match t { $($val => { - let msg = $name::decode(r).await?; + let msg = $name::decode(r)?; Ok(Self::$name(msg)) })* _ => Err(DecodeError::InvalidMessage(t)), } } + } - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + impl Encode for Message { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { match self { $(Self::$name(ref m) => { - self.id().encode(w).await?; - m.encode(w).await + self.id().encode(w)?; + m.encode(w) },)* } } + } + impl Message { pub fn id(&self) -> u64 { match self { $(Self::$name(_) => { diff --git a/moq-transport/src/message/subscribe.rs b/moq-transport/src/message/subscribe.rs index afdd96ee..30a8635f 100644 --- a/moq-transport/src/message/subscribe.rs +++ b/moq-transport/src/message/subscribe.rs @@ -1,4 +1,3 @@ -use crate::coding::{AsyncRead, AsyncWrite}; use crate::coding::{Decode, DecodeError, Encode, EncodeError, Params}; /// Sent by the subscriber to request all future objects for the given track. @@ -22,15 +21,15 @@ pub struct Subscribe { pub params: Params, } -impl Subscribe { - pub async fn decode(r: &mut R) -> Result { - let id = u64::decode(r).await?; - let track_alias = u64::decode(r).await?; - let track_namespace = String::decode(r).await?; - let track_name = String::decode(r).await?; +impl Decode for Subscribe { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let track_alias = u64::decode(r)?; + let track_namespace = String::decode(r)?; + let track_name = String::decode(r)?; - let start = SubscribePair::decode(r).await?; - let end = SubscribePair::decode(r).await?; + let start = SubscribePair::decode(r)?; + let end = SubscribePair::decode(r)?; // You can't have a start object without a start group. if start.group == SubscribeLocation::None && start.object != SubscribeLocation::None { @@ -44,7 +43,7 @@ impl Subscribe { // NOTE: There's some more location restrictions in the draft, but they're enforced at a higher level. - let params = Params::decode(r).await?; + let params = Params::decode(r)?; Ok(Self { id, @@ -56,17 +55,19 @@ impl Subscribe { params, }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w).await?; - self.track_alias.encode(w).await?; - self.track_namespace.encode(w).await?; - self.track_name.encode(w).await?; +impl Encode for Subscribe { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.track_alias.encode(w)?; + self.track_namespace.encode(w)?; + self.track_name.encode(w)?; - self.start.encode(w).await?; - self.end.encode(w).await?; + self.start.encode(w)?; + self.end.encode(w)?; - self.params.encode(w).await?; + self.params.encode(w)?; Ok(()) } @@ -78,17 +79,19 @@ pub struct SubscribePair { pub object: SubscribeLocation, } -impl SubscribePair { - pub async fn decode(r: &mut R) -> Result { +impl Decode for SubscribePair { + fn decode(r: &mut R) -> Result { Ok(Self { - group: SubscribeLocation::decode(r).await?, - object: SubscribeLocation::decode(r).await?, + group: SubscribeLocation::decode(r)?, + object: SubscribeLocation::decode(r)?, }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.group.encode(w).await?; - self.object.encode(w).await?; +impl Encode for SubscribePair { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.group.encode(w)?; + self.object.encode(w)?; Ok(()) } } @@ -102,30 +105,34 @@ pub enum SubscribeLocation { Future(u64), } -impl SubscribeLocation { - pub async fn decode(r: &mut R) -> Result { - let kind = u64::decode(r).await?; +impl Decode for SubscribeLocation { + fn decode(r: &mut R) -> Result { + let kind = u64::decode(r)?; match kind { 0 => Ok(Self::None), - 1 => Ok(Self::Absolute(u64::decode(r).await?)), - 2 => Ok(Self::Latest(u64::decode(r).await?)), - 3 => Ok(Self::Future(u64::decode(r).await?)), + 1 => Ok(Self::Absolute(u64::decode(r)?)), + 2 => Ok(Self::Latest(u64::decode(r)?)), + 3 => Ok(Self::Future(u64::decode(r)?)), _ => Err(DecodeError::InvalidSubscribeLocation), } } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id().encode(w).await?; +impl Encode for SubscribeLocation { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id().encode(w)?; match self { Self::None => Ok(()), - Self::Absolute(val) => val.encode(w).await, - Self::Latest(val) => val.encode(w).await, - Self::Future(val) => val.encode(w).await, + Self::Absolute(val) => val.encode(w), + Self::Latest(val) => val.encode(w), + Self::Future(val) => val.encode(w), } } +} +impl SubscribeLocation { fn id(&self) -> u64 { match self { Self::None => 0, diff --git a/moq-transport/src/message/subscribe_done.rs b/moq-transport/src/message/subscribe_done.rs index 2b58575f..08dcab28 100644 --- a/moq-transport/src/message/subscribe_done.rs +++ b/moq-transport/src/message/subscribe_done.rs @@ -1,6 +1,3 @@ -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - -use crate::coding::{AsyncRead, AsyncWrite}; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Sent by the publisher to cleanly terminate a Subscribe. @@ -19,31 +16,42 @@ pub struct SubscribeDone { pub last: Option<(u64, u64)>, } -impl SubscribeDone { - pub async fn decode(r: &mut R) -> Result { - let id = u64::decode(r).await?; - let code = u64::decode(r).await?; - let reason = String::decode(r).await?; - let last = match r.read_u8().await.map_err(|_| DecodeError::IoError)? { +impl Decode for SubscribeDone { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let code = u64::decode(r)?; + let reason = String::decode(r)?; + + if r.remaining() < 1 { + return Err(DecodeError::More(1)); + } + + let last = match r.get_u8() { 0 => None, - 1 => Some((u64::decode(r).await?, u64::decode(r).await?)), + 1 => Some((u64::decode(r)?, u64::decode(r)?)), _ => return Err(DecodeError::InvalidValue), }; Ok(Self { id, code, reason, last }) } +} + +impl Encode for SubscribeDone { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.code.encode(w)?; + self.reason.encode(w)?; - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w).await?; - self.code.encode(w).await?; - self.reason.encode(w).await?; + if w.remaining_mut() < 1 { + return Err(EncodeError::More(1)); + } if let Some((group, object)) = self.last { - w.write_u8(1).await.map_err(|_| EncodeError::IoError)?; - group.encode(w).await?; - object.encode(w).await?; + w.put_u8(1); + group.encode(w)?; + object.encode(w)?; } else { - w.write_u8(0).await.map_err(|_| EncodeError::IoError)?; + w.put_u8(0); } Ok(()) diff --git a/moq-transport/src/message/subscribe_error.rs b/moq-transport/src/message/subscribe_error.rs index 16494e7d..0fb6ca68 100644 --- a/moq-transport/src/message/subscribe_error.rs +++ b/moq-transport/src/message/subscribe_error.rs @@ -1,4 +1,3 @@ -use crate::coding::{AsyncRead, AsyncWrite}; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Sent by the publisher to reject a Subscribe. @@ -17,12 +16,12 @@ pub struct SubscribeError { pub alias: u64, } -impl SubscribeError { - pub async fn decode(r: &mut R) -> Result { - let id = u64::decode(r).await?; - let code = u64::decode(r).await?; - let reason = String::decode(r).await?; - let alias = u64::decode(r).await?; +impl Decode for SubscribeError { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let code = u64::decode(r)?; + let reason = String::decode(r)?; + let alias = u64::decode(r)?; Ok(Self { id, @@ -31,12 +30,14 @@ impl SubscribeError { alias, }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w).await?; - self.code.encode(w).await?; - self.reason.encode(w).await?; - self.alias.encode(w).await?; +impl Encode for SubscribeError { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.code.encode(w)?; + self.reason.encode(w)?; + self.alias.encode(w)?; Ok(()) } diff --git a/moq-transport/src/message/subscribe_ok.rs b/moq-transport/src/message/subscribe_ok.rs index c60f6557..a91bfdce 100644 --- a/moq-transport/src/message/subscribe_ok.rs +++ b/moq-transport/src/message/subscribe_ok.rs @@ -1,9 +1,5 @@ -use tokio::io::{AsyncReadExt, AsyncWriteExt}; - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; -use crate::coding::{AsyncRead, AsyncWrite}; - /// Sent by the publisher to accept a Subscribe. #[derive(Clone, Debug)] pub struct SubscribeOk { @@ -17,17 +13,21 @@ pub struct SubscribeOk { pub latest: Option<(u64, u64)>, } -impl SubscribeOk { - pub async fn decode(r: &mut R) -> Result { - let id = u64::decode(r).await?; - let expires = match u64::decode(r).await? { +impl Decode for SubscribeOk { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; + let expires = match u64::decode(r)? { 0 => None, expires => Some(expires), }; - let latest = match r.read_u8().await.map_err(|_| DecodeError::IoError)? { + if !r.has_remaining() { + return Err(DecodeError::More(1)); + } + + let latest = match r.get_u8() { 0 => None, - 1 => Some((u64::decode(r).await?, u64::decode(r).await?)), + 1 => Some((u64::decode(r)?, u64::decode(r)?)), _ => return Err(DecodeError::InvalidValue), }; @@ -35,19 +35,23 @@ impl SubscribeOk { } } -impl SubscribeOk { - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w).await?; - self.expires.unwrap_or(0).encode(w).await?; +impl Encode for SubscribeOk { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; + self.expires.unwrap_or(0).encode(w)?; + + if !w.has_remaining_mut() { + return Err(EncodeError::More(1)); + } match self.latest { Some((group, object)) => { - w.write_u8(1).await.map_err(|_| EncodeError::IoError)?; - group.encode(w).await?; - object.encode(w).await?; + w.put_u8(1); + group.encode(w)?; + object.encode(w)?; } None => { - w.write_u8(0).await.map_err(|_| EncodeError::IoError)?; + w.put_u8(0); } } diff --git a/moq-transport/src/message/unannounce.rs b/moq-transport/src/message/unannounce.rs index e93188c2..e856bf0f 100644 --- a/moq-transport/src/message/unannounce.rs +++ b/moq-transport/src/message/unannounce.rs @@ -1,7 +1,5 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError}; -use crate::coding::{AsyncRead, AsyncWrite}; - /// Sent by the publisher to terminate an Announce. #[derive(Clone, Debug)] pub struct Unannounce { @@ -9,15 +7,17 @@ pub struct Unannounce { pub namespace: String, } -impl Unannounce { - pub async fn decode(r: &mut R) -> Result { - let namespace = String::decode(r).await?; +impl Decode for Unannounce { + fn decode(r: &mut R) -> Result { + let namespace = String::decode(r)?; Ok(Self { namespace }) } +} - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.namespace.encode(w).await?; +impl Encode for Unannounce { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.namespace.encode(w)?; Ok(()) } diff --git a/moq-transport/src/message/unsubscribe.rs b/moq-transport/src/message/unsubscribe.rs index 42cea508..6a592b08 100644 --- a/moq-transport/src/message/unsubscribe.rs +++ b/moq-transport/src/message/unsubscribe.rs @@ -1,4 +1,3 @@ -use crate::coding::{AsyncRead, AsyncWrite}; use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Sent by the subscriber to terminate a Subscribe. @@ -8,16 +7,16 @@ pub struct Unsubscribe { pub id: u64, } -impl Unsubscribe { - pub async fn decode(r: &mut R) -> Result { - let id = u64::decode(r).await?; +impl Decode for Unsubscribe { + fn decode(r: &mut R) -> Result { + let id = u64::decode(r)?; Ok(Self { id }) } } -impl Unsubscribe { - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.id.encode(w).await?; +impl Encode for Unsubscribe { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.id.encode(w)?; Ok(()) } } diff --git a/moq-transport/src/session/error.rs b/moq-transport/src/session/error.rs index 34a83b90..10702f4d 100644 --- a/moq-transport/src/session/error.rs +++ b/moq-transport/src/session/error.rs @@ -1,10 +1,15 @@ +use std::{io, sync}; + use crate::{coding, serve, setup}; #[derive(thiserror::Error, Debug, Clone)] -pub enum SessionError { - // We can't use #[from] here because it would conflict with +pub enum SessionError { #[error("webtransport error: {0}")] - WebTransport(S::Error), + WebTransport(sync::Arc), + + // This needs an Arc because it's not Clone. + #[error("io error: {0}")] + Io(sync::Arc), #[error("encode error: {0}")] Encode(#[from] coding::EncodeError), @@ -42,7 +47,21 @@ pub enum SessionError { WrongSize, } -impl SessionError { +/* +impl From for SessionError { + fn from(err: T) -> Self { + Self::WebTransport(sync::Arc::new(err)) + } +} +*/ + +impl From for SessionError { + fn from(err: io::Error) -> Self { + Self::Io(sync::Arc::new(err)) + } +} + +impl SessionError { /// An integer code that is sent over the wire. pub fn code(&self) -> u64 { match self { @@ -52,6 +71,7 @@ impl SessionError { Self::Version(..) => 406, Self::Decode(_) => 400, Self::Encode(_) => 500, + Self::Io(_) => 500, Self::BoundsExceeded(_) => 500, Self::Duplicate => 409, Self::Internal => 500, diff --git a/moq-transport/src/session/mod.rs b/moq-transport/src/session/mod.rs index a7a89c63..55f00ab6 100644 --- a/moq-transport/src/session/mod.rs +++ b/moq-transport/src/session/mod.rs @@ -6,6 +6,8 @@ mod subscribe; mod subscribed; mod subscriber; +use std::sync::Arc; + pub use announce::*; pub use announced::*; pub use error::*; @@ -17,24 +19,28 @@ pub use subscriber::*; use futures::FutureExt; use futures::{stream::FuturesUnordered, StreamExt}; +use crate::coding::{Reader, Writer}; use crate::{message, setup, util::Queue}; pub struct Session { webtransport: S, - control: (S::SendStream, S::RecvStream), + + sender: Writer, + recver: Reader, + outgoing: Queue, publisher: Option>, subscriber: Option>, - outgoing: Queue>, } impl Session { fn new( webtransport: S, - control: (S::SendStream, S::RecvStream), + sender: Writer, + recver: Reader, role: setup::Role, ) -> (Self, Option>, Option>) { - let outgoing = Default::default(); + let outgoing = Queue::default(); let publisher = role .is_publisher() @@ -43,7 +49,8 @@ impl Session { let session = Self { webtransport, - control, + sender, + recver, outgoing, publisher: publisher.clone(), subscriber: subscriber.clone(), @@ -54,15 +61,20 @@ impl Session { pub async fn connect( session: S, - ) -> Result<(Session, Option>, Option>), SessionError> { + ) -> Result<(Session, Option>, Option>), SessionError> { Self::connect_role(session, setup::Role::Both).await } pub async fn connect_role( session: S, role: setup::Role, - ) -> Result<(Session, Option>, Option>), SessionError> { - let mut control = session.open_bi().await?; + ) -> Result<(Session, Option>, Option>), SessionError> { + let control = session + .open_bi() + .await + .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; + let mut sender = Writer::new(control.0); + let mut recver = Reader::new(control.1); let versions: setup::Versions = [setup::Version::DRAFT_03].into(); @@ -73,10 +85,9 @@ impl Session { }; log::debug!("sending client SETUP: {:?}", client); - client.encode(&mut control.0).await?; - - let server = setup::Server::decode(&mut control.1).await?; + sender.encode(&client).await?; + let server: setup::Server = recver.decode().await?; log::debug!("received server SETUP: {:?}", server); // Downgrade our role based on the server's role. @@ -94,23 +105,25 @@ impl Session { }, }; - Ok(Session::new(session, control, role)) + Ok(Session::new(session, sender, recver, role)) } - pub async fn accept( - session: S, - ) -> Result<(Session, Option>, Option>), SessionError> { + pub async fn accept(session: S) -> Result<(Session, Option>, Option>), SessionError> { Self::accept_role(session, setup::Role::Both).await } pub async fn accept_role( session: S, role: setup::Role, - ) -> Result<(Session, Option>, Option>), SessionError> { - let mut control = session.accept_bi().await?; - - let client = setup::Client::decode(&mut control.1).await?; - + ) -> Result<(Session, Option>, Option>), SessionError> { + let control = session + .accept_bi() + .await + .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; + let mut sender = Writer::new(control.0); + let mut recver = Reader::new(control.1); + + let client: setup::Client = recver.decode().await?; log::debug!("received client SETUP: {:?}", client); if !client.versions.contains(&setup::Version::DRAFT_03) { @@ -142,16 +155,15 @@ impl Session { }; log::debug!("sending server SETUP: {:?}", server); + sender.encode(&server).await?; - server.encode(&mut control.0).await?; - - Ok(Session::new(session, control, role)) + Ok(Session::new(session, sender, recver, role)) } - pub async fn run(self) -> Result<(), SessionError> { + pub async fn run(self) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); - tasks.push(Self::run_send(self.outgoing, self.control.0).boxed()); - tasks.push(Self::run_recv(self.control.1, self.publisher, self.subscriber.clone()).boxed()); + tasks.push(Self::run_send(self.outgoing, self.sender).boxed()); + tasks.push(Self::run_recv(self.recver, self.publisher, self.subscriber.clone()).boxed()); if let Some(subscriber) = self.subscriber { tasks.push(Self::run_streams(self.webtransport.clone(), subscriber.clone()).boxed()); @@ -163,22 +175,22 @@ impl Session { } async fn run_send( - outgoing: Queue>, - mut stream: S::SendStream, - ) -> Result<(), SessionError> { + outgoing: Queue, + mut sender: Writer, + ) -> Result<(), SessionError> { loop { let msg = outgoing.pop().await?; - msg.encode(&mut stream).await?; + sender.encode(&msg).await?; } } async fn run_recv( - mut stream: S::RecvStream, + mut recver: Reader, mut publisher: Option>, mut subscriber: Option>, - ) -> Result<(), SessionError> { + ) -> Result<(), SessionError> { loop { - let msg = message::Message::decode(&mut stream).await?; + let msg: message::Message = recver.decode().await?; let msg = match TryInto::::try_into(msg) { Ok(msg) => { @@ -207,13 +219,13 @@ impl Session { } } - async fn run_streams(webtransport: S, subscriber: Subscriber) -> Result<(), SessionError> { + async fn run_streams(webtransport: S, subscriber: Subscriber) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); loop { tokio::select! { res = webtransport.accept_uni() => { - let stream = res?; + let stream = res.map_err(|e| SessionError::WebTransport(Arc::new(e)))?; tasks.push(Subscriber::recv_stream(subscriber.clone(), stream)); }, res = tasks.next(), if !tasks.is_empty() => res.unwrap()?, @@ -221,9 +233,13 @@ impl Session { } } - async fn run_datagrams(webtransport: S, mut subscriber: Subscriber) -> Result<(), SessionError> { + async fn run_datagrams(webtransport: S, mut subscriber: Subscriber) -> Result<(), SessionError> { loop { - let datagram = webtransport.read_datagram().await?; + let datagram = webtransport + .recv_datagram() + .await + .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; + subscriber.recv_datagram(datagram).await?; } } diff --git a/moq-transport/src/session/publisher.rs b/moq-transport/src/session/publisher.rs index 154f55c3..8df995b9 100644 --- a/moq-transport/src/session/publisher.rs +++ b/moq-transport/src/session/publisher.rs @@ -21,13 +21,13 @@ pub struct Publisher { announces: Arc>>, subscribed: Arc>>>, - subscribed_queue: Queue, SessionError>, + subscribed_queue: Queue, SessionError>, - outgoing: Queue>, + outgoing: Queue, } impl Publisher { - pub(crate) fn new(webtransport: S, outgoing: Queue>) -> Self { + pub(crate) fn new(webtransport: S, outgoing: Queue) -> Self { Self { webtransport, announces: Default::default(), @@ -37,17 +37,17 @@ impl Publisher { } } - pub async fn accept(session: S) -> Result<(Session, Publisher), SessionError> { + pub async fn accept(session: S) -> Result<(Session, Publisher), SessionError> { let (session, publisher, _) = Session::accept_role(session, setup::Role::Publisher).await?; Ok((session, publisher.unwrap())) } - pub async fn connect(session: S) -> Result<(Session, Publisher), SessionError> { + pub async fn connect(session: S) -> Result<(Session, Publisher), SessionError> { let (session, publisher, _) = Session::connect_role(session, setup::Role::Publisher).await?; Ok((session, publisher.unwrap())) } - pub fn announce(&mut self, namespace: &str) -> Result, SessionError> { + pub fn announce(&mut self, namespace: &str) -> Result, SessionError> { let mut announces = self.announces.lock().unwrap(); // Insert the abort handle into the lookup table. @@ -68,13 +68,13 @@ impl Publisher { Ok(announce) } - pub async fn subscribed(&mut self) -> Result, SessionError> { + pub async fn subscribed(&mut self) -> Result, SessionError> { self.subscribed_queue.pop().await } // Helper to announce and serve any matching subscribers. // TODO this currently takes over the connection; definitely remove Clone - pub async fn serve(mut self, broadcast: serve::BroadcastSubscriber) -> Result<(), SessionError> { + pub async fn serve(mut self, broadcast: serve::BroadcastSubscriber) -> Result<(), SessionError> { log::info!("serving broadcast: {}", broadcast.namespace); let announce = self.announce(&broadcast.namespace)?; @@ -116,7 +116,7 @@ impl Publisher { broadcast.get_track(subscribe.name())?.ok_or(ServeError::NotFound) } - pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { + pub(crate) fn recv_message(&mut self, msg: message::Subscriber) -> Result<(), SessionError> { log::debug!("received message: {:?}", msg); match msg { @@ -128,13 +128,13 @@ impl Publisher { } } - fn recv_announce_ok(&mut self, _msg: message::AnnounceOk) -> Result<(), SessionError> { + fn recv_announce_ok(&mut self, _msg: message::AnnounceOk) -> Result<(), SessionError> { // Who cares // TODO make AnnouncePending so we're forced to care Ok(()) } - fn recv_announce_error(&mut self, msg: message::AnnounceError) -> Result<(), SessionError> { + fn recv_announce_error(&mut self, msg: message::AnnounceError) -> Result<(), SessionError> { if let Some(announce) = self.announces.lock().unwrap().get_mut(&msg.namespace) { announce.recv_error(ServeError::Closed(msg.code)).ok(); } @@ -142,11 +142,11 @@ impl Publisher { Ok(()) } - fn recv_announce_cancel(&mut self, _msg: message::AnnounceCancel) -> Result<(), SessionError> { + fn recv_announce_cancel(&mut self, _msg: message::AnnounceCancel) -> Result<(), SessionError> { unimplemented!("recv_announce_cancel") } - fn recv_subscribe(&mut self, msg: message::Subscribe) -> Result<(), SessionError> { + fn recv_subscribe(&mut self, msg: message::Subscribe) -> Result<(), SessionError> { let mut subscribes = self.subscribed.lock().unwrap(); // Insert the abort handle into the lookup table. @@ -160,7 +160,7 @@ impl Publisher { self.subscribed_queue.push(subscribe) } - fn recv_unsubscribe(&mut self, msg: message::Unsubscribe) -> Result<(), SessionError> { + fn recv_unsubscribe(&mut self, msg: message::Unsubscribe) -> Result<(), SessionError> { if let Some(subscribed) = self.subscribed.lock().unwrap().get_mut(&msg.id) { subscribed.recv_unsubscribe().ok(); } @@ -168,7 +168,7 @@ impl Publisher { Ok(()) } - pub fn send_message>(&self, msg: T) -> Result<(), SessionError> { + pub fn send_message>(&self, msg: T) -> Result<(), SessionError> { let msg = msg.into(); log::debug!("sending message: {:?}", msg); self.outgoing.push(msg.into()) @@ -186,7 +186,7 @@ impl Publisher { &mut self.webtransport } - pub fn close(self, err: SessionError) { + pub fn close(self, err: SessionError) { self.outgoing.close(err.clone()).ok(); self.subscribed_queue.close(err).ok(); } diff --git a/moq-transport/src/session/subscribe.rs b/moq-transport/src/session/subscribe.rs index f27bf8c9..418f655e 100644 --- a/moq-transport/src/session/subscribe.rs +++ b/moq-transport/src/session/subscribe.rs @@ -1,6 +1,9 @@ use std::sync::{Arc, Mutex}; +use tokio::io::AsyncRead; + use crate::{ + coding::Reader, data, message::{self, SubscribePair}, serve::{self, ServeError}, @@ -17,7 +20,7 @@ pub struct Subscribe { } impl Subscribe { - pub(super) fn new(session: Subscriber, msg: message::Subscribe) -> (Subscribe, SubscribeRecv) { + pub(super) fn new(session: Subscriber, msg: message::Subscribe) -> (Subscribe, SubscribeRecv) { let state = Watch::new(State::default()); let (publisher, subscriber) = serve::Track { @@ -88,12 +91,12 @@ impl Drop for Subscribe { } #[derive(Clone)] -pub(super) struct SubscribeRecv { +pub(super) struct SubscribeRecv { publisher: Arc>, state: Watch, } -impl SubscribeRecv { +impl SubscribeRecv { fn new(state: Watch, publisher: serve::TrackPublisher) -> Self { Self { publisher: Arc::new(Mutex::new(publisher)), @@ -117,32 +120,38 @@ impl SubscribeRecv { Ok(()) } - pub async fn recv_stream(&mut self, header: data::Header, stream: S::RecvStream) -> Result<(), SessionError> { + pub async fn recv_stream( + &mut self, + header: data::Header, + reader: Reader, + ) -> Result<(), SessionError> { match header { - data::Header::Track(track) => self.recv_track(track, stream).await, - data::Header::Group(group) => self.recv_group(group, stream).await, - data::Header::Object(object) => self.recv_object(object, stream).await, + data::Header::Track(track) => self.recv_track(track, reader).await, + data::Header::Group(group) => self.recv_group(group, reader).await, + data::Header::Object(object) => self.recv_object(object, reader).await, } } - async fn recv_track( + async fn recv_track( &mut self, header: data::TrackHeader, - mut stream: S::RecvStream, - ) -> Result<(), SessionError> { + mut reader: Reader, + ) -> Result<(), SessionError> { log::trace!("received track: {:?}", header); let mut track = self.publisher.lock().unwrap().create_stream(header.send_order)?; - while let Some(chunk) = data::TrackObject::decode(&mut stream).await? { + while !reader.done().await? { + let chunk: data::TrackObject = reader.decode().await?; + let mut remain = chunk.size; let mut chunks = vec![]; while remain > 0 { - let chunk = stream.read_chunk(remain, true).await?.ok_or(SessionError::WrongSize)?; - log::trace!("received track payload: {:?}", chunk.bytes.len()); - remain -= chunk.bytes.len(); - chunks.push(chunk.bytes); + let chunk = reader.read(remain).await?.ok_or(SessionError::WrongSize)?; + log::trace!("received track payload: {:?}", chunk.len()); + remain -= chunk.len(); + chunks.push(chunk); } let object = serve::StreamObject { @@ -158,11 +167,11 @@ impl SubscribeRecv { Ok(()) } - async fn recv_group( + async fn recv_group( &mut self, header: data::GroupHeader, - mut stream: S::RecvStream, - ) -> Result<(), SessionError> { + mut reader: Reader, + ) -> Result<(), SessionError> { log::trace!("received group: {:?}", header); let mut group = self.publisher.lock().unwrap().create_group(serve::Group { @@ -170,34 +179,36 @@ impl SubscribeRecv { send_order: header.send_order, })?; - while let Some(object) = data::GroupObject::decode(&mut stream).await? { + while !reader.done().await? { + let object: data::GroupObject = reader.decode().await?; + log::trace!("received group object: {:?}", object); let mut remain = object.size; let mut object = group.create_object(object.size)?; while remain > 0 { - let data = stream.read_chunk(remain, true).await?.ok_or(SessionError::WrongSize)?; - log::trace!("received group payload: {:?}", data.bytes.len()); - remain -= data.bytes.len(); - object.write(data.bytes)?; + let data = reader.read(remain).await?.ok_or(SessionError::WrongSize)?; + log::trace!("received group payload: {:?}", data.len()); + remain -= data.len(); + object.write(data)?; } } Ok(()) } - async fn recv_object( + async fn recv_object( &mut self, header: data::ObjectHeader, - mut stream: S::RecvStream, - ) -> Result<(), SessionError> { + mut reader: Reader, + ) -> Result<(), SessionError> { log::trace!("received object: {:?}", header); // TODO avoid buffering the entire object to learn the size. let mut chunks = vec![]; - while let Some(data) = stream.read_chunk(usize::MAX, true).await? { - log::trace!("received object payload: {:?}", data.bytes.len()); - chunks.push(data.bytes); + while let Some(data) = reader.read(usize::MAX).await? { + log::trace!("received object payload: {:?}", data.len()); + chunks.push(data); } let mut object = self.publisher.lock().unwrap().create_object(serve::ObjectHeader { @@ -216,7 +227,7 @@ impl SubscribeRecv { Ok(()) } - pub fn recv_datagram(&self, datagram: data::Datagram) -> Result<(), SessionError> { + pub fn recv_datagram(&self, datagram: data::Datagram) -> Result<(), SessionError> { log::trace!("received datagram: {:?}", datagram); self.publisher.lock().unwrap().write_datagram(serve::Datagram { diff --git a/moq-transport/src/session/subscribed.rs b/moq-transport/src/session/subscribed.rs index 48af20c5..cc4aaf18 100644 --- a/moq-transport/src/session/subscribed.rs +++ b/moq-transport/src/session/subscribed.rs @@ -1,7 +1,9 @@ +use std::sync::Arc; + use futures::stream::FuturesUnordered; use futures::{FutureExt, StreamExt}; -use tokio::io::AsyncWriteExt; +use crate::coding::{Encode, Writer}; use crate::serve::ServeError; use crate::util::{Watch, WatchWeak}; use crate::{data, message, serve}; @@ -34,7 +36,7 @@ impl Subscribed { self.msg.track_name.as_str() } - pub async fn serve(mut self, mut track: serve::TrackSubscriber) -> Result<(), SessionError> { + pub async fn serve(mut self, mut track: serve::TrackSubscriber) -> Result<(), SessionError> { let mut tasks = FuturesUnordered::new(); self.state.lock_mut().ok(track.latest())?; @@ -61,8 +63,15 @@ impl Subscribed { } } - async fn serve_track(mut self, mut track: serve::StreamSubscriber) -> Result<(), SessionError> { - let mut stream = self.session.webtransport().open_uni().await?; + async fn serve_track(mut self, mut track: serve::StreamSubscriber) -> Result<(), SessionError> { + let stream = self + .session + .webtransport() + .open_uni() + .await + .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; + + let mut writer = Writer::new(stream); let header: data::Header = data::TrackHeader { subscribe_id: self.msg.id, @@ -71,7 +80,7 @@ impl Subscribed { } .into(); - header.encode(&mut stream).await?; + writer.encode(&header).await?; log::trace!("sent track header: {:?}", header); @@ -85,12 +94,12 @@ impl Subscribed { size: object.payload.len(), }; - header.encode(&mut stream).await?; + writer.encode(&header).await?; log::trace!("sent track object: {:?}", header); self.state.lock_mut().update_max(object.group_id, object.object_id)?; - stream.write_all(&object.payload).await?; + writer.write(&object.payload).await?; log::trace!("sent track payload: {:?}", object.payload.len()); log::trace!("sent track done"); @@ -99,8 +108,14 @@ impl Subscribed { Ok(()) } - pub async fn serve_group(mut self, mut group: serve::GroupSubscriber) -> Result<(), SessionError> { - let mut stream = self.session.webtransport().open_uni().await?; + pub async fn serve_group(mut self, mut group: serve::GroupSubscriber) -> Result<(), SessionError> { + let stream = self + .session + .webtransport() + .open_uni() + .await + .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; + let mut writer = Writer::new(stream); let header: data::Header = data::GroupHeader { subscribe_id: self.msg.id, @@ -110,7 +125,7 @@ impl Subscribed { } .into(); - header.encode(&mut stream).await?; + writer.encode(&header).await?; log::trace!("sent group: {:?}", header); @@ -120,14 +135,14 @@ impl Subscribed { size: object.size, }; - self.state.lock_mut().update_max(group.id, object.object_id)?; + writer.encode(&header).await?; - header.encode(&mut stream).await?; + self.state.lock_mut().update_max(group.id, object.object_id)?; log::trace!("sent group object: {:?}", header); while let Some(chunk) = object.read().await? { - stream.write_all(&chunk).await?; + writer.write(&chunk).await?; log::trace!("sent group payload: {:?}", chunk.len()); } @@ -137,8 +152,14 @@ impl Subscribed { Ok(()) } - pub async fn serve_object(mut self, mut object: serve::ObjectSubscriber) -> Result<(), SessionError> { - let mut stream = self.session.webtransport().open_uni().await?; + pub async fn serve_object(mut self, mut object: serve::ObjectSubscriber) -> Result<(), SessionError> { + let stream = self + .session + .webtransport() + .open_uni() + .await + .map_err(|e| SessionError::WebTransport(Arc::new(e)))?; + let mut writer = Writer::new(stream); let header: data::Header = data::ObjectHeader { subscribe_id: self.msg.id, @@ -148,14 +169,15 @@ impl Subscribed { send_order: object.send_order, } .into(); - header.encode(&mut stream).await?; + + writer.encode(&header).await?; log::trace!("sent object: {:?}", header); self.state.lock_mut().update_max(object.group_id, object.object_id)?; while let Some(chunk) = object.read().await? { - stream.write_all(&chunk).await?; + writer.write(&chunk).await?; log::trace!("sent object payload: {:?}", chunk.len()); } @@ -164,7 +186,7 @@ impl Subscribed { Ok(()) } - pub async fn serve_datagram(&mut self, datagram: serve::Datagram) -> Result<(), SessionError> { + pub async fn serve_datagram(&mut self, datagram: serve::Datagram) -> Result<(), SessionError> { let datagram = data::Datagram { subscribe_id: self.msg.id, track_alias: self.msg.track_alias, @@ -175,7 +197,7 @@ impl Subscribed { }; let mut buffer = Vec::with_capacity(datagram.payload.len() + 100); - datagram.encode(&mut buffer).await?; // TODO Not actually async + datagram.encode(&mut buffer)?; log::trace!("sent datagram: {:?}", datagram); diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index bc370100..5d5568b0 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -4,7 +4,11 @@ use std::{ sync::{atomic, Arc, Mutex}, }; -use crate::{data, message, setup, util::Queue}; +use crate::{ + coding::{Decode, Reader}, + data, message, setup, + util::Queue, +}; use super::{Announced, AnnouncedRecv, Session, SessionError, Subscribe, SubscribeOptions, SubscribeRecv}; @@ -12,16 +16,16 @@ use super::{Announced, AnnouncedRecv, Session, SessionError, Subscribe, Subscrib #[derive(Clone)] pub struct Subscriber { announced: Arc>>>, - announced_queue: Queue, SessionError>, + announced_queue: Queue, SessionError>, - subscribes: Arc>>>, + subscribes: Arc>>, subscribe_next: Arc, - outgoing: Queue>, + outgoing: Queue, } impl Subscriber { - pub(super) fn new(outgoing: Queue>) -> Self { + pub(super) fn new(outgoing: Queue) -> Self { Self { announced: Default::default(), announced_queue: Default::default(), @@ -31,17 +35,17 @@ impl Subscriber { } } - pub async fn accept(session: S) -> Result<(Session, Self), SessionError> { + pub async fn accept(session: S) -> Result<(Session, Self), SessionError> { let (session, _, subscriber) = Session::accept_role(session, setup::Role::Subscriber).await?; Ok((session, subscriber.unwrap())) } - pub async fn connect(session: S) -> Result<(Session, Self), SessionError> { + pub async fn connect(session: S) -> Result<(Session, Self), SessionError> { let (session, _, subscriber) = Session::connect_role(session, setup::Role::Subscriber).await?; Ok((session, subscriber.unwrap())) } - pub async fn announced(&mut self) -> Result, SessionError> { + pub async fn announced(&mut self) -> Result, SessionError> { self.announced_queue.pop().await } @@ -50,7 +54,7 @@ impl Subscriber { namespace: &str, name: &str, options: SubscribeOptions, - ) -> Result, SessionError> { + ) -> Result, SessionError> { let id = self.subscribe_next.fetch_add(1, atomic::Ordering::Relaxed); let msg = message::Subscribe { @@ -66,18 +70,18 @@ impl Subscriber { self.send_message(msg.clone())?; let (publisher, subscribe) = Subscribe::new(self.clone(), msg); - self.subscribes.lock().unwrap().insert(id, publisher); + self.subscribes.lock().unwrap().insert(id, subscribe); - Ok(subscribe) + Ok(publisher) } - pub(super) fn send_message>(&mut self, msg: M) -> Result<(), SessionError> { + pub(super) fn send_message>(&mut self, msg: M) -> Result<(), SessionError> { let msg = msg.into(); log::debug!("sending message: {:?}", msg); self.outgoing.push(msg.into()) } - pub(super) fn recv_message(&mut self, msg: message::Publisher) -> Result<(), SessionError> { + pub(super) fn recv_message(&mut self, msg: message::Publisher) -> Result<(), SessionError> { log::debug!("received message: {:?}", msg); match msg { @@ -89,7 +93,7 @@ impl Subscriber { } } - fn recv_announce(&mut self, msg: message::Announce) -> Result<(), SessionError> { + fn recv_announce(&mut self, msg: message::Announce) -> Result<(), SessionError> { let mut announces = self.announced.lock().unwrap(); let entry = match announces.entry(msg.namespace.clone()) { @@ -104,7 +108,7 @@ impl Subscriber { Ok(()) } - fn recv_unannounce(&mut self, msg: message::Unannounce) -> Result<(), SessionError> { + fn recv_unannounce(&mut self, msg: message::Unannounce) -> Result<(), SessionError> { if let Some(announce) = self.announced.lock().unwrap().get_mut(&msg.namespace) { announce.recv_unannounce().ok(); } @@ -112,7 +116,7 @@ impl Subscriber { Ok(()) } - fn recv_subscribe_ok(&mut self, msg: message::SubscribeOk) -> Result<(), SessionError> { + fn recv_subscribe_ok(&mut self, msg: message::SubscribeOk) -> Result<(), SessionError> { if let Some(sub) = self.subscribes.lock().unwrap().get_mut(&msg.id) { sub.recv_ok(msg).ok(); } @@ -120,7 +124,7 @@ impl Subscriber { Ok(()) } - fn recv_subscribe_error(&mut self, msg: message::SubscribeError) -> Result<(), SessionError> { + fn recv_subscribe_error(&mut self, msg: message::SubscribeError) -> Result<(), SessionError> { if let Some(subscriber) = self.subscribes.lock().unwrap().get_mut(&msg.id) { subscriber.recv_error(msg.code).ok(); } @@ -128,7 +132,7 @@ impl Subscriber { Ok(()) } - fn recv_subscribe_done(&mut self, msg: message::SubscribeDone) -> Result<(), SessionError> { + fn recv_subscribe_done(&mut self, msg: message::SubscribeDone) -> Result<(), SessionError> { if let Some(subscriber) = self.subscribes.lock().unwrap().get_mut(&msg.id) { subscriber.recv_done(msg.code).ok(); } @@ -144,23 +148,24 @@ impl Subscriber { self.announced.lock().unwrap().remove(namespace); } - pub(super) async fn recv_stream(self, mut stream: S::RecvStream) -> Result<(), SessionError> { - let header = data::Header::decode(&mut stream).await?; + pub(super) async fn recv_stream(self, stream: S::RecvStream) -> Result<(), SessionError> { + let mut reader = Reader::new(stream); + let header: data::Header = reader.decode().await?; let id = header.subscribe_id(); let subscribe = self.subscribes.lock().unwrap().get(&id).cloned(); if let Some(mut subscribe) = subscribe { - subscribe.recv_stream(header, stream).await? + subscribe.recv_stream(header, reader).await? } Ok(()) } // TODO should not be async - pub async fn recv_datagram(&mut self, datagram: bytes::Bytes) -> Result<(), SessionError> { + pub async fn recv_datagram(&mut self, datagram: bytes::Bytes) -> Result<(), SessionError> { let mut cursor = io::Cursor::new(datagram); - let datagram = data::Datagram::decode(&mut cursor).await?; + let datagram = data::Datagram::decode(&mut cursor)?; let subscribe = self.subscribes.lock().unwrap().get(&datagram.subscribe_id).cloned(); @@ -171,7 +176,7 @@ impl Subscriber { Ok(()) } - pub fn close(self, err: SessionError) { + pub fn close(self, err: SessionError) { self.outgoing.close(err.clone()).ok(); self.announced_queue.close(err).ok(); } diff --git a/moq-transport/src/setup/client.rs b/moq-transport/src/setup/client.rs index 7807a0b2..04a52103 100644 --- a/moq-transport/src/setup/client.rs +++ b/moq-transport/src/setup/client.rs @@ -1,8 +1,6 @@ use super::{Role, Versions}; use crate::coding::{Decode, DecodeError, Encode, EncodeError, Params}; -use crate::coding::{AsyncRead, AsyncWrite}; - /// Sent by the client to setup the session. // NOTE: This is not a message type, but rather the control stream header. // Proposal: https://github.com/moq-wg/moq-transport/issues/138 @@ -18,18 +16,18 @@ pub struct Client { pub params: Params, } -impl Client { +impl Decode for Client { /// Decode a client setup message. - pub async fn decode(r: &mut R) -> Result { - let typ = u64::decode(r).await?; + fn decode(r: &mut R) -> Result { + let typ = u64::decode(r)?; if typ != 0x40 { return Err(DecodeError::InvalidMessage(typ)); } - let versions = Versions::decode(r).await?; - let mut params = Params::decode(r).await?; + let versions = Versions::decode(r)?; + let mut params = Params::decode(r)?; - let role = params.get::(0).await?.ok_or(DecodeError::MissingParameter)?; + let role = params.get::(0)?.ok_or(DecodeError::MissingParameter)?; // Make sure the PATH parameter isn't used // TODO: This assumes WebTransport support only @@ -39,16 +37,18 @@ impl Client { Ok(Self { versions, role, params }) } +} +impl Encode for Client { /// Encode a server setup message. - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - 0x40_u64.encode(w).await?; - self.versions.encode(w).await?; + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + 0x40_u64.encode(w)?; + self.versions.encode(w)?; let mut params = self.params.clone(); - params.set(0, self.role).await?; + params.set(0, self.role)?; - params.encode(w).await?; + params.encode(w)?; Ok(()) } diff --git a/moq-transport/src/setup/mod.rs b/moq-transport/src/setup/mod.rs index e5c59c84..6925cd97 100644 --- a/moq-transport/src/setup/mod.rs +++ b/moq-transport/src/setup/mod.rs @@ -13,3 +13,5 @@ pub use client::*; pub use role::*; pub use server::*; pub use version::*; + +pub const ALPN: &[u8] = b"moq-00"; diff --git a/moq-transport/src/setup/role.rs b/moq-transport/src/setup/role.rs index 84daec7a..a7a90dc9 100644 --- a/moq-transport/src/setup/role.rs +++ b/moq-transport/src/setup/role.rs @@ -1,5 +1,3 @@ -use crate::coding::{AsyncRead, AsyncWrite}; - use crate::coding::{Decode, DecodeError, Encode, EncodeError}; /// Indicates the endpoint is a publisher, subscriber, or both. @@ -56,19 +54,17 @@ impl TryFrom for Role { } } -#[async_trait::async_trait] impl Decode for Role { /// Decode the role. - async fn decode(r: &mut R) -> Result { - let v = u64::decode(r).await?; + fn decode(r: &mut R) -> Result { + let v = u64::decode(r)?; v.try_into() } } -#[async_trait::async_trait] impl Encode for Role { /// Encode the role. - async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - u64::from(*self).encode(w).await + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + u64::from(*self).encode(w) } } diff --git a/moq-transport/src/setup/server.rs b/moq-transport/src/setup/server.rs index ceed6d53..c1784d3b 100644 --- a/moq-transport/src/setup/server.rs +++ b/moq-transport/src/setup/server.rs @@ -1,8 +1,6 @@ use super::{Role, Version}; use crate::coding::{Decode, DecodeError, Encode, EncodeError, Params}; -use crate::coding::{AsyncRead, AsyncWrite}; - /// Sent by the server in response to a client setup. // NOTE: This is not a message type, but rather the control stream header. // Proposal: https://github.com/moq-wg/moq-transport/issues/138 @@ -19,18 +17,18 @@ pub struct Server { pub params: Params, } -impl Server { +impl Decode for Server { /// Decode the server setup. - pub async fn decode(r: &mut R) -> Result { - let typ = u64::decode(r).await?; + fn decode(r: &mut R) -> Result { + let typ = u64::decode(r)?; if typ != 0x41 { return Err(DecodeError::InvalidMessage(typ)); } - let version = Version::decode(r).await?; - let mut params = Params::decode(r).await?; + let version = Version::decode(r)?; + let mut params = Params::decode(r)?; - let role = params.get::(0).await?.ok_or(DecodeError::MissingParameter)?; + let role = params.get::(0)?.ok_or(DecodeError::MissingParameter)?; // Make sure the PATH parameter isn't used if params.has(1) { @@ -39,15 +37,16 @@ impl Server { Ok(Self { version, role, params }) } +} - /// Encode the server setup. - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - 0x41_u64.encode(w).await?; - self.version.encode(w).await?; +impl Encode for Server { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + 0x41_u64.encode(w)?; + self.version.encode(w)?; let mut params = self.params.clone(); - params.set(0, self.role).await?; - params.encode(w).await?; + params.set(0, self.role)?; + params.encode(w)?; Ok(()) } diff --git a/moq-transport/src/setup/version.rs b/moq-transport/src/setup/version.rs index beff9ce3..5d67f8c2 100644 --- a/moq-transport/src/setup/version.rs +++ b/moq-transport/src/setup/version.rs @@ -1,7 +1,5 @@ use crate::coding::{Decode, DecodeError, Encode, EncodeError}; -use crate::coding::{AsyncRead, AsyncWrite}; - use std::ops::Deref; /// A version number negotiated during the setup. @@ -34,16 +32,17 @@ impl From for u64 { } } -impl Version { +impl Decode for Version { /// Decode the version number. - pub async fn decode(r: &mut R) -> Result { - let v = u64::decode(r).await?; + fn decode(r: &mut R) -> Result { + let v = u64::decode(r)?; Ok(Self(v)) } +} - /// Encode the version number. - pub async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.0.encode(w).await?; +impl Encode for Version { + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.0.encode(w)?; Ok(()) } } @@ -52,15 +51,14 @@ impl Version { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Versions(Vec); -#[async_trait::async_trait] impl Decode for Versions { /// Decode the version list. - async fn decode(r: &mut R) -> Result { - let count = u64::decode(r).await?; + fn decode(r: &mut R) -> Result { + let count = u64::decode(r)?; let mut vs = Vec::new(); for _ in 0..count { - let v = Version::decode(r).await?; + let v = Version::decode(r)?; vs.push(v); } @@ -68,14 +66,13 @@ impl Decode for Versions { } } -#[async_trait::async_trait] impl Encode for Versions { /// Encode the version list. - async fn encode(&self, w: &mut W) -> Result<(), EncodeError> { - self.0.len().encode(w).await?; + fn encode(&self, w: &mut W) -> Result<(), EncodeError> { + self.0.len().encode(w)?; for v in &self.0 { - v.encode(w).await?; + v.encode(w)?; } Ok(()) From 7aafa4be238719bd7f45f8b51b70dfc0de3bb2d2 Mon Sep 17 00:00:00 2001 From: Luke Curley Date: Thu, 21 Mar 2024 16:26:30 -0700 Subject: [PATCH 3/7] Works? --- Cargo.lock | 65 ++---- moq-clock/Cargo.toml | 2 +- moq-clock/src/main.rs | 7 +- moq-pub/Cargo.toml | 3 +- moq-relay/Cargo.toml | 7 +- moq-relay/src/connection.rs | 88 +++++--- moq-relay/src/origin.rs | 259 +++++++++++++++++++----- moq-transport/Cargo.toml | 2 +- moq-transport/src/serve/track.rs | 7 + moq-transport/src/session/subscribe.rs | 142 +++---------- moq-transport/src/session/subscriber.rs | 28 ++- 11 files changed, 342 insertions(+), 268 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 95b315da..cc31efe2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -875,7 +875,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "webtransport-quinn 0.7.0", + "webtransport-quinn", ] [[package]] @@ -899,7 +899,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "webtransport-quinn 0.7.0", + "webtransport-quinn", ] [[package]] @@ -929,8 +929,8 @@ dependencies = [ "tracing-subscriber", "url", "webpki", - "webtransport-generic 0.5.0", - "webtransport-quinn 0.7.0 (registry+https://github.com/rust-lang/crates.io-index)", + "webtransport-generic", + "webtransport-quinn", ] [[package]] @@ -955,7 +955,7 @@ dependencies = [ "tracing", "tracing-subscriber", "url", - "webtransport-generic 0.5.0", + "webtransport-generic", ] [[package]] @@ -1223,13 +1223,13 @@ dependencies = [ [[package]] name = "quictransport-quinn" -version = "0.7.0" +version = "0.8.0" dependencies = [ "bytes", "quinn", "tokio", - "webtransport-generic 0.5.0", - "webtransport-proto 0.6.0", + "webtransport-generic", + "webtransport-proto", ] [[package]] @@ -2194,17 +2194,7 @@ checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" [[package]] name = "webtransport-generic" -version = "0.5.0" -dependencies = [ - "bytes", - "tokio", -] - -[[package]] -name = "webtransport-generic" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df712317d761312996f654739debeb3838eb02c6fd9146d9efdfd08a46674e45" +version = "0.8.0" dependencies = [ "bytes", "tokio", @@ -2220,40 +2210,9 @@ dependencies = [ "url", ] -[[package]] -name = "webtransport-proto" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebeada5037d6302980ae2e0ab8d840e329c1697c612c6c077172de2b7631a276" -dependencies = [ - "bytes", - "http", - "thiserror", - "url", -] - [[package]] name = "webtransport-quinn" -version = "0.7.0" -dependencies = [ - "bytes", - "futures", - "http", - "log", - "quinn", - "quinn-proto", - "thiserror", - "tokio", - "url", - "webtransport-generic 0.5.0", - "webtransport-proto 0.6.0", -] - -[[package]] -name = "webtransport-quinn" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58c942869a07920baa3cdabaf56279b027e5c87953541837dc5ff9dd5ad3d9b7" +version = "0.8.0" dependencies = [ "bytes", "futures", @@ -2264,8 +2223,8 @@ dependencies = [ "thiserror", "tokio", "url", - "webtransport-generic 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", - "webtransport-proto 0.6.0 (registry+https://github.com/rust-lang/crates.io-index)", + "webtransport-generic", + "webtransport-proto", ] [[package]] diff --git a/moq-clock/Cargo.toml b/moq-clock/Cargo.toml index 90dbebec..07635682 100644 --- a/moq-clock/Cargo.toml +++ b/moq-clock/Cargo.toml @@ -19,7 +19,7 @@ moq-transport = { path = "../moq-transport" } # QUIC quinn = "0.10" #webtransport-quinn = "0.7" -webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn" } +webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn", version = "0.8" } url = "2" # Crypto diff --git a/moq-clock/src/main.rs b/moq-clock/src/main.rs index 5994dfb8..13a4194b 100644 --- a/moq-clock/src/main.rs +++ b/moq-clock/src/main.rs @@ -94,11 +94,10 @@ async fn main() -> anyhow::Result<()> { .await .context("failed to create MoQ Transport session")?; - let subscriber = subscriber - .subscribe(&config.namespace, &config.track, Default::default()) - .context("failed to subscribe to track")?; + let (prod, sub) = serve::Track::new(&config.namespace, &config.track).produce(); + subscriber.subscribe(prod).context("failed to subscribe to track")?; - let clock = clock::Subscriber::new(subscriber.track()); + let clock = clock::Subscriber::new(sub); tokio::select! { res = session.run() => res.context("session error")?, diff --git a/moq-pub/Cargo.toml b/moq-pub/Cargo.toml index 95c90f68..b5005f8a 100644 --- a/moq-pub/Cargo.toml +++ b/moq-pub/Cargo.toml @@ -18,8 +18,7 @@ moq-transport = { path = "../moq-transport" } # QUIC quinn = "0.10" -#webtransport-quinn = "0.7" -webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn" } +webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn", version = "0.8" } url = "2" # Crypto diff --git a/moq-relay/Cargo.toml b/moq-relay/Cargo.toml index e405dd39..fb6fc6ad 100644 --- a/moq-relay/Cargo.toml +++ b/moq-relay/Cargo.toml @@ -17,10 +17,9 @@ moq-api = { path = "../moq-api" } # QUIC quinn = "0.10" -webtransport-quinn = "0.7" -#webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn" } -quictransport-quinn = { path = "../../webtransport-rs/quictransport-quinn" } -webtransport-generic = { path = "../../webtransport-rs/webtransport-generic" } +quictransport-quinn = { path = "../../webtransport-rs/quictransport-quinn", version = "0.8" } +webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn", version = "0.8" } +webtransport-generic = { path = "../../webtransport-rs/webtransport-generic", version = "0.8" } url = "2" # Crypto diff --git a/moq-relay/src/connection.rs b/moq-relay/src/connection.rs index a027ccfd..e7ab5f14 100644 --- a/moq-relay/src/connection.rs +++ b/moq-relay/src/connection.rs @@ -3,7 +3,7 @@ use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; use moq_transport::session::{Publisher, SessionError, Subscriber}; -use crate::Origin; +use crate::{Origin, OriginPublisher}; #[derive(Clone)] pub struct Connection { @@ -15,16 +15,17 @@ impl Connection { Self { origin } } - pub async fn run(self, conn: quinn::Connecting) -> anyhow::Result<()> { + pub async fn run(self, mut conn: quinn::Connecting) -> anyhow::Result<()> { let handshake = conn .handshake_data() .await? - .downcast::()?; + .downcast::() + .unwrap(); let alpn = handshake.protocol.context("missing ALPN")?; log::debug!( - "received QUIC handshake: ip={} alpn={} server={}", + "received QUIC handshake: ip={} alpn={:?} server={:?}", conn.remote_address(), alpn, handshake.server_name @@ -34,46 +35,65 @@ impl Connection { let conn = conn.await.context("failed to establish QUIC connection")?; log::debug!( - "established QUIC connection: id={} ip={} alpn={} server={}", + "established QUIC connection: id={} ip={} alpn={:?} server={:?}", conn.stable_id(), conn.remote_address(), alpn, handshake.server_name ); - let session = if alpn.as_slice() == webtransport_quinn::ALPN { - // Wait for the CONNECT request. - let request = webtransport_quinn::accept(conn) - .await - .context("failed to receive WebTransport request")?; + if alpn.as_slice() == webtransport_quinn::ALPN { + self.serve_webtransport(conn).await?; + } else { + self.serve_quic(conn).await?; + } - // Accept the CONNECT request. - let session = request - .ok() - .await - .context("failed to respond to WebTransport request")?; + Ok(()) + } - let path = request.url().path().trim_matches('/').to_string(); + async fn serve_webtransport(self, conn: quinn::Connection) -> anyhow::Result<()> { + // Wait for the CONNECT request. + let request = webtransport_quinn::accept(conn) + .await + .context("failed to receive WebTransport request")?; - log::debug!("received WebTransport CONNECT: path={}", path); - session - } else if alpn.as_slice() == moq_transport::setup::ALPN { - let session: quictransport_quinn::Session = conn.into(); - - session - } else { - anyhow::anyhow!("unsupported ALPN: alpn={:?}", alpn); - }; + // Accept the CONNECT request. + let session = request + .ok() + .await + .context("failed to respond to WebTransport request")?; let (session, publisher, subscriber) = moq_transport::Session::accept(session).await?; let mut tasks = FuturesUnordered::new(); + tasks.push(session.run().boxed()); + + if let Some(publisher) = publisher { + tasks.push(Self::serve_publisher(publisher, self.origin.clone()).boxed()); + } + + if let Some(subscriber) = subscriber { + tasks.push(Self::serve_subscriber(subscriber, self.origin).boxed()); + } + + // Return the first error + tasks.next().await.unwrap()?; + + Ok(()) + } + + async fn serve_quic(self, conn: quinn::Connection) -> anyhow::Result<()> { + let session: quictransport_quinn::Session = conn.into(); + let (session, publisher, subscriber) = moq_transport::Session::accept(session).await?; + + let mut tasks = FuturesUnordered::new(); tasks.push(session.run().boxed()); if let Some(publisher) = publisher { tasks.push(Self::serve_publisher(publisher, self.origin.clone()).boxed()); } + if let Some(subscriber) = subscriber { tasks.push(Self::serve_subscriber(subscriber, self.origin).boxed()); } @@ -100,7 +120,9 @@ impl Connection { res = publisher.subscribed() => { let subscribe = res?; log::info!("serving subscribe: namespace={} name={}", subscribe.namespace(), subscribe.name()); - tasks.push(origin.subscribe(subscribe).boxed()); + + let track = origin.subscribe(subscribe.namespace(), subscribe.name())?; + tasks.push(subscribe.serve(track).boxed()); } }; } @@ -122,9 +144,21 @@ impl Connection { res = subscriber.announced() => { let announce = res?; log::info!("serving announce: namespace={}", announce.namespace()); - tasks.push(origin.announce(announce, subscriber.clone())); + + let publisher = origin.announce(announce.namespace())?; + tasks.push(Self::serve_announce(subscriber.clone(), publisher)); } }; } } + + async fn serve_announce( + mut subscriber: Subscriber, + mut publisher: OriginPublisher, + ) -> Result<(), SessionError> { + loop { + let track = publisher.requested().await?; + subscriber.subscribe(track)?; + } + } } diff --git a/moq-relay/src/origin.rs b/moq-relay/src/origin.rs index 51c0df06..93b626c2 100644 --- a/moq-relay/src/origin.rs +++ b/moq-relay/src/origin.rs @@ -4,8 +4,11 @@ use std::{ sync::{Arc, Mutex}, }; -use moq_transport::serve::ServeError; -use moq_transport::session; +use std::collections::VecDeque; + +use moq_transport::serve::{self, ServeError, TrackPublisher, TrackSubscriber}; +use moq_transport::session::SessionError; +use moq_transport::util::Watch; use url::Url; #[derive(Clone)] @@ -21,7 +24,7 @@ pub struct Origin { _node: Option, // A map of active broadcasts by namespace. - local: Arc>>, + origins: Arc>>, // A QUIC endpoint we'll use to fetch from other origins. _quic: quinn::Endpoint, @@ -32,67 +35,59 @@ impl Origin { Self { _api, _node, - local: Default::default(), + origins: Default::default(), _quic, } } - pub async fn announce( - &self, - mut announce: session::Announced, - subscriber: session::Subscriber, - ) -> anyhow::Result<()> { - match self.local.lock().unwrap().entry(announce.namespace().to_string()) { - hash_map::Entry::Vacant(entry) => entry.insert(subscriber), + pub fn announce(&self, namespace: &str) -> Result { + let mut origins = self.origins.lock().unwrap(); + let entry = match origins.entry(namespace.to_string()) { + hash_map::Entry::Vacant(entry) => entry, hash_map::Entry::Occupied(_) => return Err(ServeError::Duplicate.into()), }; - announce.accept().ok(); - - let err = announce.closed().await; - self.local.lock().unwrap().remove(announce.namespace()); - err?; - - Ok(()) + let (publisher, subscriber) = self.produce(namespace); + entry.insert(subscriber); - /* - // Create a publisher that constantly updates itself as the origin in moq-api. - // It holds a reference to the subscriber to prevent dropping early. - let mut publisher = Publisher { - broadcast: publisher, - subscriber, - api: None, - }; - - // Insert the publisher into the database. - if let Some(api) = self.api.as_mut() { - // Make a URL for the broadcast. - let url = self.node.as_ref().ok_or(RelayError::MissingNode)?.clone().join(id)?; - let origin = moq_api::Origin { url }; - api.set_origin(id, &origin).await?; + Ok(publisher) + } - // Refresh every 5 minutes - publisher.api = Some((api.clone(), origin)); - } + /* + // Create a publisher that constantly updates itself as the origin in moq-api. + // It holds a reference to the subscriber to prevent dropping early. + let mut publisher = Publisher { + broadcast: publisher, + subscriber, + api: None, + }; + + // Insert the publisher into the database. + if let Some(api) = self.api.as_mut() { + // Make a URL for the broadcast. + let url = self.node.as_ref().ok_or(RelayError::MissingNode)?.clone().join(id)?; + let origin = moq_api::Origin { url }; + api.set_origin(id, &origin).await?; + + // Refresh every 5 minutes + publisher.api = Some((api.clone(), origin)); + } - Ok(()) - */ - } + Ok(()) + */ - pub async fn subscribe(&self, subscribe: session::Subscribed) -> anyhow::Result<()> { - let mut subscriber = self - .local + pub fn subscribe(&self, namespace: &str, name: &str) -> Result { + let mut origin = self + .origins .lock() .unwrap() - .get(subscribe.namespace()) + .get(namespace) .cloned() .ok_or(ServeError::NotFound)?; - let upstream = subscriber.subscribe(subscribe.namespace(), subscribe.name(), Default::default())?; - subscribe.serve(upstream.track()).await?; - - Ok(()) + let track = origin.request_track(name)?; + return Ok(track); /* let mut routes = self.local.lock().unwrap(); @@ -152,6 +147,176 @@ impl Origin { Ok(()) } */ + + /// Create a new broadcast. + fn produce(&self, namespace: &str) -> (OriginPublisher, OriginSubscriber) { + let state = Watch::new(State::new(namespace)); + + let publisher = OriginPublisher::new(state.clone()); + let subscriber = OriginSubscriber::new(state); + + (publisher, subscriber) + } +} + +#[derive(Debug)] +struct State { + namespace: String, + tracks: HashMap, + requested: VecDeque, + closed: Result<(), ServeError>, +} + +impl State { + pub fn new(namespace: &str) -> Self { + Self { + namespace: namespace.to_string(), + tracks: HashMap::new(), + requested: VecDeque::new(), + closed: Ok(()), + } + } + + pub fn get_track(&self, name: &str) -> Result, ServeError> { + // Insert the track into our Map so we deduplicate future requests. + if let Some(track) = self.tracks.get(name) { + return Ok(Some(track.clone())); + } + + self.closed.clone()?; + return Ok(None); + } + + pub fn request_track(&mut self, name: &str) -> Result { + // Insert the track into our Map so we deduplicate future requests. + let entry = match self.tracks.entry(name.to_string()) { + hash_map::Entry::Vacant(entry) => entry, + hash_map::Entry::Occupied(entry) => return Ok(entry.get().clone()), + }; + + self.closed.clone()?; + + // Create a new track. + let (publisher, subscriber) = serve::Track { + namespace: self.namespace.clone(), + name: name.to_string(), + } + .produce(); + + // Deduplicate with others + // TODO This should be weak + entry.insert(subscriber.clone()); + + // Send the track to the Publisher to handle. + self.requested.push_back(publisher); + + Ok(subscriber) + } + + pub fn close(&mut self, err: ServeError) -> Result<(), ServeError> { + self.closed.clone()?; + self.closed = Err(err); + Ok(()) + } +} + +impl Drop for State { + fn drop(&mut self) { + for mut track in self.requested.drain(..) { + track.close(ServeError::NotFound).ok(); + } + + self.closed = Err(ServeError::Done); + } +} + +/// Publish new tracks for a broadcast by name. +pub struct OriginPublisher { + state: Watch, +} + +impl OriginPublisher { + fn new(state: Watch) -> Self { + Self { state } + } + + /// Block until the next track requested by a subscriber. + pub async fn requested(&mut self) -> Result { + loop { + let notify = { + let state = self.state.lock(); + if !state.requested.is_empty() { + return Ok(state.into_mut().requested.pop_front().unwrap()); + } + + state.closed.clone()?; + state.changed() + }; + + notify.await; + } + } + + /// Close the broadcast with an error. + pub fn close(self, err: ServeError) -> Result<(), ServeError> { + self.state.lock_mut().close(err) + } +} + +/// Subscribe to a broadcast by requesting tracks. +/// +/// This can be cloned to create handles. +#[derive(Clone)] +pub struct OriginSubscriber { + state: Watch, + _dropped: Arc, +} + +impl OriginSubscriber { + fn new(state: Watch) -> Self { + let _dropped = Arc::new(Dropped::new(state.clone())); + Self { state, _dropped } + } + + pub fn get_track(&self, name: &str) -> Result, ServeError> { + self.state.lock_mut().get_track(name) + } + + pub fn request_track(&mut self, name: &str) -> Result { + self.state.lock_mut().request_track(name) + } + + /// Wait until if the broadcast is closed, either because the publisher was dropped or called [Publisher::close]. + pub async fn closed(&self) -> ServeError { + loop { + let notify = { + let state = self.state.lock(); + if let Some(err) = state.closed.as_ref().err() { + return err.clone(); + } + + state.changed() + }; + + notify.await; + } + } +} + +struct Dropped { + state: Watch, +} + +impl Dropped { + fn new(state: Watch) -> Self { + Self { state } + } +} + +impl Drop for Dropped { + fn drop(&mut self) { + self.state.lock_mut().close(ServeError::Done).ok(); + } } /* diff --git a/moq-transport/Cargo.toml b/moq-transport/Cargo.toml index a1d1554a..d5c2fff4 100644 --- a/moq-transport/Cargo.toml +++ b/moq-transport/Cargo.toml @@ -20,7 +20,7 @@ thiserror = "1" tokio = { version = "1", features = ["macros", "io-util", "sync"] } log = "0.4" -webtransport-generic = { path = "../../webtransport-rs/webtransport-generic" } +webtransport-generic = { path = "../../webtransport-rs/webtransport-generic", version = "0.8" } paste = "1" futures = "0.3" diff --git a/moq-transport/src/serve/track.rs b/moq-transport/src/serve/track.rs index 566a0566..a7425f76 100644 --- a/moq-transport/src/serve/track.rs +++ b/moq-transport/src/serve/track.rs @@ -28,6 +28,13 @@ pub struct Track { } impl Track { + pub fn new(namespace: &str, name: &str) -> Self { + Self { + namespace: namespace.to_string(), + name: name.to_string(), + } + } + pub fn produce(self) -> (TrackPublisher, TrackSubscriber) { let state = Watch::new(State::default()); let info = Arc::new(self); diff --git a/moq-transport/src/session/subscribe.rs b/moq-transport/src/session/subscribe.rs index 418f655e..667974ed 100644 --- a/moq-transport/src/session/subscribe.rs +++ b/moq-transport/src/session/subscribe.rs @@ -1,129 +1,48 @@ use std::sync::{Arc, Mutex}; -use tokio::io::AsyncRead; - use crate::{ coding::Reader, - data, - message::{self, SubscribePair}, + data, message, serve::{self, ServeError}, - util::Watch, }; use super::{SessionError, Subscriber}; +#[derive(Clone)] pub struct Subscribe { session: Subscriber, id: u64, - track: serve::TrackSubscriber, - state: Watch, + track: Arc>, } impl Subscribe { - pub(super) fn new(session: Subscriber, msg: message::Subscribe) -> (Subscribe, SubscribeRecv) { - let state = Watch::new(State::default()); - - let (publisher, subscriber) = serve::Track { - namespace: msg.track_namespace, - name: msg.track_name.clone(), - } - .produce(); - - // TODO apply start/end range - - let subscriber = Subscribe { - session, - id: msg.id, - track: subscriber, - state: state.clone(), - }; - - let publisher = SubscribeRecv::new(state, publisher); - - (subscriber, publisher) - } - - // Waits until an OK message is received. - pub async fn ok(&self) -> Result<(), ServeError> { - loop { - let notify = { - let state = self.state.lock(); - if state.ok.is_some() { - return Ok(()); - } - state.changed() - }; - - tokio::select! { - _ = notify => {}, - err = self.track.closed() => return err, - }; - } - } - - // Returns the maximum known group/object sequences. - pub fn max(&self) -> Option<(u64, u64)> { - let ok = self.state.lock().ok.as_ref().and_then(|ok| ok.latest); - let cache = self.track.latest(); - - // Return the max of both the OK message and the cache. - match ok { - Some(ok) => match cache { - Some(cache) => Some(cache.max(ok)), - None => Some(ok), - }, - None => cache, - } - } - - pub fn track(&self) -> serve::TrackSubscriber { - self.track.clone() - } -} - -impl Drop for Subscribe { - fn drop(&mut self) { - let msg = message::Unsubscribe { id: self.id }; - self.session.send_message(msg).ok(); - - self.session.drop_subscribe(self.id); - } -} - -#[derive(Clone)] -pub(super) struct SubscribeRecv { - publisher: Arc>, - state: Watch, -} - -impl SubscribeRecv { - fn new(state: Watch, publisher: serve::TrackPublisher) -> Self { + pub(super) fn new(session: Subscriber, id: u64, track: serve::TrackPublisher) -> Self { Self { - publisher: Arc::new(Mutex::new(publisher)), - state, + session, + id, + track: Arc::new(Mutex::new(track)), } } - pub fn recv_ok(&mut self, msg: message::SubscribeOk) -> Result<(), ServeError> { - let mut state = self.state.lock_mut(); - state.ok = Some(msg); + pub fn recv_ok(&mut self, _msg: message::SubscribeOk) -> Result<(), ServeError> { + // TODO Ok(()) } pub fn recv_error(&mut self, code: u64) -> Result<(), ServeError> { - self.publisher.lock().unwrap().close(ServeError::Closed(code))?; + self.track.lock().unwrap().close(ServeError::Closed(code))?; Ok(()) } pub fn recv_done(&mut self, code: u64) -> Result<(), ServeError> { - self.publisher.lock().unwrap().close(ServeError::Closed(code))?; + self.track.lock().unwrap().close(ServeError::Closed(code))?; Ok(()) } - pub async fn recv_stream( + pub async fn recv_stream( &mut self, header: data::Header, - reader: Reader, + reader: Reader, ) -> Result<(), SessionError> { match header { data::Header::Track(track) => self.recv_track(track, reader).await, @@ -132,14 +51,14 @@ impl SubscribeRecv { } } - async fn recv_track( + async fn recv_track( &mut self, header: data::TrackHeader, - mut reader: Reader, + mut reader: Reader, ) -> Result<(), SessionError> { log::trace!("received track: {:?}", header); - let mut track = self.publisher.lock().unwrap().create_stream(header.send_order)?; + let mut track = self.track.lock().unwrap().create_stream(header.send_order)?; while !reader.done().await? { let chunk: data::TrackObject = reader.decode().await?; @@ -167,14 +86,14 @@ impl SubscribeRecv { Ok(()) } - async fn recv_group( + async fn recv_group( &mut self, header: data::GroupHeader, - mut reader: Reader, + mut reader: Reader, ) -> Result<(), SessionError> { log::trace!("received group: {:?}", header); - let mut group = self.publisher.lock().unwrap().create_group(serve::Group { + let mut group = self.track.lock().unwrap().create_group(serve::Group { id: header.group_id, send_order: header.send_order, })?; @@ -197,10 +116,10 @@ impl SubscribeRecv { Ok(()) } - async fn recv_object( + async fn recv_object( &mut self, header: data::ObjectHeader, - mut reader: Reader, + mut reader: Reader, ) -> Result<(), SessionError> { log::trace!("received object: {:?}", header); @@ -211,7 +130,7 @@ impl SubscribeRecv { chunks.push(data); } - let mut object = self.publisher.lock().unwrap().create_object(serve::ObjectHeader { + let mut object = self.track.lock().unwrap().create_object(serve::ObjectHeader { group_id: header.group_id, object_id: header.object_id, send_order: header.send_order, @@ -230,7 +149,7 @@ impl SubscribeRecv { pub fn recv_datagram(&self, datagram: data::Datagram) -> Result<(), SessionError> { log::trace!("received datagram: {:?}", datagram); - self.publisher.lock().unwrap().write_datagram(serve::Datagram { + self.track.lock().unwrap().write_datagram(serve::Datagram { group_id: datagram.group_id, object_id: datagram.object_id, payload: datagram.payload, @@ -241,13 +160,10 @@ impl SubscribeRecv { } } -#[derive(Default)] -struct State { - ok: Option, -} - -#[derive(Default)] -pub struct SubscribeOptions { - pub start: SubscribePair, - pub end: SubscribePair, +impl Drop for Subscribe { + fn drop(&mut self) { + let msg = message::Unsubscribe { id: self.id }; + self.session.send_message(msg).ok(); + self.session.drop_subscribe(self.id); + } } diff --git a/moq-transport/src/session/subscriber.rs b/moq-transport/src/session/subscriber.rs index 5d5568b0..d150c33c 100644 --- a/moq-transport/src/session/subscriber.rs +++ b/moq-transport/src/session/subscriber.rs @@ -6,11 +6,11 @@ use std::{ use crate::{ coding::{Decode, Reader}, - data, message, setup, + data, message, serve, setup, util::Queue, }; -use super::{Announced, AnnouncedRecv, Session, SessionError, Subscribe, SubscribeOptions, SubscribeRecv}; +use super::{Announced, AnnouncedRecv, Session, SessionError, Subscribe}; // TODO remove Clone. #[derive(Clone)] @@ -18,7 +18,7 @@ pub struct Subscriber { announced: Arc>>>, announced_queue: Queue, SessionError>, - subscribes: Arc>>, + subscribes: Arc>>>, subscribe_next: Arc, outgoing: Queue, @@ -49,30 +49,26 @@ impl Subscriber { self.announced_queue.pop().await } - pub fn subscribe( - &mut self, - namespace: &str, - name: &str, - options: SubscribeOptions, - ) -> Result, SessionError> { + pub fn subscribe(&mut self, track: serve::TrackPublisher) -> Result<(), SessionError> { let id = self.subscribe_next.fetch_add(1, atomic::Ordering::Relaxed); let msg = message::Subscribe { id, track_alias: id, - track_namespace: namespace.to_string(), - track_name: name.to_string(), - start: options.start, - end: options.end, + track_namespace: track.namespace.to_string(), + track_name: track.name.to_string(), + // TODO add these to the publisher. + start: Default::default(), + end: Default::default(), params: Default::default(), }; self.send_message(msg.clone())?; - let (publisher, subscribe) = Subscribe::new(self.clone(), msg); - self.subscribes.lock().unwrap().insert(id, subscribe); + let publisher = Subscribe::new(self.clone(), msg.id, track); + self.subscribes.lock().unwrap().insert(id, publisher); - Ok(publisher) + Ok(()) } pub(super) fn send_message>(&mut self, msg: M) -> Result<(), SessionError> { From 877dabd56a5e21b1044677835f992e220ebdfa4b Mon Sep 17 00:00:00 2001 From: Luke Curley Date: Thu, 21 Mar 2024 19:50:10 -0700 Subject: [PATCH 4/7] panic down --- moq-transport/src/coding/reader.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/moq-transport/src/coding/reader.rs b/moq-transport/src/coding/reader.rs index 63d2df55..cfea80c3 100644 --- a/moq-transport/src/coding/reader.rs +++ b/moq-transport/src/coding/reader.rs @@ -36,7 +36,7 @@ impl Reader { // Append to the buffer while remain > 0 { - remain -= self.stream.read_buf(&mut self.buffer).await?; + remain = remain.saturating_sub(self.stream.read_buf(&mut self.buffer).await?); } } } From 60f4175b07486eaec08d44d27a3b5ef6d3c9636b Mon Sep 17 00:00:00 2001 From: Luke Curley Date: Thu, 21 Mar 2024 21:26:04 -0700 Subject: [PATCH 5/7] Seems to work. --- moq-relay/src/connection.rs | 9 +++++++-- moq-relay/src/origin.rs | 9 ++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/moq-relay/src/connection.rs b/moq-relay/src/connection.rs index e7ab5f14..9e076df0 100644 --- a/moq-relay/src/connection.rs +++ b/moq-relay/src/connection.rs @@ -1,7 +1,7 @@ use anyhow::Context; use futures::{stream::FuturesUnordered, FutureExt, StreamExt}; -use moq_transport::session::{Publisher, SessionError, Subscriber}; +use moq_transport::session::{Announced, Publisher, SessionError, Subscriber}; use crate::{Origin, OriginPublisher}; @@ -146,7 +146,7 @@ impl Connection { log::info!("serving announce: namespace={}", announce.namespace()); let publisher = origin.announce(announce.namespace())?; - tasks.push(Self::serve_announce(subscriber.clone(), publisher)); + tasks.push(Self::serve_announce(subscriber.clone(), publisher, announce)); } }; } @@ -155,7 +155,12 @@ impl Connection { async fn serve_announce( mut subscriber: Subscriber, mut publisher: OriginPublisher, + mut announce: Announced, ) -> Result<(), SessionError> { + // Send ANNOUNCE_OK + // We sent ANNOUNCE_CANCEL when the scope drops + announce.accept()?; + loop { let track = publisher.requested().await?; subscriber.subscribe(track)?; diff --git a/moq-relay/src/origin.rs b/moq-relay/src/origin.rs index 93b626c2..05ee6bae 100644 --- a/moq-relay/src/origin.rs +++ b/moq-relay/src/origin.rs @@ -225,8 +225,6 @@ impl Drop for State { for mut track in self.requested.drain(..) { track.close(ServeError::NotFound).ok(); } - - self.closed = Err(ServeError::Done); } } @@ -256,10 +254,11 @@ impl OriginPublisher { notify.await; } } +} - /// Close the broadcast with an error. - pub fn close(self, err: ServeError) -> Result<(), ServeError> { - self.state.lock_mut().close(err) +impl Drop for OriginPublisher { + fn drop(&mut self) { + self.state.lock_mut().close(ServeError::Done).ok(); } } From 86abd0a7a9ca55972089e14ebcc9f5a473d9712a Mon Sep 17 00:00:00 2001 From: Luke Curley Date: Thu, 21 Mar 2024 22:11:19 -0700 Subject: [PATCH 6/7] Add client support for native QUIC. --- Cargo.lock | 23 ++++++++++--- dev/clock | 3 +- dev/pub | 3 +- moq-clock/Cargo.toml | 7 ++-- moq-clock/src/cli.rs | 9 +---- moq-clock/src/main.rs | 38 +++++++++++++++----- moq-pub/Cargo.toml | 6 ++-- moq-pub/src/cli.rs | 9 +---- moq-pub/src/main.rs | 46 +++++++++++++++++++------ moq-pub/src/media.rs | 2 +- moq-relay/Cargo.toml | 8 ++--- moq-transport/Cargo.toml | 4 +-- moq-transport/src/message/publisher.rs | 12 ++++++- moq-transport/src/message/subscriber.rs | 12 ++++++- 14 files changed, 125 insertions(+), 57 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cc31efe2..be101d65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -858,7 +858,7 @@ dependencies = [ [[package]] name = "moq-clock" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anyhow", "chrono", @@ -867,6 +867,7 @@ dependencies = [ "env_logger", "log", "moq-transport", + "quictransport-quinn", "quinn", "rustls", "rustls-native-certs", @@ -875,12 +876,13 @@ dependencies = [ "tracing", "tracing-subscriber", "url", + "webtransport-generic", "webtransport-quinn", ] [[package]] name = "moq-pub" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anyhow", "clap", @@ -889,6 +891,7 @@ dependencies = [ "log", "moq-transport", "mp4", + "quictransport-quinn", "quinn", "rfc6381-codec", "rustls", @@ -899,12 +902,13 @@ dependencies = [ "tracing", "tracing-subscriber", "url", + "webtransport-generic", "webtransport-quinn", ] [[package]] name = "moq-relay" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anyhow", "axum", @@ -935,7 +939,7 @@ dependencies = [ [[package]] name = "moq-transport" -version = "0.3.0" +version = "0.4.0" dependencies = [ "anyhow", "bytes", @@ -1224,12 +1228,15 @@ dependencies = [ [[package]] name = "quictransport-quinn" version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d65e68ccd0d35f48bbe9723a2924e6829b819cd165b19d7b2914f85ede01398" dependencies = [ "bytes", "quinn", + "thiserror", "tokio", + "url", "webtransport-generic", - "webtransport-proto", ] [[package]] @@ -2195,6 +2202,8 @@ checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc" [[package]] name = "webtransport-generic" version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc1fd0d5c7e24e485aa58040fba18d6a4204d4354eca19d34b14540ecd9147b8" dependencies = [ "bytes", "tokio", @@ -2203,6 +2212,8 @@ dependencies = [ [[package]] name = "webtransport-proto" version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebeada5037d6302980ae2e0ab8d840e329c1697c612c6c077172de2b7631a276" dependencies = [ "bytes", "http", @@ -2213,6 +2224,8 @@ dependencies = [ [[package]] name = "webtransport-quinn" version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27b0ad39e557756d066277901c3024e586aa3e026e4ee9b377f2d45d782e39ff" dependencies = [ "bytes", "futures", diff --git a/dev/clock b/dev/clock index fead5117..4e4e9da0 100755 --- a/dev/clock +++ b/dev/clock @@ -12,8 +12,9 @@ HOST="${HOST:-localhost}" PORT="${PORT:-4443}" ADDR="${ADDR:-$HOST:$PORT}" NAME="${NAME:-clock}" +SCHEME="${SCHEME:-https}" # Combine the host and name into a URL. -URL="${URL:-"https://$ADDR"}" +URL="${URL:-"$SCHEME://$ADDR"}" cargo run --bin moq-clock -- "$URL" --namespace "$NAME" "$@" diff --git a/dev/pub b/dev/pub index 435c29cf..f98b9c3f 100755 --- a/dev/pub +++ b/dev/pub @@ -11,6 +11,7 @@ export RUST_LOG="${RUST_LOG:-debug}" HOST="${HOST:-localhost}" PORT="${PORT:-4443}" ADDR="${ADDR:-$HOST:$PORT}" +SCHEME="${SCHEME:-https}" # Generate a random 16 character name by default. #NAME="${NAME:-$(head /dev/urandom | LC_ALL=C tr -dc 'a-zA-Z0-9' | head -c 16)}" @@ -20,7 +21,7 @@ ADDR="${ADDR:-$HOST:$PORT}" NAME="${NAME:-bbb}" # Combine the host into a URL. -URL="${URL:-"https://$ADDR"}" +URL="${URL:-"$SCHEME://$ADDR"}" # Default to a source video INPUT="${INPUT:-dev/source.mp4}" diff --git a/moq-clock/Cargo.toml b/moq-clock/Cargo.toml index 07635682..6ad61740 100644 --- a/moq-clock/Cargo.toml +++ b/moq-clock/Cargo.toml @@ -5,7 +5,7 @@ authors = ["Luke Curley"] repository = "https://github.com/kixelated/moq-rs" license = "MIT OR Apache-2.0" -version = "0.1.0" +version = "0.2.0" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -18,8 +18,9 @@ moq-transport = { path = "../moq-transport" } # QUIC quinn = "0.10" -#webtransport-quinn = "0.7" -webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn", version = "0.8" } +webtransport-quinn = "0.8" +webtransport-generic = "0.8" +quictransport-quinn = "0.8" url = "2" # Crypto diff --git a/moq-clock/src/cli.rs b/moq-clock/src/cli.rs index b19b33d8..ff411f34 100644 --- a/moq-clock/src/cli.rs +++ b/moq-clock/src/cli.rs @@ -39,12 +39,5 @@ pub struct Config { } fn moq_url(s: &str) -> Result { - let url = Url::try_from(s).map_err(|e| e.to_string())?; - - // Make sure the scheme is moq - if url.scheme() != "https" { - return Err("url scheme must be https:// for WebTransport".to_string()); - } - - Ok(url) + Url::try_from(s).map_err(|e| e.to_string()) } diff --git a/moq-clock/src/main.rs b/moq-clock/src/main.rs index 13a4194b..85e9a605 100644 --- a/moq-clock/src/main.rs +++ b/moq-clock/src/main.rs @@ -57,20 +57,40 @@ async fn main() -> anyhow::Result<()> { tls_config.dangerous().set_certificate_verifier(Arc::new(noop)); } - tls_config.alpn_protocols = vec![webtransport_quinn::ALPN.to_vec()]; // this one is important + log::info!("connecting to server: url={}", config.url); - let arc_tls_config = std::sync::Arc::new(tls_config); - let quinn_client_config = quinn::ClientConfig::new(arc_tls_config); + match config.url.scheme() { + "https" => { + tls_config.alpn_protocols = vec![webtransport_quinn::ALPN.to_vec()]; // this one is important + let client_config = quinn::ClientConfig::new(Arc::new(tls_config)); - let mut endpoint = quinn::Endpoint::client(config.bind)?; - endpoint.set_default_client_config(quinn_client_config); + let mut endpoint = quinn::Endpoint::client(config.bind)?; + endpoint.set_default_client_config(client_config); - log::info!("connecting to server: url={}", config.url); + let session = webtransport_quinn::connect(&endpoint, &config.url) + .await + .context("failed to create WebTransport session")?; + + run(session, config).await + } + "moqt" => { + tls_config.alpn_protocols = vec![moq_transport::setup::ALPN.to_vec()]; // this one is important + let client_config = quinn::ClientConfig::new(Arc::new(tls_config)); + + let mut endpoint = quinn::Endpoint::client(config.bind)?; + endpoint.set_default_client_config(client_config); - let session = webtransport_quinn::connect(&endpoint, &config.url) - .await - .context("failed to create WebTransport session")?; + let session = quictransport_quinn::connect(&endpoint, &config.url) + .await + .context("failed to create QUIC Transport session")?; + + run(session, config).await + } + _ => anyhow::bail!("unsupported scheme: {}", config.url.scheme()), + } +} +async fn run(session: S, config: cli::Config) -> anyhow::Result<()> { if config.publish { let (session, publisher) = moq_transport::Publisher::connect(session) .await diff --git a/moq-pub/Cargo.toml b/moq-pub/Cargo.toml index b5005f8a..0447f752 100644 --- a/moq-pub/Cargo.toml +++ b/moq-pub/Cargo.toml @@ -5,7 +5,7 @@ authors = ["Mike English", "Luke Curley"] repository = "https://github.com/kixelated/moq-rs" license = "MIT OR Apache-2.0" -version = "0.1.0" +version = "0.2.0" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -18,7 +18,9 @@ moq-transport = { path = "../moq-transport" } # QUIC quinn = "0.10" -webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn", version = "0.8" } +webtransport-quinn = "0.8" +quictransport-quinn = "0.8" +webtransport-generic = "0.8" url = "2" # Crypto diff --git a/moq-pub/src/cli.rs b/moq-pub/src/cli.rs index ff02891d..a39ed412 100644 --- a/moq-pub/src/cli.rs +++ b/moq-pub/src/cli.rs @@ -41,12 +41,5 @@ pub struct Config { } fn moq_url(s: &str) -> Result { - let url = Url::try_from(s).map_err(|e| e.to_string())?; - - // Make sure the scheme is moq - if url.scheme() != "https" { - return Err("url scheme must be https:// for WebTransport".to_string()); - } - - Ok(url) + Url::try_from(s).map_err(|e| e.to_string()) } diff --git a/moq-pub/src/main.rs b/moq-pub/src/main.rs index a9ddc881..bb42bbf8 100644 --- a/moq-pub/src/main.rs +++ b/moq-pub/src/main.rs @@ -8,6 +8,7 @@ use cli::*; use moq_pub::media::Media; use moq_transport::serve; +use tokio::io::AsyncRead; // TODO: clap complete @@ -25,7 +26,7 @@ async fn main() -> anyhow::Result<()> { let input = tokio::io::stdin(); let (publisher, broadcast) = serve::Broadcast::new(&config.name).produce(); - let mut media = Media::new(input, publisher).await?; + let media = Media::new(input, publisher).await?; // Create a list of acceptable root certificates. let mut roots = rustls::RootCertStore::empty(); @@ -62,25 +63,48 @@ async fn main() -> anyhow::Result<()> { tls_config.dangerous().set_certificate_verifier(Arc::new(noop)); } - tls_config.alpn_protocols = vec![webtransport_quinn::ALPN.to_vec()]; // this one is important + log::info!("connecting to relay: url={}", config.url); - let arc_tls_config = std::sync::Arc::new(tls_config); - let quinn_client_config = quinn::ClientConfig::new(arc_tls_config); + match config.url.scheme() { + "https" => { + tls_config.alpn_protocols = vec![webtransport_quinn::ALPN.to_vec()]; + let client_config = quinn::ClientConfig::new(Arc::new(tls_config)); - let mut endpoint = quinn::Endpoint::client(config.bind)?; - endpoint.set_default_client_config(quinn_client_config); + let mut endpoint = quinn::Endpoint::client(config.bind)?; + endpoint.set_default_client_config(client_config); - log::info!("connecting to relay: url={}", config.url); + let session = webtransport_quinn::connect(&endpoint, &config.url) + .await + .context("failed to create WebTransport session")?; - let session = webtransport_quinn::connect(&endpoint, &config.url) - .await - .context("failed to create WebTransport session")?; + run(session, media, broadcast).await + } + "moqt" => { + tls_config.alpn_protocols = vec![moq_transport::setup::ALPN.to_vec()]; + let client_config = quinn::ClientConfig::new(Arc::new(tls_config)); + + let mut endpoint = quinn::Endpoint::client(config.bind)?; + endpoint.set_default_client_config(client_config); + + let session = quictransport_quinn::connect(&endpoint, &config.url) + .await + .context("failed to create QUIC Transport session")?; + + run(session, media, broadcast).await + } + _ => anyhow::bail!("url scheme must be 'https' or 'moqt'"), + } +} +async fn run( + session: T, + mut media: Media, + broadcast: serve::BroadcastSubscriber, +) -> anyhow::Result<()> { let (session, publisher) = moq_transport::Publisher::connect(session) .await .context("failed to create MoQ Transport publisher")?; - // TODO run a task that returns a 404 for all unknown subscriptions. tokio::select! { res = session.run() => res.context("session error")?, res = media.run() => res.context("media error")?, diff --git a/moq-pub/src/media.rs b/moq-pub/src/media.rs index 3a5ab0b9..eadf295d 100644 --- a/moq-pub/src/media.rs +++ b/moq-pub/src/media.rs @@ -14,7 +14,7 @@ pub struct Media { input: I, } -impl Media { +impl Media { pub async fn new(mut input: I, mut broadcast: BroadcastPublisher) -> anyhow::Result { let ftyp = read_atom(&mut input).await?; anyhow::ensure!(&ftyp[4..8] == b"ftyp", "expected ftyp atom"); diff --git a/moq-relay/Cargo.toml b/moq-relay/Cargo.toml index fb6fc6ad..992e873b 100644 --- a/moq-relay/Cargo.toml +++ b/moq-relay/Cargo.toml @@ -5,7 +5,7 @@ authors = ["Luke Curley"] repository = "https://github.com/kixelated/moq-rs" license = "MIT OR Apache-2.0" -version = "0.1.0" +version = "0.2.0" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -17,9 +17,9 @@ moq-api = { path = "../moq-api" } # QUIC quinn = "0.10" -quictransport-quinn = { path = "../../webtransport-rs/quictransport-quinn", version = "0.8" } -webtransport-quinn = { path = "../../webtransport-rs/webtransport-quinn", version = "0.8" } -webtransport-generic = { path = "../../webtransport-rs/webtransport-generic", version = "0.8" } +quictransport-quinn = "0.8" +webtransport-quinn = "0.8" +webtransport-generic = "0.8" url = "2" # Crypto diff --git a/moq-transport/Cargo.toml b/moq-transport/Cargo.toml index d5c2fff4..2c323eae 100644 --- a/moq-transport/Cargo.toml +++ b/moq-transport/Cargo.toml @@ -5,7 +5,7 @@ authors = ["Luke Curley"] repository = "https://github.com/kixelated/moq-rs" license = "MIT OR Apache-2.0" -version = "0.3.0" +version = "0.4.0" edition = "2021" keywords = ["quic", "http3", "webtransport", "media", "live"] @@ -20,7 +20,7 @@ thiserror = "1" tokio = { version = "1", features = ["macros", "io-util", "sync"] } log = "0.4" -webtransport-generic = { path = "../../webtransport-rs/webtransport-generic", version = "0.8" } +webtransport-generic = "0.8" paste = "1" futures = "0.3" diff --git a/moq-transport/src/message/publisher.rs b/moq-transport/src/message/publisher.rs index b7ca9e25..24ed96e1 100644 --- a/moq-transport/src/message/publisher.rs +++ b/moq-transport/src/message/publisher.rs @@ -1,8 +1,9 @@ use crate::message::{self, Message}; +use std::fmt; macro_rules! publisher_msgs { {$($name:ident,)*} => { - #[derive(Clone, Debug)] + #[derive(Clone)] pub enum Publisher { $($name(message::$name)),* } @@ -31,6 +32,15 @@ macro_rules! publisher_msgs { } } } + + impl fmt::Debug for Publisher { + // Delegate to the message formatter + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + $(Self::$name(ref m) => m.fmt(f),)* + } + } + } } } diff --git a/moq-transport/src/message/subscriber.rs b/moq-transport/src/message/subscriber.rs index 01cd141b..914ba3c9 100644 --- a/moq-transport/src/message/subscriber.rs +++ b/moq-transport/src/message/subscriber.rs @@ -1,8 +1,9 @@ use crate::message::{self, Message}; +use std::fmt; macro_rules! subscriber_msgs { {$($name:ident,)*} => { - #[derive(Clone, Debug)] + #[derive(Clone)] pub enum Subscriber { $($name(message::$name)),* } @@ -31,6 +32,15 @@ macro_rules! subscriber_msgs { } } } + + impl fmt::Debug for Subscriber { + // Delegate to the message formatter + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + $(Self::$name(ref m) => m.fmt(f),)* + } + } + } } } From caa9d294a052859df7b18ac33b39e36046a0747e Mon Sep 17 00:00:00 2001 From: Luke Curley Date: Thu, 21 Mar 2024 22:14:46 -0700 Subject: [PATCH 7/7] Clippy --- moq-relay/src/origin.rs | 4 ++-- moq-transport/src/coding/reader.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/moq-relay/src/origin.rs b/moq-relay/src/origin.rs index 05ee6bae..37afab77 100644 --- a/moq-relay/src/origin.rs +++ b/moq-relay/src/origin.rs @@ -87,7 +87,7 @@ impl Origin { .ok_or(ServeError::NotFound)?; let track = origin.request_track(name)?; - return Ok(track); + Ok(track) /* let mut routes = self.local.lock().unwrap(); @@ -184,7 +184,7 @@ impl State { } self.closed.clone()?; - return Ok(None); + Ok(None) } pub fn request_track(&mut self, name: &str) -> Result { diff --git a/moq-transport/src/coding/reader.rs b/moq-transport/src/coding/reader.rs index cfea80c3..b84ee43e 100644 --- a/moq-transport/src/coding/reader.rs +++ b/moq-transport/src/coding/reader.rs @@ -31,7 +31,7 @@ impl Reader { return Ok(msg); } Err(DecodeError::More(remain)) => remain, // Try again with more data - Err(err) => return Err(err.into()), + Err(err) => return Err(err), }; // Append to the buffer