Skip to content

Commit 05aa2b3

Browse files
authored
feat: use axum WS (#26)
* feat: use axum WS * lint: clippy * fix: dep spec better spec good now * feat: axum_ws tests * lint: clippy * nit: remove dead line * test: lower delay * chore: remove unused inports in examples * docs: expand them :) * fix: with_handle
1 parent 0499a49 commit 05aa2b3

File tree

11 files changed

+489
-77
lines changed

11 files changed

+489
-77
lines changed

Cargo.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ tracing = "0.1.41"
2626

2727
# axum
2828
axum = { version = "0.8.1", optional = true }
29-
mime = { version = "0.3.17", optional = true}
29+
mime = { version = "0.3.17", optional = true }
3030

3131
# pubsub
3232
tokio-stream = { version = "0.1.17", optional = true }
@@ -41,11 +41,12 @@ futures-util = { version = "0.3.31", optional = true }
4141
[dev-dependencies]
4242
tempfile = "3.15.0"
4343
tracing-subscriber = "0.3.19"
44+
axum = { version = "*", features = ["macros"] }
4445

4546
[features]
4647
default = ["axum", "ws", "ipc"]
4748
axum = ["dep:axum", "dep:mime"]
48-
pubsub = ["dep:tokio-stream"]
49+
pubsub = ["dep:tokio-stream", "axum?/ws"]
4950
ipc = ["pubsub", "dep:interprocess"]
5051
ws = ["pubsub", "dep:tokio-tungstenite", "dep:futures-util"]
5152

src/lib.rs

+30-4
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,36 @@
8686
//! # }}
8787
//! ```
8888
//!
89-
//! For WS and IPC connections, the `pubsub` module provides implementations of
90-
//! the `Connect` trait for [`std::net::SocketAddr`] to create simple WS
91-
//! servers, and [`interprocess::local_socket::ListenerOptions`] to create
92-
//! simple IPC servers.
89+
//! Routers can also be served over axum websockets. When both `axum` and
90+
//! `pubsub` features are enabled, the `pubsub` module provides
91+
//! [`pubsub::AxumWsCfg`] and the [`pubsub::ajj_websocket`] axum handler. This
92+
//! handler will serve the router over websockets at a specific route. The
93+
//! router is a property of the `AxumWsCfg` object, and is passed to the
94+
//! handler via axum's `State` extractor.
95+
//!
96+
//! ```no_run
97+
//! # #[cfg(all(feature = "axum", feature = "pubsub"))]
98+
//! # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
99+
//! # {
100+
//! # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
101+
//! // The config object contains the tokio runtime handle, and the
102+
//! // notification buffer size.
103+
//! let cfg = AxumWsCfg::new(router);
104+
//!
105+
//! axum
106+
//! .route("/ws", axum::routing::any(ajj_websocket))
107+
//! .with_state(cfg)
108+
//! # }}
109+
//! ```
110+
//!
111+
//! For IPC and non-axum WebSocket connections, the `pubsub` module provides
112+
//! implementations of the `Connect` trait for [`std::net::SocketAddr`] to
113+
//! create simple WS servers, and
114+
//! [`interprocess::local_socket::ListenerOptions`] to create simple IPC
115+
//! servers. We generally recommend using `axum` for WebSocket connections, as
116+
//! it provides a more complete and robust implementation, however, users
117+
//! needing additional control, or wanting to avoid the `axum` dependency
118+
//! can use the `pubsub` module directly.
93119
//!
94120
//! ```no_run
95121
//! # #[cfg(feature = "pubsub")]

src/pubsub/axum.rs

