From 1f7b3f171abe0f150ac49b5f25ad56bbc2ecab30 Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Thu, 12 Feb 2026 00:08:25 +0530 Subject: [PATCH] Fix CogVideoX scheduler prev_timestep for non-leading spacing When timestep_spacing is set to 'linspace' or 'trailing', the formula `timestep - num_train_timesteps // num_inference_steps` does not produce the correct previous timestep because timesteps are not uniformly spaced by that stride. This causes wrong alpha_prod_t_prev lookups and produces incorrect denoising results. Replace the hardcoded arithmetic with a `previous_timestep()` method that looks up the actual next entry in `self.timesteps`, matching the approach already used in DDPMScheduler. --- .../schedulers/scheduling_ddim_cogvideox.py | 21 ++++++++++++++++++- .../schedulers/scheduling_dpm_cogvideox.py | 21 ++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index 1a77a652786d..2054cc291ec7 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -383,7 +383,7 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + prev_timestep = self.previous_timestep(timestep) # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] @@ -500,5 +500,24 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity + def previous_timestep(self, timestep: int) -> torch.Tensor: + """ + Find the previous timestep in the scheduler's timestep schedule. + + Args: + timestep (`int`): + The current timestep. + + Returns: + `torch.Tensor`: + The previous timestep. Returns -1 if the current timestep is the last one. + """ + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + return prev_t + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py index 0c576467a19a..b939fe9a120e 100644 --- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -467,7 +467,7 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + prev_timestep = self.previous_timestep(timestep) # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] @@ -599,5 +599,24 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity + def previous_timestep(self, timestep: int) -> torch.Tensor: + """ + Find the previous timestep in the scheduler's timestep schedule. + + Args: + timestep (`int`): + The current timestep. + + Returns: + `torch.Tensor`: + The previous timestep. Returns -1 if the current timestep is the last one. + """ + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + return prev_t + def __len__(self) -> int: return self.config.num_train_timesteps