From ee2740a22738be05c60409b76519541a95725a4d Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sun, 30 Nov 2025 00:24:48 -0700 Subject: [PATCH 01/18] init (pipeline, transformer, nerf, etc) --- src/diffusers/__init__.py | 5 +- src/diffusers/loaders/peft.py | 1 + src/diffusers/loaders/single_file_model.py | 5 + src/diffusers/loaders/single_file_utils.py | 207 +++- src/diffusers/models/__init__.py | 3 +- .../models/transformers/transformer_chroma.py | 423 ++++++++ src/diffusers/pipelines/__init__.py | 2 +- src/diffusers/pipelines/auto_pipeline.py | 3 +- src/diffusers/pipelines/chroma/__init__.py | 2 + .../chroma/pipeline_chroma_radiance.py | 906 ++++++++++++++++++ src/diffusers/utils/dummy_pt_objects.py | 28 +- .../dummy_torch_and_transformers_objects.py | 13 + 12 files changed, 1571 insertions(+), 27 deletions(-) create mode 100644 src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f2d1840da222..aae7aae47dfc 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -203,6 +203,7 @@ "BriaTransformer2DModel", "CacheMixin", "ChromaTransformer2DModel", + "ChromaRadianceTransformer2DModel", "ChronoEditTransformer3DModel", "CogVideoXTransformer3DModel", "CogView3PlusTransformer2DModel", @@ -441,6 +442,7 @@ "BriaFiboPipeline", "BriaPipeline", "ChromaImg2ImgPipeline", + "ChromaRadiancePipeline", "ChromaPipeline", "ChronoEditPipeline", "CLIPImageProjection", @@ -922,6 +924,7 @@ BriaTransformer2DModel, CacheMixin, ChromaTransformer2DModel, + ChromaRadianceTransformer2DModel, ChronoEditTransformer3DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, @@ -991,7 +994,6 @@ WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, - ZImageTransformer2DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1130,6 +1132,7 @@ BriaFiboPipeline, BriaPipeline, ChromaImg2ImgPipeline, + ChromaRadiancePipeline, ChromaPipeline, ChronoEditPipeline, CLIPImageProjection, diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py index b759e04cbf2d..450a513fc7cd 100644 --- a/src/diffusers/loaders/peft.py +++ b/src/diffusers/loaders/peft.py @@ -61,6 +61,7 @@ "HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights, "WanVACETransformer3DModel": lambda model_cls, weights: weights, "ChromaTransformer2DModel": lambda model_cls, weights: weights, + "ChromaRadianceTransformer2DModel": lambda model_cls, weights: weights, "QwenImageTransformer2DModel": lambda model_cls, weights: weights, "Flux2Transformer2DModel": lambda model_cls, weights: weights, } diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 7b581ac3eb9c..f512a5e2499e 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -32,6 +32,7 @@ convert_auraflow_transformer_checkpoint_to_diffusers, convert_autoencoder_dc_checkpoint_to_diffusers, convert_chroma_transformer_checkpoint_to_diffusers, + convert_chroma_radiance_transformer_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_cosmos_transformer_checkpoint_to_diffusers, convert_flux2_transformer_checkpoint_to_diffusers, @@ -110,6 +111,10 @@ "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, + "ChromaRadianceTransformer2DModel": { + "checkpoint_mapping_fn": convert_chroma_radiance_transformer_checkpoint_to_diffusers, + "default_subfolder": "transformer", + }, "LTXVideoTransformer3DModel": { "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 1e9404933402..4a1326387d2e 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -389,14 +389,6 @@ def is_valid_url(url): return False -def _is_single_file_path_or_url(pretrained_model_name_or_path): - if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path): - return False - - repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path) - return bool(repo_id and weight_name) - - def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): if not is_valid_url(pretrained_model_name_or_path): raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.") @@ -408,6 +400,7 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") match = re.match(pattern, pretrained_model_name_or_path) if not match: + logger.warning("Unable to identify the repo_id and weights_name from the provided URL.") return repo_id, weights_name repo_id = f"{match.group(1)}/{match.group(2)}" @@ -3547,6 +3540,204 @@ def swap_scale_shift(weight): return converted_state_dict +def convert_chroma_radiance_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): + converted_state_dict = {} + keys = list(checkpoint.keys()) + + for k in keys: + if "model.diffusion_model." in k: + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) + + num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401 + num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401 + num_guidance_layers = ( + list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401 + ) + num_nerf_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "nerf_blocks." in k))[-1] + 1 # noqa: C401 + mlp_ratio = 4.0 + inner_dim = 3072 + + # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale; + # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + # guidance + converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop( + "distilled_guidance_layer.in_proj.bias" + ) + converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop( + "distilled_guidance_layer.in_proj.weight" + ) + converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop( + "distilled_guidance_layer.out_proj.bias" + ) + converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop( + "distilled_guidance_layer.out_proj.weight" + ) + for i in range(num_guidance_layers): + block_prefix = f"distilled_guidance_layer.layers.{i}." + converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop( + f"distilled_guidance_layer.layers.{i}.in_layer.bias" + ) + converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop( + f"distilled_guidance_layer.layers.{i}.in_layer.weight" + ) + converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop( + f"distilled_guidance_layer.layers.{i}.out_layer.bias" + ) + converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop( + f"distilled_guidance_layer.layers.{i}.out_layer.weight" + ) + converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop( + f"distilled_guidance_layer.norms.{i}.scale" + ) + + # context_embedder + converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") + converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") + + # x_embedder + converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") + converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") + + # double transformer blocks + for i in range(num_layers): + block_prefix = f"transformer_blocks.{i}." + # Q, K, V + sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0) + context_q, context_k, context_v = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0 + ) + sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0 + ) + context_q_bias, context_k_bias, context_v_bias = torch.chunk( + checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q]) + converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k]) + converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v]) + converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias]) + # qk_norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.norm.key_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.norm.key_norm.scale" + ) + # ff img_mlp + converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias") + converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight") + converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias") + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.0.bias" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.weight" + ) + converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_mlp.2.bias" + ) + # output projections. + converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop( + f"double_blocks.{i}.img_attn.proj.bias" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.weight" + ) + converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop( + f"double_blocks.{i}.txt_attn.proj.bias" + ) + + # single transformer blocks + for i in range(num_single_layers): + block_prefix = f"single_transformer_blocks.{i}." + # Q, K, V, mlp + mlp_hidden_dim = int(inner_dim * mlp_ratio) + split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim) + q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0) + q_bias, k_bias, v_bias, mlp_bias = torch.split( + checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0 + ) + converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q]) + converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias]) + converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k]) + converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias]) + converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v]) + converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias]) + converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp]) + converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias]) + # qk norm + converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.query_norm.scale" + ) + converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop( + f"single_blocks.{i}.norm.key_norm.scale" + ) + # output projections. + converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") + converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") + + converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") + converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") + + # nerf + + converted_state_dict["nerf.nerf_embedder.embedder.bias"] = checkpoint.pop( + "nerf_image_embedder.embedder.0.bias" + ) + converted_state_dict["nerf.nerf_embedder.embedder.weight"] = checkpoint.pop( + "nerf_image_embedder.embedder.0.weight" + ) + converted_state_dict["nerf.final_layer.conv.bias"] = checkpoint.pop( + "nerf_final_layer_conv.conv.bias" + ) + converted_state_dict["nerf.final_layer.conv.weight"] = checkpoint.pop( + "nerf_final_layer_conv.conv.weight" + ) + converted_state_dict["nerf.final_layer.norm.scale"] = checkpoint.pop( + "nerf_final_layer_conv.norm.scale" + ) + + for i in range(num_nerf_layers): + block_prefix = f"nerf.blocks.{i}." + converted_state_dict[f"{block_prefix}norm.scale"] = checkpoint.pop( + f"nerf_blocks.{i}.norm.scale" + ) + converted_state_dict[f"{block_prefix}param_generator.bias"] = checkpoint.pop( + f"nerf_blocks.{i}.param_generator.bias" + ) + converted_state_dict[f"{block_prefix}param_generator.weight"] = checkpoint.pop( + f"nerf_blocks.{i}.param_generator.weight" + ) + + return converted_state_dict def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 09b2b731b5c4..f1042649e871 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -86,7 +86,7 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"] - _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"] + _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel", "ChromaRadianceTransformer2DModel"] _import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] @@ -185,6 +185,7 @@ BriaFiboTransformer2DModel, BriaTransformer2DModel, ChromaTransformer2DModel, + ChromaRadianceTransformer2DModel, ChronoEditTransformer3DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 2ef3643dafbd..a9986f67fb76 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -35,6 +35,135 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class Nerf(nn.Module): + def __init__( + self, + in_channels: int, + num_layers: int, + nerf_hidden_size: int, + transformer_hidden_size: int, + max_freqs: int, + mlp_ratio: int, + eps = 1e-6, + ): + self.nerf_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size=nerf_hidden_size, + max_freqs=max_freqs, + ) + self.blocks = nn.ModuleList( + [ + NerfGLUBlock( + transformer_hidden_size=transformer_hidden_size, + nerf_hidden_size=nerf_hidden_size, + mlp_ratio=mlp_ratio, + eps=eps, + ) + for _ in range(num_layers) + ] + ) + self.final_layer = NerfFinalLayer( + hidden_size=nerf_hidden_size, + out_channels=in_channels, + ) + self.transformer_hidden_size = transformer_hidden_size + def __call__( + self, + latents, + patch_size, + num_patches, + ): + batch_size, channels, height, width = latents.shape + pixels = nn.functional.unfold(latents, kernel_size=patch_size, stride=patch_size) + pixels = pixels.transpose(1, 2) + hidden = latents.reshape(batch_size * num_patches, self.transformer_hidden_size) + pixels = pixels.reshape(batch_size * num_patches, channels, patch_size**2).transpose(1, 2) + + # Get pixel embeddings + latents_dct = self.nerf_embedder(pixels) + + # Pass through blocks + for block in self.blocks: + latents_dct = block(latents_dct) + + latents_dct = latents_dct.transpose(1, 2).reshape(batch_size, -1).transpose(1, 2) + latents_dct = nn.functional.fold( + latents_dct, + output_size=(height, width), + kernel_size=patch_size, + stride=patch_size, + ) + return self.final_layer(latents_dct) + +class NerfEmbedder(nn.Module): + def __init__( + self, + in_channels: int, + hidden_size_input: int, + max_freqs: int, + ): + super().__init__() + self.max_freqs = max_freqs + self.hidden_size = hidden_size + self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size)) + def fetch_pos(self, patch_size) -> torch.Tensor: + pos_x = torch.linspace(0, 1, patch_size) + pos_y = torch.linspace(0, 1, patch_size) + pos_y, pos_x = torch.meshgrid(pos_y, pos_x, indexing="ij") + pos_x = pos_x.reshape(-1, 1, 1) + pos_y = pos_y.reshape(-1, 1, 1) + freqs = torch.linspace(0, self.max_freqs - 1, self.max_freqs) + freqs_x = freqs[None, :, None] + freqs_y = freqs[None, None, :] + coeffs = (1 + freqs_x * freqs_y) ** -1 + dct_x = torch.cos(pos_x * freqs_x * torch.pi) + dct_y = torch.cos(pos_y * freqs_y * torch.pi) + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + return dct + def __call__(self, inputs: torch.Tensor) -> torch.Tensor: + batch, pixels, channels = inputs.shape + patch_size = int(pixels**0.5) + input_dtype = inputs.dtype + inputs = inputs.to(dtype=self.embedder[0].weight.dtype) + dct = self.fetch_pos(patch_size, inputs.device, self.dtype) + dct = dct.repeat(batch, 1, 1) + inputs = torch.cat((inputs, dct), dim=-1) + return self.embedder(inputs).to(dtype=input_dtype) + +class NerfGLUBlock(nn.Module): + def __init__(self, transformer_hidden_size: int, nerf_hidden_size: int, mlp_ratio, eps): + super().__init__() + total_params = 3 * nerf_hidden_size**2 * mlp_ratio + self.param_generator = nn.Linear(transformer_hidden_size, total_params) + self.norm = RMSNorm(nerf_hidden_size, eps=eps) + self.mlp_ratio = mlp_ratio + def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: + batch_size, num_x, hidden_size_x = x.shape + mlp_params = self.param_generator(s) + fc1_gate_params, fc1_value_params, fc2_params = mlp_params.chunk(3, dim=-1) + fc1_gate = fc1_gate_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc1_value = fc1_value_params.view(batch_size, hidden_size_x, hidden_size_x * self.mlp_ratio) + fc2 = fc2_params.view(batch_size, hidden_size_x * self.mlp_ratio, hidden_size_x) + fc1_gate = torch.nn.functional.normalize(fc1_gate, dim=-2) + fc1_value = torch.nn.functional.normalize(fc1_value, dim=-2) + fc2 = torch.nn.functional.normalize(fc2, dim=-2) + res_x = x + x = self.norm(x) + x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) + return x + res_x + +class NerfFinalLayer(nn.Module): + def __init__(self, hidden_size: int, out_channels: int): + super().__init__() + self.norm = RMSNorm(hidden_size) + self.conv = nn.Conv2d( + in_channels=hidden_size, + out_channels=out_channels, + kernel_size=3, + padding=1, + ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1)) class ChromaAdaLayerNormZeroPruned(nn.Module): r""" @@ -639,3 +768,297 @@ def forward( return (output,) return Transformer2DModelOutput(sample=output) + +class ChromaRadianceTransformer2DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, + AttentionMixin, +): + """ + The Transformer model introduced in Flux, modified for Chroma Radiance. + + Reference: https://huggingface.co/lodestones/Chroma1-Radiance + + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `19`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `38`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `4096`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"] + _repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] + + @register_to_config + def __init__( + self, + patch_size: int = 16, + in_channels: int = 64, + out_channels: Optional[int] = None, + num_layers: int = 19, + num_single_layers: int = 38, + attention_head_dim: int = 128, + num_attention_heads: int = 24, + joint_attention_dim: int = 4096, + axes_dims_rope: Tuple[int, ...] = (16, 56, 56), + approximator_num_channels: int = 64, + approximator_hidden_dim: int = 5120, + approximator_layers: int = 5, + nerf_layers: int = 4, + nerf_hidden_dim: int = 64, + nerf_max_freqs: int = 8, + nerf_mlp_ratio: int = 4, + ): + super().__init__() + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + self.time_text_embed = ChromaCombinedTimestepTextProjEmbeddings( + num_channels=approximator_num_channels // 4, + out_dim=3 * num_single_layers + 2 * 6 * num_layers + 2, + ) + self.distilled_guidance_layer = ChromaApproximator( + in_dim=approximator_num_channels, + out_dim=self.inner_dim, + hidden_dim=approximator_hidden_dim, + n_layers=approximator_layers, + ) + + self.nerf = Nerf( + in_channels, + nerf_layers, + nerf_hidden_dim, + self.inner_dim, + nerf_max_freqs, + nerf_mlp_ratio, + ) + + self.img_in_patch = nn.Conv2d( + in_channels, + self.inner_dim, + kernel_size=patch_size, + stride=patch_size, + bias=True, + ) + + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + ChromaTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = nn.ModuleList( + [ + ChromaSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = ChromaAdaLayerNormContinuousPruned( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + attention_mask: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + + input_vec = self.time_text_embed(timestep) + pooled_temb = self.distilled_guidance_layer(input_vec) + + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + img_offset = 3 * len(self.single_transformer_blocks) + txt_offset = img_offset + 6 * len(self.transformer_blocks) + img_modulation = img_offset + 6 * index_block + text_modulation = txt_offset + 6 * index_block + temb = torch.cat( + ( + pooled_temb[:, img_modulation : img_modulation + 6], + pooled_temb[:, text_modulation : text_modulation + 6], + ), + dim=1, + ) + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + start_idx = 3 * index_block + temb = pooled_temb[:, start_idx : start_idx + 3] + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + temb, + image_rotary_emb, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + temb = pooled_temb[:, -2:] + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b4043cd146b4..00888fa3f896 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -154,7 +154,7 @@ "AudioLDM2UNet2DConditionModel", ] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] - _import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"] + _import_structure["chroma"] = ["ChromaPipeline", "ChromaRadiancePipeline", "ChromaImg2ImgPipeline"] _import_structure["cogvideo"] = [ "CogVideoXPipeline", "CogVideoXImageToVideoPipeline", diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index 044d854390e4..d8194cbae89e 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -21,7 +21,7 @@ from ..models.controlnets import ControlNetUnionModel from ..utils import is_sentencepiece_available from .aura_flow import AuraFlowPipeline -from .chroma import ChromaPipeline +from .chroma import ChromaPipeline, ChromaRadiancePipeline from .cogview3 import CogView3PlusPipeline from .cogview4 import CogView4ControlPipeline, CogView4Pipeline from .controlnet import ( @@ -157,6 +157,7 @@ ("lumina", LuminaPipeline), ("lumina2", Lumina2Pipeline), ("chroma", ChromaPipeline), + ("chroma-radiance", ChromaRadiancePipeline), ("cogview3", CogView3PlusPipeline), ("cogview4", CogView4Pipeline), ("cogview4-control", CogView4ControlPipeline), diff --git a/src/diffusers/pipelines/chroma/__init__.py b/src/diffusers/pipelines/chroma/__init__.py index d9238b735c41..1017a1702bae 100644 --- a/src/diffusers/pipelines/chroma/__init__.py +++ b/src/diffusers/pipelines/chroma/__init__.py @@ -23,6 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_chroma"] = ["ChromaPipeline"] + _import_structure["pipeline_chroma_radiance"] = ["ChromaRadiancePipeline"] _import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -32,6 +33,7 @@ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_chroma import ChromaPipeline + from .pipeline_chroma_radiance import ChromaRadiancePipeline from .pipeline_chroma_img2img import ChromaImg2ImgPipeline else: import sys diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py new file mode 100644 index 000000000000..2488865b1f5a --- /dev/null +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -0,0 +1,906 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5EncoderModel, T5TokenizerFast + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ChromaTransformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import ChromaPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import ChromaPipeline + + >>> model_id = "lodestones/Chroma1-HD" + >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" + >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) + >>> pipe = ChromaPipeline.from_pretrained( + ... model_id, + ... transformer=transformer, + ... torch_dtype=torch.bfloat16, + ... ) + >>> pipe.enable_model_cpu_offload() + >>> prompt = [ + ... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done." + ... ] + >>> negative_prompt = [ + ... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors" + ... ] + >>> image = pipe(prompt, negative_prompt=negative_prompt).images[0] + >>> image.save("chroma.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +class ChromaRadiancePipeline( + DiffusionPipeline, + FluxLoraLoaderMixin, + FromSingleFileMixin, + TextualInversionLoaderMixin, + FluxIPAdapterMixin, +): + r""" + The Chroma Radiance pipeline for text-to-image generation. + + Reference: https://huggingface.co/lodestones/Chroma1-Radiance/ + + Args: + transformer ([`ChromaRadianceTransformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + text_encoder ([`T5EncoderModel`]): + [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically + the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. + tokenizer (`T5TokenizerFast`): + Tokenizer of class + [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->transformer" + _optional_components = ["image_encoder", "feature_extractor"] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + text_encoder: T5EncoderModel, + tokenizer: T5TokenizerFast, + transformer: ChromaTransformer2DModel, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + ): + super().__init__() + + self.register_modules( + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.default_sample_size = 1024 + + def _get_t5_prompt_embeds( + self, + prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + tokenizer_mask = text_inputs.attention_mask + + tokenizer_mask_device = tokenizer_mask.to(device) + + # unlike FLUX, Chroma uses the attention mask when generating the T5 embedding + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + output_hidden_states=False, + attention_mask=tokenizer_mask_device, + )[0] + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + # for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer + seq_lengths = tokenizer_mask_device.sum(dim=1) + mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1) + attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + attention_mask = attention_mask.repeat(1, num_images_per_prompt) + attention_mask = attention_mask.view(batch_size * num_images_per_prompt, seq_len) + + return prompt_embeds, attention_mask + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Union[str, List[str]] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + do_classifier_free_guidance: bool = True, + max_sequence_length: int = 512, + lora_scale: Optional[float] = None, + ): + r""" + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + negative_text_ids = None + + if do_classifier_free_guidance: + if negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt = ( + batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_text_ids = torch.zeros(negative_prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) + + if self.text_encoder is not None: + if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_ip_adapter_image in ip_adapter_image: + single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) + image_embeds.append(single_image_embeds[None, :]) + else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + + for single_image_embeds in ip_adapter_image_embeds: + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for single_image_embeds in image_embeds: + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt=None, + prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_embeds=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Cannot provide `prompt_embeds` without also providing `prompt_attention_mask") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError( + "Cannot provide `negative_prompt_embeds` without also providing `negative_prompt_attention_mask" + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + @staticmethod + def _prepare_latent_image_ids(batch_size, height, width, device, dtype): + latent_image_ids = torch.zeros(height, width, 3) + latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] + latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) + + return latents + + @staticmethod + def _unpack_latents(latents, height, width): + batch_size, num_patches, channels = latents.shape + + latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape(batch_size, channels // (2 * 2), height, width) + + return latents + + # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + return latents.to(device=device, dtype=dtype), latent_image_ids + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + + return latents, latent_image_ids + + def _prepare_attention_mask( + self, + batch_size, + sequence_length, + dtype, + attention_mask=None, + ): + if attention_mask is None: + return attention_mask + + # Extend the prompt attention mask to account for image tokens in the final sequence + attention_mask = torch.cat( + [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)], + dim=1, + ) + + return attention_mask + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def current_timestep(self): + return self._current_timestep + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 35, + sigmas: Optional[List[float]] = None, + guidance_scale: float = 5.0, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_ip_adapter_image: Optional[PipelineImageInput] = None, + negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + not greater than `1`). + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the prompt embeddings. Used to mask out padding tokens in the prompt sequence. + Chroma requires a single padding token remain unmasked. Please refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + negative_prompt_attention_mask (torch.Tensor, *optional*): + Attention mask for the negative prompt embeddings. Used to mask out padding tokens in the negative + prompt sequence. Chroma requires a single padding token remain unmasked. PLease refer to + https://huggingface.co/lodestones/Chroma#tldr-masking-t5-padding-tokens-enhanced-fidelity-and-increased-stability-during-training + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.ChromaPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.chroma.ChromaPipelineOutput`] or `tuple`: [`~pipelines.chroma.ChromaPipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size + width = width or self.default_sample_size + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_embeds=negative_prompt_embeds, + negative_prompt_attention_mask=negative_prompt_attention_mask, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + max_sequence_length=max_sequence_length, + ) + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( + prompt_embeds, + text_ids, + prompt_attention_mask, + negative_prompt_embeds, + negative_text_ids, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + do_classifier_free_guidance=self.do_classifier_free_guidance, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + latents = self.transformer.img_in_patch(latents) + latents = latents.flatten(2).transpose(1, 2) + num_patches = latents.shape[1] + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.15), + ) + + attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=prompt_attention_mask, + ) + negative_attention_mask = self._prepare_attention_mask( + batch_size=latents.shape[0], + sequence_length=image_seq_len, + dtype=latents.dtype, + attention_mask=negative_prompt_attention_mask, + ) + + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( + negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None + ): + negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( + negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None + ): + ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) + ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: + negative_image_embeds = self.prepare_ip_adapter_image_embeds( + negative_ip_adapter_image, + negative_ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + ) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + attention_mask=attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + attention_mask=negative_attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + self._current_timestep = None + + # 7. + + if output_type == "latent": + image = latents + else: + image = self.nerf(latents, self.patch_size, num_patches) + image = self._unpack_latents(image, height, width) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ChromaPipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 73854b38190e..69d26744b216 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -662,6 +662,19 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ChromaRadianceTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) class ChronoEditTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1698,21 +1711,6 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) -class ZImageTransformer2DModel(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - def attention_backend(*args, **kwargs): requires_backends(attention_backend, ["torch"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index e6cf26a12544..17ab236849b9 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -556,6 +556,19 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class ChromaRadiancePipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) class ChronoEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From e0707bda6e47058bf0191ab162e4f5decb4363a0 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sun, 30 Nov 2025 00:34:14 -0700 Subject: [PATCH 02/18] undo regressions --- src/diffusers/__init__.py | 1 + src/diffusers/loaders/single_file_utils.py | 9 ++++++++- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index aae7aae47dfc..1e4786f181c2 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -994,6 +994,7 @@ WanAnimateTransformer3DModel, WanTransformer3DModel, WanVACETransformer3DModel, + ZImageTransformer2DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 4a1326387d2e..cb1e3bee109d 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -389,6 +389,14 @@ def is_valid_url(url): return False +def _is_single_file_path_or_url(pretrained_model_name_or_path): + if not os.path.isfile(pretrained_model_name_or_path) or not is_valid_url(pretrained_model_name_or_path): + return False + + repo_id, weight_name = _extract_repo_id_and_weights_name(pretrained_model_name_or_path) + return bool(repo_id and weight_name) + + def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): if not is_valid_url(pretrained_model_name_or_path): raise ValueError("Invalid `pretrained_model_name_or_path` provided. Please set it to a valid URL.") @@ -400,7 +408,6 @@ def _extract_repo_id_and_weights_name(pretrained_model_name_or_path): pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "") match = re.match(pattern, pretrained_model_name_or_path) if not match: - logger.warning("Unable to identify the repo_id and weights_name from the provided URL.") return repo_id, weights_name repo_id = f"{match.group(1)}/{match.group(2)}" diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 69d26744b216..97121e50d851 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1711,6 +1711,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class ZImageTransformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def attention_backend(*args, **kwargs): requires_backends(attention_backend, ["torch"]) From 07b368ea2f969d1e309dfc65f956b0f8ffc23f62 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sun, 30 Nov 2025 01:10:33 -0700 Subject: [PATCH 03/18] fixes --- src/diffusers/loaders/single_file_utils.py | 7 ---- .../models/transformers/transformer_chroma.py | 33 +++++++++++-------- .../chroma/pipeline_chroma_radiance.py | 5 +-- 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index cb1e3bee109d..2ad02126126c 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -3606,10 +3606,6 @@ def swap_scale_shift(weight): converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight") converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias") - # x_embedder - converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight") - converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias") - # double transformer blocks for i in range(num_layers): block_prefix = f"transformer_blocks.{i}." @@ -3710,9 +3706,6 @@ def swap_scale_shift(weight): # output projections. converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") - - converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias") # nerf diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index a9986f67fb76..f6bd7a47890d 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -46,6 +46,7 @@ def __init__( mlp_ratio: int, eps = 1e-6, ): + super().__init__() self.nerf_embedder = NerfEmbedder( in_channels=in_channels, hidden_size=nerf_hidden_size, @@ -65,17 +66,17 @@ def __init__( self.final_layer = NerfFinalLayer( hidden_size=nerf_hidden_size, out_channels=in_channels, + eps=eps, ) self.transformer_hidden_size = transformer_hidden_size def __call__( self, + pixels, latents, patch_size, num_patches, ): batch_size, channels, height, width = latents.shape - pixels = nn.functional.unfold(latents, kernel_size=patch_size, stride=patch_size) - pixels = pixels.transpose(1, 2) hidden = latents.reshape(batch_size * num_patches, self.transformer_hidden_size) pixels = pixels.reshape(batch_size * num_patches, channels, patch_size**2).transpose(1, 2) @@ -99,7 +100,7 @@ class NerfEmbedder(nn.Module): def __init__( self, in_channels: int, - hidden_size_input: int, + hidden_size: int, max_freqs: int, ): super().__init__() @@ -153,9 +154,9 @@ def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: return x + res_x class NerfFinalLayer(nn.Module): - def __init__(self, hidden_size: int, out_channels: int): + def __init__(self, hidden_size: int, out_channels: int, eps): super().__init__() - self.norm = RMSNorm(hidden_size) + self.norm = RMSNorm(hidden_size, eps=eps) self.conv = nn.Conv2d( in_channels=hidden_size, out_channels=out_channels, @@ -654,8 +655,11 @@ def forward( logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - - hidden_states = self.x_embedder(hidden_states) + + pixels = nn.functional.unfold(hidden_states, kernel_size=self.patch_size, stride=self.patch_size) + pixels = pixels.transpose(1, 2) + hidden_states = self.transformer.img_in_patch(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -889,10 +893,10 @@ def __init__( ] ) - self.norm_out = ChromaAdaLayerNormContinuousPruned( - self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 - ) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + #self.norm_out = ChromaAdaLayerNormContinuousPruned( + # self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + #) + #self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) self.gradient_checkpointing = False @@ -1050,9 +1054,10 @@ def forward( hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - temb = pooled_temb[:, -2:] - hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) + #temb = pooled_temb[:, -2:] + #hidden_states = self.norm_out(hidden_states, temb) + #output = self.proj_out(hidden_states) + output = self.nerf(hidden_states, self.transformer.patch_size, num_patches) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index 2488865b1f5a..33b68be163f6 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -752,9 +752,7 @@ def __call__( latents, ) - latents = self.transformer.img_in_patch(latents) - latents = latents.flatten(2).transpose(1, 2) - num_patches = latents.shape[1] + num_patches = (height * width) // self.transformer.patch_size # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas @@ -893,7 +891,6 @@ def __call__( if output_type == "latent": image = latents else: - image = self.nerf(latents, self.patch_size, num_patches) image = self._unpack_latents(image, height, width) image = self.image_processor.postprocess(image, output_type=output_type) From d3bf917aceee2277b7e808394ff4948860a225b8 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sun, 30 Nov 2025 01:41:05 -0700 Subject: [PATCH 04/18] random tweaks --- src/diffusers/models/transformers/transformer_chroma.py | 5 ++--- src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index f6bd7a47890d..00a0c129b902 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -790,7 +790,7 @@ class ChromaRadianceTransformer2DModel( Args: patch_size (`int`, defaults to `1`): Patch size to turn the input data into small patches. - in_channels (`int`, defaults to `64`): + in_channels (`int`, defaults to `3`): The number of channels in the input. out_channels (`int`, *optional*, defaults to `None`): The number of channels in the output. If not specified, it defaults to `in_channels`. @@ -818,7 +818,7 @@ class ChromaRadianceTransformer2DModel( def __init__( self, patch_size: int = 16, - in_channels: int = 64, + in_channels: int = 3, out_channels: Optional[int] = None, num_layers: int = 19, num_single_layers: int = 38, @@ -859,7 +859,6 @@ def __init__( nerf_max_freqs, nerf_mlp_ratio, ) - self.img_in_patch = nn.Conv2d( in_channels, self.inner_dim, diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index 33b68be163f6..5928fe34375c 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -181,7 +181,7 @@ def __init__( scheduler: FlowMatchEulerDiscreteScheduler, text_encoder: T5EncoderModel, tokenizer: T5TokenizerFast, - transformer: ChromaTransformer2DModel, + transformer: ChromaRadianceTransformer2DModel, image_encoder: CLIPVisionModelWithProjection = None, feature_extractor: CLIPImageProcessor = None, ): From 1a20f5320f680cb9869596ac733a899475a649ff Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sun, 30 Nov 2025 02:12:33 -0700 Subject: [PATCH 05/18] make saveable --- src/diffusers/loaders/single_file_utils.py | 26 ++++++++++++++++--- .../models/transformers/transformer_chroma.py | 22 +++++----------- .../chroma/pipeline_chroma_radiance.py | 2 +- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 2ad02126126c..6752aa95a56c 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -191,6 +191,7 @@ "flux-depth": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-Depth-dev"}, "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, "flux-2-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-dev"}, + "chroma": {"pretrained_model_name_or_path": "lodestones/Chroma1-HD"}, "ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"}, "ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"}, "ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"}, @@ -663,6 +664,14 @@ def infer_diffusers_model_type(checkpoint): elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): if any( + c in checkpoint for c in ["distilled_guidance_layer.in_proj.bias"] + ): + # Should be updated once a repo exists + # if any(h in checkpoint for h in ["nerf_blocks.0.param_generator.bias"]): + # model_type = "chroma-radiance" + # else: + model_type = "chroma" + elif any( g in checkpoint for g in ["guidance_in.in_layer.bias", "model.diffusion_model.guidance_in.in_layer.bias"] ): if "model.diffusion_model.img_in.weight" in checkpoint: @@ -3709,10 +3718,10 @@ def swap_scale_shift(weight): # nerf - converted_state_dict["nerf.nerf_embedder.embedder.bias"] = checkpoint.pop( + converted_state_dict["nerf.nerf_embedder.embedder.0.bias"] = checkpoint.pop( "nerf_image_embedder.embedder.0.bias" ) - converted_state_dict["nerf.nerf_embedder.embedder.weight"] = checkpoint.pop( + converted_state_dict["nerf.nerf_embedder.embedder.0.weight"] = checkpoint.pop( "nerf_image_embedder.embedder.0.weight" ) converted_state_dict["nerf.final_layer.conv.bias"] = checkpoint.pop( @@ -3721,13 +3730,13 @@ def swap_scale_shift(weight): converted_state_dict["nerf.final_layer.conv.weight"] = checkpoint.pop( "nerf_final_layer_conv.conv.weight" ) - converted_state_dict["nerf.final_layer.norm.scale"] = checkpoint.pop( + converted_state_dict["nerf.final_layer.norm.weight"] = checkpoint.pop( "nerf_final_layer_conv.norm.scale" ) for i in range(num_nerf_layers): block_prefix = f"nerf.blocks.{i}." - converted_state_dict[f"{block_prefix}norm.scale"] = checkpoint.pop( + converted_state_dict[f"{block_prefix}norm.weight"] = checkpoint.pop( f"nerf_blocks.{i}.norm.scale" ) converted_state_dict[f"{block_prefix}param_generator.bias"] = checkpoint.pop( @@ -3737,6 +3746,15 @@ def swap_scale_shift(weight): f"nerf_blocks.{i}.param_generator.weight" ) + # patch + + converted_state_dict["img_in_patch.bias"] = checkpoint.pop( + "img_in_patch.bias" + ) + converted_state_dict["img_in_patch.weight"] = checkpoint.pop( + "img_in_patch.weight" + ) + return converted_state_dict def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 00a0c129b902..608cbbea6d64 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -656,10 +656,7 @@ def forward( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - pixels = nn.functional.unfold(hidden_states, kernel_size=self.patch_size, stride=self.patch_size) - pixels = pixels.transpose(1, 2) - hidden_states = self.transformer.img_in_patch(hidden_states) - hidden_states = hidden_states.flatten(2).transpose(1, 2) + hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -859,6 +856,7 @@ def __init__( nerf_max_freqs, nerf_mlp_ratio, ) + self.img_in_patch = nn.Conv2d( in_channels, self.inner_dim, @@ -868,7 +866,6 @@ def __init__( ) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) - self.x_embedder = nn.Linear(in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ @@ -892,11 +889,6 @@ def __init__( ] ) - #self.norm_out = ChromaAdaLayerNormContinuousPruned( - # self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 - #) - #self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) - self.gradient_checkpointing = False def forward( @@ -951,8 +943,11 @@ def forward( logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - - hidden_states = self.x_embedder(hidden_states) + + pixels = nn.functional.unfold(hidden_states, kernel_size=self.patch_size, stride=self.patch_size) + pixels = pixels.transpose(1, 2) + hidden_states = self.transformer.img_in_patch(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -1053,9 +1048,6 @@ def forward( hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - #temb = pooled_temb[:, -2:] - #hidden_states = self.norm_out(hidden_states, temb) - #output = self.proj_out(hidden_states) output = self.nerf(hidden_states, self.transformer.patch_size, num_patches) if USE_PEFT_BACKEND: diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index 5928fe34375c..791f79eefd86 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -21,7 +21,7 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ChromaTransformer2DModel +from ...models import AutoencoderKL, ChromaTransformer2DModel, ChromaRadianceTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, From ed746c5b0f050669bdec64cb95583f420410347a Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sun, 30 Nov 2025 18:13:28 -0700 Subject: [PATCH 06/18] update --- .../models/transformers/transformer_chroma.py | 20 ++++++++++++------- .../chroma/pipeline_chroma_radiance.py | 18 +++++++++++------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 608cbbea6d64..b1ce5c6dfa04 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -74,9 +74,13 @@ def __call__( pixels, latents, patch_size, - num_patches, ): - batch_size, channels, height, width = latents.shape + batch_size, channels, height, width = pixels.shape + num_patches = latents.shape[1] + + pixels = nn.functional.unfold(pixels, kernel_size=self.patch_size, stride=self.patch_size) + pixels = pixels.transpose(1, 2) + hidden = latents.reshape(batch_size * num_patches, self.transformer_hidden_size) pixels = pixels.reshape(batch_size * num_patches, channels, patch_size**2).transpose(1, 2) @@ -85,7 +89,7 @@ def __call__( # Pass through blocks for block in self.blocks: - latents_dct = block(latents_dct) + latents_dct = block(latents_dct, hidden) latents_dct = latents_dct.transpose(1, 2).reshape(batch_size, -1).transpose(1, 2) latents_dct = nn.functional.fold( @@ -929,6 +933,7 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + print(hidden_states.shape) if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -944,10 +949,10 @@ def forward( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - pixels = nn.functional.unfold(hidden_states, kernel_size=self.patch_size, stride=self.patch_size) - pixels = pixels.transpose(1, 2) - hidden_states = self.transformer.img_in_patch(hidden_states) + pixels = hidden_states + hidden_states = self.img_in_patch(hidden_states) hidden_states = hidden_states.flatten(2).transpose(1, 2) + print(hidden_states.shape) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -955,6 +960,7 @@ def forward( pooled_temb = self.distilled_guidance_layer(input_vec) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + print(encoder_hidden_states.shape) if txt_ids.ndim == 3: logger.warning( @@ -1048,7 +1054,7 @@ def forward( hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - output = self.nerf(hidden_states, self.transformer.patch_size, num_patches) + output = self.nerf(pixels, hidden_states, self.transformer.patch_size) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index 791f79eefd86..e0d490f3295c 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -503,6 +503,7 @@ def prepare_latents( ): shape = (batch_size, num_channels_latents, height, width) + print(shape) if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) @@ -515,7 +516,7 @@ def prepare_latents( ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + #latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) @@ -740,7 +741,7 @@ def __call__( ) # 4. Prepare latent variables - num_channels_latents = self.transformer.config.in_channels // 4 + num_channels_latents = self.transformer.config.in_channels latents, latent_image_ids = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, @@ -751,8 +752,6 @@ def __call__( generator, latents, ) - - num_patches = (height * width) // self.transformer.patch_size # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas @@ -820,6 +819,7 @@ def __call__( batch_size * num_images_per_prompt, ) + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -891,8 +891,14 @@ def __call__( if output_type == "latent": image = latents else: - image = self._unpack_latents(image, height, width) - image = self.image_processor.postprocess(image, output_type=output_type) + #image = self.transformer.nerf( + # pixels, + # latents, + # self.transformer.config.patch_size, + # num_patches, + #) + #image = self._unpack_latents(image, height, width) + image = self.image_processor.postprocess(latents, output_type=output_type) # Offload all models self.maybe_free_model_hooks() From 0a19b7e14cf3c36e0195ee74d314c296fa2937ed Mon Sep 17 00:00:00 2001 From: BuildTools Date: Mon, 1 Dec 2025 21:29:19 -0700 Subject: [PATCH 07/18] fix img_ids --- src/diffusers/loaders/single_file_utils.py | 4 ++-- src/diffusers/models/embeddings.py | 2 ++ .../models/transformers/transformer_chroma.py | 18 ++++++++++++------ .../models/transformers/transformer_flux.py | 2 ++ .../chroma/pipeline_chroma_radiance.py | 4 +++- 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index 6752aa95a56c..b51af1e89267 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -3748,10 +3748,10 @@ def swap_scale_shift(weight): # patch - converted_state_dict["img_in_patch.bias"] = checkpoint.pop( + converted_state_dict["x_embedder_patch.bias"] = checkpoint.pop( "img_in_patch.bias" ) - converted_state_dict["img_in_patch.weight"] = checkpoint.pop( + converted_state_dict["x_embedder_patch.weight"] = checkpoint.pop( "img_in_patch.weight" ) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 37fc412adcc3..07d47e83e1b8 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1229,6 +1229,8 @@ def apply_rotary_emb( else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + print(x.shape, cos.shape) + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index b1ce5c6dfa04..04065e4e55b9 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -74,9 +74,9 @@ def __call__( pixels, latents, patch_size, + num_patches, ): batch_size, channels, height, width = pixels.shape - num_patches = latents.shape[1] pixels = nn.functional.unfold(pixels, kernel_size=self.patch_size, stride=self.patch_size) pixels = pixels.transpose(1, 2) @@ -861,7 +861,7 @@ def __init__( nerf_mlp_ratio, ) - self.img_in_patch = nn.Conv2d( + self.x_embedder_patch = nn.Conv2d( in_channels, self.inner_dim, kernel_size=patch_size, @@ -933,7 +933,7 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - print(hidden_states.shape) + print("states", hidden_states.shape, encoder_hidden_states.shape) if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -949,8 +949,12 @@ def forward( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - pixels = hidden_states - hidden_states = self.img_in_patch(hidden_states) + pixels = nn.functional.unfold(hidden_states, kernel_size=self.config.patch_size, stride=self.config.patch_size) + pixels = pixels.transpose(1, 2) + print("pixels", pixels.shape) + hidden_states = self.x_embedder_patch(hidden_states) + print("img_patch:", hidden_states.shape) + num_patches = hidden_states.shape[2] * hidden_states.shape[3] hidden_states = hidden_states.flatten(2).transpose(1, 2) print(hidden_states.shape) @@ -975,8 +979,10 @@ def forward( ) img_ids = img_ids[0] + print("txt", txt_ids.shape, "img", img_ids.shape) ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) + print("ids:", ids.shape, "emb:", image_rotary_emb[0].shape, image_rotary_emb[1].shape) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") @@ -1054,7 +1060,7 @@ def forward( hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - output = self.nerf(pixels, hidden_states, self.transformer.patch_size) + output = self.nerf(pixels, hidden_states, self.config.patch_size, num_patches) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16c526f437f2..ccd4a40d0bd5 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -717,11 +717,13 @@ def forward( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) + print("ids", ids.shape) if is_torch_npu_available(): freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) else: image_rotary_emb = self.pos_embed(ids) + print([x.shape for x in image_rotary_emb]) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index e0d490f3295c..559d15cb7203 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -500,6 +500,7 @@ def prepare_latents( device, generator, latents=None, + patch_size=2, ): shape = (batch_size, num_channels_latents, height, width) @@ -518,7 +519,7 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) #latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) + latent_image_ids = self._prepare_latent_image_ids(batch_size, height // patch_size, width // patch_size, device, dtype) return latents, latent_image_ids @@ -751,6 +752,7 @@ def __call__( device, generator, latents, + patch_size=self.transformer.config.patch_size, ) # 5. Prepare timesteps From df3ab44161545d490dc27dfde2919e05693f6bb3 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Mon, 1 Dec 2025 21:34:01 -0700 Subject: [PATCH 08/18] remove some prints --- src/diffusers/models/embeddings.py | 2 -- src/diffusers/models/transformers/transformer_flux.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 07d47e83e1b8..37fc412adcc3 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1229,8 +1229,6 @@ def apply_rotary_emb( else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") - print(x.shape, cos.shape) - out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index ccd4a40d0bd5..16c526f437f2 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -717,13 +717,11 @@ def forward( img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) - print("ids", ids.shape) if is_torch_npu_available(): freqs_cos, freqs_sin = self.pos_embed(ids.cpu()) image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu()) else: image_rotary_emb = self.pos_embed(ids) - print([x.shape for x in image_rotary_emb]) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") From 1a041721c3e34818b0257d80c43f0ef481a8db46 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Mon, 1 Dec 2025 22:12:09 -0700 Subject: [PATCH 09/18] updates --- src/diffusers/models/transformers/transformer_flux.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16c526f437f2..3ff86854626d 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -115,6 +115,8 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + print("query", query.shape, "key", key.shape, "value", value.shape) + hidden_states = dispatch_attention_fn( query, key, From a82990903e512c4874e37ac156eaf862aede3a9c Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sat, 13 Dec 2025 05:50:49 -0700 Subject: [PATCH 10/18] get closer to fixing attn mask --- src/diffusers/models/transformers/transformer_chroma.py | 2 -- src/diffusers/models/transformers/transformer_flux.py | 1 + .../pipelines/chroma/pipeline_chroma_radiance.py | 8 +++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 04065e4e55b9..a7d9571832e1 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -979,10 +979,8 @@ def forward( ) img_ids = img_ids[0] - print("txt", txt_ids.shape, "img", img_ids.shape) ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) - print("ids:", ids.shape, "emb:", image_rotary_emb[0].shape, image_rotary_emb[1].shape) if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3ff86854626d..a7126fe74af2 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -351,6 +351,7 @@ def forward( f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." ) kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + print("processor", attention_mask.shape) return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index 559d15cb7203..ffea4e6bbbc1 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -757,7 +757,7 @@ def __call__( # 5. Prepare timesteps sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas - image_seq_len = latents.shape[1] + image_seq_len = latent_image_ids.shape[0] mu = calculate_shift( image_seq_len, self.scheduler.config.get("base_image_seq_len", 256), @@ -831,6 +831,8 @@ def __call__( self._current_timestep = t if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds + if attention_mask is not None: + self._joint_attention_kwargs["attention_mask"] = attention_mask.to(self.transformer.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -841,7 +843,6 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, - attention_mask=attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] @@ -849,13 +850,14 @@ def __call__( if self.do_classifier_free_guidance: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + if attention_mask is not None: + self._joint_attention_kwargs["attention_mask"] = negative_attention_mask.to(self.transformer.dtype) neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_image_ids, - attention_mask=negative_attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] From 54f3bbf1ffbe77e06ae44a5df377886c93f0d50c Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sat, 13 Dec 2025 05:54:03 -0700 Subject: [PATCH 11/18] fix attention --- src/diffusers/models/transformers/transformer_flux.py | 1 - src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py | 4 ---- 2 files changed, 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index a7126fe74af2..3ff86854626d 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -351,7 +351,6 @@ def forward( f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." ) kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} - print("processor", attention_mask.shape) return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index ffea4e6bbbc1..879f6a0892de 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -831,8 +831,6 @@ def __call__( self._current_timestep = t if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds - if attention_mask is not None: - self._joint_attention_kwargs["attention_mask"] = attention_mask.to(self.transformer.dtype) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -850,8 +848,6 @@ def __call__( if self.do_classifier_free_guidance: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - if attention_mask is not None: - self._joint_attention_kwargs["attention_mask"] = negative_attention_mask.to(self.transformer.dtype) neg_noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, From 655b8e67fb15980fe77d67ea4677c5dedde90c4a Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sat, 13 Dec 2025 06:43:59 -0700 Subject: [PATCH 12/18] make inference run --- .../models/transformers/transformer_chroma.py | 21 +++++++------------ .../chroma/pipeline_chroma_radiance.py | 7 +------ 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index a7d9571832e1..50618a49e562 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -78,7 +78,7 @@ def __call__( ): batch_size, channels, height, width = pixels.shape - pixels = nn.functional.unfold(pixels, kernel_size=self.patch_size, stride=self.patch_size) + pixels = nn.functional.unfold(pixels, kernel_size=patch_size, stride=patch_size) pixels = pixels.transpose(1, 2) hidden = latents.reshape(batch_size * num_patches, self.transformer_hidden_size) @@ -91,7 +91,7 @@ def __call__( for block in self.blocks: latents_dct = block(latents_dct, hidden) - latents_dct = latents_dct.transpose(1, 2).reshape(batch_size, -1).transpose(1, 2) + latents_dct = latents_dct.transpose(1, 2).reshape(batch_size, num_patches, -1).transpose(1, 2) latents_dct = nn.functional.fold( latents_dct, output_size=(height, width), @@ -129,11 +129,10 @@ def __call__(self, inputs: torch.Tensor) -> torch.Tensor: batch, pixels, channels = inputs.shape patch_size = int(pixels**0.5) input_dtype = inputs.dtype - inputs = inputs.to(dtype=self.embedder[0].weight.dtype) - dct = self.fetch_pos(patch_size, inputs.device, self.dtype) - dct = dct.repeat(batch, 1, 1) + dct = self.fetch_pos(patch_size) + dct = dct.repeat(batch, 1, 1).to(dtype=input_dtype, device=inputs.device) inputs = torch.cat((inputs, dct), dim=-1) - return self.embedder(inputs).to(dtype=input_dtype) + return self.embedder(inputs) class NerfGLUBlock(nn.Module): def __init__(self, transformer_hidden_size: int, nerf_hidden_size: int, mlp_ratio, eps): @@ -933,7 +932,8 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - print("states", hidden_states.shape, encoder_hidden_states.shape) + print(self.device) + pixels = hidden_states.to(self.device) if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) @@ -948,15 +948,9 @@ def forward( logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - - pixels = nn.functional.unfold(hidden_states, kernel_size=self.config.patch_size, stride=self.config.patch_size) - pixels = pixels.transpose(1, 2) - print("pixels", pixels.shape) hidden_states = self.x_embedder_patch(hidden_states) - print("img_patch:", hidden_states.shape) num_patches = hidden_states.shape[2] * hidden_states.shape[3] hidden_states = hidden_states.flatten(2).transpose(1, 2) - print(hidden_states.shape) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -964,7 +958,6 @@ def forward( pooled_temb = self.distilled_guidance_layer(input_vec) encoder_hidden_states = self.context_embedder(encoder_hidden_states) - print(encoder_hidden_states.shape) if txt_ids.ndim == 3: logger.warning( diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index 879f6a0892de..59f0025f5e7b 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -195,6 +195,7 @@ def __init__( image_encoder=image_encoder, feature_extractor=feature_extractor, ) + self.image_processor = VaeImageProcessor() self.default_sample_size = 1024 def _get_t5_prompt_embeds( @@ -891,12 +892,6 @@ def __call__( if output_type == "latent": image = latents else: - #image = self.transformer.nerf( - # pixels, - # latents, - # self.transformer.config.patch_size, - # num_patches, - #) #image = self._unpack_latents(image, height, width) image = self.image_processor.postprocess(latents, output_type=output_type) From 0191cfde6f052d3d478535269cbb9d6cf0a29f50 Mon Sep 17 00:00:00 2001 From: BuildTools Date: Sat, 13 Dec 2025 07:01:26 -0700 Subject: [PATCH 13/18] cleanup --- src/diffusers/__init__.py | 4 +- src/diffusers/loaders/single_file_model.py | 2 +- src/diffusers/loaders/single_file_utils.py | 40 ++++++------------- src/diffusers/models/__init__.py | 7 +++- .../models/transformers/transformer_chroma.py | 32 ++++++++++----- .../models/transformers/transformer_flux.py | 2 - src/diffusers/pipelines/chroma/__init__.py | 2 +- .../chroma/pipeline_chroma_radiance.py | 18 ++++----- src/diffusers/utils/dummy_pt_objects.py | 2 + .../dummy_torch_and_transformers_objects.py | 2 + 10 files changed, 55 insertions(+), 56 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 1e4786f181c2..a22c2236fdaa 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -923,8 +923,8 @@ BriaFiboTransformer2DModel, BriaTransformer2DModel, CacheMixin, - ChromaTransformer2DModel, ChromaRadianceTransformer2DModel, + ChromaTransformer2DModel, ChronoEditTransformer3DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, @@ -1133,8 +1133,8 @@ BriaFiboPipeline, BriaPipeline, ChromaImg2ImgPipeline, - ChromaRadiancePipeline, ChromaPipeline, + ChromaRadiancePipeline, ChronoEditPipeline, CLIPImageProjection, CogVideoXFunControlPipeline, diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index f512a5e2499e..42ccecb442ef 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -31,8 +31,8 @@ convert_animatediff_checkpoint_to_diffusers, convert_auraflow_transformer_checkpoint_to_diffusers, convert_autoencoder_dc_checkpoint_to_diffusers, - convert_chroma_transformer_checkpoint_to_diffusers, convert_chroma_radiance_transformer_checkpoint_to_diffusers, + convert_chroma_transformer_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_cosmos_transformer_checkpoint_to_diffusers, convert_flux2_transformer_checkpoint_to_diffusers, diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index b51af1e89267..1058c2329872 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -663,9 +663,7 @@ def infer_diffusers_model_type(checkpoint): model_type = "flux-2-dev" elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["flux"]): - if any( - c in checkpoint for c in ["distilled_guidance_layer.in_proj.bias"] - ): + if any(c in checkpoint for c in ["distilled_guidance_layer.in_proj.bias"]): # Should be updated once a repo exists # if any(h in checkpoint for h in ["nerf_blocks.0.param_generator.bias"]): # model_type = "chroma-radiance" @@ -3556,6 +3554,7 @@ def swap_scale_shift(weight): return converted_state_dict + def convert_chroma_radiance_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {} keys = list(checkpoint.keys()) @@ -3715,30 +3714,20 @@ def swap_scale_shift(weight): # output projections. converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight") converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias") - + # nerf - - converted_state_dict["nerf.nerf_embedder.embedder.0.bias"] = checkpoint.pop( - "nerf_image_embedder.embedder.0.bias" - ) + + converted_state_dict["nerf.nerf_embedder.embedder.0.bias"] = checkpoint.pop("nerf_image_embedder.embedder.0.bias") converted_state_dict["nerf.nerf_embedder.embedder.0.weight"] = checkpoint.pop( "nerf_image_embedder.embedder.0.weight" ) - converted_state_dict["nerf.final_layer.conv.bias"] = checkpoint.pop( - "nerf_final_layer_conv.conv.bias" - ) - converted_state_dict["nerf.final_layer.conv.weight"] = checkpoint.pop( - "nerf_final_layer_conv.conv.weight" - ) - converted_state_dict["nerf.final_layer.norm.weight"] = checkpoint.pop( - "nerf_final_layer_conv.norm.scale" - ) + converted_state_dict["nerf.final_layer.conv.bias"] = checkpoint.pop("nerf_final_layer_conv.conv.bias") + converted_state_dict["nerf.final_layer.conv.weight"] = checkpoint.pop("nerf_final_layer_conv.conv.weight") + converted_state_dict["nerf.final_layer.norm.weight"] = checkpoint.pop("nerf_final_layer_conv.norm.scale") for i in range(num_nerf_layers): block_prefix = f"nerf.blocks.{i}." - converted_state_dict[f"{block_prefix}norm.weight"] = checkpoint.pop( - f"nerf_blocks.{i}.norm.scale" - ) + converted_state_dict[f"{block_prefix}norm.weight"] = checkpoint.pop(f"nerf_blocks.{i}.norm.scale") converted_state_dict[f"{block_prefix}param_generator.bias"] = checkpoint.pop( f"nerf_blocks.{i}.param_generator.bias" ) @@ -3747,16 +3736,13 @@ def swap_scale_shift(weight): ) # patch - - converted_state_dict["x_embedder_patch.bias"] = checkpoint.pop( - "img_in_patch.bias" - ) - converted_state_dict["x_embedder_patch.weight"] = checkpoint.pop( - "img_in_patch.weight" - ) + + converted_state_dict["x_embedder_patch.bias"] = checkpoint.pop("img_in_patch.bias") + converted_state_dict["x_embedder_patch.weight"] = checkpoint.pop("img_in_patch.weight") return converted_state_dict + def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f1042649e871..572f6aeeac5b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -86,7 +86,10 @@ _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"] _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"] - _import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel", "ChromaRadianceTransformer2DModel"] + _import_structure["transformers.transformer_chroma"] = [ + "ChromaTransformer2DModel", + "ChromaRadianceTransformer2DModel", + ] _import_structure["transformers.transformer_chronoedit"] = ["ChronoEditTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] @@ -184,8 +187,8 @@ AuraFlowTransformer2DModel, BriaFiboTransformer2DModel, BriaTransformer2DModel, - ChromaTransformer2DModel, ChromaRadianceTransformer2DModel, + ChromaTransformer2DModel, ChronoEditTransformer3DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 50618a49e562..48ec1a26cfc8 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -35,6 +35,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name + class Nerf(nn.Module): def __init__( self, @@ -44,7 +45,7 @@ def __init__( transformer_hidden_size: int, max_freqs: int, mlp_ratio: int, - eps = 1e-6, + eps=1e-6, ): super().__init__() self.nerf_embedder = NerfEmbedder( @@ -69,6 +70,7 @@ def __init__( eps=eps, ) self.transformer_hidden_size = transformer_hidden_size + def __call__( self, pixels, @@ -77,20 +79,20 @@ def __call__( num_patches, ): batch_size, channels, height, width = pixels.shape - + pixels = nn.functional.unfold(pixels, kernel_size=patch_size, stride=patch_size) pixels = pixels.transpose(1, 2) - + hidden = latents.reshape(batch_size * num_patches, self.transformer_hidden_size) pixels = pixels.reshape(batch_size * num_patches, channels, patch_size**2).transpose(1, 2) - + # Get pixel embeddings latents_dct = self.nerf_embedder(pixels) - + # Pass through blocks for block in self.blocks: latents_dct = block(latents_dct, hidden) - + latents_dct = latents_dct.transpose(1, 2).reshape(batch_size, num_patches, -1).transpose(1, 2) latents_dct = nn.functional.fold( latents_dct, @@ -100,6 +102,7 @@ def __call__( ) return self.final_layer(latents_dct) + class NerfEmbedder(nn.Module): def __init__( self, @@ -111,6 +114,7 @@ def __init__( self.max_freqs = max_freqs self.hidden_size = hidden_size self.embedder = nn.Sequential(nn.Linear(in_channels + max_freqs**2, hidden_size)) + def fetch_pos(self, patch_size) -> torch.Tensor: pos_x = torch.linspace(0, 1, patch_size) pos_y = torch.linspace(0, 1, patch_size) @@ -123,8 +127,9 @@ def fetch_pos(self, patch_size) -> torch.Tensor: coeffs = (1 + freqs_x * freqs_y) ** -1 dct_x = torch.cos(pos_x * freqs_x * torch.pi) dct_y = torch.cos(pos_y * freqs_y * torch.pi) - dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs ** 2) + dct = (dct_x * dct_y * coeffs).view(1, -1, self.max_freqs**2) return dct + def __call__(self, inputs: torch.Tensor) -> torch.Tensor: batch, pixels, channels = inputs.shape patch_size = int(pixels**0.5) @@ -134,6 +139,7 @@ def __call__(self, inputs: torch.Tensor) -> torch.Tensor: inputs = torch.cat((inputs, dct), dim=-1) return self.embedder(inputs) + class NerfGLUBlock(nn.Module): def __init__(self, transformer_hidden_size: int, nerf_hidden_size: int, mlp_ratio, eps): super().__init__() @@ -141,6 +147,7 @@ def __init__(self, transformer_hidden_size: int, nerf_hidden_size: int, mlp_rati self.param_generator = nn.Linear(transformer_hidden_size, total_params) self.norm = RMSNorm(nerf_hidden_size, eps=eps) self.mlp_ratio = mlp_ratio + def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: batch_size, num_x, hidden_size_x = x.shape mlp_params = self.param_generator(s) @@ -156,6 +163,7 @@ def forward(self, x: torch.Tensor, s: torch.Tensor) -> torch.Tensor: x = torch.bmm(torch.nn.functional.silu(torch.bmm(x, fc1_gate)) * torch.bmm(x, fc1_value), fc2) return x + res_x + class NerfFinalLayer(nn.Module): def __init__(self, hidden_size: int, out_channels: int, eps): super().__init__() @@ -166,9 +174,11 @@ def __init__(self, hidden_size: int, out_channels: int, eps): kernel_size=3, padding=1, ) + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(self.norm(x.movedim(1, -1)).movedim(-1, 1)) + class ChromaAdaLayerNormZeroPruned(nn.Module): r""" Norm layer adaptive layer norm zero (adaLN-Zero). @@ -658,7 +668,7 @@ def forward( logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) - + hidden_states = self.x_embedder(hidden_states) timestep = timestep.to(hidden_states.dtype) * 1000 @@ -773,6 +783,7 @@ def forward( return Transformer2DModelOutput(sample=output) + class ChromaRadianceTransformer2DModel( ModelMixin, ConfigMixin, @@ -850,7 +861,7 @@ def __init__( hidden_dim=approximator_hidden_dim, n_layers=approximator_layers, ) - + self.nerf = Nerf( in_channels, nerf_layers, @@ -859,7 +870,7 @@ def __init__( nerf_max_freqs, nerf_mlp_ratio, ) - + self.x_embedder_patch = nn.Conv2d( in_channels, self.inner_dim, @@ -932,7 +943,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - print(self.device) pixels = hidden_states.to(self.device) if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3ff86854626d..16c526f437f2 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -115,8 +115,6 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - print("query", query.shape, "key", key.shape, "value", value.shape) - hidden_states = dispatch_attention_fn( query, key, diff --git a/src/diffusers/pipelines/chroma/__init__.py b/src/diffusers/pipelines/chroma/__init__.py index 1017a1702bae..2f179d36b380 100644 --- a/src/diffusers/pipelines/chroma/__init__.py +++ b/src/diffusers/pipelines/chroma/__init__.py @@ -33,8 +33,8 @@ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_chroma import ChromaPipeline - from .pipeline_chroma_radiance import ChromaRadiancePipeline from .pipeline_chroma_img2img import ChromaImg2ImgPipeline + from .pipeline_chroma_radiance import ChromaRadiancePipeline else: import sys diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index 59f0025f5e7b..cab655d3e045 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -21,11 +21,10 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin -from ...models import AutoencoderKL, ChromaTransformer2DModel, ChromaRadianceTransformer2DModel +from ...models import ChromaRadianceTransformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( USE_PEFT_BACKEND, - deprecate, is_torch_xla_available, logging, replace_example_docstring, @@ -147,6 +146,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps + class ChromaRadiancePipeline( DiffusionPipeline, FluxLoraLoaderMixin, @@ -420,7 +420,6 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): @@ -503,9 +502,7 @@ def prepare_latents( latents=None, patch_size=2, ): - shape = (batch_size, num_channels_latents, height, width) - print(shape) if latents is not None: latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) @@ -518,9 +515,11 @@ def prepare_latents( ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - #latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) + # latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = self._prepare_latent_image_ids(batch_size, height // patch_size, width // patch_size, device, dtype) + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // patch_size, width // patch_size, device, dtype + ) return latents, latent_image_ids @@ -822,7 +821,6 @@ def __call__( batch_size * num_images_per_prompt, ) - # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -887,12 +885,12 @@ def __call__( self._current_timestep = None - # 7. + # 7. if output_type == "latent": image = latents else: - #image = self._unpack_latents(image, height, width) + # image = self._unpack_latents(image, height, width) image = self.image_processor.postprocess(latents, output_type=output_type) # Offload all models diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 97121e50d851..30fac59b58ed 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -662,6 +662,7 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) + class ChromaRadianceTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -676,6 +677,7 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) + class ChronoEditTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 17ab236849b9..687c220af305 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -556,6 +556,7 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + class ChromaRadiancePipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -570,6 +571,7 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + class ChronoEditPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 72748ec3c00579894c3bab600c459605b5c13039 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 16 Dec 2025 16:53:45 -0700 Subject: [PATCH 14/18] add x0, change final layer some --- .../models/transformers/transformer_chroma.py | 22 ++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 48ec1a26cfc8..88b479b279b4 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -93,6 +93,8 @@ def __call__( for block in self.blocks: latents_dct = block(latents_dct, hidden) + latents_dct = self.final_layer.norm(latents_dct) + latents_dct = latents_dct.transpose(1, 2).reshape(batch_size, num_patches, -1).transpose(1, 2) latents_dct = nn.functional.fold( latents_dct, @@ -100,7 +102,7 @@ def __call__( kernel_size=patch_size, stride=patch_size, ) - return self.final_layer(latents_dct) + return self.final_layer.conv(latents_dct) class NerfEmbedder(nn.Module): @@ -134,10 +136,11 @@ def __call__(self, inputs: torch.Tensor) -> torch.Tensor: batch, pixels, channels = inputs.shape patch_size = int(pixels**0.5) input_dtype = inputs.dtype + inputs = inputs.float() dct = self.fetch_pos(patch_size) - dct = dct.repeat(batch, 1, 1).to(dtype=input_dtype, device=inputs.device) + dct = dct.repeat(batch, 1, 1).to(dtype=torch.float32, device=inputs.device) inputs = torch.cat((inputs, dct), dim=-1) - return self.embedder(inputs) + return self.embedder.float()(inputs).to(dtype=input_dtype) class NerfGLUBlock(nn.Module): @@ -818,6 +821,8 @@ class ChromaRadianceTransformer2DModel( `encoder_hidden_states`). axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): The dimensions to use for the rotary positional embeddings. + x0 (`bool`, defaults to `True`): + Whether or not to use x0 prediction """ _supports_gradient_checkpointing = True @@ -844,6 +849,7 @@ def __init__( nerf_hidden_dim: int = 64, nerf_max_freqs: int = 8, nerf_mlp_ratio: int = 4, + x0: bool = True, ): super().__init__() self.out_channels = out_channels or in_channels @@ -904,6 +910,7 @@ def __init__( ) self.gradient_checkpointing = False + self.x0 = x0 def forward( self, @@ -958,6 +965,10 @@ def forward( logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) + + if self.config.patch_size == 32: + _, _, height, width = hidden_states.shape + hidden_states = nn.functional.interpolate(hidden_states, size=(height//2, width//2), mode="nearest") hidden_states = self.x_embedder_patch(hidden_states) num_patches = hidden_states.shape[2] * hidden_states.shape[3] hidden_states = hidden_states.flatten(2).transpose(1, 2) @@ -1063,6 +1074,11 @@ def forward( output = self.nerf(pixels, hidden_states, self.config.patch_size, num_patches) + # using x0 prediction + + if self.x0: + output = (pixels - output) / (timestep / 1000).view(-1,1,1,1) + if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) From 5a4312a956bab8fccbbc59b9958a7a56ba594340 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 16 Dec 2025 17:07:28 -0700 Subject: [PATCH 15/18] make style make quality --- src/diffusers/models/transformers/transformer_chroma.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py index 88b479b279b4..515c3df14a96 100644 --- a/src/diffusers/models/transformers/transformer_chroma.py +++ b/src/diffusers/models/transformers/transformer_chroma.py @@ -968,7 +968,7 @@ def forward( if self.config.patch_size == 32: _, _, height, width = hidden_states.shape - hidden_states = nn.functional.interpolate(hidden_states, size=(height//2, width//2), mode="nearest") + hidden_states = nn.functional.interpolate(hidden_states, size=(height // 2, width // 2), mode="nearest") hidden_states = self.x_embedder_patch(hidden_states) num_patches = hidden_states.shape[2] * hidden_states.shape[3] hidden_states = hidden_states.flatten(2).transpose(1, 2) @@ -1077,7 +1077,7 @@ def forward( # using x0 prediction if self.x0: - output = (pixels - output) / (timestep / 1000).view(-1,1,1,1) + output = (pixels - output) / (timestep / 1000).view(-1, 1, 1, 1) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer From 414c63ee202d58245f3bfa067e5a99b324b62a33 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 16 Dec 2025 17:12:49 -0700 Subject: [PATCH 16/18] removed copied from (modified) --- src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index cab655d3e045..589bc60b87a4 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -489,7 +489,6 @@ def _unpack_latents(latents, height, width): return latents - # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_latents def prepare_latents( self, batch_size, From e3ee921641f1eface84c3af73ba2ea80c80ec2ca Mon Sep 17 00:00:00 2001 From: user Date: Tue, 16 Dec 2025 22:37:32 -0700 Subject: [PATCH 17/18] fix attention mask --- src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index 589bc60b87a4..61cfa2082ef6 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -839,6 +839,7 @@ def __call__( encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, + attention_mask=attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] @@ -852,6 +853,7 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=latent_image_ids, + attention_mask=negative_attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] From 3814a6adedd5230664f5c729efc91fc18ea4e6d5 Mon Sep 17 00:00:00 2001 From: user Date: Tue, 16 Dec 2025 23:23:15 -0700 Subject: [PATCH 18/18] updates --- .../chroma/pipeline_chroma_radiance.py | 28 +++---------------- 1 file changed, 4 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py index 61cfa2082ef6..dba6d9619681 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_radiance.py @@ -50,12 +50,12 @@ Examples: ```py >>> import torch - >>> from diffusers import ChromaPipeline + >>> from diffusers import ChromaRadiancePipeline - >>> model_id = "lodestones/Chroma1-HD" - >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors" + >>> model_id = "lodestones/Chroma1-Radiance" + >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/latest_x0.safetensors" >>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16) - >>> pipe = ChromaPipeline.from_pretrained( + >>> pipe = ChromaRadiancePipeline.from_pretrained( ... model_id, ... transformer=transformer, ... torch_dtype=torch.bfloat16, @@ -470,25 +470,6 @@ def _prepare_latent_image_ids(batch_size, height, width, device, dtype): return latent_image_ids.to(device=device, dtype=dtype) - @staticmethod - def _pack_latents(latents, batch_size, num_channels_latents, height, width): - latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) - latents = latents.permute(0, 2, 4, 1, 3, 5) - latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4) - - return latents - - @staticmethod - def _unpack_latents(latents, height, width): - batch_size, num_patches, channels = latents.shape - - latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) - latents = latents.permute(0, 3, 1, 4, 2, 5) - - latents = latents.reshape(batch_size, channels // (2 * 2), height, width) - - return latents - def prepare_latents( self, batch_size, @@ -891,7 +872,6 @@ def __call__( if output_type == "latent": image = latents else: - # image = self._unpack_latents(image, height, width) image = self.image_processor.postprocess(latents, output_type=output_type) # Offload all models