From 3962a0bbbcbd15ca562093fb29f19bd5ce0cc7e4 Mon Sep 17 00:00:00 2001 From: Mike Kistler Date: Sun, 25 Jan 2026 12:22:27 -0600 Subject: [PATCH 1/7] Add test to make sure all client conformance tests are covered --- .../ClientConformanceTests.cs | 81 +++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs index 7a92cd67c..d81b3ea70 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs @@ -21,9 +21,82 @@ public ClientConformanceTests(ITestOutputHelper output) _output = output; } + // Expected scenarios based on InlineData attributes below + private static readonly string[] ExpectedScenarios = [ + "initialize", + "tools_call", + "elicitation-sep1034-client-defaults", + "sse-retry", + "auth/metadata-default", + "auth/metadata-var1", + "auth/metadata-var2", + "auth/metadata-var3", + "auth/basic-cimd", + "auth/2025-03-26-oauth-metadata-backcompat", // Expected but not required to pass + "auth/2025-03-26-oauth-endpoint-fallback", // Expected but not required to pass + "auth/scope-from-www-authenticate", + "auth/scope-from-scopes-supported", + "auth/scope-omitted-when-undefined", + "auth/scope-step-up", + "auth/scope-retry-limit", + "auth/token-endpoint-auth-basic", + "auth/token-endpoint-auth-post", + "auth/token-endpoint-auth-none", + "auth/resource-mismatch", + "auth/pre-registration", + "auth/client-credentials-jwt", + "auth/client-credentials-basic" + ]; + + [Fact(Skip = "npx is not installed. Skipping client conformance tests.", SkipUnless = nameof(IsNpxInstalled))] + public async Task VerifyAllConformanceTestsAreListed() + { + // Get the list of available conformance tests from the suite + var startInfo = NodeHelpers.NpxStartInfo("-y @modelcontextprotocol/conformance list --client"); + + var outputBuilder = new StringBuilder(); + var process = new Process { StartInfo = startInfo }; + + process.OutputDataReceived += (sender, e) => + { + if (e.Data != null) + { + outputBuilder.AppendLine(e.Data); + } + }; + + process.Start(); + process.BeginOutputReadLine(); + await process.WaitForExitAsync(TestContext.Current.CancellationToken); + + Assert.True(process.ExitCode == 0, "Failed to list conformance tests"); + + var output = outputBuilder.ToString(); + var availableScenarios = output + .Split('\n', StringSplitOptions.RemoveEmptyEntries) + .Select(line => line.Trim()) + .Where(line => line.StartsWith("- ")) + .Select(line => line.Substring(2).Trim()) + .ToHashSet(); + + // Verify all expected scenarios are available + var missingScenarios = ExpectedScenarios.Except(availableScenarios).ToList(); + Assert.Empty(missingScenarios); + + // Verify we haven't missed any new scenarios + var newScenarios = availableScenarios.Except(ExpectedScenarios).ToList(); + if (newScenarios.Any()) + { + var newScenariosMessage = string.Join("\r\n - ", newScenarios); + Assert.Fail($"New conformance scenarios detected. Add these to ExpectedScenarios and the Theory:\r\n - {newScenariosMessage}"); + } + } + [Theory(Skip = "npx is not installed. Skipping client conformance tests.", SkipUnless = nameof(IsNpxInstalled))] [InlineData("initialize")] [InlineData("tools_call")] + [InlineData("elicitation-sep1034-client-defaults")] + [InlineData("sse-retry")] [InlineData("auth/metadata-default")] [InlineData("auth/metadata-var1")] [InlineData("auth/metadata-var2")] @@ -35,6 +108,14 @@ public ClientConformanceTests(ITestOutputHelper output) [InlineData("auth/scope-from-scopes-supported")] [InlineData("auth/scope-omitted-when-undefined")] [InlineData("auth/scope-step-up")] + [InlineData("auth/scope-retry-limit")] + [InlineData("auth/token-endpoint-auth-basic")] + [InlineData("auth/token-endpoint-auth-post")] + [InlineData("auth/token-endpoint-auth-none")] + [InlineData("auth/resource-mismatch")] + [InlineData("auth/pre-registration")] + [InlineData("auth/client-credentials-jwt")] + [InlineData("auth/client-credentials-basic")] public async Task RunConformanceTest(string scenario) { // Run the conformance test suite From 84b7fb85a9de993294349c48fbc32d14771bf9bb Mon Sep 17 00:00:00 2001 From: Mike Kistler Date: Sun, 25 Jan 2026 16:11:28 -0600 Subject: [PATCH 2/7] Fix client conformance test elicitation-sep1034-client-defaults --- .../Program.cs | 87 ++++++++++++++++--- 1 file changed, 77 insertions(+), 10 deletions(-) diff --git a/tests/ModelContextProtocol.ConformanceClient/Program.cs b/tests/ModelContextProtocol.ConformanceClient/Program.cs index e2f09e88f..d75f64682 100644 --- a/tests/ModelContextProtocol.ConformanceClient/Program.cs +++ b/tests/ModelContextProtocol.ConformanceClient/Program.cs @@ -1,9 +1,10 @@ using System.Net; using System.Net.Sockets; -using System.Text; +using System.Text.Json; using System.Web; using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol; // This program expects the following command-line arguments: // 1. The client conformance test scenario to run (e.g., "tools_call") @@ -18,15 +19,6 @@ var scenario = args[0]; var endpoint = args[1]; -McpClientOptions options = new() -{ - ClientInfo = new() - { - Name = "ConformanceClient", - Version = "1.0.0" - } -}; - var consoleLoggerFactory = LoggerFactory.Create(builder => { builder.AddConsole(); @@ -67,6 +59,31 @@ } }, loggerFactory: consoleLoggerFactory); +// Wrapper delegate pattern: allows setting elicitation handler after client creation +// This allows the actual handler to be set dynamically based on scenario +Func>? elicitationHandler = null; + +McpClientOptions options = new() +{ + ClientInfo = new() + { + Name = "ConformanceClient", + Version = "1.0.0" + }, + Handlers = new() + { + ElicitationHandler = (request, cancellationToken) => + { + if (elicitationHandler is not null) + { + return elicitationHandler(request, cancellationToken); + } + Console.WriteLine("No elicitation handler set, rejecting by default"); + return ValueTask.FromResult(new ElicitResult()); // default - reject + } + } +}; + await using var mcpClient = await McpClient.CreateAsync(clientTransport, options, loggerFactory: consoleLoggerFactory); bool success = true; @@ -105,6 +122,56 @@ success &= !(result.IsError == true); break; } + case "elicitation-sep1034-client-defaults": + { + // In this test scenario, an elicitation request will be made that includes default values in the schema. + // The client should apply these defaults to demonstrate that it received and processed them correctly. + + // Set the elicitation handler dynamically for this scenario + elicitationHandler = (request, cancellationToken) => + { + Console.WriteLine($"Received elicitation request: {request?.Message}"); + + // Apply default values from the schema + var content = new Dictionary(); + + if (request?.RequestedSchema?.Properties is not null) + { + foreach (var (key, schema) in request.RequestedSchema.Properties) + { + switch (schema) + { + case ElicitRequestParams.StringSchema stringSchema when stringSchema.Default is not null: + content[key] = JsonSerializer.SerializeToElement(stringSchema.Default); + break; + case ElicitRequestParams.NumberSchema numberSchema when numberSchema.Default.HasValue: + content[key] = JsonSerializer.SerializeToElement(numberSchema.Default.Value); + break; + case ElicitRequestParams.BooleanSchema booleanSchema when booleanSchema.Default.HasValue: + content[key] = JsonSerializer.SerializeToElement(booleanSchema.Default.Value); + break; + case ElicitRequestParams.UntitledSingleSelectEnumSchema enumSchema when enumSchema.Default is not null: + content[key] = JsonSerializer.SerializeToElement(enumSchema.Default); + break; + case ElicitRequestParams.TitledSingleSelectEnumSchema titledEnumSchema when titledEnumSchema.Default is not null: + content[key] = JsonSerializer.SerializeToElement(titledEnumSchema.Default); + break; + } + } + } + + return new ValueTask(new ElicitResult { Action = "accept", Content = content }); + }; + + // Call the test_client_elicitation_defaults tool + var testToolName = "test_client_elicitation_defaults"; + Console.WriteLine($"Calling tool: {testToolName}"); + var result = await mcpClient.CallToolAsync(toolName: testToolName, arguments: new Dictionary()); + Console.WriteLine($"Tool result: {result}"); + success &= !(result.IsError == true); + + break; + } default: // No extra processing for other scenarios break; From 595942d2416abc493609705ea6d27515eda19bc8 Mon Sep 17 00:00:00 2001 From: Mike Kistler Date: Mon, 26 Jan 2026 10:33:07 -0600 Subject: [PATCH 3/7] Disable auth/resource-mismatch pending fix in conformance suite --- .../ClientConformanceTests.cs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs index d81b3ea70..eb45df84c 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs @@ -112,7 +112,8 @@ public async Task VerifyAllConformanceTestsAreListed() [InlineData("auth/token-endpoint-auth-basic")] [InlineData("auth/token-endpoint-auth-post")] [InlineData("auth/token-endpoint-auth-none")] - [InlineData("auth/resource-mismatch")] + // auth/resource-mismatch disabled for now pending https://github.com/modelcontextprotocol/conformance/pull/123 + // [InlineData("auth/resource-mismatch")] [InlineData("auth/pre-registration")] [InlineData("auth/client-credentials-jwt")] [InlineData("auth/client-credentials-basic")] From 008d07ca731a6985a62229280449d9c74b69bb62 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 27 Jan 2026 23:26:29 -0500 Subject: [PATCH 4/7] Update to 0.1.11 conformance tests --- Directory.Packages.props | 2 +- .../Polyfills/System/IO/StreamExtensions.cs | 1 - .../AuthorizationFilterSetup.cs | 1 - .../DnsRebindingProtectionMiddleware.cs | 113 +++++ .../McpApplicationBuilderExtensions.cs | 47 +++ .../Authentication/ClientOAuthOptions.cs | 19 + .../Authentication/ClientOAuthProvider.cs | 388 ++++++++++++++++-- .../Client/SseClientSessionTransport.cs | 148 +++++-- .../ClientConformanceTests.cs | 79 +++- .../DnsRebindingProtectionMiddlewareTests.cs | 112 +++++ .../Program.cs | 133 +++++- ...elContextProtocol.ConformanceServer.csproj | 1 + .../Program.cs | 2 + .../Tools/ConformanceTools.cs | 55 ++- 14 files changed, 1007 insertions(+), 94 deletions(-) create mode 100644 src/ModelContextProtocol.AspNetCore/DnsRebindingProtectionMiddleware.cs create mode 100644 src/ModelContextProtocol.AspNetCore/McpApplicationBuilderExtensions.cs create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/DnsRebindingProtectionMiddlewareTests.cs diff --git a/Directory.Packages.props b/Directory.Packages.props index b5f599099..d919103ca 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -7,7 +7,7 @@ 10.2.0 - 0.1.10 + 0.1.11 diff --git a/src/Common/Polyfills/System/IO/StreamExtensions.cs b/src/Common/Polyfills/System/IO/StreamExtensions.cs index 452b80321..321f4f766 100644 --- a/src/Common/Polyfills/System/IO/StreamExtensions.cs +++ b/src/Common/Polyfills/System/IO/StreamExtensions.cs @@ -1,7 +1,6 @@ using ModelContextProtocol; using System.Buffers; using System.Runtime.InteropServices; -using System.Text; #if !NET namespace System.IO; diff --git a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs index ae5e42dd8..4795fb88f 100644 --- a/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs +++ b/src/ModelContextProtocol.AspNetCore/AuthorizationFilterSetup.cs @@ -3,7 +3,6 @@ using Microsoft.AspNetCore.Authorization; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Options; -using ModelContextProtocol.Protocol; using ModelContextProtocol.Server; namespace ModelContextProtocol.AspNetCore; diff --git a/src/ModelContextProtocol.AspNetCore/DnsRebindingProtectionMiddleware.cs b/src/ModelContextProtocol.AspNetCore/DnsRebindingProtectionMiddleware.cs new file mode 100644 index 000000000..7850dbb95 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/DnsRebindingProtectionMiddleware.cs @@ -0,0 +1,113 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Net.Http.Headers; +using System.Net; +using System.Text.Json; + +namespace ModelContextProtocol.AspNetCore; + +/// +/// Middleware that provides DNS rebinding protection for MCP servers by validating +/// Host and Origin headers on requests to localhost servers. +/// +/// +/// +/// DNS rebinding attacks can allow malicious websites to bypass browser same-origin policy +/// and make requests to localhost services. This middleware helps protect against such attacks +/// by validating that Host and Origin headers match expected localhost values. +/// +/// +/// Use to enable this middleware. +/// +/// +/// +/// Initializes a new instance of the class. +/// +internal sealed partial class DnsRebindingProtectionMiddleware( + RequestDelegate next, + ILogger logger) +{ + private readonly RequestDelegate _next = next; + private readonly ILogger _logger = logger; + + /// + /// Processes the HTTP request and validates Host and Origin headers for localhost servers. + /// + public async Task InvokeAsync(HttpContext context) + { + // Only apply protection to localhost servers + var localEndpoint = context.Connection.LocalIpAddress; + bool isLocalhostServer = localEndpoint is null || + IPAddress.IsLoopback(localEndpoint) || + localEndpoint.Equals(IPAddress.IPv6Loopback); + + if (isLocalhostServer) + { + // Validate Host header + var host = context.Request.Host.Host; + if (!IsLocalhost(host)) + { + LogInvalidHostHeader(host); + await WriteJsonRpcErrorResponseAsync(context, $"Forbidden: Invalid Host header '{host}' for localhost server"); + return; + } + + // Validate Origin header if present + if (context.Request.Headers.TryGetValue(HeaderNames.Origin, out var originValues) && + originValues.FirstOrDefault() is string origin && + Uri.TryCreate(origin, UriKind.Absolute, out var originUri) && + !IsLocalhost(originUri.Host)) + { + LogInvalidOriginHeader(origin); + await WriteJsonRpcErrorResponseAsync(context, $"Forbidden: Invalid Origin header '{origin}' for localhost server"); + return; + } + } + + await _next(context).ConfigureAwait(false); + } + + private static bool IsLocalhost(string host) + { + if (!string.IsNullOrWhiteSpace(host)) + { + if (host.Equals("localhost", StringComparison.OrdinalIgnoreCase) || + host.Equals("[::1]") || + host.Equals("127.0.0.1")) + { + return true; + } + + if (IPAddress.TryParse(host, out var ip)) + { + return IPAddress.IsLoopback(ip); + } + } + + return false; + } + + private static Task WriteJsonRpcErrorResponseAsync(HttpContext context, string message) + { + context.Response.StatusCode = StatusCodes.Status403Forbidden; + context.Response.ContentType = "application/json"; + return context.Response.WriteAsync($$""" + { + "jsonrpc": "2.0", + "error": + { + "code": -32000, + "message": "{{JsonEncodedText.Encode(message)}}" + }, + "id": null + } + """); + } + + [LoggerMessage(Level = LogLevel.Warning, Message = "Rejected request with invalid Host header '{Host}' for localhost server. This may indicate a DNS rebinding attack.")] + private partial void LogInvalidHostHeader(string? host); + + [LoggerMessage(Level = LogLevel.Warning, Message = "Rejected request with invalid Origin header '{Origin}' for localhost server. This may indicate a DNS rebinding attack.")] + private partial void LogInvalidOriginHeader(string origin); +} diff --git a/src/ModelContextProtocol.AspNetCore/McpApplicationBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpApplicationBuilderExtensions.cs new file mode 100644 index 000000000..e3c7fa6d9 --- /dev/null +++ b/src/ModelContextProtocol.AspNetCore/McpApplicationBuilderExtensions.cs @@ -0,0 +1,47 @@ +using ModelContextProtocol.AspNetCore; + +namespace Microsoft.AspNetCore.Builder; + +/// +/// Extension methods for adding MCP middleware to an . +/// +public static class McpApplicationBuilderExtensions +{ + /// + /// Adds DNS rebinding protection middleware for MCP servers running on localhost. + /// + /// The . + /// The for chaining. + /// + /// + /// This method provides protection against DNS rebinding attacks by validating that both + /// Host and Origin headers (when present) resolve to localhost addresses. + /// + /// + /// DNS rebinding attacks can allow malicious websites to bypass browser same-origin policy and make requests + /// to localhost services. This protection is recommended for any MCP server that binds to localhost. + /// + /// + /// For more information, see the MCP SDK security advisory. + /// + /// + /// + /// + /// var builder = WebApplication.CreateBuilder(args); + /// builder.Services.AddMcpServer().WithHttpTransport(); + /// + /// var app = builder.Build(); + /// app.UseMcpDnsRebindingProtection(); // Add before MapMcp() + /// app.MapMcp(); + /// app.Run(); + /// + /// + public static IApplicationBuilder UseMcpDnsRebindingProtection(this IApplicationBuilder app) + { + ArgumentNullException.ThrowIfNull(app); + + app.UseMiddleware(); + + return app; + } +} diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs index 483e3643e..e9fc19135 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthOptions.cs @@ -23,6 +23,25 @@ public sealed class ClientOAuthOptions /// public string? ClientSecret { get; set; } + /// + /// Gets or sets the private key in PEM format for JWT client assertion (private_key_jwt). + /// + /// + /// When provided along with , the client will use JWT client + /// assertion (private_key_jwt) for token endpoint authentication instead of client_secret. + /// This is typically used for machine-to-machine authentication with client_credentials grant. + /// + public string? JwtPrivateKeyPem { get; set; } + + /// + /// Gets or sets the signing algorithm for JWT client assertion. + /// + /// + /// Common values include "RS256", "RS384", "RS512", "ES256", "ES384", "ES512". + /// This property is only used when is provided. + /// + public string? JwtSigningAlgorithm { get; set; } + /// /// Gets or sets the HTTPS URL pointing to this client's metadata document. /// diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 75126556b..7dff362bd 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -39,6 +39,10 @@ internal sealed partial class ClientOAuthProvider : McpHttpClient private readonly string? _dcrInitialAccessToken; private readonly Func? _dcrResponseDelegate; + // JWT client assertion support (private_key_jwt) + private readonly string? _jwtPrivateKeyPem; + private readonly string? _jwtSigningAlgorithm; + private readonly HttpClient _httpClient; private readonly ILogger _logger; @@ -46,6 +50,12 @@ internal sealed partial class ClientOAuthProvider : McpHttpClient private string? _clientSecret; private ITokenCache _tokenCache; private AuthorizationServerMetadata? _authServerMetadata; + private int _scopeStepUpCount; + + /// + /// Maximum number of scope step-up retries before failing. + /// + private const int MaxScopeStepUpRetries = 3; /// /// Initializes a new instance of the class using the specified options. @@ -89,8 +99,26 @@ public ClientOAuthProvider( _dcrInitialAccessToken = options.DynamicClientRegistration?.InitialAccessToken; _dcrResponseDelegate = options.DynamicClientRegistration?.ResponseDelegate; _tokenCache = options.TokenCache ?? new InMemoryTokenCache(); + + // JWT client assertion support + _jwtPrivateKeyPem = options.JwtPrivateKeyPem; + _jwtSigningAlgorithm = options.JwtSigningAlgorithm; + + // Validate JWT signing algorithm if provided + if (_jwtSigningAlgorithm is not null && + !s_jwtSigningAlgorithms.Contains(_jwtSigningAlgorithm)) + { + throw new ArgumentException($"JWT signing algorithm '{_jwtSigningAlgorithm}' is not supported.", nameof(options)); + } } + private static readonly HashSet s_jwtSigningAlgorithms = new(StringComparer.OrdinalIgnoreCase) + { + "ES256", "ES384", "ES512", + "RS256", "RS384", "RS512", + "PS256", "PS384", "PS512" + }; + /// /// Default authorization server selection strategy that selects the first available server. /// @@ -132,7 +160,8 @@ internal override async Task SendAsync(HttpRequestMessage r var response = await base.SendAsync(request, message, cancellationToken).ConfigureAwait(false); - if (ShouldRetryWithNewAccessToken(response)) + if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized || + HasInsufficientScopeError(response)) { return await HandleUnauthorizedResponseAsync(request, message, response, attemptedRefresh, cancellationToken).ConfigureAwait(false); } @@ -161,14 +190,12 @@ internal override async Task SendAsync(HttpRequestMessage r return (null, false); } - private static bool ShouldRetryWithNewAccessToken(HttpResponseMessage response) + /// + /// Checks if the response contains an insufficient_scope error (403 Forbidden with error=insufficient_scope). + /// + private static bool HasInsufficientScopeError(HttpResponseMessage response) { - if (response.StatusCode == System.Net.HttpStatusCode.Unauthorized) - { - return true; - } - - // Only retry 403 Forbidden if it contains an insufficient_scope error as described in Section 10.1.1 of the MCP specification + // Only 403 Forbidden responses can have insufficient_scope error // https://modelcontextprotocol.io/specification/2025-11-25/basic/authorization#runtime-insufficient-scope-errors if (response.StatusCode != System.Net.HttpStatusCode.Forbidden) { @@ -222,7 +249,10 @@ private async Task HandleUnauthorizedResponseAsync( } retryRequest.Headers.Authorization = new AuthenticationHeaderValue(BearerScheme, accessToken); - return await base.SendAsync(retryRequest, originalJsonRpcMessage, cancellationToken).ConfigureAwait(false); + + // Use SendAsync (not base.SendAsync) to enable retry logic for scope step-up scenarios + // where the server may respond with 403 (insufficient_scope) multiple times + return await SendAsync(retryRequest, originalJsonRpcMessage, cancellationToken).ConfigureAwait(false); } /// @@ -233,6 +263,20 @@ private async Task HandleUnauthorizedResponseAsync( /// The to monitor for cancellation requests. private async Task GetAccessTokenAsync(HttpResponseMessage response, bool attemptedRefresh, CancellationToken cancellationToken) { + // Check if this is a scope step-up retry (403 with insufficient_scope) + if (HasInsufficientScopeError(response)) + { + if (++_scopeStepUpCount >= MaxScopeStepUpRetries) + { + ThrowFailedToHandleUnauthorizedResponse($"Maximum scope step-up retry limit ({MaxScopeStepUpRetries}) exceeded."); + } + } + else + { + // Reset the counter for non-scope-step-up requests (e.g., initial auth or token expiry) + _scopeStepUpCount = 0; + } + // Get available authorization servers from the 401 or 403 response var protectedResourceMetadata = await ExtractProtectedResourceMetadata(response, cancellationToken).ConfigureAwait(false); var availableAuthorizationServers = protectedResourceMetadata.AuthorizationServers; @@ -282,7 +326,7 @@ await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false) is { R } } - // Assign a client ID if necessary + // Skip dynamic registration if we have pre-registered credentials (ClientId + ClientSecret) if (string.IsNullOrEmpty(_clientId)) { // Try using a client metadata document before falling back to dynamic client registration @@ -296,10 +340,246 @@ await _tokenCache.GetTokensAsync(cancellationToken).ConfigureAwait(false) is { R } } - // Perform the OAuth flow + // Check if client_credentials grant type should be used. + // Use client_credentials when: + // 1. The server supports client_credentials grant type. + // 2. We have a client secret (confidential client). + // 3. No AuthorizationRedirectDelegate was explicitly provided (machine-to-machine flow). + if (ShouldUseClientCredentialsGrant(authServerMetadata)) + { + return await InitiateClientCredentialsFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); + } + + // Perform the OAuth authorization code flow return await InitiateAuthorizationCodeFlowAsync(protectedResourceMetadata, authServerMetadata, cancellationToken).ConfigureAwait(false); } + /// + /// Determines whether to use the client_credentials grant type. + /// + private bool ShouldUseClientCredentialsGrant(AuthorizationServerMetadata authServerMetadata) + { + // Must have either client secret or JWT private key for client_credentials. + if (string.IsNullOrEmpty(_clientSecret) && string.IsNullOrEmpty(_jwtPrivateKeyPem)) + { + return false; + } + + // Server must support client_credentials grant type. + if (authServerMetadata.GrantTypesSupported?.Contains("client_credentials") != true) + { + return false; + } + + // If an authorization redirect delegate was explicitly configured, use authorization code flow + // Default delegate is fine to override with client_credentials. + if (_authorizationRedirectDelegate != DefaultAuthorizationUrlHandler) + { + return false; + } + + return true; + } + + /// + /// Initiates the OAuth client_credentials flow for machine-to-machine authentication. + /// + private async Task InitiateClientCredentialsFlowAsync( + ProtectedResourceMetadata protectedResourceMetadata, + AuthorizationServerMetadata authServerMetadata, + CancellationToken cancellationToken) + { + var resourceUri = GetRequiredResourceUri(protectedResourceMetadata); + + var formParams = new Dictionary + { + ["grant_type"] = "client_credentials", + ["resource"] = resourceUri.ToString(), + }; + + var scope = GetScopeParameter(protectedResourceMetadata); + if (!string.IsNullOrEmpty(scope)) + { + formParams["scope"] = scope!; + } + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint); + + // Add client authentication based on available credentials and server support + AddClientAuthentication(request, formParams, authServerMetadata); + + request.Content = new FormUrlEncodedContent(formParams); + + using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); + httpResponse.EnsureSuccessStatusCode(); + + var tokens = await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false); + LogOAuthClientCredentialsCompleted(); + return tokens.AccessToken; + } + + /// + /// Adds client authentication to the token request based on available credentials. + /// + private void AddClientAuthentication( + HttpRequestMessage request, + Dictionary formParams, + AuthorizationServerMetadata authServerMetadata) + { + // If JWT private key is configured, use private_key_jwt. + if (!string.IsNullOrEmpty(_jwtPrivateKeyPem) && !string.IsNullOrEmpty(_jwtSigningAlgorithm)) + { + // Use the issuer as the audience if available, otherwise fall back to token endpoint + var audience = authServerMetadata.Issuer ?? authServerMetadata.TokenEndpoint!; + var assertion = CreateClientAssertion(audience); + formParams["client_id"] = GetClientIdOrThrow(); + formParams["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; + formParams["client_assertion"] = assertion; + return; + } + + // Otherwise use client_secret authentication. + var tokenEndpointAuthMethod = GetTokenEndpointAuthMethod(authServerMetadata); + + if (tokenEndpointAuthMethod == "client_secret_basic") + { + // Use HTTP Basic authentication + var credentials = $"{Uri.EscapeDataString(GetClientIdOrThrow())}:{Uri.EscapeDataString(_clientSecret ?? string.Empty)}"; + var encodedCredentials = Convert.ToBase64String(Encoding.UTF8.GetBytes(credentials)); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", encodedCredentials); + } + else + { + // Use client_secret_post (credentials in body) + formParams["client_id"] = GetClientIdOrThrow(); + formParams["client_secret"] = _clientSecret ?? string.Empty; + } + } + + /// + /// Creates a JWT client assertion for private_key_jwt authentication. + /// + private string CreateClientAssertion(Uri audience) + { + // JWT claims (payload) + var now = DateTimeOffset.UtcNow; + var clientId = GetClientIdOrThrow(); + var jti = Guid.NewGuid().ToString(); + var iat = now.ToUnixTimeSeconds(); + var exp = now.AddMinutes(5).ToUnixTimeSeconds(); + + // Manually construct JSON to avoid AOT/trimming issues with Dictionary + // Algorithm is validated in constructor to be one of the known safe values, so no escaping needed + var headerJson = $@"{{""alg"":""{_jwtSigningAlgorithm!.ToUpperInvariant()}"",""typ"":""JWT""}}"; + var claimsJson = $@"{{""iss"":""{JsonEncodedString(clientId)}"",""sub"":""{JsonEncodedString(clientId)}"",""aud"":""{JsonEncodedString(audience.ToString())}"",""jti"":""{jti}"",""iat"":{iat},""exp"":{exp}}}"; + + var headerBase64 = Base64UrlEncode(Encoding.UTF8.GetBytes(headerJson)); + var claimsBase64 = Base64UrlEncode(Encoding.UTF8.GetBytes(claimsJson)); + + var signingInput = $"{headerBase64}.{claimsBase64}"; + var signature = SignJwt(signingInput); + + return $"{signingInput}.{signature}"; + } + + /// + /// Escapes a string for JSON encoding. + /// + private static string JsonEncodedString(string value) => JsonEncodedText.Encode(value).ToString(); + + /// + /// Signs the JWT using the configured private key and algorithm. + /// + private string SignJwt(string input) + { +#if NETSTANDARD2_0 + throw new NotSupportedException( + "JWT client assertion (private_key_jwt) is not supported on .NET Standard 2.0. " + + "Use .NET 5.0 or later for this feature."); +#else + var data = Encoding.UTF8.GetBytes(input); + + var pemContent = _jwtPrivateKeyPem!; + using AsymmetricAlgorithm key = _jwtSigningAlgorithm!.StartsWith("ES", StringComparison.OrdinalIgnoreCase) ? + LoadKeyWithDisposal(ECDsa.Create, ecdsa => ecdsa.ImportFromPem(pemContent)) : + LoadKeyWithDisposal(RSA.Create, rsa => rsa.ImportFromPem(pemContent)); + + byte[] signature; + + if (_jwtSigningAlgorithm!.StartsWith("ES", StringComparison.OrdinalIgnoreCase)) + { + // ECDSA signature - JWT requires IEEE P1363 format (R||S concatenation), not DER + var ecdsa = key as ECDsa ?? throw new InvalidOperationException("Private key is not an EC key, but ES* algorithm was specified."); + var hashAlgorithm = GetHashAlgorithmName(_jwtSigningAlgorithm); + signature = ecdsa.SignData(data, hashAlgorithm, DSASignatureFormat.IeeeP1363FixedFieldConcatenation); + } + else if (_jwtSigningAlgorithm.StartsWith("RS", StringComparison.OrdinalIgnoreCase) || + _jwtSigningAlgorithm.StartsWith("PS", StringComparison.OrdinalIgnoreCase)) + { + // RSA signature + var rsa = key as RSA ?? throw new InvalidOperationException("Private key is not an RSA key, but RS*/PS* algorithm was specified."); + var hashAlgorithm = GetHashAlgorithmName(_jwtSigningAlgorithm); + var padding = _jwtSigningAlgorithm.StartsWith("PS", StringComparison.OrdinalIgnoreCase) + ? RSASignaturePadding.Pss + : RSASignaturePadding.Pkcs1; + signature = rsa.SignData(data, hashAlgorithm, padding); + } + else + { + throw new NotSupportedException($"JWT signing algorithm '{_jwtSigningAlgorithm}' is not supported."); + } + + return Base64UrlEncode(signature); +#endif + } + + private static TAlgorithm LoadKeyWithDisposal( + Func createAlgorithm, + Action importAction) + where TAlgorithm : AsymmetricAlgorithm + { + var algorithm = createAlgorithm(); + try + { + importAction(algorithm); + return algorithm; + } + catch + { + algorithm.Dispose(); + throw; + } + } + + private static HashAlgorithmName GetHashAlgorithmName(string algorithm) => + s_signingAlgorithms.TryGetValue(algorithm, out HashAlgorithmName alg) ? alg : + throw new NotSupportedException($"JWT signing algorithm '{algorithm}' is not supported."); + + private static readonly Dictionary s_signingAlgorithms = new(StringComparer.OrdinalIgnoreCase) + { + ["ES256"] = HashAlgorithmName.SHA256, + ["RS256"] = HashAlgorithmName.SHA256, + ["PS256"] = HashAlgorithmName.SHA256, + + ["ES384"] = HashAlgorithmName.SHA384, + ["RS384"] = HashAlgorithmName.SHA384, + ["PS384"] = HashAlgorithmName.SHA384, + + ["ES512"] = HashAlgorithmName.SHA512, + ["RS512"] = HashAlgorithmName.SHA512, + ["PS512"] = HashAlgorithmName.SHA512, + }; + + /// + /// Base64url encodes data per RFC 7515. + /// + private static string Base64UrlEncode(byte[] data) => +#if NET9_0_OR_GREATER + Base64Url.EncodeToString(data); +#else + Convert.ToBase64String(data).TrimEnd('=').Replace('+', '-').Replace('/', '_'); +#endif + private void ApplyClientIdMetadataDocument(Uri metadataUri) { if (!IsValidClientMetadataDocumentUri(metadataUri)) @@ -311,10 +591,10 @@ private void ApplyClientIdMetadataDocument(Uri metadataUri) _clientId = metadataUri.AbsoluteUri; // See: https://datatracker.ietf.org/doc/html/draft-ietf-oauth-client-id-metadata-document-00#section-3 - static bool IsValidClientMetadataDocumentUri(Uri uri) - => uri.IsAbsoluteUri - && string.Equals(uri.Scheme, Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase) - && uri.AbsolutePath.Length > 1; // AbsolutePath always starts with "/" + static bool IsValidClientMetadataDocumentUri(Uri uri) => + uri.IsAbsoluteUri && + string.Equals(uri.Scheme, Uri.UriSchemeHttps, StringComparison.OrdinalIgnoreCase) && + uri.AbsolutePath.Length > 1; // AbsolutePath always starts with "/" } private async Task GetAuthServerMetadataAsync(Uri authServerUri, CancellationToken cancellationToken) @@ -385,19 +665,33 @@ private static IEnumerable GetWellKnownAuthorizationServerMetadataUris(Uri private async Task RefreshTokensAsync(string refreshToken, Uri resourceUri, AuthorizationServerMetadata authServerMetadata, CancellationToken cancellationToken) { - var requestContent = new FormUrlEncodedContent(new Dictionary + var formParams = new Dictionary { ["grant_type"] = "refresh_token", ["refresh_token"] = refreshToken, - ["client_id"] = GetClientIdOrThrow(), - ["client_secret"] = _clientSecret ?? string.Empty, ["resource"] = resourceUri.ToString(), - }); + }; - using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint) + // Add client credentials based on token endpoint auth method + var tokenEndpointAuthMethod = GetTokenEndpointAuthMethod(authServerMetadata); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint); + + if (tokenEndpointAuthMethod == "client_secret_basic") { - Content = requestContent - }; + // Use HTTP Basic authentication + var credentials = $"{Uri.EscapeDataString(GetClientIdOrThrow())}:{Uri.EscapeDataString(_clientSecret ?? string.Empty)}"; + var encodedCredentials = Convert.ToBase64String(Encoding.UTF8.GetBytes(credentials)); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", encodedCredentials); + } + else + { + // Use client_secret_post (credentials in body) + formParams["client_id"] = GetClientIdOrThrow(); + formParams["client_secret"] = _clientSecret ?? string.Empty; + } + + request.Content = new FormUrlEncodedContent(formParams); using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); @@ -482,21 +776,35 @@ private async Task ExchangeCodeForTokenAsync( { var resourceUri = GetRequiredResourceUri(protectedResourceMetadata); - var requestContent = new FormUrlEncodedContent(new Dictionary + var formParams = new Dictionary { ["grant_type"] = "authorization_code", ["code"] = authorizationCode, ["redirect_uri"] = _redirectUri.ToString(), - ["client_id"] = GetClientIdOrThrow(), ["code_verifier"] = codeVerifier, - ["client_secret"] = _clientSecret ?? string.Empty, ["resource"] = resourceUri.ToString(), - }); + }; + + // Add client credentials based on token endpoint auth method + var tokenEndpointAuthMethod = GetTokenEndpointAuthMethod(authServerMetadata); + + using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint); - using var request = new HttpRequestMessage(HttpMethod.Post, authServerMetadata.TokenEndpoint) + if (tokenEndpointAuthMethod == "client_secret_basic") { - Content = requestContent - }; + // Use HTTP Basic authentication + var credentials = $"{Uri.EscapeDataString(GetClientIdOrThrow())}:{Uri.EscapeDataString(_clientSecret ?? string.Empty)}"; + var encodedCredentials = Convert.ToBase64String(Encoding.UTF8.GetBytes(credentials)); + request.Headers.Authorization = new AuthenticationHeaderValue("Basic", encodedCredentials); + } + else + { + // Use client_secret_post (credentials in body) + formParams["client_id"] = GetClientIdOrThrow(); + formParams["client_secret"] = _clientSecret ?? string.Empty; + } + + request.Content = new FormUrlEncodedContent(formParams); using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false); httpResponse.EnsureSuccessStatusCode(); @@ -870,6 +1178,25 @@ private static string ToBase64UrlString(byte[] bytes) private string GetClientIdOrThrow() => _clientId ?? throw new InvalidOperationException("Client ID is not available. This may indicate an issue with dynamic client registration."); + /// + /// Determines the token endpoint authentication method to use based on server metadata. + /// + /// The authorization server metadata. + /// The authentication method to use (client_secret_basic or client_secret_post). + private static string GetTokenEndpointAuthMethod(AuthorizationServerMetadata authServerMetadata) + { + var supportedMethods = authServerMetadata.TokenEndpointAuthMethodsSupported; + + // If client_secret_basic is supported, prefer it + if (supportedMethods?.Contains("client_secret_basic") == true) + { + return "client_secret_basic"; + } + + // Otherwise use client_secret_post (default per RFC) + return "client_secret_post"; + } + [DoesNotReturn] private static void ThrowFailedToHandleUnauthorizedResponse(string message) => throw new McpException($"Failed to handle unauthorized response with 'Bearer' scheme. {message}"); @@ -880,6 +1207,9 @@ private static void ThrowFailedToHandleUnauthorizedResponse(string message) => [LoggerMessage(Level = LogLevel.Information, Message = "OAuth authorization completed successfully")] partial void LogOAuthAuthorizationCompleted(); + [LoggerMessage(Level = LogLevel.Information, Message = "OAuth client_credentials flow completed successfully")] + partial void LogOAuthClientCredentialsCompleted(); + [LoggerMessage(Level = LogLevel.Information, Message = "OAuth token refresh completed successfully")] partial void LogOAuthTokenRefreshCompleted(); diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index 60950dfa5..3a5d96b9c 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.Diagnostics; +using System.Net; using System.Net.Http.Headers; using System.Net.ServerSentEvents; using System.Text.Json; @@ -22,6 +23,8 @@ internal sealed partial class SseClientSessionTransport : TransportBase private Task? _receiveTask; private readonly ILogger _logger; private readonly TaskCompletionSource _connectionEstablished; + private string? _lastEventId; + private TimeSpan? _retryInterval; /// /// SSE transport for a single session. Unlike stdio it does not launch a process, but connects to an existing server. @@ -138,53 +141,140 @@ public override async ValueTask DisposeAsync() private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { - try + int attempt = 0; + + while (attempt < _options.MaxReconnectionAttempts && !cancellationToken.IsCancellationRequested) { - using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); - request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); + try + { + // Delay before reconnection attempts. + if (attempt > 0) + { + await Task.Delay(_retryInterval ?? _options.DefaultReconnectionInterval, cancellationToken).ConfigureAwait(false); + } - using var response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false); + using var request = new HttpRequestMessage(HttpMethod.Get, _sseEndpoint); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); + StreamableHttpClientSessionTransport.CopyAdditionalHeaders(request.Headers, _options.AdditionalHeaders, sessionId: null, protocolVersion: null); - response.EnsureSuccessStatusCode(); + // Include Last-Event-ID header for reconnection. + if (_lastEventId is not null) + { + request.Headers.Add("Last-Event-ID", _lastEventId); + LogSseReconnectWithLastEventId(Name, _lastEventId); + } - using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + HttpResponseMessage response; + try + { + response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false); + } + catch (HttpRequestException) + { + // Network error - retry + attempt++; + continue; + } - await foreach (SseItem sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) - { - switch (sseEvent.EventType) + using var _ = response; + + if (response.StatusCode >= HttpStatusCode.InternalServerError) { - case "endpoint": - HandleEndpointEvent(sseEvent.Data); - break; + // Server error - retry + attempt++; + continue; + } - case "message": - await ProcessSseMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false); - break; + response.EnsureSuccessStatusCode(); + + using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); + + bool hadNetworkError = await ProcessSseStreamAsync(stream, cancellationToken).ConfigureAwait(false); + + if (!hadNetworkError || _lastEventId is null) + { + // Stream ended either gracefully or without resumability support. + return; } + + // Only retry if the server didn't close gracefully and we have something to retry. + attempt++; + continue; } - } - catch (Exception ex) - { - if (cancellationToken.IsCancellationRequested) + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { - // Normal shutdown + // Normal shutdown via cancellation LogTransportReadMessagesCancelled(Name); _connectionEstablished.TrySetCanceled(cancellationToken); + return; } - else + catch (Exception ex) { LogTransportReadMessagesFailed(Name, ex); _connectionEstablished.TrySetException(ex); throw; } } - finally + + if (cancellationToken.IsCancellationRequested) + { + LogTransportReadMessagesCancelled(Name); + _connectionEstablished.TrySetCanceled(cancellationToken); + } + else { SetDisconnected(); } } + /// + /// Processes the SSE stream, handling events until the stream ends. + /// + /// True if the stream ended due to a network error (should retry), false if it ended normally. + private async Task ProcessSseStreamAsync(Stream stream, CancellationToken cancellationToken) + { + try + { + await foreach (SseItem sseEvent in SseParser.Create(stream).EnumerateAsync(cancellationToken).ConfigureAwait(false)) + { + // Track event ID and retry interval for resumability + var eventId = sseEvent.EventId; + + if (!string.IsNullOrEmpty(eventId)) + { + _lastEventId = eventId; + LogSseEventIdReceived(Name, eventId!); + } + + if (sseEvent.ReconnectionInterval.HasValue) + { + _retryInterval = sseEvent.ReconnectionInterval.Value; + LogSseRetryIntervalReceived(Name, sseEvent.ReconnectionInterval.Value.TotalMilliseconds); + } + + switch (sseEvent.EventType) + { + case "endpoint": + HandleEndpointEvent(sseEvent.Data); + break; + + case "message": + await ProcessSseMessage(sseEvent.Data, cancellationToken).ConfigureAwait(false); + break; + } + } + + // Stream ended normally (server closed connection gracefully) + return false; + } + catch (Exception ex) when (ex is IOException or HttpRequestException) + { + // Network error during streaming - should retry + LogSseStreamNetworkError(Name, ex); + return true; + } + } + private async Task ProcessSseMessage(string data, CancellationToken cancellationToken) { if (!IsConnected) @@ -243,4 +333,16 @@ private void HandleEndpointEvent(string data) [LoggerMessage(Level = LogLevel.Trace, Message = "{EndpointName} rejected SSE transport POST for message ID '{MessageId}'. Server response: '{responseContent}'.")] private partial void LogRejectedPostSensitive(string endpointName, string messageId, string responseContent); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} SSE reconnection with Last-Event-ID: '{LastEventId}'.")] + private partial void LogSseReconnectWithLastEventId(string endpointName, string lastEventId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} SSE received event ID: '{EventId}'.")] + private partial void LogSseEventIdReceived(string endpointName, string eventId); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} SSE received retry interval: {RetryIntervalMs}ms.")] + private partial void LogSseRetryIntervalReceived(string endpointName, double retryIntervalMs); + + [LoggerMessage(Level = LogLevel.Debug, Message = "{EndpointName} SSE stream network error during streaming, will attempt reconnection.")] + private partial void LogSseStreamNetworkError(string endpointName, Exception ex); } \ No newline at end of file diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs index eb45df84c..037802d0f 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs @@ -1,5 +1,6 @@ using System.Diagnostics; using System.Text; +using System.Text.RegularExpressions; using ModelContextProtocol.Tests.Utils; namespace ModelContextProtocol.ConformanceTests; @@ -21,19 +22,23 @@ public ClientConformanceTests(ITestOutputHelper output) _output = output; } - // Expected scenarios based on InlineData attributes below - private static readonly string[] ExpectedScenarios = [ + // Expected scenarios based on InlineData attributes below. + // All scenarios from the conformance suite must be listed here to ensure VerifyAllConformanceTestsAreListed + // detects any new scenarios added to the suite. Scenarios may be disabled (not in InlineData) but still + // listed here - see comments on the Theory for explanations of why specific scenarios are disabled. + private static readonly string[] ExpectedScenarios = + [ "initialize", "tools_call", "elicitation-sep1034-client-defaults", - "sse-retry", + "sse-retry", // Disabled - tests pure SSE reconnection, not MCP behavior (see comment on Theory) "auth/metadata-default", "auth/metadata-var1", "auth/metadata-var2", "auth/metadata-var3", "auth/basic-cimd", - "auth/2025-03-26-oauth-metadata-backcompat", // Expected but not required to pass - "auth/2025-03-26-oauth-endpoint-fallback", // Expected but not required to pass + "auth/2025-03-26-oauth-metadata-backcompat", // Disabled - tests deprecated 2025-03-26 spec (see comment on Theory) + "auth/2025-03-26-oauth-endpoint-fallback", // Disabled - tests deprecated 2025-03-26 spec (see comment on Theory) "auth/scope-from-www-authenticate", "auth/scope-from-scopes-supported", "auth/scope-omitted-when-undefined", @@ -96,14 +101,49 @@ public async Task VerifyAllConformanceTestsAreListed() [InlineData("initialize")] [InlineData("tools_call")] [InlineData("elicitation-sep1034-client-defaults")] - [InlineData("sse-retry")] + + // The sse-retry test is disabled because it tests pure SSE reconnection behavior, + // not MCP-specific behavior. The test expects the client to: + // 1. Connect via SSE GET + // 2. Receive a priming event with retry interval and event ID + // 3. Gracefully handle stream closure + // 4. Reconnect with Last-Event-ID header (per SSE spec) + // + // The MCP SDK's SSE transport waits for an "endpoint" MCP event before considering + // the connection established (required for MCP message routing). Without this event, + // the connection establishment times out after 30 seconds. + // + // When run, the test shows: + // - [client-sse-graceful-reconnect] SUCCESS - Core SSE reconnection works + // - [client-sse-retry-timing] WARNING - "Client MUST respect the retry field timing" + // - [client-sse-last-event-id] WARNING - "Client SHOULD send Last-Event-ID header" + // + // Per SSE specification (https://html.spec.whatwg.org/multipage/server-sent-events.html): + // - Reconnecting after stream close is MUST behavior (works) + // - Sending Last-Event-ID is SHOULD behavior for resumability + // - Respecting retry timing is SHOULD behavior + // + // The test fails due to client timeout, not actual SSE behavior issues. + // Supporting pure SSE (non-MCP) would require architectural changes to the transport. + // [InlineData("sse-retry")] + [InlineData("auth/metadata-default")] [InlineData("auth/metadata-var1")] [InlineData("auth/metadata-var2")] [InlineData("auth/metadata-var3")] [InlineData("auth/basic-cimd")] + + // The following two tests are for backward compatibility with the deprecated 2025-03-26 MCP spec. + // They test legacy OAuth discovery behavior that the SDK intentionally does not support: + // - auth/2025-03-26-oauth-metadata-backcompat: Tests OAuth flow without Protected Resource Metadata (PRM), + // expecting OAuth metadata at the server root. The current SDK requires PRM per the 2025-11-25 spec. + // - auth/2025-03-26-oauth-endpoint-fallback: Tests fallback to standard OAuth endpoints (/authorize, /token, + // /register) at the server root when no metadata endpoints exist. The SDK doesn't implement this fallback. + // These are listed in ExpectedScenarios to ensure VerifyAllConformanceTestsAreListed passes, but they are + // not required for Tier 1 SDK compliance as they test deprecated spec behavior. // [InlineData("auth/2025-03-26-oauth-metadata-backcompat")] // [InlineData("auth/2025-03-26-oauth-endpoint-fallback")] + [InlineData("auth/scope-from-www-authenticate")] [InlineData("auth/scope-from-scopes-supported")] [InlineData("auth/scope-omitted-when-undefined")] @@ -112,8 +152,7 @@ public async Task VerifyAllConformanceTestsAreListed() [InlineData("auth/token-endpoint-auth-basic")] [InlineData("auth/token-endpoint-auth-post")] [InlineData("auth/token-endpoint-auth-none")] - // auth/resource-mismatch disabled for now pending https://github.com/modelcontextprotocol/conformance/pull/123 - // [InlineData("auth/resource-mismatch")] + [InlineData("auth/resource-mismatch")] [InlineData("auth/pre-registration")] [InlineData("auth/client-credentials-jwt")] [InlineData("auth/client-credentials-basic")] @@ -172,10 +211,24 @@ public async Task RunConformanceTest(string scenario) await process.WaitForExitAsync(); - return ( - Success: process.ExitCode == 0, - Output: outputBuilder.ToString(), - Error: errorBuilder.ToString() - ); + var error = errorBuilder.ToString(); + var combinedOutput = outputBuilder.ToString() + error; + + // Strip ANSI escape codes for reliable pattern matching (ESC [ ... m) + var strippedOutput = Regex.Replace(combinedOutput, @"\u001b\[[0-9;]*m|\x1b\[[0-9;]*m", "", RegexOptions.IgnoreCase); + + // Check for success based on the conformance test output, not just exit code. + // Some tests (like auth/resource-mismatch) expect the client to exit with an error + // after correctly detecting a security issue. The conformance harness reports these + // as "CLIENT EXITED WITH ERROR" but if all actual checks passed (indicated by + // "Passed: X/X, 0 failed"), we should treat this as success. + bool checksPass = + strippedOutput.Contains("OVERALL: PASSED", StringComparison.OrdinalIgnoreCase) || + (strippedOutput.Contains(", 0 failed,", StringComparison.OrdinalIgnoreCase) && + strippedOutput.Contains("Passed:", StringComparison.OrdinalIgnoreCase)); + + return (Success: process.ExitCode == 0 || checksPass, + Output: strippedOutput, + Error: error); } } diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/DnsRebindingProtectionMiddlewareTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/DnsRebindingProtectionMiddlewareTests.cs new file mode 100644 index 000000000..45cc3e25d --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/DnsRebindingProtectionMiddlewareTests.cs @@ -0,0 +1,112 @@ +using Microsoft.AspNetCore.Builder; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.AspNetCore.Tests.Utils; +using System.Net; + +namespace ModelContextProtocol.AspNetCore.Tests; + +public class DnsRebindingProtectionMiddlewareTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) +{ + [Theory] + [InlineData("localhost", HttpStatusCode.OK)] + [InlineData("localhost:5000", HttpStatusCode.OK)] + [InlineData("127.0.0.1", HttpStatusCode.OK)] + [InlineData("127.0.0.1:5000", HttpStatusCode.OK)] + [InlineData("[::1]", HttpStatusCode.OK)] + [InlineData("[::1]:5000", HttpStatusCode.OK)] + [InlineData("evil.com", HttpStatusCode.Forbidden)] + [InlineData("evil.localhost", HttpStatusCode.Forbidden)] + [InlineData("localhost.evil.com", HttpStatusCode.Forbidden)] + public async Task ValidatesHostHeader(string hostHeader, HttpStatusCode expectedStatusCode) + { + Builder.Services.AddMcpServer().WithHttpTransport(); + await using var app = Builder.Build(); + + app.UseMcpDnsRebindingProtection(); + app.MapGet("/test", () => "OK"); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var request = new HttpRequestMessage(HttpMethod.Get, "/test"); + request.Headers.Host = hostHeader; + + var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + + Assert.Equal(expectedStatusCode, response.StatusCode); + } + + [Theory] + [InlineData("http://localhost", HttpStatusCode.OK)] + [InlineData("http://localhost:5000", HttpStatusCode.OK)] + [InlineData("http://127.0.0.1", HttpStatusCode.OK)] + [InlineData("http://127.0.0.1:5000", HttpStatusCode.OK)] + [InlineData("http://[::1]", HttpStatusCode.OK)] + [InlineData("http://[::1]:5000", HttpStatusCode.OK)] + [InlineData("http://evil.com", HttpStatusCode.Forbidden)] + [InlineData("http://evil.localhost", HttpStatusCode.Forbidden)] + [InlineData("https://malicious.site", HttpStatusCode.Forbidden)] + public async Task ValidatesOriginHeader(string originHeader, HttpStatusCode expectedStatusCode) + { + Builder.Services.AddMcpServer().WithHttpTransport(); + await using var app = Builder.Build(); + + app.UseMcpDnsRebindingProtection(); + app.MapGet("/test", () => "OK"); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var request = new HttpRequestMessage(HttpMethod.Get, "/test"); + request.Headers.Host = "localhost"; // Valid host + request.Headers.Add("Origin", originHeader); + + var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + + Assert.Equal(expectedStatusCode, response.StatusCode); + } + + [Fact] + public async Task AllowsRequestsWithNoOriginHeader() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + await using var app = Builder.Build(); + + app.UseMcpDnsRebindingProtection(); + app.MapGet("/test", () => "OK"); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var request = new HttpRequestMessage(HttpMethod.Get, "/test"); + request.Headers.Host = "localhost"; + // No Origin header + + var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + } + + [Fact] + public async Task ReturnsForbiddenWithJsonRpcErrorFormat() + { + Builder.Services.AddMcpServer().WithHttpTransport(); + await using var app = Builder.Build(); + + app.UseMcpDnsRebindingProtection(); + app.MapGet("/test", () => "OK"); + + await app.StartAsync(TestContext.Current.CancellationToken); + + var request = new HttpRequestMessage(HttpMethod.Get, "/test"); + request.Headers.Host = "localhost"; + request.Headers.Add("Origin", "http://evil.com"); + + var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); + + Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); + Assert.Equal("application/json", response.Content.Headers.ContentType?.MediaType); + + var content = await response.Content.ReadAsStringAsync(TestContext.Current.CancellationToken); + Assert.Contains("jsonrpc", content); + Assert.Contains("error", content); + Assert.Contains("-32000", content); + } +} diff --git a/tests/ModelContextProtocol.ConformanceClient/Program.cs b/tests/ModelContextProtocol.ConformanceClient/Program.cs index d75f64682..545422919 100644 --- a/tests/ModelContextProtocol.ConformanceClient/Program.cs +++ b/tests/ModelContextProtocol.ConformanceClient/Program.cs @@ -3,6 +3,8 @@ using System.Text.Json; using System.Web; using Microsoft.Extensions.Logging; +using ModelContextProtocol; +using ModelContextProtocol.Authentication; using ModelContextProtocol.Client; using ModelContextProtocol.Protocol; @@ -24,6 +26,45 @@ builder.AddConsole(); }); +// Parse MCP_CONFORMANCE_CONTEXT environment variable for test context +// This may contain client_id, client_secret, private_key_pem, signing_algorithm for pre-registration tests +string? contextClientId = null; +string? contextClientSecret = null; +string? contextPrivateKeyPem = null; +string? contextSigningAlgorithm = null; +var conformanceContext = Environment.GetEnvironmentVariable("MCP_CONFORMANCE_CONTEXT"); +if (!string.IsNullOrEmpty(conformanceContext)) +{ + try + { + using var contextJson = JsonDocument.Parse(conformanceContext); + + if (contextJson.RootElement.TryGetProperty("client_id", out var clientIdProp)) + { + contextClientId = clientIdProp.GetString(); + } + + if (contextJson.RootElement.TryGetProperty("client_secret", out var clientSecretProp)) + { + contextClientSecret = clientSecretProp.GetString(); + } + + if (contextJson.RootElement.TryGetProperty("private_key_pem", out var privateKeyProp)) + { + contextPrivateKeyPem = privateKeyProp.GetString(); + } + + if (contextJson.RootElement.TryGetProperty("signing_algorithm", out var signingAlgProp)) + { + contextSigningAlgorithm = signingAlgProp.GetString(); + } + } + catch (JsonException) + { + // Ignore malformed context + } +} + // Configure OAuth callback port via environment or pick an ephemeral port. var callbackPortEnv = Environment.GetEnvironmentVariable("OAUTH_CALLBACK_PORT"); int callbackPort = 0; @@ -42,21 +83,53 @@ var clientRedirectUri = new Uri($"http://localhost:{callbackPort}/callback"); +// Build OAuth options. +// For client_credentials tests, don't set a redirect handler to trigger machine-to-machine flow. +var isClientCredentialsTest = scenario.StartsWith("auth/client-credentials-"); + +var oauthOptions = new ClientOAuthOptions +{ + RedirectUri = clientRedirectUri, + // Configure the metadata document URI for CIMD. + ClientMetadataDocumentUri = new Uri("https://conformance-test.local/client-metadata.json"), + DynamicClientRegistration = new() + { + ClientName = "ProtectedMcpClient", + }, +}; + +// Only set authorization redirect handler for tests that need authorization code flow. +// Client credentials tests should NOT have a redirect handler to trigger machine-to-machine flow. +if (!isClientCredentialsTest) +{ + oauthOptions.AuthorizationRedirectDelegate = (authUrl, redirectUri, ct) => HandleAuthorizationUrlAsync(authUrl, redirectUri, ct); +} + +// If pre-registered credentials are provided via context, use them. +// This allows the OAuth provider to skip dynamic client registration and +// potentially use client_credentials grant type if the server supports it. +if (!string.IsNullOrEmpty(contextClientId)) +{ + oauthOptions.ClientId = contextClientId; + oauthOptions.ClientSecret = contextClientSecret; +} + +// If JWT private key is provided (for private_key_jwt authentication), use it. +if (!string.IsNullOrEmpty(contextPrivateKeyPem) && !string.IsNullOrEmpty(contextSigningAlgorithm)) +{ + oauthOptions.JwtPrivateKeyPem = contextPrivateKeyPem; + oauthOptions.JwtSigningAlgorithm = contextSigningAlgorithm; +} + +// Select transport mode based on scenario. +// sse-retry test requires SSE transport mode to test SSE-specific reconnection behavior. +var transportMode = scenario == "sse-retry" ? HttpTransportMode.Sse : HttpTransportMode.StreamableHttp; + var clientTransport = new HttpClientTransport(new() { Endpoint = new Uri(endpoint), - TransportMode = HttpTransportMode.StreamableHttp, - OAuth = new() - { - RedirectUri = clientRedirectUri, - // Configure the metadata document URI for CIMD. - ClientMetadataDocumentUri = new Uri("https://conformance-test.local/client-metadata.json"), - AuthorizationRedirectDelegate = (authUrl, redirectUri, ct) => HandleAuthorizationUrlAsync(authUrl, redirectUri, ct), - DynamicClientRegistration = new() - { - ClientName = "ProtectedMcpClient", - }, - } + TransportMode = transportMode, + OAuth = oauthOptions }, loggerFactory: consoleLoggerFactory); // Wrapper delegate pattern: allows setting elicitation handler after client creation @@ -122,6 +195,32 @@ success &= !(result.IsError == true); break; } + case "auth/scope-retry-limit": + { + // For scope-retry-limit, the server will keep returning 403 with insufficient_scope + // until the client gives up (tests the max retry limit). + // The client should catch McpException when the retry limit is exceeded. + try + { + var tools = await mcpClient.ListToolsAsync(); + Console.WriteLine($"Available tools: {string.Join(", ", tools.Select(t => t.Name))}"); + + // Call the "test_tool" tool + var toolName = tools.FirstOrDefault()?.Name ?? "test-tool"; + Console.WriteLine($"Calling tool: {toolName}"); + var result = await mcpClient.CallToolAsync(toolName: toolName, arguments: new Dictionary + { + { "foo", "bar" }, + }); + success &= !(result.IsError == true); + } + catch (McpException ex) when (ex.Message.Contains("retry limit")) + { + // Expected - the client correctly limited scope step-up retries + Console.WriteLine($"Scope step-up retry limit reached (expected): {ex.Message}"); + } + break; + } case "elicitation-sep1034-client-defaults": { // In this test scenario, an elicitation request will be made that includes default values in the schema. @@ -142,19 +241,19 @@ switch (schema) { case ElicitRequestParams.StringSchema stringSchema when stringSchema.Default is not null: - content[key] = JsonSerializer.SerializeToElement(stringSchema.Default); + content[key] = JsonSerializer.SerializeToElement(stringSchema.Default, McpJsonUtilities.DefaultOptions); break; case ElicitRequestParams.NumberSchema numberSchema when numberSchema.Default.HasValue: - content[key] = JsonSerializer.SerializeToElement(numberSchema.Default.Value); + content[key] = JsonSerializer.SerializeToElement(numberSchema.Default.Value, McpJsonUtilities.DefaultOptions); break; case ElicitRequestParams.BooleanSchema booleanSchema when booleanSchema.Default.HasValue: - content[key] = JsonSerializer.SerializeToElement(booleanSchema.Default.Value); + content[key] = JsonSerializer.SerializeToElement(booleanSchema.Default.Value, McpJsonUtilities.DefaultOptions); break; case ElicitRequestParams.UntitledSingleSelectEnumSchema enumSchema when enumSchema.Default is not null: - content[key] = JsonSerializer.SerializeToElement(enumSchema.Default); + content[key] = JsonSerializer.SerializeToElement(enumSchema.Default, McpJsonUtilities.DefaultOptions); break; case ElicitRequestParams.TitledSingleSelectEnumSchema titledEnumSchema when titledEnumSchema.Default is not null: - content[key] = JsonSerializer.SerializeToElement(titledEnumSchema.Default); + content[key] = JsonSerializer.SerializeToElement(titledEnumSchema.Default, McpJsonUtilities.DefaultOptions); break; } } diff --git a/tests/ModelContextProtocol.ConformanceServer/ModelContextProtocol.ConformanceServer.csproj b/tests/ModelContextProtocol.ConformanceServer/ModelContextProtocol.ConformanceServer.csproj index 15b2c87f2..c3fed6b6c 100644 --- a/tests/ModelContextProtocol.ConformanceServer/ModelContextProtocol.ConformanceServer.csproj +++ b/tests/ModelContextProtocol.ConformanceServer/ModelContextProtocol.ConformanceServer.csproj @@ -5,6 +5,7 @@ enable enable Exe + $(NoWarn);MCP9001 diff --git a/tests/ModelContextProtocol.ConformanceServer/Program.cs b/tests/ModelContextProtocol.ConformanceServer/Program.cs index 00e557e7f..ddced4fc9 100644 --- a/tests/ModelContextProtocol.ConformanceServer/Program.cs +++ b/tests/ModelContextProtocol.ConformanceServer/Program.cs @@ -92,6 +92,8 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide var app = builder.Build(); + app.UseMcpDnsRebindingProtection(); + app.MapMcp(); app.MapGet("/health", () => "Healthy"); diff --git a/tests/ModelContextProtocol.ConformanceServer/Tools/ConformanceTools.cs b/tests/ModelContextProtocol.ConformanceServer/Tools/ConformanceTools.cs index 177de5c60..2280f8055 100644 --- a/tests/ModelContextProtocol.ConformanceServer/Tools/ConformanceTools.cs +++ b/tests/ModelContextProtocol.ConformanceServer/Tools/ConformanceTools.cs @@ -332,16 +332,54 @@ public static async Task ElicitationSep1330Enums( { Properties = { - ["color"] = new ElicitRequestParams.UntitledSingleSelectEnumSchema() + // 1. Untitled single-select: { type: "string", enum: ["option1", "option2", "option3"] } + ["untitledSingle"] = new ElicitRequestParams.UntitledSingleSelectEnumSchema() { - Description = "Choose a color", - Enum = ["red", "green", "blue"] + Description = "Untitled single-select enum", + Enum = ["option1", "option2", "option3"] }, - ["size"] = new ElicitRequestParams.UntitledSingleSelectEnumSchema() + + // 2. Titled single-select: { type: "string", oneOf: [{ const: "value1", title: "First Option" }, ...] } + ["titledSingle"] = new ElicitRequestParams.TitledSingleSelectEnumSchema() + { + Description = "Titled single-select enum", + OneOf = [ + new ElicitRequestParams.EnumSchemaOption { Const = "value1", Title = "First Option" }, + new ElicitRequestParams.EnumSchemaOption { Const = "value2", Title = "Second Option" }, + new ElicitRequestParams.EnumSchemaOption { Const = "value3", Title = "Third Option" } + ] + }, + + // 3. Legacy titled (deprecated): { type: "string", enum: ["opt1", "opt2", "opt3"], enumNames: ["Option One", "Option Two", "Option Three"] } + ["legacyEnum"] = new ElicitRequestParams.LegacyTitledEnumSchema() + { + Description = "Legacy titled enum (deprecated)", + Enum = ["opt1", "opt2", "opt3"], + EnumNames = ["Option One", "Option Two", "Option Three"] + }, + + // 4. Untitled multi-select: { type: "array", items: { type: "string", enum: ["option1", "option2", "option3"] } } + ["untitledMulti"] = new ElicitRequestParams.UntitledMultiSelectEnumSchema() + { + Description = "Untitled multi-select enum", + Items = new ElicitRequestParams.UntitledEnumItemsSchema + { + Enum = ["option1", "option2", "option3"] + } + }, + + // 5. Titled multi-select: { type: "array", items: { anyOf: [{ const: "value1", title: "First Choice" }, ...] } } + ["titledMulti"] = new ElicitRequestParams.TitledMultiSelectEnumSchema() { - Description = "Choose a size", - Enum = ["small", "medium", "large"], - Default = "medium" + Description = "Titled multi-select enum", + Items = new ElicitRequestParams.TitledEnumItemsSchema + { + AnyOf = [ + new ElicitRequestParams.EnumSchemaOption { Const = "value1", Title = "First Choice" }, + new ElicitRequestParams.EnumSchemaOption { Const = "value2", Title = "Second Choice" }, + new ElicitRequestParams.EnumSchemaOption { Const = "value3", Title = "Third Choice" } + ] + } } } }; @@ -354,8 +392,7 @@ public static async Task ElicitationSep1330Enums( if (result.Action == "accept" && result.Content != null) { - return $"Accepted with values: color={result.Content["color"].GetString()}, " + - $"size={result.Content["size"].GetString()}"; + return $"Elicitation completed: action={result.Action}, content={result.Content}"; } else { From 9deb100d75fb13fbeefc0bf91c518eb675cdc0d1 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 28 Jan 2026 21:13:54 -0500 Subject: [PATCH 5/7] Update CI conformance tests from 0.1.10 to 0.1.11 --- .github/workflows/ci-build-test.yml | 2 +- .../ClientConformanceTests.cs | 11 +++++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-build-test.yml b/.github/workflows/ci-build-test.yml index 838dce4cf..59c7b7e2d 100644 --- a/.github/workflows/ci-build-test.yml +++ b/.github/workflows/ci-build-test.yml @@ -69,7 +69,7 @@ jobs: # Keep version in sync with McpConformanceVersion in Directory.Packages.props - name: 📦 Install conformance test runner - run: npm install @modelcontextprotocol/conformance@0.1.10 + run: npm install @modelcontextprotocol/conformance@0.1.11 - name: 🏗️ Build run: make build CONFIGURATION=${{ matrix.configuration }} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs index 037802d0f..3d42cfa94 100644 --- a/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs +++ b/tests/ModelContextProtocol.AspNetCore.Tests/ClientConformanceTests.cs @@ -1,4 +1,5 @@ using System.Diagnostics; +using System.Reflection; using System.Text; using System.Text.RegularExpressions; using ModelContextProtocol.Tests.Utils; @@ -53,11 +54,16 @@ public ClientConformanceTests(ITestOutputHelper output) "auth/client-credentials-basic" ]; + private static string GetConformanceVersion() => + typeof(ClientConformanceTests).Assembly.GetCustomAttributes().FirstOrDefault(a => a.Key is "McpConformanceVersion")?.Value ?? + throw new InvalidOperationException("McpConformanceVersion not found in assembly metadata"); + [Fact(Skip = "npx is not installed. Skipping client conformance tests.", SkipUnless = nameof(IsNpxInstalled))] public async Task VerifyAllConformanceTestsAreListed() { // Get the list of available conformance tests from the suite - var startInfo = NodeHelpers.NpxStartInfo("-y @modelcontextprotocol/conformance list --client"); + // Version is configured in Directory.Packages.props for central management + var startInfo = NodeHelpers.NpxStartInfo($"-y @modelcontextprotocol/conformance@{GetConformanceVersion()} list --client"); var outputBuilder = new StringBuilder(); var process = new Process { StartInfo = startInfo }; @@ -180,7 +186,8 @@ public async Task RunConformanceTest(string scenario) $"ConformanceClient executable not found at: {conformanceClientPath}"); } - var startInfo = NodeHelpers.NpxStartInfo($"-y @modelcontextprotocol/conformance client --scenario {scenario} --command \"{conformanceClientPath} {scenario}\""); + // Version is configured in Directory.Packages.props for central management + var startInfo = NodeHelpers.NpxStartInfo($"-y @modelcontextprotocol/conformance@{GetConformanceVersion()} client --scenario {scenario} --command \"{conformanceClientPath} {scenario}\""); var outputBuilder = new StringBuilder(); var errorBuilder = new StringBuilder(); From 919ff3e3e98352097897f038cafbabfc62371d37 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 29 Jan 2026 21:00:57 +0000 Subject: [PATCH 6/7] Address PR feedback: improve auth failure tracking, update DNS rebinding protection, add client credentials tests - Replace _scopeStepUpCount with _repeatedAuthFailureCount to track all auth failures - Reset counter only on successful requests to prevent infinite loops - Replace custom DNS rebinding middleware with AllowedHosts configuration - Add client_credentials grant support to TestOAuthServer with basic auth and post methods - Add ClientCredentialsTests with infinite loop prevention test - Fix off-by-one error in MaxReconnectionAttempts Co-authored-by: halter73 <54385+halter73@users.noreply.github.com> --- .../Authentication/ClientOAuthProvider.cs | 24 +-- .../Client/SseClientSessionTransport.cs | 2 +- .../OAuth/ClientCredentialsTests.cs | 109 +++++++++++ .../Program.cs | 2 - .../appsettings.json | 2 +- .../ClientInfo.cs | 21 +++ .../Program.cs | 172 +++++++++++++++++- 7 files changed, 309 insertions(+), 23 deletions(-) create mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/OAuth/ClientCredentialsTests.cs diff --git a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs index 7dff362bd..1ca08423f 100644 --- a/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs +++ b/src/ModelContextProtocol.Core/Authentication/ClientOAuthProvider.cs @@ -50,12 +50,13 @@ internal sealed partial class ClientOAuthProvider : McpHttpClient private string? _clientSecret; private ITokenCache _tokenCache; private AuthorizationServerMetadata? _authServerMetadata; - private int _scopeStepUpCount; + private int _repeatedAuthFailureCount; /// - /// Maximum number of scope step-up retries before failing. + /// Maximum number of repeated auth failure retries before failing. + /// This prevents infinite loops when tokens are never accepted by the server. /// - private const int MaxScopeStepUpRetries = 3; + private const int MaxRepeatedAuthFailures = 3; /// /// Initializes a new instance of the class using the specified options. @@ -166,6 +167,8 @@ internal override async Task SendAsync(HttpRequestMessage r return await HandleUnauthorizedResponseAsync(request, message, response, attemptedRefresh, cancellationToken).ConfigureAwait(false); } + // Reset the auth failure counter on successful request + _repeatedAuthFailureCount = 0; return response; } @@ -263,18 +266,11 @@ private async Task HandleUnauthorizedResponseAsync( /// The to monitor for cancellation requests. private async Task GetAccessTokenAsync(HttpResponseMessage response, bool attemptedRefresh, CancellationToken cancellationToken) { - // Check if this is a scope step-up retry (403 with insufficient_scope) - if (HasInsufficientScopeError(response)) + // Track all auth failures to prevent infinite redirect loops. + // This counter is only reset when a request succeeds (in SendAsync). + if (++_repeatedAuthFailureCount > MaxRepeatedAuthFailures) { - if (++_scopeStepUpCount >= MaxScopeStepUpRetries) - { - ThrowFailedToHandleUnauthorizedResponse($"Maximum scope step-up retry limit ({MaxScopeStepUpRetries}) exceeded."); - } - } - else - { - // Reset the counter for non-scope-step-up requests (e.g., initial auth or token expiry) - _scopeStepUpCount = 0; + ThrowFailedToHandleUnauthorizedResponse($"Maximum repeated authentication failure limit ({MaxRepeatedAuthFailures}) exceeded. The server may be rejecting all tokens."); } // Get available authorization servers from the 401 or 403 response diff --git a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs index 9783d14be..72173b9ee 100644 --- a/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol.Core/Client/SseClientSessionTransport.cs @@ -145,7 +145,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken) { int attempt = 0; - while (attempt < _options.MaxReconnectionAttempts && !cancellationToken.IsCancellationRequested) + while (attempt <= _options.MaxReconnectionAttempts && !cancellationToken.IsCancellationRequested) { try { diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/ClientCredentialsTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/ClientCredentialsTests.cs new file mode 100644 index 000000000..72cbc9c3d --- /dev/null +++ b/tests/ModelContextProtocol.AspNetCore.Tests/OAuth/ClientCredentialsTests.cs @@ -0,0 +1,109 @@ +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; +using ModelContextProtocol.AspNetCore; +using ModelContextProtocol.AspNetCore.Authentication; +using ModelContextProtocol.Authentication; +using ModelContextProtocol.Client; + +namespace ModelContextProtocol.AspNetCore.Tests.OAuth; + +/// +/// Tests for client_credentials OAuth flow with various authentication methods. +/// +public class ClientCredentialsTests : OAuthTestBase +{ + public ClientCredentialsTests(ITestOutputHelper outputHelper) + : base(outputHelper) + { + } + + [Fact] + public async Task CanAuthenticate_WithClientCredentials_ClientSecretPost() + { + await using var app = await StartMcpServerAsync(); + + // Use client_credentials flow with client_secret_post authentication + // Note: No AuthorizationRedirectDelegate means machine-to-machine flow will be attempted + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new ClientOAuthOptions + { + ClientId = "client-credentials-post", + ClientSecret = "cc-secret-post", + RedirectUri = new Uri("http://localhost:1179/callback"), + // No AuthorizationRedirectDelegate - triggers client_credentials flow + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(client); + } + + [Fact] + public async Task CanAuthenticate_WithClientCredentials_ClientSecretBasic() + { + await using var app = await StartMcpServerAsync(); + + // Use client_credentials flow with client_secret_basic authentication + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new ClientOAuthOptions + { + ClientId = "client-credentials-basic", + ClientSecret = "cc-secret-basic", + RedirectUri = new Uri("http://localhost:1179/callback"), + // No AuthorizationRedirectDelegate - triggers client_credentials flow + }, + }, HttpClient, LoggerFactory); + + await using var client = await McpClient.CreateAsync( + transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken); + + Assert.NotNull(client); + } + + [Fact] + public async Task DoesNotLoopIndefinitely_WhenTokensAlwaysRejected() + { + // Set up a server that always returns 401 even after authentication + // This simulates a buggy MCP server that never accepts tokens + var app = Builder.Build(); + + // Add middleware that always returns 401 with the MCP auth challenge + app.Use((HttpContext context, RequestDelegate next) => + { + context.Response.StatusCode = StatusCodes.Status401Unauthorized; + context.Response.Headers.WWWAuthenticate = $"Bearer realm=\"{OAuthServerUrl}\" resource_metadata=\"{McpServerUrl}/.well-known/oauth-protected-resource\""; + return context.Response.WriteAsync("Unauthorized"); + }); + + await app.StartAsync(TestContext.Current.CancellationToken); + await using var _ = app; + + await using var transport = new HttpClientTransport(new() + { + Endpoint = new(McpServerUrl), + OAuth = new ClientOAuthOptions + { + ClientId = "client-credentials-post", + ClientSecret = "cc-secret-post", + RedirectUri = new Uri("http://localhost:1179/callback"), + // No AuthorizationRedirectDelegate - triggers client_credentials flow + }, + }, HttpClient, LoggerFactory); + + // Should throw McpException after max retries, not loop indefinitely + var ex = await Assert.ThrowsAsync(async () => + await McpClient.CreateAsync(transport, loggerFactory: LoggerFactory, cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Contains("Maximum repeated authentication failure limit", ex.Message); + } +} diff --git a/tests/ModelContextProtocol.ConformanceServer/Program.cs b/tests/ModelContextProtocol.ConformanceServer/Program.cs index ddced4fc9..00e557e7f 100644 --- a/tests/ModelContextProtocol.ConformanceServer/Program.cs +++ b/tests/ModelContextProtocol.ConformanceServer/Program.cs @@ -92,8 +92,6 @@ public static async Task MainAsync(string[] args, ILoggerProvider? loggerProvide var app = builder.Build(); - app.UseMcpDnsRebindingProtection(); - app.MapMcp(); app.MapGet("/health", () => "Healthy"); diff --git a/tests/ModelContextProtocol.ConformanceServer/appsettings.json b/tests/ModelContextProtocol.ConformanceServer/appsettings.json index 10f68b8c8..757d8426e 100644 --- a/tests/ModelContextProtocol.ConformanceServer/appsettings.json +++ b/tests/ModelContextProtocol.ConformanceServer/appsettings.json @@ -5,5 +5,5 @@ "Microsoft.AspNetCore": "Warning" } }, - "AllowedHosts": "*" + "AllowedHosts": "localhost;127.0.0.1;[::1]" } diff --git a/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs b/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs index 500142b6b..ffe43c9dd 100644 --- a/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs +++ b/tests/ModelContextProtocol.TestOAuthServer/ClientInfo.cs @@ -26,4 +26,25 @@ internal sealed class ClientInfo /// Gets or sets the list of redirect URIs allowed for this client. /// public List RedirectUris { get; init; } = []; + + /// + /// Gets or sets the token endpoint auth method for this client. + /// Supported values: "client_secret_post", "client_secret_basic", "private_key_jwt", "none" + /// + public string TokenEndpointAuthMethod { get; init; } = "client_secret_post"; + + /// + /// Gets or sets the allowed grant types for this client. + /// + public List AllowedGrantTypes { get; init; } = ["authorization_code", "refresh_token"]; + + /// + /// Gets or sets the client's JWKS URI for JWT client assertion verification. + /// + public string? JwksUri { get; init; } + + /// + /// Gets or sets the client's public key PEM for JWT client assertion verification (inline, no JWKS fetch). + /// + public string? PublicKeyPem { get; init; } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.TestOAuthServer/Program.cs b/tests/ModelContextProtocol.TestOAuthServer/Program.cs index e13c731de..12392db92 100644 --- a/tests/ModelContextProtocol.TestOAuthServer/Program.cs +++ b/tests/ModelContextProtocol.TestOAuthServer/Program.cs @@ -140,6 +140,38 @@ public async Task RunServerAsync(string[]? args = null, CancellationToken cancel RedirectUris = ["http://localhost:1179/callback"], }; + // Client for testing client_credentials grant with client_secret_post (default) + _clients["client-credentials-post"] = new ClientInfo + { + ClientId = "client-credentials-post", + ClientSecret = "cc-secret-post", + RequiresClientSecret = true, + RedirectUris = [], + TokenEndpointAuthMethod = "client_secret_post", + AllowedGrantTypes = ["client_credentials"], + }; + + // Client for testing client_credentials grant with client_secret_basic + _clients["client-credentials-basic"] = new ClientInfo + { + ClientId = "client-credentials-basic", + ClientSecret = "cc-secret-basic", + RequiresClientSecret = true, + RedirectUris = [], + TokenEndpointAuthMethod = "client_secret_basic", + AllowedGrantTypes = ["client_credentials"], + }; + + // Client for testing client_credentials grant with private_key_jwt + _clients["client-credentials-jwt"] = new ClientInfo + { + ClientId = "client-credentials-jwt", + RequiresClientSecret = false, // JWT assertion is used instead + RedirectUris = [], + TokenEndpointAuthMethod = "private_key_jwt", + AllowedGrantTypes = ["client_credentials"], + }; + // The MCP spec tells the client to use /.well-known/oauth-authorization-server but AddJwtBearer looks for // /.well-known/openid-configuration by default. // @@ -171,10 +203,11 @@ IResult HandleMetadataRequest(HttpContext context, string? issuerPath = null) SubjectTypesSupported = ["public"], IdTokenSigningAlgValuesSupported = ["RS256"], ScopesSupported = ["openid", "profile", "email", "mcp:tools"], - TokenEndpointAuthMethodsSupported = ["client_secret_post"], + TokenEndpointAuthMethodsSupported = ["client_secret_post", "client_secret_basic", "private_key_jwt", "none"], + TokenEndpointAuthSigningAlgValuesSupported = ["RS256"], ClaimsSupported = ["sub", "iss", "name", "email", "aud"], CodeChallengeMethodsSupported = ["S256"], - GrantTypesSupported = ["authorization_code", "refresh_token"], + GrantTypesSupported = ["authorization_code", "refresh_token", "client_credentials"], IntrospectionEndpoint = $"{_url}/introspect", RegistrationEndpoint = $"{_url}/register", ClientIdMetadataDocumentSupported = ClientIdMetadataDocumentSupported, @@ -417,6 +450,26 @@ IResult HandleMetadataRequest(HttpContext context, string? issuerPath = null) HasRefreshedToken = true; return Results.Ok(response); } + else if (grant_type == "client_credentials") + { + // Client credentials flow - machine-to-machine authentication + var scope = form["scope"].ToString(); + var requestedScopes = string.IsNullOrEmpty(scope) ? [] : scope.Split(' ').ToList(); + + // Verify client is allowed to use client_credentials grant + if (!client.AllowedGrantTypes.Contains("client_credentials")) + { + return Results.BadRequest(new OAuthErrorResponse + { + Error = "unauthorized_client", + ErrorDescription = "Client is not authorized to use client_credentials grant" + }); + } + + // Generate token response for client credentials + var response = GenerateJwtTokenResponse(client.ClientId, requestedScopes, new Uri(resource)); + return Results.Ok(response); + } else { return Results.BadRequest(new OAuthErrorResponse @@ -546,21 +599,70 @@ IResult HandleMetadataRequest(HttpContext context, string? issuerPath = null) /// /// Authenticates a client based on client credentials in the request. + /// Supports client_secret_post, client_secret_basic, private_key_jwt, and none. /// /// The HTTP context. /// The form collection containing client credentials. /// The client info if authentication succeeds, null otherwise. private ClientInfo? AuthenticateClient(HttpContext context, IFormCollection form) { - var clientId = form["client_id"].ToString(); - var clientSecret = form["client_secret"].ToString(); + string? clientId = null; + string? clientSecret = null; + + // Try client_secret_basic (HTTP Basic Auth) + var authHeader = context.Request.Headers.Authorization.FirstOrDefault(); + if (!string.IsNullOrEmpty(authHeader) && authHeader.StartsWith("Basic ", StringComparison.OrdinalIgnoreCase)) + { + var encodedCredentials = authHeader["Basic ".Length..]; + var credentialBytes = Convert.FromBase64String(encodedCredentials); + var credentials = Encoding.UTF8.GetString(credentialBytes).Split(':', 2); + if (credentials.Length == 2) + { + clientId = Uri.UnescapeDataString(credentials[0]); + clientSecret = Uri.UnescapeDataString(credentials[1]); + } + } + + // Fallback to client_secret_post (form parameters) + if (string.IsNullOrEmpty(clientId)) + { + clientId = form["client_id"].ToString(); + clientSecret = form["client_secret"].ToString(); + } if (string.IsNullOrEmpty(clientId) || !_clients.TryGetValue(clientId, out var client)) { return null; } - if (client.RequiresClientSecret && client.ClientSecret != clientSecret) + // Check for JWT client assertion (private_key_jwt) + var clientAssertionType = form["client_assertion_type"].ToString(); + var clientAssertion = form["client_assertion"].ToString(); + + if (!string.IsNullOrEmpty(clientAssertionType) && clientAssertionType == "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + { + if (string.IsNullOrEmpty(clientAssertion)) + { + return null; + } + + // Verify JWT client assertion + if (!VerifyClientAssertion(client, clientAssertion)) + { + return null; + } + + return client; + } + + // For clients that don't require a secret (e.g., CIMD clients or none auth method) + if (!client.RequiresClientSecret) + { + return client; + } + + // Verify client secret + if (client.ClientSecret != clientSecret) { return null; } @@ -568,6 +670,66 @@ IResult HandleMetadataRequest(HttpContext context, string? issuerPath = null) return client; } + /// + /// Verifies a JWT client assertion for private_key_jwt authentication. + /// + private bool VerifyClientAssertion(ClientInfo client, string assertion) + { + // For simplicity, we just check that the JWT has three parts and the client has a public key configured + // In a real implementation, we would verify the signature using the client's public key + var parts = assertion.Split('.'); + if (parts.Length != 3) + { + return false; + } + + // Parse the payload to verify claims + try + { + var payloadJson = Encoding.UTF8.GetString(WebEncoders.Base64UrlDecode(parts[1])); + var payload = JsonSerializer.Deserialize>(payloadJson); + + if (payload == null) + { + return false; + } + + // Verify required claims + if (!payload.TryGetValue("iss", out var issElement) || issElement.GetString() != client.ClientId) + { + return false; + } + + if (!payload.TryGetValue("sub", out var subElement) || subElement.GetString() != client.ClientId) + { + return false; + } + + if (!payload.TryGetValue("aud", out var audElement) || audElement.GetString() != $"{_url}/token") + { + return false; + } + + // Verify expiration + if (payload.TryGetValue("exp", out var expElement)) + { + var exp = expElement.GetInt64(); + if (DateTimeOffset.UtcNow.ToUnixTimeSeconds() > exp) + { + return false; + } + } + + // If client has a public key configured, we would verify the signature here + // For testing purposes, we accept any properly structured JWT + return true; + } + catch + { + return false; + } + } + /// /// Generates a JWT token response. /// From f8d07f36fa1b0b5232a39cb21d4027ecd9a0a8cf Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 29 Jan 2026 21:24:02 +0000 Subject: [PATCH 7/7] Delete UseMcpDnsRebindingProtection API entirely Remove the DNS rebinding protection middleware, extension methods, and tests as requested. The ConformanceServer now uses the built-in AllowedHosts configuration instead. Co-authored-by: halter73 <54385+halter73@users.noreply.github.com> --- .../DnsRebindingProtectionMiddleware.cs | 113 ------------------ .../McpApplicationBuilderExtensions.cs | 47 -------- .../DnsRebindingProtectionMiddlewareTests.cs | 112 ----------------- 3 files changed, 272 deletions(-) delete mode 100644 src/ModelContextProtocol.AspNetCore/DnsRebindingProtectionMiddleware.cs delete mode 100644 src/ModelContextProtocol.AspNetCore/McpApplicationBuilderExtensions.cs delete mode 100644 tests/ModelContextProtocol.AspNetCore.Tests/DnsRebindingProtectionMiddlewareTests.cs diff --git a/src/ModelContextProtocol.AspNetCore/DnsRebindingProtectionMiddleware.cs b/src/ModelContextProtocol.AspNetCore/DnsRebindingProtectionMiddleware.cs deleted file mode 100644 index 7850dbb95..000000000 --- a/src/ModelContextProtocol.AspNetCore/DnsRebindingProtectionMiddleware.cs +++ /dev/null @@ -1,113 +0,0 @@ -using Microsoft.AspNetCore.Builder; -using Microsoft.AspNetCore.Http; -using Microsoft.Extensions.Logging; -using Microsoft.Net.Http.Headers; -using System.Net; -using System.Text.Json; - -namespace ModelContextProtocol.AspNetCore; - -/// -/// Middleware that provides DNS rebinding protection for MCP servers by validating -/// Host and Origin headers on requests to localhost servers. -/// -/// -/// -/// DNS rebinding attacks can allow malicious websites to bypass browser same-origin policy -/// and make requests to localhost services. This middleware helps protect against such attacks -/// by validating that Host and Origin headers match expected localhost values. -/// -/// -/// Use to enable this middleware. -/// -/// -/// -/// Initializes a new instance of the class. -/// -internal sealed partial class DnsRebindingProtectionMiddleware( - RequestDelegate next, - ILogger logger) -{ - private readonly RequestDelegate _next = next; - private readonly ILogger _logger = logger; - - /// - /// Processes the HTTP request and validates Host and Origin headers for localhost servers. - /// - public async Task InvokeAsync(HttpContext context) - { - // Only apply protection to localhost servers - var localEndpoint = context.Connection.LocalIpAddress; - bool isLocalhostServer = localEndpoint is null || - IPAddress.IsLoopback(localEndpoint) || - localEndpoint.Equals(IPAddress.IPv6Loopback); - - if (isLocalhostServer) - { - // Validate Host header - var host = context.Request.Host.Host; - if (!IsLocalhost(host)) - { - LogInvalidHostHeader(host); - await WriteJsonRpcErrorResponseAsync(context, $"Forbidden: Invalid Host header '{host}' for localhost server"); - return; - } - - // Validate Origin header if present - if (context.Request.Headers.TryGetValue(HeaderNames.Origin, out var originValues) && - originValues.FirstOrDefault() is string origin && - Uri.TryCreate(origin, UriKind.Absolute, out var originUri) && - !IsLocalhost(originUri.Host)) - { - LogInvalidOriginHeader(origin); - await WriteJsonRpcErrorResponseAsync(context, $"Forbidden: Invalid Origin header '{origin}' for localhost server"); - return; - } - } - - await _next(context).ConfigureAwait(false); - } - - private static bool IsLocalhost(string host) - { - if (!string.IsNullOrWhiteSpace(host)) - { - if (host.Equals("localhost", StringComparison.OrdinalIgnoreCase) || - host.Equals("[::1]") || - host.Equals("127.0.0.1")) - { - return true; - } - - if (IPAddress.TryParse(host, out var ip)) - { - return IPAddress.IsLoopback(ip); - } - } - - return false; - } - - private static Task WriteJsonRpcErrorResponseAsync(HttpContext context, string message) - { - context.Response.StatusCode = StatusCodes.Status403Forbidden; - context.Response.ContentType = "application/json"; - return context.Response.WriteAsync($$""" - { - "jsonrpc": "2.0", - "error": - { - "code": -32000, - "message": "{{JsonEncodedText.Encode(message)}}" - }, - "id": null - } - """); - } - - [LoggerMessage(Level = LogLevel.Warning, Message = "Rejected request with invalid Host header '{Host}' for localhost server. This may indicate a DNS rebinding attack.")] - private partial void LogInvalidHostHeader(string? host); - - [LoggerMessage(Level = LogLevel.Warning, Message = "Rejected request with invalid Origin header '{Origin}' for localhost server. This may indicate a DNS rebinding attack.")] - private partial void LogInvalidOriginHeader(string origin); -} diff --git a/src/ModelContextProtocol.AspNetCore/McpApplicationBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/McpApplicationBuilderExtensions.cs deleted file mode 100644 index e3c7fa6d9..000000000 --- a/src/ModelContextProtocol.AspNetCore/McpApplicationBuilderExtensions.cs +++ /dev/null @@ -1,47 +0,0 @@ -using ModelContextProtocol.AspNetCore; - -namespace Microsoft.AspNetCore.Builder; - -/// -/// Extension methods for adding MCP middleware to an . -/// -public static class McpApplicationBuilderExtensions -{ - /// - /// Adds DNS rebinding protection middleware for MCP servers running on localhost. - /// - /// The . - /// The for chaining. - /// - /// - /// This method provides protection against DNS rebinding attacks by validating that both - /// Host and Origin headers (when present) resolve to localhost addresses. - /// - /// - /// DNS rebinding attacks can allow malicious websites to bypass browser same-origin policy and make requests - /// to localhost services. This protection is recommended for any MCP server that binds to localhost. - /// - /// - /// For more information, see the MCP SDK security advisory. - /// - /// - /// - /// - /// var builder = WebApplication.CreateBuilder(args); - /// builder.Services.AddMcpServer().WithHttpTransport(); - /// - /// var app = builder.Build(); - /// app.UseMcpDnsRebindingProtection(); // Add before MapMcp() - /// app.MapMcp(); - /// app.Run(); - /// - /// - public static IApplicationBuilder UseMcpDnsRebindingProtection(this IApplicationBuilder app) - { - ArgumentNullException.ThrowIfNull(app); - - app.UseMiddleware(); - - return app; - } -} diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/DnsRebindingProtectionMiddlewareTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/DnsRebindingProtectionMiddlewareTests.cs deleted file mode 100644 index 45cc3e25d..000000000 --- a/tests/ModelContextProtocol.AspNetCore.Tests/DnsRebindingProtectionMiddlewareTests.cs +++ /dev/null @@ -1,112 +0,0 @@ -using Microsoft.AspNetCore.Builder; -using Microsoft.Extensions.DependencyInjection; -using ModelContextProtocol.AspNetCore.Tests.Utils; -using System.Net; - -namespace ModelContextProtocol.AspNetCore.Tests; - -public class DnsRebindingProtectionMiddlewareTests(ITestOutputHelper testOutputHelper) : KestrelInMemoryTest(testOutputHelper) -{ - [Theory] - [InlineData("localhost", HttpStatusCode.OK)] - [InlineData("localhost:5000", HttpStatusCode.OK)] - [InlineData("127.0.0.1", HttpStatusCode.OK)] - [InlineData("127.0.0.1:5000", HttpStatusCode.OK)] - [InlineData("[::1]", HttpStatusCode.OK)] - [InlineData("[::1]:5000", HttpStatusCode.OK)] - [InlineData("evil.com", HttpStatusCode.Forbidden)] - [InlineData("evil.localhost", HttpStatusCode.Forbidden)] - [InlineData("localhost.evil.com", HttpStatusCode.Forbidden)] - public async Task ValidatesHostHeader(string hostHeader, HttpStatusCode expectedStatusCode) - { - Builder.Services.AddMcpServer().WithHttpTransport(); - await using var app = Builder.Build(); - - app.UseMcpDnsRebindingProtection(); - app.MapGet("/test", () => "OK"); - - await app.StartAsync(TestContext.Current.CancellationToken); - - var request = new HttpRequestMessage(HttpMethod.Get, "/test"); - request.Headers.Host = hostHeader; - - var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); - - Assert.Equal(expectedStatusCode, response.StatusCode); - } - - [Theory] - [InlineData("http://localhost", HttpStatusCode.OK)] - [InlineData("http://localhost:5000", HttpStatusCode.OK)] - [InlineData("http://127.0.0.1", HttpStatusCode.OK)] - [InlineData("http://127.0.0.1:5000", HttpStatusCode.OK)] - [InlineData("http://[::1]", HttpStatusCode.OK)] - [InlineData("http://[::1]:5000", HttpStatusCode.OK)] - [InlineData("http://evil.com", HttpStatusCode.Forbidden)] - [InlineData("http://evil.localhost", HttpStatusCode.Forbidden)] - [InlineData("https://malicious.site", HttpStatusCode.Forbidden)] - public async Task ValidatesOriginHeader(string originHeader, HttpStatusCode expectedStatusCode) - { - Builder.Services.AddMcpServer().WithHttpTransport(); - await using var app = Builder.Build(); - - app.UseMcpDnsRebindingProtection(); - app.MapGet("/test", () => "OK"); - - await app.StartAsync(TestContext.Current.CancellationToken); - - var request = new HttpRequestMessage(HttpMethod.Get, "/test"); - request.Headers.Host = "localhost"; // Valid host - request.Headers.Add("Origin", originHeader); - - var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); - - Assert.Equal(expectedStatusCode, response.StatusCode); - } - - [Fact] - public async Task AllowsRequestsWithNoOriginHeader() - { - Builder.Services.AddMcpServer().WithHttpTransport(); - await using var app = Builder.Build(); - - app.UseMcpDnsRebindingProtection(); - app.MapGet("/test", () => "OK"); - - await app.StartAsync(TestContext.Current.CancellationToken); - - var request = new HttpRequestMessage(HttpMethod.Get, "/test"); - request.Headers.Host = "localhost"; - // No Origin header - - var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); - - Assert.Equal(HttpStatusCode.OK, response.StatusCode); - } - - [Fact] - public async Task ReturnsForbiddenWithJsonRpcErrorFormat() - { - Builder.Services.AddMcpServer().WithHttpTransport(); - await using var app = Builder.Build(); - - app.UseMcpDnsRebindingProtection(); - app.MapGet("/test", () => "OK"); - - await app.StartAsync(TestContext.Current.CancellationToken); - - var request = new HttpRequestMessage(HttpMethod.Get, "/test"); - request.Headers.Host = "localhost"; - request.Headers.Add("Origin", "http://evil.com"); - - var response = await HttpClient.SendAsync(request, TestContext.Current.CancellationToken); - - Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); - Assert.Equal("application/json", response.Content.Headers.ContentType?.MediaType); - - var content = await response.Content.ReadAsStringAsync(TestContext.Current.CancellationToken); - Assert.Contains("jsonrpc", content); - Assert.Contains("error", content); - Assert.Contains("-32000", content); - } -}