Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/strands/event_loop/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def handle_content_block_start(event: ContentBlockStartEvent) -> dict[str, Any]:
current_tool_use["toolUseId"] = tool_use_data["toolUseId"]
current_tool_use["name"] = tool_use_data["name"]
current_tool_use["input"] = ""
# Preserve type field for server-side tools (e.g., "server_tool_use" for nova_grounding)
if "type" in tool_use_data:
current_tool_use["type"] = tool_use_data["type"]

return current_tool_use

Expand Down Expand Up @@ -280,11 +283,15 @@ def handle_content_block_stop(state: dict[str, Any]) -> dict[str, Any]:
tool_use_id = current_tool_use["toolUseId"]
tool_use_name = current_tool_use["name"]

tool_use = ToolUse(
toolUseId=tool_use_id,
name=tool_use_name,
input=current_tool_use["input"],
)
tool_use: ToolUse = {
"toolUseId": tool_use_id,
"name": tool_use_name,
"input": current_tool_use["input"],
}
# Preserve type field for server-side tools (e.g., "server_tool_use" for nova_grounding)
if "type" in current_tool_use:
tool_use["type"] = current_tool_use["type"]

content.append({"toolUse": tool_use})
state["current_tool_use"] = {}

Expand Down
127 changes: 111 additions & 16 deletions src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@

DEFAULT_READ_TIMEOUT = 120

# Keys in additional_args that may conflict with built-in request construction
_CONFLICTING_ADDITIONAL_ARGS_KEYS = frozenset(
{
"toolConfig",
"inferenceConfig",
"guardrailConfig",
"system",
"messages",
"modelId",
}
)


class BedrockModel(Model):
"""AWS Bedrock model provider implementation.
Expand Down Expand Up @@ -88,6 +100,9 @@ class BedrockConfig(TypedDict, total=False):
True includes status, False removes status, "auto" determines based on model_id. Defaults to "auto".
stop_sequences: List of sequences that will stop generation when encountered
streaming: Flag to enable/disable streaming. Defaults to True.
system_tools: List of Bedrock system tool definitions (e.g., nova_grounding).
These are server-side tools merged with agent tools in toolConfig.
Example: [{"systemTool": {"name": "nova_grounding"}}]
temperature: Controls randomness in generation (higher = more random)
top_p: Controls diversity via nucleus sampling (alternative to temperature)
"""
Expand All @@ -110,6 +125,7 @@ class BedrockConfig(TypedDict, total=False):
include_tool_result_status: Optional[Literal["auto"] | bool]
stop_sequences: Optional[list[str]]
streaming: Optional[bool]
system_tools: Optional[list[dict[str, Any]]]
temperature: Optional[float]
top_p: Optional[float]

Expand Down Expand Up @@ -187,6 +203,21 @@ def get_config(self) -> BedrockConfig:
"""
return self.config

def _warn_on_conflicting_additional_args(self) -> None:
"""Warn if additional_args contains keys that conflict with built-in parameters."""
additional_args = self.config.get("additional_args")
if not additional_args:
return

for key in _CONFLICTING_ADDITIONAL_ARGS_KEYS:
if key in additional_args:
warnings.warn(
f"additional_args contains '{key}' which may conflict with built-in request parameters. "
f"Values in additional_args are merged last and may overwrite built-in values.",
UserWarning,
stacklevel=3,
)

def _format_request(
self,
messages: Messages,
Expand All @@ -206,6 +237,8 @@ def _format_request(
Returns:
A Bedrock converse stream request.
"""
self._warn_on_conflicting_additional_args()

