@@ -597,6 +597,55 @@ kernel void kernel_alibi_f32(
597
597
}
598
598
}
599
599
600
+ static float rope_ntkv2_ramp (const float low, const float high, const int i0) {
601
+ const float y = (i0 / 2 - low) / min (0 .001f , high - low);
602
+ return 1 .0f - min (1 .0f , max (0 .0f , y));
603
+ }
604
+
605
+ // NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
606
+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
607
+ static float rope_ntkv2 (
608
+ const float theta_base,
609
+ const float theta_linear,
610
+ const float theta_ntk,
611
+ const float corr_factors[4 ],
612
+ const int64_t i0,
613
+ const float ntk_factor,
614
+ const float ext_factor) {
615
+ float ramp_mix;
616
+ float theta;
617
+
618
+ ramp_mix = rope_ntkv2_ramp (corr_factors[0 ], corr_factors[1 ], i0) * ntk_factor;
619
+ theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
620
+
621
+ ramp_mix = rope_ntkv2_ramp (corr_factors[2 ], corr_factors[3 ], i0) * ext_factor;
622
+ theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
623
+ return theta;
624
+ }
625
+
626
+ // Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
627
+ // Do not change unless there is a good reason for doing so!
628
+ constant float BETA_0 = 1 .75f ;
629
+ constant float BETA_1 = 1 .25f ;
630
+ constant float GAMMA_0 = 16 .0f ;
631
+ constant float GAMMA_1 = 2 .0f ;
632
+
633
+ constant float max_pos_emb = 2048 ;
634
+
635
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
636
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
637
+ static float rope_ntkv2_corr_factor (const int n_dims, const float n_rot, const float base) {
638
+ return n_dims * log (max_pos_emb / (n_rot * 2 * M_PI_F)) / (2 * log (base));
639
+ }
640
+
641
+ static void rope_ntkv2_corr_factors (int n_dims, const float freq_base, float factors[4 ]) {
642
+ // start and end correction factors
643
+ factors[0 ] = max (0 .0f , floor (rope_ntkv2_corr_factor (n_dims, BETA_0, freq_base)));
644
+ factors[1 ] = min (n_dims - 1 .0f , ceil (rope_ntkv2_corr_factor (n_dims, BETA_1, freq_base)));
645
+ factors[2 ] = max (0 .0f , floor (rope_ntkv2_corr_factor (n_dims, GAMMA_0, freq_base)));
646
+ factors[3 ] = min (n_dims - 1 .0f , ceil (rope_ntkv2_corr_factor (n_dims, GAMMA_1, freq_base)));
647
+ }
648
+
600
649
kernel void kernel_rope (
601
650
device const void * src0,
602
651
device float * dst,
@@ -621,24 +670,33 @@ kernel void kernel_rope(
621
670
constant int & mode,
622
671
constant float & freq_base,
623
672
constant float & freq_scale,
673
+ constant float & ntk_factor,
674
+ constant float & ext_factor,
624
675
uint3 tpig[[thread_position_in_grid]]) {
625
676
const int64_t i3 = tpig[2 ];
626
677
const int64_t i2 = tpig[1 ];
627
678
const int64_t i1 = tpig[0 ];
628
679
629
- const bool is_neox = mode & 2 ;
630
680
const float theta_scale = pow (freq_base, -2 .0f /n_dims);
681
+ const float theta_ntk_scale = pow (freq_base * pow (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
682
+ float corr_factors[4 ];
683
+ rope_ntkv2_corr_factors (n_dims, freq_base, corr_factors);
631
684
632
- const int64_t p = ((mode & 1 ) == 0 ? n_past + i2 : i2);
685
+ float theta_base = (mode & 1 ) == 0 ? n_past + i2 : i2;
686
+ float theta_ntk = theta_base;
633
687
634
- float theta = freq_scale * ( float )p ;
688
+ const bool is_neox = mode & 2 ;
635
689
636
690
if (!is_neox) {
637
691
for (int64_t i0 = 0 ; i0 < ne0; i0 += 2 ) {
692
+ const float theta_linear = freq_scale * theta_base;
693
+ const float theta = rope_ntkv2 (theta_base, theta_linear, theta_ntk, corr_factors,
694
+ i0, ntk_factor, ext_factor);
638
695
const float cos_theta = cos (theta);
639
696
const float sin_theta = sin (theta);
640
697
641
- theta *= theta_scale;
698
+ theta_base *= theta_scale;
699
+ theta_ntk *= theta_ntk_scale;
642
700
643
701
device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
644
702
device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -650,6 +708,7 @@ kernel void kernel_rope(
650
708
dst_data[1 ] = x0*sin_theta + x1*cos_theta;
651
709
}
652
710
} else {
711
+ theta_base *= freq_scale;
653
712
// TODO: implement
654
713
}
655
714
}
0 commit comments