Skip to content

Commit ee7136c

Browse files
MollySophiaggerganovcompilade
authored
llama: add support for QRWKV6 model architecture (#11001)
llama: add support for QRWKV6 model architecture (#11001) * WIP: Add support for RWKV6Qwen2 Signed-off-by: Molly Sophia <[email protected]> * RWKV: Some graph simplification Signed-off-by: Molly Sophia <[email protected]> * Add support for RWKV6Qwen2 with cpu and cuda GLA Signed-off-by: Molly Sophia <[email protected]> * RWKV6[QWEN2]: Concat lerp weights together to reduce cpu overhead Signed-off-by: Molly Sophia <[email protected]> * Fix some typos Signed-off-by: Molly Sophia <[email protected]> * code format changes Signed-off-by: Molly Sophia <[email protected]> * Fix wkv test & add gla test Signed-off-by: Molly Sophia <[email protected]> * Fix cuda warning Signed-off-by: Molly Sophia <[email protected]> * Update README.md Signed-off-by: Molly Sophia <[email protected]> * Update ggml/src/ggml-cuda/gla.cu Co-authored-by: Georgi Gerganov <[email protected]> * Fix fused lerp weights loading with RWKV6 Signed-off-by: Molly Sophia <[email protected]> * better sanity check skipping for QRWKV6 in llama-quant thanks @compilade Signed-off-by: Molly Sophia <[email protected]> Co-authored-by: compilade <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]> Co-authored-by: compilade <[email protected]>
1 parent c6860cc commit ee7136c

23 files changed

+863
-125
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
9999
- [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat)
100100
- [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a)
101101
- [x] [RWKV-6](https://github.com/BlinkDL/RWKV-LM)
102+
- [x] [QRWKV-6](https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1)
102103
- [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct)
103104

104105
#### Multimodal

convert_hf_to_gguf.py

+77-4
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ def prepare_tensors(self):
326326
gguf.MODEL_TENSOR.TIME_MIX_W2,
327327
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
328328
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
329+
gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
329330
gguf.MODEL_TENSOR.POSNET_NORM1,
330331
gguf.MODEL_TENSOR.POSNET_NORM2,
331332
)
@@ -3316,6 +3317,8 @@ def set_gguf_parameters(self):
33163317
# required by llama.cpp, unused
33173318
self.gguf_writer.add_head_count(0)
33183319

3320+
lerp_weights: dict[int, dict[str, Tensor]] = {}
3321+
33193322
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
33203323
new_name = self.map_tensor_name(name)
33213324

@@ -3331,14 +3334,84 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
33313334
if new_name.endswith("time_mix_decay.weight") or "lerp" in new_name:
33323335
data_torch = data_torch.squeeze()
33333336

3334-
rescale_every_n_layers = self.hparams["rescale_every"]
3335-
if rescale_every_n_layers > 0:
3336-
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
3337-
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
3337+
try:
3338+
rescale_every_n_layers = self.hparams["rescale_every"]
3339+
if rescale_every_n_layers > 0:
3340+
if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"):
3341+
data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers))
3342+
except KeyError:
3343+
pass
3344+
3345+
# concat time_mix_lerp weights to reduce some cpu overhead
3346+
# also reduces the number of tensors in the model
3347+
if bid is not None and "time_mix_lerp" in new_name and "time_mix_lerp_x" not in new_name:
3348+
try:
3349+
self.lerp_weights[bid][new_name] = data_torch
3350+
except KeyError:
3351+
self.lerp_weights[bid] = {new_name: data_torch}
3352+
if all(f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() for i in ["w", "k", "v", "r", "g"]):
3353+
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
3354+
data = torch.stack([self.lerp_weights[bid][f"blk.{bid}.time_mix_lerp_{i}.weight"].unsqueeze(0) for i in ["w", "k", "v", "r", "g"]], dim=0).unsqueeze(1)
3355+
yield (new_name, data)
3356+
return
33383357

33393358
yield (new_name, data_torch)
33403359

33413360

