|
| 1 | +//! WebSocket connection manager for [`axum`] |
| 2 | +//! |
| 3 | +//! How this works: |
| 4 | +//! `axum` does not provide a connection pattern that allows us to iplement |
| 5 | +//! [`Listener`] or [`Connect`] directly. Instead, it uses a |
| 6 | +//! [`WebSocketUpgrade`] to upgrade a connection to a WebSocket. This means |
| 7 | +//! that we cannot use the [`Listener`] trait directly. Instead, we make a |
| 8 | +//! [`AxumWsCfg`] that will be the [`State`] for our handler. |
| 9 | +//! |
| 10 | +//! The [`ajj_websocket`] handler serves the role of the [`Listener`] in this |
| 11 | +//! case. |
| 12 | +//! |
| 13 | +//! [`Connect`]: crate::pubsub::Connect |
| 14 | +
|
| 15 | +use crate::{ |
| 16 | + pubsub::{shared::ConnectionManager, Listener}, |
| 17 | + Router, |
| 18 | +}; |
| 19 | +use axum::{ |
| 20 | + extract::{ |
| 21 | + ws::{Message, WebSocket}, |
| 22 | + State, WebSocketUpgrade, |
| 23 | + }, |
| 24 | + response::Response, |
| 25 | +}; |
| 26 | +use bytes::Bytes; |
| 27 | +use futures_util::{ |
| 28 | + stream::{SplitSink, SplitStream}, |
| 29 | + SinkExt, Stream, StreamExt, |
| 30 | +}; |
| 31 | +use serde_json::value::RawValue; |
| 32 | +use std::{ |
| 33 | + convert::Infallible, |
| 34 | + pin::Pin, |
| 35 | + sync::Arc, |
| 36 | + task::{ready, Context, Poll}, |
| 37 | +}; |
| 38 | +use tokio::runtime::Handle; |
| 39 | +use tracing::debug; |
| 40 | + |
| 41 | +pub(crate) type SendHalf = SplitSink<WebSocket, Message>; |
| 42 | +pub(crate) type RecvHalf = SplitStream<WebSocket>; |
| 43 | + |
| 44 | +struct AxumListener; |
| 45 | + |
| 46 | +impl Listener for AxumListener { |
| 47 | + type RespSink = SendHalf; |
| 48 | + |
| 49 | + type ReqStream = WsJsonStream; |
| 50 | + |
| 51 | + type Error = Infallible; |
| 52 | + |
| 53 | + async fn accept(&self) -> Result<(Self::RespSink, Self::ReqStream), Self::Error> { |
| 54 | + unreachable!() |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +/// Configuration details for WebSocket connections using [`axum::extract::ws`]. |
| 59 | +/// |
| 60 | +/// The main points of configuration are: |
| 61 | +/// - The runtime [`Handle`] on which to execute tasks, which can be set with |
| 62 | +/// [`Self::with_handle`]. This defaults to the current thread's runtime |
| 63 | +/// handle. |
| 64 | +/// - The notification buffer size per client, which can be set with |
| 65 | +/// [`Self::with_notification_buffer_per_client`]. See the [`crate::pubsub`] |
| 66 | +/// module documentation for more details. |
| 67 | +/// |
| 68 | +/// This struct is used as the [`State`] for the [`ajj_websocket`] handler, and |
| 69 | +/// should be created from a fully-configured [`Router<()>`]. |
| 70 | +/// |
| 71 | +/// # Note |
| 72 | +/// |
| 73 | +/// If [`AxumWsCfg`] is NOT used within a `tokio` runtime, |
| 74 | +/// [`AxumWsCfg::with_handle`] MUST be called to set the runtime handle before |
| 75 | +/// any requests are routed. Attempting to execute a task without an active |
| 76 | +/// runtime will result in a panic. |
| 77 | +/// |
| 78 | +/// # Example |
| 79 | +/// |
| 80 | +/// ```no_run |
| 81 | +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] |
| 82 | +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; |
| 83 | +/// # { |
| 84 | +/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>, handle: tokio::runtime::Handle) { |
| 85 | +/// let cfg = AxumWsCfg::from(router) |
| 86 | +/// .with_handle(handle) |
| 87 | +/// .with_notification_buffer_per_client(10); |
| 88 | +/// # }} |
| 89 | +/// ``` |
| 90 | +#[derive(Clone)] |
| 91 | +pub struct AxumWsCfg { |
| 92 | + inner: Arc<ConnectionManager>, |
| 93 | +} |
| 94 | + |
| 95 | +impl core::fmt::Debug for AxumWsCfg { |
| 96 | + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
| 97 | + f.debug_struct("AxumWsCfg") |
| 98 | + .field( |
| 99 | + "notification_buffer_per_client", |
| 100 | + &self.inner.notification_buffer_per_task, |
| 101 | + ) |
| 102 | + .field("next_id", &self.inner.next_id) |
| 103 | + .finish() |
| 104 | + } |
| 105 | +} |
| 106 | + |
| 107 | +impl From<Router<()>> for AxumWsCfg { |
| 108 | + fn from(router: Router<()>) -> Self { |
| 109 | + Self::new(router) |
| 110 | + } |
| 111 | +} |
| 112 | + |
| 113 | +impl AxumWsCfg { |
| 114 | + /// Create a new [`AxumWsCfg`] with the given [`Router`]. |
| 115 | + pub fn new(router: Router<()>) -> Self { |
| 116 | + Self { |
| 117 | + inner: ConnectionManager::new(router).into(), |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + fn into_inner(self) -> ConnectionManager { |
| 122 | + match Arc::try_unwrap(self.inner) { |
| 123 | + Ok(inner) => inner, |
| 124 | + Err(arc) => ConnectionManager { |
| 125 | + root_tasks: arc.root_tasks.clone(), |
| 126 | + next_id: arc.next_id.clone(), |
| 127 | + router: arc.router.clone(), |
| 128 | + notification_buffer_per_task: arc.notification_buffer_per_task, |
| 129 | + }, |
| 130 | + } |
| 131 | + } |
| 132 | + |
| 133 | + /// Set the handle on which to execute tasks. |
| 134 | + pub fn with_handle(self, handle: Handle) -> Self { |
| 135 | + Self { |
| 136 | + inner: self.into_inner().with_handle(handle).into(), |
| 137 | + } |
| 138 | + } |
| 139 | + |
| 140 | + /// Set the notification buffer size per client. See the [`crate::pubsub`] |
| 141 | + /// module documentation for more details. |
| 142 | + pub fn with_notification_buffer_per_client( |
| 143 | + self, |
| 144 | + notification_buffer_per_client: usize, |
| 145 | + ) -> Self { |
| 146 | + Self { |
| 147 | + inner: self |
| 148 | + .into_inner() |
| 149 | + .with_notification_buffer_per_client(notification_buffer_per_client) |
| 150 | + .into(), |
| 151 | + } |
| 152 | + } |
| 153 | +} |
| 154 | + |
| 155 | +/// Axum handler for WebSocket connections. |
| 156 | +/// |
| 157 | +/// Used to serve [`crate::Router`]s over WebSocket connections via [`axum`]'s |
| 158 | +/// built-in WebSocket support. This handler is used in conjunction with |
| 159 | +/// [`AxumWsCfg`], which is passed as the [`State`] to the handler. |
| 160 | +/// |
| 161 | +/// # Examples |
| 162 | +/// |
| 163 | +/// Basic usage: |
| 164 | +/// |
| 165 | +/// ```no_run |
| 166 | +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] |
| 167 | +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; |
| 168 | +/// # { |
| 169 | +/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{ |
| 170 | +/// // The config object contains the tokio runtime handle, and the |
| 171 | +/// // notification buffer size. |
| 172 | +/// let cfg = AxumWsCfg::new(router); |
| 173 | +/// |
| 174 | +/// axum |
| 175 | +/// .route("/ws", axum::routing::any(ajj_websocket)) |
| 176 | +/// .with_state(cfg) |
| 177 | +/// # }} |
| 178 | +/// ``` |
| 179 | +/// |
| 180 | +/// The [`Router`] is a property of the [`AxumWsCfg`]. This means it is not |
| 181 | +/// paramterized until the [`axum::Router::with_state`] method is called. This |
| 182 | +/// has two significant consequences: |
| 183 | +/// 1. You can easily register the same [`Router`] with multiple handlers. |
| 184 | +/// 2. In order to register a second [`Router`] you need a second [`AxumWsCfg`]. |
| 185 | +/// |
| 186 | +/// Registering the same [`Router`] with multiple handlers: |
| 187 | +/// |
| 188 | +/// ```no_run |
| 189 | +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] |
| 190 | +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; |
| 191 | +/// # { |
| 192 | +/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{ |
| 193 | +/// // The config object contains the tokio runtime handle, and the |
| 194 | +/// // notification buffer size. |
| 195 | +/// let cfg = AxumWsCfg::new(router); |
| 196 | +/// |
| 197 | +/// axum |
| 198 | +/// .route("/ws", axum::routing::any(ajj_websocket)) |
| 199 | +/// .route("/super-secret-ws", axum::routing::any(ajj_websocket)) |
| 200 | +/// .with_state(cfg) |
| 201 | +/// # }} |
| 202 | +/// ``` |
| 203 | +/// |
| 204 | +/// Registering a second [`Router`] at a different path: |
| 205 | +/// |
| 206 | +/// ```no_run |
| 207 | +/// # #[cfg(all(feature = "axum", feature = "pubsub"))] |
| 208 | +/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}}; |
| 209 | +/// # { |
| 210 | +/// # async fn _main(router: Router<()>, other_router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{ |
| 211 | +/// // The config object contains the tokio runtime handle, and the |
| 212 | +/// // notification buffer size. |
| 213 | +/// let cfg = AxumWsCfg::new(router); |
| 214 | +/// let other_cfg = AxumWsCfg::new(other_router); |
| 215 | +/// |
| 216 | +/// axum |
| 217 | +/// .route("/really-cool-ws-1", axum::routing::any(ajj_websocket)) |
| 218 | +/// .with_state(cfg) |
| 219 | +/// .route("/even-cooler-ws-2", axum::routing::any(ajj_websocket)) |
| 220 | +/// .with_state(other_cfg) |
| 221 | +/// # }} |
| 222 | +/// ``` |
| 223 | +pub async fn ajj_websocket(ws: WebSocketUpgrade, State(state): State<AxumWsCfg>) -> Response { |
| 224 | + ws.on_upgrade(move |ws| { |
| 225 | + let (sink, stream) = ws.split(); |
| 226 | + |
| 227 | + state |
| 228 | + .inner |
| 229 | + .handle_new_connection::<AxumListener>(stream.into(), sink); |
| 230 | + |
| 231 | + async {} |
| 232 | + }) |
| 233 | +} |
| 234 | + |
| 235 | +/// Simple stream adapter for extracting text from a [`WebSocket`]. |
| 236 | +#[derive(Debug)] |
| 237 | +struct WsJsonStream { |
| 238 | + inner: RecvHalf, |
| 239 | + complete: bool, |
| 240 | +} |
| 241 | + |
| 242 | +impl From<RecvHalf> for WsJsonStream { |
| 243 | + fn from(inner: RecvHalf) -> Self { |
| 244 | + Self { |
| 245 | + inner, |
| 246 | + complete: false, |
| 247 | + } |
| 248 | + } |
| 249 | +} |
| 250 | + |
| 251 | +impl WsJsonStream { |
| 252 | + /// Handle an incoming [`Message`] |
| 253 | + fn handle(&self, message: Message) -> Result<Option<Bytes>, &'static str> { |
| 254 | + match message { |
| 255 | + Message::Text(text) => Ok(Some(text.into())), |
| 256 | + Message::Close(Some(frame)) => { |
| 257 | + let s = "Received close frame with data"; |
| 258 | + let reason = format!("{} ({})", frame.reason, frame.code); |
| 259 | + debug!(%reason, "{}", &s); |
| 260 | + Err(s) |
| 261 | + } |
| 262 | + Message::Close(None) => { |
| 263 | + let s = "WS client has gone away"; |
| 264 | + debug!("{}", &s); |
| 265 | + Err(s) |
| 266 | + } |
| 267 | + _ => Ok(None), |
| 268 | + } |
| 269 | + } |
| 270 | +} |
| 271 | + |
| 272 | +impl Stream for WsJsonStream { |
| 273 | + type Item = Bytes; |
| 274 | + |
| 275 | + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 276 | + loop { |
| 277 | + if self.complete { |
| 278 | + return Poll::Ready(None); |
| 279 | + } |
| 280 | + |
| 281 | + let Some(Ok(msg)) = ready!(self.inner.poll_next_unpin(cx)) else { |
| 282 | + self.complete = true; |
| 283 | + return Poll::Ready(None); |
| 284 | + }; |
| 285 | + |
| 286 | + match self.handle(msg) { |
| 287 | + Ok(Some(item)) => return Poll::Ready(Some(item)), |
| 288 | + Ok(None) => continue, |
| 289 | + Err(_) => self.complete = true, |
| 290 | + } |
| 291 | + } |
| 292 | + } |
| 293 | +} |
| 294 | + |
| 295 | +impl crate::pubsub::JsonSink for SendHalf { |
| 296 | + type Error = axum::Error; |
| 297 | + |
| 298 | + async fn send_json(&mut self, json: Box<RawValue>) -> Result<(), Self::Error> { |
| 299 | + self.send(Message::text(json.get())).await |
| 300 | + } |
| 301 | +} |
0 commit comments