From bff6af9aa38af363aa787b0facc3764268199718 Mon Sep 17 00:00:00 2001 From: Miguel Martin Date: Sat, 7 Feb 2026 00:17:52 +0000 Subject: [PATCH] AR --- scripts/convert_cosmos_to_diffusers.py | 24 +- .../models/controlnets/controlnet_cosmos.py | 7 +- .../cosmos/pipeline_cosmos2_5_transfer.py | 536 +++++++++++------- .../test_models_controlnet_cosmos.py | 20 + .../cosmos/test_cosmos2_5_transfer.py | 51 +- 5 files changed, 412 insertions(+), 226 deletions(-) diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py index ae66c9b8197c..3faadc030926 100644 --- a/scripts/convert_cosmos_to_diffusers.py +++ b/scripts/convert_cosmos_to_diffusers.py @@ -94,9 +94,15 @@ --transformer_type Cosmos-2.5-Transfer-General-2B \ --transformer_ckpt_path $transformer_ckpt_path \ --vae_type wan2.1 \ - --output_path converted/transfer/2b/general/depth \ + --output_path converted/transfer/2b/general/depth/pipeline \ --save_pipeline +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Transfer-General-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/transfer/2b/general/depth/models + # edge transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt @@ -120,9 +126,15 @@ --transformer_type Cosmos-2.5-Transfer-General-2B \ --transformer_ckpt_path $transformer_ckpt_path \ --vae_type wan2.1 \ - --output_path converted/transfer/2b/general/blur \ + --output_path converted/transfer/2b/general/blur/pipeline \ --save_pipeline +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Transfer-General-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/transfer/2b/general/blur/models + # seg transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt @@ -130,8 +142,14 @@ --transformer_type Cosmos-2.5-Transfer-General-2B \ --transformer_ckpt_path $transformer_ckpt_path \ --vae_type wan2.1 \ - --output_path converted/transfer/2b/general/seg \ + --output_path converted/transfer/2b/general/seg/pipeline \ --save_pipeline + +python scripts/convert_cosmos_to_diffusers.py \ + --transformer_type Cosmos-2.5-Transfer-General-2B \ + --transformer_ckpt_path $transformer_ckpt_path \ + --vae_type wan2.1 \ + --output_path converted/transfer/2b/general/seg/models ``` """ diff --git a/src/diffusers/models/controlnets/controlnet_cosmos.py b/src/diffusers/models/controlnets/controlnet_cosmos.py index 6ea7d629b816..69878ee0fa1b 100644 --- a/src/diffusers/models/controlnets/controlnet_cosmos.py +++ b/src/diffusers/models/controlnets/controlnet_cosmos.py @@ -191,7 +191,12 @@ def forward( dim=1, ) - control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1) + if condition_mask is not None: + control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1) + else: + control_hidden_states = torch.cat( + [control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1 + ) padding_mask_resized = transforms.functional.resize( padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py index 13f583e8df8a..fc51c5c4a9b9 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_transfer.py @@ -54,11 +54,13 @@ def __init__(self, *args, **kwargs): logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def _maybe_pad_video(video: torch.Tensor, num_frames: int): +def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int): n_pad_frames = num_frames - video.shape[2] if n_pad_frames > 0: last_frame = video[:, :, -1:, :, :] video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2) + elif num_frames < video.shape[2]: + video = video[:, :, :num_frames, :, :] return video @@ -134,8 +136,8 @@ def retrieve_latents( >>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)] >>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30) + >>> # Controls-only inference (input video is optional when controls are provided). >>> video = pipe( - ... video=input_video[:num_frames], ... controls=controls, ... controls_conditioning_scale=1.0, ... prompt=prompt, @@ -149,7 +151,7 @@ def retrieve_latents( class Cosmos2_5_TransferPipeline(DiffusionPipeline): r""" - Pipeline for Cosmos Transfer2.5 base model. + Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). @@ -182,7 +184,7 @@ def __init__( vae: AutoencoderKLWan, scheduler: UniPCMultistepScheduler, controlnet: Optional[CosmosControlNetModel], - safety_checker: CosmosSafetyChecker = None, + safety_checker: Optional[CosmosSafetyChecker] = None, ): super().__init__() @@ -385,6 +387,7 @@ def prepare_latents( device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + num_cond_latent_frames: int = 0, ) -> torch.Tensor: if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -399,10 +402,14 @@ def prepare_latents( W = width // self.vae_scale_factor_spatial shape = (B, C, T, H, W) - if num_frames_in == 0: - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if latents is not None: + if latents.shape[1:] != shape[1:]: + raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.") + latents = latents.to(device=device, dtype=dtype) + else: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if num_frames_in == 0: cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device) cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device) @@ -432,16 +439,12 @@ def prepare_latents( latents_std = self.latents_std.to(device=device, dtype=dtype) cond_latents = (cond_latents - latents_mean) / latents_std - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device=device, dtype=dtype) - padding_shape = (B, 1, T, H, W) ones_padding = latents.new_ones(padding_shape) zeros_padding = latents.new_zeros(padding_shape) - cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1) + cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1) + cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0 cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding return ( @@ -451,33 +454,6 @@ def prepare_latents( cond_indicator, ) - def _encode_controls( - self, - controls: Optional[torch.Tensor], - height: int, - width: int, - num_frames: int, - dtype: torch.dtype, - device: torch.device, - generator: Optional[Union[torch.Generator, List[torch.Generator]]], - ) -> Optional[torch.Tensor]: - if controls is None: - return None - - control_video = self.video_processor.preprocess_video(controls, height, width) - control_video = _maybe_pad_video(control_video, num_frames) - - control_video = control_video.to(device=device, dtype=self.vae.dtype) - control_latents = [ - retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video - ] - control_latents = torch.cat(control_latents, dim=0).to(dtype) - - latents_mean = self.latents_mean.to(device=device, dtype=dtype) - latents_std = self.latents_std.to(device=device, dtype=dtype) - control_latents = (control_latents - latents_mean) / latents_std - return control_latents - # Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs def check_inputs( self, @@ -535,18 +511,19 @@ def __call__( self, image: PipelineImageInput | None = None, video: List[PipelineImageInput] | None = None, + controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None, + controls_conditioning_scale: Union[float, List[float]] = 1.0, prompt: Union[str, List[str]] | None = None, negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT, height: int = 704, width: Optional[int] = None, - num_frames: int = 93, + num_frames: Optional[int] = None, + num_frames_per_chunk: int = 93, num_inference_steps: int = 36, guidance_scale: float = 3.0, - num_videos_per_prompt: Optional[int] = 1, + num_videos_per_prompt: int = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None, - controls_conditioning_scale: Union[float, List[float]] = 1.0, prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", @@ -557,24 +534,40 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, conditional_frame_timestep: float = 0.1, + num_conditional_frames: Optional[int] = 1, + num_latent_conditional_frames: Optional[int] = None, ): r""" - The call function to the pipeline for generation. Supports three modes: - - - **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip. - - **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame. - - **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip. - - Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the - above in "*2Image mode"). - - Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt). + The call function can be used in two modes: with or without controls. + + When controls are not provided (`controls is None`), inference works in the same manner as predict2.5 (see + `Cosmos2_5_PredictPipeline`). This mode strictly uses the base transformer (`self.transformer`) to perform + inference and accepts as input an optional `image` or `video` along with a `prompt` / `negative_prompt`, and + can be used in the following ways: + - **Text2World**: `image=None`, `video=None`, `prompt` provided. + - **Image2World**: `image` provided, `video=None`, `prompt` provided. + - **Video2World**: `video` provided, `image=None`, `prompt` provided. + + When `controls` are provided and a ControlNet is attached, `controls` drive the conditioning and `video` & + `image` is ignored. Controls are assumed to be pre-processed, e.g. edge maps are pre-computed. + + Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None + (default) then the number of output frames will match the input `video`, `image` or `controls` respectively. + Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per + denoising loop. In addition, when auto-regressive inference is performed, the previous + `num_latent_conditional_frames` or `num_conditional_frames` are used to condition the following denoising + inference loops. Args: - image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*): - Optional single image for Image2World conditioning. Must be `None` when `video` is provided. - video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*): - Optional input video for Video2World conditioning. Must be `None` when `image` is provided. + image (`PipelineImageInput`, *optional*): + Image input to condition the first frame. The remaining frames (if any) are initialized from blanks. + video (`List[PipelineImageInput]`, *optional*): + Video input to condition the output. Only used when `controls` is `None`. + controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*): + Control image or video input used by the ControlNet. If `None`, ControlNet is skipped. If + `self.controlnet` is `None`, `controls` is ignored. + controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`): + The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks. prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied. height (`int`, defaults to `704`): @@ -582,8 +575,9 @@ def __call__( width (`int`, *optional*): The width in pixels of the generated image. If not provided, this will be determined based on the aspect ratio of the input and the provided height. - num_frames (`int`, defaults to `93`): - Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame. + num_frames (`int`, *optional*): + Number of output frames. Defaults to `None` to output the same number of frames as the input `video`, + `image` or `controls` num_inference_steps (`int`, defaults to `35`): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -598,13 +592,9 @@ def __call__( A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.Tensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*): - Control image or video input used by the ControlNet. If `None`, ControlNet is skipped. - controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`): - The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks. + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to + tweak the same generation with different prompts. If not provided, a latents tensor is generated by + sampling using the supplied random `generator`. prompt_embeds (`torch.Tensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -627,6 +617,10 @@ def __call__( max_sequence_length (`int`, defaults to `512`): The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If the prompt is shorter than this length, it will be padded. + num_conditional_frames (`int`, *optional*, defaults to `1`): + Number of conditional frames in video-space. Only used if `num_latent_conditional_frames` is `None`. + num_latent_conditional_frames (`int`, *optional*): + Number of conditional frames in latent-space. If provided, `num_conditional_frames` is ignored. Examples: @@ -660,6 +654,34 @@ def __call__( else: width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W + if num_latent_conditional_frames is not None and num_conditional_frames is not None: + if num_conditional_frames != 1: + raise ValueError( + "Provide only one of `num_conditional_frames` or `num_latent_conditional_frames`, not both." + ) + logger.warning( + "Both `num_conditional_frames` and `num_latent_conditional_frames` were provided. " + "Ignoring `num_conditional_frames`." + ) + num_conditional_frames = None + + if num_latent_conditional_frames is None and num_conditional_frames is None: + logger.warning( + "Neither `num_conditional_frames` and `num_latent_conditional_frames` were provided. " + "Setting `num_conditional_frames` to 1." + ) + num_conditional_frames = 1 + + if num_latent_conditional_frames is not None: + if num_latent_conditional_frames < 0: + raise ValueError("`num_latent_conditional_frames` must be >= 0.") + num_cond_latent_frames = num_latent_conditional_frames + num_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1) + else: + if num_conditional_frames < 0: + raise ValueError("`num_conditional_frames` must be >= 0.") + num_cond_latent_frames = max(0, (num_conditional_frames - 1) // self.vae_scale_factor_temporal + 1) + # Check inputs. Raise error if not correct self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) @@ -706,133 +728,189 @@ def __call__( vae_dtype = self.vae.dtype transformer_dtype = self.transformer.dtype - img_context = torch.zeros( - batch_size, - self.transformer.config.img_context_num_tokens, - self.transformer.config.img_context_dim_in, - device=prompt_embeds.device, - dtype=transformer_dtype, - ) - encoder_hidden_states = (prompt_embeds, img_context) - neg_encoder_hidden_states = (negative_prompt_embeds, img_context) - - num_frames_in = None - if image is not None: - if batch_size != 1: - raise ValueError(f"batch_size must be 1 for image input (given {batch_size})") - - image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0) - video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0) - video = video.unsqueeze(0) - num_frames_in = 1 - elif video is None: - video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8) - num_frames_in = 0 - else: - num_frames_in = len(video) - - if batch_size != 1: - raise ValueError(f"batch_size must be 1 for video input (given {batch_size})") - - assert video is not None - video = self.video_processor.preprocess_video(video, height, width) - - # pad with last frame (for video2world) - num_frames_out = num_frames - video = _maybe_pad_video(video, num_frames_out) - assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})" - - video = video.to(device=device, dtype=vae_dtype) - - num_channels_latents = self.transformer.config.in_channels - 1 - latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( - video=video, - batch_size=batch_size * num_videos_per_prompt, - num_channels_latents=num_channels_latents, - height=height, - width=width, - num_frames_in=num_frames_in, - num_frames_out=num_frames, - do_classifier_free_guidance=self.do_classifier_free_guidance, - dtype=torch.float32, - device=device, - generator=generator, - latents=latents, - ) - cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep - cond_mask = cond_mask.to(transformer_dtype) - - controls_latents = None - if controls is not None: - controls_latents = self._encode_controls( - controls, - height=height, - width=width, - num_frames=num_frames, + if getattr(self.transformer.config, "img_context_dim_in", None): + img_context = torch.zeros( + batch_size, + self.transformer.config.img_context_num_tokens, + self.transformer.config.img_context_dim_in, + device=prompt_embeds.device, dtype=transformer_dtype, - device=device, - generator=generator, ) - padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + if num_videos_per_prompt > 1: + img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0) - # Denoising loop - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - self._num_timesteps = len(timesteps) - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + encoder_hidden_states = (prompt_embeds, img_context) + neg_encoder_hidden_states = (negative_prompt_embeds, img_context) + else: + encoder_hidden_states = prompt_embeds + neg_encoder_hidden_states = negative_prompt_embeds - gt_velocity = (latents - cond_latent) * cond_mask - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue + if controls is not None and self.controlnet is None: + logger.warning("`controls` was provided but `controlnet` is None; ignoring `controls`.") + controls = None - self._current_timestep = t.cpu().item() + control_video = None + if controls is None: + num_frames_in = None + if image is not None: + if batch_size != 1: + raise ValueError(f"batch_size must be 1 for image input (given {batch_size})") + + image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0) + video = torch.cat([image, torch.zeros_like(image).repeat(num_frames_per_chunk - 1, 1, 1, 1)], dim=0) + video = video.unsqueeze(0) + num_frames_in = 1 + num_frames_out = num_frames + elif video is None: + video = torch.zeros(batch_size, num_frames_per_chunk, 3, height, width, dtype=torch.uint8) + num_frames_in = 0 + num_frames_out = num_frames + else: + num_frames_in = len(video) + num_frames_out = num_frames or num_frames_in + + if num_frames_out is None or num_frames_out == 0: + num_frames_out = num_frames_per_chunk + + if num_frames is not None: + num_frames_out = min(num_frames_out, num_frames) - # NOTE: assumes sigma(t) \in [0, 1] - sigma_t = ( - torch.tensor(self.scheduler.sigmas[i].item()) - .unsqueeze(0) - .to(device=device, dtype=transformer_dtype) + assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})" + + video = self.video_processor.preprocess_video(video, height, width) + video = _maybe_pad_or_trim_video(video, num_frames_out) + else: + num_frames_out = len(controls) + if num_frames is not None: + num_frames_out = min(num_frames_out, num_frames) + control_video = self.video_processor.preprocess_video(controls, height, width) + + min_chunk_len = 1 if num_frames_out <= 1 else self.vae_scale_factor_temporal + 1 + if num_frames_per_chunk < min_chunk_len: + logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len") + num_frames_per_chunk = min_chunk_len + + max_frames_by_rope = None + if getattr(self.transformer.config, "max_size", None) is not None: + max_frames_by_rope = max( + size // patch + for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size) + ) + if num_frames_per_chunk > max_frames_by_rope: + logger.warning( + f"{num_frames_per_chunk=} is too large for RoPE setting to maximum ({max_frames_by_rope=})" ) + num_frames_per_chunk = max_frames_by_rope - in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents - in_latents = in_latents.to(transformer_dtype) - in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t - control_blocks = None - if controls_latents is not None and self.controlnet is not None: - control_output = self.controlnet( - controls_latents=controls_latents, - latents=in_latents, - timestep=in_timestep, - encoder_hidden_states=encoder_hidden_states, - condition_mask=cond_mask, - conditioning_scale=controls_conditioning_scale, - padding_mask=padding_mask, - return_dict=False, + if num_conditional_frames >= num_frames_per_chunk: + raise ValueError( + f"{num_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation." + ) + + # chunk information + num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1 + chunk_stride = num_frames_per_chunk - num_conditional_frames + chunk_idxs = [ + (start_idx, min(start_idx + num_frames_per_chunk, num_frames_out)) + for start_idx in range(0, num_frames_out - num_conditional_frames, chunk_stride) + ] + + video_chunks = [] + latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device) + latents_std = self.latents_std.to(dtype=vae_dtype, device=device) + + def decode_latents(latents): + latents = latents * latents_std + latents_mean + video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0] + return video + + latents_arg = latents + initial_num_cond_latent_frames = 0 if video is None or controls is not None else num_cond_latent_frames + latent_chunks = [] + num_chunks = len(chunk_idxs) + total_steps = num_inference_steps * num_chunks + with self.progress_bar(total=total_steps) as progress_bar: + for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs): + if chunk_idx == 0 and controls is not None: + prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype) + prev_output = self.video_processor.preprocess_video(prev_output, height, width) + else: + prev_output = video_chunks[-1].clone() if chunk_idx != 0 else video.clone() + if num_conditional_frames > 0: + prev_output[:, :, :num_conditional_frames] = prev_output[:, :, -num_conditional_frames:] + prev_output[:, :, num_conditional_frames:] = -1 # -1 == 0 in processed video space + else: + prev_output.fill_(-1) + + chunk_video = prev_output.to(device=device, dtype=vae_dtype) + chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk) + latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents( + video=chunk_video, + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=self.transformer.config.in_channels - 1, + height=height, + width=width, + num_frames_in=chunk_video.shape[2], + num_frames_out=num_frames_per_chunk, + do_classifier_free_guidance=self.do_classifier_free_guidance, + dtype=torch.float32, + device=device, + generator=generator, + num_cond_latent_frames=initial_num_cond_latent_frames + if chunk_idx == 0 + else num_cond_latent_frames, + latents=latents_arg, + ) + cond_mask = cond_mask.to(transformer_dtype) + cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep + padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) + + controls_latents = None + if controls is not None: + chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to( + device=device, dtype=self.vae.dtype + ) + chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk) + controls_latents = [ + retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) + for vid in chunk_control_video + ] + controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype) + + latents_mean = self.latents_mean.to(device=device, dtype=transformer_dtype) + latents_std = self.latents_std.to(device=device, dtype=transformer_dtype) + controls_latents = (controls_latents - latents_mean) / latents_std + + # Denoising loop + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + self._num_timesteps = len(timesteps) + + gt_velocity = (latents - cond_latent) * cond_mask + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t.cpu().item() + + # NOTE: assumes sigma(t) \in [0, 1] + sigma_t = ( + torch.tensor(self.scheduler.sigmas[i].item()) + .unsqueeze(0) + .to(device=device, dtype=transformer_dtype) ) - control_blocks = control_output[0] - - noise_pred = self.transformer( - hidden_states=in_latents, - timestep=in_timestep, - encoder_hidden_states=encoder_hidden_states, - block_controlnet_hidden_states=control_blocks, - condition_mask=cond_mask, - padding_mask=padding_mask, - return_dict=False, - )[0] - noise_pred = gt_velocity + noise_pred * (1 - cond_mask) - - if self.do_classifier_free_guidance: + + in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents + in_latents = in_latents.to(transformer_dtype) + in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t control_blocks = None if controls_latents is not None and self.controlnet is not None: control_output = self.controlnet( controls_latents=controls_latents, latents=in_latents, timestep=in_timestep, - encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + encoder_hidden_states=encoder_hidden_states, condition_mask=cond_mask, conditioning_scale=controls_conditioning_scale, padding_mask=padding_mask, @@ -840,46 +918,76 @@ def __call__( ) control_blocks = control_output[0] - noise_pred_neg = self.transformer( + noise_pred = self.transformer( hidden_states=in_latents, timestep=in_timestep, - encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + encoder_hidden_states=encoder_hidden_states, block_controlnet_hidden_states=control_blocks, condition_mask=cond_mask, padding_mask=padding_mask, return_dict=False, )[0] - # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only - noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) - noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + noise_pred = gt_velocity + noise_pred * (1 - cond_mask) + + if self.do_classifier_free_guidance: + control_blocks = None + if controls_latents is not None and self.controlnet is not None: + control_output = self.controlnet( + controls_latents=controls_latents, + latents=in_latents, + timestep=in_timestep, + encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + condition_mask=cond_mask, + conditioning_scale=controls_conditioning_scale, + padding_mask=padding_mask, + return_dict=False, + ) + control_blocks = control_output[0] + + noise_pred_neg = self.transformer( + hidden_states=in_latents, + timestep=in_timestep, + encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt + block_controlnet_hidden_states=control_blocks, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] + # NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only + noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask) + noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg) + + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] - latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + # call the callback, if provided + if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() + if XLA_AVAILABLE: + xm.mark_step() - if XLA_AVAILABLE: - xm.mark_step() + video_chunks.append(decode_latents(latents).detach().cpu()) + latent_chunks.append(latents.detach().cpu()) self._current_timestep = None if not output_type == "latent": - latents_mean = self.latents_mean.to(latents.device, latents.dtype) - latents_std = self.latents_std.to(latents.device, latents.dtype) - latents = latents * latents_std + latents_mean - video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] - video = self._match_num_frames(video, num_frames) + video_chunks = [ + chunk[:, :, num_conditional_frames:, ...] if chunk_idx != 0 else chunk + for chunk_idx, chunk in enumerate(video_chunks) + ] + video = torch.cat(video_chunks, dim=2) + video = video[:, :, :num_frames_out, ...] assert self.safety_checker is not None self.safety_checker.to(device) @@ -896,7 +1004,13 @@ def __call__( video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) video = self.video_processor.postprocess_video(video, output_type=output_type) else: - video = latents + latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1 + latent_chunks = [ + chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk + for chunk_idx, chunk in enumerate(latent_chunks) + ] + video = torch.cat(latent_chunks, dim=2) + video = video[:, :, :latent_T, ...] # Offload all models self.maybe_free_model_hooks() @@ -905,19 +1019,3 @@ def __call__( return (video,) return CosmosPipelineOutput(frames=video) - - def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor: - if target_num_frames <= 0 or video.shape[2] == target_num_frames: - return video - - frames_per_latent = max(self.vae_scale_factor_temporal, 1) - video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2) - - current_frames = video.shape[2] - if current_frames < target_num_frames: - pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1) - video = torch.cat([video, pad], dim=2) - elif current_frames > target_num_frames: - video = video[:, :, :target_num_frames] - - return video diff --git a/tests/models/controlnets/test_models_controlnet_cosmos.py b/tests/models/controlnets/test_models_controlnet_cosmos.py index bf879b11663b..e72c93b4d5e2 100644 --- a/tests/models/controlnets/test_models_controlnet_cosmos.py +++ b/tests/models/controlnets/test_models_controlnet_cosmos.py @@ -131,6 +131,26 @@ def test_output_list_format(self): self.assertIsInstance(output[0], list) self.assertEqual(len(output[0]), init_dict["n_controlnet_blocks"]) + def test_condition_mask_changes_output(self): + """Test that condition mask affects control outputs.""" + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + inputs_no_mask = dict(inputs_dict) + inputs_no_mask["condition_mask"] = torch.zeros_like(inputs_dict["condition_mask"]) + + with torch.no_grad(): + output_no_mask = model(**inputs_no_mask) + output_with_mask = model(**inputs_dict) + + self.assertEqual(len(output_no_mask.control_block_samples), len(output_with_mask.control_block_samples)) + for no_mask_tensor, with_mask_tensor in zip( + output_no_mask.control_block_samples, output_with_mask.control_block_samples + ): + self.assertFalse(torch.equal(no_mask_tensor, with_mask_tensor)) + def test_conditioning_scale_single(self): """Test that a single conditioning scale is broadcast to all blocks.""" init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/pipelines/cosmos/test_cosmos2_5_transfer.py b/tests/pipelines/cosmos/test_cosmos2_5_transfer.py index 932443bceea2..42757ed2ad2d 100644 --- a/tests/pipelines/cosmos/test_cosmos2_5_transfer.py +++ b/tests/pipelines/cosmos/test_cosmos2_5_transfer.py @@ -212,6 +212,51 @@ def test_inference(self): self.assertEqual(generated_video.shape, (3, 3, 32, 32)) self.assertTrue(torch.isfinite(generated_video).all()) + def test_inference_autoregressive_multi_chunk(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_frames"] = 5 + inputs["num_frames_per_chunk"] = 3 + inputs["num_conditional_frames"] = 1 + + video = pipe(**inputs).frames + generated_video = video[0] + self.assertEqual(generated_video.shape, (5, 3, 32, 32)) + self.assertTrue(torch.isfinite(generated_video).all()) + + def test_inference_autoregressive_multi_chunk_no_condition_video_ignored(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + inputs["num_frames"] = 5 + inputs["num_frames_per_chunk"] = 3 + inputs["num_conditional_frames"] = 0 + + video_a = torch.zeros(5, 3, 32, 32) + video_b = torch.ones(5, 3, 32, 32) + + inputs["video"] = video_a + inputs["generator"] = torch.Generator(device=device).manual_seed(0) + out_a = pipe(**inputs).frames + + inputs["video"] = video_b + inputs["generator"] = torch.Generator(device=device).manual_seed(0) + out_b = pipe(**inputs).frames + + # output should be independent of input video + self.assertTrue(torch.allclose(out_a, out_b)) + def test_inference_with_controls(self): """Test inference with control inputs (ControlNet).""" device = "cpu" @@ -222,13 +267,13 @@ def test_inference_with_controls(self): pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) - # Add control video input - should be a video tensor - inputs["controls"] = [torch.randn(3, 3, 32, 32)] # num_frames, channels, height, width + inputs["controls"] = [torch.randn(3, 32, 32) for _ in range(5)] # list of 5 frames (C, H, W) inputs["controls_conditioning_scale"] = 1.0 + inputs["num_frames"] = None video = pipe(**inputs).frames generated_video = video[0] - self.assertEqual(generated_video.shape, (3, 3, 32, 32)) + self.assertEqual(generated_video.shape, (5, 3, 32, 32)) self.assertTrue(torch.isfinite(generated_video).all()) def test_callback_inputs(self):