Skip to content

Llama training finetuning interface #2246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 120 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ struct llama_layer {
struct ggml_tensor * w1;
struct ggml_tensor * w2;
struct ggml_tensor * w3;

// LoRA optional
struct ggml_tensor * wq_a;
struct ggml_tensor * wq_b;
struct ggml_tensor * wv_a;
struct ggml_tensor * wv_b;
};

struct llama_kv_cache {
Expand Down Expand Up @@ -303,6 +309,7 @@ struct llama_context {

const llama_model & model;
const llama_vocab & vocab;
std::vector<llama_lora_layers> lora_layers;

bool model_owner = false;

Expand Down Expand Up @@ -1366,12 +1373,37 @@ static bool llama_eval_internal(

// self-attention
{
struct ggml_tensor * wq = model.layers[il].wq;
struct ggml_tensor * wk = model.layers[il].wk;

if (model.layers[il].wq_a != nullptr) {
// apply lora
ggml_tensor * BA = ggml_mul_mat(ctx0, model.layers[il].wq_a, model.layers[il].wq_b);
offload_func(BA);
ggml_set_name(BA, "BA");

#if 0
if (scaling != 1.0f) {
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
ggml_set_name(scale_tensor, "scale_tensor");

BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
offload_func(BA);
ggml_set_name(BA, "BA_scaled");
}
#endif

wq = ggml_add(ctx0, wq, BA);
offload_func(wq);
ggml_set_name(wq, "lora_wq");
}

// compute Q and K and RoPE them
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, wk, cur);
offload_func_kq(tmpk);
ggml_set_name(tmpk, "tmpk");

struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, wq, cur);
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");

Expand All @@ -1386,8 +1418,30 @@ static bool llama_eval_internal(
// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
struct ggml_tensor * wv = model.layers[il].wv;
if (model.layers[il].wv_a != nullptr) {
// apply lora
ggml_tensor * BA = ggml_mul_mat(ctx0, model.layers[il].wv_a, model.layers[il].wv_b);
offload_func(BA);
ggml_set_name(BA, "BA");

#if 0
if (scaling != 1.0f) {
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
ggml_set_name(scale_tensor, "scale_tensor");

BA = ggml_scale_inplace(lora_ctx, BA, scale_tensor);
offload_func(BA);
ggml_set_name(BA, "BA_scaled");
}
#endif

wv = ggml_add(ctx0, wv, BA);
offload_func(wv);
ggml_set_name(wv, "lora_wv");
}

struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, wv, cur);
offload_func_v(tmpv);
ggml_set_name(tmpv, "tmpv");

Expand Down Expand Up @@ -2709,7 +2763,7 @@ int llama_model_quantize(
}
}

int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) {
static int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) {
fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);

const int64_t t_start_lora_us = ggml_time_us();
Expand Down Expand Up @@ -3525,3 +3579,65 @@ const char * llama_print_system_info(void) {
const std::vector<std::pair<std::string, struct ggml_tensor *>>& llama_internal_get_tensor_map(struct llama_context * ctx) {
return ctx->model.tensors_by_name;
}

// finetune related code
int llama_enable_finetune(struct llama_model * model, enum llama_finetune_type flags, int n_lora) {
const auto& hparams = model->hparams;

const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd = hparams.n_embd;

struct ggml_context* ctx0 = model->ctx;

if (flags & LLAMA_FINETUNE_FULL) {
ggml_set_param(ctx0, model->tok_embeddings);
ggml_set_param(ctx0, model->norm);

for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i];

ggml_set_param(ctx0, layer.attention_norm);
ggml_set_param(ctx0, layer.wq);
ggml_set_param(ctx0, layer.wk);
ggml_set_param(ctx0, layer.wv);
ggml_set_param(ctx0, layer.wo);
ggml_set_param(ctx0, layer.ffn_norm);
ggml_set_param(ctx0, layer.w1);
ggml_set_param(ctx0, layer.w2);
ggml_set_param(ctx0, layer.w3);
}
} else if (flags & LLAMA_FINETUNE_LORA) {
// create AB tensor if they are not present
for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = model->layers[i];

if (flags & LLAMA_FINETUNE_LORA_Q) {
if (layer.wq_a == nullptr || layer.wq_b == nullptr) {
layer.wq_a = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_lora, n_embd);
layer.wq_b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_embd, n_lora);
// initialize

// offload

}
ggml_set_param(ctx0, layer.wq_a);
ggml_set_param(ctx0, layer.wq_b);
}

if (flags & LLAMA_FINETUNE_LORA_V) {
if (layer.wv_a == nullptr || layer.wv_b == nullptr) {
layer.wv_a = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_lora, n_embd);
layer.wv_b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F16, n_embd, n_lora);
// initialize

// offload

}
ggml_set_param(ctx0, layer.wv_a);
ggml_set_param(ctx0, layer.wv_b);
}
}
}

return 0;
}
15 changes: 15 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,16 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
};

enum llama_finetune_type {
LLAMA_FINETUNE_FULL = 0x01,
LLAMA_FINETUNE_LORA = 0x10,

LLAMA_FINETUNE_LORA_W = 0x1000, // valid only LoRA
LLAMA_FINETUNE_LORA_K = 0x2000,
LLAMA_FINETUNE_LORA_Q = 0x4000,
LLAMA_FINETUNE_LORA_V = 0x8000,
};

// model quantization parameters
typedef struct llama_model_quantize_params {
int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
Expand Down Expand Up @@ -242,6 +252,11 @@ extern "C" {
// IMPORTANT: do not use for anything else other than debugging and testing!
LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);

// Enable finetune on the context, flags indicate what type of finetune
LLAMA_API int llama_enable_finetune(struct llama_context * ctx, enum llama_finetune_type flags);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LLAMA_API int llama_enable_finetune(struct llama_context * ctx, enum llama_finetune_type flags);
LLAMA_API int llama_finetune_enable(struct llama_context * ctx, enum llama_finetune_type flags);


LLAMA_API int llama_finetune(struct llama_context * ctx, void * input, void * output);

// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
Expand Down