diff --git a/src/mock_server/bare_server.rs b/src/mock_server/bare_server.rs index c13a82b..6a0ef9a 100644 --- a/src/mock_server/bare_server.rs +++ b/src/mock_server/bare_server.rs @@ -2,6 +2,7 @@ use crate::mock_server::hyper::run_server; use crate::mock_set::MockId; use crate::mock_set::MountedMockSet; use crate::request::BodyPrintLimit; +use crate::response_template::BodyStream; use crate::{mock::Mock, verification::VerificationOutcome, Request}; use std::net::{SocketAddr, TcpListener, TcpStream}; use std::pin::pin; @@ -37,7 +38,12 @@ impl MockServerState { pub(super) async fn handle_request( &mut self, mut request: Request, - ) -> (http_types::Response, Option) { + ) -> ( + http_types::Response, + Option, + Option, + Option, + ) { request.body_print_limit = self.body_print_limit; // If request recording is enabled, record the incoming request // by adding it to the `received_requests` stack diff --git a/src/mock_server/hyper.rs b/src/mock_server/hyper.rs index 8638d6a..8cb184b 100644 --- a/src/mock_server/hyper.rs +++ b/src/mock_server/hyper.rs @@ -1,4 +1,6 @@ use crate::mock_server::bare_server::MockServerState; +use crate::response_template::BodyStream; +use hyper::header::CONTENT_LENGTH; use hyper::http; use hyper::service::{make_service_fn, service_fn}; use std::net::TcpListener; @@ -20,7 +22,7 @@ pub(super) async fn run_server( let server_state = server_state.clone(); async move { let wiremock_request = crate::Request::from_hyper(request).await; - let (response, delay) = server_state + let (response, body, length, delay) = server_state .write() .await .handle_request(wiremock_request) @@ -38,7 +40,9 @@ pub(super) async fn run_server( delay.await; } - Ok::<_, DynError>(http_types_response_to_hyper_response(response).await) + Ok::<_, DynError>( + http_types_response_to_hyper_response(response, body, length).await, + ) } })) } @@ -76,6 +80,8 @@ where async fn http_types_response_to_hyper_response( mut response: http_types::Response, + body: Option, + length: Option, ) -> hyper::Response { let version = response.version().map(|v| v.into()).unwrap_or_default(); let mut builder = http::response::Builder::new() @@ -83,9 +89,11 @@ async fn http_types_response_to_hyper_response( .version(version); headers_to_hyperium_headers(response.as_mut(), builder.headers_mut().unwrap()); + if let Some(length) = length { + builder = builder.header(CONTENT_LENGTH, length); + } - let body_bytes = response.take_body().into_bytes().await.unwrap(); - let body = hyper::Body::from(body_bytes); + let body = body.map_or_else(hyper::Body::empty, hyper::Body::wrap_stream); builder.body(body).unwrap() } diff --git a/src/mock_set.rs b/src/mock_set.rs index 996d6f0..e3d2284 100644 --- a/src/mock_set.rs +++ b/src/mock_set.rs @@ -1,5 +1,6 @@ use crate::{ mounted_mock::MountedMock, + response_template::BodyStream, verification::{VerificationOutcome, VerificationReport}, }; use crate::{Mock, Request, ResponseTemplate}; @@ -49,7 +50,10 @@ impl MountedMockSet { } } - pub(crate) async fn handle_request(&mut self, request: Request) -> (Response, Option) { + pub(crate) async fn handle_request( + &mut self, + request: Request, + ) -> (Response, Option, Option, Option) { debug!("Handling request."); let mut response_template: Option = None; self.mocks.sort_by_key(|(m, _)| m.specification.priority); @@ -64,10 +68,11 @@ impl MountedMockSet { } if let Some(response_template) = response_template { let delay = response_template.delay().map(|d| Delay::new(d.to_owned())); - (response_template.generate_response(), delay) + let (response, body, length) = response_template.generate_response(); + (response, body, length, delay) } else { debug!("Got unexpected request:\n{}", request); - (Response::new(StatusCode::NotFound), None) + (Response::new(StatusCode::NotFound), None, None, None) } } diff --git a/src/response_template.rs b/src/response_template.rs index b914dea..60f5074 100644 --- a/src/response_template.rs +++ b/src/response_template.rs @@ -1,24 +1,46 @@ +use futures::stream::{self, BoxStream}; +use futures::{Stream, StreamExt}; use http_types::headers::{HeaderName, HeaderValue}; use http_types::{Response, StatusCode}; use serde::Serialize; use std::collections::HashMap; -use std::convert::TryInto; +use std::convert::{TryFrom, TryInto}; +use std::future; +use std::panic::{RefUnwindSafe, UnwindSafe}; use std::str::FromStr; +use std::sync::Arc; use std::time::Duration; /// The blueprint for the response returned by a [`MockServer`] when a [`Mock`] matches on an incoming request. /// /// [`Mock`]: crate::Mock /// [`MockServer`]: crate::MockServer -#[derive(Clone, Debug)] +#[derive(Clone)] pub struct ResponseTemplate { mime: Option, status_code: StatusCode, headers: HashMap>, - body: Option>, + body_fn: Option>, delay: Option, } +pub(crate) type BodyFn = + dyn Fn() -> (BodyStream, Option) + Send + Sync + RefUnwindSafe + UnwindSafe; +pub(crate) type BodyStream = + BoxStream<'static, Result, Box>>; + +impl std::fmt::Debug for ResponseTemplate { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResponseTemplate") + .field("mime", &self.mime) + .field("status_code", &self.status_code) + .field("headers", &self.headers) + .field("body_fn", &self.body_fn.as_ref().map(|_| "[BodyFn]")) + .field("delay", &self.delay) + .finish() + } +} + // `wiremock` is a crate meant for testing - failures are most likely not handled/temporary mistakes. // Hence we prefer to panic and provide an easier API than to use `Result`s thus pushing // the burden of "correctness" (and conversions) on the user. @@ -39,7 +61,7 @@ impl ResponseTemplate { status_code, headers: HashMap::new(), mime: None, - body: None, + body_fn: None, delay: None, } } @@ -134,7 +156,33 @@ impl ResponseTemplate { >>::Error: std::fmt::Debug, { let body = body.try_into().expect("Failed to convert into body."); - self.body = Some(body); + self.body_fn = Some(wrap_body_in_arc_fn(body)); + self + } + + /// Set the response body with a stream of bytes. + /// + /// It sets "Content-Type" to "application/octet-stream". + /// + /// If the `length` is not set, "Transfer-Encoding" is set to "chunked". + /// + /// To set a body with a stream of bytes but a different "Content-Type" + /// [`set_body_stream_raw`](#method.set_body_stream_raw) can be used. + pub fn set_body_bytes_stream(mut self, body_stream_fn: F, length: Option) -> Self + where + F: Fn() -> B + Send + Sync + UnwindSafe + RefUnwindSafe + 'static, + B: Stream + Send + Sync + 'static, + T: TryInto>, + >>::Error: std::error::Error + Send + Sync + 'static, + { + self.body_fn = Some(Arc::new(move || { + let stream = Box::pin(body_stream_fn().map(|item| { + item.try_into().map_err(|err| { + Box::new(err) as Box + }) + })); + (stream, length) + })); self } @@ -144,7 +192,7 @@ impl ResponseTemplate { pub fn set_body_json(mut self, body: B) -> Self { let body = serde_json::to_vec(&body).expect("Failed to convert into body."); - self.body = Some(body); + self.body_fn = Some(wrap_body_in_arc_fn(body)); self.mime = Some( http_types::Mime::from_str("application/json") .expect("Failed to convert into Mime header"), @@ -162,7 +210,34 @@ impl ResponseTemplate { { let body = body.try_into().expect("Failed to convert into body."); - self.body = Some(body.into_bytes()); + self.body_fn = Some(wrap_body_in_arc_fn(body.into_bytes())); + self.mime = Some( + http_types::Mime::from_str("text/plain").expect("Failed to convert into Mime header"), + ); + self + } + + /// Set the response body to a stream of strings. + /// + /// It sets "Content-Type" to "text/plain". + /// + /// If the `length` is not set, "Transfer-Encoding" is set to "chunked". + #[must_use] + pub fn set_body_string_stream(mut self, body_stream_fn: F, length: Option) -> Self + where + F: Fn() -> B + Send + Sync + UnwindSafe + RefUnwindSafe + 'static, + B: Stream + Send + Sync + 'static, + T: TryInto, + >::Error: std::error::Error + Send + Sync + 'static, + { + self.body_fn = Some(Arc::new(move || { + let stream = Box::pin(body_stream_fn().map(|item| { + item.try_into().map(String::into_bytes).map_err(|err| { + Box::new(err) as Box + }) + })); + (stream, length) + })); self.mime = Some( http_types::Mime::from_str("text/plain").expect("Failed to convert into Mime header"), ); @@ -219,12 +294,86 @@ impl ResponseTemplate { >>::Error: std::fmt::Debug, { let body = body.try_into().expect("Failed to convert into body."); - self.body = Some(body); + self.body_fn = Some(wrap_body_in_arc_fn(body)); self.mime = Some(http_types::Mime::from_str(mime).expect("Failed to convert into Mime header")); self } + /// Set a raw response body using a stream of data. The mime type needs to be set because the + /// raw body could be of any type. + /// + /// If the `length` is not set, "Transfer-Encoding" is set to "chunked". + /// + /// It the `mime` parameter is `None`, the "Content-Type" header is set to + /// "application/octet-stream + /// + /// ### Example: + /// ```rust + /// use surf::http::mime; + /// use wiremock::matchers::method; + /// use wiremock::{Mock, MockServer, ResponseTemplate}; + /// + /// mod external { + /// use futures::{future, stream, Stream, StreamExt}; + /// + /// use std::convert::Infallible; + /// + /// // This could be a method of a struct that is implemented in another crate and a stream + /// // of strings is returned. + /// pub fn body() -> impl Stream> { + /// stream::once(future::ok(r#"{"hello": "#)) + /// .chain(stream::once(future::ok(r#""world"}"#))) + /// } + /// } + /// + /// #[async_std::main] + /// async fn main() { + /// // Arrange + /// let mock_server = MockServer::start().await; + /// let template = + /// ResponseTemplate::new(200).set_body_raw_stream(external::body, Some(18), Some("application/json")); + /// Mock::given(method("GET")) + /// .respond_with(template) + /// .mount(&mock_server) + /// .await; + /// + /// // Act + /// let mut res = surf::get(&mock_server.uri()).await.unwrap(); + /// let body = res.body_string().await.unwrap(); + /// + /// // Assert + /// assert_eq!(body, r#"{"hello": "world"}"#); + /// assert_eq!(res.content_type(), Some(mime::JSON)); + /// } + /// ``` + #[must_use] + pub fn set_body_raw_stream( + mut self, + body_stream_fn: F, + length: Option, + mime: Option<&str>, + ) -> Self + where + F: Fn() -> B + Send + Sync + UnwindSafe + RefUnwindSafe + 'static, + B: Stream> + Send + Sync + 'static, + T: TryInto>, + >>::Error: Into>, + E: Into>, + { + self.body_fn = Some(Arc::new(move || { + let stream = Box::pin(body_stream_fn().map(|item| { + item.map_err(Into::into) + .and_then(|item| item.try_into().map_err(Into::into)) + })); + (stream, length) + })); + self.mime = mime.map(|mime| { + http_types::Mime::from_str(mime).expect("Failed to convert into Mime header") + }); + self + } + /// By default the [`MockServer`] tries to fulfill incoming requests as fast as possible. /// /// You can use `set_delay` to introduce an artificial delay to simulate the behaviour of @@ -272,7 +421,7 @@ impl ResponseTemplate { } /// Generate a response from the template. - pub(crate) fn generate_response(&self) -> Response { + pub(crate) fn generate_response(&self) -> (Response, Option, Option) { let mut response = Response::new(self.status_code); // Add headers @@ -280,17 +429,19 @@ impl ResponseTemplate { response.insert_header(header_name.clone(), header_values.as_slice()); } - // Add body, if specified - if let Some(body) = &self.body { - response.set_body(body.clone()); - } - // Set content-type, if needed if let Some(mime) = &self.mime { response.set_content_type(mime.to_owned()); } - response + // Get body stream, if specified + // TODO: use Option::unzip, bumping MSRV to 1.66 + let (body, length) = match self.body_fn.as_deref().map(|f| f()) { + Some((body, length)) => (Some(body), length), + None => (None, None), + }; + + (response, body, length) } /// Retrieve the response delay. @@ -298,3 +449,15 @@ impl ResponseTemplate { &self.delay } } + +#[inline] +fn wrap_body_in_arc_fn(body: Vec) -> Arc { + Arc::new(move || { + let length = body.len(); + let stream = Box::pin(stream::once(future::ready(Ok(body.clone())))); + ( + stream, + Some(u64::try_from(length).expect("Length of body is too big")), + ) + }) +} diff --git a/tests/mocks.rs b/tests/mocks.rs index be0fb5d..8f187b9 100644 --- a/tests/mocks.rs +++ b/tests/mocks.rs @@ -1,4 +1,4 @@ -use futures::FutureExt; +use futures::{future, stream, FutureExt, StreamExt}; use http_types::StatusCode; use serde::Serialize; use serde_json::json; @@ -127,6 +127,61 @@ async fn simple_route_mock() { assert_eq!(response.body_string().await.unwrap(), "world"); } +#[async_std::test] +async fn simple_route_mock_with_bytes_stream() { + // Arrange + let mock_server = MockServer::start().await; + let stream_fn = + || stream::once(future::ready("hello ")).chain(stream::once(future::ready("world"))); + let response = ResponseTemplate::new(200).set_body_bytes_stream(stream_fn, None); + let mock = Mock::given(method("GET")) + .and(PathExactMatcher::new("path")) + .respond_with(response); + mock_server.register(mock).await; + + // Act + let mut response = surf::get(format!("{}/path", &mock_server.uri())) + .await + .unwrap(); + + // Assert + assert_eq!(response.status(), 200); + assert_eq!(response.header("Transfer-encoding").unwrap(), "chunked"); + assert_eq!(response.body_string().await.unwrap(), "hello world"); +} + +#[async_std::test] +async fn simple_route_mock_with_failing_stream() { + // Arrange + let mock_server = MockServer::start().await; + let stream_fn = || { + stream::once(future::ok("hello ")).chain(stream::once(async move { + // This is needed to make the body fail instead of the response + async_std::task::sleep(Duration::from_millis(1)).await; + Err(std::io::Error::new( + std::io::ErrorKind::Other, + "testing error", + )) + })) + }; + let response = ResponseTemplate::new(200).set_body_raw_stream(stream_fn, Some(10), None); + let mock = Mock::given(method("GET")) + .and(PathExactMatcher::new("path")) + .respond_with(response); + mock_server.register(mock).await; + + // Act + let mut response = surf::get(format!("{}/path", &mock_server.uri())) + .await + .unwrap(); + + // Assert + assert_eq!(response.status(), 200); + assert!(response.header("Transfer-encoding").is_none()); + // TODO: improve error recognition, error from stream is not bubbled up + response.body_string().await.unwrap_err(); +} + #[async_std::test] async fn two_route_mocks() { // Arrange