diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index f3bccedf5fa..514d641cea6 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -4,12 +4,12 @@ import asyncio import dataclasses -import json +import pickle import uuid from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Callable, Coroutine from types import MappingProxyType -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, ClassVar from reflex.istate.manager.redis import StateManagerRedis from reflex.state import BaseState, StateUpdate @@ -42,7 +42,7 @@ class LostAndFoundRecord: """Record for a StateUpdate for a token with its socket on another instance.""" token: str - update: dict[str, Any] + update: StateUpdate class TokenManager(ABC): @@ -328,7 +328,7 @@ async def link_token_to_sid(self, token: str, sid: str) -> str | None: try: await self.redis.set( redis_key, - json.dumps(dataclasses.asdict(socket_record)), + pickle.dumps(socket_record), ex=self.token_expiration, ) except Exception as e: @@ -386,8 +386,8 @@ async def _subscribe_lost_and_found_updates( ) async for message in pubsub.listen(): if message["type"] == "pmessage": - record = LostAndFoundRecord(**json.loads(message["data"].decode())) - await emit_update(StateUpdate(**record.update), record.token) + record = pickle.loads(message["data"]) + await emit_update(record.update, record.token) def ensure_lost_and_found_task( self, @@ -424,10 +424,9 @@ async def _get_token_owner(self, token: str, refresh: bool = False) -> str | Non redis_key = self._get_redis_key(token) try: - record_json = await self.redis.get(redis_key) - if record_json: - record_data = json.loads(record_json) - socket_record = SocketRecord(**record_data) + record_pkl = await self.redis.get(redis_key) + if record_pkl: + socket_record = pickle.loads(record_pkl) self.token_to_socket[token] = socket_record self.sid_to_token[socket_record.sid] = token return socket_record.instance_id @@ -454,11 +453,11 @@ async def emit_lost_and_found( owner_instance_id = await self._get_token_owner(token) if owner_instance_id is None: return False - record = LostAndFoundRecord(token=token, update=dataclasses.asdict(update)) + record = LostAndFoundRecord(token=token, update=update) try: await self.redis.publish( f"channel:{self._get_lost_and_found_key(owner_instance_id)}", - json.dumps(dataclasses.asdict(record)), + pickle.dumps(record), ) except Exception as e: console.error(f"Redis error publishing lost and found delta: {e}") diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 885edae0ab1..b6dbb5e1bf5 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -1,5 +1,6 @@ """Test case for displaying the connection banner when the websocket drops.""" +import pickle from collections.abc import Generator import pytest @@ -10,7 +11,7 @@ from reflex.environment import environment from reflex.istate.manager.redis import StateManagerRedis from reflex.testing import AppHarness, WebDriver -from reflex.utils.token_manager import RedisTokenManager +from reflex.utils.token_manager import RedisTokenManager, SocketRecord from .utils import SessionStorage @@ -166,11 +167,10 @@ async def test_connection_banner(connection_banner: AppHarness): sid_before = app_token_manager.token_to_sid[token] if isinstance(connection_banner.state_manager, StateManagerRedis): assert isinstance(app_token_manager, RedisTokenManager) - assert ( - await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) - == f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_before}"}}'.encode() + assert await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) == pickle.dumps( + SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_before) ) delay_button = driver.find_element(By.ID, "delay") @@ -226,11 +226,10 @@ async def test_connection_banner(connection_banner: AppHarness): assert sid_before != sid_after if isinstance(connection_banner.state_manager, StateManagerRedis): assert isinstance(app_token_manager, RedisTokenManager) - assert ( - await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) - == f'{{"instance_id": "{app_token_manager.instance_id}", "sid": "{sid_after}"}}'.encode() + assert await connection_banner.state_manager.redis.get( + app_token_manager._get_redis_key(token) + ) == pickle.dumps( + SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_after) ) # Count should have incremented after coming back up diff --git a/tests/units/utils/test_token_manager.py b/tests/units/utils/test_token_manager.py index f061574a7c9..9f740a29f37 100644 --- a/tests/units/utils/test_token_manager.py +++ b/tests/units/utils/test_token_manager.py @@ -1,7 +1,7 @@ """Unit tests for TokenManager implementations.""" import asyncio -import json +import pickle import time from collections.abc import Callable, Generator from contextlib import asynccontextmanager @@ -11,6 +11,7 @@ from reflex import config from reflex.app import EventNamespace +from reflex.istate.data import RouterData from reflex.state import StateUpdate from reflex.utils.token_manager import ( LocalTokenManager, @@ -300,7 +301,7 @@ async def test_link_token_to_sid_normal_case(self, manager, mock_redis): ) mock_redis.set.assert_called_once_with( f"token_manager_socket_record_{token}", - json.dumps({"instance_id": manager.instance_id, "sid": sid}), + pickle.dumps(SocketRecord(instance_id=manager.instance_id, sid=sid)), ex=3600, ) assert manager.token_to_socket[token].sid == sid @@ -347,7 +348,7 @@ async def test_link_token_to_sid_duplicate_detected(self, manager, mock_redis): ) mock_redis.set.assert_called_once_with( f"token_manager_socket_record_{result}", - json.dumps({"instance_id": manager.instance_id, "sid": sid}), + pickle.dumps(SocketRecord(instance_id=manager.instance_id, sid=sid)), ex=3600, ) assert manager.token_to_sid[result] == sid @@ -670,3 +671,37 @@ async def test_redis_token_manager_lost_and_found( emit2_mock.assert_not_called() emit1_mock.assert_called_once() emit1_mock.reset_mock() + + +@pytest.mark.usefixtures("redis_url") +@pytest.mark.asyncio +async def test_redis_token_manager_lost_and_found_router_data( + event_namespace_factory: Callable[[], EventNamespace], +): + """Updates emitted for lost and found tokens should serialize properly. + + Args: + event_namespace_factory: Factory fixture for EventNamespace instances. + """ + event_namespace1 = event_namespace_factory() + emit1_mock: Mock = event_namespace1.emit # pyright: ignore[reportAssignmentType] + event_namespace2 = event_namespace_factory() + emit2_mock: Mock = event_namespace2.emit # pyright: ignore[reportAssignmentType] + + await event_namespace1.on_connect(sid="sid1", environ=query_string_for("token1")) + await event_namespace2.on_connect(sid="sid2", environ=query_string_for("token2")) + + router = RouterData.from_router_data( + {"headers": {"x-test": "value"}}, + ) + + await event_namespace2.emit_update( + StateUpdate(delta={"state": {"router": router}}), token="token1" + ) + await _wait_for_call_count_positive(emit1_mock) + emit2_mock.assert_not_called() + emit1_mock.assert_called_once() + assert isinstance(emit1_mock.call_args[0][1], StateUpdate) + assert isinstance(emit1_mock.call_args[0][1].delta["state"]["router"], RouterData) + assert emit1_mock.call_args[0][1].delta["state"]["router"] == router + emit1_mock.reset_mock()