From d17a70e82ac4c636e25afc578f62690fe5501a84 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 17 Dec 2025 15:48:10 -0800 Subject: [PATCH 1/4] ENG-8540: avoid dataclasses.asdict in Lost+Found path Use the reflex serializers registry to serialize StateUpdate objects for Lost+Found usage. --- reflex/utils/token_manager.py | 17 ++++++++---- tests/units/utils/test_token_manager.py | 37 +++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index f3bccedf5fa..8522b35bc20 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -9,11 +9,12 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Callable, Coroutine from types import MappingProxyType -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, ClassVar from reflex.istate.manager.redis import StateManagerRedis from reflex.state import BaseState, StateUpdate from reflex.utils import console, prerequisites +from reflex.utils.format import json_dumps from reflex.utils.tasks import ensure_task if TYPE_CHECKING: @@ -42,7 +43,7 @@ class LostAndFoundRecord: """Record for a StateUpdate for a token with its socket on another instance.""" token: str - update: dict[str, Any] + update: StateUpdate class TokenManager(ABC): @@ -386,8 +387,12 @@ async def _subscribe_lost_and_found_updates( ) async for message in pubsub.listen(): if message["type"] == "pmessage": - record = LostAndFoundRecord(**json.loads(message["data"].decode())) - await emit_update(StateUpdate(**record.update), record.token) + record_dict = json.loads(message["data"].decode()) + record = LostAndFoundRecord( + token=record_dict["token"], + update=StateUpdate(**record_dict["update"]), + ) + await emit_update(record.update, record.token) def ensure_lost_and_found_task( self, @@ -454,11 +459,11 @@ async def emit_lost_and_found( owner_instance_id = await self._get_token_owner(token) if owner_instance_id is None: return False - record = LostAndFoundRecord(token=token, update=dataclasses.asdict(update)) + record = LostAndFoundRecord(token=token, update=update) try: await self.redis.publish( f"channel:{self._get_lost_and_found_key(owner_instance_id)}", - json.dumps(dataclasses.asdict(record)), + json_dumps(record), ) except Exception as e: console.error(f"Redis error publishing lost and found delta: {e}") diff --git a/tests/units/utils/test_token_manager.py b/tests/units/utils/test_token_manager.py index f061574a7c9..56feb8af44d 100644 --- a/tests/units/utils/test_token_manager.py +++ b/tests/units/utils/test_token_manager.py @@ -11,7 +11,9 @@ from reflex import config from reflex.app import EventNamespace +from reflex.istate.data import RouterData from reflex.state import StateUpdate +from reflex.utils.format import json_dumps from reflex.utils.token_manager import ( LocalTokenManager, RedisTokenManager, @@ -670,3 +672,38 @@ async def test_redis_token_manager_lost_and_found( emit2_mock.assert_not_called() emit1_mock.assert_called_once() emit1_mock.reset_mock() + + +@pytest.mark.usefixtures("redis_url") +@pytest.mark.asyncio +async def test_redis_token_manager_lost_and_found_router_data( + event_namespace_factory: Callable[[], EventNamespace], +): + """Updates emitted for lost and found tokens should serialize properly. + + Args: + event_namespace_factory: Factory fixture for EventNamespace instances. + """ + event_namespace1 = event_namespace_factory() + emit1_mock: Mock = event_namespace1.emit # pyright: ignore[reportAssignmentType] + event_namespace2 = event_namespace_factory() + emit2_mock: Mock = event_namespace2.emit # pyright: ignore[reportAssignmentType] + + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) + await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2")) + + router = RouterData.from_router_data( + {"headers": {"x-test": "value"}}, + ) + + await event_namespace2.emit_update( + StateUpdate(delta={"state": {"router": router}}), token="token1" + ) + await _wait_for_call_count_positive(emit1_mock) + emit2_mock.assert_not_called() + emit1_mock.assert_called_once() + assert isinstance(emit1_mock.call_args[0][1], StateUpdate) + assert emit1_mock.call_args[0][1].delta["state"]["router"] == json.loads( + json_dumps(router) + ) + emit1_mock.reset_mock() From 4633e2d8c1ed9e151f25fa18afad6a004c9481f9 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 17 Dec 2025 16:04:35 -0800 Subject: [PATCH 2/4] Use pickle instead of JSON for private records --- reflex/utils/token_manager.py | 20 +++++++------------- tests/units/utils/test_token_manager.py | 6 ++---- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index 8522b35bc20..514d641cea6 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -4,7 +4,7 @@ import asyncio import dataclasses -import json +import pickle import uuid from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Callable, Coroutine @@ -14,7 +14,6 @@ from reflex.istate.manager.redis import StateManagerRedis from reflex.state import BaseState, StateUpdate from reflex.utils import console, prerequisites -from reflex.utils.format import json_dumps from reflex.utils.tasks import ensure_task if TYPE_CHECKING: @@ -329,7 +328,7 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None: try: await self.redis.set( redis_key, - json.dumps(dataclasses.asdict(socket_record)), + pickle.dumps(socket_record), ex=self.token_expiration, ) except Exception as e: @@ -387,11 +386,7 @@ async def _subscribe_lost_and_found_updates( ) async for message in pubsub.listen(): if message["type"] == "pmessage": - record_dict = json.loads(message["data"].decode()) - record = LostAndFoundRecord( - token=record_dict["token"], - update=StateUpdate(**record_dict["update"]), - ) + record = pickle.loads(message["data"]) await emit_update(record.update, record.token) def ensure_lost_and_found_task( @@ -429,10 +424,9 @@ async def _get_token_owner(self, token: str, refresh: bool = False) -> str | Non redis_key = self._get_redis_key(token) try: - record_json = await self.redis.get(redis_key) - if record_json: - record_data = json.loads(record_json) - socket_record = SocketRecord(**record_data) + record_pkl = await self.redis.get(redis_key) + if record_pkl: + socket_record = pickle.loads(record_pkl) self.token_to_socket[token] = socket_record self.sid_to_token[socket_record.sid] = token return socket_record.instance_id @@ -463,7 +457,7 @@ async def emit_lost_and_found( try: await self.redis.publish( f"channel:{self._get_lost_and_found_key(owner_instance_id)}", - json_dumps(record), + pickle.dumps(record), ) except Exception as e: console.error(f"Redis error publishing lost and found delta: {e}") diff --git a/tests/units/utils/test_token_manager.py b/tests/units/utils/test_token_manager.py index 56feb8af44d..7a4c9c9fb08 100644 --- a/tests/units/utils/test_token_manager.py +++ b/tests/units/utils/test_token_manager.py @@ -13,7 +13,6 @@ from reflex.app import EventNamespace from reflex.istate.data import RouterData from reflex.state import StateUpdate -from reflex.utils.format import json_dumps from reflex.utils.token_manager import ( LocalTokenManager, RedisTokenManager, @@ -703,7 +702,6 @@ async def test_redis_token_manager_lost_and_found_router_data( emit2_mock.assert_not_called() emit1_mock.assert_called_once() assert isinstance(emit1_mock.call_args[0][1], StateUpdate) - assert emit1_mock.call_args[0][1].delta["state"]["router"] == json.loads( - json_dumps(router) - ) + assert isinstance(emit1_mock.call_args[0][1].delta["state"]["router"], RouterData) + assert emit1_mock.call_args[0][1].delta["state"]["router"] == router emit1_mock.reset_mock() From 8ad27dbd836ed6ae24864c354dacb4516bdd8a07 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 17 Dec 2025 16:12:48 -0800 Subject: [PATCH 3/4] oopsie --- tests/units/utils/test_token_manager.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/units/utils/test_token_manager.py b/tests/units/utils/test_token_manager.py index 7a4c9c9fb08..9f740a29f37 100644 --- a/tests/units/utils/test_token_manager.py +++ b/tests/units/utils/test_token_manager.py @@ -1,7 +1,7 @@ """Unit tests for TokenManager implementations.""" import asyncio -import json +import pickle import time from collections.abc import Callable, Generator from contextlib import asynccontextmanager @@ -301,7 +301,7 @@ async def test_link_token_to_sid_normal_case(self, manager, mock_redis): ) mock_redis.set.assert_called_once_with( f"token_manager_socket_record_{token}", - json.dumps({"instance_id": manager.instance_id, "sid": sid}), + pickle.dumps(SocketRecord(instance_id=manager.instance_id, sid=sid)), ex=3600, ) assert manager.token_to_socket[token].sid == sid @@ -348,7 +348,7 @@ async def test_link_token_to_sid_duplicate_detected(self, manager, mock_redis): ) mock_redis.set.assert_called_once_with( f"token_manager_socket_record_{result}", - json.dumps({"instance_id": manager.instance_id, "sid": sid}), + pickle.dumps(SocketRecord(instance_id=manager.instance_id, sid=sid)), ex=3600, ) assert manager.token_to_sid[result] == sid From 3ebdf3e0a2e94fb166484de6d482dd3fcc2ac18c Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 18 Dec 2025 10:28:39 -0800 Subject: [PATCH 4/4] Fix pickle test expectation for test_connection_banner --- tests/integration/test_connection_banner.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 885edae0ab1..b6dbb5e1bf5 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -1,5 +1,6 @@ """Test case for displaying the connection banner when the websocket drops.""" +import pickle from collections.abc import Generator import pytest @@ -10,7 +11,7 @@ from reflex.environment import environment from reflex.istate.manager.redis import StateManagerRedis from reflex.testing import AppHarness, WebDriver -from reflex.utils.token_manager import RedisTokenManager +from reflex.utils.token_manager import RedisTokenManager, SocketRecord from .utils import SessionStorage @@ -166,11 +167,10 @@ async def test_connection_banner(connection_banner: AppHarness): sid_before = app_token_manager.token_to_sid[token] if isinstance(connection_banner.state_manager, StateManagerRedis): assert isinstance(app_token_manager, RedisTokenManager) - assert ( - await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) - == f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_before}"}}'.encode() + assert await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) == pickle.dumps( + SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_before) ) delay_button = driver.find_element(By.ID, "delay") @@ -226,11 +226,10 @@ async def test_connection_banner(connection_banner: AppHarness): assert sid_before != sid_after if isinstance(connection_banner.state_manager, StateManagerRedis): assert isinstance(app_token_manager, RedisTokenManager) - assert ( - await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) - == f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_after}"}}'.encode() + assert await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) == pickle.dumps( + SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_after) ) # Count should have incremented after coming back up