From 5e6edac7f4245d5ec6e8aabe2678725f1ecc5903 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Fri, 23 Aug 2024 23:56:16 -0700 Subject: [PATCH] wip: polish --- core/codegen/src/attribute/route/parse.rs | 1 + core/codegen/tests/route.rs | 2 +- core/lib/Cargo.toml | 1 - core/lib/fuzz/targets/collision-matching.rs | 15 +- core/lib/src/phase.rs | 6 +- core/lib/src/rocket.rs | 15 +- core/lib/src/route/route.rs | 8 +- core/lib/src/router/router.rs | 245 ++++++++++---------- core/lib/src/trace/traceable.rs | 2 +- core/lib/tests/form_method-issue-45.rs | 4 +- examples/tls/src/redirector.rs | 6 +- 11 files changed, 156 insertions(+), 149 deletions(-) diff --git a/core/codegen/src/attribute/route/parse.rs b/core/codegen/src/attribute/route/parse.rs index 8b9613c9e3..782f342096 100644 --- a/core/codegen/src/attribute/route/parse.rs +++ b/core/codegen/src/attribute/route/parse.rs @@ -42,6 +42,7 @@ pub struct Arguments { /// The parsed `#[route(..)]` attribute. #[derive(Debug, FromMeta)] pub struct Attribute { + #[meta(naked)] pub uri: RouteUri, pub method: Option>, pub data: Option>, diff --git a/core/codegen/tests/route.rs b/core/codegen/tests/route.rs index a46fca57eb..f9a1bf66e9 100644 --- a/core/codegen/tests/route.rs +++ b/core/codegen/tests/route.rs @@ -54,8 +54,8 @@ fn post1( } #[route( + "///name/?sky=blue&&", method = POST, - uri = "///name/?sky=blue&&", format = "json", data = "", rank = 138 diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index d907e183e1..10eb5f17bd 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -74,7 +74,6 @@ tokio-stream = { version = "0.1.6", features = ["signal", "time"] } cookie = { version = "0.18", features = ["percent-encode"] } futures = { version = "0.3.30", default-features = false, features = ["std"] } state = "0.6" -rustc-hash = "1.1" # tracing tracing = { version = "0.1.40", default-features = false, features = ["std", "attributes"] } diff --git a/core/lib/fuzz/targets/collision-matching.rs b/core/lib/fuzz/targets/collision-matching.rs index 1fb035efa2..44ffd8afc6 100644 --- a/core/lib/fuzz/targets/collision-matching.rs +++ b/core/lib/fuzz/targets/collision-matching.rs @@ -16,7 +16,7 @@ struct ArbitraryRequestData<'a> { #[derive(Arbitrary)] struct ArbitraryRouteData<'a> { - method: ArbitraryMethod, + method: Option, uri: ArbitraryRouteUri<'a>, format: Option, } @@ -24,7 +24,7 @@ struct ArbitraryRouteData<'a> { impl std::fmt::Debug for ArbitraryRouteData<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ArbitraryRouteData") - .field("method", &self.method.0) + .field("method", &self.method.map(|v| v.0)) .field("base", &self.uri.0.base()) .field("unmounted", &self.uri.0.unmounted().to_string()) .field("uri", &self.uri.0.to_string()) @@ -59,12 +59,14 @@ impl<'c, 'a: 'c> ArbitraryRequestData<'a> { impl<'a> ArbitraryRouteData<'a> { fn into_route(self) -> Route { - let mut r = Route::ranked(0, self.method.0, &self.uri.0.to_string(), dummy_handler); + let method = self.method.map(|m| m.0); + let mut r = Route::ranked(0, method, &self.uri.0.to_string(), dummy_handler); r.format = self.format.map(|f| f.0); r } } +#[derive(Clone, Copy)] struct ArbitraryMethod(Method); struct ArbitraryOrigin<'a>(Origin<'a>); @@ -79,12 +81,7 @@ struct ArbitraryRouteUri<'a>(RouteUri<'a>); impl<'a> Arbitrary<'a> for ArbitraryMethod { fn arbitrary(u: &mut Unstructured<'a>) -> Result { - let all_methods = &[ - Method::Get, Method::Put, Method::Post, Method::Delete, Method::Options, - Method::Head, Method::Trace, Method::Connect, Method::Patch - ]; - - Ok(ArbitraryMethod(*u.choose(all_methods)?)) + Ok(ArbitraryMethod(*u.choose(Method::ALL_VARIANTS)?)) } fn size_hint(_: usize) -> (usize, Option) { diff --git a/core/lib/src/phase.rs b/core/lib/src/phase.rs index 19986382ed..8e05411bc6 100644 --- a/core/lib/src/phase.rs +++ b/core/lib/src/phase.rs @@ -4,7 +4,7 @@ use figment::Figment; use crate::listener::Endpoint; use crate::shutdown::Stages; use crate::{Catcher, Config, Rocket, Route}; -use crate::router::Router; +use crate::router::{Router, Finalized}; use crate::fairing::Fairings; mod private { @@ -100,7 +100,7 @@ phases! { /// represents a fully built and finalized application server ready for /// launch into orbit. See [`Rocket#ignite`] for full details. Ignite (#[derive(Debug)] Igniting) { - pub(crate) router: Router, + pub(crate) router: Router, pub(crate) fairings: Fairings, pub(crate) figment: Figment, pub(crate) config: Config, @@ -114,7 +114,7 @@ phases! { /// An instance of `Rocket` in this phase is typed as [`Rocket`] and /// represents a running application. Orbit (#[derive(Debug)] Orbiting) { - pub(crate) router: Router, + pub(crate) router: Router, pub(crate) fairings: Fairings, pub(crate) figment: Figment, pub(crate) config: Config, diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index aabd76dcc8..5762ed40f3 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -557,9 +557,10 @@ impl Rocket { // Initialize the router; check for collisions. let mut router = Router::new(); - self.routes.clone().into_iter().for_each(|r| router.add_route(r)); - self.catchers.clone().into_iter().for_each(|c| router.add_catcher(c)); - router.finalize().map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?; + self.routes.clone().into_iter().for_each(|r| router.routes.push(r)); + self.catchers.clone().into_iter().for_each(|c| router.catchers.push(c)); + let router = router.finalize() + .map_err(|(r, c)| ErrorKind::Collisions { routes: r, catchers: c, })?; // Finally, freeze managed state for faster access later. self.state.freeze(); @@ -840,8 +841,8 @@ impl Rocket

{ pub fn routes(&self) -> impl Iterator { match self.0.as_ref() { StateRef::Build(p) => Either::Left(p.routes.iter()), - StateRef::Ignite(p) => Either::Right(p.router.routes()), - StateRef::Orbit(p) => Either::Right(p.router.routes()), + StateRef::Ignite(p) => Either::Right(p.router.routes.iter()), + StateRef::Orbit(p) => Either::Right(p.router.routes.iter()), } } @@ -871,8 +872,8 @@ impl Rocket

{ pub fn catchers(&self) -> impl Iterator { match self.0.as_ref() { StateRef::Build(p) => Either::Left(p.catchers.iter()), - StateRef::Ignite(p) => Either::Right(p.router.catchers()), - StateRef::Orbit(p) => Either::Right(p.router.catchers()), + StateRef::Ignite(p) => Either::Right(p.router.catchers.iter()), + StateRef::Orbit(p) => Either::Right(p.router.catchers.iter()), } } diff --git a/core/lib/src/route/route.rs b/core/lib/src/route/route.rs index 3b31864bb3..8bf99dc0f7 100644 --- a/core/lib/src/route/route.rs +++ b/core/lib/src/route/route.rs @@ -22,7 +22,7 @@ use crate::sentinel::Sentry; /// /// let route = routes![route_name].remove(0); /// assert_eq!(route.name.unwrap(), "route_name"); -/// assert_eq!(route.method, Method::Get); +/// assert_eq!(route.method, Some(Method::Get)); /// assert_eq!(route.uri, "/route/?query"); /// assert_eq!(route.rank, 2); /// assert_eq!(route.format.unwrap(), MediaType::JSON); @@ -203,7 +203,7 @@ impl Route { /// // this is a route matching requests to `GET /` /// let index = Route::new(Method::Get, "/", handler); /// assert_eq!(index.rank, -9); - /// assert_eq!(index.method, Method::Get); + /// assert_eq!(index.method, Some(Method::Get)); /// assert_eq!(index.uri, "/"); /// ``` #[track_caller] @@ -233,12 +233,12 @@ impl Route { /// /// let foo = Route::ranked(1, Method::Post, "/foo?bar", handler); /// assert_eq!(foo.rank, 1); - /// assert_eq!(foo.method, Method::Post); + /// assert_eq!(foo.method, Some(Method::Post)); /// assert_eq!(foo.uri, "/foo?bar"); /// /// let foo = Route::ranked(None, Method::Post, "/foo?bar", handler); /// assert_eq!(foo.rank, -12); - /// assert_eq!(foo.method, Method::Post); + /// assert_eq!(foo.method, Some(Method::Post)); /// assert_eq!(foo.uri, "/foo?bar"); /// ``` #[track_caller] diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 9217907497..017486a6e6 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -1,65 +1,120 @@ -use rustc_hash::FxHashMap; +use std::ops::{Deref, DerefMut}; +use std::collections::HashMap; use crate::request::Request; use crate::http::{Method, Status}; use crate::{Route, Catcher}; use crate::router::Collide; +#[derive(Debug)] +pub(crate) struct Router(T); + +#[derive(Debug, Default)] +pub struct Pending { + pub routes: Vec, + pub catchers: Vec, +} + #[derive(Debug, Default)] -pub(crate) struct Router { - routes: FxHashMap, Vec>, - final_routes: FxHashMap>, - catchers: FxHashMap, Vec>, +pub struct Finalized { + pub routes: Vec, + pub catchers: Vec, + route_map: HashMap>, + catcher_map: HashMap, Vec>, } -pub type Collisions = Vec<(T, T)>; +pub type Pair = (T, T); -impl Router { +pub type Collisions = (Vec>, Vec>); + +pub type Result = std::result::Result; + +impl Router { pub fn new() -> Self { - Self::default() - } + Router(Pending::default()) + } + + pub fn finalize(self) -> Result, Collisions> { + fn collisions<'a, T>(items: &'a [T]) -> impl Iterator + 'a + where T: Collide + Clone + 'a, + { + items.iter() + .enumerate() + .flat_map(move |(i, a)| { + items.iter() + .skip(i + 1) + .filter(move |b| a.collides_with(b)) + .map(move |b| (a.clone(), b.clone())) + }) + } - pub fn add_route(&mut self, route: Route) { - let routes = self.routes.entry(route.method).or_default(); - routes.push(route); - routes.sort_by_key(|r| r.rank); - } + let route_collisions: Vec<_> = collisions(&self.routes).collect(); + let catcher_collisions: Vec<_> = collisions(&self.catchers).collect(); - pub fn add_catcher(&mut self, catcher: Catcher) { - let catchers = self.catchers.entry(catcher.code).or_default(); - catchers.push(catcher); - catchers.sort_by_key(|c| c.rank); - } + if !route_collisions.is_empty() || !catcher_collisions.is_empty() { + return Err((route_collisions, catcher_collisions)) + } - #[inline] - pub fn routes(&self) -> impl Iterator + Clone { - self.routes.values().flat_map(|v| v.iter()) - } + // create the route map + let mut route_map: HashMap> = HashMap::new(); + for (i, route) in self.routes.iter().enumerate() { + match route.method { + Some(method) => route_map.entry(method).or_default().push(i), + None => for method in Method::ALL_VARIANTS { + route_map.entry(*method).or_default().push(i); + } + } + } + + // create the catcher map + let mut catcher_map: HashMap, Vec> = HashMap::new(); + for (i, catcher) in self.catchers.iter().enumerate() { + catcher_map.entry(catcher.code).or_default().push(i); + } + + // sort routes by rank + for routes in route_map.values_mut() { + routes.sort_by_key(|&i| &self.routes[i].rank); + } + + // sort catchers by rank + for catchers in catcher_map.values_mut() { + catchers.sort_by_key(|&i| &self.catchers[i].rank); + } - #[inline] - pub fn catchers(&self) -> impl Iterator + Clone { - self.catchers.values().flat_map(|v| v.iter()) + Ok(Router(Finalized { + routes: self.0.routes, + catchers: self.0.catchers, + route_map, catcher_map + })) } +} +impl Router { + #[track_caller] pub fn route<'r, 'a: 'r>( &'a self, req: &'r Request<'r> ) -> impl Iterator + 'r { // Note that routes are presorted by ascending rank on each `add` and // that all routes with `None` methods have been cloned into all methods. - self.final_routes.get(&req.method()) + self.route_map.get(&req.method()) .into_iter() - .flat_map(move |routes| routes.iter().filter(move |r| r.matches(req))) + .flat_map(move |routes| routes.iter().map(move |&i| &self.routes[i])) + .filter(move |r| r.matches(req)) } // For many catchers, using aho-corasick or similar should be much faster. + #[track_caller] pub fn catch<'r>(&self, status: Status, req: &'r Request<'r>) -> Option<&Catcher> { // Note that catchers are presorted by descending base length. - let explicit = self.catchers.get(&Some(status.code)) - .and_then(|c| c.iter().find(|c| c.matches(status, req))); + let explicit = self.catcher_map.get(&Some(status.code)) + .map(|catchers| catchers.iter().map(|&i| &self.catchers[i])) + .and_then(|mut catchers| catchers.find(|c| c.matches(status, req))); - let default = self.catchers.get(&None) - .and_then(|c| c.iter().find(|c| c.matches(status, req))); + let default = self.catcher_map.get(&None) + .map(|catchers| catchers.iter().map(|&i| &self.catchers[i])) + .and_then(|mut catchers| catchers.find(|c| c.matches(status, req))); match (explicit, default) { (None, None) => None, @@ -68,49 +123,19 @@ impl Router { (Some(_), Some(b)) => Some(b), } } +} - fn collisions<'a, I, T>(&self, items: I) -> impl Iterator + 'a - where I: Iterator + Clone + 'a, T: Collide + Clone + 'a, - { - items.clone().enumerate() - .flat_map(move |(i, a)| { - items.clone() - .skip(i + 1) - .filter(move |b| a.collides_with(b)) - .map(move |b| (a.clone(), b.clone())) - }) - } +impl Deref for Router { + type Target = T; - fn _add_route(map: &mut FxHashMap>, method: Method, route: Route) { - let routes = map.entry(method).or_default(); - routes.push(route); - routes.sort_by_key(|r| r.rank); + fn deref(&self) -> &Self::Target { + &self.0 } +} - pub fn finalize(&mut self) -> Result<(), (Collisions, Collisions)> { - let routes: Vec<_> = self.collisions(self.routes()).collect(); - let catchers: Vec<_> = self.collisions(self.catchers()).collect(); - - if !routes.is_empty() || !catchers.is_empty() { - return Err((routes, catchers)) - } - - let all_routes = self.routes.iter() - .flat_map(|(method, routes)| routes.iter().map(|r| (*method, r))) - .map(|(method, route)| (method, route.clone())); - - let mut final_routes = FxHashMap::default(); - for (method, route) in all_routes { - match method { - Some(method) => Self::_add_route(&mut final_routes, method, route), - None => for method in Method::ALL_VARIANTS { - Self::_add_route(&mut final_routes, *method, route.clone()); - } - } - } - - self.final_routes = final_routes; - Ok(()) +impl DerefMut for Router { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 } } @@ -122,50 +147,32 @@ mod test { use crate::local::blocking::Client; use crate::http::{Method::*, uri::Origin}; - impl Router { - fn has_collisions(&mut self) -> bool { - self.finalize().is_err() - } - } - - fn router_with_routes(routes: &[&'static str]) -> Router { - let mut router = Router::new(); - for route in routes { - let route = Route::new(Get, route, dummy_handler); - router.add_route(route); - } - - router - } - - fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router { + fn make_router(routes: I) -> Result, Collisions> + where I: Iterator, &'static str)> + { let mut router = Router::new(); - for &(rank, route) in routes { + for (rank, route) in routes { let route = Route::ranked(rank, Get, route, dummy_handler); - router.add_route(route); + router.routes.push(route); } - router + router.finalize() } - fn router_with_rankless_routes(routes: &[&'static str]) -> Router { - let mut router = Router::new(); - for route in routes { - let route = Route::ranked(0, Get, route, dummy_handler); - router.add_route(route); - } + fn router_with_routes(routes: &[&'static str]) -> Router { + make_router(routes.iter().map(|r| (None, *r))).unwrap() + } - router + fn router_with_ranked_routes(routes: &[(isize, &'static str)]) -> Router { + make_router(routes.iter().map(|r| (Some(r.0), r.1))).unwrap() } fn rankless_route_collisions(routes: &[&'static str]) -> bool { - let mut router = router_with_rankless_routes(routes); - router.has_collisions() + make_router(routes.iter().map(|r| (Some(0), *r))).is_err() } fn default_rank_route_collisions(routes: &[&'static str]) -> bool { - let mut router = router_with_routes(routes); - router.has_collisions() + make_router(routes.iter().map(|r| (None, *r))).is_err() } #[test] @@ -302,13 +309,15 @@ mod test { assert!(!default_rank_route_collisions(&["/?a=b", "/?c=d&"])); } - fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> { + #[track_caller] + fn matches<'a>(router: &'a Router, method: Method, uri: &'a str) -> Vec<&'a Route> { let client = Client::debug_with(vec![]).expect("client"); let request = client.req(method, Origin::parse(uri).unwrap()); router.route(&request).collect() } - fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> { + #[track_caller] + fn route<'a>(router: &'a Router, method: Method, uri: &'a str) -> Option<&'a Route> { matches(router, method, uri).into_iter().next() } @@ -331,9 +340,10 @@ mod test { assert!(route(&router, Get, "/a/").is_some()); let mut router = Router::new(); - router.add_route(Route::new(Put, "/hello", dummy_handler)); - router.add_route(Route::new(Post, "/hello", dummy_handler)); - router.add_route(Route::new(Delete, "/hello", dummy_handler)); + router.routes.push(Route::new(Put, "/hello", dummy_handler)); + router.routes.push(Route::new(Post, "/hello", dummy_handler)); + router.routes.push(Route::new(Delete, "/hello", dummy_handler)); + let router = router.finalize().unwrap(); assert!(route(&router, Put, "/hello").is_some()); assert!(route(&router, Post, "/hello").is_some()); assert!(route(&router, Delete, "/hello").is_some()); @@ -389,8 +399,7 @@ mod test { /// Asserts that `$to` routes to `$want` given `$routes` are present. macro_rules! assert_ranked_match { ($routes:expr, $to:expr => $want:expr) => ({ - let mut router = router_with_routes($routes); - assert!(!router.has_collisions()); + let router = router_with_routes($routes); let route_path = route(&router, Get, $to).unwrap().uri.to_string(); assert_eq!(route_path, $want.to_string(), "\nmatched {} with {}, wanted {} in {:#?}", $to, route_path, $want, router); @@ -423,8 +432,7 @@ mod test { } fn ranked_collisions(routes: &[(isize, &'static str)]) -> bool { - let mut router = router_with_ranked_routes(routes); - router.has_collisions() + make_router(routes.iter().map(|r| (Some(r.0), r.1))).is_err() } #[test] @@ -451,7 +459,7 @@ mod test { let router = router_with_ranked_routes(&$routes); let routed_to = matches(&router, Get, $to); let expected = &[$($want),+]; - assert!(routed_to.len() == expected.len()); + assert_eq!(routed_to.len(), expected.len()); for (got, expected) in routed_to.iter().zip(expected.iter()) { assert_eq!(got.rank, expected.0); assert_eq!(got.uri.to_string(), expected.1.to_string()); @@ -567,20 +575,21 @@ mod test { ); } - fn router_with_catchers(catchers: &[(Option, &str)]) -> Router { + fn router_with_catchers(catchers: &[(Option, &str)]) -> Result> { let mut router = Router::new(); for (code, base) in catchers { let catcher = Catcher::new(*code, crate::catcher::dummy_handler); - router.add_catcher(catcher.map_base(|_| base.to_string()).unwrap()); + router.catchers.push(catcher.map_base(|_| base.to_string()).unwrap()); } - router + router.finalize() } - fn catcher<'a>(router: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> { + #[track_caller] + fn catcher<'a>(r: &'a Router, status: Status, uri: &str) -> Option<&'a Catcher> { let client = Client::debug_with(vec![]).expect("client"); let request = client.get(Origin::parse(uri).unwrap()); - router.catch(status, &request) + r.catch(status, &request) } macro_rules! assert_catcher_routing { @@ -593,7 +602,7 @@ mod test { let requests = vec![$($r),+]; let expected = vec![$(($ecode.into(), $euri)),+]; - let router = router_with_catchers(&catchers); + let router = router_with_catchers(&catchers).expect("valid router"); for (req, expected) in requests.iter().zip(expected.iter()) { let req_status = Status::from_code(req.0).expect("valid status"); let catcher = catcher(&router, req_status, req.1).expect("some catcher"); diff --git a/core/lib/src/trace/traceable.rs b/core/lib/src/trace/traceable.rs index 1b3d4a2436..9ef1d4282f 100644 --- a/core/lib/src/trace/traceable.rs +++ b/core/lib/src/trace/traceable.rs @@ -144,7 +144,7 @@ impl Trace for Route { rank = self.rank, method = %Formatter(|f| match self.method { Some(method) => write!(f, "{}", method), - None => write!(f, "*"), + None => write!(f, "[any]"), }), uri = %self.uri, uri.base = %self.uri.base(), diff --git a/core/lib/tests/form_method-issue-45.rs b/core/lib/tests/form_method-issue-45.rs index 69f486bd2b..4e57a48b80 100644 --- a/core/lib/tests/form_method-issue-45.rs +++ b/core/lib/tests/form_method-issue-45.rs @@ -13,13 +13,13 @@ fn patch(form_data: Form) -> &'static str { "PATCH OK" } -#[route(method = UPDATEREDIRECTREF, uri = "/", data = "")] +#[route("/", method = UPDATEREDIRECTREF, data = "")] fn urr(form_data: Form) -> &'static str { assert_eq!("Form data", form_data.into_inner().form_data); "UPDATEREDIRECTREF OK" } -#[route(method = "VERSION-CONTROL", uri = "/", data = "")] +#[route("/", method = "VERSION-CONTROL", data = "")] fn vc(form_data: Form) -> &'static str { assert_eq!("Form data", form_data.into_inner().form_data); "VERSION-CONTROL OK" diff --git a/examples/tls/src/redirector.rs b/examples/tls/src/redirector.rs index 1481d8f8b9..b8155317e5 100644 --- a/examples/tls/src/redirector.rs +++ b/examples/tls/src/redirector.rs @@ -2,9 +2,9 @@ use std::net::SocketAddr; -use rocket::http::Status; +use rocket::http::uri::{Origin, Host}; use rocket::tracing::{self, Instrument}; -use rocket::{route, Error, Request, Data, Route, Orbit, Rocket, Ignite}; +use rocket::{Rocket, Ignite, Orbit, State, Error}; use rocket::fairing::{Fairing, Info, Kind}; use rocket::response::Redirect; use rocket::listener::tcp::TcpListener; @@ -19,7 +19,7 @@ pub struct Config { tls_addr: SocketAddr, } -#[route(uri = "/<_..>")] +#[route("/<_..>")] fn redirect(config: &State, uri: &Origin<'_>, host: &Host<'_>) -> Redirect { // FIXME: Check the host against a whitelist! let domain = host.domain();