Skip to content

Commit 73bdcb3

Browse files
AndrewGodfreyggerganovcebtenzzre
authored
finetune : add -ngl parameter (#3762)
* Add '-ngl' support to finetune.cpp * Add fprintf in ggml_cuda_op_add When I tried CUDA offloading during finetuning following the readme, I got an assert here. This probably isn't an important case because inference later gives a warning saying you should use f16 or f32 instead when using lora * Add 'finetune.sh', which currently fails when using GPU "error: operator (): Finetuning on tensors with type 'f16' is not yet supported" * tweak finetune.sh * Suppress some warnings in ggml.c * Add f16 implementation to ggml_compute_forward_add_f16_f32 * Add an f16 case to ggml_add_cast_impl and llama_build_lora_finetune_graphs * finetune.sh: Edit comments * Add "add_f16_f32_f32_cuda" * Tweak an error message * finetune.sh: Add an optional LLAMA_MODEL_DIR variable * finetune.sh: Add an optional LLAMA_TRAINING_DIR variable * train : minor * tabs to spaces --------- Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: cebtenzzre <[email protected]>
1 parent f0e2093 commit 73bdcb3

File tree

8 files changed

+108
-17
lines changed

8 files changed

+108
-17
lines changed

common/train.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,7 @@ struct train_params_common get_default_train_params_common() {
10451045
params.n_batch = 8;
10461046
params.n_gradient_accumulation = 1;
10471047
params.n_epochs = -1;
1048+
params.n_gpu_layers = 0;
10481049

10491050
params.custom_n_ctx = false;
10501051

@@ -1080,6 +1081,7 @@ struct train_params_common get_default_train_params_common() {
10801081
params.adam_beta2 = 0.999f;
10811082
params.adam_gclip = 1.0f;
10821083
params.adam_eps_f = 0.0f;
1084+
10831085
return params;
10841086
}
10851087

common/train.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ struct train_params_common {
4444
int n_batch;
4545
int n_gradient_accumulation;
4646
int n_epochs;
47+
int n_gpu_layers;
4748

4849
bool custom_n_ctx;
4950

examples/finetune/finetune.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
652652
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
653653

654654
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
655-
if (ggml_is_quantized(a->type)) {
655+
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
656656
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
657657
} else if (a->type == GGML_TYPE_F32) {
658658
return ggml_add(ctx, a, b);
@@ -1459,6 +1459,17 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par
14591459
}
14601460
params->n_rank_w3 = std::stoi(argv[i]);
14611461
params->custom_n_rank_w3 = true;
1462+
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
1463+
if (++i >= argc) {
1464+
invalid_param = true;
1465+
break;
1466+
}
1467+
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
1468+
params->common.n_gpu_layers = std::stoi(argv[i]);
1469+
#else
1470+
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
1471+
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
1472+
#endif
14621473
} else {
14631474
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
14641475
train_print_usage(argc, argv, &default_params);
@@ -1545,6 +1556,7 @@ int main(int argc, char ** argv) {
15451556
srand(params.common.seed);
15461557

15471558
struct llama_model_params llama_mparams = llama_model_default_params();
1559+
llama_mparams.n_gpu_layers = params.common.n_gpu_layers;
15481560
llama_mparams.vocab_only = false;
15491561

15501562
printf("%s: model base = '%s'\n", __func__, params.fn_model_base);

examples/finetune/finetune.sh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/bin/bash
2+
cd `dirname $0`
3+
cd ../..
4+
5+
EXE="./finetune"
6+
7+
if [[ ! $LLAMA_MODEL_DIR ]]; then LLAMA_MODEL_DIR="./models"; fi
8+
if [[ ! $LLAMA_TRAINING_DIR ]]; then LLAMA_TRAINING_DIR="."; fi
9+
10+
# MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2-q8_0.gguf" # This is the model the readme uses.
11+
MODEL="$LLAMA_MODEL_DIR/openllama-3b-v2.gguf" # An f16 model. Note in this case with "-g", you get an f32-format .BIN file that isn't yet supported if you use it with "main --lora" with GPU inferencing.
12+
13+
while getopts "dg" opt; do
14+
case $opt in
15+
d)
16+
DEBUGGER="gdb --args"
17+
;;
18+
g)
19+
EXE="./build/bin/Release/finetune"
20+
GPUARG="--gpu-layers 25"
21+
;;
22+
esac
23+
done
24+
25+
$DEBUGGER $EXE \
26+
--model-base $MODEL \
27+
$GPUARG \
28+
--checkpoint-in chk-ol3b-shakespeare-LATEST.gguf \
29+
--checkpoint-out chk-ol3b-shakespeare-ITERATION.gguf \
30+
--lora-out lora-ol3b-shakespeare-ITERATION.bin \
31+
--train-data "$LLAMA_TRAINING_DIR\shakespeare.txt" \
32+
--save-every 10 \
33+
--threads 10 --adam-iter 30 --batch 4 --ctx 64 \
34+
--use-checkpointing

ggml-cuda.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,15 @@ static __global__ void add_f16_f32_f16(const half * x, const float * y, half * d
513513
dst[i] = __hadd(x[i], __float2half(y[i]));
514514
}
515515

516+
static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) {
517+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
518+
519+
if (i >= k) {
520+
return;
521+
}
522+
dst[i] = __half2float(x[i]) + y[i];
523+
}
524+
516525
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
517526
const int i = blockDim.x*blockIdx.x + threadIdx.x;
518527

@@ -4693,6 +4702,11 @@ static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, co
46934702
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
46944703
}
46954704