3361+
@Model.register("RWKV6Qwen2ForCausalLM")
3362+
class RWKV6Qwen2Model(Rwkv6Model):
3363+
model_arch = gguf.MODEL_ARCH.RWKV6QWEN2
3364+
3365+
def set_vocab(self):
3366+
try:
3367+
self._set_vocab_sentencepiece()
3368+
except FileNotFoundError:
3369+
self._set_vocab_gpt2()
3370+
3371+
def set_gguf_parameters(self):
3372+
block_count = self.hparams["num_hidden_layers"]
3373+
num_attention_heads = self.hparams["num_attention_heads"]
3374+
num_key_value_heads = self.hparams["num_key_value_heads"]
3375+
hidden_size = self.hparams["hidden_size"]
3376+
head_size = hidden_size // num_attention_heads
3377+
rms_norm_eps = self.hparams["rms_norm_eps"]
3378+
intermediate_size = self.hparams["intermediate_size"]
3379+
time_mix_extra_dim = 64 if hidden_size >= 4096 else 32
3380+
time_decay_extra_dim = 128 if hidden_size >= 4096 else 64
3381+
3382+
# RWKV isn't context limited
3383+
self.gguf_writer.add_context_length(1048576)
3384+
self.gguf_writer.add_embedding_length(hidden_size)
3385+
self.gguf_writer.add_block_count(block_count)
3386+
self.gguf_writer.add_wkv_head_size(head_size)
3387+
self.gguf_writer.add_time_mix_extra_dim(time_mix_extra_dim)
3388+
self.gguf_writer.add_time_decay_extra_dim(time_decay_extra_dim)
3389+
self.gguf_writer.add_feed_forward_length(intermediate_size)
3390+
self.gguf_writer.add_file_type(self.ftype)
3391+
3392+
# special parameters for time_mixing in RWKV6QWEN2
3393+
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
3394+
self.gguf_writer.add_token_shift_count(1)
3395+
# RWKV6QWEN2 use grouped key/value like GQA
3396+
self.gguf_writer.add_head_count_kv(num_key_value_heads)
3397+
3398+
# required by llama.cpp, unused
3399+
self.gguf_writer.add_head_count(0)
3400+
3401+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3402+
for new_name, data in super().modify_tensors(data_torch, name, bid):
3403+
if "time_mix_w1" in new_name or "time_mix_w2" in new_name:
3404+
data = data.view(5, -1, data.shape[-1])
3405+
# rwkv6qwen2 has a different order of rkvwg instead of the original wkvrg
3406+
# permute them here to avoid code changes
3407+
data = torch.stack([data[3], data[1], data[2], data[0], data[4]], dim=0).view(-1, data.shape[-1])
3408+
if "w2" in new_name:
3409+
data = data.view(5, -1, data.shape[-1])
3410+
yield (new_name, data)
3411+
continue
3412+
yield (new_name, data)
3413+
3414+
33423415
@Model.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
33433416
class MambaModel(Model):
33443417
model_arch = gguf.MODEL_ARCH.MAMBA

ggml/include/ggml.h

+10
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ extern "C" {
501501
GGML_OP_GET_REL_POS,
502502
GGML_OP_ADD_REL_POS,
503503
GGML_OP_RWKV_WKV6,
504+
GGML_OP_GATED_LINEAR_ATTN,
504505

505506
GGML_OP_UNARY,
506507

@@ -1859,6 +1860,15 @@ extern "C" {
18591860
struct ggml_tensor * td,
18601861
struct ggml_tensor * state);
18611862

1863+
GGML_API struct ggml_tensor * ggml_gated_linear_attn(
1864+
struct ggml_context * ctx,
1865+
struct ggml_tensor * k,
1866+
struct ggml_tensor * v,
1867+
struct ggml_tensor * q,
1868+
struct ggml_tensor * g,
1869+
struct ggml_tensor * state,
1870+
float scale);
1871+
18621872
// custom operators
18631873

18641874
typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *);

ggml/src/ggml-cpu/ggml-cpu.c

+198-2
Original file line numberDiff line numberDiff line change
@@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
1180311803
static void ggml_compute_forward_rwkv_wkv6_f32(
1180411804
const struct ggml_compute_params * params,
1180511805
struct ggml_tensor * dst) {
11806-
const int64_t T = dst->src[1]->ne[3];
11806+
const int64_t T = dst->src[1]->ne[2];
1180711807
const int64_t C = dst->ne[0];
11808-
const int64_t HEADS = dst->src[1]->ne[2];
11808+
const int64_t HEADS = dst->src[1]->ne[1];
1180911809
const int64_t n_seqs = dst->src[5]->ne[1];
1181011810
const int64_t head_size = C / HEADS;
1181111811

@@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
1200012000
}
1200112001
}
1200212002

