From 85e71167ac0da01d6d94aef4bb2bcd8090cf944f Mon Sep 17 00:00:00 2001 From: Daniel Hatton Date: Mon, 2 Dec 2024 11:01:24 +0000 Subject: [PATCH 1/3] Add functionality to check what controllers the instrument server has running to allow reconnections to sessions following instrument server restarts --- src/murfey/client/multigrid_control.py | 32 ++++++++++++++++++-------- src/murfey/client/tui/screens.py | 6 ++++- src/murfey/instrument_server/api.py | 18 +++++++++++++-- src/murfey/server/api/instrument.py | 22 ++++++++++++++++++ src/murfey/server/demo_api.py | 1 + src/murfey/util/instrument_models.py | 3 +++ src/murfey/util/models.py | 2 ++ 7 files changed, 71 insertions(+), 13 deletions(-) diff --git a/src/murfey/client/multigrid_control.py b/src/murfey/client/multigrid_control.py index 4216f4d11..5c70a38b4 100644 --- a/src/murfey/client/multigrid_control.py +++ b/src/murfey/client/multigrid_control.py @@ -36,6 +36,7 @@ class MultigridController: do_transfer: bool = True dummy_dc: bool = False force_mdoc_metadata: bool = True + rsync_restarts: List[str] = field(default_factory=lambda: []) rsync_processes: Dict[Path, RSyncer] = field(default_factory=lambda: {}) analysers: Dict[Path, Analyser] = field(default_factory=lambda: {}) data_collection_parameters: dict = field(default_factory=lambda: {}) @@ -103,7 +104,10 @@ def _start_rsyncer_multigrid( f"{self._environment.url.geturl()}/instruments/{self.instrument_name}/machine" ).json() if destination_overrides.get(source): - destination = destination_overrides[source] + f"/{extra_directory}" + if str(source) in self.rsync_restarts: + destination = destination_overrides[source] + else: + destination = destination_overrides[source] + f"/{extra_directory}" else: for k, v in destination_overrides.items(): if Path(v).name in source.parts: @@ -134,6 +138,7 @@ def _start_rsyncer_multigrid( tag=tag, limited=limited, transfer=machine_data.get("data_transfer_enabled", True), + restarted=str(source) in self.rsync_restarts, ) self.ws.send(json.dumps({"message": "refresh"})) @@ -175,6 +180,7 @@ def _start_rsyncer( tag: str = "", limited: bool = False, transfer: bool = True, + restarted: bool = False, ): log.info(f"starting rsyncer: {source}") if self._environment: @@ -238,15 +244,21 @@ def rsync_result(update: RSyncerUpdate): ), secondary=True, ) - url = f"{str(self._environment.url.geturl())}/sessions/{str(self._environment.murfey_session)}/rsyncer" - rsyncer_data = { - "source": str(source), - "destination": destination, - "session_id": self.session_id, - "transferring": self.do_transfer or self._environment.demo, - "tag": tag, - } - requests.post(url, json=rsyncer_data) + if restarted: + restarted_url = ( + f"{self.murfey_url}/sessions/{self.session_id}/rsyncer_started" + ) + capture_post(restarted_url, json={"source": str(source)}) + else: + url = f"{str(self._environment.url.geturl())}/sessions/{str(self._environment.murfey_session)}/rsyncer" + rsyncer_data = { + "source": str(source), + "destination": destination, + "session_id": self.session_id, + "transferring": self.do_transfer or self._environment.demo, + "tag": tag, + } + requests.post(url, json=rsyncer_data) self._environment.watchers[source] = DirWatcher(source, settling_time=30) if not self.analysers.get(source) and analyse: diff --git a/src/murfey/client/tui/screens.py b/src/murfey/client/tui/screens.py index 54d3b6e8e..5d7a5281f 100644 --- a/src/murfey/client/tui/screens.py +++ b/src/murfey/client/tui/screens.py @@ -138,7 +138,11 @@ def determine_default_destination( _default = "" else: _default = destination + f"/{visit}" - return _default + f"/{extra_directory}" + return ( + _default + f"/{extra_directory}" + if not _default.endswith("/") + else _default + f"{extra_directory}" + ) class InputResponse(NamedTuple): diff --git a/src/murfey/instrument_server/api.py b/src/murfey/instrument_server/api.py index b66d9d7e3..c2bee0516 100644 --- a/src/murfey/instrument_server/api.py +++ b/src/murfey/instrument_server/api.py @@ -3,6 +3,7 @@ import secrets import time from datetime import datetime +from functools import partial from logging import getLogger from pathlib import Path from typing import Annotated, Dict, List, Optional, Union @@ -28,7 +29,7 @@ watchers: Dict[Union[str, int], MultigridDirWatcher] = {} rsyncers: Dict[str, RSyncer] = {} -controllers = {} +controllers: Dict[int, MultigridController] = {} data_collection_parameters: dict = {} tokens = {} @@ -131,10 +132,17 @@ async def token_handshake_for_session(session_id: int, token: Token): ) +@router.get("/sessions/{session_id}/check_token") +def check_token(session_id: MurfeySessionID): + return {"token_valid": True} + + @router.post("/sessions/{session_id}/multigrid_watcher") def start_multigrid_watcher( session_id: MurfeySessionID, watcher_spec: MultigridWatcherSpec ): + if controllers.get(session_id) is not None: + return {"success": True} label = watcher_spec.label controllers[session_id] = MultigridController( [], @@ -148,6 +156,7 @@ def start_multigrid_watcher( _machine_config=watcher_spec.configuration.dict(), token=tokens.get(session_id, "token"), data_collection_parameters=data_collection_parameters.get(label, {}), + rsync_restarts=watcher_spec.rsync_restarts, ) watcher_spec.source.mkdir(exist_ok=True) machine_config = requests.get( @@ -161,7 +170,12 @@ def start_multigrid_watcher( watcher_spec.configuration.dict(), skip_existing_processing=watcher_spec.skip_existing_processing, ) - watchers[session_id].subscribe(controllers[session_id]._start_rsyncer_multigrid) + watchers[session_id].subscribe( + partial( + controllers[session_id]._start_rsyncer_multigrid, + destination_overrides=watcher_spec.destination_overrides, + ) + ) watchers[session_id].start() return {"success": True} diff --git a/src/murfey/server/api/instrument.py b/src/murfey/server/api/instrument.py index fb0fa37a5..2eedb8c32 100644 --- a/src/murfey/server/api/instrument.py +++ b/src/murfey/server/api/instrument.py @@ -75,6 +75,24 @@ async def activate_instrument_server_for_session( return success +@router.get("/instruments/{instrument_name}/sessions/{session_id}/active") +async def check_if_session_is_active(instrument_name: str, session_id: int): + if instrument_server_tokens.get(session_id) is None: + return {"active": False} + async with lock: + async with aiohttp.ClientSession() as session: + machine_config = get_machine_config(instrument_name=instrument_name)[ + instrument_name + ] + async with session.get( + f"{machine_config.instrument_server_url}/sessions/{int(sanitise(str(session_id)))}/check_token", + headers={ + "Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}" + }, + ) as response: + return {"active": response.status == 200} + + @router.post("/sessions/{session_id}/multigrid_watcher") async def start_multigrid_watcher( session_id: MurfeySessionID, watcher_spec: MultigridWatcherSetup, db=murfey_db @@ -109,6 +127,10 @@ async def start_multigrid_watcher( "label": visit, "instrument_name": instrument_name, "skip_existing_processing": watcher_spec.skip_existing_processing, + "destination_overrides": { + str(k): v for k, v in watcher_spec.destination_overrides.items() + }, + "rsync_restarts": watcher_spec.rsync_restarts, }, headers={ "Authorization": f"Bearer {instrument_server_tokens[session_id]['access_token']}" diff --git a/src/murfey/server/demo_api.py b/src/murfey/server/demo_api.py index eab41858a..99140ee0d 100644 --- a/src/murfey/server/demo_api.py +++ b/src/murfey/server/demo_api.py @@ -316,6 +316,7 @@ async def get_session(session_id: MurfeySessionID, db=murfey_db) -> SessionClien def increment_rsync_file_count( visit_name: str, rsyncer_info: RsyncerInfo, db=murfey_db ): + print(rsyncer_info.source, rsyncer_info.destination, rsyncer_info.session_id) rsync_instance = db.exec( select(RsyncInstance).where( RsyncInstance.source == rsyncer_info.source, diff --git a/src/murfey/util/instrument_models.py b/src/murfey/util/instrument_models.py index 02a459479..3da23c57c 100644 --- a/src/murfey/util/instrument_models.py +++ b/src/murfey/util/instrument_models.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Dict, List from pydantic import BaseModel @@ -12,3 +13,5 @@ class MultigridWatcherSpec(BaseModel): visit: str instrument_name: str skip_existing_processing: bool = False + destination_overrides: Dict[Path, str] = {} + rsync_restarts: List[str] = [] diff --git a/src/murfey/util/models.py b/src/murfey/util/models.py index 7af87f754..781e4723d 100644 --- a/src/murfey/util/models.py +++ b/src/murfey/util/models.py @@ -334,6 +334,8 @@ class PostInfo(BaseModel): class MultigridWatcherSetup(BaseModel): source: Path skip_existing_processing: bool = False + destination_overrides: Dict[Path, str] = {} + rsync_restarts: List[str] = [] class CurrentGainRef(BaseModel): From 64261de8db8eaa75b0ec0365796a8b70be975843 Mon Sep 17 00:00:00 2001 From: Daniel Hatton Date: Mon, 2 Dec 2024 13:13:11 +0000 Subject: [PATCH 2/3] Remove debugging print statement --- src/murfey/server/demo_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/murfey/server/demo_api.py b/src/murfey/server/demo_api.py index 99140ee0d..eab41858a 100644 --- a/src/murfey/server/demo_api.py +++ b/src/murfey/server/demo_api.py @@ -316,7 +316,6 @@ async def get_session(session_id: MurfeySessionID, db=murfey_db) -> SessionClien def increment_rsync_file_count( visit_name: str, rsyncer_info: RsyncerInfo, db=murfey_db ): - print(rsyncer_info.source, rsyncer_info.destination, rsyncer_info.session_id) rsync_instance = db.exec( select(RsyncInstance).where( RsyncInstance.source == rsyncer_info.source, From 2ea537eddfa86b16ceca3ab9cd5ab50473ff59fc Mon Sep 17 00:00:00 2001 From: Daniel Hatton Date: Mon, 2 Dec 2024 13:17:06 +0000 Subject: [PATCH 3/3] A bit less code --- src/murfey/client/multigrid_control.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/murfey/client/multigrid_control.py b/src/murfey/client/multigrid_control.py index 5c70a38b4..3e6f640d9 100644 --- a/src/murfey/client/multigrid_control.py +++ b/src/murfey/client/multigrid_control.py @@ -104,10 +104,11 @@ def _start_rsyncer_multigrid( f"{self._environment.url.geturl()}/instruments/{self.instrument_name}/machine" ).json() if destination_overrides.get(source): - if str(source) in self.rsync_restarts: - destination = destination_overrides[source] - else: - destination = destination_overrides[source] + f"/{extra_directory}" + destination = ( + destination_overrides[source] + if str(source) in self.rsync_restarts + else destination_overrides[source] + f"/{extra_directory}" + ) else: for k, v in destination_overrides.items(): if Path(v).name in source.parts: