Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tests/vec_inf/client/test_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
)


@pytest.fixture(autouse=True)
def patch_model_weights_exists(monkeypatch):
"""Ensure model weights directory existence checks default to True."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists", lambda self: True
)


class TestSlurmScriptGenerator:
"""Tests for SlurmScriptGenerator class."""

Expand Down Expand Up @@ -168,6 +176,21 @@ def test_generate_server_setup_singularity(self, singularity_params):
"module load " in setup
) # Remove module name since it's inconsistent between clusters

def test_generate_server_setup_singularity_no_weights(
self, singularity_params, monkeypatch
):
"""Test server setup when model weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = SlurmScriptGenerator(singularity_params)
setup = generator._generate_server_setup()

assert "ray stop" in setup
assert "/path/to/model_weights/test-model" not in setup

def test_generate_launch_cmd_venv(self, basic_params):
"""Test launch command generation with virtual environment."""
generator = SlurmScriptGenerator(basic_params)
Expand All @@ -187,6 +210,22 @@ def test_generate_launch_cmd_singularity(self, singularity_params):
assert "apptainer exec --nv" in launch_cmd
assert "source" not in launch_cmd

def test_generate_launch_cmd_singularity_no_local_weights(
self, singularity_params, monkeypatch
):
"""Test container launch when model weights directory is missing."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = SlurmScriptGenerator(singularity_params)
launch_cmd = generator._generate_launch_cmd()

assert "exec --nv" in launch_cmd
assert "--bind /path/to/model_weights/test-model" not in launch_cmd
assert "vllm serve test-model" in launch_cmd

def test_generate_launch_cmd_boolean_args(self, basic_params):
"""Test launch command with boolean vLLM arguments."""
params = basic_params.copy()
Expand Down Expand Up @@ -391,6 +430,24 @@ def test_generate_model_launch_script_singularity(
mock_touch.assert_called_once()
mock_write_text.assert_called_once()

@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
def test_generate_model_launch_script_singularity_no_weights(
self, mock_write_text, mock_touch, batch_singularity_params, monkeypatch
):
"""Test batch model launch script when model weights don't exist."""
monkeypatch.setattr(
"vec_inf.client._slurm_script_generator.Path.exists",
lambda self: False,
)

generator = BatchSlurmScriptGenerator(batch_singularity_params)
script_path = generator._generate_model_launch_script("model1")

assert script_path.name == "launch_model1.sh"
call_args = mock_write_text.call_args[0][0]
assert "/path/to/model_weights/model1" not in call_args

@patch("vec_inf.client._slurm_script_generator.datetime")
@patch("pathlib.Path.touch")
@patch("pathlib.Path.write_text")
Expand Down
78 changes: 60 additions & 18 deletions vec_inf/client/_slurm_script_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SLURM_SCRIPT_TEMPLATE,
)
from vec_inf.client._slurm_vars import CONTAINER_MODULE_NAME
from vec_inf.client._utils import check_and_warn_hf_cache