12003+
// ggml_compute_forward_gla
12004+
12005+
static void ggml_compute_forward_gla_f32(
12006+
const struct ggml_compute_params * params,
12007+
struct ggml_tensor * dst) {
12008+
const int64_t T = dst->src[1]->ne[2];
12009+
const int64_t C = dst->ne[0];
12010+
const int64_t HEADS = dst->src[1]->ne[1];
12011+
const int64_t n_seqs = dst->src[4]->ne[1];
12012+
const int64_t head_size = C / HEADS;
12013+
const float scale = ggml_get_op_params_f32(dst, 0);
12014+
12015+
float * dst_data = (float *) dst->data;
12016+
float * state = ((float *) dst->data) + C * T;
12017+
12018+
const int ith = params->ith;
12019+
const int nth = params->nth;
12020+
12021+
if (ith >= HEADS) {
12022+
return;
12023+
}
12024+
12025+
const int h_start = (HEADS * ith) / nth;
12026+
const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
12027+
(HEADS * (ith + 1)) / nth : HEADS;
12028+
12029+
float * k = (float *) dst->src[0]->data;
12030+
float * v = (float *) dst->src[1]->data;
12031+
float * q = (float *) dst->src[2]->data;
12032+
float * g = (float *) dst->src[3]->data;
12033+
12034+
size_t t_stride = HEADS * head_size; // Same to C
12035+
12036+
size_t h_stride = C / HEADS;
12037+
GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
12038+
size_t h_stride_2d = head_size * head_size;
12039+
12040+
if (ith == 0) {
12041+
memset(dst_data, 0, T * C * sizeof(float));
12042+
}
12043+
ggml_barrier(params->threadpool);
12044+
12045+
12046+
#if defined(__AVX__) && !defined(__AVX512F__)
12047+
#define GGML_F32X GGML_F32x8
12048+
#define GGML_F32X_SET1 GGML_F32x8_SET1
12049+
#define GGML_F32X_LOAD GGML_F32x8_LOAD
12050+
#define GGML_F32X_STORE GGML_F32x8_STORE
12051+
#define GGML_F32X_MUL GGML_F32x8_MUL
12052+
#define GGML_F32X_FMA GGML_F32x8_FMA
12053+
#define GLA_VECTOR_SIZE 8
12054+
#elif defined(__AVX512F__)
12055+
#define GGML_F32X GGML_F32x16
12056+
#define GGML_F32X_SET1 GGML_F32x16_SET1
12057+
#define GGML_F32X_LOAD GGML_F32x16_LOAD
12058+
#define GGML_F32X_STORE GGML_F32x16_STORE
12059+
#define GGML_F32X_MUL GGML_F32x16_MUL
12060+
#define GGML_F32X_FMA GGML_F32x16_FMA
12061+
#define GLA_VECTOR_SIZE 16
12062+
#elif defined(__ARM_NEON) && defined(__aarch64__)
12063+
#define GGML_F32X GGML_F32x4
12064+
#define GGML_F32X_SET1 GGML_F32x4_SET1
12065+
#define GGML_F32X_LOAD GGML_F32x4_LOAD
12066+
#define GGML_F32X_STORE GGML_F32x4_STORE
12067+
#define GGML_F32X_MUL GGML_F32x4_MUL
12068+
#define GGML_F32X_FMA GGML_F32x4_FMA
12069+
#define GLA_VECTOR_SIZE 4
12070+
#endif
12071+
12072+
#ifdef GLA_VECTOR_SIZE
12073+
const int64_t vec_count = head_size / GLA_VECTOR_SIZE;
12074+
12075+
for (int64_t t = 0; t < T; t++) {
12076+
size_t t_offset = t * t_stride;
12077+
size_t state_offset = head_size * C * (t / (T / n_seqs));
12078+
float * state_cur = state + state_offset;
12079+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12080+
12081+
for (int64_t h = h_start; h < h_end; h++) {
12082+
size_t h_offset = h * h_stride;
12083+
size_t t_h_offset = t_offset + h_offset;
12084+
size_t h_2d_offset = h * h_stride_2d;
12085+
12086+
for (int64_t i = 0; i < head_size; i++) {
12087+
size_t t_h_i_offset = t_h_offset + i;
12088+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12089+
12090+
float k_val = k[t_h_i_offset];
12091+
float q_val = q[t_h_i_offset] * scale;
12092+
float g_val = g[t_h_i_offset];
12093+
12094+
// Broadcast scalar values to vectors
12095+
GGML_F32X k_vec = GGML_F32X_SET1(k_val);
12096+
GGML_F32X q_vec = GGML_F32X_SET1(q_val);
12097+
GGML_F32X g_vec = GGML_F32X_SET1(g_val);
12098+
12099+
for (int64_t j = 0; j < vec_count; j++) {
12100+
size_t base_j = j * GLA_VECTOR_SIZE;
12101+
size_t t_h_j_offset = t_h_offset + base_j;
12102+
size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
12103+
12104+
// Load x elements at once
12105+
GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
12106+
GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
12107+
GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
12108+
12109+
// Compute kv = v * k
12110+
GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
12111+
12112+
// Compute temp = prev_state * g + kv
12113+
GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
12114+
12115+
// Update dst: dst += temp * q
12116+
dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
12117+
GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
12118+
12119+
// Update state
12120+
GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
12121+
}
12122+
12123+
// Handle remaining elements, this will not be used.
12124+
for (int64_t j = vec_count * GLA_VECTOR_SIZE; j < head_size; j++) {
12125+
size_t t_h_j_offset = t_h_offset + j;
12126+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
12127+
float v_val = v[t_h_j_offset];
12128+
float kv_val = v_val * k_val;
12129+
float prev_state_val = state_prev[h_2d_i_j_offset];
12130+
float temp_val = kv_val + prev_state_val * g_val;
12131+
dst_data[t_h_j_offset] += temp_val * q_val;
12132+
state_cur[h_2d_i_j_offset] = temp_val;
12133+
}
12134+
}
12135+
}
12136+
}
12137+
12138+
#else
12139+
for (int64_t t = 0; t < T; t++) {
12140+
size_t t_offset = t * t_stride;
12141+
size_t state_offset = head_size * C * (t / (T / n_seqs));
12142+
float * state_cur = state + state_offset;
12143+
float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
12144+
12145+
for (int64_t h = h_start; h < h_end; h++) {
12146+
size_t h_offset = h * h_stride;
12147+
size_t t_h_offset = t_offset + h_offset;
12148+
size_t h_2d_offset = h * h_stride_2d;
12149+
12150+
for (int64_t i = 0; i < head_size; i++) {
12151+
size_t t_h_i_offset = t_h_offset + i;
12152+
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
12153+
12154+
float k_val = k[t_h_i_offset];
12155+
float q_val = q[t_h_i_offset] * scale;
12156+
float g_val = g[t_h_i_offset];
12157+
12158+
for (int64_t j = 0; j < head_size; j++) {
12159+
size_t t_h_j_offset = t_h_offset + j;
12160+
size_t h_2d_i_j_offset = h_2d_i_offset + j;
12161+
12162+
float v_val = v[t_h_j_offset];
12163+
float kv_val = v_val * k_val;
12164+
float prev_state_val = state_prev[h_2d_i_j_offset];
12165+
float temp_val = prev_state_val * g_val + kv_val;
12166+
dst_data[t_h_j_offset] += temp_val * q_val;
12167+
state_cur[h_2d_i_j_offset] = temp_val;
12168+
}
12169+
}
12170+
}
12171+
}
12172+
#endif
12173+
}
12174+
12175+
12176+
static void ggml_compute_forward_gla(
12177+
const struct ggml_compute_params * params,
12178+
struct ggml_tensor * dst) {
12179+
12180+
const struct ggml_tensor * src0 = dst->src[0];
12181+
12182+
switch (src0->type) {
12183+
case GGML_TYPE_F32:
12184+
{
12185+
ggml_compute_forward_gla_f32(params, dst);
12186+
} break;
12187+
default:
12188+
{
12189+
GGML_ABORT("fatal error");
12190+
}
12191+
}
12192+
}
12193+
1200312194
// ggml_compute_forward_map_unary
1200412195

