1- # Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py
2-
3- import math
4- from typing import List , Optional , Tuple , Union
1+ from typing import List , Literal , Optional , Tuple , Union
52
63import numpy as np
74import torch
118from diffusers .utils import deprecate
129
1310
14- def _get_karras_sigmas (self , num_steps : int , sigma_max : float , sigma_min : float , rho : int , final_sigmas_type : str ):
11+ def _get_karras_sigmas (num_train_steps : int , num_steps : int , sigma_max : float , sigma_min : float , rho : int , final_sigmas_type : str ):
1512 sigmas = np .arange (num_steps + 1 , dtype = np .float32 ) / num_steps
1613 min_inv_rho = sigma_min ** (1 / rho )
1714 max_inv_rho = sigma_max ** (1 / rho )
1815 sigmas = (max_inv_rho + sigmas * (min_inv_rho - max_inv_rho )) ** rho
1916 sigmas = sigmas / (1 + sigmas )
2017
21- if self . config . final_sigmas_type == "zero" :
18+ if final_sigmas_type == "zero" :
2219 sigma_last = 0
20+ elif final_sigmas_type == "sigma_min" :
21+ sigma_last = sigmas [- 1 ]
2322 else :
2423 raise ValueError (
25- f"`final_sigmas_type` must be 'zero' but got { self . config . final_sigmas_type } "
24+ f"`final_sigmas_type` must be 'zero' or 'sigma_min' but got { final_sigmas_type } "
2625 )
2726
28- timesteps = torch .from_numpy (sigmas * self . config . num_train_timesteps ).to (torch .int64 )
29- sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 ) # pyright: ignore
27+ timesteps = torch .from_numpy (sigmas * num_train_steps ).to (torch .int64 )
28+ sigmas = np .concatenate ([sigmas , [sigma_last ]]).astype (np .float32 )
3029 sigmas = torch .from_numpy (sigmas ).to (dtype = torch .float32 )
3130 return sigmas , timesteps
3231
3332
3433class FlowUniPCMultistepScheduler (SchedulerMixin , ConfigMixin ):
3534 """
36- `FlowUniPCMultistepScheduler` is the UniPC algorithm [1] for flow matching [2], but strictly uses the Karras sigmas [3].
35+ `FlowUniPCMultistepScheduler` is the UniPC algorithm[1] for flow
36+ matching[2], but strictly uses the Karras sigmas [3] (i.e. it follows the EDMEulerScheduler).
37+
38+ Note this a simplified version of `UniPCMultistepScheduler`, as it:
39+ 1. Does not have variance preserving sigmas
40+ 2. Does not store betas and other variables used by `UniPCMultistepScheduler`
41+ 3. Assumes prediction_type == "flow_prediction" (this parameter is removed)
3742
38- Note this a simplified version of `UniPCMultistepScheduler`, as:
39- 1. it does not have variance preserving sigmas
40- 2. it does not store betas and other variables used by `UniPCMultistepScheduler`
41- 3. it assumes prediction_type == "flow_prediction" (this variable is removed from `FlowUniPCMultistepScheduler`)
42-
4343 References:
4444 [1] Wang, Chong, et al. "UniPC: A Unified Predictor-Corrector Framework for Fast Sampling of Diffusion Models" https://arxiv.org/abs/2302.04867
4545 [2] Lipman, Chen, et al. "Flow matching for generative modeling." https://arxiv.org/abs/2210.02747
@@ -97,7 +97,7 @@ def __init__(
9797 lower_order_final : bool = True ,
9898 disable_corrector : List [int ] = [],
9999 solver_p : SchedulerMixin = None ,
100- final_sigmas_type : Optional [ str ] = "zero" , # "zero", "sigma_min"
100+ final_sigmas_type : Literal [ "zero" , "sigma_min" ] = "zero" ,
101101 rho : int = 7 ,
102102 sigma_max : float = 200.0 ,
103103 sigma_min : float = 0.01 ,
@@ -109,21 +109,15 @@ def __init__(
109109 raise NotImplementedError (f"{ solver_type } is not implemented for { self .__class__ } " )
110110
111111 self .predict_x0 = predict_x0
112- self .num_inference_steps = None
113112 self .disable_corrector = disable_corrector
113+ self .solver_p = solver_p
114+ self .num_inference_steps = None
114115
115- self .sigmas , self .timesteps = _get_karras_sigmas (self , num_train_timesteps , sigma_max , sigma_min , rho , final_sigmas_type )
116+ self .sigmas , self .timesteps = _get_karras_sigmas (num_train_timesteps , num_train_timesteps , sigma_max , sigma_min , rho , final_sigmas_type )
116117 self .sigma_min = self .sigmas [- 1 ].item ()
117118 self .sigma_max = self .sigmas [0 ].item ()
118119
119- self .last_sample = None
120- self ._step_index = None
121- self ._begin_index = None
122- self .model_outputs = [None ] * self .config .solver_order
123- self .timestep_list = [None ] * self .config .solver_order
124- self .lower_order_nums = 0
125- self .solver_p = self .config .solver_p
126-
120+ self ._reset_state (solver_order )
127121
128122 @property
129123 def step_index (self ):
@@ -169,22 +163,29 @@ def set_timesteps(
169163 """
170164 assert sigmas is None , "sigmas are not supported for FlowUniPCMultistepScheduler"
171165
172- self .sigmas , self .timesteps = _get_karras_sigmas (self , num_inference_steps , self .config .sigma_max , self .config .sigma_min , self .config .rho , self .config .final_sigmas_type )
166+ self .sigmas , self .timesteps = _get_karras_sigmas (self . config . num_train_timesteps , num_inference_steps , self .config .sigma_max , self .config .sigma_min , self .config .rho , self .config .final_sigmas_type )
173167 self .num_inference_steps = len (self .timesteps )
174168
175169 self .sigma_min = self .sigmas [- 1 ].item ()
176170 self .sigma_max = self .sigmas [0 ].item ()
177171
172+ self .sigmas = self .sigmas .to (device )
173+ self .timesteps = self .timesteps .to (device )
174+ self ._reset_state ()
175+
176+ def _reset_state (self , solver_order : Optional [int ] = None ):
177+ """
178+ Resets the noise schedule & solver state variables
179+ """
180+ solver_order = solver_order or self .config .solver_order
181+ self .model_outputs = [None ] * solver_order
182+ self .timestep_list = [None ] * solver_order
183+ self .lower_order_nums = 0
178184 self .last_sample = None
179185 self ._step_index = None
180186 self ._begin_index = None
181- self .model_outputs = [None ] * self .config .solver_order
182- self .timestep_list = [None ] * self .config .solver_order
183- self .lower_order_nums = 0
184- self .solver_p = self .config .solver_p
185-
186- self .sigmas = self .sigmas .to (device )
187- self .timesteps = self .timesteps .to (device )
187+ if self .solver_p :
188+ self .solver_p .set_timesteps (self .num_inference_steps , device = device )
188189
189190 # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
190191 def _threshold_sample (self , sample : torch .Tensor ) -> torch .Tensor :
@@ -575,6 +576,7 @@ def _init_step_index(self, timestep):
575576 else :
576577 self ._step_index = self ._begin_index
577578
579+ # Modified from diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler.step
578580 def step (
579581 self ,
580582 model_output : torch .Tensor ,
@@ -638,7 +640,7 @@ def step(
638640 this_order = self .config .solver_order
639641
640642 self .this_order = min (this_order , self .lower_order_nums + 1 ) # warmup for multistep
641- assert self .this_order > 0
643+ assert self .this_order > 0 , "expected this_order > 0, this could be due to duplicate timesteps"
642644
643645 self .last_sample = sample
644646 prev_sample = self .multistep_uni_p_bh_update (
0 commit comments