diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 92c3e20013dd..a58f416fd4a2 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -442,7 +442,7 @@ def step( # - pred_prev_sample -> "x_t-1" # 1. get previous step value (=t-1) - prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + prev_timestep = self.previous_timestep(timestep) # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] @@ -591,5 +591,13 @@ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: tor velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample return velocity + def previous_timestep(self, timestep: int): + index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0] + if index == self.timesteps.shape[0] - 1: + prev_t = torch.tensor(-1) + else: + prev_t = self.timesteps[index + 1] + return prev_t + def __len__(self) -> int: return self.config.num_train_timesteps