1200512196
static void ggml_compute_forward_map_unary_f32(
@@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1274912940
{
1275012941
ggml_compute_forward_rwkv_wkv6(params, tensor);
1275112942
} break;
12943+
case GGML_OP_GATED_LINEAR_ATTN:
12944+
{
12945+
ggml_compute_forward_gla(params, tensor);
12946+
} break;
1275212947
case GGML_OP_MAP_UNARY:
1275312948
{
1275412949
ggml_unary_op_f32_t fun;
@@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
1304713242
case GGML_OP_WIN_UNPART:
1304813243
case GGML_OP_GET_REL_POS:
1304913244
case GGML_OP_RWKV_WKV6:
13245+
case GGML_OP_GATED_LINEAR_ATTN:
1305013246
case GGML_OP_MAP_UNARY:
1305113247
case GGML_OP_MAP_BINARY:
1305213248
case GGML_OP_MAP_CUSTOM1_F32:

ggml/src/ggml-cuda/ggml-cuda.cu

+5
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
#include "ggml-cuda/unary.cuh"
3838
#include "ggml-cuda/upscale.cuh"
3939
#include "ggml-cuda/wkv6.cuh"
40+
#include "ggml-cuda/gla.cuh"
4041

4142
#include <algorithm>
4243
#include <array>
@@ -2167,6 +2168,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
21672168
case GGML_OP_RWKV_WKV6:
21682169
ggml_cuda_op_rwkv_wkv6(ctx, dst);
21692170
break;
2171+
case GGML_OP_GATED_LINEAR_ATTN:
2172+
ggml_cuda_op_gated_linear_attn(ctx, dst);
2173+
break;
21702174
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
21712175
ggml_cuda_cross_entropy_loss_back(ctx, dst);
21722176
break;
@@ -3011,6 +3015,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30113015
case GGML_OP_TIMESTEP_EMBEDDING:
30123016
case GGML_OP_LEAKY_RELU:
30133017
case GGML_OP_RWKV_WKV6:
3018+
case GGML_OP_GATED_LINEAR_ATTN:
30143019
return true;
30153020
case GGML_OP_FLASH_ATTN_EXT: {
30163021
#ifndef FLASH_ATTN_AVAILABLE

0 commit comments

Comments
 (0)