diff --git a/tests/vec_inf/client/test_slurm_script_generator.py b/tests/vec_inf/client/test_slurm_script_generator.py index 3d141af8..a81a962d 100644 --- a/tests/vec_inf/client/test_slurm_script_generator.py +++ b/tests/vec_inf/client/test_slurm_script_generator.py @@ -12,6 +12,14 @@ ) +@pytest.fixture(autouse=True) +def patch_model_weights_exists(monkeypatch): + """Ensure model weights directory existence checks default to True.""" + monkeypatch.setattr( + "vec_inf.client._slurm_script_generator.Path.exists", lambda self: True + ) + + class TestSlurmScriptGenerator: """Tests for SlurmScriptGenerator class.""" @@ -168,6 +176,21 @@ def test_generate_server_setup_singularity(self, singularity_params): "module load " in setup ) # Remove module name since it's inconsistent between clusters + def test_generate_server_setup_singularity_no_weights( + self, singularity_params, monkeypatch + ): + """Test server setup when model weights don't exist.""" + monkeypatch.setattr( + "vec_inf.client._slurm_script_generator.Path.exists", + lambda self: False, + ) + + generator = SlurmScriptGenerator(singularity_params) + setup = generator._generate_server_setup() + + assert "ray stop" in setup + assert "/path/to/model_weights/test-model" not in setup + def test_generate_launch_cmd_venv(self, basic_params): """Test launch command generation with virtual environment.""" generator = SlurmScriptGenerator(basic_params) @@ -187,6 +210,22 @@ def test_generate_launch_cmd_singularity(self, singularity_params): assert "apptainer exec --nv" in launch_cmd assert "source" not in launch_cmd + def test_generate_launch_cmd_singularity_no_local_weights( + self, singularity_params, monkeypatch + ): + """Test container launch when model weights directory is missing.""" + monkeypatch.setattr( + "vec_inf.client._slurm_script_generator.Path.exists", + lambda self: False, + ) + + generator = SlurmScriptGenerator(singularity_params) + launch_cmd = generator._generate_launch_cmd() + + assert "exec --nv" in launch_cmd + assert "--bind /path/to/model_weights/test-model" not in launch_cmd + assert "vllm serve test-model" in launch_cmd + def test_generate_launch_cmd_boolean_args(self, basic_params): """Test launch command with boolean vLLM arguments.""" params = basic_params.copy() @@ -391,6 +430,24 @@ def test_generate_model_launch_script_singularity( mock_touch.assert_called_once() mock_write_text.assert_called_once() + @patch("pathlib.Path.touch") + @patch("pathlib.Path.write_text") + def test_generate_model_launch_script_singularity_no_weights( + self, mock_write_text, mock_touch, batch_singularity_params, monkeypatch + ): + """Test batch model launch script when model weights don't exist.""" + monkeypatch.setattr( + "vec_inf.client._slurm_script_generator.Path.exists", + lambda self: False, + ) + + generator = BatchSlurmScriptGenerator(batch_singularity_params) + script_path = generator._generate_model_launch_script("model1") + + assert script_path.name == "launch_model1.sh" + call_args = mock_write_text.call_args[0][0] + assert "/path/to/model_weights/model1" not in call_args + @patch("vec_inf.client._slurm_script_generator.datetime") @patch("pathlib.Path.touch") @patch("pathlib.Path.write_text") diff --git a/vec_inf/client/_slurm_script_generator.py b/vec_inf/client/_slurm_script_generator.py index 15571715..433a8f27 100644 --- a/vec_inf/client/_slurm_script_generator.py +++ b/vec_inf/client/_slurm_script_generator.py @@ -15,6 +15,7 @@ SLURM_SCRIPT_TEMPLATE, ) from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME +from vec_inf.client._utils import check_and_warn_hf_cache class SlurmScriptGenerator: @@ -37,8 +38,20 @@ def __init__(self, params: dict[str, Any]): self.additional_binds = ( f",{self.params['bind']}" if self.params.get("bind") else "" ) - self.model_weights_path = str( - Path(self.params["model_weights_parent_dir"], self.params["model_name"]) + model_weights_path = Path( + self.params["model_weights_parent_dir"], self.params["model_name"] + ) + self.model_weights_exists = model_weights_path.exists() + self.model_weights_path = str(model_weights_path) + self.model_source = ( + self.model_weights_path + if self.model_weights_exists + else self.params["model_name"] + ) + check_and_warn_hf_cache( + self.model_weights_exists, + self.model_weights_path, + self.params.get("env", {}), ) self.env_str = self._generate_env_str() @@ -111,7 +124,9 @@ def _generate_server_setup(self) -> str: server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"])) server_script.append( SLURM_SCRIPT_TEMPLATE["bind_path"].format( - model_weights_path=self.model_weights_path, + model_weights_path=self.model_weights_path + if self.model_weights_exists + else "", additional_binds=self.additional_binds, ) ) @@ -131,7 +146,6 @@ def _generate_server_setup(self) -> str: server_setup_str = server_setup_str.replace( "CONTAINER_PLACEHOLDER", SLURM_SCRIPT_TEMPLATE["container_command"].format( - model_weights_path=self.model_weights_path, env_str=self.env_str, ), ) @@ -165,22 +179,27 @@ def _generate_launch_cmd(self) -> str: Server launch command. """ launcher_script = ["\n"] + + vllm_args_copy = self.params["vllm_args"].copy() + model_source = self.model_source + if "--model" in vllm_args_copy: + model_source = vllm_args_copy.pop("--model") + if self.use_container: launcher_script.append( SLURM_SCRIPT_TEMPLATE["container_command"].format( - model_weights_path=self.model_weights_path, env_str=self.env_str, ) ) launcher_script.append( "\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format( - model_weights_path=self.model_weights_path, + model_source=model_source, model_name=self.params["model_name"], ) ) - for arg, value in self.params["vllm_args"].items(): + for arg, value in vllm_args_copy.items(): if isinstance(value, bool): launcher_script.append(f" {arg} \\") else: @@ -225,11 +244,26 @@ def __init__(self, params: dict[str, Any]): if self.params["models"][model_name].get("bind") else "" ) - self.params["models"][model_name]["model_weights_path"] = str( - Path( - self.params["models"][model_name]["model_weights_parent_dir"], - model_name, - ) + model_weights_path = Path( + self.params["models"][model_name]["model_weights_parent_dir"], + model_name, + ) + model_weights_exists = model_weights_path.exists() + model_weights_path_str = str(model_weights_path) + self.params["models"][model_name]["model_weights_path"] = ( + model_weights_path_str + ) + self.params["models"][model_name]["model_weights_exists"] = ( + model_weights_exists + ) + self.params["models"][model_name]["model_source"] = ( + model_weights_path_str if model_weights_exists else model_name + ) + check_and_warn_hf_cache( + model_weights_exists, + model_weights_path_str, + self.params["models"][model_name].get("env", {}), + model_name, ) def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path: @@ -266,7 +300,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path: script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"]) script_content.append( BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format( - model_weights_path=model_params["model_weights_path"], + model_weights_path=model_params["model_weights_path"] + if model_params.get("model_weights_exists", True) + else "", additional_binds=model_params["additional_binds"], ) ) @@ -283,19 +319,25 @@ def _generate_model_launch_script(self, model_name: str) -> Path: model_name=model_name, ) ) + vllm_args_copy = model_params["vllm_args"].copy() + model_source = model_params.get( + "model_source", model_params["model_weights_path"] + ) + if "--model" in vllm_args_copy: + model_source = vllm_args_copy.pop("--model") + if self.use_container: script_content.append( - BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format( - model_weights_path=model_params["model_weights_path"], - ) + BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format() ) script_content.append( "\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format( - model_weights_path=model_params["model_weights_path"], + model_source=model_source, model_name=model_name, ) ) - for arg, value in model_params["vllm_args"].items(): + + for arg, value in vllm_args_copy.items(): if isinstance(value, bool): script_content.append(f" {arg} \\") else: diff --git a/vec_inf/client/_slurm_templates.py b/vec_inf/client/_slurm_templates.py index 43d91f61..62697006 100644 --- a/vec_inf/client/_slurm_templates.py +++ b/vec_inf/client/_slurm_templates.py @@ -98,7 +98,7 @@ class SlurmScriptTemplate(TypedDict): f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop", ], "imports": "source {src_dir}/find_port.sh", - "bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}", + "bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp{{model_weights_path}}{{additional_binds}}", "container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {IMAGE_PATH} \\", "activate_venv": "source {venv}/bin/activate", "server_setup": { @@ -164,7 +164,7 @@ class SlurmScriptTemplate(TypedDict): ' && mv temp.json "$json_path"', ], "launch_cmd": [ - "vllm serve {model_weights_path} \\", + "vllm serve {model_source} \\", " --served-model-name {model_name} \\", ' --host "0.0.0.0" \\', " --port $vllm_port_number \\", @@ -255,7 +255,7 @@ class BatchModelLaunchScriptTemplate(TypedDict): ], "container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {IMAGE_PATH} \\", "launch_cmd": [ - "vllm serve {model_weights_path} \\", + "vllm serve {model_source} \\", " --served-model-name {model_name} \\", ' --host "0.0.0.0" \\', " --port $vllm_port_number \\", diff --git a/vec_inf/client/_utils.py b/vec_inf/client/_utils.py index 67d907ec..2f3b01d6 100644 --- a/vec_inf/client/_utils.py +++ b/vec_inf/client/_utils.py @@ -456,3 +456,45 @@ def check_required_fields(params: dict[str, Any]) -> dict[str, Any]: f"{arg} is required, please set it in the command arguments or environment variables" ) return env_overrides + + +def check_and_warn_hf_cache( + model_weights_exists: bool, + model_weights_path: str, + env_dict: dict[str, str], + model_name: str | None = None, +) -> None: + """Warn if model weights don't exist and HuggingFace cache directory is not set. + + Parameters + ---------- + model_weights_exists : bool + Whether the model weights exist at the expected path. + model_weights_path : str + The expected path to the model weights. + env_dict : dict[str, str] + Dictionary of environment variables to check (from --env parameter). + model_name : str | None, optional + Optional model name to include in the warning message (for batch mode). + """ + if model_weights_exists: + return + + hf_cache_vars = ["HF_HOME", "HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE"] + hf_cache_set = any( + os.environ.get(var) or env_dict.get(var) for var in hf_cache_vars + ) + + if not hf_cache_set: + model_prefix = ( + f"Model weights for '{model_name}' " if model_name else "Model weights " + ) + warnings.warn( + f"{model_prefix}not found at '{model_weights_path}' and no " + f"HuggingFace cache directory is set (HF_HOME, HF_HUB_CACHE, or " + f"HUGGINGFACE_HUB_CACHE). The model may be downloaded to your home " + f"directory, which could consume your storage quota. Consider setting " + f"one of these environment variables to a shared cache location.", + UserWarning, + stacklevel=4, + )