Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SSE::Connected event to be able to retrieve status and headers in case of success #78

Closed
wants to merge 12 commits into from
17 changes: 14 additions & 3 deletions contract-tests/src/bin/sse-test-api/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,25 @@ struct Config {
#[derive(Serialize, Debug)]
#[serde(tag = "kind", rename_all = "camelCase")]
enum EventType {
Event { event: Event },
Comment { comment: String },
Error { error: String },
Connected {
status: u16,
headers: HashMap<String, String>,
},
Event {
event: Event,
},
Comment {
comment: String,
},
Error {
error: String,
},
}

impl From<es::SSE> for EventType {
fn from(event: es::SSE) -> Self {
match event {
es::SSE::Connected((status, headers)) => Self::Connected { status, headers },
es::SSE::Event(evt) => Self::Event {
event: Event {
event_type: evt.event_type,
Expand Down
3 changes: 3 additions & 0 deletions eventsource-client/examples/tail.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ fn tail_events(client: impl es::Client) -> impl Stream<Item = Result<(), ()>> {
client
.stream()
.map_ok(|event| match event {
es::SSE::Connected((status, _)) => {
println!("got connected: \nstatus={}", status)
}
es::SSE::Event(ev) => {
println!("got an event: {}\n{}", ev.event_type, ev.data)
}
Expand Down
25 changes: 22 additions & 3 deletions eventsource-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use log::{debug, info, trace, warn};
use pin_project::pin_project;
use std::{
boxed,
collections::HashMap,
fmt::{self, Debug, Display, Formatter},
future::Future,
io::ErrorKind,
Expand All @@ -27,8 +28,8 @@ use tokio::{
time::Sleep,
};

use crate::config::ReconnectOptions;
use crate::error::{Error, Result};
use crate::{config::ReconnectOptions, ResponseWrapper};

use hyper::client::HttpConnector;
use hyper_timeout::TimeoutConnector;
Expand Down Expand Up @@ -393,6 +394,7 @@ where
let this = self.as_mut().project();
if let Some(event) = this.event_parser.get_event() {
return match event {
SSE::Connected(_) => Poll::Ready(Some(Ok(event))),
SSE::Event(ref evt) => {
*this.last_event_id = evt.id.clone();

Expand Down Expand Up @@ -438,11 +440,25 @@ where
if resp.status().is_success() {
self.as_mut().project().retry_strategy.reset(Instant::now());
self.as_mut().reset_redirects();

let headers = resp.headers();
let mut map = HashMap::new();
for (key, value) in headers.iter() {
let key = key.to_string();
let value = match value.to_str() {
Ok(value) => value.to_string(),
Err(_) => String::from(""),
};
map.insert(key, value);
}
let status = resp.status().as_u16();

self.as_mut()
.project()
.state
.set(State::Connected(resp.into_body()));
continue;

return Poll::Ready(Some(Ok(SSE::Connected((status, map)))));
}

if resp.status() == 301 || resp.status() == 307 {
Expand All @@ -467,7 +483,10 @@ where

self.as_mut().reset_redirects();
self.as_mut().project().state.set(State::New);
return Poll::Ready(Some(Err(Error::UnexpectedResponse(resp.status()))));

return Poll::Ready(Some(Err(Error::UnexpectedResponse(
ResponseWrapper::new(resp),
))));
}
Err(e) => {
// This seems basically impossible. AFAIK we can only get this way if we
Expand Down
82 changes: 79 additions & 3 deletions eventsource-client/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,77 @@
use hyper::StatusCode;
use std::collections::HashMap;

use hyper::{body::Buf, Body, Response};

pub struct ResponseWrapper {
response: Response<Body>,
}

impl ResponseWrapper {
pub fn new(response: Response<Body>) -> Self {
Self { response }
}
pub fn status(&self) -> u16 {
self.response.status().as_u16()
}
pub fn headers(&self) -> std::result::Result<HashMap<&str, &str>, HeaderError> {
let headers = self.response.headers();
let mut map = HashMap::new();
for (key, value) in headers.iter() {
let key = key.as_str();
let value = match value.to_str() {
Ok(value) => value,
Err(err) => return Err(HeaderError::new(Box::new(err))),
};
map.insert(key, value);
}
Ok(map)
}

pub async fn body_bytes(self) -> Result<Vec<u8>> {
let body = self.response.into_body();

let buf = match hyper::body::aggregate(body).await {
Ok(buf) => buf,
Err(err) => return Err(Error::HttpStream(Box::new(err))),
};

Ok(buf.chunk().to_vec())
}
}

impl std::fmt::Debug for ResponseWrapper {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseWrapper")
.field("status", &self.status())
.finish()
}
}

/// Error type for invalid response headers encountered in ResponseWrapper.
#[derive(Debug)]
pub struct HeaderError {
/// Wrapped inner error providing details about the header issue.
inner_error: Box<dyn std::error::Error + Send + Sync + 'static>,
}

impl HeaderError {
/// Constructs a new `HeaderError` wrapping an existing error.
pub fn new(err: Box<dyn std::error::Error + Send + Sync + 'static>) -> Self {
HeaderError { inner_error: err }
}
}

impl std::fmt::Display for HeaderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Invalid response header: {}", self.inner_error)
}
}

impl std::error::Error for HeaderError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(self.inner_error.as_ref())
}
}

/// Error type returned from this library's functions.
#[derive(Debug)]
Expand All @@ -8,7 +81,7 @@ pub enum Error {
/// An invalid request parameter
InvalidParameter(Box<dyn std::error::Error + Send + Sync + 'static>),
/// The HTTP response could not be handled.
UnexpectedResponse(StatusCode),
UnexpectedResponse(ResponseWrapper),
/// An error reading from the HTTP response body.
HttpStream(Box<dyn std::error::Error + Send + Sync + 'static>),
/// The HTTP response stream ended
Expand All @@ -32,7 +105,10 @@ impl std::fmt::Display for Error {
TimedOut => write!(f, "timed out"),
StreamClosed => write!(f, "stream closed"),
InvalidParameter(err) => write!(f, "invalid parameter: {err}"),
UnexpectedResponse(status_code) => write!(f, "unexpected response: {status_code}"),
UnexpectedResponse(r) => {
let status = r.status();
write!(f, "unexpected response: {status}")
}
HttpStream(err) => write!(f, "http error: {err}"),
Eof => write!(f, "eof"),
UnexpectedEof => write!(f, "unexpected eof"),
Expand Down
7 changes: 6 additions & 1 deletion eventsource-client/src/event_parser.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::{collections::VecDeque, convert::TryFrom, str::from_utf8};
use std::{
collections::{HashMap, VecDeque},
convert::TryFrom,
str::from_utf8,
};

use hyper::body::Bytes;
use log::{debug, log_enabled, trace};
Expand Down Expand Up @@ -32,6 +36,7 @@ impl EventData {

#[derive(Debug, Eq, PartialEq)]
pub enum SSE {
Connected((u16, HashMap<String, String>)),
Event(Event),
Comment(String),
}
Expand Down
3 changes: 2 additions & 1 deletion eventsource-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
//! let mut stream = Box::pin(client.stream())
//! .map_ok(|event| match event {
//! SSE::Comment(comment) => println!("got a comment event: {:?}", comment),
//! SSE::Event(evt) => println!("got an event: {}", evt.event_type)
//! SSE::Event(evt) => println!("got an event: {}", evt.event_type),
//! SSE::Connected(_) => println!("got connected")
//! })
//! .map_err(|e| println!("error streaming events: {:?}", e));
//! # while let Ok(Some(_)) = stream.try_next().await {}
Expand Down
Loading