diff --git a/Makefile b/Makefile index c20bd551fac..33335c21aa9 100644 --- a/Makefile +++ b/Makefile @@ -50,7 +50,11 @@ endif # TODO: probably these flags need to be tweaked on some architectures # feel free to update the Makefile for your architecture and send a pull request or issue ifeq ($(UNAME_M),x86_64) - CFLAGS += -mavx -mavx2 -mfma -mf16c + # AVX 512 + CFLAGS += -mavx512f -mfma -mf16c + + # AVX 256 + #CFLAGS += -mavx -mavx2 -mfma -mf16c endif ifeq ($(UNAME_M),amd64) CFLAGS += -mavx -mavx2 -mfma -mf16c diff --git a/ggml.c b/ggml.c index 79b910bbac6..5c213fc78de 100644 --- a/ggml.c +++ b/ggml.c @@ -327,6 +327,45 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float for (int i = n16; i < n; ++i) { sumf += x[i]*y[i]; } +#elif defined(__AVX512F__) + const int n64 = (n & ~63); + + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + for (int i = 0; i < n64; i += 64) { + x0 = _mm512_loadu_ps(x + i + 0); + x1 = _mm512_loadu_ps(x + i + 16); + x2 = _mm512_loadu_ps(x + i + 32); + x3 = _mm512_loadu_ps(x + i + 48); + + y0 = _mm512_loadu_ps(y + i + 0); + y1 = _mm512_loadu_ps(y + i + 16); + y2 = _mm512_loadu_ps(y + i + 32); + y3 = _mm512_loadu_ps(y + i + 48); + + sum0 = _mm512_fmadd_ps(x0, y0, sum0); + sum1 = _mm512_fmadd_ps(x1, y1, sum1); + sum2 = _mm512_fmadd_ps(x2, y2, sum2); + sum3 = _mm512_fmadd_ps(x3, y3, sum3); + } + + sum0 = _mm512_add_ps(sum0, sum1); + sum2 = _mm512_add_ps(sum2, sum3); + sum0 = _mm512_add_ps(sum0, sum2); + + sumf = sum0[0] + sum0[1] + sum0[2] + sum0[3] + sum0[4] + sum0[5] + sum0[6] + sum0[7] + + sum0[8] + sum0[9] + sum0[10] + sum0[11] + sum0[12] + sum0[13] + sum0[14] + sum0[15]; + + // leftovers + for (int i = n64; i < n; ++i) { + sumf += x[i]*y[i]; + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); @@ -524,6 +563,47 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t for (int i = n32; i < n; ++i) { sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); } +#elif defined(__AVX512F__) + // AVX 512-bit + const int n64 = (n & ~63); + + __m512 sum0 = _mm512_setzero_ps(); + __m512 sum1 = _mm512_setzero_ps(); + __m512 sum2 = _mm512_setzero_ps(); + __m512 sum3 = _mm512_setzero_ps(); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + for (int i = 0; i < n64; i += 64) { + x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 ))); + x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16))); + x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32))); + x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48))); + + y0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 0 ))); + y1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 16))); + y2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 32))); + y3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 48))); + + sum0 = _mm512_fmadd_ps(x0, y0, sum0); + sum1 = _mm512_fmadd_ps(x1, y1, sum1); + sum2 = _mm512_fmadd_ps(x2, y2, sum2); + sum3 = _mm512_fmadd_ps(x3, y3, sum3); + } + + const __m512 sum01 = _mm512_add_ps(sum0, sum1); + const __m512 sum23 = _mm512_add_ps(sum2, sum3); + const __m512 sum0123 = _mm512_add_ps(sum01, sum23); + + sumf = sum0123[0] + sum0123[1] + sum0123[2] + sum0123[3] + sum0123[4] + sum0123[5] + sum0123[6] + sum0123[7] + + sum0123[8] + sum0123[9] + sum0123[10] + sum0123[11] + sum0123[12] + sum0123[13] + sum0123[14] + sum0123[15]; + + // leftovers + for (int i = n64; i < n; ++i) { + //GGML_ASSERT(false); + sumf += ggml_fp16_to_fp32(x[i])*ggml_fp16_to_fp32(y[i]); + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); @@ -630,7 +710,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float // NEON 128-bit const int n16 = (n & ~15); - const float32x4_t v4 = vdupq_n_f32(v); + const float32x4_t v0 = vdupq_n_f32(v); float32x4_t x0, x1, x2, x3; float32x4_t y0, y1, y2, y3; @@ -646,14 +726,14 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float y2 = vld1q_f32(y + i + 8); y3 = vld1q_f32(y + i + 12); - y0 = vfmaq_f32(y0, x0, v4); - y1 = vfmaq_f32(y1, x1, v4); - y2 = vfmaq_f32(y2, x2, v4); - y3 = vfmaq_f32(y3, x3, v4); + y0 = vfmaq_f32(y0, x0, v0); + y1 = vfmaq_f32(y1, x1, v0); + y2 = vfmaq_f32(y2, x2, v0); + y3 = vfmaq_f32(y3, x3, v0); - vst1q_f32(y + i + 0, y0); - vst1q_f32(y + i + 4, y1); - vst1q_f32(y + i + 8, y2); + vst1q_f32(y + i + 0, y0); + vst1q_f32(y + i + 4, y1); + vst1q_f32(y + i + 8, y2); vst1q_f32(y + i + 12, y3); } @@ -661,11 +741,46 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float for (int i = n16; i < n; ++i) { y[i] += x[i]*v; } +#elif defined(__AVX512F__) + // AVX512 512-bit + const int n64 = (n & ~63); + + const __m512 v0 = _mm512_set1_ps(v); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + for (int i = 0; i < n64; i += 64) { + x0 = _mm512_loadu_ps(x + i + 0); + x1 = _mm512_loadu_ps(x + i + 16); + x2 = _mm512_loadu_ps(x + i + 32); + x3 = _mm512_loadu_ps(x + i + 48); + + y0 = _mm512_loadu_ps(y + i + 0); + y1 = _mm512_loadu_ps(y + i + 16); + y2 = _mm512_loadu_ps(y + i + 32); + y3 = _mm512_loadu_ps(y + i + 48); + + y0 = _mm512_fmadd_ps(x0, v0, y0); + y1 = _mm512_fmadd_ps(x1, v0, y1); + y2 = _mm512_fmadd_ps(x2, v0, y2); + y3 = _mm512_fmadd_ps(x3, v0, y3); + + _mm512_storeu_ps(y + i + 0, y0); + _mm512_storeu_ps(y + i + 16, y1); + _mm512_storeu_ps(y + i + 32, y2); + _mm512_storeu_ps(y + i + 48, y3); + } + + // leftovers + for (int i = n64; i < n; ++i) { + y[i] += x[i]*v; + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); - const __m256 v4 = _mm256_set1_ps(v); + const __m256 v0 = _mm256_set1_ps(v); __m256 x0, x1, x2, x3; __m256 y0, y1, y2, y3; @@ -681,13 +796,13 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float y2 = _mm256_loadu_ps(y + i + 16); y3 = _mm256_loadu_ps(y + i + 24); - y0 = _mm256_fmadd_ps(x0, v4, y0); - y1 = _mm256_fmadd_ps(x1, v4, y1); - y2 = _mm256_fmadd_ps(x2, v4, y2); - y3 = _mm256_fmadd_ps(x3, v4, y3); + y0 = _mm256_fmadd_ps(x0, v0, y0); + y1 = _mm256_fmadd_ps(x1, v0, y1); + y2 = _mm256_fmadd_ps(x2, v0, y2); + y3 = _mm256_fmadd_ps(x3, v0, y3); - _mm256_storeu_ps(y + i + 0, y0); - _mm256_storeu_ps(y + i + 8, y1); + _mm256_storeu_ps(y + i + 0, y0); + _mm256_storeu_ps(y + i + 8, y1); _mm256_storeu_ps(y + i + 16, y2); _mm256_storeu_ps(y + i + 24, y3); } @@ -700,7 +815,7 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float // WASM SIMD 128-bit const int n16 = (n & ~15); - const v128_t v4 = wasm_f32x4_splat(v); + const v128_t v0 = wasm_f32x4_splat(v); v128_t x0, x1, x2, x3; v128_t y0, y1, y2, y3; @@ -716,10 +831,10 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float y2 = wasm_v128_load(y + i + 8); y3 = wasm_v128_load(y + i + 12); - y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); - y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); - y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); - y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); + y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v0)); + y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v0)); + y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v0)); + y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v0)); wasm_v128_store(y + i + 0, y0); wasm_v128_store(y + i + 4, y1); @@ -745,7 +860,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ const int n32 = (n & ~31); #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - const float16x8_t v8 = vdupq_n_f16(v); + const float16x8_t v0 = vdupq_n_f16(v); float16x8_t x0, x1, x2, x3; float16x8_t y0, y1, y2, y3; @@ -761,10 +876,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ x2 = vld1q_f16(x + i + 16); x3 = vld1q_f16(x + i + 24); - y0 = vfmaq_f16(y0, x0, v8); - y1 = vfmaq_f16(y1, x1, v8); - y2 = vfmaq_f16(y2, x2, v8); - y3 = vfmaq_f16(y3, x3, v8); + y0 = vfmaq_f16(y0, x0, v0); + y1 = vfmaq_f16(y1, x1, v0); + y2 = vfmaq_f16(y2, x2, v0); + y3 = vfmaq_f16(y3, x3, v0); vst1q_f16(y + i + 0 , y0); vst1q_f16(y + i + 8 , y1); @@ -772,8 +887,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ vst1q_f16(y + i + 24, y3); } #else - const float32x4_t v40 = vdupq_n_f32(v); - const float32x4_t v41 = vdupq_n_f32(v); + const float32x4_t v0 = vdupq_n_f32(v); float32x4_t x0, x1, x2, x3, x4, x5, x6, x7; float32x4_t y0, y1, y2, y3, y4, y5, y6, y7; @@ -797,14 +911,14 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ x6 = vcvt_f32_f16(vld1_f16(x + i + 24)); x7 = vcvt_f32_f16(vld1_f16(x + i + 28)); - y0 = vfmaq_f32(y0, x0, v40); - y1 = vfmaq_f32(y1, x1, v40); - y2 = vfmaq_f32(y2, x2, v40); - y3 = vfmaq_f32(y3, x3, v40); - y4 = vfmaq_f32(y4, x4, v41); - y5 = vfmaq_f32(y5, x5, v41); - y6 = vfmaq_f32(y6, x6, v41); - y7 = vfmaq_f32(y7, x7, v41); + y0 = vfmaq_f32(y0, x0, v0); + y1 = vfmaq_f32(y1, x1, v0); + y2 = vfmaq_f32(y2, x2, v0); + y3 = vfmaq_f32(y3, x3, v0); + y4 = vfmaq_f32(y4, x4, v0); + y5 = vfmaq_f32(y5, x5, v0); + y6 = vfmaq_f32(y6, x6, v0); + y7 = vfmaq_f32(y7, x7, v0); vst1_f16(y + i + 0 , vcvt_f16_f32(y0)); vst1_f16(y + i + 4 , vcvt_f16_f32(y1)); @@ -822,11 +936,47 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ GGML_ASSERT(false); y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); } +#elif defined(__AVX512F__) + // AVX 512-bit + const int n64 = (n & ~63); + + const __m512 v0 = _mm512_set1_ps(v); + + __m512 x0, x1, x2, x3; + __m512 y0, y1, y2, y3; + + for (int i = 0; i < n64; i += 64) { + x0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 0 ))); + x1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 16))); + x2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 32))); + x3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + i + 48))); + + y0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 0 ))); + y1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 16))); + y2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 32))); + y3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + i + 48))); + + y0 = _mm512_fmadd_ps(x0, v0, y0); + y1 = _mm512_fmadd_ps(x1, v0, y1); + y2 = _mm512_fmadd_ps(x2, v0, y2); + y3 = _mm512_fmadd_ps(x3, v0, y3); + + _mm256_storeu_si256((__m256i*)(y + i + 0 ), _mm512_cvtps_ph(y0, 0)); + _mm256_storeu_si256((__m256i*)(y + i + 16), _mm512_cvtps_ph(y1, 0)); + _mm256_storeu_si256((__m256i*)(y + i + 32), _mm512_cvtps_ph(y2, 0)); + _mm256_storeu_si256((__m256i*)(y + i + 48), _mm512_cvtps_ph(y3, 0)); + } + + // leftovers + for (int i = n64; i < n; ++i) { + GGML_ASSERT(false); + y[i] = ggml_fp32_to_fp16(ggml_fp16_to_fp32(y[i]) + ggml_fp16_to_fp32(x[i])*v); + } #elif defined(__AVX2__) // AVX 256-bit const int n32 = (n & ~31); - const __m256 v8 = _mm256_set1_ps(v); + const __m256 v0 = _mm256_set1_ps(v); __m256 x0, x1, x2, x3; __m256 y0, y1, y2, y3; @@ -842,10 +992,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ x2 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 16))); x3 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(x + i + 24))); - y0 = _mm256_fmadd_ps(x0, v8, y0); - y1 = _mm256_fmadd_ps(x1, v8, y1); - y2 = _mm256_fmadd_ps(x2, v8, y2); - y3 = _mm256_fmadd_ps(x3, v8, y3); + y0 = _mm256_fmadd_ps(x0, v0, y0); + y1 = _mm256_fmadd_ps(x1, v0, y1); + y2 = _mm256_fmadd_ps(x2, v0, y2); + y3 = _mm256_fmadd_ps(x3, v0, y3); _mm_storeu_si128((__m128i*)(y + i + 0 ), _mm256_cvtps_ph(y0, 0)); _mm_storeu_si128((__m128i*)(y + i + 8 ), _mm256_cvtps_ph(y1, 0)); @@ -862,7 +1012,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ // WASM SIMD 128-bit const int n16 = (n & ~15); - const v128_t v4 = wasm_f32x4_splat(v); + const v128_t v0 = wasm_f32x4_splat(v); v128_t x0, x1, x2, x3; v128_t y0, y1, y2, y3; @@ -886,10 +1036,10 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, ggml_ y2 = wasm_v128_load(ty + 8); y3 = wasm_v128_load(ty + 12); - y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v4)); - y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v4)); - y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v4)); - y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v4)); + y0 = wasm_f32x4_add(y0, wasm_f32x4_mul(x0, v0)); + y1 = wasm_f32x4_add(y1, wasm_f32x4_mul(x1, v0)); + y2 = wasm_f32x4_add(y2, wasm_f32x4_mul(x2, v0)); + y3 = wasm_f32x4_add(y3, wasm_f32x4_mul(x3, v0)); wasm_v128_store(ty + 0, y0); wasm_v128_store(ty + 4, y1);