1
1
#version 450
2
- #extension GL_EXT_shader_explicit_arithmetic_types : require
2
+ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
3
3
4
4
#include "mul_mat_vec_base.comp"
5
5
@@ -40,9 +40,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
40
40
41
41
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
42
42
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
43
- f16vec2 d = data_a[ib0 + i].d;
44
- const FLOAT_TYPE dall = d.x;
45
- const FLOAT_TYPE dmin = d.y;
43
+ vec2 d = vec2( data_a[ib0 + i].d) ;
44
+ const FLOAT_TYPE dall = FLOAT_TYPE( d.x) ;
45
+ const FLOAT_TYPE dmin = FLOAT_TYPE( d.y) ;
46
46
47
47
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
48
48
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
@@ -63,14 +63,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
63
63
uvec2 qs16 = uvec2(unpack8(qs16_u16));
64
64
65
65
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
66
- B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0] ;
67
- B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8] ;
68
- B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
69
- B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
70
- B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
71
- B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
72
- B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
73
- B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
66
+ vec2 b0 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]) ;
67
+ vec2 b16 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]) ;
68
+ vec2 b32 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]) ;
69
+ vec2 b48 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]) ;
70
+ vec2 b64 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]) ;
71
+ vec2 b80 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]) ;
72
+ vec2 b96 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]) ;
73
+ vec2 b112 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]) ;
74
74
75
75
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
76
76
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
0 commit comments