diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 7dac03738..efd320f0a 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -1577,7 +1577,7 @@ struct WeightAdapter { bool force_prec_f32 = false; float scale = 1.f; } linear; - struct { + struct conv2d_params_t{ int s0 = 1; int s1 = 1; int p0 = 0; @@ -2630,4 +2630,103 @@ class MultiheadAttention : public GGMLBlock { } }; +__STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward( + struct ggml_context* ctx, + struct ggml_tensor* h, // Input: [q, batch] or [W, H, q, batch] + struct ggml_tensor* w1, // Outer C (Full rank) + struct ggml_tensor* w1a, // Outer A (Low rank part 1) + struct ggml_tensor* w1b, // Outer B (Low rank part 2) + struct ggml_tensor* w2, // Inner BA (Full rank) + struct ggml_tensor* w2a, // Inner A (Low rank part 1) + struct ggml_tensor* w2b, // Inner B (Low rank part 2) + bool is_conv, + WeightAdapter::ForwardParams::conv2d_params_t conv_params, + float scale) { + + GGML_ASSERT((w1 != NULL || (w1a != NULL && w1b != NULL))); + GGML_ASSERT((w2 != NULL || (w2a != NULL && w2b != NULL))); + + int uq = (w1 != NULL) ? (int)w1->ne[0] : (int)w1a->ne[0]; + int up = (w1 != NULL) ? (int)w1->ne[1] : (int)w1b->ne[1]; + + int q_actual = is_conv ? (int)h->ne[2] : (int)h->ne[0]; + int vq = q_actual / uq; + + int vp = (w2 != NULL) ? (is_conv ? (int)w2->ne[3] : (int)w2->ne[1]) + : (int)w2a->ne[1]; + GGML_ASSERT(q_actual == (uq * vq) && "Input dimension mismatch for LoKR split"); + + struct ggml_tensor* hb; + + if (!is_conv) { + int batch = (int)h->ne[1]; + struct ggml_tensor* h_mat = ggml_reshape_2d(ctx, h, vq, uq * batch); + + if (w2 != NULL) { + hb = ggml_mul_mat(ctx, w2, h_mat); + } else { + hb = ggml_mul_mat(ctx, w2b, ggml_mul_mat(ctx, w2a, h_mat)); + } + + struct ggml_tensor* hb_unbundled = ggml_reshape_3d(ctx, hb, vp, uq, batch); + struct ggml_tensor* hb_t = ggml_transpose(ctx, hb_unbundled); + + struct ggml_tensor* hc; + if (w1 != NULL) { + hc = ggml_mul_mat(ctx, w1, hb_t); + } else { + hc = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t)); + } + + struct ggml_tensor* hc_t = ggml_transpose(ctx, hc); + struct ggml_tensor* out = ggml_reshape_2d(ctx, ggml_cont(ctx, hc_t), up * vp, batch); + return ggml_scale(ctx, out, scale); + + } else { + int batch = (int)h->ne[3]; + + // Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch] + struct ggml_tensor* h_grouped = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch); + + if (w2 != NULL) { + hb = ggml_conv_2d(ctx, w2, h_grouped, conv_params.s0, conv_params.s1, + conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1); + } else { + // Low-rank decomposition: w2b is the spatial kernel, w2a is the 1x1 projection + // Inner LoRA: w2b is the spatial/down-project, w2a is the 1x1 up-project + int rank = (int)w2b->ne[1]; + int k = (int)sqrt(w2b->ne[0] / vq); + + struct ggml_tensor* w2b_4d = (ggml_n_dims(w2b) < 3) ? ggml_reshape_4d(ctx, w2b, k, k, vq, rank) : w2b; + struct ggml_tensor* w2a_4d = (ggml_n_dims(w2a) < 3) ? ggml_reshape_4d(ctx, w2a, 1, 1, rank, vp) : w2a; + + struct ggml_tensor* ha = ggml_conv_2d(ctx, w2b_4d, h_grouped, conv_params.s0, conv_params.s1, + conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1); + hb = ggml_conv_2d(ctx, w2a_4d, ha, 1, 1, 0, 0, 1, 1); + } + + int w_out = (int)hb->ne[0]; + int h_out = (int)hb->ne[1]; + + struct ggml_tensor* hb_flat = ggml_reshape_3d(ctx, hb, w_out * h_out * vp, uq, batch); + struct ggml_tensor* hb_t = ggml_transpose(ctx, hb_flat); + + struct ggml_tensor* hc; + struct ggml_tensor* w1_mat = (w1 != NULL) ? ggml_reshape_2d(ctx, w1, uq, up) : NULL; + + if (w1_mat != NULL) { + hc = ggml_mul_mat(ctx, w1_mat, hb_t); + } else { + hc = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t)); + } + + struct ggml_tensor* hc_t = ggml_transpose(ctx, hc); + struct ggml_tensor* hc_res = ggml_reshape_4d(ctx, ggml_cont(ctx, hc_t), vp, w_out * h_out, up, batch); + struct ggml_tensor* hc_perm = ggml_permute(ctx, hc_res, 1, 2, 0, 3); + struct ggml_tensor* out = ggml_reshape_4d(ctx, ggml_cont(ctx, hc_perm), w_out, h_out, up * vp, batch); + + return ggml_scale(ctx, out, scale); + } +} + #endif // __GGML_EXTEND__HPP__ diff --git a/lora.hpp b/lora.hpp index e5d9906ff..fd461086a 100644 --- a/lora.hpp +++ b/lora.hpp @@ -483,7 +483,7 @@ struct LoraModel : public GGMLRunner { diff = get_loha_weight_diff(model_tensor_name, ctx); } // lokr - if (diff == nullptr) { + if (diff == nullptr && with_lora) { diff = get_lokr_weight_diff(model_tensor_name, ctx); } if (diff != nullptr) { @@ -501,6 +501,8 @@ struct LoraModel : public GGMLRunner { return diff; } + + ggml_tensor* get_out_diff(ggml_context* ctx, ggml_tensor* x, WeightAdapter::ForwardParams forward_params, @@ -514,6 +516,115 @@ struct LoraModel : public GGMLRunner { } else { key = model_tensor_name + "." + std::to_string(index); } + bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; + + + std::string lokr_w1_name = "lora." + key + ".lokr_w1"; + std::string lokr_w1_a_name = "lora." + key + ".lokr_w1_a"; + // if either of these is found, then we have a lokr lora + auto iter = lora_tensors.find(lokr_w1_name); + auto iter_a = lora_tensors.find(lokr_w1_a_name); + if (iter != lora_tensors.end() || iter_a != lora_tensors.end()) { + std::string lokr_w1_b_name = "lora." + key + ".lokr_w1_b"; + std::string lokr_w2_name = "lora." + key + ".lokr_w2"; + std::string lokr_w2_a_name = "lora." + key + ".lokr_w2_a"; + std::string lokr_w2_b_name = "lora." + key + ".lokr_w2_b"; + std::string alpha_name = "lora." + key + ".alpha"; + + ggml_tensor* lokr_w1 = nullptr; + ggml_tensor* lokr_w1_a = nullptr; + ggml_tensor* lokr_w1_b = nullptr; + ggml_tensor* lokr_w2 = nullptr; + ggml_tensor* lokr_w2_a = nullptr; + ggml_tensor* lokr_w2_b = nullptr; + + if (iter != lora_tensors.end()) { + lokr_w1 = iter->second; + if (is_conv2d && lokr_w1->type != GGML_TYPE_F16) { + lokr_w1 = ggml_cast(ctx, lokr_w1, GGML_TYPE_F16); + } + } + iter = iter_a; + if (iter != lora_tensors.end()) { + lokr_w1_a = iter->second; + if (is_conv2d && lokr_w1_a->type != GGML_TYPE_F16) { + lokr_w1_a = ggml_cast(ctx, lokr_w1_a, GGML_TYPE_F16); + } + } + iter = lora_tensors.find(lokr_w1_b_name); + if (iter != lora_tensors.end()) { + lokr_w1_b = iter->second; + if (is_conv2d && lokr_w1_b->type != GGML_TYPE_F16) { + lokr_w1_b = ggml_cast(ctx, lokr_w1_b, GGML_TYPE_F16); + } + } + + iter = lora_tensors.find(lokr_w2_name); + if (iter != lora_tensors.end()) { + lokr_w2 = iter->second; + if (is_conv2d && lokr_w2->type != GGML_TYPE_F16) { + lokr_w2 = ggml_cast(ctx, lokr_w2, GGML_TYPE_F16); + } + } + iter = lora_tensors.find(lokr_w2_a_name); + if (iter != lora_tensors.end()) { + lokr_w2_a = iter->second; + if (is_conv2d && lokr_w2_a->type != GGML_TYPE_F16) { + lokr_w2_a = ggml_cast(ctx, lokr_w2_a, GGML_TYPE_F16); + } + } + iter = lora_tensors.find(lokr_w2_b_name); + if (iter != lora_tensors.end()) { + lokr_w2_b = iter->second; + if (is_conv2d && lokr_w2_b->type != GGML_TYPE_F16) { + lokr_w2_b = ggml_cast(ctx, lokr_w2_b, GGML_TYPE_F16); + } + } + + int rank = 1; + if (lokr_w1_b) { + rank = lokr_w1_b->ne[ggml_n_dims(lokr_w1_b) - 1]; + } + if (lokr_w2_b) { + rank = lokr_w2_b->ne[ggml_n_dims(lokr_w2_b) - 1]; + } + + float scale_value = 1.0f; + iter = lora_tensors.find(alpha_name); + if (iter != lora_tensors.end()) { + float alpha = ggml_ext_backend_tensor_get_f32(iter->second); + scale_value = alpha / rank; + applied_lora_tensors.insert(alpha_name); + } + + if (rank == 1) { + scale_value = 1.0f; + } + scale_value *= multiplier; + + auto curr_out_diff = ggml_ext_lokr_forward(ctx, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value); + if (out_diff == nullptr) { + out_diff = curr_out_diff; + } else { + out_diff = ggml_concat(ctx, out_diff, curr_out_diff, 0); + } + + if(lokr_w1) applied_lora_tensors.insert(lokr_w1_name); + if(lokr_w1_a) applied_lora_tensors.insert(lokr_w1_a_name); + if(lokr_w1_b) applied_lora_tensors.insert(lokr_w1_b_name); + if(lokr_w2) applied_lora_tensors.insert(lokr_w2_name); + if(lokr_w2_a) applied_lora_tensors.insert(lokr_w2_name); + if(lokr_w2_b) applied_lora_tensors.insert(lokr_w2_b_name); + applied_lora_tensors.insert(alpha_name); + + + index++; + continue; + } + + // not a lork, normal lora path + + std::string lora_down_name = "lora." + key + ".lora_down"; std::string lora_up_name = "lora." + key + ".lora_up"; @@ -525,9 +636,8 @@ struct LoraModel : public GGMLRunner { ggml_tensor* lora_mid = nullptr; ggml_tensor* lora_down = nullptr; - bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D; - auto iter = lora_tensors.find(lora_up_name); + iter = lora_tensors.find(lora_up_name); if (iter != lora_tensors.end()) { lora_up = iter->second; if (is_conv2d && lora_up->type != GGML_TYPE_F16) {