+301
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
//! WebSocket connection manager for [`axum`]
2+
//!
3+
//! How this works:
4+
//! `axum` does not provide a connection pattern that allows us to iplement
5+
//! [`Listener`] or [`Connect`] directly. Instead, it uses a
6+
//! [`WebSocketUpgrade`] to upgrade a connection to a WebSocket. This means
7+
//! that we cannot use the [`Listener`] trait directly. Instead, we make a
8+
//! [`AxumWsCfg`] that will be the [`State`] for our handler.
9+
//!
10+
//! The [`ajj_websocket`] handler serves the role of the [`Listener`] in this
11+
//! case.
12+
//!
13+
//! [`Connect`]: crate::pubsub::Connect
14+
15+
use crate::{
16+
pubsub::{shared::ConnectionManager, Listener},
17+
Router,
18+
};
19+
use axum::{
20+
extract::{
21+
ws::{Message, WebSocket},
22+
State, WebSocketUpgrade,
23+
},
24+
response::Response,
25+
};
26+
use bytes::Bytes;
27+
use futures_util::{
28+
stream::{SplitSink, SplitStream},
29+
SinkExt, Stream, StreamExt,
30+
};
31+
use serde_json::value::RawValue;
32+
use std::{
33+
convert::Infallible,
34+
pin::Pin,
35+
sync::Arc,
36+
task::{ready, Context, Poll},
37+
};
38+
use tokio::runtime::Handle;
39+
use tracing::debug;
40+
41+
pub(crate) type SendHalf = SplitSink<WebSocket, Message>;
42+
pub(crate) type RecvHalf = SplitStream<WebSocket>;
43+
44+
struct AxumListener;
45+
46+
impl Listener for AxumListener {
47+
type RespSink = SendHalf;
48+
49+
type ReqStream = WsJsonStream;
50+
51+
type Error = Infallible;
52+
53+
async fn accept(&self) -> Result<(Self::RespSink, Self::ReqStream), Self::Error> {
54+
unreachable!()
55+
}
56+
}
57+
58+
/// Configuration details for WebSocket connections using [`axum::extract::ws`].
59+
///
60+
/// The main points of configuration are:
61+
/// - The runtime [`Handle`] on which to execute tasks, which can be set with
62+
/// [`Self::with_handle`]. This defaults to the current thread's runtime
63+
/// handle.
64+
/// - The notification buffer size per client, which can be set with
65+
/// [`Self::with_notification_buffer_per_client`]. See the [`crate::pubsub`]
66+
/// module documentation for more details.
67+
///
68+
/// This struct is used as the [`State`] for the [`ajj_websocket`] handler, and
69+
/// should be created from a fully-configured [`Router<()>`].
70+
///
71+
/// # Note
72+
///
73+
/// If [`AxumWsCfg`] is NOT used within a `tokio` runtime,
74+
/// [`AxumWsCfg::with_handle`] MUST be called to set the runtime handle before
75+
/// any requests are routed. Attempting to execute a task without an active
76+
/// runtime will result in a panic.
77+
///
78+
/// # Example
79+
///
80+
/// ```no_run
81+
/// # #[cfg(all(feature = "axum", feature = "pubsub"))]
82+
/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
83+
/// # {
84+
/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>, handle: tokio::runtime::Handle) {
85+
/// let cfg = AxumWsCfg::from(router)
86+
/// .with_handle(handle)
87+
/// .with_notification_buffer_per_client(10);
88+
/// # }}
89+
/// ```
90+
#[derive(Clone)]
91+
pub struct AxumWsCfg {
92+
inner: Arc<ConnectionManager>,
93+
}
94+
95+
impl core::fmt::Debug for AxumWsCfg {
96+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
97+
f.debug_struct("AxumWsCfg")
98+
.field(
99+
"notification_buffer_per_client",
100+
&self.inner.notification_buffer_per_task,
101+
)
102+
.field("next_id", &self.inner.next_id)
103+
.finish()
104+
}
105+
}
106+
107+
impl From<Router<()>> for AxumWsCfg {
108+
fn from(router: Router<()>) -> Self {
109+
Self::new(router)
110+
}
111+
}
112+
113+
impl AxumWsCfg {
114+
/// Create a new [`AxumWsCfg`] with the given [`Router`].
115+
pub fn new(router: Router<()>) -> Self {
116+
Self {
117+
inner: ConnectionManager::new(router).into(),
118+
}
119+
}
120+
121+
fn into_inner(self) -> ConnectionManager {
122+
match Arc::try_unwrap(self.inner) {
123+
Ok(inner) => inner,
124+
Err(arc) => ConnectionManager {
125+
root_tasks: arc.root_tasks.clone(),
126+
next_id: arc.next_id.clone(),
127+
router: arc.router.clone(),
128+
notification_buffer_per_task: arc.notification_buffer_per_task,
129+
},
130+
}
131+
}
132+
133+
/// Set the handle on which to execute tasks.
134+
pub fn with_handle(self, handle: Handle) -> Self {
135+
Self {
136+
inner: self.into_inner().with_handle(handle).into(),
137+
}
138+
}
139+
140+
/// Set the notification buffer size per client. See the [`crate::pubsub`]
141+
/// module documentation for more details.
142+
pub fn with_notification_buffer_per_client(
143+
self,
144+
notification_buffer_per_client: usize,
145+
) -> Self {
146+
Self {
147+
inner: self
148+
.into_inner()
149+
.with_notification_buffer_per_client(notification_buffer_per_client)
150+
.into(),
151+
}
152+
}
153+
}
154+
155+
/// Axum handler for WebSocket connections.
156+
///
157+
/// Used to serve [`crate::Router`]s over WebSocket connections via [`axum`]'s
158+
/// built-in WebSocket support. This handler is used in conjunction with
159+
/// [`AxumWsCfg`], which is passed as the [`State`] to the handler.
160+
///
161+
/// # Examples
162+
///
163+
/// Basic usage:
164+
///
165+
/// ```no_run
166+
/// # #[cfg(all(feature = "axum", feature = "pubsub"))]
167+
/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
168+
/// # {
169+
/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
170+
/// // The config object contains the tokio runtime handle, and the
171+
/// // notification buffer size.
172+
/// let cfg = AxumWsCfg::new(router);
173+
///
174+
/// axum
175+
/// .route("/ws", axum::routing::any(ajj_websocket))
176+
/// .with_state(cfg)
177+
/// # }}
178+
/// ```
179+
///
180+
/// The [`Router`] is a property of the [`AxumWsCfg`]. This means it is not
181+
/// paramterized until the [`axum::Router::with_state`] method is called. This
182+
/// has two significant consequences:
183+
/// 1. You can easily register the same [`Router`] with multiple handlers.
184+
/// 2. In order to register a second [`Router`] you need a second [`AxumWsCfg`].
185+
///
186+
/// Registering the same [`Router`] with multiple handlers:
187+
///
188+
/// ```no_run
189+
/// # #[cfg(all(feature = "axum", feature = "pubsub"))]
190+
/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
191+
/// # {
192+
/// # async fn _main(router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
193+
/// // The config object contains the tokio runtime handle, and the
194+
/// // notification buffer size.
195+
/// let cfg = AxumWsCfg::new(router);
196+
///
197+
/// axum
198+
/// .route("/ws", axum::routing::any(ajj_websocket))
199+
/// .route("/super-secret-ws", axum::routing::any(ajj_websocket))
200+
/// .with_state(cfg)
201+
/// # }}
202+
/// ```
203+
///
204+
/// Registering a second [`Router`] at a different path:
205+
///
206+
/// ```no_run
207+
/// # #[cfg(all(feature = "axum", feature = "pubsub"))]
208+
/// # use ajj::{Router, pubsub::{ajj_websocket, AxumWsCfg}};
209+
/// # {
210+
/// # async fn _main(router: Router<()>, other_router: Router<()>, axum: axum::Router<AxumWsCfg>) -> axum::Router<()>{
211+
/// // The config object contains the tokio runtime handle, and the
212+
/// // notification buffer size.
213+
/// let cfg = AxumWsCfg::new(router);
214+
/// let other_cfg = AxumWsCfg::new(other_router);
215+
///
216+
/// axum
217+
/// .route("/really-cool-ws-1", axum::routing::any(ajj_websocket))
218+
/// .with_state(cfg)
219+
/// .route("/even-cooler-ws-2", axum::routing::any(ajj_websocket))
220+
/// .with_state(other_cfg)
221+
/// # }}
222+
/// ```
223+
pub async fn ajj_websocket(ws: WebSocketUpgrade, State(state): State<AxumWsCfg>) -> Response {
224+
ws.on_upgrade(move |ws| {
225+
let (sink, stream) = ws.split();
226+
227+
state
228+
.inner
229+
.handle_new_connection::<AxumListener>(stream.into(), sink);
230+
231+
async {}
232+
})
233+
}
234+
235+
/// Simple stream adapter for extracting text from a [`WebSocket`].
236+
#[derive(Debug)]
237+
struct WsJsonStream {
238+
inner: RecvHalf,
239+
complete: bool,
240+
}
241+
242+
impl From<RecvHalf> for WsJsonStream {
243+
fn from(inner: RecvHalf) -> Self {
244+
Self {
245+
inner,
246+
complete: false,
247+
}
248+
}
249+
}
250+
251+
impl WsJsonStream {
252+
/// Handle an incoming [`Message`]
253+
fn handle(&self, message: Message) -> Result<Option<Bytes>, &'static str> {
254+
match message {
255+
Message::Text(text) => Ok(Some(text.into())),
256+
Message::Close(Some(frame)) => {
257+
let s = "Received close frame with data";
258+
let reason = format!("{} ({})", frame.reason, frame.code);
259+
debug!(%reason, "{}", &s);
260+
Err(s)
261+
}
262+
Message::Close(None) => {
263+
let s = "WS client has gone away";
264+
debug!("{}", &s);
265+
Err(s)
266+
}
267+
_ => Ok(None),
268+
}
269+
}
270+
}
271+
272+
impl Stream for WsJsonStream {
273+
type Item = Bytes;
274+
275+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
276+
loop {
277+
if self.complete {
278+
return Poll::Ready(None);
279+
}
280+
281+
let Some(Ok(msg)) = ready!(self.inner.poll_next_unpin(cx)) else {
282+
self.complete = true;
283+
return Poll::Ready(None);
284+
};
285+
286+
match self.handle(msg) {
287+
Ok(Some(item)) => return Poll::Ready(Some(item)),
288+
Ok(None) => continue,
289+
Err(_) => self.complete = true,
290+
}
291+
}
292+
}
293+
}
294+
295+
impl crate::pubsub::JsonSink for SendHalf {
296+
type Error = axum::Error;
297+
298+
async fn send_json(&mut self, json: Box<RawValue>) -> Result<(), Self::Error> {
299+
self.send(Message::text(json.get())).await
300+
}
301+
}

src/pubsub/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,8 @@ pub use r#trait::{Connect, In, JsonReqStream, JsonSink, Listener, Out};
105105

106106
#[cfg(feature = "ws")]
107107
mod ws;
108+
109+
#[cfg(feature = "axum")]
110+
mod axum;
111+
#[cfg(feature = "axum")]
112+
pub use axum::{ajj_websocket, AxumWsCfg};

0 commit comments

Comments
 (0)