Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ private async ValueTask<StreamableHttpSession> StartNewSessionAsync(HttpContext
if (!HttpServerTransportOptions.Stateless)
{
sessionId = MakeNewSessionId();
transport = new()
transport = new(loggerFactory)
{
SessionId = sessionId,
FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext,
Expand All @@ -252,7 +252,7 @@ private async ValueTask<StreamableHttpSession> StartNewSessionAsync(HttpContext
// If in the future we support resuming stateless requests, we should populate
// the event stream store and retry interval here as well.
sessionId = "";
transport = new()
transport = new(loggerFactory)
{
Stateless = true,
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol;
using System.Diagnostics;
using System.Net.ServerSentEvents;
Expand All @@ -10,7 +11,11 @@ namespace ModelContextProtocol.Server;
/// Handles processing the request/response body pairs for the Streamable HTTP transport.
/// This is typically used via <see cref="JsonRpcMessageContext.RelatedTransport"/>.
/// </summary>
internal sealed class StreamableHttpPostTransport(StreamableHttpServerTransport parentTransport, Stream responseStream, CancellationToken sessionCancellationToken) : ITransport
internal sealed partial class StreamableHttpPostTransport(
StreamableHttpServerTransport parentTransport,
Stream responseStream,
CancellationToken sessionCancellationToken,
ILogger logger) : ITransport
{
private readonly SemaphoreSlim _messageLock = new(1, 1);
private readonly TaskCompletionSource<bool> _httpResponseTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
Expand Down Expand Up @@ -199,7 +204,14 @@ async Task HandleStoreStreamDisposalAsync(Task streamTask)
{
using var _ = await _messageLock.LockAsync().ConfigureAwait(false);

await _storeSseWriter!.DisposeAsync().ConfigureAwait(false);
try
{
await _storeSseWriter!.DisposeAsync().ConfigureAwait(false);
}
catch (Exception ex)
{
LogStoreStreamDisposalFailed(ex);
}
}
}
}
Expand All @@ -222,4 +234,7 @@ public async ValueTask DisposeAsync()
// Don't dispose the event stream writer here, as we may continue to write to the event store
// after disposal if there are pending messages.
}

[LoggerMessage(Level = LogLevel.Warning, Message = "Failed to dispose SSE event stream writer.")]
private partial void LogStoreStreamDisposalFailed(Exception exception);
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using ModelContextProtocol.Protocol;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
Expand All @@ -21,7 +23,7 @@ namespace ModelContextProtocol.Server;
/// such as when streaming completion results or providing progress updates during long-running operations.
/// </para>
/// </remarks>
public sealed class StreamableHttpServerTransport : ITransport
public sealed partial class StreamableHttpServerTransport : ITransport
{
/// <summary>
/// The stream ID used for unsolicited messages sent via the standalone GET SSE stream.
Expand All @@ -35,13 +37,23 @@ public sealed class StreamableHttpServerTransport : ITransport
});
private readonly CancellationTokenSource _transportDisposedCts = new();
private readonly SemaphoreSlim _unsolicitedMessageLock = new(1, 1);
private readonly ILogger _logger;

private SseEventWriter? _httpSseWriter;
private ISseEventStreamWriter? _storeSseWriter;
private TaskCompletionSource<bool>? _httpResponseTcs;
private bool _getHttpRequestStarted;
private bool _getHttpResponseCompleted;

/// <summary>
/// Initializes a new instance of the <see cref="StreamableHttpServerTransport"/> class.
/// </summary>
/// <param name="loggerFactory">Optional logger factory used for logging employed by the transport.</param>
public StreamableHttpServerTransport(ILoggerFactory? loggerFactory = null)
{
_logger = loggerFactory?.CreateLogger<StreamableHttpServerTransport>() ?? NullLogger<StreamableHttpServerTransport>.Instance;
}

/// <inheritdoc/>
public string? SessionId { get; init; }

Expand Down Expand Up @@ -161,7 +173,7 @@ public async Task<bool> HandlePostRequestAsync(JsonRpcMessage message, Stream re
Throw.IfNull(message);
Throw.IfNull(responseStream);

var postTransport = new StreamableHttpPostTransport(this, responseStream, _transportDisposedCts.Token);
var postTransport = new StreamableHttpPostTransport(this, responseStream, _transportDisposedCts.Token, _logger);
using var postCts = CancellationTokenSource.CreateLinkedTokenSource(_transportDisposedCts.Token, cancellationToken);
await using (postTransport.ConfigureAwait(false))
{
Expand Down