Skip to content

Commit 0728c86

Browse files
committed
feat(client): Set a limit on the size of the API responses passed back to the MCP client to control LLM context bloat
Issue: APPAI-176
1 parent f9b50a6 commit 0728c86

File tree

16 files changed

+3436
-49
lines changed

16 files changed

+3436
-49
lines changed

packages/gg_api_core/src/gg_api_core/client.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,21 @@ class ListResponse(TypedDict):
7575
has_more: bool
7676

7777

78+
# Default limit for paginate_all to prevent context bloat
79+
DEFAULT_PAGINATION_MAX_BYTES = 20_000
80+
81+
# Default HTTP timeout in seconds (to handle slow pagination)
82+
DEFAULT_HTTP_TIMEOUT = 20
83+
84+
85+
class PaginatedResult(TypedDict):
86+
"""Result from paginate_all with size limit protection."""
87+
88+
data: list[dict[str, Any]]
89+
cursor: str | None # Use this cursor to fetch more results if has_more is True
90+
has_more: bool # True if results were capped by size limit OR more pages exist
91+
92+
7893
def is_oauth_enabled() -> bool:
7994
"""
8095
Check if OAuth authentication is enabled via environment variable.
@@ -432,7 +447,7 @@ async def _request(self, method: str, endpoint: str, **kwargs) -> Any:
432447

433448
while retry_count <= max_retries:
434449
try:
435-
async with httpx.AsyncClient(follow_redirects=True) as client:
450+
async with httpx.AsyncClient(follow_redirects=True, timeout=DEFAULT_HTTP_TIMEOUT) as client:
436451
logger.debug(f"Sending {method} request to {url}")
437452
response = await client.request(method, url, headers=headers, **kwargs)
438453

@@ -652,7 +667,7 @@ async def _request_list(self, endpoint: str, **kwargs) -> ListResponse:
652667
}
653668
headers.update(kwargs.pop("headers", {}))
654669

655-
async with httpx.AsyncClient(follow_redirects=True) as client:
670+
async with httpx.AsyncClient(follow_redirects=True, timeout=DEFAULT_HTTP_TIMEOUT) as client:
656671
response = await client.get(url, headers=headers, **kwargs)
657672
response.raise_for_status()
658673

@@ -677,19 +692,31 @@ async def _request_list(self, endpoint: str, **kwargs) -> ListResponse:
677692
"has_more": cursor is not None,
678693
}
679694

680-
async def paginate_all(self, endpoint: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]:
695+
async def paginate_all(
696+
self,
697+
endpoint: str,
698+
params: dict[str, Any] | None = None,
699+
max_bytes: int = DEFAULT_PAGINATION_MAX_BYTES,
700+
) -> PaginatedResult:
681701
"""Fetch all pages of results using cursor-based pagination.
682702
703+
Pagination stops when either all data is fetched or the size limit is reached.
704+
683705
Args:
684706
endpoint: API endpoint path
685707
params: Query parameters to include in the request
708+
max_bytes: Maximum total bytes of JSON data to accumulate (default: 20KB).
709+
When this limit is reached, pagination stops and truncated=True is returned.
686710
687711
Returns:
688-
List of all items from all pages
712+
PaginatedResult with data, cursor, has_more, truncated info, and total_bytes
689713
"""
690714
params = params or {}
691-
all_items = []
692-
cursor = None
715+
all_items: list[dict[str, Any]] = []
716+
total_bytes = 0
717+
cursor: str | None = None
718+
truncated = False
719+
has_more = False
693720

694721
logger.debug(f"Starting pagination for endpoint '{endpoint}' with initial params: {params}")
695722

@@ -724,9 +751,26 @@ async def paginate_all(self, endpoint: str, params: dict[str, Any] | None = None
724751
logger.debug("Received empty response data, stopping pagination")
725752
break
726753

727-
logger.debug(f"Received page with {len(response['data'])} items")
754+
# Calculate size of this page's data
755+
page_bytes = len(json.dumps(response["data"]).encode("utf-8"))
756+
757+
# Check if adding this page would exceed the limit
758+
if total_bytes + page_bytes > max_bytes and all_items:
759+
# We already have some data, stop here to avoid exceeding limit
760+
logger.warning(
761+
f"Pagination stopped due to size limit: {total_bytes} bytes accumulated, "
762+
f"next page would add {page_bytes} bytes (limit: {max_bytes} bytes)"
763+
)
764+
truncated = True
765+
has_more = True
766+
# Keep the cursor so caller can continue if needed
767+
cursor = response["cursor"]
768+
break
769+
770+
logger.debug(f"Received page with {len(response['data'])} items ({page_bytes} bytes)")
728771
all_items.extend(response["data"])
729-
logger.debug(f"Total items collected so far: {len(all_items)}")
772+
total_bytes += page_bytes
773+
logger.debug(f"Total items collected so far: {len(all_items)} ({total_bytes} bytes)")
730774

731775
# Check for next cursor
732776
cursor = response["cursor"]
@@ -736,8 +780,15 @@ async def paginate_all(self, endpoint: str, params: dict[str, Any] | None = None
736780
logger.debug("No next cursor found, pagination complete")
737781
break
738782

739-
logger.info(f"Pagination complete for {endpoint}: collected {len(all_items)} total items")
740-
return all_items
783+
logger.info(
784+
f"Pagination complete for {endpoint}: collected {len(all_items)} items "
785+
f"({total_bytes} bytes, capped={truncated})"
786+
)
787+
return {
788+
"data": all_items,
789+
"cursor": cursor,
790+
"has_more": has_more or (cursor is not None),
791+
}
741792

742793
async def create_honeytoken(
743794
self, name: str, description: str = "", custom_tags: list | None = None
@@ -908,9 +959,8 @@ async def list_incidents(
908959
endpoint = "/incidents/secrets"
909960

910961
if get_all:
911-
# When get_all=True, return all items without cursor
912-
all_items = await self.paginate_all(endpoint, params)
913-
return {"data": all_items, "cursor": None, "has_more": False}
962+
# When get_all=True, return paginated result with truncation metadata
963+
return await self.paginate_all(endpoint, params)
914964

915965
query_string = "&".join([f"{k}={v}" for k, v in params.items()])
916966
if query_string:
@@ -1043,9 +1093,8 @@ async def list_honeytokens(
10431093
endpoint = "/honeytokens"
10441094

10451095
if get_all:
1046-
# When get_all=True, return all items without cursor
1047-
all_items = await self.paginate_all(endpoint, params)
1048-
return {"data": all_items, "cursor": None, "has_more": False}
1096+
# When get_all=True, return paginated result with truncation metadata
1097+
return await self.paginate_all(endpoint, params)
10491098

10501099
query_string = "&".join([f"{k}={v}" for k, v in params.items()])
10511100
if query_string:
@@ -1499,11 +1548,10 @@ async def list_occurrences(
14991548
if with_sources is not None:
15001549
params["with_sources"] = str(with_sources).lower()
15011550

1502-
# If get_all is True, use paginate_all to get all results
1551+
# If get_all is True, use paginate_all to get all results with truncation metadata
15031552
if get_all:
15041553
logger.info("Getting all occurrences using cursor-based pagination")
1505-
all_items = await self.paginate_all("occurrences/secrets", params)
1506-
return {"data": all_items, "cursor": None, "has_more": False}
1554+
return await self.paginate_all("occurrences/secrets", params)
15071555

15081556
# Otherwise, get a single page
15091557
logger.info(f"Getting occurrences with params: {params}")
@@ -1613,9 +1661,8 @@ async def list_sources(
16131661
endpoint = "/sources"
16141662

16151663
if get_all:
1616-
# When get_all=True, return all items without cursor
1617-
all_items = await self.paginate_all(endpoint, params)
1618-
return {"data": all_items, "cursor": None, "has_more": False}
1664+
# When get_all=True, return paginated result with truncation metadata
1665+
return await self.paginate_all(endpoint, params)
16191666

16201667
return await self._request_list(endpoint, params=params)
16211668

packages/gg_api_core/src/gg_api_core/tools/list_honeytokens.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastmcp.exceptions import ToolError
55
from pydantic import BaseModel, Field
66

7+
from gg_api_core.client import DEFAULT_PAGINATION_MAX_BYTES
78
from gg_api_core.utils import get_client
89

910
logger = logging.getLogger(__name__)
@@ -26,7 +27,10 @@ class ListHoneytokensParams(BaseModel):
2627
creator_api_token_id: str | int | None = Field(default=None, description="Filter by creator API token ID")
2728
per_page: int = Field(default=20, description="Number of results per page (default: 20, min: 1, max: 100)")
2829
cursor: str | None = Field(default=None, description="Pagination cursor from a previous response")
29-
get_all: bool = Field(default=False, description="If True, fetch all results using cursor-based pagination")
30+
get_all: bool = Field(
31+
default=False,
32+
description=f"If True, fetch all pages (capped at ~{DEFAULT_PAGINATION_MAX_BYTES / 1000}KB; check 'has_more' and use cursor to continue)",
33+
)
3034

3135

3236
class ListHoneytokensResult(BaseModel):
@@ -36,6 +40,7 @@ class ListHoneytokensResult(BaseModel):
3640
next_cursor: str | None = Field(
3741
default=None, description="Cursor for fetching the next page (null if no more results)"
3842
)
43+
has_more: bool = Field(default=False, description="True if more results exist (use next_cursor to fetch)")
3944

4045

4146
async def list_honeytokens(params: ListHoneytokensParams) -> ListHoneytokensResult:
@@ -90,7 +95,11 @@ async def list_honeytokens(params: ListHoneytokensParams) -> ListHoneytokensResu
9095
next_cursor = response["cursor"]
9196

9297
logger.debug(f"Found {len(honeytokens_data)} honeytokens")
93-
return ListHoneytokensResult(honeytokens=honeytokens_data, next_cursor=next_cursor)
98+
return ListHoneytokensResult(
99+
honeytokens=honeytokens_data,
100+
next_cursor=next_cursor,
101+
has_more=response.get("has_more", False),
102+
)
94103
except Exception as e:
95104
logger.error(f"Error listing honeytokens: {str(e)}")
96105
raise ToolError(str(e))

packages/gg_api_core/src/gg_api_core/tools/list_incidents.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33

44
from pydantic import BaseModel, Field
55

6-
from gg_api_core.client import IncidentSeverity, IncidentStatus, IncidentValidity, TagNames
6+
from gg_api_core.client import (
7+
DEFAULT_PAGINATION_MAX_BYTES,
8+
IncidentSeverity,
9+
IncidentStatus,
10+
IncidentValidity,
11+
TagNames,
12+
)
713
from gg_api_core.tools.find_current_source_id import find_current_source_id
814
from gg_api_core.utils import get_client
915

@@ -115,7 +121,10 @@ class ListIncidentsParams(BaseModel):
115121
ordering: str | None = Field(default=None, description="Sort field (e.g., 'date', '-date' for descending)")
116122
per_page: int = Field(default=20, description="Number of results per page (default: 20, min: 1, max: 100)")
117123
cursor: str | None = Field(default=None, description="Pagination cursor for fetching next page of results")
118-
get_all: bool = Field(default=False, description="If True, fetch all results using cursor-based pagination")
124+
get_all: bool = Field(
125+
default=False,
126+
description=f"If True, fetch all pages (capped at ~{DEFAULT_PAGINATION_MAX_BYTES / 1000}; check 'has_more' and use cursor to continue)",
127+
)
119128

120129
# Filters
121130
from_date: str | None = Field(
@@ -164,6 +173,7 @@ class ListIncidentsResult(BaseModel):
164173
next_cursor: str | None = Field(default=None, description="Pagination cursor for next page")
165174
applied_filters: dict[str, Any] = Field(default_factory=dict, description="Filters that were applied to the query")
166175
suggestion: str = Field(default="", description="Suggestions for interpreting or modifying the results")
176+
has_more: bool = Field(default=False, description="True if more results exist (use next_cursor to fetch)")
167177

168178

169179
class ListIncidentsError(BaseModel):
@@ -248,7 +258,7 @@ async def list_incidents(params: ListIncidentsParams) -> ListIncidentsResult | L
248258
if params.get_all:
249259
api_params["get_all"] = params.get_all
250260

251-
# Get incidents using list_incidents which returns ListResponse
261+
# Get incidents using list_incidents which returns ListResponse or PaginatedResult
252262
response = await client.list_incidents(**api_params)
253263
incidents_data = response["data"]
254264
next_cursor = response["cursor"]
@@ -261,6 +271,7 @@ async def list_incidents(params: ListIncidentsParams) -> ListIncidentsResult | L
261271
next_cursor=next_cursor,
262272
applied_filters=_build_filter_info(params),
263273
suggestion=_build_suggestion(params, count),
274+
has_more=response.get("has_more", False),
264275
)
265276

266277
except Exception as e:

packages/gg_api_core/src/gg_api_core/tools/list_repo_occurrences.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33

44
from pydantic import BaseModel, Field
55

6-
from gg_api_core.client import IncidentSeverity, IncidentStatus, IncidentValidity, TagNames
6+
from gg_api_core.client import (
7+
DEFAULT_PAGINATION_MAX_BYTES,
8+
IncidentSeverity,
9+
IncidentStatus,
10+
IncidentValidity,
11+
TagNames,
12+
)
713
from gg_api_core.utils import get_client
814

915
logger = logging.getLogger(__name__)
@@ -74,7 +80,10 @@ class ListRepoOccurrencesBaseParams(BaseModel):
7480
ordering: str | None = Field(default=None, description="Sort field (e.g., 'date', '-date' for descending)")
7581
per_page: int = Field(default=20, description="Number of results per page (default: 20, min: 1, max: 100)")
7682
cursor: str | None = Field(default=None, description="Pagination cursor for fetching next page of results")
77-
get_all: bool = Field(default=False, description="If True, fetch all results using cursor-based pagination")
83+
get_all: bool = Field(
84+
default=False,
85+
description=f"If True, fetch all pages (capped at ~{DEFAULT_PAGINATION_MAX_BYTES / 1000}KB; check 'has_more' and use cursor to continue)",
86+
)
7887

7988

8089
class ListRepoOccurrencesParams(ListRepoOccurrencesFilters, ListRepoOccurrencesBaseParams):
@@ -88,7 +97,7 @@ class ListRepoOccurrencesResult(BaseModel):
8897
occurrences_count: int = Field(description="Number of occurrences returned")
8998
occurrences: list[dict[str, Any]] = Field(default_factory=list, description="List of occurrence objects")
9099
cursor: str | None = Field(default=None, description="Pagination cursor for next page")
91-
has_more: bool = Field(default=False, description="Whether more results are available")
100+
has_more: bool = Field(default=False, description="True if more results exist (use cursor to fetch)")
92101
applied_filters: dict[str, Any] = Field(default_factory=dict, description="Filters that were applied to the query")
93102
suggestion: str = Field(default="", description="Suggestions for interpreting or modifying the results")
94103

@@ -237,10 +246,10 @@ async def list_repo_occurrences(
237246
with_sources=False,
238247
)
239248

240-
# Extract data from ListResponse
249+
# Extract data from ListResponse or PaginatedResult
241250
occurrences_data = result["data"]
242251
next_cursor = result["cursor"]
243-
has_more = result["has_more"]
252+
has_more = result.get("has_more", False)
244253

245254
count = len(occurrences_data)
246255
return ListRepoOccurrencesResult(

packages/gg_api_core/src/gg_api_core/tools/list_users.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pydantic import BaseModel, Field
66

7+
from gg_api_core.client import DEFAULT_PAGINATION_MAX_BYTES
78
from gg_api_core.utils import get_client
89

910
logger = logging.getLogger(__name__)
@@ -27,7 +28,10 @@ class ListUsersParams(BaseModel):
2728
default=None,
2829
description="Sort results by field (created_at, -created_at, last_login, -last_login). Use '-' prefix for descending order",
2930
)
30-
get_all: bool = Field(default=False, description="If True, fetch all results using cursor-based pagination")
31+
get_all: bool = Field(
32+
default=False,
33+
description=f"If True, fetch all pages (capped at ~{DEFAULT_PAGINATION_MAX_BYTES / 1000}KB; check 'has_more' and use cursor to continue)",
34+
)
3135

3236

3337
class ListUsersResult(BaseModel):
@@ -36,6 +40,7 @@ class ListUsersResult(BaseModel):
3640
members: list[dict[str, Any]] = Field(description="List of workspace member objects")
3741
total_count: int = Field(description="Total number of members returned")
3842
next_cursor: str | None = Field(default=None, description="Pagination cursor for next page (if applicable)")
43+
has_more: bool = Field(default=False, description="True if more results exist (use next_cursor to fetch)")
3944

4045

4146
async def list_users(params: ListUsersParams) -> ListUsersResult:
@@ -81,10 +86,15 @@ async def list_users(params: ListUsersParams) -> ListUsersResult:
8186
logger.debug(f"Query parameters: {json.dumps(query_params)}")
8287

8388
if params.get_all:
84-
# Use paginate_all for fetching all results
85-
members_list = await client.paginate_all("/members", query_params)
86-
logger.debug(f"Retrieved all {len(members_list)} members using pagination")
87-
return ListUsersResult(members=members_list, total_count=len(members_list), next_cursor=None)
89+
# Use paginate_all for fetching all results (capped at DEFAULT_PAGINATION_MAX_BYTES)
90+
result = await client.paginate_all("/members", query_params)
91+
logger.debug(f"Retrieved {len(result['data'])} members using pagination (has_more={result['has_more']})")
92+
return ListUsersResult(
93+
members=result["data"],
94+
total_count=len(result["data"]),
95+
next_cursor=result["cursor"],
96+
has_more=result["has_more"],
97+
)
8898
else:
8999
# Single page request
90100
result = await client.list_members(params=query_params)

packages/secops_mcp_server/src/secops_mcp_server/server.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from developer_mcp_server.add_health_check import add_health_check
77
from developer_mcp_server.register_tools import register_developer_tools
88
from fastmcp.exceptions import ToolError
9+
from gg_api_core.client import DEFAULT_PAGINATION_MAX_BYTES
910
from gg_api_core.mcp_server import get_mcp_server, register_common_tools
1011
from gg_api_core.scopes import set_secops_scopes
1112
from gg_api_core.tools.assign_incident import assign_incident
@@ -56,7 +57,10 @@ class ListIncidentsParams(BaseModel):
5657
description="Sort field and direction (prefix with '-' for descending order). If you need to get the latest incidents, use '-date'.",
5758
)
5859
per_page: int = Field(default=20, description="Number of results per page (1-100)")
59-
get_all: bool = Field(default=False, description="If True, fetch all results using cursor-based pagination")
60+
get_all: bool = Field(
61+
default=False,
62+
description=f"If True, fetch all pages (capped at ~{DEFAULT_PAGINATION_MAX_BYTES / 1000}KB; check 'has_more' and use cursor to continue)",
63+
)
6064
mine: bool = Field(default=False, description="If True, fetch incidents assigned to the current user")
6165

6266

@@ -72,7 +76,10 @@ class ListHoneytokensParams(BaseModel):
7276
creator_id: str | int | None = Field(default=None, description="Filter by creator ID")
7377
creator_api_token_id: str | int | None = Field(default=None, description="Filter by creator API token ID")
7478
per_page: int = Field(default=20, description="Number of results per page (default: 20, min: 1, max: 100)")
75-
get_all: bool = Field(default=False, description="If True, fetch all results using cursor-based pagination")
79+
get_all: bool = Field(
80+
default=False,
81+
description=f"If True, fetch all pages (capped at ~{DEFAULT_PAGINATION_MAX_BYTES / 1000}KB; check 'has_more' and use cursor to continue)",
82+
)
7683
mine: bool = Field(default=False, description="If True, fetch honeytokens created by the current user")
7784

7885

0 commit comments

Comments
 (0)