@@ -11803,9 +11803,9 @@ static void ggml_compute_forward_add_rel_pos(
11803
11803
static void ggml_compute_forward_rwkv_wkv6_f32 (
11804
11804
const struct ggml_compute_params * params ,
11805
11805
struct ggml_tensor * dst ) {
11806
- const int64_t T = dst -> src [1 ]-> ne [3 ];
11806
+ const int64_t T = dst -> src [1 ]-> ne [2 ];
11807
11807
const int64_t C = dst -> ne [0 ];
11808
- const int64_t HEADS = dst -> src [1 ]-> ne [2 ];
11808
+ const int64_t HEADS = dst -> src [1 ]-> ne [1 ];
11809
11809
const int64_t n_seqs = dst -> src [5 ]-> ne [1 ];
11810
11810
const int64_t head_size = C / HEADS ;
11811
11811
@@ -12000,6 +12000,197 @@ static void ggml_compute_forward_rwkv_wkv6(
12000
12000
}
12001
12001
}
12002
12002
12003
+ // ggml_compute_forward_gla
12004
+
12005
+ static void ggml_compute_forward_gla_f32 (
12006
+ const struct ggml_compute_params * params ,
12007
+ struct ggml_tensor * dst ) {
12008
+ const int64_t T = dst -> src [1 ]-> ne [2 ];
12009
+ const int64_t C = dst -> ne [0 ];
12010
+ const int64_t HEADS = dst -> src [1 ]-> ne [1 ];
12011
+ const int64_t n_seqs = dst -> src [4 ]-> ne [1 ];
12012
+ const int64_t head_size = C / HEADS ;
12013
+ const float scale = ggml_get_op_params_f32 (dst , 0 );
12014
+
12015
+ float * dst_data = (float * ) dst -> data ;
12016
+ float * state = ((float * ) dst -> data ) + C * T ;
12017
+
12018
+ const int ith = params -> ith ;
12019
+ const int nth = params -> nth ;
12020
+
12021
+ if (ith >= HEADS ) {
12022
+ return ;
12023
+ }
12024
+
12025
+ const int h_start = (HEADS * ith ) / nth ;
12026
+ const int h_end = ((HEADS * (ith + 1 )) / nth < HEADS ) ?
12027
+ (HEADS * (ith + 1 )) / nth : HEADS ;
12028
+
12029
+ float * k = (float * ) dst -> src [0 ]-> data ;
12030
+ float * v = (float * ) dst -> src [1 ]-> data ;
12031
+ float * q = (float * ) dst -> src [2 ]-> data ;
12032
+ float * g = (float * ) dst -> src [3 ]-> data ;
12033
+
12034
+ size_t t_stride = HEADS * head_size ; // Same to C
12035
+
12036
+ size_t h_stride = C / HEADS ;
12037
+ GGML_ASSERT (C % HEADS == 0 ); // C must be divisible by HEADS
12038
+ size_t h_stride_2d = head_size * head_size ;
12039
+
12040
+ if (ith == 0 ) {
12041
+ memset (dst_data , 0 , T * C * sizeof (float ));
12042
+ }
12043
+ ggml_barrier (params -> threadpool );
12044
+
12045
+
12046
+ #if defined(__AVX__ ) && !defined(__AVX512F__ )
12047
+ #define GGML_F32X GGML_F32x8
12048
+ #define GGML_F32X_SET1 GGML_F32x8_SET1
12049
+ #define GGML_F32X_LOAD GGML_F32x8_LOAD
12050
+ #define GGML_F32X_STORE GGML_F32x8_STORE
12051
+ #define GGML_F32X_MUL GGML_F32x8_MUL
12052
+ #define GGML_F32X_FMA GGML_F32x8_FMA
12053
+ #define GLA_VECTOR_SIZE 8
12054
+ #elif defined(__AVX512F__ )
12055
+ #define GGML_F32X GGML_F32x16
12056
+ #define GGML_F32X_SET1 GGML_F32x16_SET1
12057
+ #define GGML_F32X_LOAD GGML_F32x16_LOAD
12058
+ #define GGML_F32X_STORE GGML_F32x16_STORE
12059
+ #define GGML_F32X_MUL GGML_F32x16_MUL
12060
+ #define GGML_F32X_FMA GGML_F32x16_FMA
12061
+ #define GLA_VECTOR_SIZE 16
12062
+ #elif defined(__ARM_NEON ) && defined(__aarch64__ )
12063
+ #define GGML_F32X GGML_F32x4
12064
+ #define GGML_F32X_SET1 GGML_F32x4_SET1
12065
+ #define GGML_F32X_LOAD GGML_F32x4_LOAD
12066
+ #define GGML_F32X_STORE GGML_F32x4_STORE
12067
+ #define GGML_F32X_MUL GGML_F32x4_MUL
12068
+ #define GGML_F32X_FMA GGML_F32x4_FMA
12069
+ #define GLA_VECTOR_SIZE 4
12070
+ #endif
12071
+
12072
+ #ifdef GLA_VECTOR_SIZE
12073
+ const int64_t vec_count = head_size / GLA_VECTOR_SIZE ;
12074
+
12075
+ for (int64_t t = 0 ; t < T ; t ++ ) {
12076
+ size_t t_offset = t * t_stride ;
12077
+ size_t state_offset = head_size * C * (t / (T / n_seqs ));
12078
+ float * state_cur = state + state_offset ;
12079
+ float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [4 ]-> data + state_offset ;
12080
+
12081
+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
12082
+ size_t h_offset = h * h_stride ;
12083
+ size_t t_h_offset = t_offset + h_offset ;
12084
+ size_t h_2d_offset = h * h_stride_2d ;
12085
+
12086
+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
12087
+ size_t t_h_i_offset = t_h_offset + i ;
12088
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
12089
+
12090
+ float k_val = k [t_h_i_offset ];
12091
+ float q_val = q [t_h_i_offset ] * scale ;
12092
+ float g_val = g [t_h_i_offset ];
12093
+
12094
+ // Broadcast scalar values to vectors
12095
+ GGML_F32X k_vec = GGML_F32X_SET1 (k_val );
12096
+ GGML_F32X q_vec = GGML_F32X_SET1 (q_val );
12097
+ GGML_F32X g_vec = GGML_F32X_SET1 (g_val );
12098
+
12099
+ for (int64_t j = 0 ; j < vec_count ; j ++ ) {
12100
+ size_t base_j = j * GLA_VECTOR_SIZE ;
12101
+ size_t t_h_j_offset = t_h_offset + base_j ;
12102
+ size_t h_2d_i_j_offset = h_2d_i_offset + base_j ;
12103
+
12104
+ // Load x elements at once
12105
+ GGML_F32X v_vec = GGML_F32X_LOAD (& v [t_h_j_offset ]);
12106
+ GGML_F32X prev_state_vec = GGML_F32X_LOAD (& state_prev [h_2d_i_j_offset ]);
12107
+ GGML_F32X dst_vec = GGML_F32X_LOAD (& dst_data [t_h_j_offset ]);
12108
+
12109
+ // Compute kv = v * k
12110
+ GGML_F32X kv_vec = GGML_F32X_MUL (v_vec , k_vec );
12111
+
12112
+ // Compute temp = prev_state * g + kv
12113
+ GGML_F32X temp_vec = GGML_F32X_FMA (kv_vec , prev_state_vec , g_vec );
12114
+
12115
+ // Update dst: dst += temp * q
12116
+ dst_vec = GGML_F32X_FMA (dst_vec , temp_vec , q_vec );
12117
+ GGML_F32X_STORE (& dst_data [t_h_j_offset ], dst_vec );
12118
+
12119
+ // Update state
12120
+ GGML_F32X_STORE (& state_cur [h_2d_i_j_offset ], temp_vec );
12121
+ }
12122
+
12123
+ // Handle remaining elements, this will not be used.
12124
+ for (int64_t j = vec_count * GLA_VECTOR_SIZE ; j < head_size ; j ++ ) {
12125
+ size_t t_h_j_offset = t_h_offset + j ;
12126
+ size_t h_2d_i_j_offset = h_2d_i_offset + j ;
12127
+ float v_val = v [t_h_j_offset ];
12128
+ float kv_val = v_val * k_val ;
12129
+ float prev_state_val = state_prev [h_2d_i_j_offset ];
12130
+ float temp_val = kv_val + prev_state_val * g_val ;
12131
+ dst_data [t_h_j_offset ] += temp_val * q_val ;
12132
+ state_cur [h_2d_i_j_offset ] = temp_val ;
12133
+ }
12134
+ }
12135
+ }
12136
+ }
12137
+
12138
+ #else
12139
+ for (int64_t t = 0 ; t < T ; t ++ ) {
12140
+ size_t t_offset = t * t_stride ;
12141
+ size_t state_offset = head_size * C * (t / (T / n_seqs ));
12142
+ float * state_cur = state + state_offset ;
12143
+ float * state_prev = t % (T / n_seqs ) ? state_cur : (float * )dst -> src [4 ]-> data + state_offset ;
12144
+
12145
+ for (int64_t h = h_start ; h < h_end ; h ++ ) {
12146
+ size_t h_offset = h * h_stride ;
12147
+ size_t t_h_offset = t_offset + h_offset ;
12148
+ size_t h_2d_offset = h * h_stride_2d ;
12149
+
12150
+ for (int64_t i = 0 ; i < head_size ; i ++ ) {
12151
+ size_t t_h_i_offset = t_h_offset + i ;
12152
+ size_t h_2d_i_offset = h_2d_offset + i * h_stride ;
12153
+
12154
+ float k_val = k [t_h_i_offset ];
12155
+ float q_val = q [t_h_i_offset ] * scale ;
12156
+ float g_val = g [t_h_i_offset ];
12157
+
12158
+ for (int64_t j = 0 ; j < head_size ; j ++ ) {
12159
+ size_t t_h_j_offset = t_h_offset + j ;
12160
+ size_t h_2d_i_j_offset = h_2d_i_offset + j ;
12161
+
12162
+ float v_val = v [t_h_j_offset ];
12163
+ float kv_val = v_val * k_val ;
12164
+ float prev_state_val = state_prev [h_2d_i_j_offset ];
12165
+ float temp_val = prev_state_val * g_val + kv_val ;
12166
+ dst_data [t_h_j_offset ] += temp_val * q_val ;
12167
+ state_cur [h_2d_i_j_offset ] = temp_val ;
12168
+ }
12169
+ }
12170
+ }
12171
+ }
12172
+ #endif
12173
+ }
12174
+
12175
+
12176
+ static void ggml_compute_forward_gla (
12177
+ const struct ggml_compute_params * params ,
12178
+ struct ggml_tensor * dst ) {
12179
+
12180
+ const struct ggml_tensor * src0 = dst -> src [0 ];
12181
+
12182
+ switch (src0 -> type ) {
12183
+ case GGML_TYPE_F32 :
12184
+ {
12185
+ ggml_compute_forward_gla_f32 (params , dst );
12186
+ } break ;
12187
+ default :
12188
+ {
12189
+ GGML_ABORT ("fatal error" );
12190
+ }
12191
+ }
12192
+ }
12193
+
12003
12194
// ggml_compute_forward_map_unary
12004
12195
12005
12196
static void ggml_compute_forward_map_unary_f32 (
@@ -12749,6 +12940,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
12749
12940
{
12750
12941
ggml_compute_forward_rwkv_wkv6 (params , tensor );
12751
12942
} break ;
12943
+ case GGML_OP_GATED_LINEAR_ATTN :
12944
+ {
12945
+ ggml_compute_forward_gla (params , tensor );
12946
+ } break ;
12752
12947
case GGML_OP_MAP_UNARY :
12753
12948
{
12754
12949
ggml_unary_op_f32_t fun ;
@@ -13047,6 +13242,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
13047
13242
case GGML_OP_WIN_UNPART :
13048
13243
case GGML_OP_GET_REL_POS :
13049
13244
case GGML_OP_RWKV_WKV6 :
13245
+ case GGML_OP_GATED_LINEAR_ATTN :
13050
13246
case GGML_OP_MAP_UNARY :
13051
13247
case GGML_OP_MAP_BINARY :
13052
13248
case GGML_OP_MAP_CUSTOM1_F32 :
0 commit comments