Skip to content

Commit df085fa

Browse files
simplify inference pipeline
1 parent c5aaa51 commit df085fa

File tree

2 files changed

+21
-23
lines changed

2 files changed

+21
-23
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos25_predict.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,6 @@ def __call__(
506506
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
507507
max_sequence_length: int = 512,
508508
shift: float = 5.0,
509-
timestep_scale: float = 0.001,
510509
conditional_frame_timestep: float = 0.1,
511510
):
512511
r"""
@@ -592,7 +591,7 @@ def __call__(
592591
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
593592
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
594593

595-
# 1. Check inputs. Raise error if not correct
594+
# Check inputs. Raise error if not correct
596595
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
597596

598597
self._guidance_scale = guidance_scale
@@ -613,15 +612,15 @@ def __call__(
613612
)
614613
self.safety_checker.to("cpu")
615614

616-
# 2. Define call parameters
615+
# Define call parameters
617616
if prompt is not None and isinstance(prompt, str):
618617
batch_size = 1
619618
elif prompt is not None and isinstance(prompt, list):
620619
batch_size = len(prompt)
621620
else:
622621
batch_size = prompt_embeds.shape[0]
623622

624-
# 3. Encode input prompt
623+
# Encode input prompt
625624
(
626625
prompt_embeds,
627626
negative_prompt_embeds,
@@ -639,10 +638,6 @@ def __call__(
639638
vae_dtype = self.vae.dtype
640639
transformer_dtype = self.transformer.dtype
641640

642-
# 4. Prepare timesteps
643-
self.scheduler.set_timesteps(num_inference_steps, shift=shift, device=device, scale=timestep_scale)
644-
timesteps = torch.tensor(self.scheduler.timesteps).to(transformer_dtype)
645-
646641
num_frames_in = None
647642
if image is not None:
648643
# TODO: handle batch_size > 1
@@ -690,22 +685,26 @@ def __call__(
690685

691686
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
692687

693-
# 6. Denoising loop
694-
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
688+
# Denoising loop
689+
self.scheduler.set_timesteps(num_inference_steps, shift=shift, device=device)
690+
timesteps = self.scheduler.timesteps
695691
self._num_timesteps = len(timesteps)
692+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
696693

697-
gt_velocity = latents - cond_latent
694+
gt_velocity = (latents - cond_latent) * cond_mask
698695
with self.progress_bar(total=num_inference_steps) as progress_bar:
699696
for i, t in enumerate(timesteps):
700697
if self.interrupt:
701698
continue
702699

703700
self._current_timestep = t.cpu().item()
704-
timestep = t.unsqueeze(0)
701+
702+
# NOTE: equivalent to t / 1000 for FlowUniPCMultistepScheduler (sigmas are in [0, 1], num_train_timesteps=1000)
703+
sigma_t = torch.tensor(self.scheduler.sigmas[i]).unsqueeze(0).to(device=device, dtype=transformer_dtype)
705704

706705
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents # TODO: could use cond_indicator
707706
in_latents = in_latents.to(transformer_dtype)
708-
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * timestep
707+
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
709708
noise_pred = self.transformer(
710709
hidden_states=in_latents,
711710
condition_mask=cond_mask,
@@ -714,8 +713,8 @@ def __call__(
714713
padding_mask=padding_mask,
715714
return_dict=False,
716715
)[0]
717-
# NOTE: force input video latents for noise_pred by correcting velocity
718-
noise_pred = gt_velocity * cond_mask + noise_pred * (1 - cond_mask)
716+
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
717+
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
719718

720719
if self.do_classifier_free_guidance:
721720
noise_pred_neg = self.transformer(
@@ -726,8 +725,8 @@ def __call__(
726725
padding_mask=padding_mask,
727726
return_dict=False,
728727
)[0]
729-
# NOTE: force input video latents for noise_pred by correcting velocity
730-
noise_pred_neg = gt_velocity * cond_mask + noise_pred_neg * (1 - cond_mask)
728+
# NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
729+
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
731730
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
732731

733732
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

src/diffusers/schedulers/scheduling_flow_unipc_multistep.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ def set_timesteps(
157157
sigmas: Optional[List[float]] = None,
158158
mu: Optional[Union[float, None]] = None,
159159
shift: Optional[Union[float, None]] = None,
160-
scale: float = 0.001,
161160
):
162161
"""
163162
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -192,7 +191,10 @@ def set_timesteps(
192191
shift = self.config.shift
193192
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
194193

195-
if self.config.final_sigmas_type == "zero":
194+
if self.config.final_sigmas_type == "sigma_min":
195+
# TODO(migmartin): this raises an error, rewrite this class
196+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
197+
elif self.config.final_sigmas_type == "zero":
196198
sigma_last = 0
197199
else:
198200
raise ValueError(
@@ -203,8 +205,7 @@ def set_timesteps(
203205
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
204206

205207
self.sigmas = torch.from_numpy(sigmas)
206-
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
207-
self.timesteps *= scale
208+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
208209

209210
self.num_inference_steps = len(timesteps)
210211

@@ -304,7 +305,6 @@ def convert_model_output(
304305
sigma = self.sigmas[self.step_index]
305306
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
306307

307-
# print("sigma_t ==>", self.step_index, sigma, sigma_t, alpha_t, sample.shape, model_output.shape)
308308
if self.predict_x0:
309309
if self.config.prediction_type == "flow_prediction":
310310
sigma_t = self.sigmas[self.step_index]
@@ -317,7 +317,6 @@ def convert_model_output(
317317

318318
if self.config.thresholding:
319319
x0_pred = self._threshold_sample(x0_pred)
320-
# print("self.config.thresholding", self.config.thresholding)
321320
return x0_pred
322321
else:
323322
if self.config.prediction_type == "flow_prediction":

0 commit comments

Comments
 (0)