Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions conditioner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct Conditioner {
virtual void free_params_buffer() = 0;
virtual void get_param_tensors(std::map<std::string, struct ggml_tensor*>& tensors) = 0;
virtual size_t get_params_buffer_size() = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {}
virtual std::tuple<SDCondition, std::vector<bool>> get_learned_condition_with_trigger(ggml_context* work_ctx,
int n_threads,
Expand Down Expand Up @@ -115,6 +116,13 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
return buffer_size;
}

void set_flash_attention_enabled(bool enabled) override {
text_model->set_flash_attention_enabled(enabled);
if (sd_version_is_sdxl(version)) {
text_model2->set_flash_attention_enabled(enabled);
}
}

void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
text_model->set_weight_adapter(adapter);
if (sd_version_is_sdxl(version)) {
Expand Down Expand Up @@ -783,6 +791,18 @@ struct SD3CLIPEmbedder : public Conditioner {
return buffer_size;
}

void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (clip_g) {
clip_g->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}

void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
Expand Down Expand Up @@ -1191,6 +1211,15 @@ struct FluxCLIPEmbedder : public Conditioner {
return buffer_size;
}

void set_flash_attention_enabled(bool enabled) override {
if (clip_l) {
clip_l->set_flash_attention_enabled(enabled);
}
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}

void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) {
if (clip_l) {
clip_l->set_weight_adapter(adapter);
Expand Down Expand Up @@ -1440,6 +1469,12 @@ struct T5CLIPEmbedder : public Conditioner {
return buffer_size;
}

void set_flash_attention_enabled(bool enabled) override {
if (t5) {
t5->set_flash_attention_enabled(enabled);
}
}

void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (t5) {
t5->set_weight_adapter(adapter);
Expand Down Expand Up @@ -1650,6 +1685,10 @@ struct LLMEmbedder : public Conditioner {
return buffer_size;
}

void set_flash_attention_enabled(bool enabled) override {
llm->set_flash_attention_enabled(enabled);
}

void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter) override {
if (llm) {
llm->set_weight_adapter(adapter);
Expand Down
14 changes: 7 additions & 7 deletions diffusion_model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct DiffusionModel {
virtual size_t get_params_buffer_size() = 0;
virtual void set_weight_adapter(const std::shared_ptr<WeightAdapter>& adapter){};
virtual int64_t get_adm_in_channels() = 0;
virtual void set_flash_attn_enabled(bool enabled) = 0;
virtual void set_flash_attention_enabled(bool enabled) = 0;
virtual void set_circular_axes(bool circular_x, bool circular_y) = 0;
};

Expand Down Expand Up @@ -84,7 +84,7 @@ struct UNetModel : public DiffusionModel {
return unet.unet.adm_in_channels;
}

void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
unet.set_flash_attention_enabled(enabled);
}

Expand Down Expand Up @@ -149,7 +149,7 @@ struct MMDiTModel : public DiffusionModel {
return 768 + 1280;
}

void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
mmdit.set_flash_attention_enabled(enabled);
}

Expand Down Expand Up @@ -215,7 +215,7 @@ struct FluxModel : public DiffusionModel {
return 768;
}

void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
flux.set_flash_attention_enabled(enabled);
}

Expand Down Expand Up @@ -286,7 +286,7 @@ struct WanModel : public DiffusionModel {
return 768;
}

void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
wan.set_flash_attention_enabled(enabled);
}

Expand Down Expand Up @@ -357,7 +357,7 @@ struct QwenImageModel : public DiffusionModel {
return 768;
}

void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
qwen_image.set_flash_attention_enabled(enabled);
}

Expand Down Expand Up @@ -424,7 +424,7 @@ struct ZImageModel : public DiffusionModel {
return 768;
}

