Skip to content

Support body streams #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/mock_server/bare_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::mock_server::hyper::run_server;
use crate::mock_set::MockId;
use crate::mock_set::MountedMockSet;
use crate::request::BodyPrintLimit;
use crate::response_template::BodyStream;
use crate::{mock::Mock, verification::VerificationOutcome, Request};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::pin::pin;
Expand Down Expand Up @@ -37,7 +38,12 @@ impl MockServerState {
pub(super) async fn handle_request(
&mut self,
mut request: Request,
) -> (http_types::Response, Option<futures_timer::Delay>) {
) -> (
http_types::Response,
Option<BodyStream>,
Option<u64>,
Option<futures_timer::Delay>,
) {
request.body_print_limit = self.body_print_limit;
// If request recording is enabled, record the incoming request
// by adding it to the `received_requests` stack
Expand Down
16 changes: 12 additions & 4 deletions src/mock_server/hyper.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::mock_server::bare_server::MockServerState;
use crate::response_template::BodyStream;
use hyper::header::CONTENT_LENGTH;
use hyper::http;
use hyper::service::{make_service_fn, service_fn};
use std::net::TcpListener;
Expand All @@ -20,7 +22,7 @@ pub(super) async fn run_server(
let server_state = server_state.clone();
async move {
let wiremock_request = crate::Request::from_hyper(request).await;
let (response, delay) = server_state
let (response, body, length, delay) = server_state
.write()
.await
.handle_request(wiremock_request)
Expand All @@ -38,7 +40,9 @@ pub(super) async fn run_server(
delay.await;
}

Ok::<_, DynError>(http_types_response_to_hyper_response(response).await)
Ok::<_, DynError>(
http_types_response_to_hyper_response(response, body, length).await,
)
}
}))
}
Expand Down Expand Up @@ -76,16 +80,20 @@ where

async fn http_types_response_to_hyper_response(
mut response: http_types::Response,
body: Option<BodyStream>,
length: Option<u64>,
) -> hyper::Response<hyper::Body> {
let version = response.version().map(|v| v.into()).unwrap_or_default();
let mut builder = http::response::Builder::new()
.status(response.status() as u16)
.version(version);

headers_to_hyperium_headers(response.as_mut(), builder.headers_mut().unwrap());
if let Some(length) = length {
builder = builder.header(CONTENT_LENGTH, length);
}

let body_bytes = response.take_body().into_bytes().await.unwrap();
let body = hyper::Body::from(body_bytes);
let body = body.map_or_else(hyper::Body::empty, hyper::Body::wrap_stream);

builder.body(body).unwrap()
}
Expand Down
11 changes: 8 additions & 3 deletions src/mock_set.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{
mounted_mock::MountedMock,
response_template::BodyStream,
verification::{VerificationOutcome, VerificationReport},
};
use crate::{Mock, Request, ResponseTemplate};
Expand Down Expand Up @@ -49,7 +50,10 @@ impl MountedMockSet {
}
}

pub(crate) async fn handle_request(&mut self, request: Request) -> (Response, Option<Delay>) {
pub(crate) async fn handle_request(
&mut self,
request: Request,
) -> (Response, Option<BodyStream>, Option<u64>, Option<Delay>) {
debug!("Handling request.");
let mut response_template: Option<ResponseTemplate> = None;
self.mocks.sort_by_key(|(m, _)| m.specification.priority);
Expand All @@ -64,10 +68,11 @@ impl MountedMockSet {
}
if let Some(response_template) = response_template {
let delay = response_template.delay().map(|d| Delay::new(d.to_owned()));
(response_template.generate_response(), delay)
let (response, body, length) = response_template.generate_response();
(response, body, length, delay)
} else {
debug!("Got unexpected request:\n{}", request);
(Response::new(StatusCode::NotFound), None)
(Response::new(StatusCode::NotFound), None, None, None)
}
}

Expand Down
193 changes: 178 additions & 15 deletions src/response_template.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,46 @@
use futures::stream::{self, BoxStream};
use futures::{Stream, StreamExt};
use http_types::headers::{HeaderName, HeaderValue};
use http_types::{Response, StatusCode};
use serde::Serialize;
use std::collections::HashMap;
use std::convert::TryInto;
use std::convert::{TryFrom, TryInto};
use std::future;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;

/// The blueprint for the response returned by a [`MockServer`] when a [`Mock`] matches on an incoming request.
///
/// [`Mock`]: crate::Mock
/// [`MockServer`]: crate::MockServer
#[derive(Clone, Debug)]
#[derive(Clone)]
pub struct ResponseTemplate {
mime: Option<http_types::Mime>,
status_code: StatusCode,
headers: HashMap<HeaderName, Vec<HeaderValue>>,
body: Option<Vec<u8>>,
body_fn: Option<Arc<BodyFn>>,
delay: Option<Duration>,
}

pub(crate) type BodyFn =
dyn Fn() -> (BodyStream, Option<u64>) + Send + Sync + RefUnwindSafe + UnwindSafe;
pub(crate) type BodyStream =
BoxStream<'static, Result<Vec<u8>, Box<dyn std::error::Error + Send + Sync>>>;

impl std::fmt::Debug for ResponseTemplate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseTemplate")
.field("mime", &self.mime)
.field("status_code", &self.status_code)
.field("headers", &self.headers)
.field("body_fn", &self.body_fn.as_ref().map(|_| "[BodyFn]"))
.field("delay", &self.delay)
.finish()
}
}

