Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 100 additions & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,7 @@ struct WeightAdapter {
bool force_prec_f32 = false;
float scale = 1.f;
} linear;
struct {
struct conv2d_params_t{
int s0 = 1;
int s1 = 1;
int p0 = 0;
Expand Down Expand Up @@ -2630,4 +2630,103 @@ class MultiheadAttention : public GGMLBlock {
}
};

__STATIC_INLINE__ struct ggml_tensor* ggml_ext_lokr_forward(
struct ggml_context* ctx,
struct ggml_tensor* h, // Input: [q, batch] or [W, H, q, batch]
struct ggml_tensor* w1, // Outer C (Full rank)
struct ggml_tensor* w1a, // Outer A (Low rank part 1)
struct ggml_tensor* w1b, // Outer B (Low rank part 2)
struct ggml_tensor* w2, // Inner BA (Full rank)
struct ggml_tensor* w2a, // Inner A (Low rank part 1)
struct ggml_tensor* w2b, // Inner B (Low rank part 2)
bool is_conv,
WeightAdapter::ForwardParams::conv2d_params_t conv_params,
float scale) {

GGML_ASSERT((w1 != NULL || (w1a != NULL && w1b != NULL)));
GGML_ASSERT((w2 != NULL || (w2a != NULL && w2b != NULL)));

int uq = (w1 != NULL) ? (int)w1->ne[0] : (int)w1a->ne[0];
int up = (w1 != NULL) ? (int)w1->ne[1] : (int)w1b->ne[1];

int q_actual = is_conv ? (int)h->ne[2] : (int)h->ne[0];
int vq = q_actual / uq;

int vp = (w2 != NULL) ? (is_conv ? (int)w2->ne[3] : (int)w2->ne[1])
: (int)w2a->ne[1];
GGML_ASSERT(q_actual == (uq * vq) && "Input dimension mismatch for LoKR split");

struct ggml_tensor* hb;

if (!is_conv) {
int batch = (int)h->ne[1];
struct ggml_tensor* h_mat = ggml_reshape_2d(ctx, h, vq, uq * batch);

if (w2 != NULL) {
hb = ggml_mul_mat(ctx, w2, h_mat);
} else {
hb = ggml_mul_mat(ctx, w2b, ggml_mul_mat(ctx, w2a, h_mat));
}

struct ggml_tensor* hb_unbundled = ggml_reshape_3d(ctx, hb, vp, uq, batch);
struct ggml_tensor* hb_t = ggml_transpose(ctx, hb_unbundled);

struct ggml_tensor* hc;
if (w1 != NULL) {
hc = ggml_mul_mat(ctx, w1, hb_t);
} else {
hc = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t));
}

struct ggml_tensor* hc_t = ggml_transpose(ctx, hc);
struct ggml_tensor* out = ggml_reshape_2d(ctx, ggml_cont(ctx, hc_t), up * vp, batch);
return ggml_scale(ctx, out, scale);

} else {
int batch = (int)h->ne[3];

// Reshape input: [W, H, vq*uq, batch] -> [W, H, vq, uq * batch]
struct ggml_tensor* h_grouped = ggml_reshape_4d(ctx, h, h->ne[0], h->ne[1], vq, uq * batch);

if (w2 != NULL) {
hb = ggml_conv_2d(ctx, w2, h_grouped, conv_params.s0, conv_params.s1,
conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1);
} else {
// Low-rank decomposition: w2b is the spatial kernel, w2a is the 1x1 projection
// Inner LoRA: w2b is the spatial/down-project, w2a is the 1x1 up-project
int rank = (int)w2b->ne[1];
int k = (int)sqrt(w2b->ne[0] / vq);

struct ggml_tensor* w2b_4d = (ggml_n_dims(w2b) < 3) ? ggml_reshape_4d(ctx, w2b, k, k, vq, rank) : w2b;
struct ggml_tensor* w2a_4d = (ggml_n_dims(w2a) < 3) ? ggml_reshape_4d(ctx, w2a, 1, 1, rank, vp) : w2a;

struct ggml_tensor* ha = ggml_conv_2d(ctx, w2b_4d, h_grouped, conv_params.s0, conv_params.s1,
conv_params.p0, conv_params.p1, conv_params.d0, conv_params.d1);
hb = ggml_conv_2d(ctx, w2a_4d, ha, 1, 1, 0, 0, 1, 1);
}

int w_out = (int)hb->ne[0];
int h_out = (int)hb->ne[1];

struct ggml_tensor* hb_flat = ggml_reshape_3d(ctx, hb, w_out * h_out * vp, uq, batch);
struct ggml_tensor* hb_t = ggml_transpose(ctx, hb_flat);

struct ggml_tensor* hc;
struct ggml_tensor* w1_mat = (w1 != NULL) ? ggml_reshape_2d(ctx, w1, uq, up) : NULL;

if (w1_mat != NULL) {
hc = ggml_mul_mat(ctx, w1_mat, hb_t);
} else {
hc = ggml_mul_mat(ctx, w1b, ggml_mul_mat(ctx, w1a, hb_t));
}

struct ggml_tensor* hc_t = ggml_transpose(ctx, hc);
struct ggml_tensor* hc_res = ggml_reshape_4d(ctx, ggml_cont(ctx, hc_t), vp, w_out * h_out, up, batch);
struct ggml_tensor* hc_perm = ggml_permute(ctx, hc_res, 1, 2, 0, 3);
struct ggml_tensor* out = ggml_reshape_4d(ctx, ggml_cont(ctx, hc_perm), w_out, h_out, up * vp, batch);

return ggml_scale(ctx, out, scale);
}
}

