diff --git a/edge-http/src/io.rs b/edge-http/src/io.rs index 6f2be50..a6d15a1 100644 --- a/edge-http/src/io.rs +++ b/edge-http/src/io.rs @@ -9,7 +9,10 @@ use httparse::Status; use log::trace; use crate::ws::UpgradeError; -use crate::{BodyType, Headers, Method, RequestHeaders, ResponseHeaders}; +use crate::{ + BodyType, ConnectionType, Headers, HeadersMismatchError, Method, RequestHeaders, + ResponseHeaders, +}; pub mod client; pub mod server; @@ -27,6 +30,7 @@ pub enum Error { InvalidState, Timeout, ConnectionClosed, + HeadersMismatchError(HeadersMismatchError), WsUpgradeError(UpgradeError), Io(E), } @@ -45,6 +49,12 @@ impl From for Error { } } +impl From for Error { + fn from(e: HeadersMismatchError) -> Self { + Self::HeadersMismatchError(e) + } +} + impl From for Error { fn from(e: UpgradeError) -> Self { Self::WsUpgradeError(e) @@ -78,6 +88,7 @@ where Self::IncompleteBody => write!(f, "HTTP body is incomplete"), Self::InvalidState => write!(f, "Connection is not in requested state"), Self::Timeout => write!(f, "Timeout"), + Self::HeadersMismatchError(e) => write!(f, "Headers mismatch: {e}"), Self::WsUpgradeError(e) => write!(f, "WebSocket upgrade error: {e}"), Self::ConnectionClosed => write!(f, "Connection closed"), Self::Io(e) => write!(f, "{e}"), @@ -89,6 +100,7 @@ where impl std::error::Error for Error where E: std::error::Error {} impl<'b, const N: usize> RequestHeaders<'b, N> { + /// Parse the headers from the input stream pub async fn receive( &mut self, buf: &'b mut [u8], @@ -99,7 +111,7 @@ impl<'b, const N: usize> RequestHeaders<'b, N> { R: Read, { let (read_len, headers_len) = - match read_reply_buf::(&mut input, buf, true, exact).await { + match raw::read_reply_buf::(&mut input, buf, true, exact).await { Ok(read_len) => read_len, Err(e) => return Err(e), }; @@ -139,25 +151,33 @@ impl<'b, const N: usize> RequestHeaders<'b, N> { } } - pub async fn send(&self, mut output: W) -> Result> + /// Resolve the connection type and body type from the headers + pub fn resolve(&self) -> Result<(ConnectionType, BodyType), Error> { + self.headers + .resolve::(None, true, self.http11.unwrap_or(false)) + } + + /// Send the headers to the output stream, returning the connection type and body type + pub async fn send( + &self, + chunked_if_unspecified: bool, + mut output: W, + ) -> Result<(ConnectionType, BodyType), Error> where W: Write, { - send_request( - self.http11.unwrap_or(false), - self.method, - self.path, - &mut output, - ) - .await?; - let body_type = self.headers.send(&mut output).await?; - send_headers_end(output).await?; + let http11 = self.http11.unwrap_or(false); + + send_request(http11, self.method, self.path, &mut output).await?; - Ok(body_type) + self.headers + .send(None, true, http11, chunked_if_unspecified, output) + .await } } impl<'b, const N: usize> ResponseHeaders<'b, N> { + /// Parse the headers from the input stream pub async fn receive( &mut self, buf: &'b mut [u8], @@ -167,7 +187,8 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> { where R: Read, { - let (read_len, headers_len) = read_reply_buf::(&mut input, buf, false, exact).await?; + let (read_len, headers_len) = + raw::read_reply_buf::(&mut input, buf, false, exact).await?; let mut parser = httparse::Response::new(&mut self.headers.0); @@ -201,21 +222,41 @@ impl<'b, const N: usize> ResponseHeaders<'b, N> { } } - pub async fn send(&self, mut output: W) -> Result> - where - W: Write, - { - send_status( + /// Resolve the connection type and body type from the headers + pub fn resolve( + &self, + request_connection_type: ConnectionType, + ) -> Result<(ConnectionType, BodyType), Error> { + self.headers.resolve::( + Some(request_connection_type), + false, self.http11.unwrap_or(false), - self.code, - self.reason, - &mut output, ) - .await?; - let body_type = self.headers.send(&mut output).await?; - send_headers_end(output).await?; + } - Ok(body_type) + /// Send the headers to the output stream, returning the connection type and body type + pub async fn send( + &self, + request_connection_type: ConnectionType, + chunked_if_unspecified: bool, + mut output: W, + ) -> Result<(ConnectionType, BodyType), Error> + where + W: Write, + { + let http11 = self.http11.unwrap_or(false); + + send_status(http11, self.code, self.reason, &mut output).await?; + + self.headers + .send( + Some(request_connection_type), + false, + http11, + chunked_if_unspecified, + output, + ) + .await } } @@ -228,7 +269,7 @@ pub(crate) async fn send_request( where W: Write, { - send_status_line( + raw::send_status_line( true, http11, method.map(|method| method.as_str()), @@ -249,7 +290,7 @@ where { let status_str: Option> = status.map(|status| status.try_into().unwrap()); - send_status_line( + raw::send_status_line( false, http11, status_str.as_ref().map(|status| status.as_str()), @@ -261,65 +302,135 @@ where pub(crate) async fn send_headers<'a, H, W>( headers: H, - output: W, -) -> Result> + carry_over_connection_type: Option, + request: bool, + http11: bool, + chunked_if_unspecified: bool, + mut output: W, +) -> Result<(ConnectionType, BodyType), Error> where W: Write, H: IntoIterator, { - send_raw_headers( + let (headers_connection_type, headers_body_type) = raw::send_headers( headers .into_iter() .map(|(name, value)| (*name, value.as_bytes())), + &mut output, + ) + .await?; + + send_headers_end( + headers_connection_type, + headers_body_type, + carry_over_connection_type, + request, + http11, + chunked_if_unspecified, output, ) .await } -pub(crate) async fn send_raw_headers<'a, H, W>( - headers: H, +async fn send_headers_end( + headers_connection_type: Option, + headers_body_type: Option, + carry_over_connection_type: Option, + request: bool, + http11: bool, + chunked_if_unspecified: bool, mut output: W, -) -> Result> +) -> Result<(ConnectionType, BodyType), Error> where W: Write, - H: IntoIterator, { - let mut body = BodyType::Unknown; + let connection_type = + ConnectionType::resolve(headers_connection_type, carry_over_connection_type, http11)?; - for (name, value) in headers.into_iter() { - if body == BodyType::Unknown { - body = BodyType::from_header(name, unsafe { str::from_utf8_unchecked(value) }); - } + let body_type = BodyType::resolve( + headers_body_type, + connection_type, + request, + http11, + chunked_if_unspecified, + )?; - output.write_all(name.as_bytes()).await.map_err(Error::Io)?; - output.write_all(b": ").await.map_err(Error::Io)?; - output.write_all(value).await.map_err(Error::Io)?; - output.write_all(b"\r\n").await.map_err(Error::Io)?; + if headers_connection_type.is_none() { + // Send an explicit Connection-Type just in case + let (name, value) = connection_type.raw_header(); + + raw::send_header(name, value, &mut output).await?; } - Ok(body) -} + if headers_body_type.is_none() { + let mut buf = heapless::String::new(); -pub(crate) async fn send_headers_end(mut output: W) -> Result<(), Error> -where - W: Write, -{ - output.write_all(b"\r\n").await.map_err(Error::Io) + if let Some((name, value)) = body_type.raw_header(&mut buf) { + // Send explicit body type header just in case or if the body type was upgraded + raw::send_header(name, value, &mut output).await?; + } + } + + raw::send_headers_end(output).await?; + + Ok((connection_type, body_type)) } impl<'b, const N: usize> Headers<'b, N> { - pub(crate) async fn send(&self, output: W) -> Result> + fn resolve( + &self, + carry_over_connection_type: Option, + request: bool, + http11: bool, + ) -> Result<(ConnectionType, BodyType), Error> { + let headers_connection_type = ConnectionType::from_headers(self.iter()); + let headers_body_type = BodyType::from_headers(self.iter()); + + let connection_type = + ConnectionType::resolve(headers_connection_type, carry_over_connection_type, http11)?; + let body_type = + BodyType::resolve(headers_body_type, connection_type, request, http11, false)?; + + Ok((connection_type, body_type)) + } + + async fn send( + &self, + carry_over_connection_type: Option, + request: bool, + http11: bool, + chunked_if_unspecified: bool, + mut output: W, + ) -> Result<(ConnectionType, BodyType), Error> where W: Write, { - send_raw_headers(self.iter_raw(), output).await + let (headers_connection_type, headers_body_type) = + raw::send_headers(self.iter_raw(), &mut output).await?; + + send_headers_end( + headers_connection_type, + headers_body_type, + carry_over_connection_type, + request, + http11, + chunked_if_unspecified, + output, + ) + .await } } +/// Represents an incoming HTTP request stream body +/// +/// Implements the `Read` trait to read the body from the stream #[allow(private_interfaces)] pub enum Body<'b, R> { - Close(PartiallyRead<'b, R>), + /// The body is raw and should be read as is (only possible for HTTP responses with connection = Close) + Raw(PartiallyRead<'b, R>), + /// The body is of a known length (Content-Length) ContentLen(ContentLenRead>), + /// The body is chunked (Transfer-Encoding: chunked) Chunked(ChunkedRead<'b, PartiallyRead<'b, R>>), } @@ -327,6 +438,13 @@ impl<'b, R> Body<'b, R> where R: Read, { + /// Create a new body + /// + /// Parameters: + /// - `body_type`: The type of the body, as resolved using `BodyType::resolve` + /// - `buf`: The buffer to use for reading the body + /// - `read_len`: The length of the buffer that has already been read when processing the icoming headers + /// - `input`: The raw input stream pub fn new(body_type: BodyType, buf: &'b mut [u8], read_len: usize, input: R) -> Self { match body_type { BodyType::Chunked => Body::Chunked(ChunkedRead::new( @@ -338,33 +456,37 @@ where content_len, PartiallyRead::new(&buf[..read_len], input), )), - BodyType::Close => Body::Close(PartiallyRead::new(&buf[..read_len], input)), - BodyType::Unknown => Body::ContentLen(ContentLenRead::new( - 0, - PartiallyRead::new(&buf[..read_len], input), - )), + BodyType::Raw => Body::Raw(PartiallyRead::new(&buf[..read_len], input)), } } + /// Check if the body needs to be closed (i.e. the underlying input stream cannot be re-used for Keep-Alive connections) + pub fn needs_close(&self) -> bool { + !self.is_complete() || matches!(self, Self::Raw(_)) + } + + /// Check if the body has been completely read pub fn is_complete(&self) -> bool { match self { - Self::Close(_) => true, + Self::Raw(_) => true, Self::ContentLen(r) => r.is_complete(), Self::Chunked(r) => r.is_complete(), } } + /// Return a mutable reference to the underlying raw reader pub fn as_raw_reader(&mut self) -> &mut R { match self { - Self::Close(r) => &mut r.input, + Self::Raw(r) => &mut r.input, Self::ContentLen(r) => &mut r.input.input, Self::Chunked(r) => &mut r.input.input, } } + /// Release the body, returning the underlying raw reader pub fn release(self) -> R { match self { - Self::Close(r) => r.release(), + Self::Raw(r) => r.release(), Self::ContentLen(r) => r.release().release(), Self::Chunked(r) => r.release().release(), } @@ -384,7 +506,7 @@ where { async fn read(&mut self, buf: &mut [u8]) -> Result { match self { - Self::Close(read) => Ok(read.read(buf).await.map_err(Error::Io)?), + Self::Raw(read) => Ok(read.read(buf).await.map_err(Error::Io)?), Self::ContentLen(read) => Ok(read.read(buf).await?), Self::Chunked(read) => Ok(read.read(buf).await?), } @@ -722,10 +844,16 @@ where } } +/// Represents an outgoing HTTP request stream body +/// +/// Implements the `Write` trait to write the body to the stream #[allow(private_interfaces)] pub enum SendBody { - Close(W), + /// The body is raw and should be written as is (only possible for HTTP responses with connection = Close) + Raw(W), + /// The body is of a known length (Content-Length) ContentLen(ContentLenWrite), + /// The body is chunked (Transfer-Encoding: chunked) Chunked(ChunkedWrite), } @@ -733,17 +861,22 @@ impl SendBody where W: Write, { + /// Create a new body + /// + /// Parameters: + /// - `body_type`: The type of the body, as resolved using `BodyType::resolve` + /// - `output`: The raw output stream pub fn new(body_type: BodyType, output: W) -> SendBody { match body_type { BodyType::Chunked => SendBody::Chunked(ChunkedWrite::new(output)), BodyType::ContentLen(content_len) => { SendBody::ContentLen(ContentLenWrite::new(content_len, output)) } - BodyType::Close => SendBody::Close(output), - BodyType::Unknown => SendBody::ContentLen(ContentLenWrite::new(0, output)), + BodyType::Raw => SendBody::Raw(output), } } + /// Check if the body has been completely written to pub fn is_complete(&self) -> bool { match self { Self::ContentLen(w) => w.is_complete(), @@ -751,16 +884,18 @@ where } } + /// Check if the body needs to be closed (i.e. the underlying output stream cannot be re-used for Keep-Alive connections) pub fn needs_close(&self) -> bool { - !self.is_complete() || matches!(self, Self::Close(_)) + !self.is_complete() || matches!(self, Self::Raw(_)) } + /// Finish writing the body (necessary for chunked encoding) pub async fn finish(&mut self) -> Result<(), Error> where W: Write, { match self { - Self::Close(_) => (), + Self::Raw(_) => (), Self::ContentLen(w) => { if !w.is_complete() { return Err(Error::IncompleteBody); @@ -774,17 +909,19 @@ where Ok(()) } + /// Return a mutable reference to the underlying raw writer pub fn as_raw_writer(&mut self) -> &mut W { match self { - Self::Close(w) => w, + Self::Raw(w) => w, Self::ContentLen(w) => &mut w.output, Self::Chunked(w) => &mut w.output, } } + /// Release the body, returning the underlying raw writer pub fn release(self) -> W { match self { - Self::Close(w) => w, + Self::Raw(w) => w, Self::ContentLen(w) => w.release(), Self::Chunked(w) => w.release(), } @@ -804,7 +941,7 @@ where { async fn write(&mut self, buf: &[u8]) -> Result { match self { - Self::Close(w) => Ok(w.write(buf).await.map_err(Error::Io)?), + Self::Raw(w) => Ok(w.write(buf).await.map_err(Error::Io)?), Self::ContentLen(w) => Ok(w.write(buf).await?), Self::Chunked(w) => Ok(w.write(buf).await?), } @@ -812,7 +949,7 @@ where async fn flush(&mut self) -> Result<(), Self::Error> { match self { - Self::Close(w) => Ok(w.flush().await.map_err(Error::Io)?), + Self::Raw(w) => Ok(w.flush().await.map_err(Error::Io)?), Self::ContentLen(w) => Ok(w.flush().await?), Self::Chunked(w) => Ok(w.flush().await?), } @@ -941,37 +1078,93 @@ where } } -async fn read_reply_buf( - mut input: R, - buf: &mut [u8], - request: bool, - exact: bool, -) -> Result<(usize, usize), Error> -where - R: Read, -{ - if exact { - let raw_headers_len = read_headers(&mut input, buf).await?; +mod raw { + use core::str; + + use embedded_io_async::{Read, Write}; + + use log::warn; + + use crate::{BodyType, ConnectionType}; + + use super::Error; + + pub(crate) async fn read_reply_buf( + mut input: R, + buf: &mut [u8], + request: bool, + exact: bool, + ) -> Result<(usize, usize), Error> + where + R: Read, + { + if exact { + let raw_headers_len = read_headers(&mut input, buf).await?; - let mut headers = [httparse::EMPTY_HEADER; N]; + let mut headers = [httparse::EMPTY_HEADER; N]; - let status = if request { - httparse::Request::new(&mut headers).parse(&buf[..raw_headers_len])? + let status = if request { + httparse::Request::new(&mut headers).parse(&buf[..raw_headers_len])? + } else { + httparse::Response::new(&mut headers).parse(&buf[..raw_headers_len])? + }; + + if let httparse::Status::Complete(headers_len) = status { + return Ok((raw_headers_len, headers_len)); + } + + Err(Error::TooManyHeaders) } else { - httparse::Response::new(&mut headers).parse(&buf[..raw_headers_len])? - }; + let mut offset = 0; + let mut size = 0; + + while buf.len() > size { + let read = input.read(&mut buf[offset..]).await.map_err(Error::Io)?; + if read == 0 { + Err(if offset == 0 { + Error::ConnectionClosed + } else { + Error::IncompleteHeaders + })?; + } + + offset += read; + size += read; + + let mut headers = [httparse::EMPTY_HEADER; N]; + + let status = if request { + httparse::Request::new(&mut headers).parse(&buf[..size])? + } else { + httparse::Response::new(&mut headers).parse(&buf[..size])? + }; - if let httparse::Status::Complete(headers_len) = status { - return Ok((raw_headers_len, headers_len)); + if let httparse::Status::Complete(headers_len) = status { + return Ok((size, headers_len)); + } + } + + Err(Error::TooManyHeaders) } + } - Err(Error::TooManyHeaders) - } else { + pub(crate) async fn read_headers( + mut input: R, + buf: &mut [u8], + ) -> Result> + where + R: Read, + { let mut offset = 0; - let mut size = 0; + let mut byte = [0]; + + loop { + if offset == buf.len() { + Err(Error::TooLongHeaders)?; + } + + let read = input.read(&mut byte).await.map_err(Error::Io)?; - while buf.len() > size { - let read = input.read(&mut buf[offset..]).await.map_err(Error::Io)?; if read == 0 { Err(if offset == 0 { Error::ConnectionClosed @@ -980,122 +1173,146 @@ where })?; } - offset += read; - size += read; + buf[offset] = byte[0]; - let mut headers = [httparse::EMPTY_HEADER; N]; + offset += 1; - let status = if request { - httparse::Request::new(&mut headers).parse(&buf[..size])? - } else { - httparse::Response::new(&mut headers).parse(&buf[..size])? - }; - - if let httparse::Status::Complete(headers_len) = status { - return Ok((size, headers_len)); + if offset >= b"\r\n\r\n".len() && buf[offset - 4..offset] == *b"\r\n\r\n" { + break Ok(offset); } } - - Err(Error::TooManyHeaders) } -} -async fn read_headers(mut input: R, buf: &mut [u8]) -> Result> -where - R: Read, -{ - let mut offset = 0; - let mut byte = [0]; + pub(crate) async fn send_status_line( + request: bool, + http11: bool, + token: Option<&str>, + extra: Option<&str>, + mut output: W, + ) -> Result<(), Error> + where + W: Write, + { + let mut written = false; - loop { - if offset == buf.len() { - Err(Error::TooLongHeaders)?; + if !request { + send_version(&mut output, http11).await?; + written = true; } - let read = input.read(&mut byte).await.map_err(Error::Io)?; + if let Some(token) = token { + if written { + output.write_all(b" ").await.map_err(Error::Io)?; + } - if read == 0 { - Err(if offset == 0 { - Error::ConnectionClosed - } else { - Error::IncompleteHeaders - })?; + output + .write_all(token.as_bytes()) + .await + .map_err(Error::Io)?; + + written = true; } - buf[offset] = byte[0]; + if let Some(extra) = extra { + if written { + output.write_all(b" ").await.map_err(Error::Io)?; + } - offset += 1; + output + .write_all(extra.as_bytes()) + .await + .map_err(Error::Io)?; - if offset >= b"\r\n\r\n".len() && buf[offset - 4..offset] == *b"\r\n\r\n" { - break Ok(offset); + written = true; } - } -} -async fn send_status_line( - request: bool, - http11: bool, - token: Option<&str>, - extra: Option<&str>, - mut output: W, -) -> Result<(), Error> -where - W: Write, -{ - let mut written = false; - - if !request { - send_version(&mut output, http11).await?; - written = true; - } + if request { + if written { + output.write_all(b" ").await.map_err(Error::Io)?; + } - if let Some(token) = token { - if written { - output.write_all(b" ").await.map_err(Error::Io)?; + send_version(&mut output, http11).await?; } - output - .write_all(token.as_bytes()) - .await - .map_err(Error::Io)?; + output.write_all(b"\r\n").await.map_err(Error::Io)?; - written = true; + Ok(()) } - if let Some(extra) = extra { - if written { - output.write_all(b" ").await.map_err(Error::Io)?; - } - + pub(crate) async fn send_version(mut output: W, http11: bool) -> Result<(), Error> + where + W: Write, + { output - .write_all(extra.as_bytes()) + .write_all(if http11 { b"HTTP/1.1" } else { b"HTTP/1.0" }) .await - .map_err(Error::Io)?; - - written = true; + .map_err(Error::Io) } - if request { - if written { - output.write_all(b" ").await.map_err(Error::Io)?; + pub(crate) async fn send_headers<'a, H, W>( + headers: H, + mut output: W, + ) -> Result<(Option, Option), Error> + where + W: Write, + H: IntoIterator, + { + let mut connection = None; + let mut body = None; + + for (name, value) in headers.into_iter() { + let header_connection = + ConnectionType::from_header(name, unsafe { str::from_utf8_unchecked(value) }); + + if let Some(header_connection) = header_connection { + if let Some(connection) = connection { + warn!("Multiple Connection headers found. Current {connection} and new {header_connection}"); + } + + // The last connection header wins + connection = Some(header_connection); + } + + let header_body = + BodyType::from_header(name, unsafe { str::from_utf8_unchecked(value) }); + + if let Some(header_body) = header_body { + if let Some(body) = body { + warn!("Multiple body type headers found. Current {body} and new {header_body}"); + } + + // The last body header wins + body = Some(header_body); + } + + send_header(name, value, &mut output).await?; } - send_version(&mut output, http11).await?; + Ok((connection, body)) } - output.write_all(b"\r\n").await.map_err(Error::Io)?; + pub(crate) async fn send_header( + name: &str, + value: &[u8], + mut output: W, + ) -> Result<(), Error> + where + W: Write, + { + output.write_all(name.as_bytes()).await.map_err(Error::Io)?; + output.write_all(b": ").await.map_err(Error::Io)?; + output.write_all(value).await.map_err(Error::Io)?; + output.write_all(b"\r\n").await.map_err(Error::Io)?; - Ok(()) -} + Ok(()) + } -async fn send_version(mut output: W, http11: bool) -> Result<(), Error> -where - W: Write, -{ - output - .write_all(if http11 { b"HTTP/1.1" } else { b"HTTP/1.0" }) - .await - .map_err(Error::Io) + pub(crate) async fn send_headers_end(mut output: W) -> Result<(), Error> + where + W: Write, + { + output.write_all(b"\r\n").await.map_err(Error::Io) + } } #[cfg(test)] diff --git a/edge-http/src/io/client.rs b/edge-http/src/io/client.rs index b93f2aa..6578540 100644 --- a/edge-http/src/io/client.rs +++ b/edge-http/src/io/client.rs @@ -8,12 +8,10 @@ use edge_nal::TcpConnect; use crate::{ ws::{upgrade_request_headers, MAX_BASE64_KEY_LEN, MAX_BASE64_KEY_RESPONSE_LEN, NONCE_LEN}, - DEFAULT_MAX_HEADERS_COUNT, + ConnectionType, DEFAULT_MAX_HEADERS_COUNT, }; -use super::{ - send_headers, send_headers_end, send_request, Body, BodyType, Error, ResponseHeaders, SendBody, -}; +use super::{send_headers, send_request, Body, Error, ResponseHeaders, SendBody}; #[allow(unused_imports)] #[cfg(feature = "embedded-svc")] @@ -23,6 +21,7 @@ use super::Method; const COMPLETION_BUF_SIZE: usize = 64; +/// A client connection that can be used to send HTTP requests and receive responses. #[allow(private_interfaces)] pub enum Connection<'b, T, const N: usize = DEFAULT_MAX_HEADERS_COUNT> where @@ -38,6 +37,12 @@ impl<'b, T, const N: usize> Connection<'b, T, N> where T: TcpConnect, { + /// Create a new client connection. + /// + /// Parameters: + /// - `buf`: A buffer to use for reading and writing data. + /// - `socket`: The TCP stack to use for the connection. + /// - `addr`: The address of the server to connect to. pub fn new(buf: &'b mut [u8], socket: &'b T, addr: SocketAddr) -> Self { Self::Unbound(UnboundState { buf, @@ -47,6 +52,7 @@ where }) } + /// Reinitialize the connection with a new address. pub async fn reinitialize(&mut self, addr: SocketAddr) -> Result<(), Error> { let _ = self.complete().await; self.unbound_mut().unwrap().addr = addr; @@ -54,6 +60,7 @@ where Ok(()) } + /// Initiate an HTTP request. pub async fn initiate_request( &mut self, http11: bool, @@ -64,18 +71,7 @@ where self.start_request(http11, method, uri, headers).await } - pub fn is_request_initiated(&self) -> bool { - matches!(self, Self::Request(_)) - } - - pub async fn initiate_response(&mut self) -> Result<(), Error> { - self.complete_request().await - } - - pub fn is_response_initiated(&self) -> bool { - matches!(self, Self::Response(_)) - } - + /// A utility method to initiate a WebSocket upgrade request. pub async fn initiate_ws_upgrade_request( &mut self, host: Option<&str>, @@ -91,6 +87,24 @@ where .await } + /// Return `true` if a request has been initiated. + pub fn is_request_initiated(&self) -> bool { + matches!(self, Self::Request(_)) + } + + /// Initiate an HTTP response. + /// + /// This should be called after a request has been initiated and the request body had been sent. + pub async fn initiate_response(&mut self) -> Result<(), Error> { + self.complete_request().await + } + + /// Return `true` if a response has been initiated. + pub fn is_response_initiated(&self) -> bool { + matches!(self, Self::Response(_)) + } + + /// Return `true` if the server accepted the WebSocket upgrade request. pub fn is_ws_upgrade_accepted( &self, nonce: &[u8; NONCE_LEN], @@ -99,6 +113,9 @@ where Ok(self.headers()?.is_ws_upgrade_accepted(nonce, buf)) } + /// Split the connection into its headers and body parts. + /// + /// The connection must be in response mode. #[allow(clippy::type_complexity)] pub fn split(&mut self) -> (&ResponseHeaders<'b, N>, &mut Body<'b, T::Socket<'b>>) { let response = self.response_mut().expect("Not in response mode"); @@ -106,16 +123,23 @@ where (&response.response, &mut response.io) } + /// Get the headers of the response. + /// + /// The connection must be in response mode. pub fn headers(&self) -> Result<&ResponseHeaders<'b, N>, Error> { let response = self.response_ref()?; Ok(&response.response) } + /// Get a mutable reference to the raw connection. + /// + /// This can be used to send raw data over the connection. pub fn raw_connection(&mut self) -> Result<&mut T::Socket<'b>, Error> { Ok(self.io_mut()) } + /// Release the connection, returning the raw connection and the buffer. pub fn release(mut self) -> (T::Socket<'b>, &'b mut [u8]) { let mut state = self.unbind(); @@ -162,19 +186,17 @@ where let io = state.io.as_mut().unwrap(); - let body_type = send_headers(headers, &mut *io).await?; - send_headers_end(io).await?; - - Ok(body_type) + send_headers(headers, None, true, http11, true, &mut *io).await } .await; match result { - Ok(body_type) => { + Ok((connection_type, body_type)) => { *self = Self::Request(RequestState { buf: state.buf, socket: state.socket, addr: state.addr, + connection_type, io: SendBody::new(body_type, state.io.unwrap()), }); @@ -189,34 +211,45 @@ where } } + /// Complete the request-response cycle + /// + /// If the request has not been initiated, this method will do nothing. + /// If the response has not been initiated, it will be initiated and will be consumed. pub async fn complete(&mut self) -> Result<(), Error> { let result = async { if self.request_mut().is_ok() { self.complete_request().await?; } - if self.response_mut().is_ok() { - self.complete_response().await?; - } + let needs_close = if self.response_mut().is_ok() { + self.complete_response().await? + } else { + true + }; - Result::<(), Error>::Ok(()) + Result::<_, Error>::Ok(needs_close) } .await; let mut state = self.unbind(); - if result.is_err() { - state.io = None; - } + match result { + Ok(true) | Err(_) => state.io = None, + _ => (), + }; *self = Self::Unbound(state); - result + result?; + + Ok(()) } async fn complete_request(&mut self) -> Result<(), Error> { self.request_mut()?.io.finish().await?; + let request_connection_type = self.request_mut()?.connection_type; + let mut state = self.unbind(); let buf_ptr: *mut [u8] = state.buf; @@ -227,18 +260,17 @@ where .await { Ok((buf, read_len)) => { - let io = Body::new( - BodyType::from_headers(response.headers.iter()), - buf, - read_len, - state.io.unwrap(), - ); + let (connection_type, body_type) = + response.resolve::(request_connection_type)?; + + let io = Body::new(body_type, buf, read_len, state.io.unwrap()); *self = Self::Response(ResponseState { buf: buf_ptr, response, socket: state.socket, addr: state.addr, + connection_type, io, }); @@ -255,7 +287,7 @@ where } } - async fn complete_response(&mut self) -> Result<(), Error> { + async fn complete_response(&mut self) -> Result> { if self.request_mut().is_ok() { self.complete_request().await?; } @@ -265,9 +297,19 @@ where let mut buf = [0; COMPLETION_BUF_SIZE]; while response.io.read(&mut buf).await? > 0 {} + let needs_close = response.needs_close(); + *self = Self::Unbound(self.unbind()); - Ok(()) + Ok(needs_close) + } + + /// Return `true` if the connection needs to be closed (i.e. the server has requested it or the connection is in an invalid state) + pub fn needs_close(&self) -> bool { + match self { + Self::Response(response) => response.needs_close(), + _ => true, + } } fn unbind(&mut self) -> UnboundState<'b, T, N> { @@ -391,6 +433,7 @@ where buf: &'b mut [u8], socket: &'b T, addr: SocketAddr, + connection_type: ConnectionType, io: SendBody>, } @@ -402,9 +445,19 @@ where response: ResponseHeaders<'b, N>, socket: &'b T, addr: SocketAddr, + connection_type: ConnectionType, io: Body<'b, T::Socket<'b>>, } +impl ResponseState<'_, T, N> +where + T: TcpConnect, +{ + fn needs_close(&self) -> bool { + matches!(self.connection_type, ConnectionType::Close) || self.io.needs_close() + } +} + #[cfg(feature = "embedded-svc")] mod embedded_svc_compat { use super::*; diff --git a/edge-http/src/io/server.rs b/edge-http/src/io/server.rs index 2239f34..9476669 100644 --- a/edge-http/src/io/server.rs +++ b/edge-http/src/io/server.rs @@ -11,12 +11,10 @@ use embedded_io_async::{ErrorType, Read, Write}; use log::{debug, info, warn}; -use super::{ - send_headers, send_headers_end, send_status, Body, BodyType, Error, RequestHeaders, SendBody, -}; +use super::{send_headers, send_status, Body, Error, RequestHeaders, SendBody}; use crate::ws::{upgrade_response_headers, MAX_BASE64_KEY_RESPONSE_LEN}; -use crate::DEFAULT_MAX_HEADERS_COUNT; +use crate::{ConnectionType, DEFAULT_MAX_HEADERS_COUNT}; #[allow(unused_imports)] #[cfg(feature = "embedded-svc")] @@ -28,6 +26,7 @@ pub const DEFAULT_TIMEOUT_MS: u32 = 5000; const COMPLETION_BUF_SIZE: usize = 64; +/// A connection state machine for handling HTTP server requests-response cycles. #[allow(private_interfaces)] pub enum Connection<'b, T, const N: usize = DEFAULT_MAX_HEADERS_COUNT> { Transition(TransitionState), @@ -40,6 +39,12 @@ impl<'b, T, const N: usize> Connection<'b, T, N> where T: Read + Write, { + /// Create a new connection state machine for an incoming request + /// + /// Parameters: + /// - `buf`: A buffer to store the request headers + /// - `io`: A socket stream + /// - `timeout_ms`: An optional timeout in milliseconds to wait for a new incoming request pub async fn new( buf: &'b mut [u8], mut io: T, @@ -61,34 +66,47 @@ where }? }; - let io = Body::new( - BodyType::from_headers(request.headers.iter()), - buf, - read_len, - io, - ); + let (connection_type, body_type) = request.resolve::()?; - Ok(Self::Request(RequestState { request, io })) + let io = Body::new(body_type, buf, read_len, io); + + Ok(Self::Request(RequestState { + request, + io, + connection_type, + })) } + /// Return `true` of the connection is in request state (i.e. the initial state upon calling `new`) pub fn is_request_initiated(&self) -> bool { matches!(self, Self::Request(_)) } + /// Split the connection into request headers and body pub fn split(&mut self) -> (&RequestHeaders<'b, N>, &mut Body<'b, T>) { let req = self.request_mut().expect("Not in request mode"); (&req.request, &mut req.io) } + /// Return a reference to the request headers pub fn headers(&self) -> Result<&RequestHeaders<'b, N>, Error> { Ok(&self.request_ref()?.request) } + /// Return `true` if the request is a WebSocket upgrade request pub fn is_ws_upgrade_request(&self) -> Result> { Ok(self.headers()?.is_ws_upgrade_request()) } + /// Switch the connection into a response state + /// + /// Parameters: + /// - `status`: The HTTP status code + /// - `message`: An optional HTTP status message + /// - `headers`: An array of HTTP response headers. + /// Note that if no `Content-Length` or `Transfer-Encoding` headers are provided, + /// the body will be send with chunked encoding (for HTTP1.1 only and if the connection is not Close) pub async fn initiate_response( &mut self, status: u16, @@ -98,6 +116,7 @@ where self.complete_request(Some(status), message, headers).await } + /// A convenience method to initiate a WebSocket upgrade response pub async fn initiate_ws_upgrade_response( &mut self, buf: &mut [u8; MAX_BASE64_KEY_RESPONSE_LEN], @@ -107,10 +126,13 @@ where self.initiate_response(101, None, &headers).await } + /// Return `true` if the connection is in response state pub fn is_response_initiated(&self) -> bool { matches!(self, Self::Response(_)) } + /// Completes the response and switches the connection back to the unbound state + /// If the connection is still in a request state, and empty 200 OK response is sent pub async fn complete(&mut self) -> Result<(), Error> { if self.is_request_initiated() { self.complete_request(Some(200), Some("OK"), &[]).await?; @@ -123,6 +145,9 @@ where Ok(()) } + /// Completes the response with an error message and switches the connection back to the unbound state + /// + /// If the connection is still in a request state, an empty 500 Internal Error response is sent pub async fn complete_err(&mut self, err: &str) -> Result<(), Error> { let result = self.request_mut(); @@ -135,8 +160,8 @@ where let response = self.response_mut()?; - response.write_all(err.as_bytes()).await?; - response.finish().await?; + response.io.write_all(err.as_bytes()).await?; + response.io.finish().await?; Ok(()) } @@ -144,6 +169,9 @@ where } } + /// Return `true` if the connection needs to be closed + /// + /// This is determined by the connection type (i.e. `Connection: Close` header) pub fn needs_close(&self) -> bool { match self { Self::Response(response) => response.needs_close(), @@ -151,6 +179,9 @@ where } } + /// Switch the connection to unbound state, returning a mutable reference to the underlying socket stream + /// + /// NOTE: Use with care, and only if the connection is completed in the meantime pub fn unbind(&mut self) -> Result<&mut T, Error> { let io = self.unbind_mut(); *self = Self::Unbound(io); @@ -170,42 +201,33 @@ where while request.io.read(&mut buf).await? > 0 {} let http11 = request.request.http11.unwrap_or(false); + let request_connection_type = request.connection_type; let mut io = self.unbind_mut(); let result = async { send_status(http11, status, reason, &mut io).await?; - let mut body_type = send_headers( - headers.iter().filter(|(k, v)| { - http11 - || !k.eq_ignore_ascii_case("Transfer-Encoding") - || !v.eq_ignore_ascii_case("Chunked") - }), + + let (connection_type, body_type) = send_headers( + headers.iter(), + Some(request_connection_type), + false, + http11, + true, &mut io, ) .await?; - if matches!(body_type, BodyType::Unknown) { - if http11 { - send_headers(&[("Transfer-Encoding", "Chunked")], &mut io).await?; - body_type = BodyType::Chunked; - } else { - body_type = BodyType::Close; - } - }; - - send_headers_end(&mut io).await?; - - Ok(body_type) + Ok((connection_type, body_type)) } .await; match result { - Ok(body_type) => { - *self = Self::Response(SendBody::new( - if http11 { body_type } else { BodyType::Close }, - io, - )); + Ok((connection_type, body_type)) => { + *self = Self::Response(ResponseState { + io: SendBody::new(body_type, io), + connection_type, + }); Ok(()) } @@ -218,7 +240,7 @@ where } async fn complete_response(&mut self) -> Result<(), Error> { - self.response_mut()?.finish().await?; + self.response_mut()?.io.finish().await?; Ok(()) } @@ -228,7 +250,7 @@ where match state { Self::Request(request) => request.io.release(), - Self::Response(response) => response.release(), + Self::Response(response) => response.io.release(), Self::Unbound(io) => io, _ => unreachable!(), } @@ -250,7 +272,7 @@ where } } - fn response_mut(&mut self) -> Result<&mut SendBody, Error> { + fn response_mut(&mut self) -> Result<&mut ResponseState, Error> { if let Self::Response(response) = self { Ok(response) } else { @@ -261,7 +283,7 @@ where fn io_mut(&mut self) -> &mut T { match self { Self::Request(request) => request.io.as_raw_reader(), - Self::Response(response) => response.as_raw_writer(), + Self::Response(response) => response.io.as_raw_writer(), Self::Unbound(io) => io, _ => unreachable!(), } @@ -289,11 +311,11 @@ where T: Read + Write, { async fn write(&mut self, buf: &[u8]) -> Result { - self.response_mut()?.write(buf).await + self.response_mut()?.io.write(buf).await } async fn flush(&mut self) -> Result<(), Self::Error> { - self.response_mut()?.flush().await + self.response_mut()?.io.flush().await } } @@ -302,10 +324,24 @@ struct TransitionState(()); struct RequestState<'b, T, const N: usize> { request: RequestHeaders<'b, N>, io: Body<'b, T>, + connection_type: ConnectionType, +} + +struct ResponseState { + io: SendBody, + connection_type: ConnectionType, } -type ResponseState = SendBody; +impl ResponseState +where + T: Write, +{ + fn needs_close(&self) -> bool { + matches!(self.connection_type, ConnectionType::Close) || self.io.needs_close() + } +} +/// A trait (async callback) for handling incoming HTTP requests pub trait Handler<'b, T, const N: usize> where T: Read + Write, @@ -327,6 +363,10 @@ where } } +/// A trait (async callback) for handling a single HTTP request +/// +/// The only difference between this and `Handler` is that this trait has an additional `task_id` parameter, +/// which is used for logging purposes pub trait TaskHandler<'b, T, const N: usize> where T: Read + Write, @@ -356,9 +396,11 @@ where } } +/// A type that adapts a `Handler` into a `TaskHandler` pub struct TaskHandlerAdaptor(H); impl TaskHandlerAdaptor { + /// Create a new `TaskHandlerAdaptor` from a `Handler` pub const fn new(handler: H) -> Self { Self(handler) } @@ -386,6 +428,11 @@ where } } +/// A convenience function to handle multiple HTTP requests over a single socket stream, +/// using the specified handler. +/// +/// The socket stream will be closed only in case of error, or until the client explicitly requests that +/// either with a hard socket close, or with a `Connection: Close` header. pub async fn handle_connection( io: T, buf: &mut [u8], @@ -398,6 +445,11 @@ pub async fn handle_connection( handle_task_connection(io, buf, timeout_ms, 0, TaskHandlerAdaptor::new(handler)).await } +/// A convenience function to handle multiple HTTP requests over a single socket stream, +/// using the specified task handler. +/// +/// The socket stream will be closed only in case of error, or until the client explicitly requests that +/// either with a hard socket close, or with a `Connection: Close` header. pub async fn handle_task_connection( mut io: T, buf: &mut [u8], @@ -439,9 +491,12 @@ pub async fn handle_task_connection( } } +/// The error type for handling HTTP requests #[derive(Debug)] pub enum HandleRequestError { + /// A connection error (HTTP protocol error or a socket IO error) Connection(Error), + /// A handler error Handler(E), } @@ -472,6 +527,8 @@ where { } +/// A convenience function to handle a single HTTP request over a socket stream, +/// using the specified handler. pub async fn handle_request<'b, const N: usize, H, T>( buf: &'b mut [u8], io: T, @@ -485,6 +542,8 @@ where handle_task_request(buf, io, 0, timeout_ms, TaskHandlerAdaptor::new(handler)).await } +/// A convenience function to handle a single HTTP request over a socket stream, +/// using the specified task handler. pub async fn handle_task_request<'b, const N: usize, H, T>( buf: &'b mut [u8], io: T, @@ -511,11 +570,16 @@ where Ok(connection.needs_close()) } +/// A type alias for an HTTP server with default buffer sizes. pub type DefaultServer = Server<{ DEFAULT_HANDLER_TASKS_COUNT }, { DEFAULT_BUF_SIZE }, { DEFAULT_MAX_HEADERS_COUNT }>; +/// A type alias for the HTTP server buffers (essentially, arrays of `MaybeUninit`) pub type ServerBuffers = MaybeUninit<[[u8; B]; P]>; +/// An HTTP server that can handle multiple requests concurrently. +/// +/// The server needs an implementation of `edge_nal::TcpAccept` to accept incoming connections. #[repr(transparent)] pub struct Server< const P: usize = DEFAULT_HANDLER_TASKS_COUNT, @@ -524,11 +588,13 @@ pub struct Server< >(ServerBuffers); impl Server { + /// Create a new HTTP server #[inline(always)] pub const fn new() -> Self { Self(MaybeUninit::uninit()) } + /// Run the server with the specified acceptor and handler #[inline(never)] #[cold] pub async fn run( @@ -594,6 +660,7 @@ impl Server { result } + /// Run the server with the specified acceptor and task handler #[inline(never)] #[cold] pub async fn run_with_task_id( diff --git a/edge-http/src/lib.rs b/edge-http/src/lib.rs index 3c9dc01..bdb77c6 100644 --- a/edge-http/src/lib.rs +++ b/edge-http/src/lib.rs @@ -2,10 +2,11 @@ #![allow(async_fn_in_trait)] #![warn(clippy::large_futures)] -use core::fmt::Display; +use core::fmt::{self, Display}; use core::str; use httparse::{Header, EMPTY_HEADER}; +use log::{debug, warn}; use ws::{is_upgrade_accepted, is_upgrade_request, MAX_BASE64_KEY_RESPONSE_LEN, NONCE_LEN}; pub const DEFAULT_MAX_HEADERS_COUNT: usize = 64; @@ -13,6 +14,35 @@ pub const DEFAULT_MAX_HEADERS_COUNT: usize = 64; #[cfg(feature = "io")] pub mod io; +/// Errors related to invalid combinations of connection type +/// and body type (Content-Length, Transfer-Encoding) in the headers +#[derive(Debug)] +pub enum HeadersMismatchError { + /// Connection type mismatch: Keep-Alive connection type in the response, + /// while the request contained a Close connection type + ResponseConnectionTypeMismatchError, + /// Body type mismatch: the body type in the headers cannot be used with the specified connection type and HTTP protocol. + /// This is often a user-error, but might also come from the other peer not following the protocol. + /// I.e.: + /// - Chunked body with an HTTP1.0 connection + /// - Raw body with a Keep-Alive connection + /// - etc. + BodyTypeError(&'static str), +} + +impl Display for HeadersMismatchError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::ResponseConnectionTypeMismatchError => write!( + f, + "Response connection type is different from the request connection type" + ), + Self::BodyTypeError(s) => write!(f, "Body type mismatch: {s}"), + } + } +} + +/// Http methods #[derive(Copy, Clone, Debug, PartialEq, Eq)] #[cfg_attr(feature = "std", derive(Hash))] pub enum Method { @@ -164,58 +194,70 @@ impl Method { } impl Display for Method { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.as_str()) } } +/// HTTP headers #[derive(Debug)] pub struct Headers<'b, const N: usize = 64>([httparse::Header<'b>; N]); impl<'b, const N: usize> Headers<'b, N> { + /// Create a new Headers instance #[inline(always)] pub const fn new() -> Self { Self([httparse::EMPTY_HEADER; N]) } + /// Utility method to return the value of the `Content-Length` header, if present pub fn content_len(&self) -> Option { self.get("Content-Length") .map(|content_len_str| content_len_str.parse::().unwrap()) } + /// Utility method to return the value of the `Content-Type` header, if present pub fn content_type(&self) -> Option<&str> { self.get("Content-Type") } + /// Utility method to return the value of the `Content-Encoding` header, if present pub fn content_encoding(&self) -> Option<&str> { self.get("Content-Encoding") } + /// Utility method to return the value of the `Transfer-Encoding` header, if present pub fn transfer_encoding(&self) -> Option<&str> { self.get("Transfer-Encoding") } + /// Utility method to return the value of the `Host` header, if present pub fn host(&self) -> Option<&str> { self.get("Host") } + /// Utility method to return the value of the `Connection` header, if present pub fn connection(&self) -> Option<&str> { self.get("Connection") } + /// Utility method to return the value of the `Cache-Control` header, if present pub fn cache_control(&self) -> Option<&str> { self.get("Cache-Control") } + /// Utility method to return the value of the `Upgrade` header, if present pub fn upgrade(&self) -> Option<&str> { self.get("Upgrade") } + /// Iterate over all headers pub fn iter(&self) -> impl Iterator { self.iter_raw() .map(|(name, value)| (name, unsafe { str::from_utf8_unchecked(value) })) } + /// Iterate over all headers, returning the values as raw byte slices pub fn iter_raw(&self) -> impl Iterator { self.0 .iter() @@ -223,22 +265,26 @@ impl<'b, const N: usize> Headers<'b, N> { .map(|header| (header.name, header.value)) } + /// Get the value of a header by name pub fn get(&self, name: &str) -> Option<&str> { self.iter() .find(|(hname, _)| name.eq_ignore_ascii_case(hname)) .map(|(_, value)| value) } + /// Get the raw value of a header by name, returning the value as a raw byte slice pub fn get_raw(&self, name: &str) -> Option<&[u8]> { self.iter_raw() .find(|(hname, _)| name.eq_ignore_ascii_case(hname)) .map(|(_, value)| value) } + /// Set a header by name and value pub fn set(&mut self, name: &'b str, value: &'b str) -> &mut Self { self.set_raw(name, value.as_bytes()) } + /// Set a header by name and value, using a raw byte slice for the value pub fn set_raw(&mut self, name: &'b str, value: &'b [u8]) -> &mut Self { if !name.is_empty() { for header in &mut self.0 { @@ -254,6 +300,7 @@ impl<'b, const N: usize> Headers<'b, N> { } } + /// Remove a header by name pub fn remove(&mut self, name: &str) -> &mut Self { let index = self .0 @@ -274,6 +321,7 @@ impl<'b, const N: usize> Headers<'b, N> { self } + /// A utility method to set the `Content-Length` header pub fn set_content_len( &mut self, content_len: u64, @@ -284,58 +332,73 @@ impl<'b, const N: usize> Headers<'b, N> { self.set("Content-Length", buf.as_str()) } + /// A utility method to set the `Content-Type` header pub fn set_content_type(&mut self, content_type: &'b str) -> &mut Self { self.set("Content-Type", content_type) } + /// A utility method to set the `Content-Encoding` header pub fn set_content_encoding(&mut self, content_encoding: &'b str) -> &mut Self { self.set("Content-Encoding", content_encoding) } + /// A utility method to set the `Transfer-Encoding` header pub fn set_transfer_encoding(&mut self, transfer_encoding: &'b str) -> &mut Self { self.set("Transfer-Encoding", transfer_encoding) } + /// A utility method to set the `Transfer-Encoding: Chunked` header pub fn set_transfer_encoding_chunked(&mut self) -> &mut Self { self.set_transfer_encoding("Chunked") } + /// A utility method to set the `Host` header pub fn set_host(&mut self, host: &'b str) -> &mut Self { self.set("Host", host) } + /// A utility method to set the `Connection` header pub fn set_connection(&mut self, connection: &'b str) -> &mut Self { self.set("Connection", connection) } + /// A utility method to set the `Connection: Close` header pub fn set_connection_close(&mut self) -> &mut Self { self.set_connection("Close") } + /// A utility method to set the `Connection: Keep-Alive` header pub fn set_connection_keep_alive(&mut self) -> &mut Self { self.set_connection("Keep-Alive") } + /// A utility method to set the `Connection: Upgrade` header pub fn set_connection_upgrade(&mut self) -> &mut Self { self.set_connection("Upgrade") } + /// A utility method to set the `Cache-Control` header pub fn set_cache_control(&mut self, cache: &'b str) -> &mut Self { self.set("Cache-Control", cache) } + /// A utility method to set the `Cache-Control: No-Cache` header pub fn set_cache_control_no_cache(&mut self) -> &mut Self { self.set_cache_control("No-Cache") } + /// A utility method to set the `Upgrade` header pub fn set_upgrade(&mut self, upgrade: &'b str) -> &mut Self { self.set("Upgrade", upgrade) } + /// A utility method to set the `Upgrade: websocket` header pub fn set_upgrade_websocket(&mut self) -> &mut Self { self.set_upgrade("websocket") } + /// A utility method to set all Websocket upgrade request headers, + /// including the `Sec-WebSocket-Key` header with the base64-encoded nonce pub fn set_ws_upgrade_request_headers( &mut self, host: Option<&'b str>, @@ -351,6 +414,8 @@ impl<'b, const N: usize> Headers<'b, N> { self } + /// A utility method to set all Websocket upgrade response headers + /// including the `Sec-WebSocket-Accept` header with the base64-encoded response pub fn set_ws_upgrade_response_headers<'a, H>( &mut self, request_headers: H, @@ -374,54 +439,285 @@ impl Default for Headers<'_, N> { } } -#[derive(Copy, Clone, Eq, PartialEq, Debug)] +/// Connection type +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] +pub enum ConnectionType { + KeepAlive, + Close, +} + +impl ConnectionType { + /// Resolve the connection type + /// + /// Resolution is based on: + /// - The connection type found in the headers, if any + /// - (if the above is missing) based on the carry-over connection type, if any + /// - (if the above is missing) based on the HTTP version + /// + /// Parameters: + /// - `headers_connection_type`: The connection type found in the headers, if any + /// - `carry_over_connection_type`: The carry-over connection type + /// (i.e. if this is a response, the `carry_over_connection_type` is the connection type of the request) + /// - `http11`: Whether the HTTP protocol is 1.1 + pub fn resolve( + headers_connection_type: Option, + carry_over_connection_type: Option, + http11: bool, + ) -> Result { + match headers_connection_type { + Some(connection_type) => { + if let Some(carry_over_connection_type) = carry_over_connection_type { + if matches!(connection_type, ConnectionType::KeepAlive) + && matches!(carry_over_connection_type, ConnectionType::Close) + { + warn!("Cannot set a Keep-Alive connection when the peer requested Close"); + Err(HeadersMismatchError::ResponseConnectionTypeMismatchError)?; + } + } + + Ok(connection_type) + } + None => { + if let Some(carry_over_connection_type) = carry_over_connection_type { + Ok(carry_over_connection_type) + } else if http11 { + Ok(Self::KeepAlive) + } else { + Ok(Self::Close) + } + } + } + } + + /// Create a connection type from a header + /// + /// If the header is not a `Connection` header, this method returns `None` + pub fn from_header(name: &str, value: &str) -> Option { + if "Connection".eq_ignore_ascii_case(name) && value.eq_ignore_ascii_case("Close") { + Some(Self::Close) + } else if "Connection".eq_ignore_ascii_case(name) + && value.eq_ignore_ascii_case("Keep-Alive") + { + Some(Self::KeepAlive) + } else { + None + } + } + + /// Create a connection type from headers + /// + /// If multiple `Connection` headers are found, this method logs a warning and returns the last one + /// If no `Connection` headers are found, this method returns `None` + pub fn from_headers<'a, H>(headers: H) -> Option + where + H: IntoIterator, + { + let mut connection = None; + + for (name, value) in headers { + let header_connection = Self::from_header(name, value); + + if let Some(header_connection) = header_connection { + if let Some(connection) = connection { + warn!("Multiple Connection headers found. Current {connection} and new {header_connection}"); + } + + // The last connection header wins + connection = Some(header_connection); + } + } + + connection + } + + /// Create a raw header from the connection type + pub fn raw_header(&self) -> (&str, &[u8]) { + let connection = match self { + Self::KeepAlive => "Keep-Alive", + Self::Close => "Close", + }; + + ("Connection", connection.as_bytes()) + } +} + +impl Display for ConnectionType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::KeepAlive => write!(f, "Keep-Alive"), + Self::Close => write!(f, "Close"), + } + } +} + +/// Body type +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)] pub enum BodyType { + /// Chunked body (Transfer-Encoding: Chunked) Chunked, + /// Content-length body (Content-Length: {len}) ContentLen(u64), - Close, - Unknown, + /// Raw body - can only be used with responses, where the connection type is `Close` + Raw, } impl BodyType { - pub fn from_header(name: &str, value: &str) -> Self { + /// Resolve the body type + /// + /// Resolution is based on: + /// - The body type found in the headers (i.e. `Content-Length` and/or `Transfer-Encoding`), if any + /// - (if the above is missing) based on the resolved connection type, HTTP protocol and whether we are dealing with a request or a response + /// + /// Parameters: + /// - `headers_body_type`: The body type found in the headers, if any + /// - `connection_type`: The resolved connection type + /// - `request`: Whether we are dealing with a request or a response + /// - `http11`: Whether the HTTP protocol is 1.1 + /// - `chunked_if_unspecified`: (HTTP1.1 only) Upgrades the body type to Chunked if requested so and if no body was specified in the headers + pub fn resolve( + headers_body_type: Option, + connection_type: ConnectionType, + request: bool, + http11: bool, + chunked_if_unspecified: bool, + ) -> Result { + match headers_body_type { + Some(headers_body_type) => { + match headers_body_type { + BodyType::Raw => { + if request { + warn!("Raw body in a request. This is not allowed."); + Err(HeadersMismatchError::BodyTypeError( + "Raw body in a request. This is not allowed.", + ))?; + } else if !matches!(connection_type, ConnectionType::Close) { + warn!("Raw body response with a Keep-Alive connection. This is not allowed."); + Err(HeadersMismatchError::BodyTypeError("Raw body response with a Keep-Alive connection. This is not allowed."))?; + } + } + BodyType::Chunked => { + if !http11 { + warn!("Chunked body with an HTTP/1.0 connection. This is not allowed."); + Err(HeadersMismatchError::BodyTypeError( + "Chunked body with an HTTP/1.0 connection. This is not allowed.", + ))?; + } + } + _ => {} + } + + Ok(headers_body_type) + } + None => { + if request { + if chunked_if_unspecified && http11 { + // With HTTP1.1 we can safely upgrade the body to a chunked one + Ok(BodyType::Chunked) + } else { + debug!("Unknown body type in a request. Assuming Content-Length=0."); + Ok(BodyType::ContentLen(0)) + } + } else if matches!(connection_type, ConnectionType::Close) { + Ok(BodyType::Raw) + } else if chunked_if_unspecified && http11 { + // With HTTP1.1 we can safely upgrade the body to a chunked one + Ok(BodyType::Chunked) + } else { + warn!("Unknown body type in a response with a Keep-Alive connection. This is not allowed."); + Err(HeadersMismatchError::BodyTypeError("Unknown body type in a response with a Keep-Alive connection. This is not allowed.")) + } + } + } + } + + /// Create a body type from a header + /// + /// If the header is not a `Content-Length` or `Transfer-Encoding` header, this method returns `None` + pub fn from_header(name: &str, value: &str) -> Option { if "Transfer-Encoding".eq_ignore_ascii_case(name) { if value.eq_ignore_ascii_case("Chunked") { - return Self::Chunked; + return Some(Self::Chunked); } } else if "Content-Length".eq_ignore_ascii_case(name) { - return Self::ContentLen(value.parse::().unwrap()); // TODO - } else if "Connection".eq_ignore_ascii_case(name) && value.eq_ignore_ascii_case("Close") { - return Self::Close; + return Some(Self::ContentLen(value.parse::().unwrap())); // TODO } - Self::Unknown + None } - pub fn from_headers<'a, H>(headers: H) -> Self + /// Create a body type from headers + /// + /// If multiple body type headers are found, this method logs a warning and returns the last one + /// If no body type headers are found, this method returns `None` + pub fn from_headers<'a, H>(headers: H) -> Option where H: IntoIterator, { + let mut body = None; + for (name, value) in headers { - let body = Self::from_header(name, value); + let header_body = Self::from_header(name, value); + + if let Some(header_body) = header_body { + if let Some(body) = body { + warn!("Multiple body type headers found. Current {body} and new {header_body}"); + } - if body != Self::Unknown { - return body; + // The last body header wins + body = Some(header_body); } } - Self::Unknown + body + } + + /// Create a raw header from the body type + /// + /// If the body type is `Raw`, this method returns `None` as a raw body cannot be + /// represented in a header and is rather, a consequence of using connection type `Close` + /// with HTTP server responses + pub fn raw_header<'a>(&self, buf: &'a mut heapless::String<20>) -> Option<(&str, &'a [u8])> { + match self { + Self::Chunked => Some(("Transfer-Encoding", "Chunked".as_bytes())), + Self::ContentLen(len) => { + use core::fmt::Write; + + buf.clear(); + + write!(buf, "{}", len).unwrap(); + + Some(("Content-Length", buf.as_bytes())) + } + Self::Raw => None, + } } } +impl Display for BodyType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Chunked => write!(f, "Chunked"), + Self::ContentLen(len) => write!(f, "Content-Length: {len}"), + Self::Raw => write!(f, "Raw"), + } + } +} + +/// Request headers including the request line (method, path) #[derive(Default, Debug)] pub struct RequestHeaders<'b, const N: usize> { + /// Whether the request is HTTP/1.1, if present. If not present, HTTP/1.0 should be assumed pub http11: Option, + /// The HTTP method, if present pub method: Option, + /// The request path, if present pub path: Option<&'b str>, + /// The headers pub headers: Headers<'b, N>, } impl RequestHeaders<'_, N> { + /// Create a new RequestHeaders instance for HTTP/1.1 #[inline(always)] pub const fn new() -> Self { Self { @@ -432,13 +728,14 @@ impl RequestHeaders<'_, N> { } } + /// A utility method to check if the request is a Websocket upgrade request pub fn is_ws_upgrade_request(&self) -> bool { is_upgrade_request(self.method, self.headers.iter()) } } impl Display for RequestHeaders<'_, N> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(http11) = self.http11 { write!(f, "{} ", if http11 { "HTTP/1.1" } else { "HTTP/1.0" })?; } @@ -459,15 +756,21 @@ impl Display for RequestHeaders<'_, N> { } } +/// Response headers including the response line (HTTP version, status code, reason phrase) #[derive(Default, Debug)] pub struct ResponseHeaders<'b, const N: usize> { + /// Whether the response is HTTP/1.1, if present. If not present, HTTP/1.0 should be assumed pub http11: Option, + /// The status code, if present pub code: Option, + /// The reason phrase, if present pub reason: Option<&'b str>, + /// The headers pub headers: Headers<'b, N>, } impl ResponseHeaders<'_, N> { + /// Create a new ResponseHeaders instance for HTTP/1.1 #[inline(always)] pub const fn new() -> Self { Self { @@ -478,6 +781,8 @@ impl ResponseHeaders<'_, N> { } } + /// A utility method to check if the response is a Websocket upgrade response + /// and if the upgrade was accepted pub fn is_ws_upgrade_accepted( &self, nonce: &[u8; NONCE_LEN], @@ -488,7 +793,7 @@ impl ResponseHeaders<'_, N> { } impl Display for ResponseHeaders<'_, N> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if let Some(http11) = self.http11 { writeln!(f, "{} ", if http11 { "HTTP/1.1 " } else { "HTTP/1.0" })?; } @@ -509,6 +814,7 @@ impl Display for ResponseHeaders<'_, N> { } } +/// Websocket utilities pub mod ws { use core::fmt; @@ -523,6 +829,14 @@ pub mod ws { pub const UPGRADE_REQUEST_HEADERS_LEN: usize = 7; pub const UPGRADE_RESPONSE_HEADERS_LEN: usize = 4; + /// Return ready-to-use WS upgrade request headers + /// + /// Parameters: + /// - `host`: The `Host` header, if present + /// - `origin`: The `Origin` header, if present + /// - `version`: The `Sec-WebSocket-Version` header, if present; otherwise version "13" is assumed + /// - `nonce`: The nonce to use for the `Sec-WebSocket-Key` header + /// - `buf`: A buffer to use for base64 encoding the nonce pub fn upgrade_request_headers<'a>( host: Option<&'a str>, origin: Option<&'a str>, @@ -544,6 +858,7 @@ pub mod ws { ] } + /// Check if the request is a Websocket upgrade request pub fn is_upgrade_request<'a, H>(method: Option, request_headers: H) -> bool where H: IntoIterator, @@ -566,10 +881,14 @@ pub mod ws { connection && upgrade } + /// Websocket upgrade errors #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum UpgradeError { + /// No `Sec-WebSocket-Version` header NoVersion, + /// No `Sec-WebSocket-Key` header NoSecKey, + /// Unsupported `Sec-WebSocket-Version` UnsupportedVersion, } @@ -586,6 +905,12 @@ pub mod ws { #[cfg(feature = "std")] impl std::error::Error for UpgradeError {} + /// Return ready-to-use WS upgrade response headers + /// + /// Parameters: + /// - `request_headers`: The request headers + /// - `version`: The `Sec-WebSocket-Version` header, if present; otherwise version "13" is assumed + /// - `buf`: A buffer to use for base64 encoding bits and pieces of the response pub fn upgrade_response_headers<'a, 'b, H>( request_headers: H, version: Option<&'a str>, @@ -627,6 +952,13 @@ pub mod ws { } } + /// Check if the response is a Websocket upgrade response and if the upgrade was accepted + /// + /// Parameters: + /// - `code`: The status response code + /// - `response_headers`: The response headers + /// - `nonce`: The nonce used for the `Sec-WebSocket-Key` header in the WS upgrade request + /// - `buf`: A buffer to use when performing the check pub fn is_upgrade_accepted<'a, H>( code: Option, response_headers: H, @@ -670,6 +1002,7 @@ pub mod ws { unsafe { core::str::from_utf8_unchecked(&buf[..nonce_base64_len]) } } + /// Compute the response for a given `Sec-WebSocket-Key` pub fn sec_key_response<'a>( sec_key: &str, buf: &'a mut [u8; MAX_BASE64_KEY_RESPONSE_LEN], @@ -703,7 +1036,10 @@ pub mod ws { #[cfg(test)] mod test { - use crate::ws::{sec_key_response, MAX_BASE64_KEY_RESPONSE_LEN}; + use crate::{ + ws::{sec_key_response, MAX_BASE64_KEY_RESPONSE_LEN}, + BodyType, ConnectionType, + }; #[test] fn test_resp() { @@ -712,6 +1048,275 @@ mod test { assert_eq!(resp, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="); } + + #[test] + fn test_resolve_conn() { + // Default connection type resolution + assert_eq!( + ConnectionType::resolve(None, None, true).unwrap(), + ConnectionType::KeepAlive + ); + assert_eq!( + ConnectionType::resolve(None, None, false).unwrap(), + ConnectionType::Close + ); + + // Connection type resolution based on carry-over (for responses) + assert_eq!( + ConnectionType::resolve(None, Some(ConnectionType::KeepAlive), false).unwrap(), + ConnectionType::KeepAlive + ); + assert_eq!( + ConnectionType::resolve(None, Some(ConnectionType::KeepAlive), true).unwrap(), + ConnectionType::KeepAlive + ); + + // Connection type resoluton based on the header value + assert_eq!( + ConnectionType::resolve(Some(ConnectionType::Close), None, false).unwrap(), + ConnectionType::Close + ); + assert_eq!( + ConnectionType::resolve(Some(ConnectionType::KeepAlive), None, false).unwrap(), + ConnectionType::KeepAlive + ); + assert_eq!( + ConnectionType::resolve(Some(ConnectionType::Close), None, true).unwrap(), + ConnectionType::Close + ); + assert_eq!( + ConnectionType::resolve(Some(ConnectionType::KeepAlive), None, true).unwrap(), + ConnectionType::KeepAlive + ); + + // Connection type in the headers should aggree with the carry-over one + assert_eq!( + ConnectionType::resolve( + Some(ConnectionType::Close), + Some(ConnectionType::Close), + false + ) + .unwrap(), + ConnectionType::Close + ); + assert_eq!( + ConnectionType::resolve( + Some(ConnectionType::KeepAlive), + Some(ConnectionType::KeepAlive), + false + ) + .unwrap(), + ConnectionType::KeepAlive + ); + assert_eq!( + ConnectionType::resolve( + Some(ConnectionType::Close), + Some(ConnectionType::Close), + true + ) + .unwrap(), + ConnectionType::Close + ); + assert_eq!( + ConnectionType::resolve( + Some(ConnectionType::KeepAlive), + Some(ConnectionType::KeepAlive), + true + ) + .unwrap(), + ConnectionType::KeepAlive + ); + assert_eq!( + ConnectionType::resolve( + Some(ConnectionType::Close), + Some(ConnectionType::KeepAlive), + false + ) + .unwrap(), + ConnectionType::Close + ); + assert!(ConnectionType::resolve( + Some(ConnectionType::KeepAlive), + Some(ConnectionType::Close), + false + ) + .is_err()); + assert_eq!( + ConnectionType::resolve( + Some(ConnectionType::Close), + Some(ConnectionType::KeepAlive), + true + ) + .unwrap(), + ConnectionType::Close + ); + assert!(ConnectionType::resolve( + Some(ConnectionType::KeepAlive), + Some(ConnectionType::Close), + true + ) + .is_err()); + } + + #[test] + fn test_resolve_body() { + // Request with no body type specified means Content-Length=0 + assert_eq!( + BodyType::resolve(None, ConnectionType::KeepAlive, true, true, false).unwrap(), + BodyType::ContentLen(0) + ); + assert_eq!( + BodyType::resolve(None, ConnectionType::Close, true, true, false).unwrap(), + BodyType::ContentLen(0) + ); + assert_eq!( + BodyType::resolve(None, ConnectionType::KeepAlive, true, false, false).unwrap(), + BodyType::ContentLen(0) + ); + assert_eq!( + BodyType::resolve(None, ConnectionType::Close, true, false, false).unwrap(), + BodyType::ContentLen(0) + ); + + // Request or response with a chunked body type is invalid for HTTP1.0 + assert!(BodyType::resolve( + Some(BodyType::Chunked), + ConnectionType::Close, + true, + false, + false + ) + .is_err()); + assert!(BodyType::resolve( + Some(BodyType::Chunked), + ConnectionType::KeepAlive, + true, + false, + false + ) + .is_err()); + assert!(BodyType::resolve( + Some(BodyType::Chunked), + ConnectionType::Close, + false, + false, + false + ) + .is_err()); + assert!(BodyType::resolve( + Some(BodyType::Chunked), + ConnectionType::KeepAlive, + false, + false, + false + ) + .is_err()); + + // Raw body in a request is not allowed + assert!(BodyType::resolve( + Some(BodyType::Raw), + ConnectionType::Close, + true, + true, + false + ) + .is_err()); + assert!(BodyType::resolve( + Some(BodyType::Raw), + ConnectionType::KeepAlive, + true, + true, + false + ) + .is_err()); + assert!(BodyType::resolve( + Some(BodyType::Raw), + ConnectionType::Close, + true, + false, + false + ) + .is_err()); + assert!(BodyType::resolve( + Some(BodyType::Raw), + ConnectionType::KeepAlive, + true, + false, + false + ) + .is_err()); + + // Raw body in a response with a Keep-Alive connection is not allowed + assert!(BodyType::resolve( + Some(BodyType::Raw), + ConnectionType::KeepAlive, + false, + true, + false + ) + .is_err()); + assert!(BodyType::resolve( + Some(BodyType::Raw), + ConnectionType::KeepAlive, + false, + false, + false + ) + .is_err()); + + // The same, but with a Close connection IS allowed + assert_eq!( + BodyType::resolve( + Some(BodyType::Raw), + ConnectionType::Close, + false, + true, + false + ) + .unwrap(), + BodyType::Raw + ); + assert_eq!( + BodyType::resolve( + Some(BodyType::Raw), + ConnectionType::Close, + false, + false, + false + ) + .unwrap(), + BodyType::Raw + ); + + // Request upgrades to chunked encoding should only work for HTTP1.1, and if there is no body type in the headers + assert_eq!( + BodyType::resolve(None, ConnectionType::Close, true, true, true).unwrap(), + BodyType::Chunked + ); + assert_eq!( + BodyType::resolve(None, ConnectionType::KeepAlive, true, true, true).unwrap(), + BodyType::Chunked + ); + assert_eq!( + BodyType::resolve(None, ConnectionType::Close, true, false, true).unwrap(), + BodyType::ContentLen(0) + ); + assert_eq!( + BodyType::resolve(None, ConnectionType::KeepAlive, true, false, true).unwrap(), + BodyType::ContentLen(0) + ); + + // Response upgrades to chunked encoding should only work for HTTP1.1, and if there is no body type in the headers, and if the connection is KeepAlive + assert_eq!( + BodyType::resolve(None, ConnectionType::KeepAlive, false, true, true).unwrap(), + BodyType::Chunked + ); + // Response upgrades should not be honored if the connection is Close + assert_eq!( + BodyType::resolve(None, ConnectionType::Close, false, true, true).unwrap(), + BodyType::Raw + ); + } } #[cfg(feature = "embedded-svc")]