Skip to content

Commit 53e425d

Browse files
Basic tests for flow unipc
1 parent a17c98d commit 53e425d

File tree

4 files changed

+454
-109
lines changed

4 files changed

+454
-109
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos25_predict.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def retrieve_latents(
118118
... ).frames[0]
119119
>>> export_to_video(video, "video2world.mp4", fps=16)
120120
121-
>>> # To produce a single-frame image instead of a world clip, set num_frames=1 and
121+
>>> # To produce a single-frame image instead of a world (video) clip, set num_frames=1 and
122122
>>> # save the first frame: pipe(..., num_frames=1).frames[0][0].
123123
```
124124
"""
@@ -201,7 +201,6 @@ def _get_prompt_embeds(
201201
dtype = dtype or self.text_encoder.dtype
202202
prompt = [prompt] if isinstance(prompt, str) else prompt
203203

204-
# Tokenize prompts
205204
input_ids_batch = []
206205

207206
for sample_idx in range(len(prompt)):
@@ -257,7 +256,7 @@ def _get_prompt_embeds(
257256

258257
return prompt_embeds
259258

260-
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt with num_videos_per_prompt->num_videos_per_prompt
259+
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.encode_prompt
261260
def encode_prompt(
262261
self,
263262
prompt: Union[str, List[str]],
@@ -505,7 +504,6 @@ def __call__(
505504
] = None,
506505
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
507506
max_sequence_length: int = 512,
508-
shift: float = 5.0,
509507
conditional_frame_timestep: float = 0.1,
510508
):
511509
r"""
@@ -656,13 +654,16 @@ def __call__(
656654
video = self.video_processor.preprocess_video(video, height, width)
657655

658656
# pad with last frame (for video2world)
659-
if video.shape[2] < num_frames:
657+
num_frames_out = num_frames
658+
if video.shape[2] < num_frames_out:
660659
assert batch_size == 1, "batch_size must be 1 for padding frames"
661-
n_pad_frames = num_frames - num_frames_in
660+
n_pad_frames = num_frames_out - num_frames_in
662661
last_frame = video[0, :, -1:, :, :] # [C, T==1, H, W]
663662
pad_frames = last_frame.repeat(1, 1, n_pad_frames, 1, 1) # [B, C, T, H, W]
664663
video = torch.cat((video, pad_frames), dim=2)
665664

665+
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
666+
666667
video = video.to(device=device, dtype=vae_dtype)
667668

668669
num_channels_latents = self.transformer.config.in_channels - 1
@@ -686,7 +687,7 @@ def __call__(
686687
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
687688

688689
# Denoising loop
689-
self.scheduler.set_timesteps(num_inference_steps, shift=shift, device=device)
690+
self.scheduler.set_timesteps(num_inference_steps, device=device)
690691
timesteps = self.scheduler.timesteps
691692
self._num_timesteps = len(timesteps)
692693
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -699,10 +700,10 @@ def __call__(
699700

700701
self._current_timestep = t.cpu().item()
701702

702-
# NOTE: sigmas are in [0, 1] in FlowUniPCMultistepScheduler
703-
sigma_t = torch.tensor(self.scheduler.sigmas[i]).unsqueeze(0).to(device=device, dtype=transformer_dtype)
703+
# NOTE: assumes sigma(t) \in [0, 1]
704+
sigma_t = torch.tensor(self.scheduler.sigmas[i].item()).unsqueeze(0).to(device=device, dtype=transformer_dtype)
704705

705-
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents # TODO: could use cond_indicator
706+
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
706707
in_latents = in_latents.to(transformer_dtype)
707708
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
708709
noise_pred = self.transformer(
@@ -725,7 +726,7 @@ def __call__(
725726
padding_mask=padding_mask,
726727
return_dict=False,
727728
)[0]
728-
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
729+
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
729730
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
730731
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
731732

src/diffusers/schedulers/scheduling_flow_unipc_multistep.py

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2-
3-
import math
4-
from typing import List, Optional, Tuple, Union
1+
from typing import List, Literal, Optional, Tuple, Union
52

63
import numpy as np
74
import torch
@@ -11,35 +8,38 @@
118
from diffusers.utils import deprecate
129

1310

14-
def _get_karras_sigmas(self, num_steps: int, sigma_max: float, sigma_min: float, rho: int, final_sigmas_type: str):
11+
def _get_karras_sigmas(num_train_steps: int, num_steps: int, sigma_max: float, sigma_min: float, rho: int, final_sigmas_type: str):
1512
sigmas = np.arange(num_steps + 1, dtype=np.float32) / num_steps
1613
min_inv_rho = sigma_min ** (1 / rho)
1714
max_inv_rho = sigma_max ** (1 / rho)
1815
sigmas = (max_inv_rho + sigmas * (min_inv_rho - max_inv_rho)) ** rho
1916
sigmas = sigmas / (1 + sigmas)
2017

21-
if self.config.final_sigmas_type == "zero":
18+
if final_sigmas_type == "zero":
2219
sigma_last = 0
20+
elif final_sigmas_type == "sigma_min":
21+
sigma_last = sigmas[-1]
2322
else:
2423
raise ValueError(
25-
f"`final_sigmas_type` must be 'zero' but got {self.config.final_sigmas_type}"
24+
f"`final_sigmas_type` must be 'zero' or 'sigma_min' but got {final_sigmas_type}"
2625
)
2726

28-
timesteps = torch.from_numpy(sigmas * self.config.num_train_timesteps).to(torch.int64)
29-
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
27+
timesteps = torch.from_numpy(sigmas * num_train_steps).to(torch.int64)
28+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
3029
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
3130
return sigmas, timesteps
3231

3332

3433
class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
3534
"""
36-
`FlowUniPCMultistepScheduler` is the UniPC algorithm [1] for flow matching [2], but strictly uses the Karras sigmas [3].
35+
`FlowUniPCMultistepScheduler` is the UniPC algorithm[1] for flow
36+
matching[2], but strictly uses the Karras sigmas [3] (i.e. it follows the EDMEulerScheduler).
37+
38+
Note this a simplified version of `UniPCMultistepScheduler`, as it:
39+
1. Does not have variance preserving sigmas
40+
2. Does not store betas and other variables used by `UniPCMultistepScheduler`
41+
3. Assumes prediction_type == "flow_prediction" (this parameter is removed)
3742
38-
Note this a simplified version of `UniPCMultistepScheduler`, as:
39-
1. it does not have variance preserving sigmas
40-
2. it does not store betas and other variables used by `UniPCMultistepScheduler`
41-
3. it assumes prediction_type == "flow_prediction" (this variable is removed from `FlowUniPCMultistepScheduler`)
42-
4343
References:
4444
[1] Wang, Chong, et al. "UniPC: A Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models" https://arxiv.org/abs/2302.04867
4545
[2] Lipman, Chen, et al. "Flow matching for generative modeling." https://arxiv.org/abs/2210.02747
@@ -97,7 +97,7 @@ def __init__(
9797
lower_order_final: bool = True,
9898
disable_corrector: List[int] = [],
9999
solver_p: SchedulerMixin = None,
100-
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
100+
final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
101101
rho: int = 7,
102102
sigma_max: float = 200.0,
103103
sigma_min: float = 0.01,
@@ -109,21 +109,15 @@ def __init__(
109109
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
110110

111111
self.predict_x0 = predict_x0
112-
self.num_inference_steps = None
113112
self.disable_corrector = disable_corrector
113+
self.solver_p = solver_p
114+
self.num_inference_steps = None
114115

115-
self.sigmas, self.timesteps = _get_karras_sigmas(self, num_train_timesteps, sigma_max, sigma_min, rho, final_sigmas_type)
116+
self.sigmas, self.timesteps = _get_karras_sigmas(num_train_timesteps, num_train_timesteps, sigma_max, sigma_min, rho, final_sigmas_type)
116117
self.sigma_min = self.sigmas[-1].item()
117118
self.sigma_max = self.sigmas[0].item()
118119

119-
self.last_sample = None
120-
self._step_index = None
121-
self._begin_index = None
122-
self.model_outputs = [None] * self.config.solver_order
123-
self.timestep_list = [None] * self.config.solver_order
124-
self.lower_order_nums = 0
125-
self.solver_p = self.config.solver_p
126-
120+
self._reset_state(solver_order)
127121

128122
@property
129123
def step_index(self):
@@ -169,22 +163,29 @@ def set_timesteps(
169163
"""
170164
assert sigmas is None, "sigmas are not supported for FlowUniPCMultistepScheduler"
171165

172-
self.sigmas, self.timesteps = _get_karras_sigmas(self, num_inference_steps, self.config.sigma_max, self.config.sigma_min, self.config.rho, self.config.final_sigmas_type)
166+
self.sigmas, self.timesteps = _get_karras_sigmas(self.config.num_train_timesteps, num_inference_steps, self.config.sigma_max, self.config.sigma_min, self.config.rho, self.config.final_sigmas_type)
173167
self.num_inference_steps = len(self.timesteps)
174168

175169
self.sigma_min = self.sigmas[-1].item()
176170
self.sigma_max = self.sigmas[0].item()
177171

172+
self.sigmas = self.sigmas.to(device)
173+
self.timesteps = self.timesteps.to(device)
174+
self._reset_state()
175+
176+
def _reset_state(self, solver_order: Optional[int] = None):
177+
"""
178+
Resets the noise schedule & solver state variables
179+
"""
180+
solver_order = solver_order or self.config.solver_order
181+
self.model_outputs = [None] * solver_order
182+
self.timestep_list = [None] * solver_order
183+
self.lower_order_nums = 0
178184
self.last_sample = None
179185
self._step_index = None
180186
self._begin_index = None
181-
self.model_outputs = [None] * self.config.solver_order
182-
self.timestep_list = [None] * self.config.solver_order
183-
self.lower_order_nums = 0
184-
self.solver_p = self.config.solver_p
185-
186-
self.sigmas = self.sigmas.to(device)
187-
self.timesteps = self.timesteps.to(device)
187+
if self.solver_p:
188+
self.solver_p.set_timesteps(self.num_inference_steps, device=device)
188189

189190
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
190191
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
@@ -575,6 +576,7 @@ def _init_step_index(self, timestep):
575576
else:
576577
self._step_index = self._begin_index
577578

579+
# Modified from diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler.step
578580
def step(
579581
self,
580582
model_output: torch.Tensor,
@@ -638,7 +640,7 @@ def step(
638640
this_order = self.config.solver_order
639641

640642
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
641-
assert self.this_order > 0
643+
assert self.this_order > 0, "expected this_order > 0, this could be due to duplicate timesteps"
642644

643645
self.last_sample = sample
644646
prev_sample = self.multistep_uni_p_bh_update(

0 commit comments

Comments
 (0)