Skip to content

Commit 9c987ee

Browse files
committed
Introduce enum llama_ftype
1 parent d8d4e86 commit 9c987ee

File tree

3 files changed

+46
-31
lines changed

3 files changed

+46
-31
lines changed

examples/quantize/quantize.cpp

+11-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ int main(int argc, char ** argv) {
1212

1313
if (argc != 4) {
1414
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]);
15-
fprintf(stderr, " type = 2 - q4_0\n");
16-
fprintf(stderr, " type = 3 - q4_1\n");
15+
fprintf(stderr, " type = %d - q4_0\n", LLAMA_FTYPE_MOSTLY_Q4_0);
16+
fprintf(stderr, " type = %d - q4_1\n", LLAMA_FTYPE_MOSTLY_Q4_1);
1717
return 1;
1818
}
1919

@@ -27,7 +27,15 @@ int main(int argc, char ** argv) {
2727
const std::string fname_inp = argv[1];
2828
const std::string fname_out = argv[2];
2929

30-
const int itype = atoi(argv[3]);
30+
const enum llama_ftype itype = (enum llama_ftype)atoi(argv[3]);
31+
switch (itype) {
32+
case LLAMA_FTYPE_MOSTLY_Q4_0:
33+
case LLAMA_FTYPE_MOSTLY_Q4_1:
34+
break;
35+
default:
36+
fprintf(stderr, "Invalid model file type %d\n", itype);
37+
return 1;
38+
}
3139

3240
const int64_t t_main_start_us = ggml_time_us();
3341

llama.cpp

+25-27
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ enum e_model {
5454
MODEL_65B,
5555
};
5656

57+
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", "gptq" };
58+
5759
static const size_t MB = 1024*1024;
5860

5961
// computed for n_ctx == 2048
@@ -100,7 +102,7 @@ struct llama_hparams {
100102
int32_t n_head = 32;
101103
int32_t n_layer = 32;
102104
int32_t n_rot = 64;
103-
int32_t f16 = 1;
105+
int32_t f16 = LLAMA_FTYPE_MOSTLY_F16;
104106
};
105107

106108
struct llama_layer {
@@ -435,7 +437,7 @@ static bool llama_model_load(
435437
}
436438

437439
// temp warning to tell the user to use "--n_parts"
438-
if (hparams.f16 == 4 && n_parts != 1) {
440+
if (hparams.f16 == LLAMA_FTYPE_PER_LAYER_IS_Q4_1 && n_parts != 1) {
439441
fprintf(stderr, "%s: GPTQ model detected - are you sure n_parts should be %d? we normally expect it to be 1\n", __func__, n_parts);
440442
fprintf(stderr, "%s: use '--n_parts 1' if necessary\n", __func__);
441443
}
@@ -508,11 +510,14 @@ static bool llama_model_load(
508510
// wtype is for per-layer weights, while vtype is for other weights
509511
ggml_type wtype, vtype;
510512
switch (model.hparams.f16) {
511-
case 0: wtype = vtype = GGML_TYPE_F32; break;
512-
case 1: wtype = vtype = GGML_TYPE_F16; break;
513-
case 2: wtype = vtype = GGML_TYPE_Q4_0; break;
514-
case 3: wtype = vtype = GGML_TYPE_Q4_1; break;
515-
case 4: wtype = GGML_TYPE_Q4_1; vtype = GGML_TYPE_F16; break;
513+
case LLAMA_FTYPE_ALL_F32: wtype = vtype = GGML_TYPE_F32; break;
514+
case LLAMA_FTYPE_MOSTLY_F16: wtype = vtype = GGML_TYPE_F16; break;
515+
case LLAMA_FTYPE_MOSTLY_Q4_0: wtype = vtype = GGML_TYPE_Q4_0; break;
516+
case LLAMA_FTYPE_MOSTLY_Q4_1: wtype = vtype = GGML_TYPE_Q4_1; break;
517+
case LLAMA_FTYPE_PER_LAYER_IS_Q4_1:
518+
wtype = GGML_TYPE_Q4_1;
519+
vtype = GGML_TYPE_F16;
520+
break;
516521
default:
517522
{
518523
fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
@@ -684,16 +689,15 @@ static bool llama_model_load(
684689
return false;
685690
}
686691
if (0) {
687-
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
688692
fprintf(stderr, "%24s - [%5d, %5d], type = %6s\n", name.data(), ne[0], ne[1], ftype_str[ftype]);
689693
}
690694

691695
switch (ftype) {
692-
case 0: // f32
693-
case 1: // f16
696+
case LLAMA_FTYPE_ALL_F32:
697+
case LLAMA_FTYPE_MOSTLY_F16:
694698
break;
695-
case 2: // q4_0
696-
case 3: // q4_1
699+
case LLAMA_FTYPE_MOSTLY_Q4_0:
700+
case LLAMA_FTYPE_MOSTLY_Q4_1:
697701
assert(ne[0] % 64 == 0);
698702
break;
699703
default:
@@ -1273,20 +1277,15 @@ static llama_vocab::id llama_sample_top_p_top_k(
12731277
//
12741278

12751279
// TODO: reuse code from the llama_model_load() somehow
1276-
static bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype) {
1277-
ggml_type type = GGML_TYPE_Q4_1;
1280+
static bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum llama_ftype itype) {
1281+
ggml_type type;
12781282

12791283
switch (itype) {
1280-
case 2: type = GGML_TYPE_Q4_0; break;
1281-
case 3: type = GGML_TYPE_Q4_1; break;
1282-
default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1;
1284+
case LLAMA_FTYPE_MOSTLY_Q4_0: type = GGML_TYPE_Q4_0; break;
1285+
case LLAMA_FTYPE_MOSTLY_Q4_1: type = GGML_TYPE_Q4_1; break;
1286+
default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return false;
12831287
};
12841288

1285-
if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) {
1286-
fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type);
1287-
return false;
1288-
}
1289-
12901289
llama_vocab vocab;
12911290

12921291
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
@@ -1438,7 +1437,6 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
14381437
}
14391438

14401439
{
1441-
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
14421440
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
14431441
}
14441442

@@ -1459,12 +1457,12 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
14591457
quantize &= (n_dims == 2);
14601458

14611459
if (quantize) {
1462-
if (ftype != 0 && ftype != 1) {
1460+
if (ftype != LLAMA_FTYPE_ALL_F32 && ftype != LLAMA_FTYPE_MOSTLY_F16) {
14631461
fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
14641462
return false;
14651463
}
14661464

1467-
if (ftype == 1) {
1465+
if (ftype == LLAMA_FTYPE_MOSTLY_F16) {
14681466
data_f16.resize(nelements);
14691467
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
14701468
data_f32.resize(nelements);
@@ -1478,7 +1476,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
14781476

14791477
ftype = itype;
14801478
} else {
1481-
const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t);
1479+
const int bpe = (ftype == LLAMA_FTYPE_ALL_F32) ? sizeof(float) : sizeof(uint16_t);
14821480

14831481
data_u8.resize(nelements*bpe);
14841482
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
@@ -1659,7 +1657,7 @@ void llama_free(struct llama_context * ctx) {
16591657
int llama_model_quantize(
16601658
const char * fname_inp,
16611659
const char * fname_out,
1662-
int itype) {
1660+
enum llama_ftype itype) {
16631661
if (!llama_model_quantize_internal(fname_inp, fname_out, itype)) {
16641662
fprintf(stderr, "%s: failed to quantize\n", __func__);
16651663
return 1;

llama.h

+10-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ extern "C" {
6464
void * progress_callback_user_data;
6565
};
6666

67+
// model file types
68+
enum llama_ftype {
69+
LLAMA_FTYPE_ALL_F32 = 0,
70+
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
71+
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
72+
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
73+
LLAMA_FTYPE_PER_LAYER_IS_Q4_1 = 4, // but tok_embeddings.weight and output.weight are F16
74+
};
75+
6776
LLAMA_API struct llama_context_params llama_context_default_params();
6877

6978
// Various functions for loading a ggml llama model.
@@ -81,7 +90,7 @@ extern "C" {
8190
LLAMA_API int llama_model_quantize(
8291
const char * fname_inp,
8392
const char * fname_out,
84-
int itype);
93+
enum llama_ftype itype);
8594

8695
// Returns the KV cache that will contain the context for the
8796
// ongoing prediction with the model.

0 commit comments

Comments
 (0)