Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions src/diffusers/models/controlnets/controlnet_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,15 +517,75 @@ def __init__(

@classmethod
def from_transformer(cls, controlnet, transformer):
controlnet.t_scale = transformer.t_scale
controlnet.t_embedder = transformer.t_embedder
controlnet.all_x_embedder = transformer.all_x_embedder
controlnet.cap_embedder = transformer.cap_embedder
controlnet.rope_embedder = transformer.rope_embedder
controlnet.noise_refiner = transformer.noise_refiner
controlnet.context_refiner = transformer.context_refiner
controlnet.x_pad_token = transformer.x_pad_token
controlnet.cap_pad_token = transformer.cap_pad_token
config = transformer.config

# Scalar value — immutable, direct assignment is safe
controlnet.t_scale = config.t_scale

# nn.Modules — instantiate matching architecture, then load weights via state_dict.
# This follows the same pattern used by FluxControlNetModel.from_transformer and
# QwenImageControlNetModel.from_transformer.
controlnet.t_embedder = TimestepEmbedder(min(config.dim, ADALN_EMBED_DIM), mid_size=1024)
controlnet.t_embedder.load_state_dict(transformer.t_embedder.state_dict())

all_x_embedder = {}
for patch_size, f_patch_size in zip(config.all_patch_size, config.all_f_patch_size):
all_x_embedder[f"{patch_size}-{f_patch_size}"] = nn.Linear(
f_patch_size * patch_size * patch_size * config.in_channels, config.dim, bias=True
)
controlnet.all_x_embedder = nn.ModuleDict(all_x_embedder)
controlnet.all_x_embedder.load_state_dict(transformer.all_x_embedder.state_dict())

controlnet.cap_embedder = nn.Sequential(
RMSNorm(config.cap_feat_dim, eps=config.norm_eps),
nn.Linear(config.cap_feat_dim, config.dim, bias=True),
)
controlnet.cap_embedder.load_state_dict(transformer.cap_embedder.state_dict())

controlnet.noise_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
1000 + layer_id,
config.dim,
config.n_heads,
config.n_kv_heads,
config.norm_eps,
config.qk_norm,
modulation=True,
)
for layer_id in range(config.n_refiner_layers)
]
)
controlnet.noise_refiner.load_state_dict(transformer.noise_refiner.state_dict())

controlnet.context_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
layer_id,
config.dim,
config.n_heads,
config.n_kv_heads,
config.norm_eps,
config.qk_norm,
modulation=False,
)
for layer_id in range(config.n_refiner_layers)
]
)
controlnet.context_refiner.load_state_dict(transformer.context_refiner.state_dict())

# nn.Parameters — clone to create independent copies
controlnet.x_pad_token = nn.Parameter(transformer.x_pad_token.data.clone())
controlnet.cap_pad_token = nn.Parameter(transformer.cap_pad_token.data.clone())

# RopeEmbedder — not an nn.Module, has no learnable weights.
# Create a fresh instance with the same config.
controlnet.rope_embedder = RopeEmbedder(
theta=config.rope_theta,
axes_dims=list(config.axes_dims),
axes_lens=list(config.axes_lens),
)

return controlnet

@staticmethod
Expand Down
197 changes: 197 additions & 0 deletions tests/models/controlnets/test_models_controlnet_z_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers.models.controlnets.controlnet_z_image import ZImageControlNetModel
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel


class ZImageControlNetFromTransformerTests(unittest.TestCase):
"""Tests for ZImageControlNetModel.from_transformer weight independence.

Verifies that from_transformer creates independent copies of weights,
so modifying the controlnet does not affect the original transformer.
Regression test for https://github.com/huggingface/diffusers/issues/13077
"""

def get_transformer_config(self):
return {
"all_patch_size": (2,),
"all_f_patch_size": (1,),
"in_channels": 16,
"dim": 256,
"n_layers": 2,
"n_refiner_layers": 2,
"n_heads": 8,
"n_kv_heads": 8,
"cap_feat_dim": 256,
"axes_dims": [8, 12, 12],
"axes_lens": [64, 64, 64],
}

def get_controlnet_config(self):
return {
"control_layers_places": [0, 1],
"control_refiner_layers_places": [0, 1],
"add_control_noise_refiner": "control_noise_refiner",
"control_in_dim": 16,
"dim": 256,
"n_refiner_layers": 2,
"n_heads": 8,
"n_kv_heads": 8,
}

