diff --git a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java index 80c25610d..aeb34a1e9 100644 --- a/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java +++ b/contrib/langchain4j/src/main/java/com/google/adk/models/langchain4j/LangChain4j.java @@ -29,6 +29,7 @@ import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.FunctionResponse; import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; import com.google.genai.types.Schema; import com.google.genai.types.ToolConfig; @@ -51,6 +52,7 @@ import dev.langchain4j.data.pdf.PdfFile; import dev.langchain4j.data.video.Video; import dev.langchain4j.exception.UnsupportedFeatureException; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -64,6 +66,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.BackpressureStrategy; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; @@ -83,24 +86,109 @@ public class LangChain4j extends BaseLlm { private final ChatModel chatModel; private final StreamingChatModel streamingChatModel; private final ObjectMapper objectMapper; + private final TokenCountEstimator tokenCountEstimator; + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + private ChatModel chatModel; + private StreamingChatModel streamingChatModel; + private String modelName; + private TokenCountEstimator tokenCountEstimator; + + private Builder() {} + + public Builder chatModel(ChatModel chatModel) { + this.chatModel = chatModel; + return this; + } + + public Builder streamingChatModel(StreamingChatModel streamingChatModel) { + this.streamingChatModel = streamingChatModel; + return this; + } + + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + + public Builder tokenCountEstimator(TokenCountEstimator tokenCountEstimator) { + this.tokenCountEstimator = tokenCountEstimator; + return this; + } + + public LangChain4j build() { + if (chatModel == null && streamingChatModel == null) { + throw new IllegalStateException( + "At least one of chatModel or streamingChatModel must be provided"); + } + + String effectiveModelName = modelName; + if (effectiveModelName == null) { + if (chatModel != null) { + effectiveModelName = chatModel.defaultRequestParameters().modelName(); + } else { + effectiveModelName = streamingChatModel.defaultRequestParameters().modelName(); + } + } + + if (effectiveModelName == null) { + throw new IllegalStateException("Model name cannot be null"); + } + + return new LangChain4j( + chatModel, streamingChatModel, effectiveModelName, tokenCountEstimator); + } + } + + private LangChain4j( + ChatModel chatModel, + StreamingChatModel streamingChatModel, + String modelName, + TokenCountEstimator tokenCountEstimator) { + super(Objects.requireNonNull(modelName, "model name cannot be null")); + this.chatModel = chatModel; + this.streamingChatModel = streamingChatModel; + this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; + } public LangChain4j(ChatModel chatModel) { + this(chatModel, (TokenCountEstimator) null); + } + + public LangChain4j(ChatModel chatModel, TokenCountEstimator tokenCountEstimator) { super( Objects.requireNonNull( chatModel.defaultRequestParameters().modelName(), "chat model name cannot be null")); this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); this.streamingChatModel = null; this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; } public LangChain4j(ChatModel chatModel, String modelName) { + this(chatModel, modelName, (TokenCountEstimator) null); + } + + public LangChain4j( + ChatModel chatModel, String modelName, TokenCountEstimator tokenCountEstimator) { super(Objects.requireNonNull(modelName, "chat model name cannot be null")); this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); this.streamingChatModel = null; this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; } public LangChain4j(StreamingChatModel streamingChatModel) { + this(streamingChatModel, (TokenCountEstimator) null); + } + + public LangChain4j( + StreamingChatModel streamingChatModel, TokenCountEstimator tokenCountEstimator) { super( Objects.requireNonNull( streamingChatModel.defaultRequestParameters().modelName(), @@ -109,22 +197,23 @@ public LangChain4j(StreamingChatModel streamingChatModel) { this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; } public LangChain4j(StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); - this.chatModel = null; - this.streamingChatModel = - Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); - this.objectMapper = new ObjectMapper(); + this(streamingChatModel, modelName, (TokenCountEstimator) null); } - public LangChain4j(ChatModel chatModel, StreamingChatModel streamingChatModel, String modelName) { - super(Objects.requireNonNull(modelName, "model name cannot be null")); - this.chatModel = Objects.requireNonNull(chatModel, "chatModel cannot be null"); + public LangChain4j( + StreamingChatModel streamingChatModel, + String modelName, + TokenCountEstimator tokenCountEstimator) { + super(Objects.requireNonNull(modelName, "streaming chat model name cannot be null")); + this.chatModel = null; this.streamingChatModel = Objects.requireNonNull(streamingChatModel, "streamingChatModel cannot be null"); this.objectMapper = new ObjectMapper(); + this.tokenCountEstimator = tokenCountEstimator; } @Override @@ -185,7 +274,7 @@ public void onError(Throwable throwable) { ChatRequest chatRequest = toChatRequest(llmRequest); ChatResponse chatResponse = chatModel.chat(chatRequest); - LlmResponse llmResponse = toLlmResponse(chatResponse); + LlmResponse llmResponse = toLlmResponse(chatResponse, chatRequest); return Flowable.just(llmResponse); } @@ -496,11 +585,38 @@ private JsonSchemaElement toJsonSchemaElement(Schema schema) { } } - private LlmResponse toLlmResponse(ChatResponse chatResponse) { + private LlmResponse toLlmResponse(ChatResponse chatResponse, ChatRequest chatRequest) { Content content = Content.builder().role("model").parts(toParts(chatResponse.aiMessage())).build(); - return LlmResponse.builder().content(content).build(); + LlmResponse.Builder builder = LlmResponse.builder().content(content); + TokenUsage tokenUsage = chatResponse.tokenUsage(); + if (tokenCountEstimator != null) { + try { + int estimatedInput = + tokenCountEstimator.estimateTokenCountInMessages(chatRequest.messages()); + int estimatedOutput = + tokenCountEstimator.estimateTokenCountInText(chatResponse.aiMessage().text()); + int estimatedTotal = estimatedInput + estimatedOutput; + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(estimatedInput) + .candidatesTokenCount(estimatedOutput) + .totalTokenCount(estimatedTotal) + .build()); + } catch (Exception e) { + e.printStackTrace(); + } + } else if (tokenUsage != null) { + builder.usageMetadata( + GenerateContentResponseUsageMetadata.builder() + .promptTokenCount(tokenUsage.inputTokenCount()) + .candidatesTokenCount(tokenUsage.outputTokenCount()) + .totalTokenCount(tokenUsage.totalTokenCount()) + .build()); + } + + return builder.build(); } private List toParts(AiMessage aiMessage) { diff --git a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java index 428a5660c..67f6bee0a 100644 --- a/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java +++ b/contrib/langchain4j/src/test/java/com/google/adk/models/langchain4j/LangChain4jTest.java @@ -26,6 +26,7 @@ import dev.langchain4j.agent.tool.ToolExecutionRequest; import dev.langchain4j.data.message.AiMessage; import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.model.TokenCountEstimator; import dev.langchain4j.model.chat.ChatModel; import dev.langchain4j.model.chat.StreamingChatModel; import dev.langchain4j.model.chat.request.ChatRequest; @@ -33,6 +34,7 @@ import dev.langchain4j.model.chat.request.json.JsonStringSchema; import dev.langchain4j.model.chat.response.ChatResponse; import dev.langchain4j.model.chat.response.StreamingChatResponseHandler; +import dev.langchain4j.model.output.TokenUsage; import io.reactivex.rxjava3.core.Flowable; import java.util.ArrayList; import java.util.List; @@ -688,4 +690,140 @@ void testGenerateContentWithStructuredResponseJsonSchema() { final UserMessage userMessage = (UserMessage) capturedRequest.messages().get(0); assertThat(userMessage.singleText()).isEqualTo("Give me information about John Doe"); } + + @Test + @DisplayName( + "Should use TokenCountEstimator to estimate token usage when TokenUsage is not available") + void testTokenCountEstimatorFallback() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(50); // Input tokens + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(20); // Output tokens + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage (simulating when LLM doesn't provide token counts) + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response has usage metadata estimated by TokenCountEstimator + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("The weather is sunny today."); + + // IMPORTANT: Verify that token usage was estimated via the TokenCountEstimator + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(20)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(70)); // 50 + 20 + + // Verify the estimator was actually called + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should prioritize TokenCountEstimator over TokenUsage when estimator is provided") + void testTokenCountEstimatorPriority() { + // Given + // Create a mock TokenCountEstimator + final TokenCountEstimator tokenCountEstimator = mock(TokenCountEstimator.class); + when(tokenCountEstimator.estimateTokenCountInMessages(any())).thenReturn(100); // From estimator + when(tokenCountEstimator.estimateTokenCountInText(any())).thenReturn(50); // From estimator + + // Create LangChain4j with the TokenCountEstimator using Builder + final LangChain4j langChain4jWithEstimator = + LangChain4j.builder() + .chatModel(chatModel) + .modelName(MODEL_NAME) + .tokenCountEstimator(tokenCountEstimator) + .build(); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("What is the weather today?")))) + .build(); + + // Mock ChatResponse WITH actual TokenUsage from the LLM + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("The weather is sunny today."); + final TokenUsage actualTokenUsage = new TokenUsage(30, 15, 45); // Actual token counts from LLM + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(actualTokenUsage); // LLM provides token usage + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jWithEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // IMPORTANT: When TokenCountEstimator is present, it takes priority over TokenUsage + assertThat(response).isNotNull(); + assertThat(response.usageMetadata()).isPresent(); + final GenerateContentResponseUsageMetadata usageMetadata = response.usageMetadata().get(); + assertThat(usageMetadata.promptTokenCount()).isEqualTo(Optional.of(100)); // From estimator + assertThat(usageMetadata.candidatesTokenCount()).isEqualTo(Optional.of(50)); // From estimator + assertThat(usageMetadata.totalTokenCount()).isEqualTo(Optional.of(150)); // 100 + 50 + + // Verify the estimator was called (it takes priority) + verify(tokenCountEstimator).estimateTokenCountInMessages(any()); + verify(tokenCountEstimator).estimateTokenCountInText("The weather is sunny today."); + } + + @Test + @DisplayName("Should not include usageMetadata when TokenUsage is null and no estimator provided") + void testNoUsageMetadataWithoutEstimator() { + // Given + // Create LangChain4j WITHOUT TokenCountEstimator (default behavior) + final LangChain4j langChain4jNoEstimator = new LangChain4j(chatModel, MODEL_NAME); + + // Create a LlmRequest + final LlmRequest llmRequest = + LlmRequest.builder() + .contents(List.of(Content.fromParts(Part.fromText("Hello, world!")))) + .build(); + + // Mock ChatResponse WITHOUT TokenUsage + final ChatResponse chatResponse = mock(ChatResponse.class); + final AiMessage aiMessage = AiMessage.from("Hello! How can I help you?"); + when(chatResponse.aiMessage()).thenReturn(aiMessage); + when(chatResponse.tokenUsage()).thenReturn(null); // No token usage from LLM + when(chatModel.chat(any(ChatRequest.class))).thenReturn(chatResponse); + + // When + final LlmResponse response = + langChain4jNoEstimator.generateContent(llmRequest, false).blockingFirst(); + + // Then + // Verify the response does NOT have usage metadata + assertThat(response).isNotNull(); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().text()).isEqualTo("Hello! How can I help you?"); + + // IMPORTANT: usageMetadata should be empty when no TokenUsage and no estimator + assertThat(response.usageMetadata()).isEmpty(); + } } diff --git a/contrib/samples/a2a_basic/bin/.project b/contrib/samples/a2a_basic/bin/.project new file mode 100644 index 000000000..a4b459152 --- /dev/null +++ b/contrib/samples/a2a_basic/bin/.project @@ -0,0 +1,23 @@ + + + google-adk-sample-a2a-basic + + + + + + org.eclipse.jdt.core.javabuilder + + + + + org.eclipse.m2e.core.maven2Builder + + + + + + org.eclipse.jdt.core.javanature + org.eclipse.m2e.core.maven2Nature + + diff --git a/contrib/samples/a2a_basic/bin/.settings/org.eclipse.core.resources.prefs b/contrib/samples/a2a_basic/bin/.settings/org.eclipse.core.resources.prefs new file mode 100644 index 000000000..99f26c020 --- /dev/null +++ b/contrib/samples/a2a_basic/bin/.settings/org.eclipse.core.resources.prefs @@ -0,0 +1,2 @@ +eclipse.preferences.version=1 +encoding/=UTF-8 diff --git a/contrib/samples/a2a_basic/bin/.settings/org.eclipse.m2e.core.prefs b/contrib/samples/a2a_basic/bin/.settings/org.eclipse.m2e.core.prefs new file mode 100644 index 000000000..f897a7f1c --- /dev/null +++ b/contrib/samples/a2a_basic/bin/.settings/org.eclipse.m2e.core.prefs @@ -0,0 +1,4 @@ +activeProfiles= +eclipse.preferences.version=1 +resolveWorkspaceProjects=true +version=1 diff --git a/contrib/samples/a2a_basic/bin/README.md b/contrib/samples/a2a_basic/bin/README.md new file mode 100644 index 000000000..4d2c793ae --- /dev/null +++ b/contrib/samples/a2a_basic/bin/README.md @@ -0,0 +1,49 @@ +# A2A Basic Sample + +This sample shows how to invoke an A2A-compliant HTTP endpoint from the Google +ADK runtime using the reusable `google-adk-a2a` module. It wires a +`RemoteA2AAgent` to the production `JdkA2AHttpClient`, so you can exercise a +running service (for example the Spring Boot webservice in +`a2a/webservice`). + +## Prerequisites + +1. Start the Spring service (or point to any other A2A-compliant endpoint): + + ```bash + cd /google_adk + ./mvnw -f a2a/webservice/pom.xml spring-boot:run \ + -Dspring-boot.run.arguments=--server.port=8081 + ``` + +## Build and run + +```bash +cd google_adk +./mvnw -f contrib/samples/a2a_basic/pom.xml exec:java \ + -Dexec.args="http://localhost:8081/a2a/remote" +``` + +You should see the client log each turn, including the remote agent response +(e.g. `4 is not a prime number.`). + +To run the client in the background and capture logs: + +```bash +nohup env GOOGLE_GENAI_USE_VERTEXAI=FALSE \ + GOOGLE_API_KEY=your_api_key \ + ./mvnw -f contrib/samples/a2a_basic/pom.xml exec:java \ + -Dexec.args="http://localhost:8081/a2a/remote" \ + > /tmp/a2a_basic.log 2>&1 & echo $! +``` + +Tail `/tmp/a2a_basic.log` to inspect the conversation. + +## Key files + +- `A2AAgent.java` – builds a root agent with a local dice-rolling tool and a + remote prime-checking sub-agent. +- `A2AAgentRun.java` – minimal driver that executes a single + `SendMessage` turn to demonstrate the remote call. +- `pom.xml` – standalone Maven configuration for building and running the + sample. diff --git a/contrib/samples/a2a_basic/bin/pom.xml b/contrib/samples/a2a_basic/bin/pom.xml new file mode 100644 index 000000000..fbcfe8d06 --- /dev/null +++ b/contrib/samples/a2a_basic/bin/pom.xml @@ -0,0 +1,98 @@ + + + 4.0.0 + + + com.google.adk + google-adk-samples + 0.4.1-SNAPSHOT + .. + + + google-adk-sample-a2a-basic + jar + + Google ADK - Sample - A2A Basic Client + Demonstrates sending A2A REST requests using the google-adk-a2a module. + + + UTF-8 + 17 + ${project.version} + ${project.version} + 2.0.16 + + + + + com.google.adk + google-adk + ${google-adk.version} + + + com.google.adk + google-adk-a2a + ${google-adk-a2a.version} + + + org.slf4j + slf4j-simple + ${slf4j.version} + runtime + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.13.0 + + ${java.version} + true + + + + org.codehaus.mojo + build-helper-maven-plugin + 3.6.0 + + + add-source + generate-sources + + add-source + + + + . + + + + + + + org.apache.maven.plugins + maven-source-plugin + + + **/*.jar + target/** + + + + + org.codehaus.mojo + exec-maven-plugin + 3.2.0 + + com.example.a2a_basic.A2AAgentRun + runtime + + + + +