class SlurmScriptGenerator:
Expand All @@ -37,8 +38,20 @@ def __init__(self, params: dict[str, Any]):
self.additional_binds = (
f",{self.params['bind']}" if self.params.get("bind") else ""
)
self.model_weights_path = str(
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
model_weights_path = Path(
self.params["model_weights_parent_dir"], self.params["model_name"]
)
self.model_weights_exists = model_weights_path.exists()
self.model_weights_path = str(model_weights_path)
self.model_source = (
self.model_weights_path
if self.model_weights_exists
else self.params["model_name"]
)
check_and_warn_hf_cache(
self.model_weights_exists,
self.model_weights_path,
self.params.get("env", {}),
)
self.env_str = self._generate_env_str()

Expand Down Expand Up @@ -111,7 +124,9 @@ def _generate_server_setup(self) -> str:
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
server_script.append(
SLURM_SCRIPT_TEMPLATE["bind_path"].format(
model_weights_path=self.model_weights_path,
model_weights_path=self.model_weights_path
if self.model_weights_exists
else "",
additional_binds=self.additional_binds,
)
)
Expand All @@ -131,7 +146,6 @@ def _generate_server_setup(self) -> str:
server_setup_str = server_setup_str.replace(
"CONTAINER_PLACEHOLDER",
SLURM_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=self.model_weights_path,
env_str=self.env_str,
),
)
Expand Down Expand Up @@ -165,22 +179,27 @@ def _generate_launch_cmd(self) -> str:
Server launch command.
"""
launcher_script = ["\n"]

vllm_args_copy = self.params["vllm_args"].copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is necessary, as the model name should be parsed with launch command not part of --vllm-args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, model_name is the short name used for config lookup, log directories, and job naming (e.g., llama-3). The --model in vllm_args would allow users to specify the full HF path when downloading from HuggingFace.

I'm open to alternative approaches if you have a preference, like:
Dedicated CLI option (e.g., --hf-model) - keeps model_name as the short identifier, adds explicit option for full HF path
Reuse existing model_name - allow full HF paths directly, but adjust config lookups, log directory structure, etc. to handle paths with /

model_source = self.model_source
if "--model" in vllm_args_copy:
model_source = vllm_args_copy.pop("--model")

if self.use_container:
launcher_script.append(
SLURM_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=self.model_weights_path,
env_str=self.env_str,
)
)

launcher_script.append(
"\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"]).format(
model_weights_path=self.model_weights_path,
model_source=model_source,
model_name=self.params["model_name"],
)
)

for arg, value in self.params["vllm_args"].items():
for arg, value in vllm_args_copy.items():
if isinstance(value, bool):
launcher_script.append(f" {arg} \\")
else:
Expand Down Expand Up @@ -225,11 +244,26 @@ def __init__(self, params: dict[str, Any]):
if self.params["models"][model_name].get("bind")
else ""
)
self.params["models"][model_name]["model_weights_path"] = str(
Path(
self.params["models"][model_name]["model_weights_parent_dir"],
model_name,
)
model_weights_path = Path(
self.params["models"][model_name]["model_weights_parent_dir"],
model_name,
)
model_weights_exists = model_weights_path.exists()
model_weights_path_str = str(model_weights_path)
self.params["models"][model_name]["model_weights_path"] = (
model_weights_path_str
)
self.params["models"][model_name]["model_weights_exists"] = (
model_weights_exists
)
self.params["models"][model_name]["model_source"] = (
model_weights_path_str if model_weights_exists else model_name
)
check_and_warn_hf_cache(
model_weights_exists,
model_weights_path_str,
self.params["models"][model_name].get("env", {}),
model_name,
)

def _write_to_log_dir(self, script_content: list[str], script_name: str) -> Path:
Expand Down Expand Up @@ -266,7 +300,9 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
script_content.append(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
model_weights_path=model_params["model_weights_path"],
model_weights_path=model_params["model_weights_path"]
if model_params.get("model_weights_exists", True)
else "",
additional_binds=model_params["additional_binds"],
)
)
Expand All @@ -283,19 +319,25 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
model_name=model_name,
)
)
vllm_args_copy = model_params["vllm_args"].copy()
model_source = model_params.get(
"model_source", model_params["model_weights_path"]
)
if "--model" in vllm_args_copy:
model_source = vllm_args_copy.pop("--model")

if self.use_container:
script_content.append(
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
model_weights_path=model_params["model_weights_path"],
)
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format()
)
script_content.append(
"\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"]).format(
model_weights_path=model_params["model_weights_path"],
model_source=model_source,
model_name=model_name,
)
)
for arg, value in model_params["vllm_args"].items():

for arg, value in vllm_args_copy.items():
if isinstance(value, bool):
script_content.append(f" {arg} \\")
else:
Expand Down
6 changes: 3 additions & 3 deletions vec_inf/client/_slurm_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class SlurmScriptTemplate(TypedDict):
f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop",
],
"imports": "source {src_dir}/find_port.sh",
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}",
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp{{model_weights_path}}{{additional_binds}}",
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {IMAGE_PATH} \\",
"activate_venv": "source {venv}/bin/activate",
"server_setup": {
Expand Down Expand Up @@ -164,7 +164,7 @@ class SlurmScriptTemplate(TypedDict):
' && mv temp.json "$json_path"',
],
"launch_cmd": [
"vllm serve {model_weights_path} \\",
"vllm serve {model_source} \\",
" --served-model-name {model_name} \\",
' --host "0.0.0.0" \\',
" --port $vllm_port_number \\",
Expand Down Expand Up @@ -255,7 +255,7 @@ class BatchModelLaunchScriptTemplate(TypedDict):
],
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {IMAGE_PATH} \\",
"launch_cmd": [
"vllm serve {model_weights_path} \\",
"vllm serve {model_source} \\",
" --served-model-name {model_name} \\",
' --host "0.0.0.0" \\',
" --port $vllm_port_number \\",
Expand Down
42 changes: 42 additions & 0 deletions vec_inf/client/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,3 +456,45 @@ def check_required_fields(params: dict[str, Any]) -> dict[str, Any]:
f"{arg} is required, please set it in the command arguments or environment variables"
)
return env_overrides


def check_and_warn_hf_cache(
model_weights_exists: bool,
model_weights_path: str,
env_dict: dict[str, str],
model_name: str | None = None,
) -> None:
"""Warn if model weights don't exist and HuggingFace cache directory is not set.

Parameters
----------
model_weights_exists : bool
Whether the model weights exist at the expected path.
model_weights_path : str
The expected path to the model weights.
env_dict : dict[str, str]
Dictionary of environment variables to check (from --env parameter).
model_name : str | None, optional
Optional model name to include in the warning message (for batch mode).
"""
if model_weights_exists:
return

hf_cache_vars = ["HF_HOME", "HF_HUB_CACHE", "HUGGINGFACE_HUB_CACHE"]
hf_cache_set = any(
os.environ.get(var) or env_dict.get(var) for var in hf_cache_vars
)

if not hf_cache_set:
model_prefix = (
f"Model weights for '{model_name}' " if model_name else "Model weights "
)
warnings.warn(
f"{model_prefix}not found at '{model_weights_path}' and no "
f"HuggingFace cache directory is set (HF_HOME, HF_HUB_CACHE, or "
f"HUGGINGFACE_HUB_CACHE). The model may be downloaded to your home "
f"directory, which could consume your storage quota. Consider setting "
f"one of these environment variables to a shared cache location.",
UserWarning,
stacklevel=4,
)