def test_t_embedder_independence(self):
"""Modifying controlnet.t_embedder should not affect transformer.t_embedder."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

original = transformer.t_embedder.mlp[0].weight.clone()
torch.nn.init.constant_(controlnet.t_embedder.mlp[0].weight, 42.0)

self.assertTrue(
torch.equal(transformer.t_embedder.mlp[0].weight, original),
"Transformer t_embedder weights were corrupted by controlnet modification",
)

def test_cap_embedder_independence(self):
"""Modifying controlnet.cap_embedder should not affect transformer.cap_embedder."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

original = transformer.cap_embedder[1].weight.clone()
torch.nn.init.constant_(controlnet.cap_embedder[1].weight, 42.0)

self.assertTrue(
torch.equal(transformer.cap_embedder[1].weight, original),
"Transformer cap_embedder weights were corrupted by controlnet modification",
)

def test_all_x_embedder_independence(self):
"""Modifying controlnet.all_x_embedder should not affect transformer.all_x_embedder."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

first_key = list(transformer.all_x_embedder.keys())[0]
original = transformer.all_x_embedder[first_key].weight.clone()
torch.nn.init.constant_(controlnet.all_x_embedder[first_key].weight, 42.0)

self.assertTrue(
torch.equal(transformer.all_x_embedder[first_key].weight, original),
"Transformer all_x_embedder weights were corrupted by controlnet modification",
)

def test_noise_refiner_independence(self):
"""Modifying controlnet.noise_refiner should not affect transformer.noise_refiner."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

original = list(transformer.noise_refiner.parameters())[0].clone()
torch.nn.init.constant_(list(controlnet.noise_refiner.parameters())[0], 42.0)

self.assertTrue(
torch.equal(list(transformer.noise_refiner.parameters())[0], original),
"Transformer noise_refiner weights were corrupted by controlnet modification",
)

def test_context_refiner_independence(self):
"""Modifying controlnet.context_refiner should not affect transformer.context_refiner."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

original = list(transformer.context_refiner.parameters())[0].clone()
torch.nn.init.constant_(list(controlnet.context_refiner.parameters())[0], 42.0)

self.assertTrue(
torch.equal(list(transformer.context_refiner.parameters())[0], original),
"Transformer context_refiner weights were corrupted by controlnet modification",
)

def test_x_pad_token_independence(self):
"""Modifying controlnet.x_pad_token should not affect transformer.x_pad_token."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

original = transformer.x_pad_token.data.clone()
controlnet.x_pad_token.data.fill_(99.0)

self.assertTrue(
torch.equal(transformer.x_pad_token.data, original),
"Transformer x_pad_token was corrupted by controlnet modification",
)

def test_cap_pad_token_independence(self):
"""Modifying controlnet.cap_pad_token should not affect transformer.cap_pad_token."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

original = transformer.cap_pad_token.data.clone()
controlnet.cap_pad_token.data.fill_(99.0)

self.assertTrue(
torch.equal(transformer.cap_pad_token.data, original),
"Transformer cap_pad_token was corrupted by controlnet modification",
)

def test_rope_embedder_independence(self):
"""Controlnet.rope_embedder should be a different instance from transformer.rope_embedder."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

self.assertIsNot(
controlnet.rope_embedder,
transformer.rope_embedder,
"Controlnet and transformer share the same rope_embedder instance",
)

def test_weights_correctly_copied(self):
"""Verify that weights are correctly copied from transformer to controlnet."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

modules_to_check = ["t_embedder", "all_x_embedder", "cap_embedder", "noise_refiner", "context_refiner"]

for name in modules_to_check:
t_sd = getattr(transformer, name).state_dict()
c_sd = getattr(controlnet, name).state_dict()
for key in t_sd:
self.assertTrue(
torch.equal(t_sd[key], c_sd[key]),
f"Weights not correctly copied for {name}.{key}",
)

def test_t_scale_correctly_copied(self):
"""Verify that t_scale is correctly copied from transformer config."""
transformer = ZImageTransformer2DModel(**self.get_transformer_config())
controlnet = ZImageControlNetModel(**self.get_controlnet_config())
controlnet = ZImageControlNetModel.from_transformer(controlnet=controlnet, transformer=transformer)

self.assertEqual(
controlnet.t_scale,
transformer.config.t_scale,
"t_scale not correctly copied from transformer config",
)