Skip to content

feat: provide more context information #236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions crates/rmcp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,39 @@ let service = client.serve(transport).await?;



## Access with peer interface when handling message

You can get the [`Peer`](crate::service::Peer) struct from [`NotificationContext`](crate::service::NotificationContext) and [`RequestContext`](crate::service::RequestContext).

```rust, ignore
# use rmcp::{
# ServerHandler,
# model::{LoggingLevel, LoggingMessageNotificationParam, ProgressNotificationParam},
# service::{NotificationContext, RoleServer},
# };
# pub struct Handler;

impl ServerHandler for Handler {
async fn on_progress(
&self,
notification: ProgressNotificationParam,
context: NotificationContext<RoleServer>,
) {
let peer = context.peer;
let _ = peer
.notify_logging_message(LoggingMessageNotificationParam {
level: LoggingLevel::Info,
logger: None,
data: serde_json::json!({
"message": format!("Progress: {}", notification.progress),
}),
})
.await;
}
}
```


## Manage Multi Services

For many cases you need to manage several service in a collection, you can call `into_dyn` to convert services into the same type.
Expand Down
36 changes: 25 additions & 11 deletions crates/rmcp/src/handler/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
error::Error as McpError,
model::*,
service::{RequestContext, RoleClient, Service, ServiceRole},
service::{NotificationContext, RequestContext, RoleClient, Service, ServiceRole},
};