// `wiremock` is a crate meant for testing - failures are most likely not handled/temporary mistakes.
// Hence we prefer to panic and provide an easier API than to use `Result`s thus pushing
// the burden of "correctness" (and conversions) on the user.
Expand All @@ -39,7 +61,7 @@ impl ResponseTemplate {
status_code,
headers: HashMap::new(),
mime: None,
body: None,
body_fn: None,
delay: None,
}
}
Expand Down Expand Up @@ -134,7 +156,33 @@ impl ResponseTemplate {
<B as TryInto<Vec<u8>>>::Error: std::fmt::Debug,
{
let body = body.try_into().expect("Failed to convert into body.");
self.body = Some(body);
self.body_fn = Some(wrap_body_in_arc_fn(body));
self
}

/// Set the response body with a stream of bytes.
///
/// It sets "Content-Type" to "application/octet-stream".
///
/// If the `length` is not set, "Transfer-Encoding" is set to "chunked".
///
/// To set a body with a stream of bytes but a different "Content-Type"
/// [`set_body_stream_raw`](#method.set_body_stream_raw) can be used.
pub fn set_body_bytes_stream<F, B, T>(mut self, body_stream_fn: F, length: Option<u64>) -> Self
where
F: Fn() -> B + Send + Sync + UnwindSafe + RefUnwindSafe + 'static,
B: Stream<Item = T> + Send + Sync + 'static,
T: TryInto<Vec<u8>>,
<T as TryInto<Vec<u8>>>::Error: std::error::Error + Send + Sync + 'static,
{
self.body_fn = Some(Arc::new(move || {
let stream = Box::pin(body_stream_fn().map(|item| {
item.try_into().map_err(|err| {
Box::new(err) as Box<dyn std::error::Error + Send + Sync + 'static>
})
}));
(stream, length)
}));
self
}

Expand All @@ -144,7 +192,7 @@ impl ResponseTemplate {
pub fn set_body_json<B: Serialize>(mut self, body: B) -> Self {
let body = serde_json::to_vec(&body).expect("Failed to convert into body.");

self.body = Some(body);
self.body_fn = Some(wrap_body_in_arc_fn(body));
self.mime = Some(
http_types::Mime::from_str("application/json")
.expect("Failed to convert into Mime header"),
Expand All @@ -162,7 +210,34 @@ impl ResponseTemplate {
{
let body = body.try_into().expect("Failed to convert into body.");

self.body = Some(body.into_bytes());
self.body_fn = Some(wrap_body_in_arc_fn(body.into_bytes()));
self.mime = Some(
http_types::Mime::from_str("text/plain").expect("Failed to convert into Mime header"),
);
self
}

/// Set the response body to a stream of strings.
///
/// It sets "Content-Type" to "text/plain".
///
/// If the `length` is not set, "Transfer-Encoding" is set to "chunked".
#[must_use]
pub fn set_body_string_stream<F, B, T>(mut self, body_stream_fn: F, length: Option<u64>) -> Self
where
F: Fn() -> B + Send + Sync + UnwindSafe + RefUnwindSafe + 'static,
B: Stream<Item = T> + Send + Sync + 'static,
T: TryInto<String>,
<T as TryInto<String>>::Error: std::error::Error + Send + Sync + 'static,
{
self.body_fn = Some(Arc::new(move || {
let stream = Box::pin(body_stream_fn().map(|item| {
item.try_into().map(String::into_bytes).map_err(|err| {
Box::new(err) as Box<dyn std::error::Error + Send + Sync + 'static>
})
}));
(stream, length)
}));
self.mime = Some(
http_types::Mime::from_str("text/plain").expect("Failed to convert into Mime header"),
);
Expand Down Expand Up @@ -219,12 +294,86 @@ impl ResponseTemplate {
<B as TryInto<Vec<u8>>>::Error: std::fmt::Debug,
{
let body = body.try_into().expect("Failed to convert into body.");
self.body = Some(body);
self.body_fn = Some(wrap_body_in_arc_fn(body));
self.mime =
Some(http_types::Mime::from_str(mime).expect("Failed to convert into Mime header"));
self
}

/// Set a raw response body using a stream of data. The mime type needs to be set because the
/// raw body could be of any type.
///
/// If the `length` is not set, "Transfer-Encoding" is set to "chunked".
///
/// It the `mime` parameter is `None`, the "Content-Type" header is set to
/// "application/octet-stream
///
/// ### Example:
/// ```rust
/// use surf::http::mime;
/// use wiremock::matchers::method;
/// use wiremock::{Mock, MockServer, ResponseTemplate};
///
/// mod external {
/// use futures::{future, stream, Stream, StreamExt};
///
/// use std::convert::Infallible;
///
/// // This could be a method of a struct that is implemented in another crate and a stream
/// // of strings is returned.
/// pub fn body() -> impl Stream<Item = Result<&'static str, Infallible>> {
/// stream::once(future::ok(r#"{"hello": "#))
/// .chain(stream::once(future::ok(r#""world"}"#)))
/// }
/// }
///
/// #[async_std::main]
/// async fn main() {
/// // Arrange
/// let mock_server = MockServer::start().await;
/// let template =
/// ResponseTemplate::new(200).set_body_raw_stream(external::body, Some(18), Some("application/json"));
/// Mock::given(method("GET"))
/// .respond_with(template)
/// .mount(&mock_server)
/// .await;
///
/// // Act
/// let mut res = surf::get(&mock_server.uri()).await.unwrap();
/// let body = res.body_string().await.unwrap();
///
/// // Assert
/// assert_eq!(body, r#"{"hello": "world"}"#);
/// assert_eq!(res.content_type(), Some(mime::JSON));
/// }
/// ```
#[must_use]
pub fn set_body_raw_stream<F, B, T, E>(
mut self,
body_stream_fn: F,
length: Option<u64>,
mime: Option<&str>,
) -> Self
where
F: Fn() -> B + Send + Sync + UnwindSafe + RefUnwindSafe + 'static,
B: Stream<Item = Result<T, E>> + Send + Sync + 'static,
T: TryInto<Vec<u8>>,
<T as TryInto<Vec<u8>>>::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
self.body_fn = Some(Arc::new(move || {
let stream = Box::pin(body_stream_fn().map(|item| {
item.map_err(Into::into)
.and_then(|item| item.try_into().map_err(Into::into))
}));
(stream, length)
}));
self.mime = mime.map(|mime| {
http_types::Mime::from_str(mime).expect("Failed to convert into Mime header")
});
self
}

/// By default the [`MockServer`] tries to fulfill incoming requests as fast as possible.
///
/// You can use `set_delay` to introduce an artificial delay to simulate the behaviour of
Expand Down Expand Up @@ -272,29 +421,43 @@ impl ResponseTemplate {
}

/// Generate a response from the template.
pub(crate) fn generate_response(&self) -> Response {
pub(crate) fn generate_response(&self) -> (Response, Option<BodyStream>, Option<u64>) {
let mut response = Response::new(self.status_code);

// Add headers
for (header_name, header_values) in &self.headers {
response.insert_header(header_name.clone(), header_values.as_slice());
}

// Add body, if specified
if let Some(body) = &self.body {
response.set_body(body.clone());
}

// Set content-type, if needed
if let Some(mime) = &self.mime {
response.set_content_type(mime.to_owned());
}

response
// Get body stream, if specified
// TODO: use Option::unzip, bumping MSRV to 1.66
let (body, length) = match self.body_fn.as_deref().map(|f| f()) {
Some((body, length)) => (Some(body), length),
None => (None, None),
};

(response, body, length)
}

/// Retrieve the response delay.
pub(crate) fn delay(&self) -> &Option<Duration> {
&self.delay
}
}

#[inline]
fn wrap_body_in_arc_fn(body: Vec<u8>) -> Arc<BodyFn> {
Arc::new(move || {
let length = body.len();
let stream = Box::pin(stream::once(future::ready(Ok(body.clone()))));
(
stream,
Some(u64::try_from(length).expect("Length of body is too big")),
)
})
}
Loading