diff --git a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py index 1a77a652786d..2054cc291ec7 100644 --- a/src/diffusers/schedulers/scheduling_ddim_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_ddim_cogvideox.py @@ -383,7 +383,7 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + prev_timestep = self.previous_timestep(timestep) # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] @@ -500,5 +500,24 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity + def previous_timestep(self, timestep: int) -> torch.Tensor: + """ + Find the previous timestep in the scheduler's timestep schedule. + + Args: + timestep (`int`): + The current timestep. + + Returns: + `torch.Tensor`: + The previous timestep. Returns -1 if the current timestep is the last one. + """ + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + return prev_t + def __len__(self) -> int: return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py index 0c576467a19a..b939fe9a120e 100644 --- a/src/diffusers/schedulers/scheduling_dpm_cogvideox.py +++ b/src/diffusers/schedulers/scheduling_dpm_cogvideox.py @@ -467,7 +467,7 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + prev_timestep = self.previous_timestep(timestep) # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] @@ -599,5 +599,24 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity + def previous_timestep(self, timestep: int) -> torch.Tensor: + """ + Find the previous timestep in the scheduler's timestep schedule. + + Args: + timestep (`int`): + The current timestep. + + Returns: + `torch.Tensor`: + The previous timestep. Returns -1 if the current timestep is the last one. + """ + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + return prev_t + def __len__(self) -> int: return self.config.num_train_timesteps