Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions modelopt/torch/utils/plugins/megatron_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,48 @@ def _forward_step_func(data, model):
# NOTE: we don't support traditional positional embedding. Only RoPE or YaRN are supported.
position_ids = None

output_tensor = model(
data["tokens"],
position_ids,
attention_mask,
inference_context=inference_context,
runtime_gather_output=True,
)
# Check if this is a VLM model (has vision inputs)
_has_pixel_values = data.get("pixel_values") is not None
_has_image_grid_thw = data.get("image_grid_thw") is not None
_has_image_sizes = data.get("image_sizes") is not None
has_vision_inputs = _has_pixel_values or _has_image_grid_thw or _has_image_sizes

if has_vision_inputs:
# For VLM models:
# - position_ids: [batch, seq_len] (required for RoPE with multi-modal positions)
# - attention_mask: [batch, seq_len] (simple 1D boolean mask, not 4D causal)
vlm_position_ids = (
torch.arange(seq_len, dtype=torch.long, device=device)
.unsqueeze(0)
.expand(batch_size, -1)
)
vlm_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)

forward_args = {
"input_ids": data["tokens"],
"position_ids": vlm_position_ids,
"attention_mask": vlm_attention_mask,
"inference_context": inference_context,
"runtime_gather_output": True,
}
# Add vision inputs
if _has_pixel_values:
forward_args["pixel_values"] = data["pixel_values"]
if _has_image_grid_thw:
forward_args["image_grid_thw"] = data["image_grid_thw"]
if _has_image_sizes:
forward_args["image_sizes"] = data["image_sizes"]

output_tensor = model(**forward_args)
else:
# For text-only LLM models
output_tensor = model(
data["tokens"],
position_ids,
attention_mask,
inference_context=inference_context,
runtime_gather_output=True,
)
return output_tensor, _dummy_loss_func

disable_tqdm = disable_tqdm or torch.distributed.get_rank() > 0
Expand Down Expand Up @@ -248,9 +283,18 @@ def _forward_step_func(data, model):
else:
tokens = input_ids

data_dict = {"tokens": tokens}
# Vision inputs should only be passed during prefill (step 0), not during decode steps
if pixel_values is not None:
data_dict["pixel_values"] = pixel_values
if image_grid_thw is not None:
data_dict["image_grid_thw"] = image_grid_thw
if image_sizes is not None:
data_dict["image_sizes"] = image_sizes

list_of_logits = get_forward_backward_func()(
forward_step_func=_forward_step_func,
data_iterator=[{"tokens": tokens}],
data_iterator=[data_dict],
model=model,
num_microbatches=1,
seq_length=tokens.shape[-1],
Expand Down