void set_flash_attn_enabled(bool enabled) {
void set_flash_attention_enabled(bool enabled) {
z_image.set_flash_attention_enabled(enabled);
}

Expand Down
3 changes: 2 additions & 1 deletion examples/cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ Context Options:
--control-net-cpu keep controlnet in cpu (for low vram)
--clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram)
--diffusion-fa use flash attention in the diffusion model
--fa use flash attention
--diffusion-fa use flash attention in the diffusion model only
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions
Expand Down
9 changes: 8 additions & 1 deletion examples/common/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ struct SDContextParams {
bool control_net_cpu = false;
bool clip_on_cpu = false;
bool vae_on_cpu = false;
bool flash_attn = false;
bool diffusion_flash_attn = false;
bool diffusion_conv_direct = false;
bool vae_conv_direct = false;
Expand Down Expand Up @@ -615,9 +616,13 @@ struct SDContextParams {
"--vae-on-cpu",
"keep vae in cpu (for low vram)",
true, &vae_on_cpu},
{"",
"--fa",
"use flash attention",
true, &flash_attn},
{"",
"--diffusion-fa",
"use flash attention in the diffusion model",
"use flash attention in the diffusion model only",
true, &diffusion_flash_attn},
{"",
"--diffusion-conv-direct",
Expand Down Expand Up @@ -904,6 +909,7 @@ struct SDContextParams {
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
<< " flash_attn: " << (flash_attn ? "true" : "false") << ",\n"
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
Expand Down Expand Up @@ -968,6 +974,7 @@ struct SDContextParams {
clip_on_cpu,
control_net_cpu,
vae_on_cpu,
flash_attn,
diffusion_flash_attn,
taesd_preview,
diffusion_conv_direct,
Expand Down
3 changes: 2 additions & 1 deletion examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ Context Options:
--clip-on-cpu keep clip in cpu (for low vram)
--vae-on-cpu keep vae in cpu (for low vram)
--mmap whether to memory-map model
--diffusion-fa use flash attention in the diffusion model
--fa use flash attention
--diffusion-fa use flash attention in the diffusion model only
--diffusion-conv-direct use ggml_conv2d_direct in the diffusion model
--vae-conv-direct use ggml_conv2d_direct in the vae model
--circular enable circular padding for convolutions
Expand Down
2 changes: 1 addition & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2623,7 +2623,7 @@ class MultiheadAttention : public GGMLBlock {
v = v_proj->forward(ctx, x);
}

x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask); // [N, n_token, embed_dim]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, n_head, mask, false); // [N, n_token, embed_dim]

x = out_proj->forward(ctx, x); // [N, n_token, embed_dim]
return x;
Expand Down
34 changes: 25 additions & 9 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ class StableDiffusionGGML {
}
}
if (is_chroma) {
if (sd_ctx_params->diffusion_flash_attn && sd_ctx_params->chroma_use_dit_mask) {
if ((sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) && sd_ctx_params->chroma_use_dit_mask) {
LOG_WARN(
"!!!It looks like you are using Chroma with flash attention. "
"This is currently unsupported. "
Expand Down Expand Up @@ -571,14 +571,6 @@ class StableDiffusionGGML {
}
}

if (sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attn_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attn_enabled(true);
}
}

cond_stage_model->alloc_params_buffer();
cond_stage_model->get_param_tensors(tensors);

Expand Down Expand Up @@ -725,6 +717,28 @@ class StableDiffusionGGML {
pmid_model->get_param_tensors(tensors, "pmid");
}

if (sd_ctx_params->flash_attn) {
LOG_INFO("Using flash attention");
cond_stage_model->set_flash_attention_enabled(true);
if (clip_vision) {
clip_vision->set_flash_attention_enabled(true);
}
if (first_stage_model) {
first_stage_model->set_flash_attention_enabled(true);
}
if (tae_first_stage) {
tae_first_stage->set_flash_attention_enabled(true);
}
}

if (sd_ctx_params->flash_attn || sd_ctx_params->diffusion_flash_attn) {
LOG_INFO("Using flash attention in the diffusion model");
diffusion_model->set_flash_attention_enabled(true);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_flash_attention_enabled(true);
}
}

diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
if (high_noise_diffusion_model) {
high_noise_diffusion_model->set_circular_axes(sd_ctx_params->circular_x, sd_ctx_params->circular_y);
Expand Down Expand Up @@ -2942,6 +2956,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"keep_clip_on_cpu: %s\n"
"keep_control_net_on_cpu: %s\n"
"keep_vae_on_cpu: %s\n"
"flash_attn: %s\n"
"diffusion_flash_attn: %s\n"
"circular_x: %s\n"
"circular_y: %s\n"
Expand Down Expand Up @@ -2973,6 +2988,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
BOOL_STR(sd_ctx_params->keep_control_net_on_cpu),
BOOL_STR(sd_ctx_params->keep_vae_on_cpu),
BOOL_STR(sd_ctx_params->flash_attn),
BOOL_STR(sd_ctx_params->diffusion_flash_attn),
BOOL_STR(sd_ctx_params->circular_x),
BOOL_STR(sd_ctx_params->circular_y),
Expand Down
1 change: 1 addition & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ typedef struct {
bool keep_clip_on_cpu;
bool keep_control_net_on_cpu;
bool keep_vae_on_cpu;
bool flash_attn;
bool diffusion_flash_attn;
bool tae_preview_only;
bool diffusion_conv_direct;
Expand Down
2 changes: 1 addition & 1 deletion vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class AttnBlock : public UnaryBlock {
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
}

h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false);
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled);

if (use_linear) {
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]
Expand Down
4 changes: 2 additions & 2 deletions wan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@ namespace WAN {
auto v = qkv_vec[2];
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]

v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, false); // [t, h * w, c]
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, true, ctx->flash_attn_enabled); // [t, h * w, c]

x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]
Expand Down
Loading