diff --git a/docs/concepts/filters.md b/docs/concepts/filters.md index 91b11d6af..fbf7b6d6b 100644 --- a/docs/concepts/filters.md +++ b/docs/concepts/filters.md @@ -7,9 +7,14 @@ uid: filters # MCP Server Handler Filters -For each handler type in the MCP Server, there are corresponding `AddXXXFilter` methods in `McpServerBuilderExtensions.cs` that allow you to add filters to the handler pipeline. The filters are stored in `McpServerOptions.Filters` and applied during server configuration. +The MCP Server provides two levels of filters for intercepting and modifying request processing: -## Available Filter Methods +1. **Message Filters** - Low-level filters (`AddIncomingMessageFilter`, `AddOutgoingMessageFilter`) that intercept all JSON-RPC messages before routing +2. **Request-Specific Filters** - Handler-level filters (e.g., `AddListToolsFilter`, `AddCallToolFilter`) that target specific MCP operations + +The filters are stored in `McpServerOptions.Filters` and applied during server configuration. + +## Available Request-Specific Filter Methods The following filter methods are available: @@ -25,6 +30,212 @@ The following filter methods are available: - `AddUnsubscribeFromResourcesFilter` - Filter for resource unsubscription handlers - `AddSetLoggingLevelFilter` - Filter for logging level handlers +## Message Filters + +In addition to the request-specific filters above, there are low-level message filters that intercept all JSON-RPC messages before they are routed to specific handlers: + +- `AddIncomingMessageFilter` - Filter for all incoming JSON-RPC messages (requests and notifications) +- `AddOutgoingMessageFilter` - Filter for all outgoing JSON-RPC messages (responses and notifications) + +### When to Use Message Filters + +Message filters operate at a lower level than request-specific filters and are useful when you need to: + +- Intercept all messages regardless of type +- Implement custom protocol extensions or handle custom JSON-RPC methods +- Log or monitor all traffic between client and server +- Modify or skip messages before they reach handlers +- Send additional messages in response to specific events + +### Incoming Message Filter + +`AddIncomingMessageFilter` intercepts all incoming JSON-RPC messages before they are dispatched to request-specific handlers: + +```csharp +services.AddMcpServer() + .AddIncomingMessageFilter(next => async (context, cancellationToken) => + { + var logger = context.Services?.GetService>(); + + // Access the raw JSON-RPC message + if (context.JsonRpcMessage is JsonRpcRequest request) + { + logger?.LogInformation($"Incoming request: {request.Method}"); + } + + // Call next to continue processing + await next(context, cancellationToken); + }) + .WithTools(); +``` + +#### MessageContext Properties + +Inside an incoming message filter, you have access to: + +- `context.JsonRpcMessage` - The incoming `JsonRpcMessage` (can be `JsonRpcRequest` or `JsonRpcNotification`) +- `context.Server` - The `McpServer` instance for sending responses or notifications +- `context.Services` - The request's service provider +- `context.Items` - A dictionary for passing data between filters + +#### Skipping Default Handlers + +You can skip the default handler by not calling `next`. This is useful for implementing custom protocol methods: + +```csharp +.AddIncomingMessageFilter(next => async (context, cancellationToken) => +{ + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == "custom/myMethod") + { + // Handle the custom method directly + var response = new JsonRpcResponse + { + Id = request.Id, + Result = JsonSerializer.SerializeToNode(new { message = "Custom response" }) + }; + await context.Server.SendMessageAsync(response, cancellationToken); + return; // Don't call next - we handled it + } + + await next(context, cancellationToken); +}) +``` + +### Outgoing Message Filter + +`AddOutgoingMessageFilter` intercepts all outgoing JSON-RPC messages before they are sent to the client: + +```csharp +services.AddMcpServer() + .AddOutgoingMessageFilter(next => async (context, cancellationToken) => + { + var logger = context.Services?.GetService>(); + + // Inspect outgoing messages + switch (context.JsonRpcMessage) + { + case JsonRpcResponse response: + logger?.LogInformation($"Sending response for request {response.Id}"); + break; + case JsonRpcNotification notification: + logger?.LogInformation($"Sending notification: {notification.Method}"); + break; + } + + await next(context, cancellationToken); + }) + .WithTools(); +``` + +#### Skipping Outgoing Messages + +You can suppress outgoing messages by not calling `next`: + +```csharp +.AddOutgoingMessageFilter(next => async (context, cancellationToken) => +{ + // Suppress specific notifications + if (context.JsonRpcMessage is JsonRpcNotification notification && + notification.Method == "notifications/progress") + { + return; // Don't send this notification + } + + await next(context, cancellationToken); +}) +``` + +#### Sending Additional Messages + +Outgoing message filters can send additional messages by calling `next` with a new `MessageContext`: + +```csharp +.AddOutgoingMessageFilter(next => async (context, cancellationToken) => +{ + // Send an extra notification before certain responses + if (context.JsonRpcMessage is JsonRpcResponse response && + response.Result is JsonObject result && + result.ContainsKey("tools")) + { + var notification = new JsonRpcNotification + { + Method = "custom/toolsListed", + Params = new JsonObject { ["timestamp"] = DateTime.UtcNow.ToString("O") }, + Context = new JsonRpcMessageContext + { + RelatedTransport = context.JsonRpcMessage.Context?.RelatedTransport + } + }; + await next(new MessageContext(context.Server, notification), cancellationToken); + } + + await next(context, cancellationToken); +}) +``` + +### Message Filter Execution Order + +Message filters execute in registration order, with the first registered filter being the outermost: + +```csharp +services.AddMcpServer() + .AddIncomingMessageFilter(incomingFilter1) // Incoming: executes first (outermost) + .AddIncomingMessageFilter(incomingFilter2) // Incoming: executes second + .AddOutgoingMessageFilter(outgoingFilter1) // Outgoing: executes first (outermost) + .AddOutgoingMessageFilter(outgoingFilter2) // Outgoing: executes second + .AddListToolsFilter(toolsFilter) // Request-specific filter + .WithTools(); +``` + +**Important**: Incoming message filters always run before request-specific filters, and outgoing message filters run when responses or notifications are sent. The complete execution flow for a request/response cycle is: + +``` +Request arrives + ↓ +IncomingFilter1 (before next) + ↓ +IncomingFilter2 (before next) + ↓ +Request Routing → ListToolsFilter → Handler + ↓ +IncomingFilter2 (after next) + ↓ +IncomingFilter1 (after next) + ↓ +Response sent via OutgoingFilter1 (before next) + ↓ +OutgoingFilter2 (before next) + ↓ +Transport sends message + ↓ +OutgoingFilter2 (after next) + ↓ +OutgoingFilter1 (after next) +``` + +### Passing Data Between Filters + +The `Items` dictionary allows you to pass data between filters processing the same message: + +```csharp +.AddIncomingMessageFilter(next => async (context, cancellationToken) => +{ + context.Items["requestStartTime"] = DateTime.UtcNow; + await next(context, cancellationToken); +}) +.AddIncomingMessageFilter(next => async (context, cancellationToken) => +{ + await next(context, cancellationToken); + + if (context.Items.TryGetValue("requestStartTime", out var startTime)) + { + var elapsed = DateTime.UtcNow - (DateTime)startTime; + var logger = context.Services?.GetService>(); + logger?.LogInformation($"Request processed in {elapsed.TotalMilliseconds}ms"); + } +}) +``` + ## Usage Filters are functions that take a handler and return a new handler, allowing you to wrap the original handler with additional functionality: diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 6b45e446b..30854dcdf 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -260,9 +260,6 @@ private async Task GetAccessTokenAsync(HttpResponseMessage response, boo // Get auth server metadata var authServerMetadata = await GetAuthServerMetadataAsync(selectedAuthServer, cancellationToken).ConfigureAwait(false); - // Store auth server metadata for future refresh operations - _authServerMetadata = authServerMetadata; - // The existing access token must be invalid to have resulted in a 401 response, but refresh might still work. var resourceUri = GetRequiredResourceUri(protectedResourceMetadata); @@ -296,6 +293,9 @@ await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false) is { R } } + // Store auth server metadata for future refresh operations + _authServerMetadata = authServerMetadata; + // Perform the OAuth flow return await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); } diff --git a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs index 0b4622c68..e7728cd9d 100644 --- a/src/ModelContextProtocol.Core/Client/McpClientImpl.cs +++ b/src/ModelContextProtocol.Core/Client/McpClientImpl.cs @@ -58,7 +58,15 @@ internal McpClientImpl(ITransport transport, string endpointName, McpClientOptio RegisterHandlers(options, notificationHandlers, requestHandlers); - _sessionHandler = new McpSessionHandler(isServer: false, transport, endpointName, requestHandlers, notificationHandlers, _logger); + _sessionHandler = new McpSessionHandler( + isServer: false, + transport, + endpointName, + requestHandlers, + notificationHandlers, + incomingMessageFilter: null, + outgoingMessageFilter: null, + _logger); } private void RegisterHandlers(McpClientOptions options, NotificationHandlers notificationHandlers, RequestHandlers requestHandlers) diff --git a/src/ModelContextProtocol.Core/JsonRpcMessageFilter.cs b/src/ModelContextProtocol.Core/JsonRpcMessageFilter.cs new file mode 100644 index 000000000..0113e1abf --- /dev/null +++ b/src/ModelContextProtocol.Core/JsonRpcMessageFilter.cs @@ -0,0 +1,10 @@ +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol; + +/// +/// Represents a filter that wraps the processing of incoming JSON-RPC messages. +/// +/// The next handler in the pipeline. +/// A wrapped handler that processes messages and optionally delegates to the next handler. +internal delegate Func JsonRpcMessageFilter(Func next); diff --git a/src/ModelContextProtocol.Core/McpSessionHandler.cs b/src/ModelContextProtocol.Core/McpSessionHandler.cs index 5aa1ff3b6..9137218eb 100644 --- a/src/ModelContextProtocol.Core/McpSessionHandler.cs +++ b/src/ModelContextProtocol.Core/McpSessionHandler.cs @@ -66,6 +66,8 @@ internal static bool SupportsPrimingEvent(string? protocolVersion) private readonly ITransport _transport; private readonly RequestHandlers _requestHandlers; private readonly NotificationHandlers _notificationHandlers; + private readonly JsonRpcMessageFilter _incomingMessageFilter; + private readonly JsonRpcMessageFilter _outgoingMessageFilter; private readonly long _sessionStartingTimestamp = Stopwatch.GetTimestamp(); private readonly DistributedContextPropagator _propagator = DistributedContextPropagator.Current; @@ -95,6 +97,8 @@ internal static bool SupportsPrimingEvent(string? protocolVersion) /// The name of the endpoint for logging and debug purposes. /// A collection of request handlers. /// A collection of notification handlers. + /// A filter that wraps incoming message processing. Takes the next handler and returns a wrapped handler. If null, a passthrough filter is used. + /// A filter that wraps outgoing message processing. Takes the next handler and returns a wrapped handler. If null, a passthrough filter is used. /// The logger. public McpSessionHandler( bool isServer, @@ -102,6 +106,8 @@ public McpSessionHandler( string endpointName, RequestHandlers requestHandlers, NotificationHandlers notificationHandlers, + JsonRpcMessageFilter? incomingMessageFilter, + JsonRpcMessageFilter? outgoingMessageFilter, ILogger logger) { Throw.IfNull(transport); @@ -120,7 +126,9 @@ public McpSessionHandler( EndpointName = endpointName; _requestHandlers = requestHandlers; _notificationHandlers = notificationHandlers; - _logger = logger ?? NullLogger.Instance; + _incomingMessageFilter = incomingMessageFilter ?? (next => next); + _outgoingMessageFilter = outgoingMessageFilter ?? (next => next); + _logger = logger; LogSessionCreated(EndpointName, _sessionId, _transportKind); } @@ -309,36 +317,14 @@ private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken AddTags(ref tags, activity, message, method, target); } - switch (message) + await _incomingMessageFilter(async (msg, ct) => { - case JsonRpcRequest request: - LogRequestHandlerCalled(EndpointName, request.Method); - long requestStartingTimestamp = Stopwatch.GetTimestamp(); - try - { - var result = await HandleRequest(request, cancellationToken).ConfigureAwait(false); - LogRequestHandlerCompleted(EndpointName, request.Method, GetElapsed(requestStartingTimestamp).TotalMilliseconds); - AddResponseTags(ref tags, activity, result, method); - } - catch (Exception ex) - { - LogRequestHandlerException(EndpointName, request.Method, GetElapsed(requestStartingTimestamp).TotalMilliseconds, ex); - throw; - } - break; - - case JsonRpcNotification notification: - await HandleNotification(notification, cancellationToken).ConfigureAwait(false); - break; - - case JsonRpcMessageWithId messageWithId: - HandleMessageWithId(message, messageWithId); - break; - - default: - LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); - break; - } + var result = await HandleMessageCoreAsync(msg, ct).ConfigureAwait(false); + if (addTags && result is not null) + { + AddResponseTags(ref tags, activity, result, method); + } + })(message, cancellationToken).ConfigureAwait(false); } catch (Exception e) when (addTags) { @@ -351,7 +337,40 @@ private async Task HandleMessageAsync(JsonRpcMessage message, CancellationToken } } - private async Task HandleNotification(JsonRpcNotification notification, CancellationToken cancellationToken) + private async Task HandleMessageCoreAsync(JsonRpcMessage message, CancellationToken cancellationToken) + { + switch (message) + { + case JsonRpcRequest request: + LogRequestHandlerCalled(EndpointName, request.Method); + long requestStartingTimestamp = Stopwatch.GetTimestamp(); + try + { + var result = await HandleRequestAsync(request, cancellationToken).ConfigureAwait(false); + LogRequestHandlerCompleted(EndpointName, request.Method, GetElapsed(requestStartingTimestamp).TotalMilliseconds); + return result; + } + catch (Exception ex) + { + LogRequestHandlerException(EndpointName, request.Method, GetElapsed(requestStartingTimestamp).TotalMilliseconds, ex); + throw; + } + + case JsonRpcNotification notification: + await HandleNotificationAsync(notification, cancellationToken).ConfigureAwait(false); + return null; + + case JsonRpcMessageWithId messageWithId: + HandleMessageWithId(message, messageWithId); + return null; + + default: + LogEndpointHandlerUnexpectedMessageType(EndpointName, message.GetType().Name); + return null; + } + } + + private async Task HandleNotificationAsync(JsonRpcNotification notification, CancellationToken cancellationToken) { // Special-case cancellation to cancel a pending operation. (We'll still subsequently invoke a user-specified handler if one exists.) if (notification.Method == NotificationMethods.CancelledNotification) @@ -387,7 +406,7 @@ private void HandleMessageWithId(JsonRpcMessage message, JsonRpcMessageWithId me } } - private async Task HandleRequest(JsonRpcRequest request, CancellationToken cancellationToken) + private async Task HandleRequestAsync(JsonRpcRequest request, CancellationToken cancellationToken) { if (!_requestHandlers.TryGetValue(request.Method, out var handler)) { @@ -586,26 +605,29 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken can AddTags(ref tags, activity, message, method, target); } - if (_logger.IsEnabled(LogLevel.Trace)) + await _outgoingMessageFilter(async (msg, ct) => { - LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(message, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); - } - else - { - LogSendingMessage(EndpointName); - } + if (_logger.IsEnabled(LogLevel.Trace)) + { + LogSendingMessageSensitive(EndpointName, JsonSerializer.Serialize(msg, McpJsonUtilities.JsonContext.Default.JsonRpcMessage)); + } + else + { + LogSendingMessage(EndpointName); + } - await SendToRelatedTransportAsync(message, cancellationToken).ConfigureAwait(false); + await SendToRelatedTransportAsync(msg, ct).ConfigureAwait(false); - // If the sent notification was a cancellation notification, cancel the pending request's await, as either the - // server won't be sending a response, or per the specification, the response should be ignored. There are inherent - // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. - if (message is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && - GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && - _pendingRequests.TryRemove(cn.RequestId, out var tcs)) - { - tcs.TrySetCanceled(default); - } + // If the sent notification was a cancellation notification, cancel the pending request's await, as either the + // server won't be sending a response, or per the specification, the response should be ignored. There are inherent + // race conditions here, so it's possible and allowed for the operation to complete before we get to this point. + if (msg is JsonRpcNotification { Method: NotificationMethods.CancelledNotification } notification && + GetCancelledNotificationParams(notification.Params) is CancelledNotificationParams cn && + _pendingRequests.TryRemove(cn.RequestId, out var tcs)) + { + tcs.TrySetCanceled(default); + } + })(message, cancellationToken).ConfigureAwait(false); } catch (Exception ex) when (addTags) { diff --git a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs index 0b9cd0416..38d798dc4 100644 --- a/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs +++ b/src/ModelContextProtocol.Core/Protocol/JsonRpcMessageContext.cs @@ -58,4 +58,20 @@ public class JsonRpcMessageContext /// /// public ClaimsPrincipal? User { get; set; } + + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this message. + /// + /// + /// + /// This property allows data to be flowed throughout the message processing pipeline, + /// including from incoming message filters to request-specific filters and handlers. + /// + /// + /// When creating a or for server-side + /// processing, the Items dictionary from this context will be used, ensuring data set in message filters + /// is available in request filters and handlers. + /// + /// + public IDictionary? Items { get; set; } } diff --git a/src/ModelContextProtocol.Core/Server/McpMessageFilter.cs b/src/ModelContextProtocol.Core/Server/McpMessageFilter.cs new file mode 100644 index 000000000..f6ff7dfe3 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpMessageFilter.cs @@ -0,0 +1,30 @@ +namespace ModelContextProtocol.Server; + +/// +/// Delegate type for applying filters to JSON-RPC messages. +/// +/// The next message handler in the pipeline. +/// The next message handler wrapped with the filter. +/// +/// +/// Message filters allow you to intercept and process JSON-RPC messages before they reach +/// their respective handlers (incoming) or before they are sent (outgoing). This is useful for implementing +/// cross-cutting concerns that need to apply to all message types, such as logging, authentication, rate limiting, +/// redaction, or request tracing. +/// +/// +/// Filters are applied in the order they are registered, with the first registered filter being the outermost. +/// Each filter receives the next handler in the pipeline and can choose to: +/// +/// Call the next handler to continue processing (await next(context, cancellationToken)) +/// Skip the default handlers entirely by not calling next +/// Perform operations before and/or after calling next +/// Catch and handle exceptions from inner handlers +/// +/// +/// +/// For request-specific filters, use instead. +/// +/// +public delegate McpMessageHandler McpMessageFilter( + McpMessageHandler next); diff --git a/src/ModelContextProtocol.Core/Server/McpMessageHandler.cs b/src/ModelContextProtocol.Core/Server/McpMessageHandler.cs new file mode 100644 index 000000000..4164588d9 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/McpMessageHandler.cs @@ -0,0 +1,21 @@ +namespace ModelContextProtocol.Server; + +/// +/// Delegate type for handling incoming JSON-RPC messages. +/// +/// The message context containing the JSON-RPC message and other metadata. +/// A cancellation token to cancel the operation. +/// A task representing the asynchronous operation. +/// +/// +/// This delegate can handle any type of JSON-RPC message, including requests, notifications, responses, and errors. +/// Use this for implementing cross-cutting concerns that need to intercept all message types, +/// such as logging, authentication, rate limiting, or request tracing. +/// +/// +/// For request-specific handling, use instead. +/// +/// +public delegate Task McpMessageHandler( + MessageContext context, + CancellationToken cancellationToken); diff --git a/src/ModelContextProtocol.Core/Server/McpServerFilters.cs b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs index 334e1323d..f2fe320f1 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerFilters.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerFilters.cs @@ -11,6 +11,39 @@ namespace ModelContextProtocol.Server; /// public sealed class McpServerFilters { + /// + /// Gets the filters for all incoming JSON-RPC messages. + /// + /// + /// + /// These filters intercept all incoming JSON-RPC messages before they are processed by the server, + /// including requests, notifications, responses, and errors. The filters can perform logging, + /// authentication, rate limiting, or other cross-cutting concerns that apply to all message types. + /// + /// + /// Message filters are applied before request-specific filters. If a message filter does not call + /// the next handler in the pipeline, the default handlers will not be executed. + /// + /// + public List IncomingMessageFilters { get; } = []; + + /// + /// Gets the filters for all outgoing JSON-RPC messages. + /// + /// + /// + /// These filters intercept all outgoing JSON-RPC messages before they are sent to the client, + /// including responses, notifications, and errors. The filters can perform logging, + /// redaction, auditing, or other cross-cutting concerns that apply to all message types. + /// + /// + /// If a message filter does not call the next handler in the pipeline, the message will not be sent. + /// Filters may also call the next handler multiple times with different messages to emit additional + /// server-to-client messages. + /// + /// + public List OutgoingMessageFilters { get; } = []; + /// /// Gets the filters for the list-tools handler pipeline. /// diff --git a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs index d33ee917c..e58857dab 100644 --- a/src/ModelContextProtocol.Core/Server/McpServerImpl.cs +++ b/src/ModelContextProtocol.Core/Server/McpServerImpl.cs @@ -118,7 +118,17 @@ void Register(McpServerPrimitiveCollection? collection, } // And initialize the session. - _sessionHandler = new McpSessionHandler(isServer: true, _sessionTransport, _endpointName!, _requestHandlers, _notificationHandlers, _logger); + var incomingMessageFilter = BuildMessageFilterPipeline(options.Filters.IncomingMessageFilters); + var outgoingMessageFilter = BuildMessageFilterPipeline(options.Filters.OutgoingMessageFilters); + _sessionHandler = new McpSessionHandler( + isServer: true, + _sessionTransport, + _endpointName!, + _requestHandlers, + _notificationHandlers, + incomingMessageFilter, + outgoingMessageFilter, + _logger); } /// @@ -875,6 +885,39 @@ private static McpRequestHandler BuildFilterPipeline filters) + { + if (filters.Count == 0) + { + return next => next; + } + + return next => + { + // Build the handler chain from the filters. + // The innermost handler calls the provided 'next' delegate with the message from the context. + McpMessageHandler baseHandler = async (context, cancellationToken) => + { + await next(context.JsonRpcMessage, cancellationToken).ConfigureAwait(false); + }; + + var current = baseHandler; + for (int i = filters.Count - 1; i >= 0; i--) + { + current = filters[i](current); + } + + // Return the handler that creates a MessageContext and invokes the pipeline. + return async (message, cancellationToken) => + { + // Ensure message has a Context so Items can be shared through the pipeline + message.Context ??= new(); + var context = new MessageContext(new DestinationBoundMcpServer(this, message.Context.RelatedTransport), message); + await current(context, cancellationToken).ConfigureAwait(false); + }; + }; + } + private void UpdateEndpointNameWithClientInfo() { if (ClientInfo is null) diff --git a/src/ModelContextProtocol.Core/Server/MessageContext.cs b/src/ModelContextProtocol.Core/Server/MessageContext.cs new file mode 100644 index 000000000..9426532f5 --- /dev/null +++ b/src/ModelContextProtocol.Core/Server/MessageContext.cs @@ -0,0 +1,108 @@ +using System.Security.Claims; +using ModelContextProtocol.Protocol; + +namespace ModelContextProtocol.Server; + +/// +/// Provides a context container that provides access to the server and resources for processing a JSON-RPC message. +/// +/// +/// +/// The encapsulates contextual information for handling any JSON-RPC message, +/// including requests, responses, notifications, and errors. This is the base class for +/// , which adds request-specific properties. +/// +/// +/// This type is typically received as a parameter in message filter delegates registered via +/// or . +/// +/// +public class MessageContext +{ + /// + /// Initializes a new instance of the class with the specified server and JSON-RPC message. + /// + /// The server with which this instance is associated. + /// The JSON-RPC message associated with this context. + public MessageContext(McpServer server, JsonRpcMessage jsonRpcMessage) + { + Throw.IfNull(server); + Throw.IfNull(jsonRpcMessage); + + Server = server; + JsonRpcMessage = jsonRpcMessage; + Services = server.Services; + } + + /// Gets or sets the server with which this instance is associated. + public McpServer Server + { + get => field; + set + { + Throw.IfNull(value); + field = value; + } + } + + /// + /// Gets or sets a key/value collection that can be used to share data within the scope of this message. + /// + /// + /// + /// This dictionary is shared with the property + /// on the underlying , ensuring that data set in message filters + /// flows through to request-specific filters and handlers. + /// + /// + public IDictionary Items + { + get + { + JsonRpcMessage.Context ??= new(); + return JsonRpcMessage.Context.Items ??= new Dictionary(); + } + set + { + JsonRpcMessage.Context ??= new(); + JsonRpcMessage.Context.Items = value; + } + } + + /// Gets or sets the services associated with this message. + /// + /// This provider might not be the same instance stored in + /// if was true, in which case this + /// might be a scoped derived from the server's + /// . + /// + public IServiceProvider? Services { get; set; } + + /// Gets or sets the user associated with this message. + /// + /// + /// This property is backed by the property + /// on the underlying , ensuring that user information set in message filters + /// flows through to request-specific filters and handlers. + /// + /// + public ClaimsPrincipal? User + { + get => JsonRpcMessage.Context?.User; + set + { + JsonRpcMessage.Context ??= new(); + JsonRpcMessage.Context.User = value; + } + } + + /// + /// Gets the JSON-RPC message associated with this context. + /// + /// + /// This property provides access to the complete JSON-RPC message, + /// including the method name (for requests/notifications), request ID (for requests/responses), + /// and associated transport and user information. + /// + public JsonRpcMessage JsonRpcMessage { get; set; } +} diff --git a/src/ModelContextProtocol.Core/Server/RequestContext.cs b/src/ModelContextProtocol.Core/Server/RequestContext.cs index d27da571a..959bb07da 100644 --- a/src/ModelContextProtocol.Core/Server/RequestContext.cs +++ b/src/ModelContextProtocol.Core/Server/RequestContext.cs @@ -1,4 +1,3 @@ -using System.Security.Claims; using ModelContextProtocol.Protocol; namespace ModelContextProtocol.Server; @@ -12,11 +11,8 @@ namespace ModelContextProtocol.Server; /// This type is typically received as a parameter in handler delegates registered with IMcpServerBuilder, /// and can be injected as parameters into s. /// -public sealed class RequestContext +public sealed class RequestContext : MessageContext { - /// The server with which this instance is associated. - private McpServer _server; - /// /// Initializes a new instance of the class with the specified server and JSON-RPC request. /// @@ -24,48 +20,10 @@ public sealed class RequestContext /// The JSON-RPC request associated with this context. /// or is . public RequestContext(McpServer server, JsonRpcRequest jsonRpcRequest) + : base(server, jsonRpcRequest) { - Throw.IfNull(server); - Throw.IfNull(jsonRpcRequest); - - _server = server; - JsonRpcRequest = jsonRpcRequest; - Services = server.Services; - User = jsonRpcRequest.Context?.User; - } - - /// Gets or sets the server with which this instance is associated. - public McpServer Server - { - get => _server; - set - { - Throw.IfNull(value); - _server = value; - } - } - - /// - /// Gets or sets a key/value collection that can be used to share data within the scope of this request. - /// - public IDictionary Items - { - get => field ??= new Dictionary(); - set => field = value; } - /// Gets or sets the services associated with this request. - /// - /// This provider might not be the same instance stored in - /// if was true, in which case this - /// might be a scoped derived from the server's - /// . - /// - public IServiceProvider? Services { get; set; } - - /// Gets or sets the user associated with this request. - public ClaimsPrincipal? User { get; set; } - /// Gets or sets the parameters associated with this request. public TParams? Params { get; set; } @@ -81,7 +39,11 @@ public McpServer Server /// This property provides access to the complete JSON-RPC request that initiated this handler invocation, /// including the method name, parameters, request ID, and associated transport and user information. /// - public JsonRpcRequest JsonRpcRequest { get; } + public JsonRpcRequest JsonRpcRequest + { + get => (JsonRpcRequest)JsonRpcMessage; + set => JsonRpcMessage = value; + } /// /// Ends the current response and enables polling for updates from the server. diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index f0f94c270..44d6e3cef 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -36,9 +36,12 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio { Debug.Assert(_pendingRequest.Id is null); + message.Context ??= new JsonRpcMessageContext(); + if (message is JsonRpcRequest request) { _pendingRequest = request.Id; + message.Context.RelatedTransport = this; // Invoke the initialize request handler if applicable. if (request.Method == RequestMethods.Initialize) @@ -48,9 +51,6 @@ public async ValueTask HandlePostAsync(JsonRpcMessage message, Cancellatio } } - message.Context ??= new JsonRpcMessageContext(); - message.Context.RelatedTransport = this; - if (parentTransport.FlowExecutionContextFromRequests) { message.Context.ExecutionContext = ExecutionContext.Capture(); diff --git a/src/ModelContextProtocol/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/McpServerBuilderExtensions.cs index e990cfcb9..b77533574 100644 --- a/src/ModelContextProtocol/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/McpServerBuilderExtensions.cs @@ -1105,6 +1105,70 @@ public static IMcpServerBuilder AddSetLoggingLevelFilter(this IMcpServerBuilder builder.Services.Configure(options => options.Filters.SetLoggingLevelFilters.Add(filter)); return builder; } + + /// + /// Adds a filter to intercept all incoming JSON-RPC messages. + /// + /// The builder instance. + /// The filter function that wraps the message handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter intercepts all incoming JSON-RPC messages before they are processed by the server, + /// including requests, notifications, responses, and errors. The filter can perform logging, + /// authentication, rate limiting, or other cross-cutting concerns that apply to all message types. + /// + /// + /// Message filters are applied before request-specific filters. If a message filter does not call + /// the next handler in the pipeline, the default handlers will not be executed. + /// + /// + /// Filters are applied in the order they are registered, with the first registered filter being the outermost. + /// Each filter receives the next handler in the pipeline and can choose to: + /// + /// Call the next handler to continue processing: await next(context, cancellationToken) + /// Skip the default handlers entirely by not calling next + /// Perform operations before and/or after calling next + /// Catch and handle exceptions from inner handlers + /// + /// + /// + public static IMcpServerBuilder AddIncomingMessageFilter(this IMcpServerBuilder builder, McpMessageFilter filter) + { + Throw.IfNull(builder); + Throw.IfNull(filter); + + builder.Services.Configure(options => options.Filters.IncomingMessageFilters.Add(filter)); + return builder; + } + + /// + /// Adds a filter to intercept all outgoing JSON-RPC messages. + /// + /// The builder instance. + /// The filter function that wraps the message handler. + /// The builder provided in . + /// is . + /// + /// + /// This filter intercepts all outgoing JSON-RPC messages before they are sent to the client, + /// including responses, notifications, and errors. The filter can perform logging, redaction, + /// auditing, or other cross-cutting concerns that apply to all message types. + /// + /// + /// If a message filter does not call the next handler in the pipeline, the message will not be sent. + /// Filters may also call the next handler multiple times with different messages to emit additional + /// server-to-client messages. + /// + /// + public static IMcpServerBuilder AddOutgoingMessageFilter(this IMcpServerBuilder builder, McpMessageFilter filter) + { + Throw.IfNull(builder); + + builder.Services.Configure(options => options.Filters.OutgoingMessageFilters.Add(filter)); + return builder; + } #endregion #region Transports diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs index 0bbff49b4..36a005795 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/Utils/KestrelInMemoryTest.cs @@ -7,7 +7,7 @@ namespace ModelContextProtocol.AspNetCore.Tests.Utils; -public class KestrelInMemoryTest : LoggedTest +public abstract class KestrelInMemoryTest : LoggedTest { public KestrelInMemoryTest(ITestOutputHelper testOutputHelper) : base(testOutputHelper) diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs new file mode 100644 index 000000000..8c3b64bde --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsMessageFilterTests.cs @@ -0,0 +1,700 @@ +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Utils; +using System.Security.Claims; +using System.Text.Json.Nodes; + +namespace ModelContextProtocol.Tests.Configuration; + +public class McpServerBuilderExtensionsMessageFilterTests(ITestOutputHelper testOutputHelper) : ClientServerTestBase(testOutputHelper, startServer: false) +{ + private static ILogger GetLogger(IServiceProvider? services, string categoryName) + { + var loggerFactory = services?.GetRequiredService() ?? throw new InvalidOperationException("LoggerFactory not available"); + return loggerFactory.CreateLogger(categoryName); + } + + [Fact] + public async Task AddIncomingMessageFilter_Logs_For_Request() + { + List messageTypes = []; + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + var logger = GetLogger(context.Services, "MessageFilter1"); + logger.LogInformation("MessageFilter1 before"); + + var messageTypeName = context.JsonRpcMessage.GetType().Name; + messageTypes.Add(messageTypeName); + + await next(context, cancellationToken); + + logger.LogInformation("MessageFilter1 after"); + }) + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + var logger = GetLogger(context.Services, "MessageFilter2"); + logger.LogInformation("MessageFilter2 before"); + await next(context, cancellationToken); + logger.LogInformation("MessageFilter2 after"); + }) + .WithTools() + .WithPrompts() + .WithResources(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var beforeMessages = MockLoggerProvider.LogMessages.Where(m => m.Message == "MessageFilter1 before").ToList(); + Assert.True(beforeMessages.Count > 0); + Assert.Equal(LogLevel.Information, beforeMessages[0].LogLevel); + Assert.Equal("MessageFilter1", beforeMessages[0].Category); + + var afterMessages = MockLoggerProvider.LogMessages.Where(m => m.Message == "MessageFilter1 after").ToList(); + Assert.True(afterMessages.Count > 0); + Assert.Equal(LogLevel.Information, afterMessages[0].LogLevel); + Assert.Equal("MessageFilter1", afterMessages[0].Category); + } + + [Fact] + public async Task AddIncomingMessageFilter_Intercepts_Request_Messages() + { + List messageTypes = []; + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + var messageTypeName = context.JsonRpcMessage.GetType().Name; + messageTypes.Add(messageTypeName); + await next(context, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // The message filter should intercept JsonRpcRequest messages + Assert.Contains("JsonRpcRequest", messageTypes); + } + + [Fact] + public async Task AddIncomingMessageFilter_Multiple_Filters_Execute_In_Order() + { + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + var logger = GetLogger(context.Services, "MessageFilter1"); + logger.LogInformation("MessageFilter1 before"); + await next(context, cancellationToken); + logger.LogInformation("MessageFilter1 after"); + }) + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + var logger = GetLogger(context.Services, "MessageFilter2"); + logger.LogInformation("MessageFilter2 before"); + await next(context, cancellationToken); + logger.LogInformation("MessageFilter2 after"); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var logMessages = MockLoggerProvider.LogMessages + .Where(m => m.Category.StartsWith("MessageFilter")) + .Select(m => m.Message) + .ToList(); + + // First filter registered is outermost + // We should see this pattern for each message: MessageFilter1 before -> MessageFilter2 before -> MessageFilter2 after -> MessageFilter1 after + int idx1Before = logMessages.IndexOf("MessageFilter1 before"); + int idx2Before = logMessages.IndexOf("MessageFilter2 before"); + int idx2After = logMessages.IndexOf("MessageFilter2 after"); + int idx1After = logMessages.IndexOf("MessageFilter1 after"); + + Assert.True(idx1Before >= 0); + Assert.True(idx2Before >= 0); + Assert.True(idx2After >= 0); + Assert.True(idx1After >= 0); + + // Verify ordering within a single request + Assert.True(idx1Before < idx2Before); + Assert.True(idx2Before < idx2After); + Assert.True(idx2After < idx1After); + } + + [Fact] + public async Task AddIncomingMessageFilter_Has_Access_To_Server() + { + McpServer? capturedServer = null; + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + capturedServer = context.Server; + await next(context, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // The captured server is a per-destination wrapper that provides the same functionality + Assert.NotNull(capturedServer); + Assert.NotNull(capturedServer.ServerOptions); + } + + [Fact] + public async Task AddIncomingMessageFilter_Items_Dictionary_Can_Be_Used() + { + string? capturedValue = null; + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + context.Items["testKey"] = "testValue"; + await next(context, cancellationToken); + }) + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + if (context.Items.TryGetValue("testKey", out var value)) + { + capturedValue = value as string; + } + await next(context, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("testValue", capturedValue); + } + + [Fact] + public async Task AddIncomingMessageFilter_Can_Access_JsonRpcMessage_Details() + { + string? capturedMethod = null; + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) + { + capturedMethod = request.Method; + } + await next(context, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal(RequestMethods.ToolsList, capturedMethod); + } + + [Fact] + public async Task AddIncomingMessageFilter_Exception_Propagates_Properly() + { + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + // Only throw for tools/list, not for initialize/initialized + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) + { + throw new InvalidOperationException("Filter exception"); + } + await next(context, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + var exception = await Assert.ThrowsAsync(async () => + { + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + }); + + Assert.Contains("error", exception.Message, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task AddIncomingMessageFilter_Runs_Before_Request_Specific_Filters() + { + var executionOrder = new List(); + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) + { + executionOrder.Add("MessageFilter"); + } + await next(context, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + executionOrder.Add("ListToolsFilter"); + return await next(request, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + // Message filter should run before the request-specific filter + Assert.Equal(2, executionOrder.Count); + Assert.Equal("MessageFilter", executionOrder[0]); + Assert.Equal("ListToolsFilter", executionOrder[1]); + } + + [Fact] + public async Task AddIncomingMessageFilter_Can_Skip_Default_Handlers() + { + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + // Skip calling next for tools/list + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) + { + // Don't call next - this will skip the default handler + return; + } + await next(context, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + // When default handlers are skipped, the request should time out + // because no response will be sent + using var requestCts = new CancellationTokenSource(TimeSpan.FromSeconds(2)); + await Assert.ThrowsAnyAsync(async () => + { + await client.ListToolsAsync(cancellationToken: requestCts.Token); + }); + } + + [Fact] + public async Task AddOutgoingMessageFilter_Sees_Initialize_Progress_And_Response() + { + var observedMessages = new List(); + + McpServerBuilder + .AddOutgoingMessageFilter((next) => async (context, cancellationToken) => + { + switch (context.JsonRpcMessage) + { + case JsonRpcResponse response when response.Result is JsonObject result: + if (result.ContainsKey("protocolVersion")) + { + observedMessages.Add("initialize"); + } + else if (result.ContainsKey("content")) + { + observedMessages.Add("response"); + } + break; + case JsonRpcNotification notification when notification.Method == NotificationMethods.ProgressNotification: + observedMessages.Add("progress"); + break; + } + + await next(context, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + IProgress progress = new Progress(_ => { }); + await client.CallToolAsync("progress-tool", progress: progress, cancellationToken: TestContext.Current.CancellationToken); + + int initializeIndex = observedMessages.IndexOf("initialize"); + int progressIndex = observedMessages.IndexOf("progress"); + int responseIndex = observedMessages.LastIndexOf("response"); + + Assert.True(initializeIndex >= 0); + Assert.True(progressIndex > initializeIndex); + Assert.True(responseIndex > progressIndex); + } + + [Fact] + public async Task AddOutgoingMessageFilter_Can_Skip_Sending_Messages() + { + McpServerBuilder + .AddOutgoingMessageFilter((next) => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcResponse response && response.Result is JsonObject result && result.ContainsKey("tools")) + { + return; + } + + await next(context, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + using var requestCts = new CancellationTokenSource(TimeSpan.FromSeconds(2)); + await Assert.ThrowsAnyAsync(async () => + { + await client.ListToolsAsync(cancellationToken: requestCts.Token); + }); + } + + [Fact] + public async Task AddOutgoingMessageFilter_Can_Send_Additional_Messages() + { + McpServerBuilder + .AddOutgoingMessageFilter((next) => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcResponse response && response.Result is JsonObject result && result.ContainsKey("tools")) + { + var extraNotification = new JsonRpcNotification + { + Method = "test/extra", + Params = new JsonObject { ["message"] = "extra" }, + Context = new JsonRpcMessageContext { RelatedTransport = context.JsonRpcMessage.Context?.RelatedTransport }, + }; + + await next(new MessageContext(context.Server, extraNotification), cancellationToken); + } + + await next(context, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + var extraNotificationReceived = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await using var registration = client.RegisterNotificationHandler("test/extra", (notification, _) => + { + extraNotificationReceived.TrySetResult(notification.Params?["message"]?.GetValue()); + return default; + }); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + var extraMessage = await extraNotificationReceived.Task.WaitAsync(TimeSpan.FromSeconds(2), TestContext.Current.CancellationToken); + + Assert.Equal("extra", extraMessage); + } + + [Fact] + public async Task AddIncomingMessageFilter_Items_Flow_To_Request_Filters() + { + string? capturedValue = null; + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + // Set an item in the message filter + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) + { + context.Items["messageFilterKey"] = "messageFilterValue"; + } + await next(context, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + // Read the item in the request-specific filter + if (request.Items.TryGetValue("messageFilterKey", out var value)) + { + capturedValue = value as string; + } + return await next(request, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal("messageFilterValue", capturedValue); + } + + [Fact] + public async Task AddIncomingMessageFilter_Items_Flow_To_CallTool_Handler() + { + object? capturedValue = null; + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + // Set an item in the message filter for CallTool requests + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsCall) + { + context.Items["toolContextKey"] = 42; + } + await next(context, cancellationToken); + }) + .AddCallToolFilter((next) => async (request, cancellationToken) => + { + // Read the item in the call tool filter + if (request.Items.TryGetValue("toolContextKey", out var value)) + { + capturedValue = value; + } + return await next(request, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.CallToolAsync("simple-tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal(42, capturedValue); + } + + [Fact] + public async Task AddIncomingMessageFilter_User_Flows_To_CallTool_Handler() + { + ClaimsPrincipal? capturedUser = null; + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + // Set a custom user in the message filter for CallTool requests + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsCall) + { + var claims = new[] { new Claim(ClaimTypes.Name, "TestUser"), new Claim(ClaimTypes.Role, "Admin") }; + var identity = new ClaimsIdentity(claims, "TestAuth"); + context.User = new ClaimsPrincipal(identity); + } + await next(context, cancellationToken); + }) + .AddCallToolFilter((next) => async (request, cancellationToken) => + { + // Read the user in the call tool filter + capturedUser = request.User; + return await next(request, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.CallToolAsync("simple-tool", cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(capturedUser); + Assert.Equal("TestUser", capturedUser.Identity?.Name); + Assert.True(capturedUser.IsInRole("Admin")); + } + + [Fact] + public async Task AddIncomingMessageFilter_Items_Preserved_When_Context_Replaced() + { + object? firstFilterValue = null; + object? secondFilterValue = null; + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + // First filter sets an item + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) + { + context.Items["firstFilterKey"] = "firstFilterValue"; + } + await next(context, cancellationToken); + }) + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + // Second filter creates a new context with a new JsonRpcRequest and adds an item + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) + { + var newRequest = new JsonRpcRequest + { + Id = request.Id, + Method = RequestMethods.ToolsList, + Params = request.Params, + Context = new JsonRpcMessageContext { RelatedTransport = request.Context?.RelatedTransport }, + }; + + var newContext = new MessageContext(context.Server, newRequest); + newContext.Items["secondFilterKey"] = "secondFilterValue"; + + await next(newContext, cancellationToken); + return; + } + await next(context, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + // Request filter should see items from message filters + request.Items.TryGetValue("firstFilterKey", out firstFilterValue); + request.Items.TryGetValue("secondFilterKey", out secondFilterValue); + return await next(request, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Null(firstFilterValue); + Assert.Equal("secondFilterValue", secondFilterValue); + } + + [Fact] + public async Task AddIncomingMessageFilter_Items_Flow_Through_Multiple_Request_Filters() + { + var observedValues = new List(); + + McpServerBuilder + .AddIncomingMessageFilter((next) => async (context, cancellationToken) => + { + if (context.JsonRpcMessage is JsonRpcRequest request && request.Method == RequestMethods.ToolsList) + { + context.Items["sharedKey"] = "fromMessageFilter"; + } + await next(context, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + // First request filter reads and modifies + if (request.Items.TryGetValue("sharedKey", out var value)) + { + observedValues.Add((string)value!); + request.Items["sharedKey"] = "modifiedByFilter1"; + } + return await next(request, cancellationToken); + }) + .AddListToolsFilter((next) => async (request, cancellationToken) => + { + // Second request filter should see modified value + if (request.Items.TryGetValue("sharedKey", out var value)) + { + observedValues.Add((string)value!); + } + return await next(request, cancellationToken); + }) + .WithTools(); + + StartServer(); + + await using McpClient client = await CreateMcpClientForServer(); + + await client.ListToolsAsync(cancellationToken: TestContext.Current.CancellationToken); + + Assert.Equal(2, observedValues.Count); + Assert.Equal("fromMessageFilter", observedValues[0]); + Assert.Equal("modifiedByFilter1", observedValues[1]); + } + + [McpServerToolType] + public sealed class TestTool + { + [McpServerTool] + public static string TestToolMethod() + { + return "test result"; + } + } + + [McpServerPromptType] + public sealed class TestPrompt + { + [McpServerPrompt] + public static Task TestPromptMethod() + { + return Task.FromResult(new GetPromptResult + { + Description = "Test prompt", + Messages = [new() { Role = Role.User, Content = new TextContentBlock { Text = "Test" } }] + }); + } + } + + [McpServerResourceType] + public sealed class TestResource + { + [McpServerResource(UriTemplate = "test://resource/{id}")] + public static string TestResourceMethod(string id) + { + return $"Test resource for ID: {id}"; + } + } + + [McpServerToolType] + public sealed class ProgressTool + { + [McpServerTool(Name = "progress-tool")] + public static async Task ReportProgress( + McpServer server, + RequestContext context, + CancellationToken cancellationToken) + { + if (context.Params?.ProgressToken is { } token) + { + await server.NotifyProgressAsync(token, new ProgressNotificationValue + { + Progress = 0, + Total = 2, + Message = "starting", + }, cancellationToken: cancellationToken); + + await server.NotifyProgressAsync(token, new ProgressNotificationValue + { + Progress = 1, + Total = 2, + Message = "running", + }, cancellationToken: cancellationToken); + } + + return "done"; + } + } + + [McpServerToolType] + public sealed class SimpleTool + { + [McpServerTool(Name = "simple-tool")] + public static string Execute() + { + return "success"; + } + } +}