diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index fdbef3afb71..ad97d59162a 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -758,7 +758,7 @@ def _mark_dirty( Raises: ImmutableStateError: if the StateProxy is not mutable. """ - if not self._self_state._is_mutable(): + if not self._self_state._is_mutable(): # pyright: ignore[reportAttributeAccessIssue] msg = ( "Background task StateProxy is immutable outside of a context " "manager. Use `async with self` to modify state." diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index 1ddc893e674..01bf154d913 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -190,6 +190,9 @@ async def _link_to(self, token: str) -> Self: if not token: msg = "Cannot link shared state to empty token." raise ReflexRuntimeError(msg) + if not isinstance(self, SharedState): + msg = "Can only link SharedState instances." + raise RuntimeError(msg) if self._linked_to == token: return self # already linked to this token if self._linked_to and self._linked_to != token: @@ -215,6 +218,10 @@ async def _unlink(self): """ from reflex.istate.manager import get_state_manager + if not isinstance(self, SharedState): + msg = "Can only unlink SharedState instances." + raise ReflexRuntimeError(msg) + state_name = self.get_full_name() if ( not self._reflex_internal_links @@ -272,6 +279,9 @@ async def _internal_patch_linked_state( _substate_key(token, type(self)) ) linked_state = await linked_root_state.get_state(type(self)) + if not isinstance(linked_state, SharedState): + msg = f"Linked state for token {token} is not a SharedState." + raise ReflexRuntimeError(msg) # Avoid unnecessary dirtiness of shared state when there are no changes. if type(self) not in self._held_locks[token]: self._held_locks[token][type(self)] = linked_state diff --git a/reflex/state.py b/reflex/state.py index 182addcad59..8ffe0abf3b7 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1350,7 +1350,7 @@ def _check_overwritten_dynamic_args(cls, args: list[str]): for substate in cls.get_substates(): substate._check_overwritten_dynamic_args(args) - def __getattribute__(self, name: str) -> Any: + def _get_attribute(self, name: str) -> Any: """Get the state var. If the var is inherited, get the var from the parent state. @@ -1408,6 +1408,9 @@ def __getattribute__(self, name: str) -> Any: return value + if not TYPE_CHECKING: + __getattribute__ = _get_attribute + def __setattr__(self, name: str, value: Any): """Set the attribute. diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index 88d71804b09..481e0c62d4f 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -669,6 +669,7 @@ def set_sub_sub(var: str, value: str): # Ensure the state is gone (not hydrated) async def poll_for_not_hydrated(): state = await client_side.get_state(_substate_key(token or "", state_name)) + assert isinstance(state, State) return not state.is_hydrated assert await AppHarness._poll_for_async(poll_for_not_hydrated) @@ -723,30 +724,30 @@ async def get_sub_state(): async def poll_for_c1_set(): sub_state = await get_sub_state() - return sub_state.c1 == "c1 post expire" + return sub_state.c1 == "c1 post expire" # pyright: ignore[reportAttributeAccessIssue] assert await AppHarness._poll_for_async(poll_for_c1_set) sub_state = await get_sub_state() - assert sub_state.c1 == "c1 post expire" - assert sub_state.c2 == "c2 value" - assert sub_state.c3 == "" - assert sub_state.c4 == "c4 value" - assert sub_state.c5 == "c5 value" - assert sub_state.c6 == "c6 value" - assert sub_state.c7 == "c7 value" - assert sub_state.l1 == "l1 value" - assert sub_state.l2 == "l2 value" - assert sub_state.l3 == "l3 value" - assert sub_state.l4 == "l4 value" - assert sub_state.s1 == "s1 value" - assert sub_state.s2 == "s2 value" - assert sub_state.s3 == "s3 value" + assert sub_state.c1 == "c1 post expire" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.c2 == "c2 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.c3 == "" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.c4 == "c4 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.c5 == "c5 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.c6 == "c6 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.c7 == "c7 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.l1 == "l1 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.l2 == "l2 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.l3 == "l3 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.l4 == "l4 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.s1 == "s1 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.s2 == "s2 value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_state.s3 == "s3 value" # pyright: ignore[reportAttributeAccessIssue] sub_sub_state = sub_state.substates[ client_side.get_state_name("_client_side_sub_sub_state") ] - assert sub_sub_state.c1s == "c1s value" - assert sub_sub_state.l1s == "l1s value" - assert sub_sub_state.s1s == "s1s value" + assert sub_sub_state.c1s == "c1s value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_sub_state.l1s == "l1s value" # pyright: ignore[reportAttributeAccessIssue] + assert sub_sub_state.s1s == "s1s value" # pyright: ignore[reportAttributeAccessIssue] # clear the cookie jar and local storage, ensure state reset to default driver.delete_all_cookies() diff --git a/tests/integration/test_component_state.py b/tests/integration/test_component_state.py index d1078b8eb07..0230fa129e4 100644 --- a/tests/integration/test_component_state.py +++ b/tests/integration/test_component_state.py @@ -167,8 +167,8 @@ async def test_component_state_app(component_state_app: AppHarness): a_state = root_state.substates[a_state_name] b_state = root_state.substates[b_state_name] assert a_state._backend_vars != a_state.backend_vars - assert a_state._be == a_state._backend_vars["_be"] == 3 - assert b_state._be is None + assert a_state._be == a_state._backend_vars["_be"] == 3 # pyright: ignore[reportAttributeAccessIssue] + assert b_state._be is None # pyright: ignore[reportAttributeAccessIssue] assert b_state._backend_vars["_be"] is None assert count_b.text == "0" @@ -183,7 +183,7 @@ async def test_component_state_app(component_state_app: AppHarness): a_state = root_state.substates[a_state_name] b_state = root_state.substates[b_state_name] assert b_state._backend_vars != b_state.backend_vars - assert b_state._be == b_state._backend_vars["_be"] == 2 + assert b_state._be == b_state._backend_vars["_be"] == 2 # pyright: ignore[reportAttributeAccessIssue] # Check locally-defined substate style count_c = driver.find_element(By.ID, "count-c") diff --git a/tests/integration/test_computed_vars.py b/tests/integration/test_computed_vars.py index 3b929ab149c..f4fb7a8d5f2 100644 --- a/tests/integration/test_computed_vars.py +++ b/tests/integration/test_computed_vars.py @@ -203,8 +203,8 @@ async def test_computed_vars( token = f"{token}_{full_state_name}" state = (await computed_vars.get_state(token)).substates[state_name] assert state is not None - assert state.count1_backend == 0 - assert state._count1_backend == 0 + assert state.count1_backend == 0 # pyright: ignore[reportAttributeAccessIssue] + assert state._count1_backend == 0 # pyright: ignore[reportAttributeAccessIssue] # test that backend var is not rendered count1_backend = driver.find_element(By.ID, "count1_backend") @@ -259,9 +259,9 @@ async def test_computed_vars( ) state = (await computed_vars.get_state(token)).substates[state_name] assert state is not None - assert state.count1_backend == 1 + assert state.count1_backend == 1 # pyright: ignore[reportAttributeAccessIssue] assert count1_backend.text == "" - assert state._count1_backend == 1 + assert state._count1_backend == 1 # pyright: ignore[reportAttributeAccessIssue] assert count1_backend_.text == "" mark_dirty.click() diff --git a/tests/integration/test_dynamic_routes.py b/tests/integration/test_dynamic_routes.py index 31858f86ecb..2f573f192c5 100644 --- a/tests/integration/test_dynamic_routes.py +++ b/tests/integration/test_dynamic_routes.py @@ -23,7 +23,7 @@ class DynamicState(rx.State): @rx.event def on_load(self): - page_data = f"{self.router.page.path}-{self.page_id or 'no page id'}" + page_data = f"{self.router.page.path}-{self.page_id or 'no page id'}" # pyright: ignore[reportAttributeAccessIssue] print(f"on_load: {page_data}") self.order.append(page_data) @@ -43,7 +43,7 @@ def on_load_static(self): @rx.var def next_page(self) -> str: try: - return str(int(self.page_id) + 1) + return str(int(self.page_id) + 1) # pyright: ignore[reportAttributeAccessIssue] except ValueError: return "0" @@ -81,7 +81,7 @@ class ArgState(rx.State): @rx.var(cache=False) def arg(self) -> int: - return int(self.arg_str or 0) + return int(self.arg_str or 0) # pyright: ignore[reportAttributeAccessIssue] class ArgSubState(ArgState): @rx.var @@ -90,7 +90,7 @@ def cached_arg(self) -> int: @rx.var def cached_arg_str(self) -> str: - return self.arg_str + return self.arg_str # pyright: ignore[reportAttributeAccessIssue] @rx.page(route="/arg/[arg_str]") def arg() -> rx.Component: @@ -238,11 +238,11 @@ async def _backend_state(): async def _check(): return (await _backend_state()).substates[ dynamic_state_name - ].order == exp_order + ].order == exp_order # pyright: ignore[reportAttributeAccessIssue] await AppHarness._poll_for_async(_check, timeout=10) assert ( - list((await _backend_state()).substates[dynamic_state_name].order) + list((await _backend_state()).substates[dynamic_state_name].order) # pyright: ignore[reportAttributeAccessIssue] == exp_order ) diff --git a/tests/integration/test_event_actions.py b/tests/integration/test_event_actions.py index eb8470ea3fb..e8f73587945 100644 --- a/tests/integration/test_event_actions.py +++ b/tests/integration/test_event_actions.py @@ -265,10 +265,10 @@ def poll_for_order( async def _poll_for_order(exp_order: list[str]): async def _check(): - return (await _backend_state(event_action, token)).order == exp_order + return (await _backend_state(event_action, token)).order == exp_order # pyright: ignore[reportAttributeAccessIssue] await AppHarness._poll_for_async(_check) - assert (await _backend_state(event_action, token)).order == exp_order + assert (await _backend_state(event_action, token)).order == exp_order # pyright: ignore[reportAttributeAccessIssue] return _poll_for_order @@ -358,12 +358,12 @@ async def test_event_actions_throttle_debounce( # Wait until the debounce event shows up async def _debounce_received(): state = await _backend_state(event_action, token) - return state.order and state.order[-1] == "on_click_debounce" + return state.order and state.order[-1] == "on_click_debounce" # pyright: ignore[reportAttributeAccessIssue] await AppHarness._poll_for_async(_debounce_received) # This test is inherently racy, so ensure the `on_click_throttle` event is fired approximately the expected number of times. - final_event_order = (await _backend_state(event_action, token)).order + final_event_order = (await _backend_state(event_action, token)).order # pyright: ignore[reportAttributeAccessIssue] n_on_click_throttle_received = final_event_order.count("on_click_throttle") print( f"Expected ~{exp_events} on_click_throttle events, received {n_on_click_throttle_received}" diff --git a/tests/integration/test_event_chain.py b/tests/integration/test_event_chain.py index cb48df0f1d5..289f12cbe80 100644 --- a/tests/integration/test_event_chain.py +++ b/tests/integration/test_event_chain.py @@ -462,11 +462,11 @@ async def test_event_chain_click( async def _has_all_events(): return len( - (await event_chain.get_state(token)).substates[state_name].event_order + (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).substates[state_name].event_order + event_order = (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] assert event_order == exp_event_order @@ -515,13 +515,13 @@ async def test_event_chain_on_load( async def _has_all_events(): return len( - (await event_chain.get_state(token)).substates[state_name].event_order + (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) backend_state = (await event_chain.get_state(token)).substates[state_name] - assert backend_state.event_order == exp_event_order - assert backend_state.is_hydrated is True + assert backend_state.event_order == exp_event_order # pyright: ignore[reportAttributeAccessIssue] + assert backend_state.is_hydrated is True # pyright: ignore[reportAttributeAccessIssue] @pytest.mark.parametrize( @@ -582,11 +582,11 @@ async def test_event_chain_on_mount( async def _has_all_events(): return len( - (await event_chain.get_state(token)).substates[state_name].event_order + (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] ) == len(exp_event_order) await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).substates[state_name].event_order + event_order = (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] assert list(event_order) == exp_event_order diff --git a/tests/integration/test_form_submit.py b/tests/integration/test_form_submit.py index f3cc77333ae..a064c1842fb 100644 --- a/tests/integration/test_form_submit.py +++ b/tests/integration/test_form_submit.py @@ -232,7 +232,7 @@ async def get_form_data(): return ( (await form_submit.get_state(f"{token}_{full_state_name}")) .substates[state_name] - .form_data + .form_data # pyright: ignore[reportAttributeAccessIssue] ) # wait for the form data to arrive at the backend diff --git a/tests/integration/test_input.py b/tests/integration/test_input.py index 2684db8f93e..125f5259154 100644 --- a/tests/integration/test_input.py +++ b/tests/integration/test_input.py @@ -96,7 +96,7 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): async def get_state_text(): state = await fully_controlled_input.get_state(f"{token}_{full_state_name}") - return state.substates[state_name].text + return state.substates[state_name].text # pyright: ignore[reportAttributeAccessIssue] # ensure defaults are set correctly assert ( diff --git a/tests/integration/test_linked_state.py b/tests/integration/test_linked_state.py index c892c48d7aa..4f3d5b14453 100644 --- a/tests/integration/test_linked_state.py +++ b/tests/integration/test_linked_state.py @@ -50,9 +50,9 @@ async def unlink(self): @rx.event async def on_load_link_default(self): - linked_state = await self._link_to(self.room or "default") - if self.room: - assert linked_state._linked_to == self.room + linked_state = await self._link_to(self.room or "default") # pyright: ignore[reportAttributeAccessIssue] + if self.room: # pyright: ignore[reportAttributeAccessIssue] + assert linked_state._linked_to == self.room # pyright: ignore[reportAttributeAccessIssue] else: assert linked_state._linked_to == "default" diff --git a/tests/integration/test_login_flow.py b/tests/integration/test_login_flow.py index 5765d80d8c1..e06008ecfee 100644 --- a/tests/integration/test_login_flow.py +++ b/tests/integration/test_login_flow.py @@ -22,6 +22,10 @@ def LoginSample(): class State(rx.State): auth_token: str = rx.LocalStorage("") + @rx.event + def set_auth_token(self, token: str): + self.auth_token = token + @rx.event def logout(self): self.set_auth_token("") diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index 9bfafe16e73..ed4a5456cd1 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -295,15 +295,15 @@ async def test_upload_file( state = await upload_file.get_state(substate_token) # only the secondary form tracks progress and chain events - assert state.substates[state_name].event_order.count("upload_progress") == 1 - assert state.substates[state_name].event_order.count("chain_event") == 1 + assert state.substates[state_name].event_order.count("upload_progress") == 1 # pyright: ignore[reportAttributeAccessIssue] + assert state.substates[state_name].event_order.count("chain_event") == 1 # pyright: ignore[reportAttributeAccessIssue] # look up the backend state and assert on uploaded contents async def get_file_data(): return ( (await upload_file.get_state(substate_token)) .substates[state_name] - ._file_data + ._file_data # pyright: ignore[reportAttributeAccessIssue] ) file_data = await AppHarness._poll_for_async(get_file_data) @@ -358,7 +358,7 @@ async def get_file_data(): return ( (await upload_file.get_state(substate_token)) .substates[state_name] - ._file_data + ._file_data # pyright: ignore[reportAttributeAccessIssue] ) file_data = await AppHarness._poll_for_async(get_file_data) @@ -469,7 +469,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive # Get interim progress dicts saved in the on_upload_progress handler. async def _progress_dicts(): state = await upload_file.get_state(substate_token) - return state.substates[state_name].progress_dicts + return state.substates[state_name].progress_dicts # pyright: ignore[reportAttributeAccessIssue] # We should have _some_ progress assert await AppHarness._poll_for_async(_progress_dicts) @@ -479,7 +479,7 @@ async def _progress_dicts(): assert p["progress"] != 1 state = await upload_file.get_state(substate_token) - file_data = state.substates[state_name]._file_data + file_data = state.substates[state_name]._file_data # pyright: ignore[reportAttributeAccessIssue] assert isinstance(file_data, dict) normalized_file_data = {Path(k).name: v for k, v in file_data.items()} assert Path(exp_name).name not in normalized_file_data @@ -575,11 +575,11 @@ async def test_on_drop( async def exp_name_in_quaternary(): state = await upload_file.get_state(substate_token) - return exp_name in state.substates[state_name].quaternary_names + return exp_name in state.substates[state_name].quaternary_names # pyright: ignore[reportAttributeAccessIssue] # Poll until the file names appear in the display await AppHarness._poll_for_async(exp_name_in_quaternary) # Verify through state that the file names were captured correctly state = await upload_file.get_state(substate_token) - assert exp_name in state.substates[state_name].quaternary_names + assert exp_name in state.substates[state_name].quaternary_names # pyright: ignore[reportAttributeAccessIssue] diff --git a/tests/units/components/core/test_cond.py b/tests/units/components/core/test_cond.py index ec651d00671..0e1df51d067 100644 --- a/tests/units/components/core/test_cond.py +++ b/tests/units/components/core/test_cond.py @@ -42,7 +42,7 @@ def test_validate_cond(cond_state: BaseState): cond_state: A fixture. """ cond_component = cond( - cond_state.value, + cond_state.value, # pyright: ignore[reportAttributeAccessIssue] Text.create("cond is True"), Text.create("cond is False"), ) @@ -50,7 +50,7 @@ def test_validate_cond(cond_state: BaseState): assert cond_dict["name"] == "Fragment" [condition] = cond_dict["children"] - assert condition["cond_state"] == str(cond_state.value.bool()) + assert condition["cond_state"] == str(cond_state.value.bool()) # pyright: ignore[reportAttributeAccessIssue] # true value true_value = condition["true_value"] diff --git a/tests/units/components/datadisplay/test_datatable.py b/tests/units/components/datadisplay/test_datatable.py index 968fd58cb83..8142870e465 100644 --- a/tests/units/components/datadisplay/test_datatable.py +++ b/tests/units/components/datadisplay/test_datatable.py @@ -34,12 +34,13 @@ def test_validate_data_table(data_table_state: rx.State, expected): expected: expected var name. """ - if not types.is_dataframe(data_table_state.data._var_type): + if not types.is_dataframe(data_table_state.data._var_type): # pyright: ignore[reportAttributeAccessIssue] data_table_component = DataTable.create( - data=data_table_state.data, columns=data_table_state.columns + data=data_table_state.data, # pyright: ignore[reportAttributeAccessIssue] + columns=data_table_state.columns, # pyright: ignore[reportAttributeAccessIssue] ) else: - data_table_component = DataTable.create(data=data_table_state.data) + data_table_component = DataTable.create(data=data_table_state.data) # pyright: ignore[reportAttributeAccessIssue] data_table_dict = data_table_component.render() diff --git a/tests/units/components/test_component.py b/tests/units/components/test_component.py index 060cc46f5fe..7ae569f5540 100644 --- a/tests/units/components/test_component.py +++ b/tests/units/components/test_component.py @@ -41,33 +41,44 @@ from reflex.vars.object import ObjectVar -@pytest.fixture -def test_state(): - class TestState(BaseState): - num: int +class TestState(BaseState): + """A test state with various methods for event handling.""" - def do_something(self): - pass + num: int - def do_something_arg(self, arg): - pass + @rx.event + def do_something(self): + """A method with no arguments.""" - def do_something_with_bool(self, arg: bool): - pass + @rx.event + def do_something_arg(self, arg): + """A method with one unspecfied argument.""" - def do_something_with_int(self, arg: int): - pass + @rx.event + def do_something_with_bool(self, arg: bool): + """A method with a boolean argument.""" - def do_something_with_list_int(self, arg: list[int]): - pass + @rx.event + def do_something_with_int(self, arg: int): + """A method with an integer argument.""" - def do_something_with_list_str(self, arg: list[str]): - pass + @rx.event + def do_something_with_list_int(self, arg: list[int]): + """A method with a list of integers argument.""" - def do_something_required_optional( - self, required_arg: int, optional_arg: int | None = None - ): - pass + @rx.event + def do_something_with_list_str(self, arg: list[str]): + """A method with a list of strings argument.""" + + @rx.event + def do_something_required_optional( + self, required_arg: int, optional_arg: int | None = None + ): + """A method with one required and one optional argument.""" + + +@pytest.fixture +def test_state(): return TestState @@ -577,7 +588,7 @@ def test_invalid_prop_type(component1, text: str, number: int): component1.create(text=text, number=number) -def test_var_props(component1, test_state): +def test_var_props(component1, test_state: type[TestState]): """Test that we can set a Var prop. Args: @@ -863,7 +874,7 @@ def my_component(width: Var[int], color: Var[str]): assert isinstance(component, Box) -def test_invalid_event_handler_args(component2, test_state): +def test_invalid_event_handler_args(component2, test_state: type[TestState]): """Test that an invalid event handler raises an error. Args: @@ -943,7 +954,7 @@ def test_invalid_event_handler_args(component2, test_state): ) -def test_valid_event_handler_args(component2, test_state): +def test_valid_event_handler_args(component2, test_state: type[TestState]): """Test that an valid event handler args do not raise exception. Args: @@ -1144,7 +1155,7 @@ def test_format_component(component, rendered): assert str(component) == rendered -def test_stateful_component(test_state): +def test_stateful_component(test_state: type[TestState]): """Test that a stateful component is created correctly. Args: @@ -1162,7 +1173,7 @@ def test_stateful_component(test_state): assert sc2.references == 2 -def test_stateful_component_memoize_event_trigger(test_state): +def test_stateful_component_memoize_event_trigger(test_state: type[TestState]): """Test that a stateful component is created correctly with events. Args: @@ -2100,7 +2111,7 @@ def add_hooks(self): assert ImportVar(tag="useEffect") in imports["react"] -def test_add_style_embedded_vars(test_state: BaseState): +def test_add_style_embedded_vars(test_state: type[TestState]): """Test that add_style works with embedded vars when returning a plain dict. Args: diff --git a/tests/units/istate/manager/test_redis.py b/tests/units/istate/manager/test_redis.py index 35450b122c8..076268ca9c6 100644 --- a/tests/units/istate/manager/test_redis.py +++ b/tests/units/istate/manager/test_redis.py @@ -15,18 +15,22 @@ from tests.units.mock_redis import mock_redis, real_redis +class RedisTestState(BaseState): + """A test state for redis state manager tests.""" + + foo: str = "bar" + count: int = 0 + + @pytest.fixture -def root_state() -> type[BaseState]: - class RedisTestState(BaseState): - foo: str = "bar" - count: int = 0 +def root_state() -> type[RedisTestState]: return RedisTestState @pytest_asyncio.fixture(loop_scope="function", scope="function") async def state_manager_redis( - root_state: type[BaseState], + root_state: type[RedisTestState], ) -> AsyncGenerator[StateManagerRedis]: """Get a StateManagerRedis with a real or mocked redis client. @@ -64,7 +68,7 @@ def event_log(state_manager_redis: StateManagerRedis) -> list[dict[str, Any]]: @pytest.mark.asyncio async def test_basic_get_set( state_manager_redis: StateManagerRedis, - root_state: type[BaseState], + root_state: type[RedisTestState], ): """Test basic operations of StateManagerRedis. @@ -84,7 +88,7 @@ async def test_basic_get_set( async def test_modify( state_manager_redis: StateManagerRedis, - root_state: type[BaseState], + root_state: type[RedisTestState], ): """Test modifying state with StateManagerRedis. @@ -106,16 +110,18 @@ async def test_modify( async with state_manager_redis.modify_state( _substate_key(token, root_state) ) as new_state: + assert isinstance(new_state, root_state) assert new_state.count == 1 new_state.count += 2 final_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + assert isinstance(final_state, root_state) assert final_state.count == 3 async def test_modify_oplock( state_manager_redis: StateManagerRedis, - root_state: type[BaseState], + root_state: type[RedisTestState], event_log: list[dict[str, Any]], ): """Test modifying state with StateManagerRedis with optimistic locking. @@ -244,7 +250,7 @@ async def test_modify_oplock( async def test_oplock_contention_queue( state_manager_redis: StateManagerRedis, - root_state: type[BaseState], + root_state: type[RedisTestState], event_log: list[dict[str, Any]], ): """Test the oplock contention queue. @@ -275,6 +281,7 @@ async def modify_1(): async with state_manager_redis.modify_state( _substate_key(token, root_state), ) as new_state: + assert isinstance(new_state, root_state) new_state.count += 1 modify_started.set() await modify_1_continue.wait() @@ -285,6 +292,7 @@ async def modify_2(): async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: + assert isinstance(new_state, root_state) new_state.count += 1 await modify_2_continue.wait() @@ -294,6 +302,7 @@ async def modify_3(): async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: + assert isinstance(new_state, root_state) new_state.count += 1 await modify_2_continue.wait() @@ -316,11 +325,13 @@ async def modify_3(): interim_state = await state_manager_redis.get_state( _substate_key(token, root_state) ) + assert isinstance(interim_state, root_state) assert interim_state.count == 1 await state_manager_2.close() final_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + assert isinstance(final_state, root_state) assert final_state.count == 3 # There should only be two lock acquisitions @@ -334,7 +345,7 @@ async def modify_3(): async def test_oplock_contention_no_lease( state_manager_redis: StateManagerRedis, - root_state: type[BaseState], + root_state: type[RedisTestState], event_log: list[dict[str, Any]], ): """Test the oplock contention queue, when no waiters can share. @@ -371,6 +382,7 @@ async def modify_1(): async with state_manager_redis.modify_state( _substate_key(token, root_state), ) as new_state: + assert isinstance(new_state, root_state) new_state.count += 1 modify_started.set() await modify_1_continue.wait() @@ -381,6 +393,7 @@ async def modify_2(): async with state_manager_2.modify_state( _substate_key(token, root_state), ) as new_state: + assert isinstance(new_state, root_state) new_state.count += 1 await modify_2_continue.wait() @@ -390,6 +403,7 @@ async def modify_3(): async with state_manager_3.modify_state( _substate_key(token, root_state), ) as new_state: + assert isinstance(new_state, root_state) new_state.count += 1 await modify_2_continue.wait() @@ -424,6 +438,7 @@ async def modify_3(): await state_manager_3.close() final_state = await state_manager_2.get_state(_substate_key(token, root_state)) + assert isinstance(final_state, root_state) assert final_state.count == 3 # There should be three lock acquisitions @@ -439,7 +454,7 @@ async def modify_3(): @pytest.mark.asyncio async def test_oplock_contention_racers( state_manager_redis: StateManagerRedis, - root_state: type[BaseState], + root_state: type[RedisTestState], racer_delay: float | None, ): """Test the oplock contention queue with racers. @@ -468,6 +483,7 @@ async def modify_1(): _substate_key(token, root_state), ) as new_state: lease_1 = await state_manager_redis._get_local_lease(token) + assert isinstance(new_state, root_state) new_state.count += 1 async def modify_2(): @@ -478,6 +494,7 @@ async def modify_2(): _substate_key(token, root_state), ) as new_state: lease_2 = await state_manager_2._get_local_lease(token) + assert isinstance(new_state, root_state) new_state.count += 1 await asyncio.gather( @@ -500,7 +517,7 @@ async def modify_2(): @pytest.mark.asyncio async def test_oplock_immediate_cancel( state_manager_redis: StateManagerRedis, - root_state: type[BaseState], + root_state: type[RedisTestState], event_log: list[dict[str, Any]], ): """Test that immediate cancellation of modify releases oplock. @@ -526,6 +543,7 @@ async def canceller(): _substate_key(token, root_state), ) as new_state: assert await state_manager_redis._get_local_lease(token) is None + assert isinstance(new_state, root_state) new_state.count += 1 await task @@ -534,7 +552,7 @@ async def canceller(): @pytest.mark.asyncio async def test_oplock_fetch_substate( state_manager_redis: StateManagerRedis, - root_state: type[BaseState], + root_state: type[RedisTestState], event_log: list[dict[str, Any]], ): """Test fetching substate with oplock enabled and partial state is cached. @@ -607,7 +625,7 @@ def short_lock_expiration( @pytest.mark.asyncio async def test_oplock_hold_oplock_after_cancel( state_manager_redis: StateManagerRedis, - root_state: type[BaseState], + root_state: type[RedisTestState], event_log: list[dict[str, Any]], short_lock_expiration: int, ): @@ -633,6 +651,7 @@ async def modify(): _substate_key(token, root_state), ) as new_state: modify_started.set() + assert isinstance(new_state, root_state) new_state.count += 1 await modify_continue.wait() modify_ended.set() @@ -667,6 +686,7 @@ async def modify(): async with state_manager_redis.modify_state( _substate_key(token, root_state), ) as new_state: + assert isinstance(new_state, root_state) new_state.count += 1 # There should have been two redis lock acquisitions. @@ -681,4 +701,5 @@ async def modify(): # Both increments should be present. final_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + assert isinstance(final_state, root_state) assert final_state.count == 2 diff --git a/tests/units/test_app.py b/tests/units/test_app.py index f7300cba907..6efd006f1fa 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -456,7 +456,7 @@ async def test_initialize_with_state(test_state: type[ATestState], token: str): @pytest.mark.asyncio -async def test_set_and_get_state(test_state): +async def test_set_and_get_state(test_state: type[ATestState]): """Test setting and getting the state of an app with different tokens. Args: @@ -471,6 +471,8 @@ async def test_set_and_get_state(test_state): # Get the default state for each token. state1 = await app.state_manager.get_state(token1) state2 = await app.state_manager.get_state(token2) + assert isinstance(state1, test_state) + assert isinstance(state2, test_state) assert state1.var == 0 assert state2.var == 0 @@ -483,6 +485,8 @@ async def test_set_and_get_state(test_state): # Get the states again and check the values. state1 = await app.state_manager.get_state(token1) state2 = await app.state_manager.get_state(token2) + assert isinstance(state1, test_state) + assert isinstance(state2, test_state) assert state1.var == 1 assert state2.var == 2 @@ -1132,7 +1136,7 @@ def comp_dynamic(self) -> str: Returns: same as self.dynamic """ - return self.dynamic + return self.dynamic # pyright: ignore[reportAttributeAccessIssue] on_load_internal = OnLoadInternalState.on_load_internal.fn # pyright: ignore [reportFunctionMemberAccess] @@ -1220,7 +1224,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( client_ip = "127.0.0.1" async with app.state_manager.modify_state(substate_token) as state: state.router_data = {"simulate": "hydrated"} - assert state.dynamic == "" + assert state.dynamic == "" # pyright: ignore[reportAttributeAccessIssue] exp_vals = ["foo", "foobar", "baz"] def _event(name, val, **kwargs): @@ -1294,7 +1298,7 @@ def _dynamic_state_event(name, val, **kwargs): if isinstance(app.state_manager, StateManagerRedis): # When redis is used, the state is not updated until the processing is complete state = await app.state_manager.get_state(substate_token) - assert state.dynamic == prev_exp_val + assert state.dynamic == prev_exp_val # pyright: ignore[reportAttributeAccessIssue] # complete the processing with pytest.raises(StopAsyncIteration): @@ -1305,7 +1309,7 @@ def _dynamic_state_event(name, val, **kwargs): # check that router data was written to the state_manager store state = await app.state_manager.get_state(substate_token) - assert state.dynamic == exp_val + assert state.dynamic == exp_val # pyright: ignore[reportAttributeAccessIssue] process_coro = process( app, @@ -1374,6 +1378,7 @@ def _dynamic_state_event(name, val, **kwargs): if environment.REFLEX_OPLOCK_ENABLED.get(): await app.state_manager.close() state = await app.state_manager.get_state(substate_token) + assert isinstance(state, DynamicState) assert state.loaded == len(exp_vals) assert state.counter == len(exp_vals) @@ -1416,7 +1421,9 @@ async def test_process_events(mocker: MockerFixture, token: str): if environment.REFLEX_OPLOCK_ENABLED.get(): await app.state_manager.close() - assert (await app.state_manager.get_state(event.substate_token)).value == 5 + gen_state = await app.state_manager.get_state(event.substate_token) + assert isinstance(gen_state, GenState) + assert gen_state.value == 5 assert app._postprocess.call_count == 6 # pyright: ignore [reportAttributeAccessIssue] await app.state_manager.close() diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 573024d56f7..8c733ee2e85 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -985,21 +985,20 @@ class DynamicState(BaseState): # Existing instances get the BaseVar assert ds1.dynamic_int.equals(DynamicState.dynamic_int) # pyright: ignore [reportAttributeAccessIssue] # New instances get an actual value with the default - assert DynamicState().dynamic_int == 42 + assert DynamicState().dynamic_int == 42 # pyright: ignore[reportAttributeAccessIssue] ds1.add_var("dynamic_list", list[int], [5, 10]) assert ds1.dynamic_list.equals(DynamicState.dynamic_list) # pyright: ignore [reportAttributeAccessIssue] ds2 = DynamicState() - assert ds2.dynamic_list == [5, 10] - ds2.dynamic_list.append(15) - assert ds2.dynamic_list == [5, 10, 15] - assert DynamicState().dynamic_list == [5, 10] + assert ds2.dynamic_list == [5, 10] # pyright: ignore[reportAttributeAccessIssue] + ds2.dynamic_list.append(15) # pyright: ignore[reportAttributeAccessIssue] + assert ds2.dynamic_list == [5, 10, 15] # pyright: ignore[reportAttributeAccessIssue] + assert DynamicState().dynamic_list == [5, 10] # pyright: ignore[reportAttributeAccessIssue] ds1.add_var("dynamic_dict", dict[str, int], {"k1": 5, "k2": 10}) assert ds1.dynamic_dict.equals(DynamicState.dynamic_dict) # pyright: ignore [reportAttributeAccessIssue] assert ds2.dynamic_dict.equals(DynamicState.dynamic_dict) # pyright: ignore [reportAttributeAccessIssue] - assert DynamicState().dynamic_dict == {"k1": 5, "k2": 10} - assert DynamicState().dynamic_dict == {"k1": 5, "k2": 10} + assert DynamicState().dynamic_dict == {"k1": 5, "k2": 10} # pyright: ignore[reportAttributeAccessIssue] def test_add_var_default_handlers(test_state): @@ -1192,6 +1191,7 @@ def rendered_var(self) -> int: ms = MainState() cs = ms.substates[ChildState.get_name()] assert ms.v == 2 + assert isinstance(cs, ChildState) assert cs.v == 2 assert cs.rendered_var == 2 @@ -1270,11 +1270,15 @@ def set_v4(self, v: int): assert ms.v == 1 # ensure handler can be called from substate - ms.substates[SubState.get_name()].set_v3(2) + sub_state = ms.substates[SubState.get_name()] + assert isinstance(sub_state, SubState) + sub_state.set_v3(2) assert ms.v == 2 # ensure handler can be called from substate (referencing grandparent handler) - ms.get_substate(tuple(SubSubState.get_full_name().split("."))).set_v4(3) + sub_sub_state = ms.get_substate(tuple(SubSubState.get_full_name().split("."))) + assert isinstance(sub_sub_state, SubSubState) + sub_sub_state.set_v4(3) assert ms.v == 3 @@ -1741,6 +1745,7 @@ async def test_state_manager_modify_state( assert token in state_manager._states_locks assert state_manager._states_locks[token].locked() # Should be able to write proxy objects inside mutables + assert isinstance(state, TestState) complex_1 = state.complex[1] assert isinstance(complex_1, MutableProxy) state.complex[3] = complex_1 @@ -1784,6 +1789,7 @@ async def test_state_manager_contend( async def _coro(): async with state_manager.modify_state(substate_token) as state: await asyncio.sleep(0.01) + assert isinstance(state, TestState) state.num1 += 1 tasks = [asyncio.create_task(_coro()) for _ in range(n_coroutines)] @@ -1794,7 +1800,9 @@ async def _coro(): if environment.REFLEX_OPLOCK_ENABLED.get(): await state_manager.close() - assert (await state_manager.get_state(substate_token)).num1 == exp_num1 + test_state = await state_manager.get_state(substate_token) + assert isinstance(test_state, TestState) + assert test_state.num1 == exp_num1 if isinstance(state_manager, StateManagerRedis): assert (await state_manager.redis.get(f"{token}_lock")) is None @@ -1939,15 +1947,17 @@ async def _coro_waiter(): with pytest.raises(LockExpiredError): raise loop_exception # In oplock mode, the blocker block's both updates - assert (await state_manager_redis.get_state(substate_token_redis)).num1 == 0 + test_state = await state_manager_redis.get_state(substate_token_redis) + assert isinstance(test_state, TestState) + assert test_state.num1 == 0 else: with pytest.raises(LockExpiredError): await tasks[0] await tasks[1] assert loop_exception is None - assert ( - await state_manager_redis.get_state(substate_token_redis) - ).num1 == exp_num1 + test_state = await state_manager_redis.get_state(substate_token_redis) + assert isinstance(test_state, TestState) + assert test_state.num1 == exp_num1 assert order == ["blocker", "waiter"] @@ -2202,6 +2212,7 @@ async def test_state_proxy( assert gotten_state is not parent_state gotten_grandchild_state = gotten_state.get_substate(sp._self_substate_path) assert gotten_grandchild_state is not None + assert isinstance(gotten_grandchild_state, GrandchildState) assert gotten_grandchild_state.value2 == "42" # ensure state update was emitted @@ -2412,12 +2423,11 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): "private", ] - assert ( - await mock_app.state_manager.get_state( - _substate_key(token, BackgroundTaskState) - ) - ).order == exp_order - + background_task_state = await mock_app.state_manager.get_state( + _substate_key(token, BackgroundTaskState) + ) + assert isinstance(background_task_state, BackgroundTaskState) + assert background_task_state.order == exp_order assert mock_app.event_namespace is not None emit_mock = mock_app.event_namespace.emit @@ -2503,13 +2513,11 @@ async def test_background_task_reset(mock_app: rx.App, token: str): if environment.REFLEX_OPLOCK_ENABLED.get(): await mock_app.state_manager.close() - assert ( - await mock_app.state_manager.get_state( - _substate_key(token, BackgroundTaskState) - ) - ).order == [ - "reset", - ] + background_task_state = await mock_app.state_manager.get_state( + _substate_key(token, BackgroundTaskState) + ) + assert isinstance(background_task_state, BackgroundTaskState) + assert background_task_state.order == ["reset"] @pytest.mark.asyncio @@ -3193,12 +3201,11 @@ async def test_get_state(mock_app: rx.App, token: str): ) # Because ChildState3 has a computed var, it is always dirty, and always populated. - assert ( - test_state.substates[ChildState3.get_name()] - .substates[GrandchildState3.get_name()] - .computed - == "" - ) + grandchild_state3 = test_state.substates[ChildState3.get_name()].substates[ + GrandchildState3.get_name() + ] + assert isinstance(grandchild_state3, GrandchildState3) + assert grandchild_state3.computed == "" # Get the child_state2 directly. child_state2_direct = test_state.get_substate([ChildState2.get_name()]) @@ -3444,6 +3451,7 @@ async def test_setvar(mock_app: rx.App, token: str): token: A token. """ state = await mock_app.state_manager.get_state(_substate_key(token, TestState)) + assert isinstance(state, TestState) # Set Var in same state (with Var type casting) for event in rx.event.fix_events( @@ -3826,7 +3834,10 @@ class DillState(BaseState): pk = state._serialize() unpickled_state = BaseState._deserialize(pk) + assert isinstance(unpickled_state, DillState) + assert unpickled_state._f is not None assert unpickled_state._f() == 420 + assert unpickled_state._o is not None assert unpickled_state._o._f() == 42 # Threading locks are unpicklable normally, and raise TypeError instead of PicklingError. @@ -3834,6 +3845,7 @@ class DillState(BaseState): state2._g = threading.Lock() pk2 = state2._serialize() unpickled_state2 = BaseState._deserialize(pk2) + assert isinstance(unpickled_state2, DillState) assert isinstance(unpickled_state2._g, type(threading.Lock())) # Some object, like generator, are still unpicklable with dill. @@ -4255,6 +4267,7 @@ async def v(self) -> int: # Get the unconnected sibling state, which will be used to `get_state` other instances. child = root.get_substate(Child.get_full_name().split(".")) + assert isinstance(child, Child) # Get an uncached child state. child2 = await child.get_state(Child2) @@ -4427,6 +4440,7 @@ async def test_rebind_mutable_proxy(mock_app: rx.App, token: str) -> None: async with mock_app.state_manager.modify_state( _substate_key(token, MutableProxyState) ) as state: + assert isinstance(state, MutableProxyState) assert state.data["a"] == [2, 3] if isinstance(mock_app.state_manager, StateManagerRedis): # In redis mode, the object identity does not persist across async with self calls.