From 2e0033e9301a144c98aa610fcf442b60ac170405 Mon Sep 17 00:00:00 2001 From: Brian Lui Date: Thu, 6 Sep 2018 17:08:00 -0700 Subject: [PATCH 1/3] Removed out-of-bound pointer access for AddScalarU SSE and AVX intrinsics --- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 20 ++++++++++++-------- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 14 ++++++++------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index b31a427139..0abb157801 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -20,6 +20,10 @@ internal static class AvxIntrinsics { private static readonly Vector256 _absMask256 = Avx.StaticCast(Avx.SetAllVector256(0x7FFFFFFF)); + // The count of 32-bit floats in Vector256 + private const int AvxAlignment = 8; + + // The count of bytes in Vector256, corresponding to _cbAlign in AlignedArray private const int Vector256Alignment = 32; [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] @@ -415,32 +419,32 @@ public static unsafe void AddScalarU(float scalar, Span dst) { fixed (float* pdst = dst) { - float* pDstEnd = pdst + dst.Length; - float* pDstCurrent = pdst; - Vector256 scalarVector256 = Avx.SetAllVector256(scalar); + int countAvx = Math.DivRem(dst.Length, AvxAlignment, out int remainderAvx); + float* pDstCurrent = pdst; - while (pDstCurrent + 8 <= pDstEnd) + for (int i = 0; i < countAvx; i++) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = Avx.Add(dstVector, scalarVector256); Avx.Store(pDstCurrent, dstVector); - pDstCurrent += 8; + pDstCurrent += AvxAlignment; } Vector128 scalarVector128 = Sse.SetAllVector128(scalar); + int countSse = Math.DivRem(remainderAvx, SseIntrinsics.SseAlignment, out int remainderSse); - if (pDstCurrent + 4 <= pDstEnd) + if (countSse > 0) { Vector128 dstVector = Sse.LoadVector128(pDstCurrent); dstVector = Sse.Add(dstVector, scalarVector128); Sse.Store(pDstCurrent, dstVector); - pDstCurrent += 4; + pDstCurrent += SseIntrinsics.SseAlignment; } - while (pDstCurrent < pDstEnd) + for (int i = 0; i < remainderSse; i++) { Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); dstVector = Sse.AddScalar(dstVector, scalarVector128); diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 0f4fb54d18..08bc1f246a 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -26,6 +26,9 @@ internal static class SseIntrinsics Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) : Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); + // The count of 32-bit floats in Vector128 + internal const int SseAlignment = 4; + // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray private const int Vector128Alignment = 16; @@ -412,21 +415,20 @@ public static unsafe void AddScalarU(float scalar, Span dst) { fixed (float* pdst = dst) { - float* pDstEnd = pdst + dst.Length; - float* pDstCurrent = pdst; - Vector128 scalarVector = Sse.SetAllVector128(scalar); + int count = Math.DivRem(dst.Length, SseAlignment, out int remainder); + float* pDstCurrent = pdst; - while (pDstCurrent + 4 <= pDstEnd) + for (int i = 0; i < count; i++) { Vector128 dstVector = Sse.LoadVector128(pDstCurrent); dstVector = Sse.Add(dstVector, scalarVector); Sse.Store(pDstCurrent, dstVector); - pDstCurrent += 4; + pDstCurrent += SseAlignment; } - while (pDstCurrent < pDstEnd) + for (int i = 0; i < remainder; i++) { Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); dstVector = Sse.AddScalar(dstVector, scalarVector); From a9558f433a330d16741b867ab49642850dc70adb Mon Sep 17 00:00:00 2001 From: Brian Lui Date: Fri, 7 Sep 2018 12:28:25 -0700 Subject: [PATCH 2/3] Respond to PR feedback: remove 2nd Math.DivRem and use indexing for scalar operations --- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 21 ++++++++------------- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 12 ++++-------- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index 0abb157801..e6cb94afc2 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -21,7 +21,7 @@ internal static class AvxIntrinsics private static readonly Vector256 _absMask256 = Avx.StaticCast(Avx.SetAllVector256(0x7FFFFFFF)); // The count of 32-bit floats in Vector256 - private const int AvxAlignment = 8; + private const int Vector256SingleElementCount = 8; // The count of bytes in Vector256, corresponding to _cbAlign in AlignedArray private const int Vector256Alignment = 32; @@ -420,37 +420,32 @@ public static unsafe void AddScalarU(float scalar, Span dst) fixed (float* pdst = dst) { Vector256 scalarVector256 = Avx.SetAllVector256(scalar); - int countAvx = Math.DivRem(dst.Length, AvxAlignment, out int remainderAvx); + int count = Math.DivRem(dst.Length, Vector256SingleElementCount, out int remainder); float* pDstCurrent = pdst; - for (int i = 0; i < countAvx; i++) + for (int i = 0; i < count; i++) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = Avx.Add(dstVector, scalarVector256); Avx.Store(pDstCurrent, dstVector); - pDstCurrent += AvxAlignment; + pDstCurrent += Vector256SingleElementCount; } Vector128 scalarVector128 = Sse.SetAllVector128(scalar); - int countSse = Math.DivRem(remainderAvx, SseIntrinsics.SseAlignment, out int remainderSse); - if (countSse > 0) + if (remainder >= 4) { Vector128 dstVector = Sse.LoadVector128(pDstCurrent); dstVector = Sse.Add(dstVector, scalarVector128); Sse.Store(pDstCurrent, dstVector); - pDstCurrent += SseIntrinsics.SseAlignment; + pDstCurrent += SseIntrinsics.Vector128SingleElementCount; } - for (int i = 0; i < remainderSse; i++) + for (int i = 0; i < remainder - 4; i++) { - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - dstVector = Sse.AddScalar(dstVector, scalarVector128); - Sse.StoreScalar(pDstCurrent, dstVector); - - pDstCurrent++; + pDstCurrent[i] += scalar; } } } diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 08bc1f246a..263779e478 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -27,7 +27,7 @@ internal static class SseIntrinsics Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); // The count of 32-bit floats in Vector128 - internal const int SseAlignment = 4; + internal const int Vector128SingleElementCount = 4; // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray private const int Vector128Alignment = 16; @@ -416,7 +416,7 @@ public static unsafe void AddScalarU(float scalar, Span dst) fixed (float* pdst = dst) { Vector128 scalarVector = Sse.SetAllVector128(scalar); - int count = Math.DivRem(dst.Length, SseAlignment, out int remainder); + int count = Math.DivRem(dst.Length, Vector128SingleElementCount, out int remainder); float* pDstCurrent = pdst; for (int i = 0; i < count; i++) @@ -425,16 +425,12 @@ public static unsafe void AddScalarU(float scalar, Span dst) dstVector = Sse.Add(dstVector, scalarVector); Sse.Store(pDstCurrent, dstVector); - pDstCurrent += SseAlignment; + pDstCurrent += Vector128SingleElementCount; } for (int i = 0; i < remainder; i++) { - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - dstVector = Sse.AddScalar(dstVector, scalarVector); - Sse.StoreScalar(pDstCurrent, dstVector); - - pDstCurrent++; + pDstCurrent[i] += scalar; } } } From f26f11277d33769cb6e917c2f8e9b052513619d7 Mon Sep 17 00:00:00 2001 From: Brian Lui Date: Fri, 7 Sep 2018 16:51:47 -0700 Subject: [PATCH 3/3] Respond to PR feedback: Replaced scalar operations by indexed code for all AVX intrinsics, except MatMul's and those involving AbsMask --- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 400 +++++++++------------- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 7 +- 2 files changed, 165 insertions(+), 242 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index e6cb94afc2..0294804b29 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -20,9 +20,6 @@ internal static class AvxIntrinsics { private static readonly Vector256 _absMask256 = Avx.StaticCast(Avx.SetAllVector256(0x7FFFFFFF)); - // The count of 32-bit floats in Vector256 - private const int Vector256SingleElementCount = 8; - // The count of bytes in Vector256, corresponding to _cbAlign in AlignedArray private const int Vector256Alignment = 32; @@ -420,7 +417,8 @@ public static unsafe void AddScalarU(float scalar, Span dst) fixed (float* pdst = dst) { Vector256 scalarVector256 = Avx.SetAllVector256(scalar); - int count = Math.DivRem(dst.Length, Vector256SingleElementCount, out int remainder); + + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pDstCurrent = pdst; for (int i = 0; i < count; i++) @@ -429,7 +427,7 @@ public static unsafe void AddScalarU(float scalar, Span dst) dstVector = Avx.Add(dstVector, scalarVector256); Avx.Store(pDstCurrent, dstVector); - pDstCurrent += Vector256SingleElementCount; + pDstCurrent += 8; } Vector128 scalarVector128 = Sse.SetAllVector128(scalar); @@ -440,10 +438,10 @@ public static unsafe void AddScalarU(float scalar, Span dst) dstVector = Sse.Add(dstVector, scalarVector128); Sse.Store(pDstCurrent, dstVector); - pDstCurrent += SseIntrinsics.Vector128SingleElementCount; + pDstCurrent += 4; } - for (int i = 0; i < remainder - 4; i++) + for (int i = 0; i < remainder % 4; i++) { pDstCurrent[i] += scalar; } @@ -454,12 +452,12 @@ public static unsafe void ScaleU(float scale, Span dst) { fixed (float* pdst = dst) { - float* pDstCurrent = pdst; - float* pEnd = pdst + dst.Length; - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - while (pDstCurrent + 8 <= pEnd) + int count = Math.DivRem(dst.Length, 8, out int remainder); + float* pDstCurrent = pdst; + + for (int i = 0; i < count; i++) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -471,7 +469,7 @@ public static unsafe void ScaleU(float scale, Span dst) Vector128 scaleVector128 = Sse.SetAllVector128(scale); - if (pDstCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -481,14 +479,9 @@ public static unsafe void ScaleU(float scale, Span dst) pDstCurrent += 4; } - while (pDstCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - dstVector = Sse.MultiplyScalar(scaleVector128, dstVector); - Sse.StoreScalar(pDstCurrent, dstVector); - - pDstCurrent++; + pDstCurrent[i] *= scale; } } } @@ -498,13 +491,13 @@ public static unsafe void ScaleSrcU(float scale, Span src, Span ds fixed (float* psrc = src) fixed (float* pdst = dst) { - float* pDstEnd = pdst + dst.Length; + Vector256 scaleVector256 = Avx.SetAllVector256(scale); + + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - - while (pDstCurrent + 8 <= pDstEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Multiply(srcVector, scaleVector256); @@ -516,7 +509,7 @@ public static unsafe void ScaleSrcU(float scale, Span src, Span ds Vector128 scaleVector128 = Sse.SetAllVector128(scale); - if (pDstCurrent + 4 <= pDstEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Multiply(srcVector, scaleVector128); @@ -526,14 +519,9 @@ public static unsafe void ScaleSrcU(float scale, Span src, Span ds pDstCurrent += 4; } - while (pDstCurrent < pDstEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - srcVector = Sse.MultiplyScalar(srcVector, scaleVector128); - Sse.StoreScalar(pDstCurrent, srcVector); - - pSrcCurrent++; - pDstCurrent++; + pDstCurrent[i] = pSrcCurrent[i] * scale; } } } @@ -543,13 +531,13 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) { fixed (float* pdst = dst) { - float* pDstEnd = pdst + dst.Length; - float* pDstCurrent = pdst; - Vector256 a256 = Avx.SetAllVector256(a); Vector256 b256 = Avx.SetAllVector256(b); - while (pDstCurrent + 8 <= pDstEnd) + int count = Math.DivRem(dst.Length, 8, out int remainder); + float* pDstCurrent = pdst; + + for (int i = 0; i < count; i++) { Vector256 dstVector = Avx.LoadVector256(pDstCurrent); dstVector = Avx.Add(dstVector, b256); @@ -562,7 +550,7 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) Vector128 a128 = Sse.SetAllVector128(a); Vector128 b128 = Sse.SetAllVector128(b); - if (pDstCurrent + 4 <= pDstEnd) + if (remainder >= 4) { Vector128 dstVector = Sse.LoadVector128(pDstCurrent); dstVector = Sse.Add(dstVector, b128); @@ -572,14 +560,9 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) pDstCurrent += 4; } - while (pDstCurrent < pDstEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - dstVector = Sse.AddScalar(dstVector, b128); - dstVector = Sse.MultiplyScalar(dstVector, a128); - Sse.StoreScalar(pDstCurrent, dstVector); - - pDstCurrent++; + pDstCurrent[i] = a * (pDstCurrent[i] + b); } } } @@ -589,13 +572,13 @@ public static unsafe void AddScaleU(float scale, Span src, Span ds fixed (float* psrc = src) fixed (float* pdst = dst) { + Vector256 scaleVector256 = Avx.SetAllVector256(scale); + + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - float* pEnd = pdst + dst.Length; - - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - while (pDstCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -610,7 +593,7 @@ public static unsafe void AddScaleU(float scale, Span src, Span ds Vector128 scaleVector128 = Sse.SetAllVector128(scale); - if (pDstCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -623,17 +606,9 @@ public static unsafe void AddScaleU(float scale, Span src, Span ds pDstCurrent += 4; } - while (pDstCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - srcVector = Sse.MultiplyScalar(srcVector, scaleVector128); - dstVector = Sse.AddScalar(dstVector, srcVector); - Sse.StoreScalar(pDstCurrent, dstVector); - - pSrcCurrent++; - pDstCurrent++; + pDstCurrent[i] += scale * pSrcCurrent[i]; } } } @@ -644,14 +619,14 @@ public static unsafe void AddScaleCopyU(float scale, Span src, Span scaleVector256 = Avx.SetAllVector256(scale); + + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; float* pResCurrent = pres; - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - - while (pResCurrent + 8 <= pResEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -666,7 +641,7 @@ public static unsafe void AddScaleCopyU(float scale, Span src, Span scaleVector128 = Sse.SetAllVector128(scale); - if (pResCurrent + 4 <= pResEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -679,17 +654,9 @@ public static unsafe void AddScaleCopyU(float scale, Span src, Span srcVector = Sse.LoadScalarVector128(pSrcCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - srcVector = Sse.MultiplyScalar(srcVector, scaleVector128); - dstVector = Sse.AddScalar(dstVector, srcVector); - Sse.StoreScalar(pResCurrent, dstVector); - - pSrcCurrent++; - pDstCurrent++; - pResCurrent++; + pResCurrent[i] = pDstCurrent[i] + scale * pSrcCurrent[i]; } } } @@ -700,14 +667,14 @@ public static unsafe void AddScaleSU(float scale, Span src, Span idx fixed (int* pidx = idx) fixed (float* pdst = dst) { + Vector256 scaleVector256 = Avx.SetAllVector256(scale); + + int count = Math.DivRem(idx.Length, 8, out int remainder); float* pSrcCurrent = psrc; int* pIdxCurrent = pidx; float* pDstCurrent = pdst; - int* pEnd = pidx + idx.Length; - - Vector256 scaleVector256 = Avx.SetAllVector256(scale); - while (pIdxCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Load8(pDstCurrent, pIdxCurrent); @@ -722,7 +689,7 @@ public static unsafe void AddScaleSU(float scale, Span src, Span idx Vector128 scaleVector128 = Sse.SetAllVector128(scale); - if (pIdxCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = SseIntrinsics.Load4(pDstCurrent, pIdxCurrent); @@ -735,12 +702,10 @@ public static unsafe void AddScaleSU(float scale, Span src, Span idx pSrcCurrent += 4; } - while (pIdxCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - pDstCurrent[*pIdxCurrent] += scale * (*pSrcCurrent); - - pIdxCurrent++; - pSrcCurrent++; + int index = pIdxCurrent[i]; + pDstCurrent[index] += scale * pSrcCurrent[i]; } } } @@ -750,11 +715,11 @@ public static unsafe void AddU(Span src, Span dst) fixed (float* psrc = src) fixed (float* pdst = dst) { + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - float* pEnd = psrc + src.Length; - while (pSrcCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -766,7 +731,7 @@ public static unsafe void AddU(Span src, Span dst) pDstCurrent += 8; } - if (pSrcCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -778,16 +743,9 @@ public static unsafe void AddU(Span src, Span dst) pDstCurrent += 4; } - while (pSrcCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - Vector128 result = Sse.AddScalar(srcVector, dstVector); - Sse.StoreScalar(pDstCurrent, result); - - pSrcCurrent++; - pDstCurrent++; + pDstCurrent[i] += pSrcCurrent[i]; } } } @@ -798,12 +756,12 @@ public static unsafe void AddSU(Span src, Span idx, Span dst) fixed (int* pidx = idx) fixed (float* pdst = dst) { + int count = Math.DivRem(idx.Length, 8, out int remainder); float* pSrcCurrent = psrc; int* pIdxCurrent = pidx; float* pDstCurrent = pdst; - int* pEnd = pidx + idx.Length; - while (pIdxCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 dstVector = Load8(pDstCurrent, pIdxCurrent); Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); @@ -815,7 +773,7 @@ public static unsafe void AddSU(Span src, Span idx, Span dst) pSrcCurrent += 8; } - if (pIdxCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 dstVector = SseIntrinsics.Load4(pDstCurrent, pIdxCurrent); Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); @@ -827,12 +785,10 @@ public static unsafe void AddSU(Span src, Span idx, Span dst) pSrcCurrent += 4; } - while (pIdxCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - pDstCurrent[*pIdxCurrent] += *pSrcCurrent; - - pIdxCurrent++; - pSrcCurrent++; + int index = pIdxCurrent[i]; + pDstCurrent[index] += pSrcCurrent[i]; } } } @@ -843,12 +799,12 @@ public static unsafe void MulElementWiseU(Span src1, Span src2, Sp fixed (float* psrc2 = src2) fixed (float* pdst = dst) { + int count = Math.DivRem(dst.Length, 8, out int remainder); float* pSrc1Current = psrc1; float* pSrc2Current = psrc2; float* pDstCurrent = pdst; - float* pEnd = pdst + dst.Length; - while (pDstCurrent + 8 <= pEnd) + for (int i = 0; i < count; i++) { Vector256 src1Vector = Avx.LoadVector256(pSrc1Current); Vector256 src2Vector = Avx.LoadVector256(pSrc2Current); @@ -860,7 +816,7 @@ public static unsafe void MulElementWiseU(Span src1, Span src2, Sp pDstCurrent += 8; } - if (pDstCurrent + 4 <= pEnd) + if (remainder >= 4) { Vector128 src1Vector = Sse.LoadVector128(pSrc1Current); Vector128 src2Vector = Sse.LoadVector128(pSrc2Current); @@ -872,16 +828,9 @@ public static unsafe void MulElementWiseU(Span src1, Span src2, Sp pDstCurrent += 4; } - while (pDstCurrent < pEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 src1Vector = Sse.LoadScalarVector128(pSrc1Current); - Vector128 src2Vector = Sse.LoadScalarVector128(pSrc2Current); - src2Vector = Sse.MultiplyScalar(src1Vector, src2Vector); - Sse.StoreScalar(pDstCurrent, src2Vector); - - pSrc1Current++; - pSrc2Current++; - pDstCurrent++; + pDstCurrent[i] = pSrc1Current[i] * pSrc2Current[i]; } } } @@ -890,12 +839,12 @@ public static unsafe float SumU(Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { result256 = Avx.Add(result256, Avx.LoadVector256(pSrcCurrent)); pSrcCurrent += 8; @@ -906,21 +855,21 @@ public static unsafe float SumU(Span src) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { result128 = Sse.Add(result128, Sse.LoadVector128(pSrcCurrent)); pSrcCurrent += 4; } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - result128 = Sse.AddScalar(result128, Sse.LoadScalarVector128(pSrcCurrent)); - pSrcCurrent++; + result += pSrcCurrent[i]; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -928,12 +877,12 @@ public static unsafe float SumSqU(Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Add(result256, Avx.Multiply(srcVector, srcVector)); @@ -946,7 +895,7 @@ public static unsafe float SumSqU(Span src) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result128 = Sse.Add(result128, Sse.Multiply(srcVector, srcVector)); @@ -955,16 +904,14 @@ public static unsafe float SumSqU(Span src) } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, srcVector)); - - pSrcCurrent++; + result += pSrcCurrent[i] * pSrcCurrent[i]; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -972,13 +919,13 @@ public static unsafe float SumSqDiffU(float mean, Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); Vector256 meanVector256 = Avx.SetAllVector256(mean); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -993,7 +940,7 @@ public static unsafe float SumSqDiffU(float mean, Span src) Vector128 result128 = Sse.SetZeroVector128(); Vector128 meanVector128 = Sse.SetAllVector128(mean); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector128); @@ -1003,17 +950,15 @@ public static unsafe float SumSqDiffU(float mean, Span src) } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - srcVector = Sse.SubtractScalar(srcVector, meanVector128); - result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, srcVector)); - - pSrcCurrent++; + float difference = pSrcCurrent[i] - mean; + result += difference * difference; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -1021,12 +966,12 @@ public static unsafe float SumAbsU(Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256)); @@ -1039,7 +984,7 @@ public static unsafe float SumAbsU(Span src) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result128 = Sse.Add(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128)); @@ -1049,7 +994,7 @@ public static unsafe float SumAbsU(Span src) result128 = SseIntrinsics.VectorSum128(in result128); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); result128 = Sse.AddScalar(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128)); @@ -1065,13 +1010,13 @@ public static unsafe float SumAbsDiffU(float mean, Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); Vector256 meanVector256 = Avx.SetAllVector256(mean); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1086,7 +1031,7 @@ public static unsafe float SumAbsDiffU(float mean, Span src) Vector128 result128 = Sse.SetZeroVector128(); Vector128 meanVector128 = Sse.SetAllVector128(mean); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector128); @@ -1097,7 +1042,7 @@ public static unsafe float SumAbsDiffU(float mean, Span src) result128 = SseIntrinsics.VectorSum128(in result128); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); srcVector = Sse.SubtractScalar(srcVector, meanVector128); @@ -1114,12 +1059,12 @@ public static unsafe float MaxAbsU(Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256)); @@ -1132,7 +1077,7 @@ public static unsafe float MaxAbsU(Span src) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); result128 = Sse.Max(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128)); @@ -1142,7 +1087,7 @@ public static unsafe float MaxAbsU(Span src) result128 = SseIntrinsics.VectorMax128(in result128); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); result128 = Sse.MaxScalar(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128)); @@ -1158,13 +1103,13 @@ public static unsafe float MaxAbsDiffU(float mean, Span src) { fixed (float* psrc = src) { - float* pSrcEnd = psrc + src.Length; - float* pSrcCurrent = psrc; - Vector256 result256 = Avx.SetZeroVector256(); Vector256 meanVector256 = Avx.SetAllVector256(mean); - while (pSrcCurrent + 8 <= pSrcEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); srcVector = Avx.Subtract(srcVector, meanVector256); @@ -1179,7 +1124,7 @@ public static unsafe float MaxAbsDiffU(float mean, Span src) Vector128 result128 = Sse.SetZeroVector128(); Vector128 meanVector128 = Sse.SetAllVector128(mean); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); srcVector = Sse.Subtract(srcVector, meanVector128); @@ -1190,7 +1135,7 @@ public static unsafe float MaxAbsDiffU(float mean, Span src) result128 = SseIntrinsics.VectorMax128(in result128); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); srcVector = Sse.SubtractScalar(srcVector, meanVector128); @@ -1208,13 +1153,13 @@ public static unsafe float DotU(Span src, Span dst) fixed (float* psrc = src) fixed (float* pdst = dst) { + Vector256 result256 = Avx.SetZeroVector256(); + + int count = Math.DivRem(src.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - float* pSrcEnd = psrc + src.Length; - - Vector256 result256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Avx.LoadVector256(pSrcCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -1230,7 +1175,7 @@ public static unsafe float DotU(Span src, Span dst) Vector128 result128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 srcVector = Sse.LoadVector128(pSrcCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -1242,19 +1187,14 @@ public static unsafe float DotU(Span src, Span dst) } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = Sse.LoadScalarVector128(pSrcCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, dstVector)); - - pSrcCurrent++; - pDstCurrent++; + result += pSrcCurrent[i] * pDstCurrent[i]; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -1264,14 +1204,14 @@ public static unsafe float DotSU(Span src, Span dst, Span idx fixed (float* pdst = dst) fixed (int* pidx = idx) { + Vector256 result256 = Avx.SetZeroVector256(); + + int count = Math.DivRem(idx.Length, 8, out int remainder); float* pSrcCurrent = psrc; - float* pDstCurrent = pdst; int* pIdxCurrent = pidx; - int* pIdxEnd = pidx + idx.Length; - - Vector256 result256 = Avx.SetZeroVector256(); + float* pDstCurrent = pdst; - while (pIdxCurrent + 8 <= pIdxEnd) + for (int i = 0; i < count; i++) { Vector256 srcVector = Load8(pSrcCurrent, pIdxCurrent); Vector256 dstVector = Avx.LoadVector256(pDstCurrent); @@ -1287,7 +1227,7 @@ public static unsafe float DotSU(Span src, Span dst, Span idx Vector128 result128 = Sse.SetZeroVector128(); - if (pIdxCurrent + 4 <= pIdxEnd) + if (remainder >= 4) { Vector128 srcVector = SseIntrinsics.Load4(pSrcCurrent, pIdxCurrent); Vector128 dstVector = Sse.LoadVector128(pDstCurrent); @@ -1299,19 +1239,15 @@ public static unsafe float DotSU(Span src, Span dst, Span idx } result128 = SseIntrinsics.VectorSum128(in result128); + float result = Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); - while (pIdxCurrent < pIdxEnd) + for (int i = 0; i < remainder % 4; i++) { - Vector128 srcVector = SseIntrinsics.Load1(pSrcCurrent, pIdxCurrent); - Vector128 dstVector = Sse.LoadScalarVector128(pDstCurrent); - - result128 = Sse.AddScalar(result128, Sse.MultiplyScalar(srcVector, dstVector)); - - pIdxCurrent++; - pDstCurrent++; + int index = pIdxCurrent[i]; + result += pSrcCurrent[index] * pDstCurrent[i]; } - return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded)); + return result; } } @@ -1320,13 +1256,13 @@ public static unsafe float Dist2(Span src, Span dst) fixed (float* psrc = src) fixed (float* pdst = dst) { + Vector256 sqDistanceVector256 = Avx.SetZeroVector256(); + + int count = Math.DivRem(src.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDstCurrent = pdst; - float* pSrcEnd = psrc + src.Length; - - Vector256 sqDistanceVector256 = Avx.SetZeroVector256(); - while (pSrcCurrent + 8 <= pSrcEnd) + for (int i = 0; i < count; i++) { Vector256 distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent), Avx.LoadVector256(pDstCurrent)); @@ -1342,7 +1278,7 @@ public static unsafe float Dist2(Span src, Span dst) Vector128 sqDistanceVector128 = Sse.SetZeroVector128(); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 distanceVector = Sse.Subtract(Sse.LoadVector128(pSrcCurrent), Sse.LoadVector128(pDstCurrent)); @@ -1354,15 +1290,12 @@ public static unsafe float Dist2(Span src, Span dst) } sqDistanceVector128 = SseIntrinsics.VectorSum128(in sqDistanceVector128); - float norm = Sse.ConvertToSingle(Sse.AddScalar(sqDistanceVector128, sqDistanceVectorPadded)); - while (pSrcCurrent < pSrcEnd) + + for (int i = 0; i < remainder % 4; i++) { - float distance = (*pSrcCurrent) - (*pDstCurrent); + float distance = pSrcCurrent[i] - pDstCurrent[i]; norm += distance * distance; - - pSrcCurrent++; - pDstCurrent++; } return norm; @@ -1375,15 +1308,15 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span src, flo fixed (float* pdst1 = v) fixed (float* pdst2 = w) { - float* pSrcEnd = psrc + src.Length; + Vector256 xPrimal256 = Avx.SetAllVector256(primalUpdate); + Vector256 xThreshold256 = Avx.SetAllVector256(threshold); + + int count = Math.DivRem(src.Length, 8, out int remainder); float* pSrcCurrent = psrc; float* pDst1Current = pdst1; float* pDst2Current = pdst2; - Vector256 xPrimal256 = Avx.SetAllVector256(primalUpdate); - Vector256 xThreshold256 = Avx.SetAllVector256(threshold); - - while (pSrcCurrent + 8 <= pSrcEnd) + for (int i = 0; i < count; i++) { Vector256 xSrc = Avx.LoadVector256(pSrcCurrent); @@ -1402,7 +1335,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span src, flo Vector128 xPrimal128 = Sse.SetAllVector128(primalUpdate); Vector128 xThreshold128 = Sse.SetAllVector128(threshold); - if (pSrcCurrent + 4 <= pSrcEnd) + if (remainder >= 4) { Vector128 xSrc = Sse.LoadVector128(pSrcCurrent); @@ -1418,15 +1351,11 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span src, flo pDst2Current += 4; } - while (pSrcCurrent < pSrcEnd) + for (int i = 0; i < remainder % 4; i++) { - *pDst1Current += (*pSrcCurrent) * primalUpdate; - float dst1 = *pDst1Current; - *pDst2Current = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0; - - pSrcCurrent++; - pDst1Current++; - pDst2Current++; + pDst1Current[i] += primalUpdate * pSrcCurrent[i]; + float dst1 = pDst1Current[i]; + pDst2Current[i] = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0; } } } @@ -1438,14 +1367,14 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span src, Sp fixed (float* pdst1 = v) fixed (float* pdst2 = w) { - int* pIdxEnd = pidx + indices.Length; - float* pSrcCurrent = psrc; - int* pIdxCurrent = pidx; - Vector256 xPrimal256 = Avx.SetAllVector256(primalUpdate); Vector256 xThreshold = Avx.SetAllVector256(threshold); - while (pIdxCurrent + 8 <= pIdxEnd) + int count = Math.DivRem(src.Length, 8, out int remainder); + float* pSrcCurrent = psrc; + int* pIdxCurrent = pidx; + + for (int i = 0; i < count; i++) { Vector256 xSrc = Avx.LoadVector256(pSrcCurrent); @@ -1463,7 +1392,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span src, Sp Vector128 xPrimal128 = Sse.SetAllVector128(primalUpdate); Vector128 xThreshold128 = Sse.SetAllVector128(threshold); - if (pIdxCurrent + 4 <= pIdxEnd) + if (remainder >= 4) { Vector128 xSrc = Sse.LoadVector128(pSrcCurrent); @@ -1478,15 +1407,12 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span src, Sp pSrcCurrent += 4; } - while (pIdxCurrent < pIdxEnd) + for (int i = 0; i < remainder % 4; i++) { - int index = *pIdxCurrent; - pdst1[index] += (*pSrcCurrent) * primalUpdate; + int index = pIdxCurrent[i]; + pdst1[index] += primalUpdate * pSrcCurrent[i]; float dst1 = pdst1[index]; pdst2[index] = Math.Abs(dst1) > threshold ? (dst1 > 0 ? dst1 - threshold : dst1 + threshold) : 0; - - pIdxCurrent++; - pSrcCurrent++; } } } diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 263779e478..76e6ed52fc 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -26,9 +26,6 @@ internal static class SseIntrinsics Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) : Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); - // The count of 32-bit floats in Vector128 - internal const int Vector128SingleElementCount = 4; - // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray private const int Vector128Alignment = 16; @@ -416,7 +413,7 @@ public static unsafe void AddScalarU(float scalar, Span dst) fixed (float* pdst = dst) { Vector128 scalarVector = Sse.SetAllVector128(scalar); - int count = Math.DivRem(dst.Length, Vector128SingleElementCount, out int remainder); + int count = Math.DivRem(dst.Length, 4, out int remainder); float* pDstCurrent = pdst; for (int i = 0; i < count; i++) @@ -425,7 +422,7 @@ public static unsafe void AddScalarU(float scalar, Span dst) dstVector = Sse.Add(dstVector, scalarVector); Sse.Store(pDstCurrent, dstVector); - pDstCurrent += Vector128SingleElementCount; + pDstCurrent += 4; } for (int i = 0; i < remainder; i++)