Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ def apply_history_length(task: Task, history_length: int | None) -> Task:
A new task object with limited history
"""
# Apply historyLength parameter if specified
if history_length is not None and history_length > 0 and task.history:
if history_length is not None and history_length >= 0:
# Limit history to the most recent N messages
limited_history = task.history[-history_length:]
if task.history and history_length > 0:
limited_history = task.history[-history_length:]
else:
limited_history = []
# Create a new task instance with limited history
return task.model_copy(update={'history': limited_history})

Expand Down
41 changes: 40 additions & 1 deletion tests/utils/test_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import unittest
import uuid

Expand All @@ -6,7 +6,7 @@
import pytest

from a2a.types import Artifact, Message, Part, Role, TextPart
from a2a.utils.task import completed_task, new_task
from a2a.utils.task import apply_history_length, completed_task, new_task


class TestTask(unittest.TestCase):
Expand Down Expand Up @@ -188,6 +188,45 @@
history=[],
)

def test_apply_history_length_cases(self):
# Setup task with 3 messages
history = [
Message(role=Role.user, parts=[Part(root=TextPart(text='1'))], message_id='1'),
Message(role=Role.agent, parts=[Part(root=TextPart(text='2'))], message_id='2'),
Message(role=Role.user, parts=[Part(root=TextPart(text='3'))], message_id='3'),
]
task_id = str(uuid.uuid4())
context_id = str(uuid.uuid4())
task = completed_task(
task_id=task_id,
context_id=context_id,
artifacts=[Artifact(artifact_id='a', parts=[Part(root=TextPart(text='a'))])],
history=history
)

# historyLength = 0 -> empty
t0 = apply_history_length(task, 0)
self.assertEqual(len(t0.history), 0)

# historyLength = 1 -> last one
t1 = apply_history_length(task, 1)
self.assertEqual(len(t1.history), 1)
self.assertEqual(t1.history[0].message_id, '3')

# historyLength = 2 -> last two
t2 = apply_history_length(task, 2)
self.assertEqual(len(t2.history), 2)
self.assertEqual(t2.history[0].message_id, '2')
self.assertEqual(t2.history[1].message_id, '3')

# historyLength = None -> all
tn = apply_history_length(task, None)
self.assertEqual(len(tn.history), 3)

# historyLength = 10 -> all
t10 = apply_history_length(task, 10)
self.assertEqual(len(t10.history), 3)
Comment on lines +192 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

This test is great for covering the main scenarios. It could be made more robust and maintainable by parameterizing the test cases and using unittest.subTest. This approach has a few benefits:

  • Clarity: It separates the test data from the test logic, making it easier to see what's being tested.
  • Maintainability: Adding new test cases is as simple as adding a new entry to the test_cases dictionary.
  • Better Failure Reporting: subTest ensures that all cases are run, and it reports failures for each subtest individually, rather than stopping at the first failure.
  • Improved Assertions: The assertions can be made more consistent and thorough by checking the exact sequence of message IDs for every case.
        # Setup task with 3 messages
        history = [
            Message(role=Role.user, parts=[Part(root=TextPart(text='1'))], message_id='1'),
            Message(role=Role.agent, parts=[Part(root=TextPart(text='2'))], message_id='2'),
            Message(role=Role.user, parts=[Part(root=TextPart(text='3'))], message_id='3'),
        ]
        task = completed_task(
            task_id=str(uuid.uuid4()),
            context_id=str(uuid.uuid4()),
            artifacts=[Artifact(artifact_id='a', parts=[Part(root=TextPart(text='a'))])],
            history=history,
        )

        test_cases = {
            'zero_length': (0, []),
            'one_item': (1, ['3']),
            'two_items': (2, ['2', '3']),
            'none_length_is_full_history': (None, ['1', '2', '3']),
            'length_greater_than_history_is_full_history': (10, ['1', '2', '3']),
        }

        for name, (length, expected_ids) in test_cases.items():
            with self.subTest(msg=name):
                new_task = apply_history_length(task, length)
                actual_ids = [m.message_id for m in new_task.history]
                self.assertEqual(actual_ids, expected_ids)



if __name__ == '__main__':
unittest.main()
Loading