diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs index c50e51388..22c861326 100644 --- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs +++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs @@ -238,7 +238,7 @@ private async ValueTask StartNewSessionAsync(HttpContext if (!HttpServerTransportOptions.Stateless) { sessionId = MakeNewSessionId(); - transport = new() + transport = new(loggerFactory) { SessionId = sessionId, FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext, @@ -252,7 +252,7 @@ private async ValueTask StartNewSessionAsync(HttpContext // If in the future we support resuming stateless requests, we should populate // the event stream store and retry interval here as well. sessionId = ""; - transport = new() + transport = new(loggerFactory) { Stateless = true, }; diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs index f0f94c270..b6f6b663e 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpPostTransport.cs @@ -1,3 +1,4 @@ +using Microsoft.Extensions.Logging; using ModelContextProtocol.Protocol; using System.Diagnostics; using System.Net.ServerSentEvents; @@ -10,7 +11,11 @@ namespace ModelContextProtocol.Server; /// Handles processing the request/response body pairs for the Streamable HTTP transport. /// This is typically used via . /// -internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, Stream responseStream, CancellationToken sessionCancellationToken) : ITransport +internal sealed partial class StreamableHttpPostTransport( + StreamableHttpServerTransport parentTransport, + Stream responseStream, + CancellationToken sessionCancellationToken, + ILogger logger) : ITransport { private readonly SemaphoreSlim _messageLock = new(1, 1); private readonly TaskCompletionSource _httpResponseTcs = new(TaskCreationOptions.RunContinuationsAsynchronously); @@ -199,7 +204,14 @@ async Task HandleStoreStreamDisposalAsync(Task streamTask) { using var _ = await _messageLock.LockAsync().ConfigureAwait(false); - await _storeSseWriter!.DisposeAsync().ConfigureAwait(false); + try + { + await _storeSseWriter!.DisposeAsync().ConfigureAwait(false); + } + catch (Exception ex) + { + LogStoreStreamDisposalFailed(ex); + } } } } @@ -222,4 +234,7 @@ public async ValueTask DisposeAsync() // Don't dispose the event stream writer here, as we may continue to write to the event store // after disposal if there are pending messages. } + + [LoggerMessage(Level = LogLevel.Warning, Message = "Failed to dispose SSE event stream writer.")] + private partial void LogStoreStreamDisposalFailed(Exception exception); } diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs index 55b6650b1..58227757b 100644 --- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs +++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs @@ -1,3 +1,5 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using ModelContextProtocol.Protocol; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; @@ -21,7 +23,7 @@ namespace ModelContextProtocol.Server; /// such as when streaming completion results or providing progress updates during long-running operations. /// /// -public sealed class StreamableHttpServerTransport : ITransport +public sealed partial class StreamableHttpServerTransport : ITransport { /// /// The stream ID used for unsolicited messages sent via the standalone GET SSE stream. @@ -35,6 +37,7 @@ public sealed class StreamableHttpServerTransport : ITransport }); private readonly CancellationTokenSource _transportDisposedCts = new(); private readonly SemaphoreSlim _unsolicitedMessageLock = new(1, 1); + private readonly ILogger _logger; private SseEventWriter? _httpSseWriter; private ISseEventStreamWriter? _storeSseWriter; @@ -42,6 +45,15 @@ public sealed class StreamableHttpServerTransport : ITransport private bool _getHttpRequestStarted; private bool _getHttpResponseCompleted; + /// + /// Initializes a new instance of the class. + /// + /// Optional logger factory used for logging employed by the transport. + public StreamableHttpServerTransport(ILoggerFactory? loggerFactory = null) + { + _logger = loggerFactory?.CreateLogger() ?? NullLogger.Instance; + } + /// public string? SessionId { get; init; } @@ -161,7 +173,7 @@ public async Task HandlePostRequestAsync(JsonRpcMessage message, Stream re Throw.IfNull(message); Throw.IfNull(responseStream); - var postTransport = new StreamableHttpPostTransport(this, responseStream, _transportDisposedCts.Token); + var postTransport = new StreamableHttpPostTransport(this, responseStream, _transportDisposedCts.Token, _logger); using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_transportDisposedCts.Token, cancellationToken); await using (postTransport.ConfigureAwait(false)) {