Skip to content

Commit 72e7533

Browse files
authored
test(logging): Add tests for logging (#96)
* test(logging): implement basic logging functionality * test(logging): add comprehensive server transport tests
1 parent 57f2ba2 commit 72e7533

File tree

3 files changed

+336
-5
lines changed

3 files changed

+336
-5
lines changed

crates/rmcp/Cargo.toml

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
2-
31
[package]
42
name = "rmcp"
53
license = { workspace = true }
@@ -25,7 +23,6 @@ tracing = { version = "0.1" }
2523
tokio-util = { version = "0.7" }
2624
pin-project-lite = "0.2"
2725
paste = { version = "1", optional = true }
28-
2926
# for auto generate schema
3027
schemars = { version = "0.8", optional = true }
3128

@@ -103,3 +100,8 @@ name = "test_notification"
103100
required-features = ["server", "client"]
104101
path = "tests/test_notification.rs"
105102

103+
[[test]]
104+
name = "test_logging"
105+
required-features = ["server", "client"]
106+
path = "tests/test_logging.rs"
107+

crates/rmcp/src/model.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -610,8 +610,8 @@ pub type PromptListChangedNotification = NotificationNoParam<PromptListChangedNo
610610
const_string!(ToolListChangedNotificationMethod = "notifications/tools/list_changed");
611611
pub type ToolListChangedNotification = NotificationNoParam<ToolListChangedNotificationMethod>;
612612
// 日志相关
613-
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
614-
#[serde(rename_all = "camelCase")]
613+
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Copy)]
614+
#[serde(rename_all = "lowercase")] //match spec
615615
pub enum LoggingLevel {
616616
Debug,
617617
Info,

crates/rmcp/tests/test_logging.rs

+329
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
// cargo test --features "server client" --package rmcp test_logging
2+
use std::{
3+
future::Future,
4+
sync::{Arc, Mutex},
5+
};
6+
7+
use rmcp::{
8+
ClientHandler, Error as McpError, Peer, RoleClient, RoleServer, ServerHandler, ServiceExt,
9+
model::{
10+
LoggingLevel, LoggingMessageNotificationParam, ServerCapabilities, ServerInfo,
11+
SetLevelRequestParam,
12+
},
13+
service::RequestContext,
14+
};
15+
use tokio::sync::Notify;
16+
17+
pub struct LoggingClient {
18+
receive_signal: Arc<Notify>,
19+
received_messages: Arc<Mutex<Vec<LoggingMessageNotificationParam>>>,
20+
peer: Option<Peer<RoleClient>>,
21+
}
22+
23+
impl ClientHandler for LoggingClient {
24+
async fn on_logging_message(&self, params: LoggingMessageNotificationParam) {
25+
println!("Client: Received log message: {:?}", params);
26+
let mut messages = self.received_messages.lock().unwrap();
27+
messages.push(params);
28+
self.receive_signal.notify_one();
29+
}
30+
31+
fn set_peer(&mut self, peer: Peer<RoleClient>) {
32+
self.peer.replace(peer);
33+
}
34+
35+
fn get_peer(&self) -> Option<Peer<RoleClient>> {
36+
self.peer.clone()
37+
}
38+
}
39+
40+
pub struct TestServer {}
41+
42+
impl TestServer {
43+
fn new() -> Self {
44+
Self {}
45+
}
46+
}
47+
48+
impl ServerHandler for TestServer {
49+
fn get_info(&self) -> ServerInfo {
50+
ServerInfo {
51+
capabilities: ServerCapabilities::builder().enable_logging().build(),
52+
..Default::default()
53+
}
54+
}
55+
56+
fn set_level(
57+
&self,
58+
request: SetLevelRequestParam,
59+
context: RequestContext<RoleServer>,
60+
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
61+
let peer = context.peer;
62+
async move {
63+
let (data, logger) = match request.level {
64+
LoggingLevel::Error => (
65+
serde_json::json!({
66+
"message": "Failed to process request",
67+
"error_code": "E1001",
68+
"error_details": "Connection timeout",
69+
"timestamp": chrono::Utc::now().to_rfc3339(),
70+
}),
71+
Some("error_handler".to_string()),
72+
),
73+
LoggingLevel::Debug => (
74+
serde_json::json!({
75+
"message": "Processing request",
76+
"function": "handle_request",
77+
"line": 42,
78+
"context": {
79+
"request_id": "req-123",
80+
"user_id": "user-456"
81+
},
82+
"timestamp": chrono::Utc::now().to_rfc3339(),
83+
}),
84+
Some("debug_logger".to_string()),
85+
),
86+
LoggingLevel::Info => (
87+
serde_json::json!({
88+
"message": "System status update",
89+
"status": "healthy",
90+
"metrics": {
91+
"requests_per_second": 150,
92+
"average_latency_ms": 45,
93+
"error_rate": 0.01
94+
},
95+
"timestamp": chrono::Utc::now().to_rfc3339(),
96+
}),
97+
Some("monitoring".to_string()),
98+
),
99+
_ => (
100+
serde_json::json!({
101+
"message": format!("Message at level {:?}", request.level),
102+
"timestamp": chrono::Utc::now().to_rfc3339(),
103+
}),
104+
None,
105+
),
106+
};
107+
108+
if let Err(e) = peer
109+
.notify_logging_message(LoggingMessageNotificationParam {
110+
level: request.level,
111+
data,
112+
logger,
113+
})
114+
.await
115+
{
116+
panic!("Failed to send notification: {}", e);
117+
}
118+
Ok(())
119+
}
120+
}
121+
}
122+
123+
#[tokio::test]
124+
async fn test_logging_spec_compliance() -> anyhow::Result<()> {
125+
let (server_transport, client_transport) = tokio::io::duplex(4096);
126+
let receive_signal = Arc::new(Notify::new());
127+
let received_messages = Arc::new(Mutex::new(Vec::new()));
128+
129+
// Start server
130+
tokio::spawn(async move {
131+
let server = TestServer::new().serve(server_transport).await?;
132+
133+
// Test server can send messages before level is set
134+
server
135+
.peer()
136+
.notify_logging_message(LoggingMessageNotificationParam {
137+
level: LoggingLevel::Info,
138+
data: serde_json::json!({
139+
"message": "Server initiated message",
140+
"timestamp": chrono::Utc::now().to_rfc3339(),
141+
}),
142+
logger: Some("test_server".to_string()),
143+
})
144+
.await?;
145+
146+
server.waiting().await?;
147+
anyhow::Ok(())
148+
});
149+
150+
let client = LoggingClient {
151+
receive_signal: receive_signal.clone(),
152+
received_messages: received_messages.clone(),
153+
peer: None,
154+
}
155+
.serve(client_transport)
156+
.await?;
157+
158+
// Verify server-initiated message
159+
receive_signal.notified().await;
160+
{
161+
let mut messages = received_messages.lock().unwrap();
162+
assert_eq!(messages.len(), 1, "Should receive server-initiated message");
163+
messages.clear();
164+
}
165+
166+
// Test level filtering and message format
167+
for level in [
168+
LoggingLevel::Emergency,
169+
LoggingLevel::Warning,
170+
LoggingLevel::Debug,
171+
] {
172+
client
173+
.peer()
174+
.set_level(SetLevelRequestParam { level })
175+
.await?;
176+
receive_signal.notified().await;
177+
178+
let mut messages = received_messages.lock().unwrap();
179+
let msg = messages.last().unwrap();
180+
181+
// Verify required fields
182+
assert_eq!(msg.level, level);
183+
assert!(msg.data.is_object());
184+
185+
// Verify data format
186+
let data = msg.data.as_object().unwrap();
187+
assert!(data.contains_key("message"));
188+
assert!(data.contains_key("timestamp"));
189+
190+
// Verify timestamp
191+
let timestamp = data["timestamp"].as_str().unwrap();
192+
chrono::DateTime::parse_from_rfc3339(timestamp).expect("RFC3339 timestamp");
193+
194+
messages.clear();
195+
}
196+
197+
client.cancel().await?;
198+
Ok(())
199+
}
200+
201+
#[tokio::test]
202+
async fn test_logging_user_scenarios() -> anyhow::Result<()> {
203+
let (server_transport, client_transport) = tokio::io::duplex(4096);
204+
let receive_signal = Arc::new(Notify::new());
205+
let received_messages = Arc::new(Mutex::new(Vec::new()));
206+
207+
// Start server
208+
tokio::spawn(async move {
209+
let server = TestServer::new().serve(server_transport).await?;
210+
server.waiting().await?;
211+
anyhow::Ok(())
212+
});
213+
214+
let client = LoggingClient {
215+
receive_signal: receive_signal.clone(),
216+
received_messages: received_messages.clone(),
217+
peer: None,
218+
}
219+
.serve(client_transport)
220+
.await?;
221+
222+
// Test 1: Error reporting scenario
223+
// User should see detailed error information
224+
client
225+
.peer()
226+
.set_level(SetLevelRequestParam {
227+
level: LoggingLevel::Error,
228+
})
229+
.await?;
230+
receive_signal.notified().await;
231+
{
232+
let messages = received_messages.lock().unwrap();
233+
let msg = &messages[0];
234+
let data = msg.data.as_object().unwrap();
235+
assert!(
236+
data.contains_key("error_code"),
237+
"Error should have an error code"
238+
);
239+
assert!(
240+
data.contains_key("error_details"),
241+
"Error should have details"
242+
);
243+
assert!(
244+
data.contains_key("timestamp"),
245+
"Should know when error occurred"
246+
);
247+
}
248+
249+
// Test 2: Debug scenario
250+
// User debugging their application should see detailed information
251+
client
252+
.peer()
253+
.set_level(SetLevelRequestParam {
254+
level: LoggingLevel::Debug,
255+
})
256+
.await?;
257+
receive_signal.notified().await;
258+
{
259+
let messages = received_messages.lock().unwrap();
260+
let msg = messages.last().unwrap();
261+
let data = msg.data.as_object().unwrap();
262+
assert!(
263+
data.contains_key("function"),
264+
"Debug should show function name"
265+
);
266+
assert!(data.contains_key("line"), "Debug should show line number");
267+
assert!(
268+
data.contains_key("context"),
269+
"Debug should show execution context"
270+
);
271+
}
272+
273+
// Test 3: Production monitoring scenario
274+
// User monitoring production should see important status updates
275+
client
276+
.peer()
277+
.set_level(SetLevelRequestParam {
278+
level: LoggingLevel::Info,
279+
})
280+
.await?;
281+
receive_signal.notified().await;
282+
{
283+
let messages = received_messages.lock().unwrap();
284+
let msg = messages.last().unwrap();
285+
let data = msg.data.as_object().unwrap();
286+
assert!(data.contains_key("status"), "Should show system status");
287+
assert!(data.contains_key("metrics"), "Should include metrics");
288+
}
289+
290+
client.cancel().await?;
291+
Ok(())
292+
}
293+
294+
#[test]
295+
fn test_logging_level_serialization() {
296+
// Test all levels match spec exactly
297+
let test_cases = [
298+
(LoggingLevel::Alert, "alert"),
299+
(LoggingLevel::Critical, "critical"),
300+
(LoggingLevel::Debug, "debug"),
301+
(LoggingLevel::Emergency, "emergency"),
302+
(LoggingLevel::Error, "error"),
303+
(LoggingLevel::Info, "info"),
304+
(LoggingLevel::Notice, "notice"),
305+
(LoggingLevel::Warning, "warning"),
306+
];
307+
308+
for (level, expected) in test_cases {
309+
let serialized = serde_json::to_string(&level).unwrap();
310+
// Remove quotes from serialized string
311+
let serialized = serialized.trim_matches('"');
312+
assert_eq!(
313+
serialized, expected,
314+
"LoggingLevel::{:?} should serialize to \"{}\"",
315+
level, expected
316+
);
317+
}
318+
319+
// Test deserialization from spec strings
320+
for (level, spec_string) in test_cases {
321+
let deserialized: LoggingLevel =
322+
serde_json::from_str(&format!("\"{}\"", spec_string)).unwrap();
323+
assert_eq!(
324+
deserialized, level,
325+
"\"{}\" should deserialize to LoggingLevel::{:?}",
326+
spec_string, level
327+
);
328+
}
329+
}

0 commit comments

Comments
 (0)