if not tool_specs:
has_tool_content = any(
any("toolUse" in block or "toolResult" in block for block in msg.get("content", [])) for msg in messages
Expand Down Expand Up @@ -238,18 +271,19 @@ def _format_request(
"inputSchema": tool_spec["inputSchema"],
}
}
for tool_spec in tool_specs
for tool_spec in (tool_specs or [])
],
*(
[{"cachePoint": {"type": self.config["cache_tools"]}}]
if self.config.get("cache_tools")
else []
),
*(self.config.get("system_tools") or []),
],
**({"toolChoice": tool_choice if tool_choice else {"auto": {}}}),
}
}
if tool_specs
if tool_specs or self.config.get("system_tools")
else {}
),
**(
Expand Down Expand Up @@ -672,8 +706,12 @@ def _stream(
logger.debug("got response from model")
if streaming:
response = self.client.converse_stream(**request)
# Track tool use events to fix stopReason for streaming responses
has_tool_use = False
# Track tool use/result events to fix stopReason for streaming responses
# We need to distinguish server-side tools (already executed) from client-side tools
tool_use_info: dict[str, str] = {} # toolUseId -> type (e.g., "server_tool_use")
tool_result_ids: set[str] = set() # IDs of tools with results
has_client_tools = False

for chunk in response["stream"]:
if (
"metadata" in chunk
Expand All @@ -685,22 +723,40 @@ def _stream(
for event in self._generate_redaction_events():
callback(event)

# Track if we see tool use events
if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"):
has_tool_use = True
# Track tool use events with their types
if "contentBlockStart" in chunk:
tool_use_start = chunk["contentBlockStart"].get("start", {}).get("toolUse")
if tool_use_start:
tool_use_id = tool_use_start.get("toolUseId", "")
tool_type = tool_use_start.get("type", "")
tool_use_info[tool_use_id] = tool_type
# Check if it's a client-side tool (not server_tool_use)
if tool_type != "server_tool_use":
has_client_tools = True

# Track tool result events (for server-side tools that were already executed)
tool_result_start = chunk["contentBlockStart"].get("start", {}).get("toolResult")
if tool_result_start:
tool_result_ids.add(tool_result_start.get("toolUseId", ""))

# Fix stopReason for streaming responses that contain tool use
# BUT: Only override if there are client-side tools without results
if (
has_tool_use
and "messageStop" in chunk
"messageStop" in chunk
and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn"
):
# Create corrected chunk with tool_use stopReason
modified_chunk = chunk.copy()
modified_chunk["messageStop"] = message_stop.copy()
modified_chunk["messageStop"]["stopReason"] = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")
callback(modified_chunk)
# Check if we have client-side tools that need execution
needs_execution = has_client_tools and not set(tool_use_info.keys()).issubset(tool_result_ids)

if needs_execution:
# Create corrected chunk with tool_use stopReason
modified_chunk = chunk.copy()
modified_chunk["messageStop"] = message_stop.copy()
modified_chunk["messageStop"]["stopReason"] = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")
callback(modified_chunk)
else:
callback(chunk)
else:
callback(chunk)

Expand Down Expand Up @@ -762,6 +818,43 @@ def _stream(
callback()
logger.debug("finished streaming response from model")

def _has_client_side_tools_to_execute(self, message_content: list[dict[str, Any]]) -> bool:
"""Check if message contains client-side tools that need execution.

Server-side tools (like nova_grounding) are executed by Bedrock and include
toolResult blocks in the response. We should NOT override stopReason to
"tool_use" for these tools.

Args:
message_content: The content array from Bedrock response.

Returns:
True if there are client-side tools without results, False otherwise.
"""
tool_use_ids = set()
tool_result_ids = set()
has_client_tools = False

for content in message_content:
if "toolUse" in content:
tool_use = content["toolUse"]
tool_use_ids.add(tool_use["toolUseId"])

# Check if it's a server-side tool (Bedrock executes these)
if tool_use.get("type") != "server_tool_use":
has_client_tools = True

elif "toolResult" in content:
# Track which tools already have results
tool_result_ids.add(content["toolResult"]["toolUseId"])

# Only return True if there are client-side tools without results
if not has_client_tools:
return False

# Check if all tool uses have corresponding results
return not tool_use_ids.issubset(tool_result_ids)

def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
"""Convert a non-streaming response to the streaming format.

Expand Down Expand Up @@ -842,10 +935,12 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera

# Yield messageStop event
# Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side
# BUT: Don't override for server-side tools (like nova_grounding) that are already executed
current_stop_reason = response["stopReason"]
if current_stop_reason == "end_turn":
message_content = response["output"]["message"]["content"]
if any("toolUse" in content for content in message_content):
# Only override if there are client-side tools that need execution
if self._has_client_side_tools_to_execute(message_content):
current_stop_reason = "tool_use"
logger.warning("Override stop reason from end_turn to tool_use")

Expand Down
28 changes: 21 additions & 7 deletions src/strands/tools/executors/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,24 +180,38 @@ async def _stream(
invocation_state = before_event.invocation_state

if not selected_tool:
if tool_func == selected_tool:
# Server-side tools (e.g., nova_grounding) are executed by the model provider
# and their results are already included in the response. We provide a
# placeholder result to satisfy the tool result requirement.
if tool_use.get("type") == "server_tool_use":
logger.debug("tool_name=<%s> | server-side tool executed by model provider", tool_name)
result: ToolResult = {
"toolUseId": str(tool_use.get("toolUseId")),
"status": "success",
"content": [{"text": f"Server-side tool '{tool_name}' executed by model provider"}],
}
elif tool_func == selected_tool:
logger.error(
"tool_name=<%s>, available_tools=<%s> | tool not found in registry",
tool_name,
list(agent.tool_registry.registry.keys()),
)
result = {
"toolUseId": str(tool_use.get("toolUseId")),
"status": "error",
"content": [{"text": f"Unknown tool: {tool_name}"}],
}
else:
logger.debug(
"tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call",
tool_name,
str(tool_use.get("toolUseId")),
)

result: ToolResult = {
"toolUseId": str(tool_use.get("toolUseId")),
"status": "error",
"content": [{"text": f"Unknown tool: {tool_name}"}],
}
result = {
"toolUseId": str(tool_use.get("toolUseId")),
"status": "error",
"content": [{"text": f"Unknown tool: {tool_name}"}],
}

after_event, _ = await ToolExecutor._invoke_after_tool_call_hook(
agent, selected_tool, tool_use, invocation_state, result
Expand Down
4 changes: 3 additions & 1 deletion src/strands/types/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Dict, List, Literal, Optional

from typing_extensions import TypedDict
from typing_extensions import NotRequired, TypedDict

from .citations import CitationsContentBlock
from .media import DocumentContent, ImageContent, VideoContent
Expand Down Expand Up @@ -129,10 +129,12 @@ class ContentBlockStartToolUse(TypedDict):
Attributes:
name: The name of the tool that the model is requesting to use.
toolUseId: The ID for the tool request.
type: Optional type identifier (e.g., "server_tool_use" for server-side tools).
"""

name: str
toolUseId: str
type: NotRequired[str]


class ContentBlockStart(TypedDict, total=False):
Expand Down
2 changes: 2 additions & 0 deletions src/strands/types/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ class ToolUse(TypedDict):
Can be any JSON-serializable type.
name: The name of the tool to invoke.
toolUseId: A unique identifier for this specific tool use request.
type: Optional type identifier for the tool use (e.g., "server_tool_use" for server-side tools).
"""

input: Any
name: str
toolUseId: str
type: NotRequired[str]


class ToolResultContent(TypedDict, total=False):
Expand Down
38 changes: 38 additions & 0 deletions tests/strands/event_loop/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def test_handle_message_start():
{"start": {"toolUse": {"toolUseId": "test", "name": "test"}}},
{"toolUseId": "test", "name": "test", "input": ""},
),
# Server-side tool with type field (e.g., nova_grounding)
(
{"start": {"toolUse": {"toolUseId": "server-1", "name": "nova_grounding", "type": "server_tool_use"}}},
{"toolUseId": "server-1", "name": "nova_grounding", "input": "", "type": "server_tool_use"},
),
],
)
def test_handle_content_block_start(chunk: ContentBlockStartEvent, exp_tool_use):
Expand Down Expand Up @@ -328,6 +333,39 @@ def test_handle_content_block_delta(event: ContentBlockDeltaEvent, event_type, s
"redactedContent": b"",
},
),
# Server-side tool with type field (e.g., nova_grounding)
(
{
"content": [],
"current_tool_use": {
"toolUseId": "server-1",
"name": "nova_grounding",
"input": "{}",
"type": "server_tool_use",
},
"text": "",
"reasoningText": "",
"citationsContent": [],
"redactedContent": b"",
},
{
"content": [
{
"toolUse": {
"toolUseId": "server-1",
"name": "nova_grounding",
"input": {},
"type": "server_tool_use",
}
}
],
"current_tool_use": {},
"text": "",
"reasoningText": "",
"citationsContent": [],
"redactedContent": b"",
},
),
# Text
(
{
Expand Down
Loading