Skip to content

Commit c5aaa51

Browse files
scheduler cleanup
1 parent f6519ec commit c5aaa51

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos25_predict.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,7 @@ def __call__(
506506
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
507507
max_sequence_length: int = 512,
508508
shift: float = 5.0,
509+
timestep_scale: float = 0.001,
509510
conditional_frame_timestep: float = 0.1,
510511
):
511512
r"""
@@ -635,14 +636,13 @@ def __call__(
635636
max_sequence_length=max_sequence_length,
636637
)
637638

638-
# 4. Prepare timesteps
639-
self.scheduler.set_timesteps(num_inference_steps, shift=shift, device=device)
640-
timesteps = self.scheduler.timesteps
641-
642-
# 5. Prepare latent variables
643639
vae_dtype = self.vae.dtype
644640
transformer_dtype = self.transformer.dtype
645641

642+
# 4. Prepare timesteps
643+
self.scheduler.set_timesteps(num_inference_steps, shift=shift, device=device, scale=timestep_scale)
644+
timesteps = torch.tensor(self.scheduler.timesteps).to(transformer_dtype)
645+
646646
num_frames_in = None
647647
if image is not None:
648648
# TODO: handle batch_size > 1
@@ -700,12 +700,8 @@ def __call__(
700700
if self.interrupt:
701701
continue
702702

703-
self._current_timestep = t
704-
705-
timestep = torch.stack([t]).to(torch.float32)
706-
# TODO: make scheduler scale this instead
707-
timestep *= 0.001 # NOTE: timestep scale
708-
timestep = timestep.to(transformer_dtype)
703+
self._current_timestep = t.cpu().item()
704+
timestep = t.unsqueeze(0)
709705

710706
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents # TODO: could use cond_indicator
711707
in_latents = in_latents.to(transformer_dtype)

src/diffusers/schedulers/scheduling_flow_unipc_multistep.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def set_timesteps(
157157
sigmas: Optional[List[float]] = None,
158158
mu: Optional[Union[float, None]] = None,
159159
shift: Optional[Union[float, None]] = None,
160+
scale: float = 0.001,
160161
):
161162
"""
162163
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -191,10 +192,7 @@ def set_timesteps(
191192
shift = self.config.shift
192193
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) # pyright: ignore
193194

194-
if self.config.final_sigmas_type == "sigma_min":
195-
# TODO(migmartin): this raises an error, rewrite this class
196-
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
197-
elif self.config.final_sigmas_type == "zero":
195+
if self.config.final_sigmas_type == "zero":
198196
sigma_last = 0
199197
else:
200198
raise ValueError(
@@ -205,7 +203,8 @@ def set_timesteps(
205203
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32) # pyright: ignore
206204

207205
self.sigmas = torch.from_numpy(sigmas)
208-
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
206+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.float32)
207+
self.timesteps *= scale
209208

210209
self.num_inference_steps = len(timesteps)
211210

0 commit comments

Comments
 (0)