Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 166 additions & 23 deletions src/strands/multiagent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@
MultiAgentInitializedEvent,
)
from ..hooks import HookProvider, HookRegistry
from ..interrupt import Interrupt, _InterruptState
from ..session import SessionManager
from ..telemetry import get_tracer
from ..types._events import (
MultiAgentHandoffEvent,
MultiAgentNodeCancelEvent,
MultiAgentNodeInterruptEvent,
MultiAgentNodeStartEvent,
MultiAgentNodeStopEvent,
MultiAgentNodeStreamEvent,
Expand Down Expand Up @@ -162,6 +164,7 @@ class GraphNode:
execution_status: Status = Status.PENDING
result: NodeResult | None = None
execution_time: int = 0
graph: Optional["Graph"] = None
_initial_messages: Messages = field(default_factory=list, init=False)
_initial_state: AgentState = field(default_factory=AgentState, init=False)

Expand All @@ -177,9 +180,18 @@ def __post_init__(self) -> None:
def reset_executor_state(self) -> None:
"""Reset GraphNode executor state to initial state when graph was created.

This is useful when nodes are executed multiple times and need to start
fresh on each execution, providing stateless behavior.
If Graph is resuming from an interrupt, we reset the executor state from the interrupt context.
"""
if self.graph and self.graph._interrupt_state.activated and self.node_id in self.graph._interrupt_state.context:
context = self.graph._interrupt_state.context[self.node_id]
if hasattr(self.executor, "messages"):
self.executor.messages = context["messages"]
if hasattr(self.executor, "state"):
self.executor.state = AgentState(context["state"])
if hasattr(self.executor, "_interrupt_state"):
self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"])
return

if hasattr(self.executor, "messages"):
self.executor.messages = copy.deepcopy(self._initial_messages)

Expand Down Expand Up @@ -440,11 +452,16 @@ def __init__(
self.nodes = nodes
self.edges = edges
self.entry_points = entry_points

# Set graph reference on all nodes for interrupt state restoration
for node in self.nodes.values():
node.graph = self
self.max_node_executions = max_node_executions
self.execution_timeout = execution_timeout
self.node_timeout = node_timeout
self.reset_on_revisit = reset_on_revisit
self.state = GraphState()
self._interrupt_state = _InterruptState()
self.tracer = get_tracer()
self.trace_attributes: dict[str, AttributeValue] = self._parse_trace_attributes(trace_attributes)
self.session_manager = session_manager
Expand Down Expand Up @@ -519,6 +536,8 @@ async def stream_async(
- multi_agent_node_stop: When a node stops execution
- result: Final graph result
"""
self._interrupt_state.resume(task)

if invocation_state is None:
invocation_state = {}

Expand All @@ -544,6 +563,8 @@ async def stream_async(

span = self.tracer.start_multiagent_span(task, "graph", custom_trace_attributes=self.trace_attributes)
with trace_api.use_span(span, end_on_exit=True):
interrupts = []

try:
logger.debug(
"max_node_executions=<%s>, execution_timeout=<%s>s, node_timeout=<%s>s | graph execution config",
Expand All @@ -553,6 +574,9 @@ async def stream_async(
)

async for event in self._execute_graph(invocation_state):
if isinstance(event, MultiAgentNodeInterruptEvent):
interrupts = event.interrupts

yield event.as_dict()

# Set final status based on execution results
Expand All @@ -564,7 +588,7 @@ async def stream_async(
logger.debug("status=<%s> | graph execution completed", self.state.status)

# Yield final result (consistent with Agent's AgentResultEvent format)
result = self._build_result()
result = self._build_result(interrupts)

# Use the same event format as Agent for consistency
yield MultiAgentResultEvent(result=result).as_dict()
Expand All @@ -576,8 +600,11 @@ async def stream_async(
finally:
self.state.execution_time = round((time.time() - start_time) * 1000)
await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self))
self._resume_from_session = False
self._resume_next_nodes.clear()
# Don't clear resume flags here - they should only be cleared when consumed
# at the start of _execute_graph, or when status is COMPLETED/FAILED
if self.state.status in (Status.COMPLETED, Status.FAILED):
self._resume_from_session = False
self._resume_next_nodes.clear()

def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:
"""Validate graph nodes for duplicate instances."""
Expand All @@ -593,7 +620,22 @@ def _validate_graph(self, nodes: dict[str, GraphNode]) -> None:

async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterator[Any]:
"""Execute graph and yield TypedEvent objects."""
ready_nodes = self._resume_next_nodes if self._resume_from_session else list(self.entry_points)
# Make a copy to avoid clearing the list we're about to use
ready_nodes = self._resume_next_nodes.copy() if self._resume_from_session else list(self.entry_points)

logger.debug(
"resume_from_session=<%s>, resume_next_nodes=<%s>, entry_points=<%s>, ready_nodes=<%s> | "
"starting execution",
self._resume_from_session,
[n.node_id for n in self._resume_next_nodes] if self._resume_next_nodes else [],
[n.node_id for n in self.entry_points],
[n.node_id for n in ready_nodes],
)

# Clear resume flags after consuming them once
if self._resume_from_session:
self._resume_from_session = False
self._resume_next_nodes.clear()

while ready_nodes:
# Check execution limits before continuing
Expand All @@ -609,15 +651,31 @@ async def _execute_graph(self, invocation_state: dict[str, Any]) -> AsyncIterato
current_batch = ready_nodes.copy()
ready_nodes.clear()

# Track if interrupt occurred
interrupt_detected = False

# Execute current batch
async for event in self._execute_nodes_parallel(current_batch, invocation_state):
# Check for interrupt event
if isinstance(event, MultiAgentNodeInterruptEvent):
interrupt_detected = True
yield event

# Stop execution if interrupted
if interrupt_detected:
break

# Find newly ready nodes after batch execution
# We add all nodes in current batch as completed batch,
# because a failure would throw exception and code would not make it here
newly_ready = self._find_newly_ready_nodes(current_batch)

logger.debug(
"completed_batch=<%s>, newly_ready=<%s> | finding next nodes",
[n.node_id for n in current_batch],
[n.node_id for n in newly_ready],
)

# Emit handoff event for batch transition if there are nodes to transition to
if newly_ready:
handoff_event = MultiAgentHandoffEvent(
Expand Down Expand Up @@ -797,10 +855,14 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
)
yield start_event

before_event, _ = await self.hooks.invoke_callbacks_async(
before_event, interrupts = await self.hooks.invoke_callbacks_async(
BeforeNodeCallEvent(self, node.node_id, invocation_state)
)

if interrupts:
yield self._activate_interrupt(node, interrupts)
return

start_time = time.time()
try:
if before_event.cancel_node:
Expand All @@ -811,8 +873,22 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
yield MultiAgentNodeCancelEvent(node.node_id, cancel_message)
raise RuntimeError(cancel_message)

# Build node input from satisfied dependencies
node_input = self._build_node_input(node)
# Check if resuming from interrupt with response
# For hook interrupts: activated=False at agent level → use original task input
# For agent interrupts: activated=True at agent level → use interrupt responses
if (
self._interrupt_state.activated
and node.node_id in self._interrupt_state.context
and self._interrupt_state.context[node.node_id].get("activated")
):
# Agent was interrupted - use interrupt response as input
node_input = self._interrupt_state.context.get("responses", [])

# Restore node state from interrupt context (this restores the agent's interrupt state)
node.reset_executor_state()
else:
# Normal execution or hook interrupt - build normal input
node_input = self._build_node_input(node)

# Execute and stream events (timeout handled at task level)
if isinstance(node.executor, MultiAgentBase):
Expand All @@ -830,13 +906,15 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
if multi_agent_result is None:
raise ValueError(f"Node '{node.node_id}' did not produce a result event")

# Check the actual status from multi_agent_result
node_result = NodeResult(
result=multi_agent_result,
execution_time=multi_agent_result.execution_time,
status=Status.COMPLETED,
status=multi_agent_result.status,
accumulated_usage=multi_agent_result.accumulated_usage,
accumulated_metrics=multi_agent_result.accumulated_metrics,
execution_count=multi_agent_result.execution_count,
interrupts=multi_agent_result.interrupts or [],
)

elif isinstance(node.executor, Agent):
Expand All @@ -854,31 +932,37 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any])
if agent_response is None:
raise ValueError(f"Node '{node.node_id}' did not produce a result event")

# Check for interrupt (from main branch)
if agent_response.stop_reason == "interrupt":
node.executor.messages.pop() # remove interrupted tool use message
node.executor._interrupt_state.deactivate()

raise RuntimeError("user raised interrupt from agent | interrupts are not yet supported in graphs")

# Extract metrics with defaults
response_metrics = getattr(agent_response, "metrics", None)
usage = getattr(
response_metrics, "accumulated_usage", Usage(inputTokens=0, outputTokens=0, totalTokens=0)
)
metrics = getattr(response_metrics, "accumulated_metrics", Metrics(latencyMs=0))

# Check for interrupt and set appropriate status
execution_time = round((time.time() - start_time) * 1000)
status = Status.INTERRUPTED if agent_response.stop_reason == "interrupt" else Status.COMPLETED

node_result = NodeResult(
result=agent_response,
execution_time=round((time.time() - start_time) * 1000),
status=Status.COMPLETED,
execution_time=execution_time,
status=status,
accumulated_usage=usage,
accumulated_metrics=metrics,
execution_count=1,
interrupts=agent_response.interrupts or [],
)
else:
raise ValueError(f"Node '{node.node_id}' of type '{type(node.executor)}' is not supported")

# Check if node was interrupted
if node_result.status == Status.INTERRUPTED:
yield self._activate_interrupt(node, node_result.interrupts)
return

# Deactivate interrupt state after successful execution
self._interrupt_state.deactivate()

# Mark as completed
node.execution_status = Status.COMPLETED
node.result = node_result
Expand Down Expand Up @@ -979,7 +1063,13 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
if isinstance(self.state.task, str):
return [ContentBlock(text=self.state.task)]
else:
return cast(list[ContentBlock], self.state.task)
# Filter out interruptResponse dicts from task
task_list = cast(list[ContentBlock], self.state.task)
filtered_task = []
for item in task_list:
if not (isinstance(item, dict) and "interruptResponse" in item):
filtered_task.append(item)
return filtered_task

# Combine task with dependency outputs
node_input = []
Expand All @@ -988,9 +1078,12 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:
if isinstance(self.state.task, str):
node_input.append(ContentBlock(text=f"Original Task: {self.state.task}"))
else:
# Add task content blocks with a prefix
# Add task content blocks with a prefix (filter out interruptResponse dicts)
node_input.append(ContentBlock(text="Original Task:"))
node_input.extend(cast(list[ContentBlock], self.state.task))
task_list = cast(list[ContentBlock], self.state.task)
for item in task_list:
if not (isinstance(item, dict) and "interruptResponse" in item):
node_input.append(item)

# Add dependency outputs
node_input.append(ContentBlock(text="\nInputs from previous nodes:"))
Expand All @@ -1006,8 +1099,49 @@ def _build_node_input(self, node: GraphNode) -> list[ContentBlock]:

return node_input

def _build_result(self) -> GraphResult:
def _activate_interrupt(self, node: GraphNode, interrupts: list[Interrupt]) -> MultiAgentNodeInterruptEvent:
"""Activate the interrupt state.

A Graph may be interrupted either from a BeforeNodeCallEvent hook or from within an agent node.
In either case, we must manage the interrupt state of both the Graph and the individual agent nodes.

Args:
node: The interrupted node.
interrupts: The interrupts raised by the user.

Returns:
MultiAgentNodeInterruptEvent
"""
logger.debug("node=<%s> | node interrupted", node.node_id)
self.state.status = Status.INTERRUPTED

# Save node and graph state for resumption
node_executor = node.executor
self._interrupt_state.context[node.node_id] = {
"activated": node_executor._interrupt_state.activated
if hasattr(node_executor, "_interrupt_state")
else False,
"interrupt_state": node_executor._interrupt_state.to_dict()
if hasattr(node_executor, "_interrupt_state")
else {},
"state": node_executor.state.get() if hasattr(node_executor, "state") else {},
"messages": node_executor.messages if hasattr(node_executor, "messages") else [],
}

self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
self._interrupt_state.activate()

# Set resume state so graph continues from this node instead of entry points
self._resume_next_nodes = [node]
self._resume_from_session = True

return MultiAgentNodeInterruptEvent(node.node_id, interrupts)

def _build_result(self, interrupts: list[Interrupt] | None = None) -> GraphResult:
"""Build graph result from current state."""
if interrupts is None:
interrupts = []

return GraphResult(
status=self.state.status,
results=self.state.results,
Expand All @@ -1021,6 +1155,7 @@ def _build_result(self) -> GraphResult:
execution_order=self.state.execution_order,
edges=self.state.edges,
entry_points=self.state.entry_points,
interrupts=interrupts,
)

def serialize_state(self) -> dict[str, Any]:
Expand All @@ -1037,6 +1172,9 @@ def serialize_state(self) -> dict[str, Any]:
"next_nodes_to_execute": next_nodes,
"current_task": self.state.task,
"execution_order": [n.node_id for n in self.state.execution_order],
"_internal_state": {
"interrupt_state": self._interrupt_state.to_dict(),
},
}

def deserialize_state(self, payload: dict[str, Any]) -> None:
Expand All @@ -1052,6 +1190,11 @@ def deserialize_state(self, payload: dict[str, Any]) -> None:
payload: Dictionary containing persisted state data including status,
completed nodes, results, and next nodes to execute.
"""
# Restore interrupt state if present
if "_internal_state" in payload:
internal_state = payload["_internal_state"]
self._interrupt_state = _InterruptState.from_dict(internal_state.get("interrupt_state", {}))

if not payload.get("next_nodes_to_execute"):
# Reset all nodes
for node in self.nodes.values():
Expand Down
Loading