impl<H: ClientHandler> Service<RoleClient> for H {
Expand All @@ -26,28 +26,29 @@ impl<H: ClientHandler> Service<RoleClient> for H {
async fn handle_notification(
&self,
notification: <RoleClient as ServiceRole>::PeerNot,
context: NotificationContext<RoleClient>,
) -> Result<(), McpError> {
match notification {
ServerNotification::CancelledNotification(notification) => {
self.on_cancelled(notification.params).await
self.on_cancelled(notification.params, context).await
}
ServerNotification::ProgressNotification(notification) => {
self.on_progress(notification.params).await
self.on_progress(notification.params, context).await
}
ServerNotification::LoggingMessageNotification(notification) => {
self.on_logging_message(notification.params).await
self.on_logging_message(notification.params, context).await
}
ServerNotification::ResourceUpdatedNotification(notification) => {
self.on_resource_updated(notification.params).await
self.on_resource_updated(notification.params, context).await
}
ServerNotification::ResourceListChangedNotification(_notification_no_param) => {
self.on_resource_list_changed().await
self.on_resource_list_changed(context).await
}
ServerNotification::ToolListChangedNotification(_notification_no_param) => {
self.on_tool_list_changed().await
self.on_tool_list_changed(context).await
}
ServerNotification::PromptListChangedNotification(_notification_no_param) => {
self.on_prompt_list_changed().await
self.on_prompt_list_changed(context).await
}
};
Ok(())
Expand Down Expand Up @@ -87,34 +88,47 @@ pub trait ClientHandler: Sized + Send + Sync + 'static {
fn on_cancelled(
&self,
params: CancelledNotificationParam,
context: NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_progress(
&self,
params: ProgressNotificationParam,
context: NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_logging_message(
&self,
params: LoggingMessageNotificationParam,
context: NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_resource_updated(
&self,
params: ResourceUpdatedNotificationParam,
context: NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_resource_list_changed(&self) -> impl Future<Output = ()> + Send + '_ {
fn on_resource_list_changed(
&self,
context: NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_tool_list_changed(&self) -> impl Future<Output = ()> + Send + '_ {
fn on_tool_list_changed(
&self,
context: NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_prompt_list_changed(&self) -> impl Future<Output = ()> + Send + '_ {
fn on_prompt_list_changed(
&self,
context: NotificationContext<RoleClient>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}

Expand Down
23 changes: 16 additions & 7 deletions crates/rmcp/src/handler/server.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
error::Error as McpError,
model::*,
service::{RequestContext, RoleServer, Service, ServiceRole},
service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole},
};

mod resource;
Expand Down Expand Up @@ -71,19 +71,20 @@ impl<H: ServerHandler> Service<RoleServer> for H {
async fn handle_notification(
&self,
notification: <RoleServer as ServiceRole>::PeerNot,
context: NotificationContext<RoleServer>,
) -> Result<(), McpError> {
match notification {
ClientNotification::CancelledNotification(notification) => {
self.on_cancelled(notification.params).await
self.on_cancelled(notification.params, context).await
}
ClientNotification::ProgressNotification(notification) => {
self.on_progress(notification.params).await
self.on_progress(notification.params, context).await
}
ClientNotification::InitializedNotification(_notification) => {
self.on_initialized().await
self.on_initialized(context).await
}
ClientNotification::RootsListChangedNotification(_notification) => {
self.on_roots_list_changed().await
self.on_roots_list_changed(context).await
}
};
Ok(())
Expand Down Expand Up @@ -196,20 +197,28 @@ pub trait ServerHandler: Sized + Send + Sync + 'static {
fn on_cancelled(
&self,
notification: CancelledNotificationParam,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_progress(
&self,
notification: ProgressNotificationParam,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}
fn on_initialized(&self) -> impl Future<Output = ()> + Send + '_ {
fn on_initialized(
&self,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + Send + '_ {
tracing::info!("client initialized");
std::future::ready(())
}
fn on_roots_list_changed(&self) -> impl Future<Output = ()> + Send + '_ {
fn on_roots_list_changed(
&self,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + Send + '_ {
std::future::ready(())
}

Expand Down
36 changes: 36 additions & 0 deletions crates/rmcp/src/handler/server/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,42 @@ where
}
}

impl<'a, S> FromToolCallContextPart<'a, S> for crate::Peer<RoleServer> {
fn from_tool_call_context_part(
context: ToolCallContext<'a, S>,
) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> {
let peer = context.request_context.peer.clone();
Ok((peer, context))
}
}

impl<'a, S> FromToolCallContextPart<'a, S> for crate::model::Meta {
fn from_tool_call_context_part(
mut context: ToolCallContext<'a, S>,
) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> {
let mut meta = crate::model::Meta::default();
std::mem::swap(&mut meta, &mut context.request_context.meta);
Ok((meta, context))
}
}

pub struct RequestId(pub crate::model::RequestId);
impl<'a, S> FromToolCallContextPart<'a, S> for RequestId {
fn from_tool_call_context_part(
context: ToolCallContext<'a, S>,
) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> {
Ok((RequestId(context.request_context.id.clone()), context))
}
}

impl<'a, S> FromToolCallContextPart<'a, S> for RequestContext<RoleServer> {
fn from_tool_call_context_part(
context: ToolCallContext<'a, S>,
) -> Result<(Self, ToolCallContext<'a, S>), crate::Error> {
Ok((context.request_context.clone(), context))
}
}

impl<'s, S> ToolCallContext<'s, S> {
pub fn invoke<H, A>(self, h: H) -> H::Fut
where
Expand Down
2 changes: 1 addition & 1 deletion crates/rmcp/src/model/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub struct Extensions {
impl Extensions {
/// Create an empty `Extensions`.
#[inline]
pub fn new() -> Extensions {
pub const fn new() -> Extensions {
Extensions { map: None }
}

Expand Down
58 changes: 48 additions & 10 deletions crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ pub trait ServiceRole: std::fmt::Debug + Send + Sync + 'static + Copy + Clone {
type PeerResp: TransferObject;
type PeerNot: TryInto<CancelledNotification, Error = Self::PeerNot>
+ From<CancelledNotification>
+ TransferObject;
+ TransferObject
+ GetMeta
+ GetExtensions;
type InitializeError<E>;
const IS_CLIENT: bool;
type Info: TransferObject;
Expand All @@ -100,6 +102,7 @@ pub trait Service<R: ServiceRole>: Send + Sync + 'static {
fn handle_notification(
&self,
notification: R::PeerNot,
context: NotificationContext<R>,
) -> impl Future<Output = Result<(), McpError>> + Send + '_;
fn get_info(&self) -> R::Info;
}
Expand Down Expand Up @@ -145,8 +148,9 @@ impl<R: ServiceRole> Service<R> for Box<dyn DynService<R>> {
fn handle_notification(
&self,
notification: R::PeerNot,
context: NotificationContext<R>,
) -> impl Future<Output = Result<(), McpError>> + Send + '_ {
DynService::handle_notification(self.as_ref(), notification)
DynService::handle_notification(self.as_ref(), notification, context)
}

fn get_info(&self) -> R::Info {
Expand All @@ -160,7 +164,11 @@ pub trait DynService<R: ServiceRole>: Send + Sync {
request: R::PeerReq,
context: RequestContext<R>,
) -> BoxFuture<Result<R::Resp, McpError>>;
fn handle_notification(&self, notification: R::PeerNot) -> BoxFuture<Result<(), McpError>>;
fn handle_notification(
&self,
notification: R::PeerNot,
context: NotificationContext<R>,
) -> BoxFuture<Result<(), McpError>>;
fn get_info(&self) -> R::Info;
}

Expand All @@ -172,8 +180,12 @@ impl<R: ServiceRole, S: Service<R>> DynService<R> for S {
) -> BoxFuture<Result<R::Resp, McpError>> {
Box::pin(self.handle_request(request, context))
}
fn handle_notification(&self, notification: R::PeerNot) -> BoxFuture<Result<(), McpError>> {
Box::pin(self.handle_notification(notification))
fn handle_notification(
&self,
notification: R::PeerNot,
context: NotificationContext<R>,
) -> BoxFuture<Result<(), McpError>> {
Box::pin(self.handle_notification(notification, context))
}
fn get_info(&self) -> R::Info {
self.get_info()
Expand Down Expand Up @@ -487,6 +499,15 @@ pub struct RequestContext<R: ServiceRole> {
pub peer: Peer<R>,
}

/// Request execution context
#[derive(Debug, Clone)]
pub struct NotificationContext<R: ServiceRole> {
pub meta: Meta,
pub extensions: Extensions,
/// An interface to fetch the remote client or server
pub peer: Peer<R>,
}

/// Use this function to skip initialization process
pub fn serve_directly<R, S, T, E, A>(
service: S,
Expand Down Expand Up @@ -710,7 +731,9 @@ where
}));
}
Event::PeerMessage(JsonRpcMessage::Request(JsonRpcRequest {
id, request, ..
id,
mut request,
..
})) => {
tracing::debug!(%id, ?request, "received request");
{
Expand All @@ -719,12 +742,17 @@ where
let request_ct = serve_loop_ct.child_token();
let context_ct = request_ct.child_token();
local_ct_pool.insert(id.clone(), request_ct);
let mut extensions = Extensions::new();
let mut meta = Meta::new();
// avoid clone
std::mem::swap(&mut extensions, request.extensions_mut());
std::mem::swap(&mut meta, request.get_meta_mut());
let context = RequestContext {
ct: context_ct,
id: id.clone(),
peer: peer.clone(),
meta: request.get_meta().clone(),
extensions: request.extensions().clone(),
meta,
extensions,
};
tokio::spawn(async move {
let result = service.handle_request(request, context).await;
Expand All @@ -748,7 +776,7 @@ where
})) => {
tracing::info!(?notification, "received notification");
// catch cancelled notification
let notification = match notification.try_into() {
let mut notification = match notification.try_into() {
Ok::<CancelledNotification, _>(cancelled) => {
if let Some(ct) = local_ct_pool.remove(&cancelled.params.request_id) {
tracing::info!(id = %cancelled.params.request_id, reason = cancelled.params.reason, "cancelled");
Expand All @@ -760,8 +788,18 @@ where
};
{
let service = shared_service.clone();
let mut extensions = Extensions::new();
let mut meta = Meta::new();
// avoid clone
std::mem::swap(&mut extensions, notification.extensions_mut());
std::mem::swap(&mut meta, notification.get_meta_mut());
let context = NotificationContext {
peer: peer.clone(),
meta,
extensions,
};
tokio::spawn(async move {
let result = service.handle_notification(notification).await;
let result = service.handle_notification(notification, context).await;
if let Err(error) = result {
tracing::warn!(%error, "Error sending notification");
}
Expand Down
Loading