Skip to content

Commit 3855416

Browse files
authored
ggml : introduce bfloat16 support (#6412)
* Introduce bfloat16 support Many models on Hugging Face (e.g. Mistral, TinyLLaMA) use bfloat16 as their canonical floating point format. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───┐ 0b0000000000000000 brain16 This encoding has the same number of exponent bits as float32. That makes conversion relatively straightforward, even in the absence of hardware support. For example, converting brain16 to binary32 means simply shifting 16 bits to the left. ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌──┴───┐┌─┴───────────────────┐ 0b00000000000000000000000000000000 IEEE binary32 The issue is that converting bf16 to fp16 can result in information loss. Only 13% of bf16 numbers can be precisely represented in fp16 which in practice ends up being 99.71% of Mistral 7b v0.2's weights however there is currently no way other than fp32 to get the others ┌sign │ │ ┌exponent │ │ │ │ ┌mantissa │ │ │ │┌─┴─┐┌─┴──────┐ 0b0000000000000000 IEEE binary16 This change fixes that, by adding a bf16 data type to GGML. Support for CPU inference has been implemented along with optimizations for the AVX2, AVX512, and AVX512BF16 ISAs. Perplexity on Mistral 7b 0.2 improves somewhere around -0.0024 to -0.0046 compared to using fp16 * Remove GGML code that's not needed * Minimize the GGML API surface area for BF16 * Remove bf16 luts * Make the GGML header look nicer * Fix documentation * Apply ggerganov's fixes for test-backend-ops * Add BF16 code for new ggml_validate_row_data() function
1 parent c0e6fbf commit 3855416

File tree

11 files changed

+1228
-102
lines changed

11 files changed

+1228
-102
lines changed

examples/finetune/finetune.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
575575
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
576576

577577
auto add_to_f32 = [] (struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
578-
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16) {
578+
if (ggml_is_quantized(a->type) || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16) {
579579
return ggml_add_cast(ctx, a, b, GGML_TYPE_F32);
580580
} else if (a->type == GGML_TYPE_F32) {
581581
return ggml_add(ctx, a, b);

examples/quantize/quantize.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
4646
{ "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 4.45G, +0.0122 ppl @ LLaMA-v1-7B", },
4747
{ "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 5.15G, +0.0008 ppl @ LLaMA-v1-7B", },
4848
{ "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 6.70G, +0.0004 ppl @ LLaMA-v1-7B", },
49-
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "13.00G @ 7B", },
49+
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, -0.0020 ppl @ Mistral-7B", },
50+
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
5051
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },
5152
// Note: Ensure COPY comes after F32 to avoid ftype 0 from matching.
5253
{ "COPY", LLAMA_FTYPE_ALL_F32, "only copy tensors, no quantizing", },

ggml-impl.h

+77
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,83 @@
1717
#define MIN(a, b) ((a) < (b) ? (a) : (b))
1818
#define MAX(a, b) ((a) > (b) ? (a) : (b))
1919

20+
/**
21+
* Converts brain16 to float32.
22+
*
23+
* The bfloat16 floating point format has the following structure:
24+
*
25+
* ┌sign
26+
* │
27+
* │ ┌exponent
28+
* │ │
29+
* │ │ ┌mantissa
30+
* │ │ │
31+
* │┌──┴───┐┌─┴───┐
32+
* 0b0000000000000000 brain16
33+
*
34+
* Since bf16 has the same number of exponent bits as a 32bit float,
35+
* encoding and decoding numbers becomes relatively straightforward.
36+
*
37+
* ┌sign
38+
* │
39+
* │ ┌exponent
40+
* │ │
41+
* │ │ ┌mantissa
42+
* │ │ │
43+
* │┌──┴───┐┌─┴───────────────────┐
44+
* 0b00000000000000000000000000000000 IEEE binary32
45+
*
46+
* For comparison, the standard fp16 format has fewer exponent bits.
47+
*
48+
* ┌sign
49+
* │
50+
* │ ┌exponent
51+
* │ │
52+
* │ │ ┌mantissa
53+
* │ │ │
54+
* │┌─┴─┐┌─┴──────┐
55+
* 0b0000000000000000 IEEE binary16
56+
*
57+
* @see IEEE 754-2008
58+
*/
59+
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
60+
union {
61+
float f;
62+
uint32_t i;
63+
} u;
64+
u.i = (uint32_t)h.bits << 16;
65+
return u.f;
66+
}
67+
68+
/**
69+
* Converts float32 to brain16.
70+
*
71+
* This function is binary identical to AMD Zen4 VCVTNEPS2BF16.
72+
* Subnormals shall be flushed to zero, and NANs will be quiet.
73+
* This code should vectorize nicely if using modern compilers.
74+
*/
75+
static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
76+
ggml_bf16_t h;
77+
union {
78+
float f;
79+
uint32_t i;
80+
} u;
81+
u.f = s;
82+
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
83+
h.bits = (u.i >> 16) | 64; /* force to quiet */
84+
return h;
85+
}
86+
if (!(u.i & 0x7f800000)) { /* subnormal */
87+
h.bits = (u.i & 0x80000000) >> 16; /* flush to zero */
88+
return h;
89+
}
90+
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
91+
return h;
92+
}
93+
94+
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
95+
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
96+
2097
#ifdef __cplusplus
2198
extern "C" {
2299
#endif

ggml-metal.m

+1-1
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
803803
case GGML_OP_DIAG_MASK_INF:
804804
case GGML_OP_GET_ROWS:
805805
{
806-
return op->ne[3] == 1;
806+
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
807807
}
808808
default:
809809
return false;

ggml-quants.c

+18
Original file line numberDiff line numberDiff line change
@@ -12450,6 +12450,24 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
1245012450
const size_t nb = nbytes/ggml_type_size(type);
1245112451

1245212452
switch (type) {
12453+
case GGML_TYPE_BF16:
12454+
{
12455+
int nans = 0;
12456+
int infs = 0;
12457+
const unsigned short * f = (const unsigned short *) data;
12458+
for (size_t i = 0; i < nb; ++i) {
12459+
nans += (f[i] & 0x7fff) > 0x7f80;
12460+
infs += (f[i] & 0x7fff) == 0x7f80;
12461+
}
12462+
if (nans) {
12463+
fprintf(stderr, "%s: found %d NaNs in row of %zu BF16 values\n", __func__, nans, nb);
12464+
return false;
12465+
}
12466+
if (infs) {
12467+
fprintf(stderr, "%s: found %d infinities in row of %zu BF16 values\n", __func__, infs, nb);
12468+
return false;
12469+
}
12470+
} break;
1245312471
case GGML_TYPE_F16:
1245412472
{
1245512473
const ggml_fp16_t * f = (const ggml_fp16_t *) data;

0 commit comments

Comments
 (0)