diff --git a/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java b/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java new file mode 100644 index 000000000..978a80dab --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java @@ -0,0 +1,46 @@ +package com.google.adk.tools.mcp; + +import com.google.common.collect.ImmutableMap; +import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; +import io.modelcontextprotocol.client.transport.ServerParameters; +import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.spec.McpClientTransport; +import java.util.Collection; +import java.util.Optional; + +/** + * The default builder for creating MCP client transports. Supports StdioClientTransport based on + * {@link ServerParameters} and the standard HttpClientSseClientTransport based on {@link + * SseServerParameters}. + */ +public class DefaultMcpTransportBuilder implements McpTransportBuilder { + + @Override + public McpClientTransport build(Object connectionParams) { + if (connectionParams instanceof ServerParameters serverParameters) { + return new StdioClientTransport(serverParameters); + } else if (connectionParams instanceof SseServerParameters sseServerParams) { + return HttpClientSseClientTransport.builder(sseServerParams.url()) + .sseEndpoint("sse") + .customizeRequest( + builder -> + Optional.ofNullable(sseServerParams.headers()) + .map(ImmutableMap::entrySet) + .stream() + .flatMap(Collection::stream) + .forEach( + entry -> + builder.header( + entry.getKey(), + Optional.ofNullable(entry.getValue()) + .map(Object::toString) + .orElse("")))) + .build(); + } else { + throw new IllegalArgumentException( + "DefaultMcpTransportBuilder supports only ServerParameters or SseServerParameters, but" + + " got " + + connectionParams.getClass().getName()); + } + } +} diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java b/core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java index 0e86282ed..1c6e292ac 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpSessionManager.java @@ -16,27 +16,20 @@ package com.google.adk.tools.mcp; -import com.google.common.collect.ImmutableMap; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.McpSyncClient; -import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; -import io.modelcontextprotocol.client.transport.ServerParameters; -import io.modelcontextprotocol.client.transport.StdioClientTransport; import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpSchema.ClientCapabilities; import io.modelcontextprotocol.spec.McpSchema.InitializeResult; +import java.time.Duration; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.time.Duration; -import java.util.Collection; -import java.util.Optional; - /** * Manages MCP client sessions. * *

This class provides methods for creating and initializing MCP client sessions, handling - * different connection parameters (Stdio and SSE). + * different connection parameters and transport builders. */ // TODO(b/413489523): Implement this class. public class McpSessionManager { @@ -53,31 +46,21 @@ public McpSyncClient createSession() { } public static McpSyncClient initializeSession(Object connectionParams) { - McpClientTransport transport; - if (connectionParams instanceof ServerParameters serverParameters) { - transport = new StdioClientTransport(serverParameters); - } else if (connectionParams instanceof SseServerParameters sseServerParams) { - transport = HttpClientSseClientTransport.builder(sseServerParams.url()) - .sseEndpoint("sse") - .customizeRequest(builder -> - Optional.ofNullable(sseServerParams.headers()) - .map(ImmutableMap::entrySet) - .stream().flatMap(Collection::stream) - .forEach(entry -> - builder.header( - entry.getKey(), - Optional.ofNullable(entry.getValue()) - .map(Object::toString) - .orElse("") - ) - ) - ) - .build(); - } else { - throw new IllegalArgumentException( - "Connection parameters must be either ServerParameters or SseServerParameters, but got " - + connectionParams.getClass().getName()); - } + McpClientTransport transport = new DefaultMcpTransportBuilder().build(connectionParams); + McpSyncClient client = + McpClient.sync(transport) + .requestTimeout(Duration.ofSeconds(10)) + .capabilities(ClientCapabilities.builder().build()) + .build(); + InitializeResult initResult = client.initialize(); + logger.debug("Initialize Client Result: {}", initResult); + return client; + } + + public static McpSyncClient initializeSession( + Object connectionParams, McpTransportBuilder transportBuilder) { + McpClientTransport transport = transportBuilder.build(connectionParams); + McpSyncClient client = McpClient.sync(transport) .requestTimeout(Duration.ofSeconds(10)) diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpTransportBuilder.java b/core/src/main/java/com/google/adk/tools/mcp/McpTransportBuilder.java new file mode 100644 index 000000000..c44e779c1 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/mcp/McpTransportBuilder.java @@ -0,0 +1,20 @@ +package com.google.adk.tools.mcp; + +import io.modelcontextprotocol.spec.McpClientTransport; + +/** + * Interface for building McpClientTransport instances. Implementations of this interface are + * responsible for constructing concrete McpClientTransport objects based on the provided connection + * parameters. + */ +public interface McpTransportBuilder { + /** + * Builds an McpClientTransport based on the provided connection parameters. + * + * @param connectionParams The parameters required to configure the transport. The type of this + * object determines the type of transport built. + * @return An instance of McpClientTransport. + * @throws IllegalArgumentException if the connectionParams are not supported or invalid. + */ + McpClientTransport build(Object connectionParams); +}