Skip to content

Commit 08eb991

Browse files
committed
metal : add cpy f16 -> f32 kernel
1 parent a742d9f commit 08eb991

File tree

4 files changed

+84
-15
lines changed

4 files changed

+84
-15
lines changed

convert.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ class UnquantizedDataType(DataType):
6363
pass
6464

6565

66-
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
67-
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
68-
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
69-
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
66+
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
67+
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
68+
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
69+
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
7070

7171

7272
@dataclass(frozen=True)
@@ -996,7 +996,7 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM
996996

997997

998998
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
999-
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type
999+
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
10001000

10011001
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
10021002
return GGMLFileType.AllF32

ggml-metal.m

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
156156
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
157157
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
158+
GGML_METAL_DECL_KERNEL(cpy_f16_f32);
158159
GGML_METAL_DECL_KERNEL(concat);
159160
GGML_METAL_DECL_KERNEL(sqr);
160161
GGML_METAL_DECL_KERNEL(sum_rows);
@@ -424,6 +425,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
424425
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
425426
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
426427
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
428+
GGML_METAL_ADD_KERNEL(cpy_f16_f32);
427429
GGML_METAL_ADD_KERNEL(concat);
428430
GGML_METAL_ADD_KERNEL(sqr);
429431
GGML_METAL_ADD_KERNEL(sum_rows);
@@ -539,6 +541,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
539541
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
540542
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
541543
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
544+
GGML_METAL_DEL_KERNEL(cpy_f16_f32);
542545
GGML_METAL_DEL_KERNEL(concat);
543546
GGML_METAL_DEL_KERNEL(sqr);
544547
GGML_METAL_DEL_KERNEL(sum_rows);
@@ -867,12 +870,37 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
867870
case GGML_OP_ROPE:
868871
case GGML_OP_IM2COL:
869872
case GGML_OP_ARGSORT:
870-
case GGML_OP_DUP:
871-
case GGML_OP_CPY:
872-
case GGML_OP_CONT:
873873
case GGML_OP_MUL_MAT:
874874
case GGML_OP_MUL_MAT_ID:
875875
return true;
876+
case GGML_OP_CPY:
877+
case GGML_OP_DUP:
878+
case GGML_OP_CONT:
879+
{
880+
switch (op->src[0]->type) {
881+
case GGML_TYPE_F32:
882+
switch (op->type) {
883+
case GGML_TYPE_F16:
884+
case GGML_TYPE_F32:
885+
case GGML_TYPE_Q8_0:
886+
case GGML_TYPE_Q4_0:
887+
case GGML_TYPE_Q4_1:
888+
return true;
889+
default:
890+
return false;
891+
}
892+
case GGML_TYPE_F16:
893+
switch (op->type) {
894+
case GGML_TYPE_F16:
895+
case GGML_TYPE_F32:
896+
return true;
897+
default:
898+
return false;
899+
}
900+
default:
901+
return false;
902+
};
903+
}
876904
case GGML_OP_DIAG_MASK_INF:
877905
{
878906
return op->ne[0] % 4 == 0;
@@ -2021,7 +2049,7 @@ void ggml_metal_graph_compute(
20212049
{
20222050
switch (dstt) {
20232051
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
2024-
case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
2052+
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
20252053
default: GGML_ASSERT(false && "not implemented");
20262054
};
20272055
} break;

ggml-metal.metal

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,8 +1698,8 @@ template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_ar
16981698
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
16991699

17001700
kernel void kernel_cpy_f16_f16(
1701-
device const half * src0,
1702-
device half * dst,
1701+
device const half * src0,
1702+
device half * dst,
17031703
constant int64_t & ne00,
17041704
constant int64_t & ne01,
17051705
constant int64_t & ne02,
@@ -1738,6 +1738,47 @@ kernel void kernel_cpy_f16_f16(
17381738
}
17391739
}
17401740

1741+
kernel void kernel_cpy_f16_f32(
1742+
device const half * src0,
1743+
device float * dst,
1744+
constant int64_t & ne00,
1745+
constant int64_t & ne01,
1746+
constant int64_t & ne02,
1747+
constant int64_t & ne03,
1748+
constant uint64_t & nb00,
1749+
constant uint64_t & nb01,
1750+
constant uint64_t & nb02,
1751+
constant uint64_t & nb03,
1752+
constant int64_t & ne0,
1753+
constant int64_t & ne1,
1754+
constant int64_t & ne2,
1755+
constant int64_t & ne3,
1756+
constant uint64_t & nb0,
1757+
constant uint64_t & nb1,
1758+
constant uint64_t & nb2,
1759+
constant uint64_t & nb3,
1760+
uint3 tgpig[[threadgroup_position_in_grid]],
1761+
uint3 tpitg[[thread_position_in_threadgroup]],
1762+
uint3 ntg[[threads_per_threadgroup]]) {
1763+
const int64_t i03 = tgpig[2];
1764+
const int64_t i02 = tgpig[1];
1765+
const int64_t i01 = tgpig[0];
1766+
1767+
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
1768+
1769+
const int64_t i3 = n / (ne2*ne1*ne0);
1770+
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
1771+
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
1772+
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
1773+
1774+
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1775+
1776+
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
1777+
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
1778+
dst_data[i00] = src[0];
1779+
}
1780+
}
1781+
17411782
kernel void kernel_cpy_f32_f16(
17421783
device const float * src0,
17431784
device half * dst,

llama.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4277,23 +4277,23 @@ struct llm_build_context {
42774277
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
42784278
cb(logits, "ffn_moe_logits", il);
42794279

4280-
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
4280+
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
42814281
cb(probs, "ffn_moe_probs", il);
42824282

42834283
// select experts
4284-
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
4284+
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
42854285
cb(selected_experts->src[0], "ffn_moe_argsort", il);
42864286

42874287
ggml_tensor * weights = ggml_get_rows(ctx0,
42884288
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
42894289
cb(weights, "ffn_moe_weights", il);
42904290

4291-
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
4291+
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
42924292

42934293
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
42944294
cb(weights_sum, "ffn_moe_weights_sum", il);
42954295

4296-
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
4296+
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
42974297
cb(weights, "ffn_moe_weights_norm", il);
42984298

42994299
// compute expert outputs

0 commit comments

Comments
 (0)