@@ -506,7 +506,6 @@ def __call__(
506506 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
507507 max_sequence_length : int = 512 ,
508508 shift : float = 5.0 ,
509- timestep_scale : float = 0.001 ,
510509 conditional_frame_timestep : float = 0.1 ,
511510 ):
512511 r"""
@@ -592,7 +591,7 @@ def __call__(
592591 if isinstance (callback_on_step_end , (PipelineCallback , MultiPipelineCallbacks )):
593592 callback_on_step_end_tensor_inputs = callback_on_step_end .tensor_inputs
594593
595- # 1. Check inputs. Raise error if not correct
594+ # Check inputs. Raise error if not correct
596595 self .check_inputs (prompt , height , width , prompt_embeds , callback_on_step_end_tensor_inputs )
597596
598597 self ._guidance_scale = guidance_scale
@@ -613,15 +612,15 @@ def __call__(
613612 )
614613 self .safety_checker .to ("cpu" )
615614
616- # 2. Define call parameters
615+ # Define call parameters
617616 if prompt is not None and isinstance (prompt , str ):
618617 batch_size = 1
619618 elif prompt is not None and isinstance (prompt , list ):
620619 batch_size = len (prompt )
621620 else :
622621 batch_size = prompt_embeds .shape [0 ]
623622
624- # 3. Encode input prompt
623+ # Encode input prompt
625624 (
626625 prompt_embeds ,
627626 negative_prompt_embeds ,
@@ -639,10 +638,6 @@ def __call__(
639638 vae_dtype = self .vae .dtype
640639 transformer_dtype = self .transformer .dtype
641640
642- # 4. Prepare timesteps
643- self .scheduler .set_timesteps (num_inference_steps , shift = shift , device = device , scale = timestep_scale )
644- timesteps = torch .tensor (self .scheduler .timesteps ).to (transformer_dtype )
645-
646641 num_frames_in = None
647642 if image is not None :
648643 # TODO: handle batch_size > 1
@@ -690,22 +685,26 @@ def __call__(
690685
691686 padding_mask = latents .new_zeros (1 , 1 , height , width , dtype = transformer_dtype )
692687
693- # 6. Denoising loop
694- num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
688+ # Denoising loop
689+ self .scheduler .set_timesteps (num_inference_steps , shift = shift , device = device )
690+ timesteps = self .scheduler .timesteps
695691 self ._num_timesteps = len (timesteps )
692+ num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
696693
697- gt_velocity = latents - cond_latent
694+ gt_velocity = ( latents - cond_latent ) * cond_mask
698695 with self .progress_bar (total = num_inference_steps ) as progress_bar :
699696 for i , t in enumerate (timesteps ):
700697 if self .interrupt :
701698 continue
702699
703700 self ._current_timestep = t .cpu ().item ()
704- timestep = t .unsqueeze (0 )
701+
702+ # NOTE: equivalent to t / 1000 for FlowUniPCMultistepScheduler (sigmas are in [0, 1], num_train_timesteps=1000)
703+ sigma_t = torch .tensor (self .scheduler .sigmas [i ]).unsqueeze (0 ).to (device = device , dtype = transformer_dtype )
705704
706705 in_latents = cond_mask * cond_latent + (1 - cond_mask ) * latents # TODO: could use cond_indicator
707706 in_latents = in_latents .to (transformer_dtype )
708- in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator ) * timestep
707+ in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator ) * sigma_t
709708 noise_pred = self .transformer (
710709 hidden_states = in_latents ,
711710 condition_mask = cond_mask ,
@@ -714,8 +713,8 @@ def __call__(
714713 padding_mask = padding_mask ,
715714 return_dict = False ,
716715 )[0 ]
717- # NOTE: force input video latents for noise_pred by correcting velocity
718- noise_pred = gt_velocity * cond_mask + noise_pred * (1 - cond_mask )
716+ # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
717+ noise_pred = gt_velocity + noise_pred * (1 - cond_mask )
719718
720719 if self .do_classifier_free_guidance :
721720 noise_pred_neg = self .transformer (
@@ -726,8 +725,8 @@ def __call__(
726725 padding_mask = padding_mask ,
727726 return_dict = False ,
728727 )[0 ]
729- # NOTE: force input video latents for noise_pred by correcting velocity
730- noise_pred_neg = gt_velocity * cond_mask + noise_pred_neg * (1 - cond_mask )
728+ # NOTE: replace velocity (noise_pred) with gt_velocity for conditioning inputs only
729+ noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask )
731730 noise_pred = noise_pred + self .guidance_scale * (noise_pred - noise_pred_neg )
732731
733732 latents = self .scheduler .step (noise_pred , t , latents , return_dict = False )[0 ]
0 commit comments