Skip to content
2 changes: 2 additions & 0 deletions src/google/adk/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .llm_agent import Agent
from .llm_agent import LlmAgent
from .loop_agent import LoopAgent
from .map_agent import MapAgent
from .mcp_instruction_provider import McpInstructionProvider
from .parallel_agent import ParallelAgent
from .run_config import RunConfig
Expand All @@ -29,6 +30,7 @@
'BaseAgent',
'LlmAgent',
'LoopAgent',
'MapAgent',
'McpInstructionProvider',
'ParallelAgent',
'SequentialAgent',
Expand Down
175 changes: 175 additions & 0 deletions src/google/adk/agents/map_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from __future__ import annotations

import sys
from typing import Annotated
from typing import AsyncGenerator

from annotated_types import Len
from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.parallel_agent import _create_branch_ctx_for_sub_agent
from google.adk.agents.parallel_agent import _merge_agent_run
from google.adk.agents.parallel_agent import _merge_agent_run_pre_3_11
from google.adk.events import Event
from google.adk.flows.llm_flows.contents import _should_include_event_in_context
from google.genai import types
from pydantic import Field
from pydantic import RootModel
from typing_extensions import override

from ..utils.context_utils import Aclosing


class MapAgent(BaseAgent):
sub_agents: Annotated[list[BaseAgent], Len(1, 1)] = Field(
min_length=1,
max_length=1,
default_factory=list,
description=(
"A single base agent that will be copied and invoked for each prompt"
),
)

@override
async def _run_async_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Core logic of this workflow agent.

Args:
invocation_context: InvocationContext, provides access to the input prompts.

Yields:
Event: the events generated by the sub-agent for each input prompt.
"""

# Create a branch string if it doesn't exist, to ensure parallel invocations don't interfere with each other
prompts, invoker = self._extract_input_prompts(invocation_context)

# for agent naming - e.g. if there are 100-999 prompts, sub-agent copies are named 001, 002, 003 and so on
number_field_width = len(str(len(prompts)))

# Clone the sub-agent's tree for each prompt
sub_agents = [
self._branch_agent_tree(self.sub_agents[0], i, number_field_width)
for i, _ in enumerate(prompts)
]

# Set the map agent as the parent of the clones
self.clone(update={"sub_agents": sub_agents})

# Create a separate invocation context for each prompt, each with a numbered copy of the sub-agent.
contexts = [
self._branch_context(
invocation_context,
agent=agent,
invoker=invoker,
prompt=prompt,
)
for prompt, agent in zip(prompts, sub_agents)
]

agent_runs = [ctx.agent.run_async(ctx) for ctx in contexts]

merge_func = (
_merge_agent_run
if sys.version_info >= (3, 11)
else _merge_agent_run_pre_3_11
)
async with Aclosing(merge_func(agent_runs)) as agen:
async for event in agen:
yield event

def _extract_input_prompts(
self, ctx: InvocationContext
) -> tuple[list[str], str]:
"""
The input to the map agent is a list of strings.
We extract the text content from the latest event, and assume it is a list of strings serialized as a json string.
"""
for i in range(len(ctx.session.events) - 1, -1, -1):
event = ctx.session.events[i]
if _should_include_event_in_context(ctx.branch, event):
break
else:
return [], "user"

invoker: str = event.author
input_message: str = (
(event.content or types.Content()).parts or [types.Part()]
)[0].text or ""

# Remove the event which has the prompt list, so that a sub agent does not
# see the prompts of its siblings, which may confuse it.
# The event is removed only for this invocation.
ctx.session.events.pop(i)

agent_input = RootModel[list[str]].model_validate_json(input_message).root

return agent_input, invoker

@staticmethod
def _get_unique_name(name: str, idx: int, width: int) -> str:
"""e.g. my_sub_agent_046"""
return f"{name}_{idx:0{width}d}"

def _branch_context(
self,
ctx: InvocationContext,
*,
agent: BaseAgent,
invoker: str,
prompt: str,
) -> InvocationContext:
"""Creates a an invocation context for invoking a sub-agent clone with a single prompt.

Args:
ctx: The current invocation context of the map agent. To be copied and edited for the sub-agent copy.
agent: the sub-agent clone to be invoked in the returned context.
invoker: the invoker of the map agent in this invocation.
prompt: the prompt on which the sub-agent copy should be invoked

Returns:
InvocationContext: A new invocation context ready to run with the unique sub-agent copy and the prompt
"""
prompt_part = [types.Part(text=prompt)]

# Add the prompt to the user_content of this branch to easily access agent input in callbacks
user_content = types.Content(
role="user",
parts=((ctx.user_content or types.Content()).parts or []) + prompt_part,
)

new_ctx = _create_branch_ctx_for_sub_agent(self, agent, ctx).model_copy(
update=dict(agent=agent, user_content=user_content)
)

# Add the prompt as a temporary event of this branch in place of the prompt list as the natural input of the sub-agent.
prompt_content = types.Content(
role="user" if invoker == "user" else "model", parts=prompt_part
)
new_ctx.session.events.append(
Event(author=invoker, branch=new_ctx.branch, content=prompt_content)
)

return new_ctx

def _branch_agent_tree(
self, agent: BaseAgent, idx: int, width: int
) -> BaseAgent:
"""
Clone and rename an agent and its sub-tree to create a thread-safe branch.
Args:
agent: the root of the current sub-agent tree - in the first call it is the main sub-agent of the map agent
idx: index of the prompt in the input prompts, serves as a unique postfix to the agent name
width: number of digits in the total number of prompts, to ensure naming is consistent in field width
(e.g. 001, 002, ... 010, 011, ... 100, 101; and not 1, 2, ... 10, 11, ... 100, 101)
"""
new_name = self._get_unique_name(agent.name, idx=idx, width=width)
new_sub_agents = [
self._branch_agent_tree(a, idx, width) for a in agent.sub_agents
]
new_agent = agent.clone(
update={"name": new_name, "sub_agents": new_sub_agents}
)
return new_agent
Loading