@@ -580,6 +580,32 @@ kernel void kernel_alibi_f32(
580
580
}
581
581
}
582
582
583
+ static float rope_ntkv2_ramp (const float low, const float high, const int i0) {
584
+ const float y = (i0 / 2 - low) / min (0 .001f , high - low);
585
+ return 1 .0f - min (1 .0f , max (0 .0f , y));
586
+ }
587
+
588
+ // NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
589
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
590
+ static float rope_ntkv2 (
591
+ const float theta_base,
592
+ const float theta_linear,
593
+ const float theta_ntk,
594
+ device const float corr_factors[4 ],
595
+ const int64_t i0,
596
+ const float ntk_factor,
597
+ const float extrapolation_factor) {
598
+ float ramp_mix;
599
+ float theta;
600
+
601
+ ramp_mix = rope_ntkv2_ramp (corr_factors[0 ], corr_factors[1 ], i0) * ntk_factor;
602
+ theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
603
+
604
+ ramp_mix = rope_ntkv2_ramp (corr_factors[2 ], corr_factors[3 ], i0) * extrapolation_factor;
605
+ theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
606
+ return theta;
607
+ }
608
+
583
609
kernel void kernel_rope (
584
610
device const void * src0,
585
611
device float * dst,
@@ -604,24 +630,33 @@ kernel void kernel_rope(
604
630
constant int & mode,
605
631
constant float & freq_base,
606
632
constant float & freq_scale,
633
+ constant float & ntk_factor,
634
+ constant float & extrapolation_factor,
607
635
uint3 tpig[[thread_position_in_grid]]) {
608
636
const int64_t i3 = tpig[2 ];
609
637
const int64_t i2 = tpig[1 ];
610
638
const int64_t i1 = tpig[0 ];
611
639
612
- const bool is_neox = mode & 2 ;
613
- const float theta_scale = pow (freq_base, -2 .0f /n_dims);
640
+ const float theta_scale = powf (freq_base, -2 .0f /n_dims);
641
+ const float theta_ntk_scale = powf (freq_base * powf (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
642
+ device float corr_factors[4 ];
643
+ ggml_rope_ntkv2_corr_factors (n_dims, freq_base, corr_factors);
614
644
615
- const int64_t p = ((mode & 1 ) == 0 ? n_past + i2 : i2);
645
+ float theta_base = (mode & 1 ) == 0 ? n_past + i2 : i2;
646
+ float theta_ntk = theta_base;
616
647
617
- float theta = freq_scale * ( float )p ;
648
+ const bool is_neox = mode & 2 ;
618
649
619
650
if (!is_neox) {
620
651
for (int64_t i0 = 0 ; i0 < ne0; i0 += 2 ) {
621
- const float cos_theta = cos (theta);
622
- const float sin_theta = sin (theta);
652
+ const float theta_linear = freq_scale * theta_base;
653
+ const float theta = rope_ntkv2 (theta_base, theta_linear, theta_ntk, corr_factors,
654
+ i0, ntk_factor, extrapolation_factor);
655
+ const float cos_theta = cosf (theta);
656
+ const float sin_theta = sinf (theta);
623
657
624
- theta *= theta_scale;
658
+ theta_base *= theta_scale;
659
+ theta_ntk *= theta_ntk_scale;
625
660
626
661
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
627
662
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -633,6 +668,7 @@ kernel void kernel_rope(
633
668
dst_data[1 ] = x0*sin_theta + x1*cos_theta;
634
669
}
635
670
} else {
671
+ theta_base *= freq_scale;
636
672
// TODO: implement
637
673
}
638
674
}
0 commit comments