Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/client/src/simpleStreamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ async function connect(url?: string): Promise<void> {
attempts++;
console.log(`\nPlease provide the following information (attempt ${attempts}/${maxAttempts}):`);

const content: Record<string, unknown> = {};
const content: Record<string, string | number | boolean | string[]> = {};
let inputCancelled = false;

// Collect input for each field
Expand Down Expand Up @@ -357,7 +357,7 @@ async function connect(url?: string): Promise<void> {
// Parse and validate the input
try {
if (answer === '' && field.default !== undefined) {
content[fieldName] = field.default;
content[fieldName] = field.default as string | number | boolean | string[];
} else if (answer === '' && !isRequired) {
// Skip optional empty fields
continue;
Expand Down Expand Up @@ -401,7 +401,7 @@ async function connect(url?: string): Promise<void> {
}
}

content[fieldName] = parsedValue;
content[fieldName] = parsedValue as string | number | boolean | string[];
}
} catch (error) {
console.log(`❌ Error: ${error}`);
Expand Down
62 changes: 10 additions & 52 deletions packages/client/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,13 @@ import type {
ListToolsRequest,
LoggingLevel,
MessageExtraInfo,
Notification,
NotificationMethod,
ProtocolOptions,
ReadResourceRequest,
Request,
RequestMethod,
RequestOptions,
RequestTypeMap,
Result,
ResultTypeMap,
ServerCapabilities,
SubscribeRequest,
Tool,
Expand Down Expand Up @@ -193,36 +191,8 @@ export type ClientOptions = ProtocolOptions & {
*
* The client will automatically begin the initialization flow with the server when connect() is called.
*
* To use with custom types, extend the base Request/Notification/Result types and pass them as type parameters:
*
* ```typescript
* // Custom schemas
* const CustomRequestSchema = RequestSchema.extend({...})
* const CustomNotificationSchema = NotificationSchema.extend({...})
* const CustomResultSchema = ResultSchema.extend({...})
*
* // Type aliases
* type CustomRequest = z.infer<typeof CustomRequestSchema>
* type CustomNotification = z.infer<typeof CustomNotificationSchema>
* type CustomResult = z.infer<typeof CustomResultSchema>
*
* // Create typed client
* const client = new Client<CustomRequest, CustomNotification, CustomResult>({
* name: "CustomClient",
* version: "1.0.0"
* })
* ```
*/
export class Client<
RequestT extends Request = Request,
NotificationT extends Notification = Notification,
ResultT extends Result = Result
> extends Protocol<
ClientRequest | RequestT,
ClientNotification | NotificationT,
ClientResult | ResultT,
ClientContext<ClientRequest | RequestT, ClientNotification | NotificationT>
> {
export class Client extends Protocol<ClientContext> {
private _serverCapabilities?: ServerCapabilities;
private _serverVersion?: Implementation;
private _capabilities: ClientCapabilities;
Expand All @@ -231,7 +201,7 @@ export class Client<
private _cachedToolOutputValidators: Map<string, JsonSchemaValidator<unknown>> = new Map();
private _cachedKnownTaskTools: Set<string> = new Set();
private _cachedRequiredTaskTools: Set<string> = new Set();
private _experimental?: { tasks: ExperimentalClientTasks<RequestT, NotificationT, ResultT> };
private _experimental?: { tasks: ExperimentalClientTasks };
private _listChangedDebounceTimers: Map<string, ReturnType<typeof setTimeout>> = new Map();
private _pendingListChangedConfig?: ListChangedHandlers;
private _enforceStrictCapabilities: boolean;
Expand All @@ -254,10 +224,7 @@ export class Client<
}
}

protected override buildContext(
ctx: BaseContext<ClientRequest | RequestT, ClientNotification | NotificationT>,
_transportInfo?: MessageExtraInfo
): ClientContext<ClientRequest | RequestT, ClientNotification | NotificationT> {
protected override buildContext(ctx: BaseContext, _transportInfo?: MessageExtraInfo): ClientContext {
return ctx;
}

Expand Down Expand Up @@ -297,7 +264,7 @@ export class Client<
*
* @experimental
*/
get experimental(): { tasks: ExperimentalClientTasks<RequestT, NotificationT, ResultT> } {
get experimental(): { tasks: ExperimentalClientTasks } {
if (!this._experimental) {
this._experimental = {
tasks: new ExperimentalClientTasks(this)
Expand All @@ -324,16 +291,10 @@ export class Client<
*/
public override setRequestHandler<M extends RequestMethod>(
method: M,
handler: (
request: RequestTypeMap[M],
ctx: ClientContext<ClientRequest | RequestT, ClientNotification | NotificationT>
) => ClientResult | ResultT | Promise<ClientResult | ResultT>
handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
): void {
if (method === 'elicitation/create') {
const wrappedHandler = async (
request: RequestTypeMap[M],
ctx: ClientContext<ClientRequest | RequestT, ClientNotification | NotificationT>
): Promise<ClientResult | ResultT> => {
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
const validatedRequest = parseSchema(ElicitRequestSchema, request);
if (!validatedRequest.success) {
// Type guard: if success is false, error is guaranteed to exist
Expand Down Expand Up @@ -403,10 +364,7 @@ export class Client<
}

if (method === 'sampling/createMessage') {
const wrappedHandler = async (
request: RequestTypeMap[M],
ctx: ClientContext<ClientRequest | RequestT, ClientNotification | NotificationT>
): Promise<ClientResult | ResultT> => {
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
const validatedRequest = parseSchema(CreateMessageRequestSchema, request);
if (!validatedRequest.success) {
const errorMessage =
Expand Down Expand Up @@ -533,7 +491,7 @@ export class Client<
return this._instructions;
}

protected assertCapabilityForMethod(method: RequestT['method']): void {
protected assertCapabilityForMethod(method: RequestMethod): void {
switch (method as ClientRequest['method']) {
case 'logging/setLevel': {
if (!this._serverCapabilities?.logging) {
Expand Down Expand Up @@ -596,7 +554,7 @@ export class Client<
}
}

protected assertNotificationCapability(method: NotificationT['method']): void {
protected assertNotificationCapability(method: NotificationMethod): void {
switch (method as ClientNotification['method']) {
case 'notifications/roots/list_changed': {
if (!this._capabilities.roots?.listChanged) {
Expand Down
21 changes: 7 additions & 14 deletions packages/client/src/experimental/tasks/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,11 @@ import type {
AnyObjectSchema,
CallToolRequest,
CancelTaskResult,
ClientRequest,
GetTaskResult,
ListTasksResult,
Notification,
Request,
RequestOptions,
ResponseMessage,
Result,
SchemaOutput
} from '@modelcontextprotocol/core';
import { CallToolResultSchema, ProtocolError, ProtocolErrorCode } from '@modelcontextprotocol/core';
Expand All @@ -27,9 +24,9 @@ import type { Client } from '../../client/client.js';
* Internal interface for accessing Client's private methods.
* @internal
*/
interface ClientInternal<RequestT extends Request> {
interface ClientInternal {
requestStream<T extends AnyObjectSchema>(
request: ClientRequest | RequestT,
request: Request,
resultSchema: T,
options?: RequestOptions
): AsyncGenerator<ResponseMessage<SchemaOutput<T>>, void, void>;
Expand All @@ -48,12 +45,8 @@ interface ClientInternal<RequestT extends Request> {
*
* @experimental
*/
export class ExperimentalClientTasks<
RequestT extends Request = Request,
NotificationT extends Notification = Notification,
ResultT extends Result = Result
> {
constructor(private readonly _client: Client<RequestT, NotificationT, ResultT>) {}
export class ExperimentalClientTasks {
constructor(private readonly _client: Client) {}

/**
* Calls a tool and returns an AsyncGenerator that yields response messages.
Expand Down Expand Up @@ -96,7 +89,7 @@ export class ExperimentalClientTasks<
options?: RequestOptions
): AsyncGenerator<ResponseMessage<SchemaOutput<typeof CallToolResultSchema>>, void, void> {
// Access Client's internal methods
const clientInternal = this._client as unknown as ClientInternal<RequestT>;
const clientInternal = this._client as unknown as ClientInternal;

// Add task creation parameters if server supports it and not explicitly provided
const optionsWithTask = {
Expand Down Expand Up @@ -255,14 +248,14 @@ export class ExperimentalClientTasks<
* @experimental
*/
requestStream<T extends AnyObjectSchema>(
request: ClientRequest | RequestT,
request: Request,
resultSchema: T,
options?: RequestOptions
): AsyncGenerator<ResponseMessage<SchemaOutput<T>>, void, void> {
// Delegate to the client's underlying Protocol method
type ClientWithRequestStream = {
requestStream<U extends AnyObjectSchema>(
request: ClientRequest | RequestT,
request: Request,
resultSchema: U,
options?: RequestOptions
): AsyncGenerator<ResponseMessage<SchemaOutput<U>>, void, void>;
Expand Down
Loading
Loading