4705+
static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) {
4706+
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
4707+
add_f16_f32_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
4708+
}
4709+
46964710
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
46974711
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
46984712
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -5996,7 +6010,10 @@ inline void ggml_cuda_op_add(
59966010
add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream);
59976011
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
59986012
add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream);
6013+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
6014+
add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream);
59996015
} else {
6016+
fprintf(stderr, "src0->type: %d dst->type: %d\n", src0->type, dst->type);
60006017
GGML_ASSERT(false);
60016018
}
60026019

ggml-quants.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) {
716716
__riscv_vse8_v_i8m1(y[i].qs , vs, vl);
717717
}
718718
#else
719+
UNUSED(nb);
719720
// scalar
720721
quantize_row_q8_0_reference(x, y, k);
721722
#endif
@@ -969,6 +970,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int k) {
969970
y[i].s = sum*d;
970971
}
971972
#else
973+
UNUSED(nb);
972974
// scalar
973975
quantize_row_q8_1_reference(x, y, k);
974976
#endif

ggml.c

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3153,7 +3153,7 @@ static struct ggml_tensor * ggml_add_cast_impl(
31533153
// TODO: support less-strict constraint
31543154
// GGML_ASSERT(ggml_can_repeat(b, a));
31553155
GGML_ASSERT(ggml_can_repeat_rows(b, a));
3156-
GGML_ASSERT(ggml_is_quantized(a->type)); // currently only supported for quantized input
3156+
GGML_ASSERT(ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16); // currently only supported for quantized input and f16
31573157

31583158
bool is_node = false;
31593159

@@ -6927,9 +6927,15 @@ static void ggml_compute_forward_add_f16_f32(
69276927

69286928
GGML_ASSERT(src0->type == GGML_TYPE_F16);
69296929
GGML_ASSERT(src1->type == GGML_TYPE_F32);
6930-
GGML_ASSERT(dst->type == GGML_TYPE_F16);
69316930

6932-
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6931+
if (dst->type == GGML_TYPE_F32) {
6932+
GGML_ASSERT( nb0 == sizeof(float));
6933+
}
6934+
else {
6935+
GGML_ASSERT(dst->type == GGML_TYPE_F16);
6936+
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
6937+
}
6938+
69336939
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
69346940

69356941
// rows per thread
@@ -6940,18 +6946,35 @@ static void ggml_compute_forward_add_f16_f32(
69406946
const int ir1 = MIN(ir0 + dr, nr);
69416947

69426948
if (nb10 == sizeof(float)) {
6943-
for (int ir = ir0; ir < ir1; ++ir) {
6944-
// src0, src1 and dst are same shape => same indices
6945-
const int i3 = ir/(ne2*ne1);
6946-
const int i2 = (ir - i3*ne2*ne1)/ne1;
6947-
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
6948-
6949-
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
6950-
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
6951-
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
6952-
6953-
for (int i = 0; i < ne0; i++) {
6954-
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
6949+
if (dst->type == GGML_TYPE_F16) {
6950+
for (int ir = ir0; ir < ir1; ++ir) {
6951+
// src0, src1 and dst are same shape => same indices
6952+
const int i3 = ir/(ne2*ne1);
6953+
const int i2 = (ir - i3*ne2*ne1)/ne1;
6954+
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
6955+
6956+
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
6957+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
6958+
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
6959+
6960+
for (int i = 0; i < ne0; i++) {
6961+
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i]);
6962+
}
6963+
}
6964+
} else {
6965+
for (int ir = ir0; ir < ir1; ++ir) {
6966+
// src0, src1 and dst are same shape => same indices
6967+
const int i3 = ir/(ne2*ne1);
6968+
const int i2 = (ir - i3*ne2*ne1)/ne1;
6969+
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
6970+
6971+
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
6972+
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
6973+
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11);
6974+
6975+
for (int i = 0; i < ne0; i++) {
6976+
dst_ptr[i] = GGML_FP16_TO_FP32(src0_ptr[i]) + src1_ptr[i];
6977+
}
69556978
}
69566979
}
69576980
}

llama.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8003,7 +8003,7 @@ static int llama_apply_lora_from_file_internal(
80038003
if (dest_t->backend == GGML_BACKEND_GPU || dest_t->backend == GGML_BACKEND_GPU_SPLIT) {
80048004
if (dest_t->type != GGML_TYPE_F16) {
80058005
throw std::runtime_error(format(
8006-
"%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models", __func__));
8006+
"%s: error: the simultaneous use of LoRAs and GPU acceleration is only supported for f16 models. dest_t->type: %d", __func__, dest_t->type));
80078007
}
80088008
offload_func = ggml_cuda_assign_buffers;
80098009
offload_func_force_inplace = ggml_cuda_assign_buffers_force_inplace;

0 commit comments

Comments
 (0)