#endif // __GGML_EXTEND__HPP__
116 changes: 113 additions & 3 deletions lora.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ struct LoraModel : public GGMLRunner {
diff = get_loha_weight_diff(model_tensor_name, ctx);
}
// lokr
if (diff == nullptr) {
if (diff == nullptr && with_lora) {
diff = get_lokr_weight_diff(model_tensor_name, ctx);
}
if (diff != nullptr) {
Expand All @@ -501,6 +501,8 @@ struct LoraModel : public GGMLRunner {
return diff;
}



ggml_tensor* get_out_diff(ggml_context* ctx,
ggml_tensor* x,
WeightAdapter::ForwardParams forward_params,
Expand All @@ -514,6 +516,115 @@ struct LoraModel : public GGMLRunner {
} else {
key = model_tensor_name + "." + std::to_string(index);
}
bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;


std::string lokr_w1_name = "lora." + key + ".lokr_w1";
std::string lokr_w1_a_name = "lora." + key + ".lokr_w1_a";
// if either of these is found, then we have a lokr lora
auto iter = lora_tensors.find(lokr_w1_name);
auto iter_a = lora_tensors.find(lokr_w1_a_name);
if (iter != lora_tensors.end() || iter_a != lora_tensors.end()) {
std::string lokr_w1_b_name = "lora." + key + ".lokr_w1_b";
std::string lokr_w2_name = "lora." + key + ".lokr_w2";
std::string lokr_w2_a_name = "lora." + key + ".lokr_w2_a";
std::string lokr_w2_b_name = "lora." + key + ".lokr_w2_b";
std::string alpha_name = "lora." + key + ".alpha";

ggml_tensor* lokr_w1 = nullptr;
ggml_tensor* lokr_w1_a = nullptr;
ggml_tensor* lokr_w1_b = nullptr;
ggml_tensor* lokr_w2 = nullptr;
ggml_tensor* lokr_w2_a = nullptr;
ggml_tensor* lokr_w2_b = nullptr;

if (iter != lora_tensors.end()) {
lokr_w1 = iter->second;
if (is_conv2d && lokr_w1->type != GGML_TYPE_F16) {
lokr_w1 = ggml_cast(ctx, lokr_w1, GGML_TYPE_F16);
}
}
iter = iter_a;
if (iter != lora_tensors.end()) {
lokr_w1_a = iter->second;
if (is_conv2d && lokr_w1_a->type != GGML_TYPE_F16) {
lokr_w1_a = ggml_cast(ctx, lokr_w1_a, GGML_TYPE_F16);
}
}
iter = lora_tensors.find(lokr_w1_b_name);
if (iter != lora_tensors.end()) {
lokr_w1_b = iter->second;
if (is_conv2d && lokr_w1_b->type != GGML_TYPE_F16) {
lokr_w1_b = ggml_cast(ctx, lokr_w1_b, GGML_TYPE_F16);
}
}

iter = lora_tensors.find(lokr_w2_name);
if (iter != lora_tensors.end()) {
lokr_w2 = iter->second;
if (is_conv2d && lokr_w2->type != GGML_TYPE_F16) {
lokr_w2 = ggml_cast(ctx, lokr_w2, GGML_TYPE_F16);
}
}
iter = lora_tensors.find(lokr_w2_a_name);
if (iter != lora_tensors.end()) {
lokr_w2_a = iter->second;
if (is_conv2d && lokr_w2_a->type != GGML_TYPE_F16) {
lokr_w2_a = ggml_cast(ctx, lokr_w2_a, GGML_TYPE_F16);
}
}
iter = lora_tensors.find(lokr_w2_b_name);
if (iter != lora_tensors.end()) {
lokr_w2_b = iter->second;
if (is_conv2d && lokr_w2_b->type != GGML_TYPE_F16) {
lokr_w2_b = ggml_cast(ctx, lokr_w2_b, GGML_TYPE_F16);
}
}

int rank = 1;
if (lokr_w1_b) {
rank = lokr_w1_b->ne[ggml_n_dims(lokr_w1_b) - 1];
}
if (lokr_w2_b) {
rank = lokr_w2_b->ne[ggml_n_dims(lokr_w2_b) - 1];
}

float scale_value = 1.0f;
iter = lora_tensors.find(alpha_name);
if (iter != lora_tensors.end()) {
float alpha = ggml_ext_backend_tensor_get_f32(iter->second);
scale_value = alpha / rank;
applied_lora_tensors.insert(alpha_name);
}

if (rank == 1) {
scale_value = 1.0f;
}
scale_value *= multiplier;

auto curr_out_diff = ggml_ext_lokr_forward(ctx, x, lokr_w1, lokr_w1_a, lokr_w1_b, lokr_w2, lokr_w2_a, lokr_w2_b, is_conv2d, forward_params.conv2d, scale_value);
if (out_diff == nullptr) {
out_diff = curr_out_diff;
} else {
out_diff = ggml_concat(ctx, out_diff, curr_out_diff, 0);
}

if(lokr_w1) applied_lora_tensors.insert(lokr_w1_name);
if(lokr_w1_a) applied_lora_tensors.insert(lokr_w1_a_name);
if(lokr_w1_b) applied_lora_tensors.insert(lokr_w1_b_name);
if(lokr_w2) applied_lora_tensors.insert(lokr_w2_name);
if(lokr_w2_a) applied_lora_tensors.insert(lokr_w2_name);
if(lokr_w2_b) applied_lora_tensors.insert(lokr_w2_b_name);
applied_lora_tensors.insert(alpha_name);


index++;
continue;
}

// not a lork, normal lora path



std::string lora_down_name = "lora." + key + ".lora_down";
std::string lora_up_name = "lora." + key + ".lora_up";
Expand All @@ -525,9 +636,8 @@ struct LoraModel : public GGMLRunner {
ggml_tensor* lora_mid = nullptr;
ggml_tensor* lora_down = nullptr;

bool is_conv2d = forward_params.op_type == WeightAdapter::ForwardParams::op_type_t::OP_CONV2D;

auto iter = lora_tensors.find(lora_up_name);
iter = lora_tensors.find(lora_up_name);
if (iter != lora_tensors.end()) {
lora_up = iter->second;
if (is_conv2d && lora_up->type != GGML_TYPE_F16) {
Expand Down
Loading