From 6221e1b1ddea6c785c6fa04267cd7687542e48c8 Mon Sep 17 00:00:00 2001 From: Anipik Date: Wed, 10 Oct 2018 14:35:24 -0700 Subject: [PATCH 1/7] implemenatation and unitTests added --- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 413 ++++++++++++++---- .../CpuMathUtils.netcoreapp.cs | 4 +- src/Microsoft.ML.CpuMath/Sse.cs | 4 +- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 388 +++++++++++----- src/Microsoft.ML.CpuMath/Thunk.cs | 4 +- src/Native/CpuMathNative/Sse.cpp | 383 ++++++++++++---- .../UnitTests.cs | 8 +- 7 files changed, 932 insertions(+), 272 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index 0d2c015fa3..e2bd8031fe 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -143,18 +143,15 @@ private static Vector256 GetNewDst256(in Vector256 xDst1, in Vecto // Multiply matrix times vector into vector. public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) + fixed (float* psrc = &src.Items[0]) + fixed (float* pdst = &dst.Items[0]) + fixed (float* pmat = &mat.Items[0]) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; @@ -167,25 +164,141 @@ public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, Vector256 res2 = res0; Vector256 res3 = res0; + int length = ccol; float* pSrcCurrent = psrc; - - while (pSrcCurrent < pSrcEnd) + if (ccol < 8) { float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.LoadVector256(pMatTemp); + Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + res0 = Avx.Multiply(x01, vector); + res1 = Avx.Multiply(x11, vector); + res2 = Avx.Multiply(x21, vector); + res3 = Avx.Multiply(x31, vector); + pMatCurrent += ccol; + } + else + { + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); - Vector256 x01 = Avx.LoadAlignedVector256(pMatTemp); - Vector256 x11 = Avx.LoadAlignedVector256(pMatTemp += ccol); - Vector256 x21 = Avx.LoadAlignedVector256(pMatTemp += ccol); - Vector256 x31 = Avx.LoadAlignedVector256(pMatTemp += ccol); - Vector256 x02 = Avx.LoadAlignedVector256(pSrcCurrent); - - res0 = Avx.Add(res0, Avx.Multiply(x01, x02)); - res1 = Avx.Add(res1, Avx.Multiply(x11, x02)); - res2 = Avx.Add(res2, Avx.Multiply(x21, x02)); - res3 = Avx.Add(res3, Avx.Multiply(x31, x02)); - - pSrcCurrent += 8; - pMatCurrent += 8; + int remainder = 0; + if ((misalignment & 3) != 0) + { + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations + while (pSrcCurrent + 8 <= pSrcEnd) + { + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.LoadVector256(pMatTemp); + Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + res0 = Avx.Add(res0, Avx.Multiply(x01, vector)); + res1 = Avx.Add(res1, Avx.Multiply(x11, vector)); + res2 = Avx.Add(res2, Avx.Multiply(x21, vector)); + res3 = Avx.Add(res3, Avx.Multiply(x31, vector)); + + pSrcCurrent += 8; + pMatCurrent += 8; + } + } + else + { + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.LoadVector256(pMatTemp); + Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + Vector256 tempX01 = Avx.And(x01, mask); + Vector256 tempX11 = Avx.And(x11, mask); + Vector256 tempX21 = Avx.And(x21, mask); + Vector256 tempX31 = Avx.And(x31, mask); + + Vector256 tempVec = Avx.And(vector, mask); + + res0 = Avx.Multiply(tempX01, tempVec); + res1 = Avx.Multiply(tempX11, tempVec); + res2 = Avx.Multiply(tempX21, tempVec); + res3 = Avx.Multiply(tempX31, tempVec); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } + + if (length > 7) + { + remainder = length % 8; + while (pSrcCurrent + 8 <= pSrcEnd) + { + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.LoadVector256(pMatTemp); + Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + res0 = Avx.Add(res0, Avx.Multiply(x01, vector)); + res1 = Avx.Add(res1, Avx.Multiply(x11, vector)); + res2 = Avx.Add(res2, Avx.Multiply(x21, vector)); + res3 = Avx.Add(res3, Avx.Multiply(x31, vector)); + + pSrcCurrent += 8; + pMatCurrent += 8; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pSrcCurrent -= (8 - remainder); + + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.LoadVector256(pMatTemp); + Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + + Vector256 tempX01 = Avx.And(x01, mask); + Vector256 tempX11 = Avx.And(x11, mask); + Vector256 tempX21 = Avx.And(x21, mask); + Vector256 tempX31 = Avx.And(x31, mask); + + Vector256 tempVec = Avx.And(vector, mask); + + res0 = Avx.Add(res0, Avx.Multiply(tempVec, tempX01)); + res1 = Avx.Add(res1, Avx.Multiply(tempVec, tempX11)); + res2 = Avx.Add(res2, Avx.Multiply(tempVec, tempX21)); + res3 = Avx.Add(res3, Avx.Multiply(tempVec, tempX31)); + + pMatCurrent += 8; + pSrcCurrent += 8; + } + } } // Add up the entries of each, with the 4 results in res0 @@ -196,9 +309,9 @@ public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0)); if (add) { - sum = Sse.Add(sum, Sse.LoadAlignedVector128(pDstCurrent)); + sum = Sse.Add(sum, Sse.LoadVector128(pDstCurrent)); } - Sse.StoreAligned(pDstCurrent, sum); + Sse.Store(pDstCurrent, sum); pDstCurrent += 4; pMatCurrent += 3 * ccol; @@ -268,27 +381,25 @@ public static unsafe void MatMulPX(bool add, AlignedArray mat, int[] rgposSrc, A public static unsafe void MatMulTranX(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) + fixed (float* psrc = &src.Items[0]) + fixed (float* pdst = &dst.Items[0]) + fixed (float* pmat = &mat.Items[0]) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; + bool firstTime = true; // We do 4-way unrolling - if (!add) + while (pSrcCurrent < pSrcEnd) { - Vector128 h01 = Sse.LoadAlignedVector128(pSrcCurrent); + Vector128 h01 = Sse.LoadVector128(pSrcCurrent); // Replicate each slot of h01 (ABCD) into its own register. Vector128 h11 = Sse.Shuffle(h01, h01, 0x55); // B Vector128 h21 = Sse.Shuffle(h01, h01, 0xAA); // C @@ -300,17 +411,18 @@ public static unsafe void MatMulTranX(bool add, AlignedArray mat, AlignedArray s Vector256 x21 = Avx.SetHighLow(h21, h21); Vector256 x31 = Avx.SetHighLow(h31, h31); - pSrcCurrent += 4; - + int length = crow; float* pDstCurrent = pdst; - while (pDstCurrent < pDstEnd) + if (crow < 8) { float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.LoadAlignedVector256(pMatTemp); - Vector256 x12 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x22 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x32 = Avx.LoadAlignedVector256(pMatTemp += crow); + + Vector256 x02 = Avx.LoadVector256(pMatTemp); + Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); x02 = Avx.Multiply(x01, x02); x12 = Avx.Multiply(x11, x12); @@ -321,57 +433,180 @@ public static unsafe void MatMulTranX(bool add, AlignedArray mat, AlignedArray s x22 = Avx.Add(x22, x32); x02 = Avx.Add(x02, x22); - Avx.StoreAligned(pDstCurrent, x02); + if (add || !firstTime) + { + x02 = Avx.Add(x02, x3); + } + Avx.Store(pDstCurrent, x02); pDstCurrent += 8; pMatCurrent += 8; } - - pMatCurrent += 3 * crow; - } - - while (pSrcCurrent < pSrcEnd) - { - Vector128 h01 = Sse.LoadAlignedVector128(pSrcCurrent); - // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Sse.Shuffle(h01, h01, 0x55); // B - Vector128 h21 = Sse.Shuffle(h01, h01, 0xAA); // C - Vector128 h31 = Sse.Shuffle(h01, h01, 0xFF); // D - h01 = Sse.Shuffle(h01, h01, 0x00); // A - - Vector256 x01 = Avx.SetHighLow(h01, h01); - Vector256 x11 = Avx.SetHighLow(h11, h11); - Vector256 x21 = Avx.SetHighLow(h21, h21); - Vector256 x31 = Avx.SetHighLow(h31, h31); - - float* pDstCurrent = pdst; - - while (pDstCurrent < pDstEnd) + else { - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.LoadAlignedVector256(pMatTemp); - Vector256 x12 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x22 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x32 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x3 = Avx.LoadAlignedVector256(pDstCurrent); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - x3 = Avx.Add(x02, x3); + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); - Avx.StoreAligned(pDstCurrent, x3); - - pDstCurrent += 8; - pMatCurrent += 8; + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector256 x02 = Avx.LoadVector256(pMatTemp); + Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + if (add || !firstTime) + { + x02 = Avx.Add(x02, x3); + } + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + else + { + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + float* pMatTemp = pMatCurrent; + + Vector256 x02 = Avx.LoadVector256(pMatTemp); + Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); + + x02 = Avx.And(x02, leadingMask); + x12 = Avx.And(x12, leadingMask); + x22 = Avx.And(x22, leadingMask); + x32 = Avx.And(x32, leadingMask); + + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); + + if (add || !firstTime) + { + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); + } + + Avx.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 7) + { + remainder = length % 8; + while (pDstCurrent + 8 <= pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector256 x02 = Avx.LoadVector256(pMatTemp); + Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + if (add || !firstTime) + { + x02 = Avx.Add(x02, x3); + } + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pDstCurrent -= (8 - remainder); + float* pMatTemp = pMatCurrent; + + Vector256 x02 = Avx.LoadVector256(pMatTemp); + Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); + Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); + + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); + + x02 = Avx.And(x02, trailingMask); + x12 = Avx.And(x12, trailingMask); + x22 = Avx.And(x22, trailingMask); + x32 = Avx.And(x32, trailingMask); + + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); + + if (add || !firstTime) + { + x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); + } + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } } + firstTime = false; pMatCurrent += 3 * crow; pSrcCurrent += 4; } diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs index 09738684f5..0171b68505 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs @@ -47,12 +47,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(crun <= dst.Size); - SseIntrinsics.MatMulA(add, mat, src, dst, crun, src.Size); + SseIntrinsics.MatMul(add, mat, src, dst, crun, src.Size); } else { Contracts.Assert(crun <= src.Size); - SseIntrinsics.MatMulTranA(add, mat, src, dst, dst.Size, crun); + SseIntrinsics.MatMulTran(add, mat, src, dst, dst.Size, crun); } } else diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs index d541c02533..d5780a4fae 100644 --- a/src/Microsoft.ML.CpuMath/Sse.cs +++ b/src/Microsoft.ML.CpuMath/Sse.cs @@ -43,12 +43,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulA(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); + Thunk.MatMul(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); } else { Contracts.Assert(0 <= crun && crun <= src.Size); - Thunk.MatMulTranA(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); + Thunk.MatMulTran(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); } } } diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index ede8a7da97..bf04985143 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -23,13 +23,6 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath { internal static class SseIntrinsics { - internal static readonly Vector128 AbsMask128 = Sse2.IsSupported ? - Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) : - Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); - - // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray - private const int Vector128Alignment = 16; - public static readonly uint[] LeadingAlignmentMask = new uint[16] { 0x00000000, 0x00000000, 0x00000000, 0x00000000, @@ -46,6 +39,13 @@ internal static class SseIntrinsics 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, }; + internal static readonly Vector128 AbsMask128 = Sse2.IsSupported ? + Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) : + Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); + + // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray + private const int Vector128Alignment = 16; + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] private static bool HasCompatibleAlignment(AlignedArray alignedArray) { @@ -134,20 +134,17 @@ internal static Vector128 GetNewDst128(in Vector128 xDst1, in Vect } // Multiply matrix times vector into vector. - public static unsafe void MatMulA(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMul(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) + fixed (float* psrc = &src.Items[0]) + fixed (float* pdst = &dst.Items[0]) + fixed (float* pmat = &mat.Items[0]) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; @@ -160,25 +157,122 @@ public static unsafe void MatMulA(bool add, AlignedArray mat, AlignedArray src, Vector128 res2 = res0; Vector128 res3 = res0; + int length = ccol; float* pSrcCurrent = psrc; - while (pSrcCurrent < pSrcEnd) - { - float* pMatTemp = pMatCurrent; + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 16); + int remainder = 0; - Vector128 x01 = Sse.LoadAlignedVector128(pMatTemp); - Vector128 x11 = Sse.LoadAlignedVector128(pMatTemp += ccol); - Vector128 x21 = Sse.LoadAlignedVector128(pMatTemp += ccol); - Vector128 x31 = Sse.LoadAlignedVector128(pMatTemp += ccol); - Vector128 x02 = Sse.LoadAlignedVector128(pSrcCurrent); + if ((misalignment & 3) != 0) + { + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations + while (pSrcCurrent + 4 <= pSrcEnd) + { + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.LoadVector128(pMatTemp); + Vector128 x11 = Sse.LoadVector128(pMatTemp += ccol); + Vector128 x21 = Sse.LoadVector128(pMatTemp += ccol); + Vector128 x31 = Sse.LoadVector128(pMatTemp += ccol); + Vector128 vector = Sse.LoadVector128(pSrcCurrent); + + res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); + res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); + res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); + res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + if (misalignment != 0) + { + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.LoadVector128(pMatTemp); + Vector128 x11 = Sse.LoadVector128(pMatTemp += ccol); + Vector128 x21 = Sse.LoadVector128(pMatTemp += ccol); + Vector128 x31 = Sse.LoadVector128(pMatTemp += ccol); + Vector128 vector = Sse.LoadVector128(pSrcCurrent); + + Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); + + Vector128 tempX01 = Sse.And(x01, mask); + Vector128 tempX11 = Sse.And(x11, mask); + Vector128 tempX21 = Sse.And(x21, mask); + Vector128 tempX31 = Sse.And(x31, mask); + Vector128 tempVec = Sse.And(vector, mask); + + res0 = Sse.Multiply(tempX01, tempVec); + res1 = Sse.Multiply(tempX11, tempVec); + res2 = Sse.Multiply(tempX21, tempVec); + res3 = Sse.Multiply(tempX31, tempVec); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } - res0 = Sse.Add(res0, Sse.Multiply(x01, x02)); - res1 = Sse.Add(res1, Sse.Multiply(x11, x02)); - res2 = Sse.Add(res2, Sse.Multiply(x21, x02)); - res3 = Sse.Add(res3, Sse.Multiply(x31, x02)); + if (length > 4) + { + remainder = length % 4; + while (pSrcCurrent + 4 <= pSrcEnd) + { + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.LoadAlignedVector128(pMatTemp); + Vector128 x11 = Sse.LoadAlignedVector128(pMatTemp += ccol); + Vector128 x21 = Sse.LoadAlignedVector128(pMatTemp += ccol); + Vector128 x31 = Sse.LoadAlignedVector128(pMatTemp += ccol); + Vector128 vector = Sse.LoadVector128(pSrcCurrent); + + res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); + res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); + res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); + res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + remainder = length; + } - pSrcCurrent += 4; - pMatCurrent += 4; + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pSrcCurrent -= (4 - remainder); + + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.LoadVector128(pMatTemp); + Vector128 x11 = Sse.LoadVector128(pMatTemp += ccol); + Vector128 x21 = Sse.LoadVector128(pMatTemp += ccol); + Vector128 x31 = Sse.LoadVector128(pMatTemp += ccol); + Vector128 vector = Sse.LoadVector128(pSrcCurrent); + + Vector128 mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + + Vector128 tempX01 = Sse.And(x01, mask); + Vector128 tempX11 = Sse.And(x11, mask); + Vector128 tempX21 = Sse.And(x21, mask); + Vector128 tempX31 = Sse.And(x31, mask); + Vector128 tempVec = Sse.And(vector, mask); + + res0 = Sse.Add(res0, Sse.Multiply(tempVec, tempX01)); + res1 = Sse.Add(res1, Sse.Multiply(tempVec, tempX11)); + res2 = Sse.Add(res2, Sse.Multiply(tempVec, tempX21)); + res3 = Sse.Add(res3, Sse.Multiply(tempVec, tempX31)); + + pMatCurrent += 4; + pSrcCurrent += 4; + } } // Add up the entries of each, with the 4 results in res0 @@ -188,10 +282,10 @@ public static unsafe void MatMulA(bool add, AlignedArray mat, AlignedArray src, if (add) { - res0 = Sse.Add(res0, Sse.LoadAlignedVector128(pDstCurrent)); + res0 = Sse.Add(res0, Sse.LoadVector128(pDstCurrent)); } - Sse.StoreAligned(pDstCurrent, res0); + Sse.Store(pDstCurrent, res0); pDstCurrent += 4; pMatCurrent += 3 * ccol; } @@ -256,26 +350,25 @@ public static unsafe void MatMulPA(bool add, AlignedArray mat, int[] rgposSrc, A } } - public static unsafe void MatMulTranA(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) + fixed (float* psrc = &src.Items[0]) + fixed (float* pdst = &dst.Items[0]) + fixed (float* pmat = &mat.Items[0]) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; + bool firstTime = true; - if (!add) + // We do 4-way unrolling + while (pSrcCurrent < pSrcEnd) { Vector128 x01 = Sse.LoadAlignedVector128(pSrcCurrent); // Replicate each 32-bit slot of x01 (ABCD) into its own register. @@ -284,73 +377,170 @@ public static unsafe void MatMulTranA(bool add, AlignedArray mat, AlignedArray s Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D x01 = Sse.Shuffle(x01, x01, 0x00); // A - pSrcCurrent += 4; - + int length = crow; float* pDstCurrent = pdst; - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); - Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 16); - Sse.StoreAligned(pDstCurrent, x02); - - pDstCurrent += 4; - pMatCurrent += 4; + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector128 x02 = Sse.LoadVector128(pMatTemp); + Vector128 x12 = Sse.LoadVector128(pMatTemp += crow); + Vector128 x22 = Sse.LoadVector128(pMatTemp += crow); + Vector128 x32 = Sse.LoadVector128(pMatTemp += crow); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + if (add || !firstTime) + { + x02 = Sse.Add(x02, x3); + } + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } } + else + { + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + float* pMatTemp = pMatCurrent; + + Vector128 x02 = Sse.LoadVector128(pMatTemp); + Vector128 x12 = Sse.LoadVector128(pMatTemp += crow); + Vector128 x22 = Sse.LoadVector128(pMatTemp += crow); + Vector128 x32 = Sse.LoadVector128(pMatTemp += crow); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); + + x02 = Sse.And(x02, leadingMask); + x12 = Sse.And(x12, leadingMask); + x22 = Sse.And(x22, leadingMask); + x32 = Sse.And(x32, leadingMask); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); + + if (add || !firstTime) + { + x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); + } + + Sse.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 4) + { + remainder = length % 4; + while (pDstCurrent + 4 <= pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); + Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + if (add || !firstTime) + { + x02 = Sse.Add(x02, x3); + } + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + remainder = length; + } - pMatCurrent += 3 * crow; - } + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + float* pMatTemp = pMatCurrent; - while (pSrcCurrent < pSrcEnd) - { - Vector128 x01 = Sse.LoadAlignedVector128(pSrcCurrent); - // Replicate each 32-bit slot of x01 (ABCD) into its own register. - Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B - Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C - Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D - x01 = Sse.Shuffle(x01, x01, 0x00); // A + Vector128 x02 = Sse.LoadVector128(pMatTemp); + Vector128 x12 = Sse.LoadVector128(pMatTemp += crow); + Vector128 x22 = Sse.LoadVector128(pMatTemp += crow); + Vector128 x32 = Sse.LoadVector128(pMatTemp += crow); - float* pDstCurrent = pdst; + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; + x02 = Sse.And(x02, trailingMask); + x12 = Sse.And(x12, trailingMask); + x22 = Sse.And(x22, trailingMask); + x32 = Sse.And(x32, trailingMask); - Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); - Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x3 = Sse.LoadAlignedVector128(pDstCurrent); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - x3 = Sse.Add(x02, x3); + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); - Sse.StoreAligned(pDstCurrent, x3); + x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); - pDstCurrent += 4; - pMatCurrent += 4; + if (add || !firstTime) + { + x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); + } + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } } + firstTime = false; pMatCurrent += 3 * crow; pSrcCurrent += 4; } @@ -526,7 +716,7 @@ public static unsafe void Scale(float scale, Span dst) for (float* pEnd = pDstCurrent + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 4) { - // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads + // If we aren't using the VEX-encoding, the JIT will only fold away aligned loads // (due to semantics of the legacy encoding). // We don't need an assert, since the instruction will throw for unaligned inputs. Vector128 temp = Sse.LoadAlignedVector128(pDstCurrent); diff --git a/src/Microsoft.ML.CpuMath/Thunk.cs b/src/Microsoft.ML.CpuMath/Thunk.cs index 23192fc277..3291684873 100644 --- a/src/Microsoft.ML.CpuMath/Thunk.cs +++ b/src/Microsoft.ML.CpuMath/Thunk.cs @@ -16,7 +16,7 @@ internal static unsafe class Thunk public static extern bool ChkAvx(); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulA(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); + public static extern void MatMul(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void MatMulX(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); @@ -65,7 +65,7 @@ public static extern void RespNormU(bool add, float alpha, float beta, bool avgO // and columns from that perspective. Alternatively, crow is the number of rows in the transpose of pmat // (thought of as row-major order). [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranA(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); + public static extern void MatMulTran(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void MatMulTranX(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 4a2d30e979..72b89fc765 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -122,33 +122,149 @@ EXPORT_API(bool) ChkAvx() } // Multiply matrix times vector into vector. -EXPORT_API(void) MatMulA(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) +EXPORT_API(void) MatMul(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { - const float * psLim = psrc + ccol; - const float * pdLim = pdst + crow; - const float * pm = pmat; - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 3 * ccol) + const float * pSrcEnd = psrc + ccol; + const float * pDstEnd = pdst + crow; + float* pDstCurrent = pdst; + const float* pMatCurrent = pmat; + + while (pDstCurrent < pDstEnd) { __m128 res0 = _mm_setzero_ps(); __m128 res1 = res0; __m128 res2 = res0; __m128 res3 = res0; - for (const float * ps = psrc; ps < psLim; ps += 4, pm += 4) - { - const float * pmTmp; - __m128 x01 = _mm_load_ps(pmTmp = pm); - __m128 x11 = _mm_load_ps(pmTmp += ccol); - __m128 x21 = _mm_load_ps(pmTmp += ccol); - __m128 x31 = _mm_load_ps(pmTmp += ccol); - __m128 x02 = _mm_load_ps(ps); - x01 = _mm_mul_ps(x01, x02); - x11 = _mm_mul_ps(x11, x02); - x21 = _mm_mul_ps(x21, x02); - x31 = _mm_mul_ps(x31, x02); - res0 = _mm_add_ps(res0, x01); - res1 = _mm_add_ps(res1, x11); - res2 = _mm_add_ps(res2, x21); - res3 = _mm_add_ps(res3, x31); + + int length = ccol; + const float* pSrcCurrent = psrc; + + uintptr_t address = (uintptr_t)(pMatCurrent); + uintptr_t misalignment = address % 16; + int remainder = 0; + + if ((misalignment & 3) != 0) + { + while (pSrcCurrent + 4 <= pSrcEnd) + { + remainder = ccol % 4; + + __m128 x01 = _mm_loadu_ps(pMatCurrent); + __m128 x11 = _mm_loadu_ps(pMatCurrent + ccol); + __m128 x21 = _mm_loadu_ps(pMatCurrent + 2 * ccol); + __m128 x31 = _mm_loadu_ps(pMatCurrent + 3 * ccol); + __m128 vector = _mm_loadu_ps(pSrcCurrent); + + x01 = _mm_mul_ps(x01, vector); + x11 = _mm_mul_ps(x11, vector); + x21 = _mm_mul_ps(x21, vector); + x31 = _mm_mul_ps(x31, vector); + + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + if (misalignment != 0) + { + misalignment >>= 2; + misalignment = 4 - misalignment; + + __m128 x01 = _mm_loadu_ps(pMatCurrent); + __m128 x11 = _mm_loadu_ps(pMatCurrent + ccol); + __m128 x21 = _mm_loadu_ps(pMatCurrent + 2 * ccol); + __m128 x31 = _mm_loadu_ps(pMatCurrent + 3 * ccol); + __m128 vector = _mm_loadu_ps(pSrcCurrent); + + __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); + + __m128 tempX01 = _mm_and_ps(x01, mask); + __m128 tempX11 = _mm_and_ps(x11, mask); + __m128 tempX21 = _mm_and_ps(x21, mask); + __m128 tempX31 = _mm_and_ps(x31, mask); + + __m128 tempVec = _mm_and_ps(vector, mask); + + res0 = _mm_mul_ps(tempX01, tempVec); + res1 = _mm_mul_ps(tempX11, tempVec); + res2 = _mm_mul_ps(tempX21, tempVec); + res3 = _mm_mul_ps(tempX31, tempVec); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } + + if (length > 3) + { + remainder = length % 4; + while(pSrcCurrent + 4 <= pSrcEnd) + { + __m128 x01 = _mm_load_ps(pMatCurrent); + __m128 x11 = _mm_load_ps(pMatCurrent + ccol); + __m128 x21 = _mm_load_ps(pMatCurrent + 2 * ccol); + __m128 x31 = _mm_load_ps(pMatCurrent + 3 * ccol); + + __m128 vector = _mm_loadu_ps(pSrcCurrent); + + x01 = _mm_mul_ps(x01, vector); + x11 = _mm_mul_ps(x11, vector); + x21 = _mm_mul_ps(x21, vector); + x31 = _mm_mul_ps(x31, vector); + + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pSrcCurrent -= (4 - remainder); + + __m128 x01 = _mm_loadu_ps(pMatCurrent); + __m128 x11 = _mm_loadu_ps(pMatCurrent + ccol); + __m128 x21 = _mm_loadu_ps(pMatCurrent + 2 * ccol); + __m128 x31 = _mm_loadu_ps(pMatCurrent + 3 * ccol); + __m128 vector = _mm_loadu_ps(pSrcCurrent); + + __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + + __m128 tempX01 = _mm_and_ps(x01, mask); + __m128 tempX11 = _mm_and_ps(x11, mask); + __m128 tempX21 = _mm_and_ps(x21, mask); + __m128 tempX31 = _mm_and_ps(x31, mask); + + __m128 tempVec = _mm_and_ps(vector, mask); + + x01 = _mm_mul_ps(tempX01, tempVec); + x11 = _mm_mul_ps(tempX11, tempVec); + x21 = _mm_mul_ps(tempX21, tempVec); + x31 = _mm_mul_ps(tempX31, tempVec); + + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); + + pMatCurrent += 4; + pSrcCurrent += 4; + } } // Add up the entries of each, with the 4 results in res0 @@ -157,8 +273,12 @@ EXPORT_API(void) MatMulA(bool add, _In_ const float * pmat, _In_ const float * p res0 = _mm_hadd_ps(res0, res2); if (add) - res0 = _mm_add_ps(res0, _mm_load_ps(pd)); - _mm_store_ps(pd, res0); + res0 = _mm_add_ps(res0, _mm_loadu_ps(pDstCurrent)); + + _mm_storeu_ps(pDstCurrent, res0); + + pDstCurrent += 4; + pMatCurrent += 3 * ccol; } } @@ -495,70 +615,185 @@ EXPORT_API(void) RespNormU(bool add, float alpha, float beta, bool avgOverFullKe } } -EXPORT_API(void) MatMulTranA(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) +EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { - const float * psLim = psrc + ccol; - const float * pdLim = pdst + crow; - const float * pm = pmat; - const float * ps = psrc; + const float * pSrcEnd = psrc + ccol; + const float * pDstEnd = pdst + crow; + + const float* pMatCurrent = pmat; + const float* pSrcCurrent = psrc; + bool firstTime = true; - if (!add) + while (pSrcCurrent < pSrcEnd) { - __m128 x01 = _mm_load_ps(ps); + __m128 x01 = _mm_loadu_ps(pSrcCurrent); // Replicate each slot of x01 into its own register. __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); x01 = _mm_shuffle_ps(x01, x01, 0x00); - ps += 4; - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) - { - const float * pmTmp; - __m128 x02 = _mm_load_ps(pmTmp = pm); - __m128 x12 = _mm_load_ps(pmTmp += crow); - __m128 x22 = _mm_load_ps(pmTmp += crow); - __m128 x32 = _mm_load_ps(pmTmp += crow); - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - _mm_store_ps(pd, x02); - } - pm += 3 * crow; - } + int length = crow; + float* pDstCurrent = pdst; - for (; ps < psLim; ps += 4) - { - __m128 x01 = _mm_load_ps(ps); - // Replicate each slot of x01 into its own register. - __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); - __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); - __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); - x01 = _mm_shuffle_ps(x01, x01, 0x00); - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) + uintptr_t address = (uintptr_t)(pMatCurrent); + uintptr_t misalignment = address % 16; + int remainder = 0; + + if ((misalignment & 3) != 0) { - const float * pmTmp; - __m128 x02 = _mm_load_ps(pmTmp = pm); - __m128 x12 = _mm_load_ps(pmTmp += crow); - __m128 x22 = _mm_load_ps(pmTmp += crow); - __m128 x32 = _mm_load_ps(pmTmp += crow); - __m128 x3 = _mm_load_ps(pd); - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - x3 = _mm_add_ps(x02, x3); - _mm_store_ps(pd, x3); + while (pDstCurrent < pDstEnd) + { + __m128 x02 = _mm_loadu_ps(pMatCurrent); + __m128 x12 = _mm_loadu_ps(pMatCurrent + crow); + __m128 x22 = _mm_loadu_ps(pMatCurrent + 2 * crow); + __m128 x32 = _mm_loadu_ps(pMatCurrent + 3 * crow); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + if (add || !firstTime) + { + x02 = _mm_add_ps(x02, x3); + } + _mm_storeu_ps(pDstCurrent, x02); + + pDstCurrent += 4; + pMatCurrent += 4; + } } + else + { + int remainder = 0; + if (misalignment != 0) + { + misalignment >>= 2; + misalignment = 4 - misalignment; + + __m128 x02 = _mm_loadu_ps(pMatCurrent); + __m128 x12 = _mm_loadu_ps(pMatCurrent + crow); + __m128 x22 = _mm_loadu_ps(pMatCurrent + 2 * crow); + __m128 x32 = _mm_loadu_ps(pMatCurrent + 3 * crow); + + __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); + __m128 mask2 = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (( 4 - misalignment) * 4)); + + x02 = _mm_and_ps(x02, mask); + x12 = _mm_and_ps(x12, mask); + x22 = _mm_and_ps(x22, mask); + x32 = _mm_and_ps(x32, mask); - pm += 3 * crow; + __m128 x3 = _mm_loadu_ps(pDstCurrent); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + x02 = _mm_or_ps(x02, _mm_and_ps(x3, mask2)); + + if (add || !firstTime) + { + x02 = _mm_add_ps(x02, _mm_and_ps(x3, mask)); + } + + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + + if(length > 3) + { + remainder = length % 4; + while (pDstCurrent + 4 <= pDstEnd) + { + __m128 x02 = _mm_loadu_ps(pMatCurrent); + __m128 x12 = _mm_loadu_ps(pMatCurrent + crow); + __m128 x22 = _mm_loadu_ps(pMatCurrent + 2 * crow); + __m128 x32 = _mm_loadu_ps(pMatCurrent + 3 * crow); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + if (add || !firstTime) + { + x02 = _mm_add_ps(x02, x3); + } + _mm_storeu_ps(pDstCurrent, x02); + + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + length = remainder; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + + __m128 x02 = _mm_loadu_ps(pMatCurrent); + __m128 x12 = _mm_loadu_ps(pMatCurrent + crow); + __m128 x22 = _mm_loadu_ps(pMatCurrent + 2 * crow); + __m128 x32 = _mm_loadu_ps(pMatCurrent + 3 * crow); + + __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + __m128 mask2 = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (( 4 - remainder) * 4)); + + x02 = _mm_and_ps(x02, mask); + x12 = _mm_and_ps(x12, mask); + x22 = _mm_and_ps(x22, mask); + x32 = _mm_and_ps(x32, mask); + + __m128 x3 = _mm_loadu_ps(pDstCurrent); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + x02 = _mm_or_ps(x02, _mm_and_ps(x3, mask2)); + + if (add || !firstTime) + { + x02 = _mm_add_ps(x02, _mm_and_ps(x3, mask)); + } + + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += 4; + pDstCurrent += 4; + } + } + + firstTime = false; + pMatCurrent += 3 * crow; + pSrcCurrent += 4; } } diff --git a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs index 6c0f1cbaf0..e9cffb3c50 100644 --- a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs +++ b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs @@ -84,7 +84,7 @@ public CpuMathUtilsUnitTests() [InlineData(0, 0, 0, new float[] { -416.6801f, -416.6801f, -416.6801f, -416.6801f, -416.6801f, -416.6801f, -416.6801f, -416.6801f })] [InlineData(1, 1, 0, new float[] { 1496f, 3672f, 5848f, 8024f, 10200f, 12376f, 14552f, 16728f })] [InlineData(1, 0, 1, new float[] { 204f, 492f, 780f, 1068f, 1356f, 1644f, 1932f, 2220f, 2508f, 2796f, 3084f, 3372f, 3660f, 3948f, 4236f, 4524f })] - public void MatMulATest(int matTest, int srcTest, int dstTest, float[] expected) + public void MatMulTest(int matTest, int srcTest, int dstTest, float[] expected) { AlignedArray mat = _testMatrices[matTest]; AlignedArray src = _testSrcVectors[srcTest]; @@ -100,7 +100,7 @@ public void MatMulATest(int matTest, int srcTest, int dstTest, float[] expected) [InlineData(0, 0, 0, new float[] { -416.6801f, -415.6801f, -414.6801f, -413.6801f, -412.6801f, -411.6801f, -410.6801f, -409.6801f })] [InlineData(1, 1, 0, new float[] { 1496f, 3673f, 5850f, 8027f, 10204f, 12381f, 14558f, 16735f })] [InlineData(1, 0, 1, new float[] { 204f, 493f, 782f, 1071f, 1360f, 1649f, 1938f, 2227f, 2516f, 2805f, 3094f, 3383f, 3672f, 3961f, 4250f, 4539f })] - public void MatMulAAddTest(int matTest, int srcTest, int dstTest, float[] expected) + public void MatMulAddTest(int matTest, int srcTest, int dstTest, float[] expected) { AlignedArray mat = _testMatrices[matTest]; AlignedArray src = _testSrcVectors[srcTest]; @@ -116,7 +116,7 @@ public void MatMulAAddTest(int matTest, int srcTest, int dstTest, float[] expect [InlineData(0, 0, 0, new float[] { 70.56001f, -85.68f, -351.36f, 498.24f, -3829.32f, -969.48f, 1168.2f, 118.44f })] [InlineData(1, 0, 1, new float[] { 2724f, 2760f, 2796f, 2832f, 2868f, 2904f, 2940f, 2976f, 3012f, 3048f, 3084f, 3120f, 3156f, 3192f, 3228f, 3264f })] [InlineData(1, 1, 0, new float[] { 11016f, 11152f, 11288f, 11424f, 11560f, 11696f, 11832f, 11968f })] - public void MatMulTranATest(int matTest, int srcTest, int dstTest, float[] expected) + public void MatMulTranTest(int matTest, int srcTest, int dstTest, float[] expected) { AlignedArray mat = _testMatrices[matTest]; AlignedArray src = _testSrcVectors[srcTest]; @@ -132,7 +132,7 @@ public void MatMulTranATest(int matTest, int srcTest, int dstTest, float[] expec [InlineData(0, 0, 0, new float[] { 70.56001f, -84.68f, -349.36f, 501.24f, -3825.32f, -964.48f, 1174.2f, 125.44f })] [InlineData(1, 0, 1, new float[] { 2724f, 2761f, 2798f, 2835f, 2872f, 2909f, 2946f, 2983f, 3020f, 3057f, 3094f, 3131f, 3168f, 3205f, 3242f, 3279f })] [InlineData(1, 1, 0, new float[] { 11016f, 11153f, 11290f, 11427f, 11564f, 11701f, 11838f, 11975f })] - public void MatMulTranAAddTest(int matTest, int srcTest, int dstTest, float[] expected) + public void MatMulTranAddTest(int matTest, int srcTest, int dstTest, float[] expected) { AlignedArray mat = _testMatrices[matTest]; AlignedArray src = _testSrcVectors[srcTest]; From 9515ebebe8ec2a16c8f5e4d80c595ee03f65e580 Mon Sep 17 00:00:00 2001 From: Anipik Date: Wed, 10 Oct 2018 15:15:36 -0700 Subject: [PATCH 2/7] added performance test for matmul and matmulTrans --- .../AvxPerformanceTests.cs | 28 +++++++++++++++++++ .../NativePerformanceTests.cs | 20 +++++++++++++ .../SsePerformanceTests.cs | 28 +++++++++++++++++++ 3 files changed, 76 insertions(+) diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs index 5930090bc5..9e4fc83032 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs @@ -11,6 +11,26 @@ namespace Microsoft.ML.CpuMath.PerformanceTests { public class AvxPerformanceTests : PerformanceTests { + private AlignedArray _testMatrices; + private AlignedArray _testSrcVectors; + private AlignedArray _testDstVectors; + + [GlobalSetup(Targets = new string[] { nameof(MatMulX), nameof(MatMulTranX) })] + public void MatMulSetup() + { + Setup(); + int vectorAlignment = CpuMathUtils.GetVectorAlignment(); + + _testMatrices = new AlignedArray(1000 * 1000, vectorAlignment); + _testMatrices.CopyFrom(src, 0, 1000 * 1000); + + _testSrcVectors = new AlignedArray(1000, vectorAlignment); + _testSrcVectors.CopyFrom(src, 0, 1000); + + _testDstVectors = new AlignedArray(1000, vectorAlignment); + _testDstVectors.CopyFrom(dst, 0, 1000); + } + [Benchmark] public void AddScalarU() => AvxIntrinsics.AddScalarU(DefaultScale, new Span(dst, 0, Length)); @@ -99,5 +119,13 @@ public void SdcaL1UpdateU() [Benchmark] public void SdcaL1UpdateSU() => AvxIntrinsics.SdcaL1UpdateSU(DefaultScale, new Span(src, 0, IndexLength), new Span(idx, 0, IndexLength), DefaultScale, new Span(dst), new Span(result)); + + [Benchmark] + public void MatMulX() + => AvxIntrinsics.MatMulX(true, _testMatrices, _testMatrices, _testDstVectors, 1000, 1000); + + [Benchmark] + public void MatMulTranX() + => AvxIntrinsics.MatMulTranX(true, _testMatrices, _testMatrices, _testDstVectors, 1000, 1000); } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs index 3cce45046c..91dcf32d40 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs @@ -228,5 +228,25 @@ public unsafe void SdcaL1UpdateSU() CpuMathNativeUtils.SdcaL1UpdateSU(DefaultScale, psrc, pidx, DefaultScale, pdst, pres, IndexLength); } } + + [Benchmark] + public unsafe void MatMulX() + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + { + Thunk.MatMulX(true, psrc, psrc, pdst, 1000, 1000); + } + } + + [Benchmark] + public unsafe void MatMulTranX() + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + { + Thunk.MatMulTranX(true, psrc, psrc, pdst, 1000, 1000); + } + } } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs index 923d7c539f..9a9978851e 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs @@ -11,6 +11,26 @@ namespace Microsoft.ML.CpuMath.PerformanceTests { public class SsePerformanceTests : PerformanceTests { + private AlignedArray _testMatrices; + private AlignedArray _testSrcVectors; + private AlignedArray _testDstVectors; + + [GlobalSetup(Targets = new string[] { nameof(MatMulX), nameof(MatMulTranX) })] + public void MatMulSetup() + { + Setup(); + int vectorAlignment = CpuMathUtils.GetVectorAlignment(); + + _testMatrices = new AlignedArray(1000 * 1000, vectorAlignment); + _testMatrices.CopyFrom(src, 0, 1000 * 1000); + + _testSrcVectors = new AlignedArray(1000, vectorAlignment); + _testSrcVectors.CopyFrom(src, 0, 1000); + + _testDstVectors = new AlignedArray(1000, vectorAlignment); + _testDstVectors.CopyFrom(dst, 0, 1000); + } + [Benchmark] public void AddScalarU() => SseIntrinsics.AddScalarU(DefaultScale, new Span(dst, 0, Length)); @@ -99,5 +119,13 @@ public void SdcaL1UpdateU() [Benchmark] public void SdcaL1UpdateSU() => SseIntrinsics.SdcaL1UpdateSU(DefaultScale, new Span(src, 0, IndexLength), new Span(idx, 0, IndexLength), DefaultScale, new Span(dst), new Span(result)); + + [Benchmark] + public void MatMulX() + => SseIntrinsics.MatMul(true, _testMatrices, _testMatrices, _testDstVectors, 1000, 1000); + + [Benchmark] + public void MatMulTranX() + => SseIntrinsics.MatMulTran(true, _testMatrices, _testMatrices, _testDstVectors, 1000, 1000); } } From 9c34ff24124983c412555b6fbf0f1e3b4d1ac901 Mon Sep 17 00:00:00 2001 From: Anipik Date: Thu, 11 Oct 2018 15:20:26 -0700 Subject: [PATCH 3/7] load combined with math operation --- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 474 ++++++++---------- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 192 ++++--- src/Native/CpuMathNative/Sse.cpp | 190 +++---- .../AvxPerformanceTests.cs | 24 +- .../SsePerformanceTests.cs | 24 +- 5 files changed, 361 insertions(+), 543 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index e2bd8031fe..760cd718db 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -146,9 +146,14 @@ public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - fixed (float* pmat = &mat.Items[0]) + MatMulX(add, mat.Items, src.Items, dst.Items, crow, ccol); + } + + public static unsafe void MatMulX(bool add, float[] mat, float[] src, float[] dst, int crow, int ccol) + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* pmat = &mat[0]) fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { @@ -166,43 +171,79 @@ public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, int length = ccol; float* pSrcCurrent = psrc; - if (ccol < 8) + + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); + + int remainder = 0; + if ((misalignment & 3) != 0) { - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.LoadVector256(pMatTemp); - Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 vector = Avx.LoadVector256(pSrcCurrent); - - res0 = Avx.Multiply(x01, vector); - res1 = Avx.Multiply(x11, vector); - res2 = Avx.Multiply(x21, vector); - res3 = Avx.Multiply(x31, vector); - pMatCurrent += ccol; + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations + while (pSrcCurrent < pSrcEnd) + { + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + + res0 = Avx.Add(res0, x01); + res1 = Avx.Add(res1, x11); + res2 = Avx.Add(res2, x21); + res3 = Avx.Add(res3, x31); + + pSrcCurrent += 8; + pMatCurrent += 8; + } } else { - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + + Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); + + res0 = Avx.Multiply(x01, vector); + res1 = Avx.Multiply(x11, vector); + res2 = Avx.Multiply(x21, vector); + res3 = Avx.Multiply(x31, vector); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } - int remainder = 0; - if ((misalignment & 3) != 0) + if (length > 7) { - // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent + 8 <= pSrcEnd) + remainder = length % 8; + while (pSrcCurrent < pSrcEnd) { - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.LoadVector256(pMatTemp); - Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); Vector256 vector = Avx.LoadVector256(pSrcCurrent); - res0 = Avx.Add(res0, Avx.Multiply(x01, vector)); - res1 = Avx.Add(res1, Avx.Multiply(x11, vector)); - res2 = Avx.Add(res2, Avx.Multiply(x21, vector)); - res3 = Avx.Add(res3, Avx.Multiply(x31, vector)); + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + + res0 = Avx.Add(res0, x01); + res1 = Avx.Add(res1, x11); + res2 = Avx.Add(res2, x21); + res3 = Avx.Add(res3, x31); pSrcCurrent += 8; pMatCurrent += 8; @@ -210,94 +251,30 @@ public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, } else { - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.LoadVector256(pMatTemp); - Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 vector = Avx.LoadVector256(pSrcCurrent); - - Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - - Vector256 tempX01 = Avx.And(x01, mask); - Vector256 tempX11 = Avx.And(x11, mask); - Vector256 tempX21 = Avx.And(x21, mask); - Vector256 tempX31 = Avx.And(x31, mask); - - Vector256 tempVec = Avx.And(vector, mask); - - res0 = Avx.Multiply(tempX01, tempVec); - res1 = Avx.Multiply(tempX11, tempVec); - res2 = Avx.Multiply(tempX21, tempVec); - res3 = Avx.Multiply(tempX31, tempVec); - - pMatCurrent += misalignment; - pSrcCurrent += misalignment; - length -= misalignment; - } - - if (length > 7) - { - remainder = length % 8; - while (pSrcCurrent + 8 <= pSrcEnd) - { - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.LoadVector256(pMatTemp); - Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 vector = Avx.LoadVector256(pSrcCurrent); - - res0 = Avx.Add(res0, Avx.Multiply(x01, vector)); - res1 = Avx.Add(res1, Avx.Multiply(x11, vector)); - res2 = Avx.Add(res2, Avx.Multiply(x21, vector)); - res3 = Avx.Add(res3, Avx.Multiply(x31, vector)); - - pSrcCurrent += 8; - pMatCurrent += 8; - } - } - else - { - remainder = length; - } - - if (remainder != 0) - { - pMatCurrent -= (8 - remainder); - pSrcCurrent -= (8 - remainder); - - float* pMatTemp = pMatCurrent; - Vector256 x01 = Avx.LoadVector256(pMatTemp); - Vector256 x11 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x21 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 x31 = Avx.LoadVector256(pMatTemp += ccol); - Vector256 vector = Avx.LoadVector256(pSrcCurrent); + remainder = length; + } - Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pSrcCurrent -= (8 - remainder); - Vector256 tempX01 = Avx.And(x01, mask); - Vector256 tempX11 = Avx.And(x11, mask); - Vector256 tempX21 = Avx.And(x21, mask); - Vector256 tempX31 = Avx.And(x31, mask); + Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - Vector256 tempVec = Avx.And(vector, mask); + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); - res0 = Avx.Add(res0, Avx.Multiply(tempVec, tempX01)); - res1 = Avx.Add(res1, Avx.Multiply(tempVec, tempX11)); - res2 = Avx.Add(res2, Avx.Multiply(tempVec, tempX21)); - res3 = Avx.Add(res3, Avx.Multiply(tempVec, tempX31)); + res0 = Avx.Add(res0, Avx.Multiply(x01, vector)); + res1 = Avx.Add(res1, Avx.Multiply(x11, vector)); + res2 = Avx.Add(res2, Avx.Multiply(x21, vector)); + res3 = Avx.Add(res3, Avx.Multiply(x31, vector)); - pMatCurrent += 8; - pSrcCurrent += 8; - } + pMatCurrent += 8; + pSrcCurrent += 8; } } @@ -384,9 +361,14 @@ public static unsafe void MatMulTranX(bool add, AlignedArray mat, AlignedArray s Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - fixed (float* pmat = &mat.Items[0]) + MatMulTranX(add, mat.Items, src.Items, dst.Items, crow, ccol); + } + + public static unsafe void MatMulTranX(bool add, float[] mat, float[] src, float[] dst, int crow, int ccol) + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* pmat = &mat[0]) fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { @@ -414,55 +396,86 @@ public static unsafe void MatMulTranX(bool add, AlignedArray mat, AlignedArray s int length = crow; float* pDstCurrent = pdst; - if (crow < 8) - { - float* pMatTemp = pMatCurrent; + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); - Vector256 x02 = Avx.LoadVector256(pMatTemp); - Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); + if (add || !firstTime) + { + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); + } - if (add || !firstTime) - { - x02 = Avx.Add(x02, x3); + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; } - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; } else { - nuint address = (nuint)(pMatCurrent); - int misalignment = (int)(address % 32); + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); + + if (add || !firstTime) + { + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); + } - if ((misalignment & 3) != 0) + Avx.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 7) { + remainder = length % 8; while (pDstCurrent < pDstEnd) { float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.LoadVector256(pMatTemp); - Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); x02 = Avx.Add(x02, x12); x22 = Avx.Add(x22, x32); @@ -470,7 +483,7 @@ public static unsafe void MatMulTranX(bool add, AlignedArray mat, AlignedArray s if (add || !firstTime) { - x02 = Avx.Add(x02, x3); + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); } Avx.Store(pDstCurrent, x02); @@ -480,129 +493,42 @@ public static unsafe void MatMulTranX(bool add, AlignedArray mat, AlignedArray s } else { - int remainder = 0; - if (misalignment != 0) - { - // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then - // masking any elements that will be included in the first aligned read - misalignment >>= 2; - misalignment = 8 - misalignment; - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.LoadVector256(pMatTemp); - Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); - - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); - - x02 = Avx.And(x02, leadingMask); - x12 = Avx.And(x12, leadingMask); - x22 = Avx.And(x22, leadingMask); - x32 = Avx.And(x32, leadingMask); - - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); - - if (add || !firstTime) - { - x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); - } + remainder = length; + } - Avx.Store(pDstCurrent, x02); - pMatCurrent += misalignment; - pDstCurrent += misalignment; - length -= misalignment; - } - if (length > 7) - { - remainder = length % 8; - while (pDstCurrent + 8 <= pDstEnd) - { - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.LoadVector256(pMatTemp); - Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - if (add || !firstTime) - { - x02 = Avx.Add(x02, x3); - } - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } - } - else + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pDstCurrent -= (8 - remainder); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); + + if (add || !firstTime) { - remainder = length; + x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); } - if (remainder != 0) - { - pMatCurrent -= (8 - remainder); - pDstCurrent -= (8 - remainder); - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.LoadVector256(pMatTemp); - Vector256 x12 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x22 = Avx.LoadVector256(pMatTemp += crow); - Vector256 x32 = Avx.LoadVector256(pMatTemp += crow); - - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); - Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); - - x02 = Avx.And(x02, trailingMask); - x12 = Avx.And(x12, trailingMask); - x22 = Avx.And(x22, trailingMask); - x32 = Avx.And(x32, trailingMask); - - Vector256 x3 = Avx.LoadVector256(pDstCurrent); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - - x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); - - if (add || !firstTime) - { - x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); - } - - Avx.Store(pDstCurrent, x02); - pDstCurrent += 8; - pMatCurrent += 8; - } + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; } } @@ -733,7 +659,7 @@ public static unsafe void Scale(float scale, Span dst) if (length < 8) { - switch(length) + switch (length) { case 7: dst[6] *= scale; goto case 6; case 6: dst[5] *= scale; goto case 5; @@ -775,7 +701,7 @@ public static unsafe void Scale(float scale, Span dst) Vector256 result = Avx.LoadVector256(pDstCurrent); Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (( 8 - misalignment) * 8)); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); Vector256 temp = Avx.And(result, leadingMask); result = Avx.And(result, trailingMask); diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index bf04985143..b4ad17b6a1 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -139,9 +139,14 @@ public static unsafe void MatMul(bool add, AlignedArray mat, AlignedArray src, A Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - fixed (float* pmat = &mat.Items[0]) + MatMul(add, mat.Items, src.Items, dst.Items, crow, ccol); + } + + public static unsafe void MatMul(bool add, float[] mat, float[] src, float[] dst, int crow, int ccol) + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* pmat = &mat[0]) fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { @@ -167,19 +172,20 @@ public static unsafe void MatMul(bool add, AlignedArray mat, AlignedArray src, A if ((misalignment & 3) != 0) { // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent < pSrcEnd) { - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.LoadVector128(pMatTemp); - Vector128 x11 = Sse.LoadVector128(pMatTemp += ccol); - Vector128 x21 = Sse.LoadVector128(pMatTemp += ccol); - Vector128 x31 = Sse.LoadVector128(pMatTemp += ccol); Vector128 vector = Sse.LoadVector128(pSrcCurrent); - res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); - res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); - res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); - res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + + res0 = Sse.Add(res0, x01); + res1 = Sse.Add(res1, x11); + res2 = Sse.Add(res2, x21); + res3 = Sse.Add(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -194,25 +200,20 @@ public static unsafe void MatMul(bool add, AlignedArray mat, AlignedArray src, A misalignment >>= 2; misalignment = 4 - misalignment; - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.LoadVector128(pMatTemp); - Vector128 x11 = Sse.LoadVector128(pMatTemp += ccol); - Vector128 x21 = Sse.LoadVector128(pMatTemp += ccol); - Vector128 x31 = Sse.LoadVector128(pMatTemp += ccol); - Vector128 vector = Sse.LoadVector128(pSrcCurrent); - Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - Vector128 tempX01 = Sse.And(x01, mask); - Vector128 tempX11 = Sse.And(x11, mask); - Vector128 tempX21 = Sse.And(x21, mask); - Vector128 tempX31 = Sse.And(x31, mask); - Vector128 tempVec = Sse.And(vector, mask); + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); - res0 = Sse.Multiply(tempX01, tempVec); - res1 = Sse.Multiply(tempX11, tempVec); - res2 = Sse.Multiply(tempX21, tempVec); - res3 = Sse.Multiply(tempX31, tempVec); + res0 = Sse.Multiply(x01, vector); + res1 = Sse.Multiply(x11, vector); + res2 = Sse.Multiply(x21, vector); + res3 = Sse.Multiply(x31, vector); pMatCurrent += misalignment; pSrcCurrent += misalignment; @@ -222,19 +223,20 @@ public static unsafe void MatMul(bool add, AlignedArray mat, AlignedArray src, A if (length > 4) { remainder = length % 4; - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent < pSrcEnd) { - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.LoadAlignedVector128(pMatTemp); - Vector128 x11 = Sse.LoadAlignedVector128(pMatTemp += ccol); - Vector128 x21 = Sse.LoadAlignedVector128(pMatTemp += ccol); - Vector128 x31 = Sse.LoadAlignedVector128(pMatTemp += ccol); Vector128 vector = Sse.LoadVector128(pSrcCurrent); - res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); - res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); - res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); - res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + + res0 = Sse.Add(res0, x01); + res1 = Sse.Add(res1, x11); + res2 = Sse.Add(res2, x21); + res3 = Sse.Add(res3, x31); pSrcCurrent += 4; pMatCurrent += 4; @@ -250,25 +252,19 @@ public static unsafe void MatMul(bool add, AlignedArray mat, AlignedArray src, A pMatCurrent -= (4 - remainder); pSrcCurrent -= (4 - remainder); - float* pMatTemp = pMatCurrent; - Vector128 x01 = Sse.LoadVector128(pMatTemp); - Vector128 x11 = Sse.LoadVector128(pMatTemp += ccol); - Vector128 x21 = Sse.LoadVector128(pMatTemp += ccol); - Vector128 x31 = Sse.LoadVector128(pMatTemp += ccol); - Vector128 vector = Sse.LoadVector128(pSrcCurrent); - Vector128 mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - Vector128 tempX01 = Sse.And(x01, mask); - Vector128 tempX11 = Sse.And(x11, mask); - Vector128 tempX21 = Sse.And(x21, mask); - Vector128 tempX31 = Sse.And(x31, mask); - Vector128 tempVec = Sse.And(vector, mask); + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); - res0 = Sse.Add(res0, Sse.Multiply(tempVec, tempX01)); - res1 = Sse.Add(res1, Sse.Multiply(tempVec, tempX11)); - res2 = Sse.Add(res2, Sse.Multiply(tempVec, tempX21)); - res3 = Sse.Add(res3, Sse.Multiply(tempVec, tempX31)); + res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); + res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); + res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); + res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); pMatCurrent += 4; pSrcCurrent += 4; @@ -354,10 +350,14 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr { Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); + MatMulTran(add, mat.Items, src.Items, dst.Items, crow, ccol); + } - fixed (float* psrc = &src.Items[0]) - fixed (float* pdst = &dst.Items[0]) - fixed (float* pmat = &mat.Items[0]) + public static unsafe void MatMulTran(bool add, float[] mat, float[] src, float[] dst, int crow, int ccol) + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* pmat = &mat[0]) fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { @@ -370,7 +370,7 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr // We do 4-way unrolling while (pSrcCurrent < pSrcEnd) { - Vector128 x01 = Sse.LoadAlignedVector128(pSrcCurrent); + Vector128 x01 = Sse.LoadVector128(pSrcCurrent); // Replicate each 32-bit slot of x01 (ABCD) into its own register. Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C @@ -388,17 +388,10 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr while (pDstCurrent < pDstEnd) { float* pMatTemp = pMatCurrent; - - Vector128 x02 = Sse.LoadVector128(pMatTemp); - Vector128 x12 = Sse.LoadVector128(pMatTemp += crow); - Vector128 x22 = Sse.LoadVector128(pMatTemp += crow); - Vector128 x32 = Sse.LoadVector128(pMatTemp += crow); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); x02 = Sse.Add(x02, x12); x22 = Sse.Add(x22, x32); @@ -406,7 +399,7 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr if (add || !firstTime) { - x02 = Sse.Add(x02, x3); + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); } Sse.Store(pDstCurrent, x02); @@ -423,21 +416,15 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr // masking any elements that will be included in the first aligned read misalignment >>= 2; misalignment = 4 - misalignment; - float* pMatTemp = pMatCurrent; - - Vector128 x02 = Sse.LoadVector128(pMatTemp); - Vector128 x12 = Sse.LoadVector128(pMatTemp += crow); - Vector128 x22 = Sse.LoadVector128(pMatTemp += crow); - Vector128 x32 = Sse.LoadVector128(pMatTemp += crow); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); - x02 = Sse.And(x02, leadingMask); - x12 = Sse.And(x12, leadingMask); - x22 = Sse.And(x22, leadingMask); - x32 = Sse.And(x32, leadingMask); + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); x02 = Sse.Multiply(x01, x02); x12 = Sse.Multiply(x11, x12); @@ -448,6 +435,8 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr x22 = Sse.Add(x22, x32); x02 = Sse.Add(x02, x22); + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); if (add || !firstTime) @@ -467,16 +456,10 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr { float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); - Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); x02 = Sse.Add(x02, x12); x22 = Sse.Add(x22, x32); @@ -484,7 +467,7 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr if (add || !firstTime) { - x02 = Sse.Add(x02, x3); + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); } Sse.Store(pDstCurrent, x02); @@ -501,22 +484,13 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr { pMatCurrent -= (4 - remainder); pDstCurrent -= (4 - remainder); - float* pMatTemp = pMatCurrent; - - Vector128 x02 = Sse.LoadVector128(pMatTemp); - Vector128 x12 = Sse.LoadVector128(pMatTemp += crow); - Vector128 x22 = Sse.LoadVector128(pMatTemp += crow); - Vector128 x32 = Sse.LoadVector128(pMatTemp += crow); - Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); - Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); - - x02 = Sse.And(x02, trailingMask); - x12 = Sse.And(x12, trailingMask); - x22 = Sse.And(x22, trailingMask); - x32 = Sse.And(x32, trailingMask); - Vector128 x3 = Sse.LoadVector128(pDstCurrent); + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); x02 = Sse.Multiply(x01, x02); x12 = Sse.Multiply(x11, x12); @@ -527,6 +501,8 @@ public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray sr x22 = Sse.Add(x22, x32); x02 = Sse.Add(x02, x22); + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); if (add || !firstTime) diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 72b89fc765..830e2c8a45 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -145,21 +145,16 @@ EXPORT_API(void) MatMul(bool add, _In_ const float * pmat, _In_ const float * ps if ((misalignment & 3) != 0) { - while (pSrcCurrent + 4 <= pSrcEnd) + while (pSrcCurrent < pSrcEnd) { - remainder = ccol % 4; - - __m128 x01 = _mm_loadu_ps(pMatCurrent); - __m128 x11 = _mm_loadu_ps(pMatCurrent + ccol); - __m128 x21 = _mm_loadu_ps(pMatCurrent + 2 * ccol); - __m128 x31 = _mm_loadu_ps(pMatCurrent + 3 * ccol); __m128 vector = _mm_loadu_ps(pSrcCurrent); - x01 = _mm_mul_ps(x01, vector); - x11 = _mm_mul_ps(x11, vector); - x21 = _mm_mul_ps(x21, vector); - x31 = _mm_mul_ps(x31, vector); - + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + res0 = _mm_add_ps(res0, x01); res1 = _mm_add_ps(res1, x11); res2 = _mm_add_ps(res2, x21); @@ -176,25 +171,20 @@ EXPORT_API(void) MatMul(bool add, _In_ const float * pmat, _In_ const float * ps misalignment >>= 2; misalignment = 4 - misalignment; - __m128 x01 = _mm_loadu_ps(pMatCurrent); - __m128 x11 = _mm_loadu_ps(pMatCurrent + ccol); - __m128 x21 = _mm_loadu_ps(pMatCurrent + 2 * ccol); - __m128 x31 = _mm_loadu_ps(pMatCurrent + 3 * ccol); - __m128 vector = _mm_loadu_ps(pSrcCurrent); - __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - - __m128 tempX01 = _mm_and_ps(x01, mask); - __m128 tempX11 = _mm_and_ps(x11, mask); - __m128 tempX21 = _mm_and_ps(x21, mask); - __m128 tempX31 = _mm_and_ps(x31, mask); - - __m128 tempVec = _mm_and_ps(vector, mask); - res0 = _mm_mul_ps(tempX01, tempVec); - res1 = _mm_mul_ps(tempX11, tempVec); - res2 = _mm_mul_ps(tempX21, tempVec); - res3 = _mm_mul_ps(tempX31, tempVec); + // We only align pMat since it has significantly more reads. + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); + + res0 = _mm_mul_ps(x01, vector); + res1 = _mm_mul_ps(x11, vector); + res2 = _mm_mul_ps(x21, vector); + res3 = _mm_mul_ps(x31, vector); pMatCurrent += misalignment; pSrcCurrent += misalignment; @@ -204,19 +194,15 @@ EXPORT_API(void) MatMul(bool add, _In_ const float * pmat, _In_ const float * ps if (length > 3) { remainder = length % 4; - while(pSrcCurrent + 4 <= pSrcEnd) + while(pSrcCurrent < pSrcEnd) { - __m128 x01 = _mm_load_ps(pMatCurrent); - __m128 x11 = _mm_load_ps(pMatCurrent + ccol); - __m128 x21 = _mm_load_ps(pMatCurrent + 2 * ccol); - __m128 x31 = _mm_load_ps(pMatCurrent + 3 * ccol); - __m128 vector = _mm_loadu_ps(pSrcCurrent); - x01 = _mm_mul_ps(x01, vector); - x11 = _mm_mul_ps(x11, vector); - x21 = _mm_mul_ps(x21, vector); - x31 = _mm_mul_ps(x31, vector); + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp)); + __m128 x11 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + __m128 x21 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + __m128 x31 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); res0 = _mm_add_ps(res0, x01); res1 = _mm_add_ps(res1, x11); @@ -236,31 +222,20 @@ EXPORT_API(void) MatMul(bool add, _In_ const float * pmat, _In_ const float * ps { pMatCurrent -= (4 - remainder); pSrcCurrent -= (4 - remainder); - - __m128 x01 = _mm_loadu_ps(pMatCurrent); - __m128 x11 = _mm_loadu_ps(pMatCurrent + ccol); - __m128 x21 = _mm_loadu_ps(pMatCurrent + 2 * ccol); - __m128 x31 = _mm_loadu_ps(pMatCurrent + 3 * ccol); - __m128 vector = _mm_loadu_ps(pSrcCurrent); __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - __m128 tempX01 = _mm_and_ps(x01, mask); - __m128 tempX11 = _mm_and_ps(x11, mask); - __m128 tempX21 = _mm_and_ps(x21, mask); - __m128 tempX31 = _mm_and_ps(x31, mask); - - __m128 tempVec = _mm_and_ps(vector, mask); - - x01 = _mm_mul_ps(tempX01, tempVec); - x11 = _mm_mul_ps(tempX11, tempVec); - x21 = _mm_mul_ps(tempX21, tempVec); - x31 = _mm_mul_ps(tempX31, tempVec); - - res0 = _mm_add_ps(res0, x01); - res1 = _mm_add_ps(res1, x11); - res2 = _mm_add_ps(res2, x21); - res3 = _mm_add_ps(res3, x31); + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); + + res0 = _mm_add_ps(x01, _mm_mul_ps(x01, vector)); + res1 = _mm_add_ps(x11, _mm_mul_ps(x11, vector)); + res2 = _mm_add_ps(x21, _mm_mul_ps(x21, vector)); + res3 = _mm_add_ps(x31, _mm_mul_ps(x31, vector)); pMatCurrent += 4; pSrcCurrent += 4; @@ -644,16 +619,11 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float { while (pDstCurrent < pDstEnd) { - __m128 x02 = _mm_loadu_ps(pMatCurrent); - __m128 x12 = _mm_loadu_ps(pMatCurrent + crow); - __m128 x22 = _mm_loadu_ps(pMatCurrent + 2 * crow); - __m128 x32 = _mm_loadu_ps(pMatCurrent + 3 * crow); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); x02 = _mm_add_ps(x02, x12); x22 = _mm_add_ps(x22, x32); @@ -661,10 +631,10 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float if (add || !firstTime) { - x02 = _mm_add_ps(x02, x3); + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); } - _mm_storeu_ps(pDstCurrent, x02); + _mm_storeu_ps(pDstCurrent, x02); pDstCurrent += 4; pMatCurrent += 4; } @@ -676,21 +646,15 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float { misalignment >>= 2; misalignment = 4 - misalignment; + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - __m128 x02 = _mm_loadu_ps(pMatCurrent); - __m128 x12 = _mm_loadu_ps(pMatCurrent + crow); - __m128 x22 = _mm_loadu_ps(pMatCurrent + 2 * crow); - __m128 x32 = _mm_loadu_ps(pMatCurrent + 3 * crow); - - __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); - __m128 mask2 = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (( 4 - misalignment) * 4)); - - x02 = _mm_and_ps(x02, mask); - x12 = _mm_and_ps(x12, mask); - x22 = _mm_and_ps(x22, mask); - x32 = _mm_and_ps(x32, mask); - - __m128 x3 = _mm_loadu_ps(pDstCurrent); + // We only align pMat since it has significantly more reads. + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); x02 = _mm_mul_ps(x01, x02); x12 = _mm_mul_ps(x11, x12); @@ -701,11 +665,13 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float x22 = _mm_add_ps(x22, x32); x02 = _mm_add_ps(x02, x22); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, mask2)); + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (( 4 - misalignment) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); if (add || !firstTime) { - x02 = _mm_add_ps(x02, _mm_and_ps(x3, mask)); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); } _mm_storeu_ps(pDstCurrent, x02); @@ -717,18 +683,13 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float if(length > 3) { remainder = length % 4; - while (pDstCurrent + 4 <= pDstEnd) + while (pDstCurrent < pDstEnd) { - __m128 x02 = _mm_loadu_ps(pMatCurrent); - __m128 x12 = _mm_loadu_ps(pMatCurrent + crow); - __m128 x22 = _mm_loadu_ps(pMatCurrent + 2 * crow); - __m128 x32 = _mm_loadu_ps(pMatCurrent + 3 * crow); - __m128 x3 = _mm_loadu_ps(pDstCurrent); - - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); x02 = _mm_add_ps(x02, x12); x22 = _mm_add_ps(x22, x32); @@ -736,7 +697,7 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float if (add || !firstTime) { - x02 = _mm_add_ps(x02, x3); + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); } _mm_storeu_ps(pDstCurrent, x02); @@ -753,21 +714,14 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float { pMatCurrent -= (4 - remainder); pDstCurrent -= (4 - remainder); - - __m128 x02 = _mm_loadu_ps(pMatCurrent); - __m128 x12 = _mm_loadu_ps(pMatCurrent + crow); - __m128 x22 = _mm_loadu_ps(pMatCurrent + 2 * crow); - __m128 x32 = _mm_loadu_ps(pMatCurrent + 3 * crow); - - __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); - __m128 mask2 = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (( 4 - remainder) * 4)); - x02 = _mm_and_ps(x02, mask); - x12 = _mm_and_ps(x12, mask); - x22 = _mm_and_ps(x22, mask); - x32 = _mm_and_ps(x32, mask); - - __m128 x3 = _mm_loadu_ps(pDstCurrent); + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); x02 = _mm_mul_ps(x01, x02); x12 = _mm_mul_ps(x11, x12); @@ -778,11 +732,13 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float x22 = _mm_add_ps(x22, x32); x02 = _mm_add_ps(x02, x22); - x02 = _mm_or_ps(x02, _mm_and_ps(x3, mask2)); + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (( 4 - remainder) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); if (add || !firstTime) { - x02 = _mm_add_ps(x02, _mm_and_ps(x3, mask)); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); } _mm_storeu_ps(pDstCurrent, x02); diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs index 9e4fc83032..ec9f1a874e 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs @@ -11,26 +11,6 @@ namespace Microsoft.ML.CpuMath.PerformanceTests { public class AvxPerformanceTests : PerformanceTests { - private AlignedArray _testMatrices; - private AlignedArray _testSrcVectors; - private AlignedArray _testDstVectors; - - [GlobalSetup(Targets = new string[] { nameof(MatMulX), nameof(MatMulTranX) })] - public void MatMulSetup() - { - Setup(); - int vectorAlignment = CpuMathUtils.GetVectorAlignment(); - - _testMatrices = new AlignedArray(1000 * 1000, vectorAlignment); - _testMatrices.CopyFrom(src, 0, 1000 * 1000); - - _testSrcVectors = new AlignedArray(1000, vectorAlignment); - _testSrcVectors.CopyFrom(src, 0, 1000); - - _testDstVectors = new AlignedArray(1000, vectorAlignment); - _testDstVectors.CopyFrom(dst, 0, 1000); - } - [Benchmark] public void AddScalarU() => AvxIntrinsics.AddScalarU(DefaultScale, new Span(dst, 0, Length)); @@ -122,10 +102,10 @@ public void SdcaL1UpdateSU() [Benchmark] public void MatMulX() - => AvxIntrinsics.MatMulX(true, _testMatrices, _testMatrices, _testDstVectors, 1000, 1000); + => AvxIntrinsics.MatMulX(true, src, src1, dst, 1000, 1000); [Benchmark] public void MatMulTranX() - => AvxIntrinsics.MatMulTranX(true, _testMatrices, _testMatrices, _testDstVectors, 1000, 1000); + => AvxIntrinsics.MatMulTranX(true, src, src1, dst, 1000, 1000); } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs index 9a9978851e..1d6c9f7c40 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs @@ -11,26 +11,6 @@ namespace Microsoft.ML.CpuMath.PerformanceTests { public class SsePerformanceTests : PerformanceTests { - private AlignedArray _testMatrices; - private AlignedArray _testSrcVectors; - private AlignedArray _testDstVectors; - - [GlobalSetup(Targets = new string[] { nameof(MatMulX), nameof(MatMulTranX) })] - public void MatMulSetup() - { - Setup(); - int vectorAlignment = CpuMathUtils.GetVectorAlignment(); - - _testMatrices = new AlignedArray(1000 * 1000, vectorAlignment); - _testMatrices.CopyFrom(src, 0, 1000 * 1000); - - _testSrcVectors = new AlignedArray(1000, vectorAlignment); - _testSrcVectors.CopyFrom(src, 0, 1000); - - _testDstVectors = new AlignedArray(1000, vectorAlignment); - _testDstVectors.CopyFrom(dst, 0, 1000); - } - [Benchmark] public void AddScalarU() => SseIntrinsics.AddScalarU(DefaultScale, new Span(dst, 0, Length)); @@ -122,10 +102,10 @@ public void SdcaL1UpdateSU() [Benchmark] public void MatMulX() - => SseIntrinsics.MatMul(true, _testMatrices, _testMatrices, _testDstVectors, 1000, 1000); + => SseIntrinsics.MatMul(true, src, src1, dst, 1000, 1000); [Benchmark] public void MatMulTranX() - => SseIntrinsics.MatMulTran(true, _testMatrices, _testMatrices, _testDstVectors, 1000, 1000); + => SseIntrinsics.MatMulTran(true, src, src1, dst, 1000, 1000); } } From 179310974b88b3ba1fd2bcefc9fb45872bbb10f4 Mon Sep 17 00:00:00 2001 From: Anipik Date: Sat, 13 Oct 2018 14:11:00 -0700 Subject: [PATCH 4/7] add flag removed --- src/Microsoft.ML.CpuMath/Avx.cs | 10 +-- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 82 ++++++++----------- .../CpuAligenedMathUtils.cs | 20 ++--- .../CpuMathUtils.netcoreapp.cs | 59 ++++--------- .../CpuMathUtils.netstandard.cs | 6 +- src/Microsoft.ML.CpuMath/Sse.cs | 15 ++-- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 69 +++++++--------- src/Microsoft.ML.CpuMath/Thunk.cs | 16 ++-- ...AdaptiveSingularSpectrumSequenceModeler.cs | 10 +-- src/Microsoft.ML.Transforms/RffTransform.cs | 4 +- src/Native/CpuMathNative/Sse.cpp | 40 ++++----- .../AvxPerformanceTests.cs | 4 +- .../NativePerformanceTests.cs | 4 +- .../SsePerformanceTests.cs | 4 +- .../UnitTests.cs | 74 +---------------- 15 files changed, 140 insertions(+), 277 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/Avx.cs b/src/Microsoft.ML.CpuMath/Avx.cs index 5d4610d9bc..2736ad9137 100644 --- a/src/Microsoft.ML.CpuMath/Avx.cs +++ b/src/Microsoft.ML.CpuMath/Avx.cs @@ -34,7 +34,7 @@ public static bool CheckAvx() return Thunk.ChkAvx(); } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) + public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) { Contracts.Assert(Compat(mat)); Contracts.Assert(Compat(src)); @@ -50,12 +50,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulX(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); + Thunk.MatMulX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); } else { Contracts.Assert(0 <= crun && crun <= src.Size); - Thunk.MatMulTranX(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); + Thunk.MatMulTranX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); } } } @@ -88,12 +88,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (!tran) { Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulPX(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); + Thunk.MatMulPX(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); } else { Contracts.Assert(0 <= crun && crun <= srcValues.Size); - Thunk.MatMulTranPX(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size); + Thunk.MatMulTranPX(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size); } } } diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index 760cd718db..0183d47d8e 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -141,15 +141,15 @@ private static Vector256 GetNewDst256(in Vector256 xDst1, in Vecto } // Multiply matrix times vector into vector. - public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); - MatMulX(add, mat.Items, src.Items, dst.Items, crow, ccol); + MatMulX(mat.Items, src.Items, dst.Items, crow, ccol); } - public static unsafe void MatMulX(bool add, float[] mat, float[] src, float[] dst, int crow, int ccol) + public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int crow, int ccol) { fixed (float* psrc = &src[0]) fixed (float* pdst = &dst[0]) @@ -165,9 +165,9 @@ public static unsafe void MatMulX(bool add, float[] mat, float[] src, float[] ds while (pDstCurrent < pDstEnd) { Vector256 res0 = Avx.SetZeroVector256(); - Vector256 res1 = res0; - Vector256 res2 = res0; - Vector256 res3 = res0; + Vector256 res1 = Avx.SetZeroVector256(); + Vector256 res2 = Avx.SetZeroVector256(); + Vector256 res3 = Avx.SetZeroVector256(); int length = ccol; float* pSrcCurrent = psrc; @@ -209,7 +209,7 @@ public static unsafe void MatMulX(bool add, float[] mat, float[] src, float[] ds Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - // We only align pMat since it has significantly more reads. + // We only align pMat since it has significantly more reads. float* pMatTemp = pMatCurrent; Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); @@ -230,7 +230,7 @@ public static unsafe void MatMulX(bool add, float[] mat, float[] src, float[] ds if (length > 7) { remainder = length % 8; - while (pSrcCurrent < pSrcEnd) + while (pSrcCurrent + 8 <= pSrcEnd) { Vector256 vector = Avx.LoadVector256(pSrcCurrent); @@ -284,10 +284,6 @@ public static unsafe void MatMulX(bool add, float[] mat, float[] src, float[] ds res0 = Avx.HorizontalAdd(res0, res2); Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0)); - if (add) - { - sum = Sse.Add(sum, Sse.LoadVector128(pDstCurrent)); - } Sse.Store(pDstCurrent, sum); pDstCurrent += 4; @@ -297,7 +293,7 @@ public static unsafe void MatMulX(bool add, float[] mat, float[] src, float[] ds } // Partial sparse source vector. - public static unsafe void MatMulPX(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src, + public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { Contracts.Assert(HasCompatibleAlignment(mat)); @@ -344,27 +340,22 @@ public static unsafe void MatMulPX(bool add, AlignedArray mat, int[] rgposSrc, A ppos++; } - if (add) - { - result = Avx.Add(result, Avx.LoadAlignedVector256(pDstCurrent)); - } Avx.StoreAligned(pDstCurrent, result); - pDstCurrent += 8; pm0 += 8 * ccol; } } } - public static unsafe void MatMulTranX(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulTranX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); - MatMulTranX(add, mat.Items, src.Items, dst.Items, crow, ccol); + MatMulTranX(mat.Items, src.Items, dst.Items, crow, ccol); } - public static unsafe void MatMulTranX(bool add, float[] mat, float[] src, float[] dst, int crow, int ccol) + public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int crow, int ccol) { fixed (float* psrc = &src[0]) fixed (float* pdst = &dst[0]) @@ -383,10 +374,10 @@ public static unsafe void MatMulTranX(bool add, float[] mat, float[] src, float[ { Vector128 h01 = Sse.LoadVector128(pSrcCurrent); // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Sse.Shuffle(h01, h01, 0x55); // B - Vector128 h21 = Sse.Shuffle(h01, h01, 0xAA); // C - Vector128 h31 = Sse.Shuffle(h01, h01, 0xFF); // D - h01 = Sse.Shuffle(h01, h01, 0x00); // A + Vector128 h11 = Avx.Permute(h01, 0x55); // B + Vector128 h21 = Avx.Permute(h01, 0xAA); // C + Vector128 h31 = Avx.Permute(h01, 0xFF); // D + h01 = Avx.Permute(h01, 0x00); // A Vector256 x01 = Avx.SetHighLow(h01, h01); Vector256 x11 = Avx.SetHighLow(h11, h11); @@ -413,7 +404,7 @@ public static unsafe void MatMulTranX(bool add, float[] mat, float[] src, float[ x22 = Avx.Add(x22, x32); x02 = Avx.Add(x02, x22); - if (add || !firstTime) + if (!firstTime) { x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); } @@ -455,7 +446,7 @@ public static unsafe void MatMulTranX(bool add, float[] mat, float[] src, float[ Vector256 x3 = Avx.LoadVector256(pDstCurrent); x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); - if (add || !firstTime) + if (!firstTime) { x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); } @@ -468,7 +459,7 @@ public static unsafe void MatMulTranX(bool add, float[] mat, float[] src, float[ if (length > 7) { remainder = length % 8; - while (pDstCurrent < pDstEnd) + while (pDstCurrent + 8 <= pDstEnd) { float* pMatTemp = pMatCurrent; @@ -481,7 +472,7 @@ public static unsafe void MatMulTranX(bool add, float[] mat, float[] src, float[ x22 = Avx.Add(x22, x32); x02 = Avx.Add(x02, x22); - if (add || !firstTime) + if (!firstTime) { x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); } @@ -521,7 +512,7 @@ public static unsafe void MatMulTranX(bool add, float[] mat, float[] src, float[ Vector256 x3 = Avx.LoadVector256(pDstCurrent); x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); - if (add || !firstTime) + if (!firstTime) { x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); } @@ -540,7 +531,7 @@ public static unsafe void MatMulTranX(bool add, float[] mat, float[] src, float[ } // Partial sparse source vector. - public static unsafe void MatMulTranPX(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src, + public static unsafe void MatMulTranPX(AlignedArray mat, int[] rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow) { Contracts.Assert(HasCompatibleAlignment(mat)); @@ -560,26 +551,23 @@ public static unsafe void MatMulTranPX(bool add, AlignedArray mat, int[] rgposSr int* pposEnd = pposSrc + iposEnd; float* pDstEnd = pdst + crow; - if (!add) - { - int col = *ppos - posMin; - ppos++; + int col = *ppos - posMin; + ppos++; - Vector256 x0 = Avx.SetAllVector256(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; + Vector256 x0 = Avx.SetAllVector256(psrc[col]); + float* pDstCurrent = pdst; + float* pMatCurrent = pmat + col * crow; - while (pDstCurrent < pDstEnd) - { - Vector256 x1 = Avx.LoadAlignedVector256(pMatCurrent); - x1 = Avx.Multiply(x1, x0); - Avx.StoreAligned(pDstCurrent, x1); + while (pDstCurrent < pDstEnd) + { + Vector256 x1 = Avx.LoadAlignedVector256(pMatCurrent); + x1 = Avx.Multiply(x1, x0); + Avx.StoreAligned(pDstCurrent, x1); - pDstCurrent += 8; - pMatCurrent += 8; - } + pDstCurrent += 8; + pMatCurrent += 8; } - + // REVIEW: Should we explore unrolling the outer loop? while (ppos < pposEnd) { diff --git a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs index 30308f219d..bcc92234d0 100644 --- a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs +++ b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs @@ -75,40 +75,32 @@ public static void AssertCompatible(ICpuFullMatrix mat, ICpuVector src, ICpuVect /// /// Matrix multiplication: - /// if (add) - /// dst = mat * src - /// else - /// dest += mat * src + /// dst = mat * src /// - /// The addition flag /// The multiplier matrix /// The source vector /// The destination vector - public static void MatTimesSrc(bool add, ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) + public static void MatTimesSrc(ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) { bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol); AssertCompatible(mat, src, dst); var m = A(mat); - CpuMathUtils.MatTimesSrc(colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt); + CpuMathUtils.MatTimesSrc(colMajor, m.Items, A(src).Items, A(dst).Items, m.RunCnt); } /// /// Matrix transpose multiplication: - /// if (add) - /// dst = mat' * src - /// else - /// dest += mat' * src + /// dst = mat' * src /// - /// The addition flag /// The multiplier matrix /// The source vector /// The destination vector - public static void MatTranTimesSrc(bool add, ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) + public static void MatTranTimesSrc(ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) { bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol); AssertCompatible(mat, dst, src); var m = A(mat); - CpuMathUtils.MatTimesSrc(!colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt); + CpuMathUtils.MatTimesSrc(!colMajor, m.Items, A(src).Items, A(dst).Items, m.RunCnt); } } diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs index 0171b68505..356d68f72b 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs @@ -24,7 +24,7 @@ public static partial class CpuMathUtils public static int GetVectorAlignment() => Avx.IsSupported ? Vector256Alignment : (Sse.IsSupported ? Vector128Alignment : FloatAlignment); - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) + public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) { Contracts.Assert(mat.Size == dst.Size * src.Size); Contracts.Assert(crun >= 0); @@ -34,12 +34,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(crun <= dst.Size); - AvxIntrinsics.MatMulX(add, mat, src, dst, crun, src.Size); + AvxIntrinsics.MatMulX(mat, src, dst, crun, src.Size); } else { Contracts.Assert(crun <= src.Size); - AvxIntrinsics.MatMulTranX(add, mat, src, dst, dst.Size, crun); + AvxIntrinsics.MatMulTranX(mat, src, dst, dst.Size, crun); } } else if (Sse.IsSupported) @@ -47,12 +47,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(crun <= dst.Size); - SseIntrinsics.MatMul(add, mat, src, dst, crun, src.Size); + SseIntrinsics.MatMul(mat, src, dst, crun, src.Size); } else { Contracts.Assert(crun <= src.Size); - SseIntrinsics.MatMulTran(add, mat, src, dst, dst.Size, crun); + SseIntrinsics.MatMulTran(mat, src, dst, dst.Size, crun); } } else @@ -68,14 +68,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr dotProduct += mat[i * src.Size + j] * src[j]; } - if (add) - { - dst[i] += dotProduct; - } - else - { - dst[i] = dotProduct; - } + dst[i] = dotProduct; } } else @@ -89,20 +82,13 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr dotProduct += mat[j * src.Size + i] * src[j]; } - if (add) - { - dst[i] += dotProduct; - } - else - { - dst[i] = dotProduct; - } + dst[i] = dotProduct; } } } } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + public static void MatTimesSrc(bool tran, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) { Contracts.AssertValue(rgposSrc); @@ -113,8 +99,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (iposMin >= iposLim) { - if (!add) - dst.ZeroItems(); + dst.ZeroItems(); return; } @@ -126,12 +111,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (!tran) { Contracts.Assert(crun <= dst.Size); - AvxIntrinsics.MatMulPX(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); + AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); } else { Contracts.Assert(crun <= srcValues.Size); - AvxIntrinsics.MatMulTranPX(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size); + AvxIntrinsics.MatMulTranPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size); } } else if (Sse.IsSupported) @@ -139,12 +124,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (!tran) { Contracts.Assert(crun <= dst.Size); - SseIntrinsics.MatMulPA(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); + SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); } else { Contracts.Assert(crun <= srcValues.Size); - SseIntrinsics.MatMulTranPA(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size); + SseIntrinsics.MatMulTranPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size); } } else @@ -161,14 +146,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo dotProduct += mat[i * srcValues.Size + col] * srcValues[col]; } - if (add) - { - dst[i] += dotProduct; - } - else - { - dst[i] = dotProduct; - } + dst[i] = dotProduct; } } else @@ -183,14 +161,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo dotProduct += mat[col * dst.Size + i] * srcValues[col]; } - if (add) - { - dst[i] += dotProduct; - } - else - { - dst[i] = dotProduct; - } + dst[i] = dotProduct; } } diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs index b35f171388..72d5210e20 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs @@ -15,10 +15,10 @@ public static partial class CpuMathUtils public static int GetVectorAlignment() => Vector128Alignment; - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, add, mat, src, dst, crun); + public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, mat, src, dst, crun); - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, - int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun); + public static void MatTimesSrc(bool tran, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun); public static void Add(float a, float[] dst, int count) => SseUtils.Add(a, dst, count); diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs index d5780a4fae..58f80318d0 100644 --- a/src/Microsoft.ML.CpuMath/Sse.cs +++ b/src/Microsoft.ML.CpuMath/Sse.cs @@ -27,7 +27,7 @@ private static bool Compat(AlignedArray a) return q; } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) + public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) { Contracts.Assert(Compat(mat)); Contracts.Assert(Compat(src)); @@ -43,18 +43,18 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMul(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); + Thunk.MatMul(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); } else { Contracts.Assert(0 <= crun && crun <= src.Size); - Thunk.MatMulTran(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); + Thunk.MatMulTran(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); } } } } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + public static void MatTimesSrc(bool tran, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) { Contracts.Assert(Compat(mat)); @@ -66,8 +66,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (iposMin >= iposLim) { - if (!add) - dst.ZeroItems(); + dst.ZeroItems(); return; } Contracts.AssertNonEmpty(rgposSrc); @@ -81,12 +80,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (!tran) { Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulPA(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); + Thunk.MatMulPA(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); } else { Contracts.Assert(0 <= crun && crun <= srcValues.Size); - Thunk.MatMulTranPA(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size); + Thunk.MatMulTranPA(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size); } } } diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index b4ad17b6a1..0a3e57953f 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -134,15 +134,15 @@ internal static Vector128 GetNewDst128(in Vector128 xDst1, in Vect } // Multiply matrix times vector into vector. - public static unsafe void MatMul(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); - MatMul(add, mat.Items, src.Items, dst.Items, crow, ccol); + MatMul(mat.Items, src.Items, dst.Items, crow, ccol); } - public static unsafe void MatMul(bool add, float[] mat, float[] src, float[] dst, int crow, int ccol) + public static unsafe void MatMul(float[] mat, float[] src, float[] dst, int crow, int ccol) { fixed (float* psrc = &src[0]) fixed (float* pdst = &dst[0]) @@ -158,9 +158,9 @@ public static unsafe void MatMul(bool add, float[] mat, float[] src, float[] dst while (pDstCurrent < pDstEnd) { Vector128 res0 = Sse.SetZeroVector128(); - Vector128 res1 = res0; - Vector128 res2 = res0; - Vector128 res3 = res0; + Vector128 res1 = Sse.SetZeroVector128(); + Vector128 res2 = Sse.SetZeroVector128(); + Vector128 res3 = Sse.SetZeroVector128(); int length = ccol; float* pSrcCurrent = psrc; @@ -276,11 +276,6 @@ public static unsafe void MatMul(bool add, float[] mat, float[] src, float[] dst res2 = Sse3.HorizontalAdd(res2, res3); res0 = Sse3.HorizontalAdd(res0, res2); - if (add) - { - res0 = Sse.Add(res0, Sse.LoadVector128(pDstCurrent)); - } - Sse.Store(pDstCurrent, res0); pDstCurrent += 4; pMatCurrent += 3 * ccol; @@ -289,7 +284,7 @@ public static unsafe void MatMul(bool add, float[] mat, float[] src, float[] dst } // Partial sparse source vector. - public static unsafe void MatMulPA(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src, + public static unsafe void MatMulPA(AlignedArray mat, int[] rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { Contracts.Assert(HasCompatibleAlignment(mat)); @@ -334,26 +329,21 @@ public static unsafe void MatMulPA(bool add, AlignedArray mat, int[] rgposSrc, A ppos++; } - if (add) - { - result = Sse.Add(result, Sse.LoadAlignedVector128(pDstCurrent)); - } Sse.StoreAligned(pDstCurrent, result); - pDstCurrent += 4; pm0 += 4 * ccol; } } } - public static unsafe void MatMulTran(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); - MatMulTran(add, mat.Items, src.Items, dst.Items, crow, ccol); + MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); } - public static unsafe void MatMulTran(bool add, float[] mat, float[] src, float[] dst, int crow, int ccol) + public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int crow, int ccol) { fixed (float* psrc = &src[0]) fixed (float* pdst = &dst[0]) @@ -397,7 +387,7 @@ public static unsafe void MatMulTran(bool add, float[] mat, float[] src, float[] x22 = Sse.Add(x22, x32); x02 = Sse.Add(x02, x22); - if (add || !firstTime) + if (!firstTime) { x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); } @@ -439,7 +429,7 @@ public static unsafe void MatMulTran(bool add, float[] mat, float[] src, float[] Vector128 x3 = Sse.LoadVector128(pDstCurrent); x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); - if (add || !firstTime) + if (!firstTime) { x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); } @@ -465,7 +455,7 @@ public static unsafe void MatMulTran(bool add, float[] mat, float[] src, float[] x22 = Sse.Add(x22, x32); x02 = Sse.Add(x02, x22); - if (add || !firstTime) + if (!firstTime) { x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); } @@ -505,7 +495,7 @@ public static unsafe void MatMulTran(bool add, float[] mat, float[] src, float[] Vector128 x3 = Sse.LoadVector128(pDstCurrent); x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); - if (add || !firstTime) + if (!firstTime) { x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); } @@ -524,7 +514,7 @@ public static unsafe void MatMulTran(bool add, float[] mat, float[] src, float[] } // Partial sparse source vector. - public static unsafe void MatMulTranPA(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src, + public static unsafe void MatMulTranPA(AlignedArray mat, int[] rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow) { Contracts.Assert(HasCompatibleAlignment(mat)); @@ -544,26 +534,23 @@ public static unsafe void MatMulTranPA(bool add, AlignedArray mat, int[] rgposSr int* pposEnd = pposSrc + iposEnd; float* pDstEnd = pdst + crow; - if (!add) - { - int col = *ppos - posMin; - ppos++; + int col = *ppos - posMin; + ppos++; - Vector128 x0 = Sse.SetAllVector128(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; + Vector128 x0 = Sse.SetAllVector128(psrc[col]); + float* pDstCurrent = pdst; + float* pMatCurrent = pmat + col * crow; - while (pDstCurrent < pDstEnd) - { - Vector128 x1 = Sse.LoadAlignedVector128(pMatCurrent); - x1 = Sse.Multiply(x1, x0); - Sse.StoreAligned(pDstCurrent, x1); + while (pDstCurrent < pDstEnd) + { + Vector128 x1 = Sse.LoadAlignedVector128(pMatCurrent); + x1 = Sse.Multiply(x1, x0); + Sse.StoreAligned(pDstCurrent, x1); - pDstCurrent += 4; - pMatCurrent += 4; - } + pDstCurrent += 4; + pMatCurrent += 4; } - + // REVIEW: Should we explore unrolling the outer loop? while (ppos < pposEnd) { diff --git a/src/Microsoft.ML.CpuMath/Thunk.cs b/src/Microsoft.ML.CpuMath/Thunk.cs index 3291684873..897fd0901a 100644 --- a/src/Microsoft.ML.CpuMath/Thunk.cs +++ b/src/Microsoft.ML.CpuMath/Thunk.cs @@ -16,16 +16,16 @@ internal static unsafe class Thunk public static extern bool ChkAvx(); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMul(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); + public static extern void MatMul(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulX(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); + public static extern void MatMulX(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulPA(bool add, /*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, + public static extern void MatMulPA(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, int posMin, int iposMin, int iposLim, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulPX(bool add, /*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, + public static extern void MatMulPX(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, int posMin, int iposMin, int iposLim, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] @@ -65,15 +65,15 @@ public static extern void RespNormU(bool add, float alpha, float beta, bool avgO // and columns from that perspective. Alternatively, crow is the number of rows in the transpose of pmat // (thought of as row-major order). [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTran(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); + public static extern void MatMulTran(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranX(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); + public static extern void MatMulTranX(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranPA(bool add, /*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, + public static extern void MatMulTranPA(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, int posMin, int iposMin, int iposLim, float* pdst, int crow); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranPX(bool add, /*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, + public static extern void MatMulTranPX(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, int posMin, int iposMin, int iposLim, float* pdst, int crow); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index 7f0fe543f6..5c01472ed9 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -1130,10 +1130,10 @@ public void Consume(ref Single input, bool updateModel = false) _x[_windowSize - 1] = input; // Computing y: Eq. (11) in https://hal-institut-mines-telecom.archives-ouvertes.fr/hal-00479772/file/twocolumns.pdf - CpuAligenedMathUtils.MatTimesSrc(false, _wTrans, _x, _y); + CpuAligenedMathUtils.MatTimesSrc(_wTrans, _x, _y); // Updating the state vector - CpuAligenedMathUtils.MatTranTimesSrc(false, _wTrans, _y, _xSmooth); + CpuAligenedMathUtils.MatTranTimesSrc(_wTrans, _y, _xSmooth); _nextPrediction = _autoregressionNoiseMean + _observationNoiseMean; for (i = 0; i < _windowSize - 2; ++i) @@ -1337,7 +1337,7 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) nu += _y[i] * _y[i]; } - CpuAligenedMathUtils.MatTranTimesSrc(false, _wTrans, _y, _xSmooth); + CpuAligenedMathUtils.MatTranTimesSrc(_wTrans, _y, _xSmooth); for (i = 0; i < _windowSize - 1; ++i) _alpha[i] = _xSmooth[i] / (1 - nu); @@ -1382,8 +1382,8 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) _x[i - originalSeriesLength + _windowSize] = dataArray[i]; } - CpuAligenedMathUtils.MatTimesSrc(false, _wTrans, _x, _y); - CpuAligenedMathUtils.MatTranTimesSrc(false, _wTrans, _y, _xSmooth); + CpuAligenedMathUtils.MatTimesSrc(_wTrans, _x, _y); + CpuAligenedMathUtils.MatTranTimesSrc(_wTrans, _y, _xSmooth); for (i = 1; i < _windowSize; ++i) { diff --git a/src/Microsoft.ML.Transforms/RffTransform.cs b/src/Microsoft.ML.Transforms/RffTransform.cs index 218cc30b56..8cce5bdeeb 100644 --- a/src/Microsoft.ML.Transforms/RffTransform.cs +++ b/src/Microsoft.ML.Transforms/RffTransform.cs @@ -610,7 +610,7 @@ private void TransformFeatures(ref VBuffer src, ref VBuffer dst, T if (src.IsDense) { featuresAligned.CopyFrom(src.Values, 0, src.Length); - CpuMathUtils.MatTimesSrc(false, false, transformInfo.RndFourierVectors, featuresAligned, productAligned, + CpuMathUtils.MatTimesSrc(false, transformInfo.RndFourierVectors, featuresAligned, productAligned, transformInfo.NewDim); } else @@ -618,7 +618,7 @@ private void TransformFeatures(ref VBuffer src, ref VBuffer dst, T // This overload of MatTimesSrc ignores the values in slots that are not in src.Indices, so there is // no need to zero them out. featuresAligned.CopyFrom(src.Indices, src.Values, 0, 0, src.Count, zeroItems: false); - CpuMathUtils.MatTimesSrc(false, false, transformInfo.RndFourierVectors, src.Indices, featuresAligned, 0, 0, + CpuMathUtils.MatTimesSrc(false, transformInfo.RndFourierVectors, src.Indices, featuresAligned, 0, 0, src.Count, productAligned, transformInfo.NewDim); } diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 830e2c8a45..9ca29868fe 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -122,7 +122,7 @@ EXPORT_API(bool) ChkAvx() } // Multiply matrix times vector into vector. -EXPORT_API(void) MatMul(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) +EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { const float * pSrcEnd = psrc + ccol; const float * pDstEnd = pdst + crow; @@ -247,9 +247,6 @@ EXPORT_API(void) MatMul(bool add, _In_ const float * pmat, _In_ const float * ps res2 = _mm_hadd_ps(res2, res3); res0 = _mm_hadd_ps(res0, res2); - if (add) - res0 = _mm_add_ps(res0, _mm_loadu_ps(pDstCurrent)); - _mm_storeu_ps(pDstCurrent, res0); pDstCurrent += 4; @@ -258,7 +255,7 @@ EXPORT_API(void) MatMul(bool add, _In_ const float * pmat, _In_ const float * ps } // Partial sparse source vector. -EXPORT_API(void) MatMulPA(bool add, _In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, +EXPORT_API(void) MatMulPA(_In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow, int ccol) { // REVIEW: For extremely sparse inputs, interchanging the loops would @@ -283,8 +280,6 @@ EXPORT_API(void) MatMulPA(bool add, _In_ const float * pmat, _In_ const int * pp res = _mm_add_ps(res, x2); } - if (add) - res = _mm_add_ps(res, _mm_load_ps(pd)); _mm_store_ps(pd, res); } } @@ -590,7 +585,7 @@ EXPORT_API(void) RespNormU(bool add, float alpha, float beta, bool avgOverFullKe } } -EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) +EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { const float * pSrcEnd = psrc + ccol; const float * pDstEnd = pdst + crow; @@ -629,7 +624,7 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float x22 = _mm_add_ps(x22, x32); x02 = _mm_add_ps(x02, x22); - if (add || !firstTime) + if (!firstTime) { x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); } @@ -669,7 +664,7 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float __m128 x3 = _mm_loadu_ps(pDstCurrent); x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); - if (add || !firstTime) + if (!firstTime) { x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); } @@ -695,7 +690,7 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float x22 = _mm_add_ps(x22, x32); x02 = _mm_add_ps(x02, x22); - if (add || !firstTime) + if (!firstTime) { x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); } @@ -736,7 +731,7 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float __m128 x3 = _mm_loadu_ps(pDstCurrent); x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); - if (add || !firstTime) + if (!firstTime) { x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); } @@ -754,26 +749,23 @@ EXPORT_API(void) MatMulTran(bool add, _In_ const float * pmat, _In_ const float } // Partial sparse source vector. -EXPORT_API(void) MatMulTranPA(bool add, _In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, +EXPORT_API(void) MatMulTranPA(_In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow) { const int * ppos = pposSrc + iposMin; const int * pposLim = pposSrc + iposLim; const float * pdLim = pdst + crow; - if (!add) + int col = *ppos++ - posMin; + const float * pm = pmat + col * crow; + __m128 x0 = _mm_set1_ps(psrc[col]); + for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) { - int col = *ppos++ - posMin; - const float * pm = pmat + col * crow; - __m128 x0 = _mm_set1_ps(psrc[col]); - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) - { - __m128 x1 = _mm_load_ps(pm); - x1 = _mm_mul_ps(x1, x0); - _mm_store_ps(pd, x1); - } + __m128 x1 = _mm_load_ps(pm); + x1 = _mm_mul_ps(x1, x0); + _mm_store_ps(pd, x1); } - + // REVIEW: Should we explore unrolling the outer loop? for (; ppos < pposLim; ppos++) { diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs index ec9f1a874e..5400985319 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs @@ -102,10 +102,10 @@ public void SdcaL1UpdateSU() [Benchmark] public void MatMulX() - => AvxIntrinsics.MatMulX(true, src, src1, dst, 1000, 1000); + => AvxIntrinsics.MatMulX(src, src1, dst, 1000, 1000); [Benchmark] public void MatMulTranX() - => AvxIntrinsics.MatMulTranX(true, src, src1, dst, 1000, 1000); + => AvxIntrinsics.MatMulTranX(src, src1, dst, 1000, 1000); } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs index 91dcf32d40..1ee1c2bb53 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs @@ -235,7 +235,7 @@ public unsafe void MatMulX() fixed (float* psrc = &src[0]) fixed (float* pdst = &dst[0]) { - Thunk.MatMulX(true, psrc, psrc, pdst, 1000, 1000); + Thunk.MatMulX(psrc, psrc, pdst, 1000, 1000); } } @@ -245,7 +245,7 @@ public unsafe void MatMulTranX() fixed (float* psrc = &src[0]) fixed (float* pdst = &dst[0]) { - Thunk.MatMulTranX(true, psrc, psrc, pdst, 1000, 1000); + Thunk.MatMulTranX(psrc, psrc, pdst, 1000, 1000); } } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs index 1d6c9f7c40..f248baca4a 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs @@ -102,10 +102,10 @@ public void SdcaL1UpdateSU() [Benchmark] public void MatMulX() - => SseIntrinsics.MatMul(true, src, src1, dst, 1000, 1000); + => SseIntrinsics.MatMul(src, src1, dst, 1000, 1000); [Benchmark] public void MatMulTranX() - => SseIntrinsics.MatMulTran(true, src, src1, dst, 1000, 1000); + => SseIntrinsics.MatMulTran(src, src1, dst, 1000, 1000); } } diff --git a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs index e9cffb3c50..4456469f2d 100644 --- a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs +++ b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs @@ -90,23 +90,7 @@ public void MatMulTest(int matTest, int srcTest, int dstTest, float[] expected) AlignedArray src = _testSrcVectors[srcTest]; AlignedArray dst = _testDstVectors[dstTest]; - CpuMathUtils.MatTimesSrc(false, false, mat, src, dst, dst.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { -416.6801f, -415.6801f, -414.6801f, -413.6801f, -412.6801f, -411.6801f, -410.6801f, -409.6801f })] - [InlineData(1, 1, 0, new float[] { 1496f, 3673f, 5850f, 8027f, 10204f, 12381f, 14558f, 16735f })] - [InlineData(1, 0, 1, new float[] { 204f, 493f, 782f, 1071f, 1360f, 1649f, 1938f, 2227f, 2516f, 2805f, 3094f, 3383f, 3672f, 3961f, 4250f, 4539f })] - public void MatMulAddTest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - - CpuMathUtils.MatTimesSrc(false, true, mat, src, dst, dst.Size); + CpuMathUtils.MatTimesSrc(false, mat, src, dst, dst.Size); float[] actual = new float[dst.Size]; dst.CopyTo(actual, 0, dst.Size); Assert.Equal(expected, actual, _matMulComparer); @@ -122,23 +106,7 @@ public void MatMulTranTest(int matTest, int srcTest, int dstTest, float[] expect AlignedArray src = _testSrcVectors[srcTest]; AlignedArray dst = _testDstVectors[dstTest]; - CpuMathUtils.MatTimesSrc(true, false, mat, src, dst, src.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { 70.56001f, -84.68f, -349.36f, 501.24f, -3825.32f, -964.48f, 1174.2f, 125.44f })] - [InlineData(1, 0, 1, new float[] { 2724f, 2761f, 2798f, 2835f, 2872f, 2909f, 2946f, 2983f, 3020f, 3057f, 3094f, 3131f, 3168f, 3205f, 3242f, 3279f })] - [InlineData(1, 1, 0, new float[] { 11016f, 11153f, 11290f, 11427f, 11564f, 11701f, 11838f, 11975f })] - public void MatMulTranAddTest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - - CpuMathUtils.MatTimesSrc(true, true, mat, src, dst, src.Size); + CpuMathUtils.MatTimesSrc(true, mat, src, dst, src.Size); float[] actual = new float[dst.Size]; dst.CopyTo(actual, 0, dst.Size); Assert.Equal(expected, actual, _matMulComparer); @@ -155,24 +123,7 @@ public void MatMulPATest(int matTest, int srcTest, int dstTest, float[] expected AlignedArray dst = _testDstVectors[dstTest]; int[] idx = _testIndexArray; - CpuMathUtils.MatTimesSrc(false, false, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { 38.25002f, 39.25002f, 40.25002f, 41.25002f, 42.25002f, 43.25002f, 44.25002f, 45.25002f })] - [InlineData(1, 1, 0, new float[] { 910f, 2191f, 3472f, 4753f, 6034f, 7315f, 8596f, 9877f })] - [InlineData(1, 0, 1, new float[] { 95f, 232f, 369f, 506f, 643f, 780f, 917f, 1054f, 1191f, 1328f, 1465f, 1602f, 1739f, 1876f, 2013f, 2150f })] - public void MatMulPAAddTest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - int[] idx = _testIndexArray; - - CpuMathUtils.MatTimesSrc(false, true, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Size); + CpuMathUtils.MatTimesSrc(false, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Size); float[] actual = new float[dst.Size]; dst.CopyTo(actual, 0, dst.Size); Assert.Equal(expected, actual, _matMulComparer); @@ -189,24 +140,7 @@ public void MatMulTranPATest(int matTest, int srcTest, int dstTest, float[] expe AlignedArray dst = _testDstVectors[dstTest]; int[] idx = _testIndexArray; - CpuMathUtils.MatTimesSrc(true, false, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, src.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { 33.32f, -39.46f, -163.92f, 238.28f, -1804.29f, -452.81f, 557.65f, 62.93f })] - [InlineData(1, 0, 1, new float[] { 1265f, 1283f, 1301f, 1319f, 1337f, 1355f, 1373f, 1391f, 1409f, 1427f, 1445f, 1463f, 1481f, 1499f, 1517f, 1535f })] - [InlineData(1, 1, 0, new float[] { 6720f, 6801f, 6882f, 6963f, 7044f, 7125f, 7206f, 7287f })] - public void MatMulTranPAAddTest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - int[] idx = _testIndexArray; - - CpuMathUtils.MatTimesSrc(true, true, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, src.Size); + CpuMathUtils.MatTimesSrc(true, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, src.Size); float[] actual = new float[dst.Size]; dst.CopyTo(actual, 0, dst.Size); Assert.Equal(expected, actual, _matMulComparer); From 421e33d1b9155eca70d12f46fa392e271a2ee9ee Mon Sep 17 00:00:00 2001 From: Anipik Date: Sat, 13 Oct 2018 14:32:21 -0700 Subject: [PATCH 5/7] TransPA removed as nobody uses this combination of flags --- src/Microsoft.ML.CpuMath/Avx.cs | 17 ++--- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 64 ------------------- .../CpuMathUtils.netcoreapp.cs | 60 ++++------------- .../CpuMathUtils.netstandard.cs | 4 +- src/Microsoft.ML.CpuMath/Sse.cs | 14 +--- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 64 ------------------- src/Microsoft.ML.CpuMath/Thunk.cs | 7 -- src/Microsoft.ML.Transforms/RffTransform.cs | 2 +- src/Native/CpuMathNative/Avx.cpp | 40 ------------ src/Native/CpuMathNative/Sse.cpp | 35 ---------- .../UnitTests.cs | 19 +----- 11 files changed, 23 insertions(+), 303 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/Avx.cs b/src/Microsoft.ML.CpuMath/Avx.cs index 2736ad9137..f7769b3295 100644 --- a/src/Microsoft.ML.CpuMath/Avx.cs +++ b/src/Microsoft.ML.CpuMath/Avx.cs @@ -61,7 +61,7 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, Al } } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) { Contracts.Assert(Compat(mat)); @@ -73,8 +73,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (iposMin >= iposLim) { - if (!add) - dst.ZeroItems(); + dst.ZeroItems(); return; } Contracts.AssertNonEmpty(rgposSrc); @@ -85,16 +84,8 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo fixed (float* psrc = &srcValues.Items[0]) fixed (int* ppossrc = &rgposSrc[0]) { - if (!tran) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulPX(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); - } - else - { - Contracts.Assert(0 <= crun && crun <= srcValues.Size); - Thunk.MatMulTranPX(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size); - } + Contracts.Assert(0 <= crun && crun <= dst.Size); + Thunk.MatMulPX(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); } } } diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index 0183d47d8e..edddfdfb1e 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -530,70 +530,6 @@ public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int } } - // Partial sparse source vector. - public static unsafe void MatMulTranPX(AlignedArray mat, int[] rgposSrc, AlignedArray src, - int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow) - { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); - - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) - fixed (int* pposSrc = &rgposSrc[0]) - { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - - int* ppos = pposSrc + iposMin; - int* pposEnd = pposSrc + iposEnd; - float* pDstEnd = pdst + crow; - - int col = *ppos - posMin; - ppos++; - - Vector256 x0 = Avx.SetAllVector256(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; - - while (pDstCurrent < pDstEnd) - { - Vector256 x1 = Avx.LoadAlignedVector256(pMatCurrent); - x1 = Avx.Multiply(x1, x0); - Avx.StoreAligned(pDstCurrent, x1); - - pDstCurrent += 8; - pMatCurrent += 8; - } - - // REVIEW: Should we explore unrolling the outer loop? - while (ppos < pposEnd) - { - int col = *ppos - posMin; - - Vector256 x0 = Avx.SetAllVector256(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; - - while (pDstCurrent < pDstEnd) - { - Vector256 x1 = Avx.LoadAlignedVector256(pMatCurrent); - Vector256 x2 = Avx.LoadAlignedVector256(pDstCurrent); - x1 = Avx.Multiply(x1, x0); - x2 = Avx.Add(x2, x1); - Avx.StoreAligned(pDstCurrent, x2); - - pDstCurrent += 8; - pMatCurrent += 8; - } - - ppos++; - } - } - } - // dst[i] += scale public static unsafe void AddScalarU(float scalar, Span dst) { diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs index 356d68f72b..64b27f8965 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs @@ -88,7 +88,7 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, Al } } - public static void MatTimesSrc(bool tran, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) { Contracts.AssertValue(rgposSrc); @@ -108,62 +108,26 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, int[] rgposSrc, Alig if (Avx.IsSupported) { - if (!tran) - { - Contracts.Assert(crun <= dst.Size); - AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); - } - else - { - Contracts.Assert(crun <= srcValues.Size); - AvxIntrinsics.MatMulTranPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size); - } + Contracts.Assert(crun <= dst.Size); + AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); } else if (Sse.IsSupported) { - if (!tran) - { - Contracts.Assert(crun <= dst.Size); - SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); - } - else - { - Contracts.Assert(crun <= srcValues.Size); - SseIntrinsics.MatMulTranPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size); - } + Contracts.Assert(crun <= dst.Size); + SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); } else { - if (!tran) - { - Contracts.Assert(crun <= dst.Size); - for (int i = 0; i < crun; i++) - { - float dotProduct = 0; - for (int j = iposMin; j < iposLim; j++) - { - int col = rgposSrc[j] - posMin; - dotProduct += mat[i * srcValues.Size + col] * srcValues[col]; - } - - dst[i] = dotProduct; - } - } - else + Contracts.Assert(crun <= dst.Size); + for (int i = 0; i < crun; i++) { - Contracts.Assert(crun <= srcValues.Size); - for (int i = 0; i < dst.Size; i++) + float dotProduct = 0; + for (int j = iposMin; j < iposLim; j++) { - float dotProduct = 0; - for (int j = iposMin; j < iposLim; j++) - { - int col = rgposSrc[j] - posMin; - dotProduct += mat[col * dst.Size + i] * srcValues[col]; - } - - dst[i] = dotProduct; + int col = rgposSrc[j] - posMin; + dotProduct += mat[i * srcValues.Size + col] * srcValues[col]; } - + dst[i] = dotProduct; } } } diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs index 72d5210e20..36518df075 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs @@ -17,8 +17,8 @@ public static int GetVectorAlignment() public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, mat, src, dst, crun); - public static void MatTimesSrc(bool tran, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, - int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun); + public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun); public static void Add(float a, float[] dst, int count) => SseUtils.Add(a, dst, count); diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs index 58f80318d0..abbed0528b 100644 --- a/src/Microsoft.ML.CpuMath/Sse.cs +++ b/src/Microsoft.ML.CpuMath/Sse.cs @@ -54,7 +54,7 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, Al } } - public static void MatTimesSrc(bool tran, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) { Contracts.Assert(Compat(mat)); @@ -77,16 +77,8 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, int[] rgposSrc, Alig fixed (float* psrc = &srcValues.Items[0]) fixed (int* ppossrc = &rgposSrc[0]) { - if (!tran) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulPA(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); - } - else - { - Contracts.Assert(0 <= crun && crun <= srcValues.Size); - Thunk.MatMulTranPA(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size); - } + Contracts.Assert(0 <= crun && crun <= dst.Size); + Thunk.MatMulPA(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); } } } diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 0a3e57953f..eb3b0a1c0d 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -513,70 +513,6 @@ public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int } } - // Partial sparse source vector. - public static unsafe void MatMulTranPA(AlignedArray mat, int[] rgposSrc, AlignedArray src, - int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow) - { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); - - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) - fixed (int* pposSrc = &rgposSrc[0]) - { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - - int* ppos = pposSrc + iposMin; - int* pposEnd = pposSrc + iposEnd; - float* pDstEnd = pdst + crow; - - int col = *ppos - posMin; - ppos++; - - Vector128 x0 = Sse.SetAllVector128(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; - - while (pDstCurrent < pDstEnd) - { - Vector128 x1 = Sse.LoadAlignedVector128(pMatCurrent); - x1 = Sse.Multiply(x1, x0); - Sse.StoreAligned(pDstCurrent, x1); - - pDstCurrent += 4; - pMatCurrent += 4; - } - - // REVIEW: Should we explore unrolling the outer loop? - while (ppos < pposEnd) - { - int col = *ppos - posMin; - - Vector128 x0 = Sse.SetAllVector128(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; - - while (pDstCurrent < pDstEnd) - { - Vector128 x1 = Sse.LoadAlignedVector128(pMatCurrent); - Vector128 x2 = Sse.LoadAlignedVector128(pDstCurrent); - x1 = Sse.Multiply(x1, x0); - x2 = Sse.Add(x2, x1); - Sse.StoreAligned(pDstCurrent, x2); - - pDstCurrent += 4; - pMatCurrent += 4; - } - - ppos++; - } - } - } - // dst[i] += scale public static unsafe void AddScalarU(float scalar, Span dst) { diff --git a/src/Microsoft.ML.CpuMath/Thunk.cs b/src/Microsoft.ML.CpuMath/Thunk.cs index 897fd0901a..22938bb0d7 100644 --- a/src/Microsoft.ML.CpuMath/Thunk.cs +++ b/src/Microsoft.ML.CpuMath/Thunk.cs @@ -69,13 +69,6 @@ public static extern void RespNormU(bool add, float alpha, float beta, bool avgO [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void MatMulTranX(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranPA(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, - int posMin, int iposMin, int iposLim, float* pdst, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranPX(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, - int posMin, int iposMin, int iposLim, float* pdst, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void MatMulTranRU(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs, /*const*/ float* psrc, float* pdst, int crow, int ccol); diff --git a/src/Microsoft.ML.Transforms/RffTransform.cs b/src/Microsoft.ML.Transforms/RffTransform.cs index 8cce5bdeeb..0dd2c95ffd 100644 --- a/src/Microsoft.ML.Transforms/RffTransform.cs +++ b/src/Microsoft.ML.Transforms/RffTransform.cs @@ -618,7 +618,7 @@ private void TransformFeatures(ref VBuffer src, ref VBuffer dst, T // This overload of MatTimesSrc ignores the values in slots that are not in src.Indices, so there is // no need to zero them out. featuresAligned.CopyFrom(src.Indices, src.Values, 0, 0, src.Count, zeroItems: false); - CpuMathUtils.MatTimesSrc(false, transformInfo.RndFourierVectors, src.Indices, featuresAligned, 0, 0, + CpuMathUtils.MatTimesSrc(transformInfo.RndFourierVectors, src.Indices, featuresAligned, 0, 0, src.Count, productAligned, transformInfo.NewDim); } diff --git a/src/Native/CpuMathNative/Avx.cpp b/src/Native/CpuMathNative/Avx.cpp index fd2e78ed5a..b52e28cd80 100644 --- a/src/Native/CpuMathNative/Avx.cpp +++ b/src/Native/CpuMathNative/Avx.cpp @@ -394,46 +394,6 @@ EXPORT_API(void) MatMulTranX(bool add, _In_ const float * pmat, _In_ const float _vleave(); } -// Partial sparse source vector. -EXPORT_API(void) MatMulTranPX(bool add, _In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, - int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow) -{ - const int * ppos = pposSrc + iposMin; - const int * pposLim = pposSrc + iposLim; - const float * pdLim = pdst + crow; - - if (!add) - { - int col = *ppos++ - posMin; - const float * pm = pmat + col * crow; - __m256 x0 = _mm256_set1_ps(psrc[col]); - for (float * pd = pdst; pd < pdLim; pd += 8, pm += 8) - { - __m256 x1 = _mm256_load_ps(pm); - x1 = _mm256_mul_ps(x1, x0); - _mm256_store_ps(pd, x1); - } - } - - // REVIEW: Should we explore unrolling the outer loop? - for (; ppos < pposLim; ppos++) - { - int col = *ppos - posMin; - __m256 x0 = _mm256_set1_ps(psrc[col]); - const float * pm = pmat + col * crow; - for (float * pd = pdst; pd < pdLim; pd += 8, pm += 8) - { - __m256 x1 = _mm256_load_ps(pm); - __m256 x2 = _mm256_load_ps(pd); - x1 = _mm256_mul_ps(x1, x0); - x2 = _mm256_add_ps(x2, x1); - _mm256_store_ps(pd, x2); - } - } - - _vleave(); -} - // Sparse matrix. EXPORT_API(void) MatMulTranRX(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pd, int crow, int ccol) diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 9ca29868fe..d29f3312d2 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -748,41 +748,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I } } -// Partial sparse source vector. -EXPORT_API(void) MatMulTranPA(_In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, - int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow) -{ - const int * ppos = pposSrc + iposMin; - const int * pposLim = pposSrc + iposLim; - const float * pdLim = pdst + crow; - - int col = *ppos++ - posMin; - const float * pm = pmat + col * crow; - __m128 x0 = _mm_set1_ps(psrc[col]); - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) - { - __m128 x1 = _mm_load_ps(pm); - x1 = _mm_mul_ps(x1, x0); - _mm_store_ps(pd, x1); - } - - // REVIEW: Should we explore unrolling the outer loop? - for (; ppos < pposLim; ppos++) - { - int col = *ppos - posMin; - __m128 x0 = _mm_set1_ps(psrc[col]); - const float * pm = pmat + col * crow; - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) - { - __m128 x1 = _mm_load_ps(pm); - __m128 x2 = _mm_load_ps(pd); - x1 = _mm_mul_ps(x1, x0); - x2 = _mm_add_ps(x2, x1); - _mm_store_ps(pd, x2); - } - } -} - // Sparse matrix. EXPORT_API(void) MatMulTranRU(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pd, int crow, int ccol) diff --git a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs index 4456469f2d..cb943db500 100644 --- a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs +++ b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs @@ -123,24 +123,7 @@ public void MatMulPATest(int matTest, int srcTest, int dstTest, float[] expected AlignedArray dst = _testDstVectors[dstTest]; int[] idx = _testIndexArray; - CpuMathUtils.MatTimesSrc(false, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { 33.32f, -40.46f, -165.92f, 235.28f, -1808.29f, -457.81f, 551.65f, 55.93f })] - [InlineData(1, 0, 1, new float[] { 1265f, 1282f, 1299f, 1316f, 1333f, 1350f, 1367f, 1384f, 1401f, 1418f, 1435f, 1452f, 1469f, 1486f, 1503f, 1520f })] - [InlineData(1, 1, 0, new float[] { 6720f, 6800f, 6880f, 6960f, 7040f, 7120f, 7200f, 7280f })] - public void MatMulTranPATest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - int[] idx = _testIndexArray; - - CpuMathUtils.MatTimesSrc(true, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, src.Size); + CpuMathUtils.MatTimesSrc(mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Size); float[] actual = new float[dst.Size]; dst.CopyTo(actual, 0, dst.Size); Assert.Equal(expected, actual, _matMulComparer); From 348f329eeb988e88a8a989cc48322695e2d88656 Mon Sep 17 00:00:00 2001 From: Anipik Date: Mon, 15 Oct 2018 14:33:39 -0700 Subject: [PATCH 6/7] removed firstTime and corrected nativePerformanceTests --- src/Microsoft.ML.CpuMath/AvxIntrinsics.cs | 161 +++++++++++++++-- src/Native/CpuMathNative/Sse.cpp | 162 +++++++++++++++--- .../NativePerformanceTests.cs | 10 +- 3 files changed, 289 insertions(+), 44 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index edddfdfb1e..fcf4c1fdd7 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -367,10 +367,10 @@ public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; - bool firstTime = true; - // We do 4-way unrolling - while (pSrcCurrent < pSrcEnd) + // The reason behind adding the if condtion instead of boolean flag + // is to avoid branching in codegen. + if (pSrcCurrent < pSrcEnd) { Vector128 h01 = Sse.LoadVector128(pSrcCurrent); // Replicate each slot of h01 (ABCD) into its own register. @@ -404,10 +404,145 @@ public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int x22 = Avx.Add(x22, x32); x02 = Avx.Add(x02, x22); - if (!firstTime) + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + else + { + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); + + Avx.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 7) + { + remainder = length % 8; + while (pDstCurrent + 8 <= pDstEnd) { - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); + float* pMatTemp = pMatCurrent; + + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pDstCurrent -= (8 - remainder); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; + } + + // We do 4-way unrolling + while (pSrcCurrent < pSrcEnd) + { + Vector128 h01 = Sse.LoadVector128(pSrcCurrent); + // Replicate each slot of h01 (ABCD) into its own register. + Vector128 h11 = Avx.Permute(h01, 0x55); // B + Vector128 h21 = Avx.Permute(h01, 0xAA); // C + Vector128 h31 = Avx.Permute(h01, 0xFF); // D + h01 = Avx.Permute(h01, 0x00); // A + + Vector256 x01 = Avx.SetHighLow(h01, h01); + Vector256 x11 = Avx.SetHighLow(h11, h11); + Vector256 x21 = Avx.SetHighLow(h21, h21); + Vector256 x31 = Avx.SetHighLow(h31, h31); + + int length = crow; + float* pDstCurrent = pdst; + + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); + + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); Avx.Store(pDstCurrent, x02); pDstCurrent += 8; @@ -446,10 +581,7 @@ public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int Vector256 x3 = Avx.LoadVector256(pDstCurrent); x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); - if (!firstTime) - { - x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); - } + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); Avx.Store(pDstCurrent, x02); pMatCurrent += misalignment; @@ -472,10 +604,7 @@ public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int x22 = Avx.Add(x22, x32); x02 = Avx.Add(x02, x22); - if (!firstTime) - { - x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); - } + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); Avx.Store(pDstCurrent, x02); pDstCurrent += 8; @@ -512,10 +641,7 @@ public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int Vector256 x3 = Avx.LoadVector256(pDstCurrent); x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); - if (!firstTime) - { - x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); - } + x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); Avx.Store(pDstCurrent, x02); pDstCurrent += 8; @@ -523,7 +649,6 @@ public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int } } - firstTime = false; pMatCurrent += 3 * crow; pSrcCurrent += 4; } diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index d29f3312d2..9fde43fb12 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -592,11 +592,10 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I const float* pMatCurrent = pmat; const float* pSrcCurrent = psrc; - bool firstTime = true; - while (pSrcCurrent < pSrcEnd) + if (pSrcCurrent < pSrcEnd) { - __m128 x01 = _mm_loadu_ps(pSrcCurrent); + __m128 x01 = _mm_loadu_ps(pSrcCurrent); // Replicate each slot of x01 into its own register. __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); @@ -624,11 +623,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I x22 = _mm_add_ps(x22, x32); x02 = _mm_add_ps(x02, x22); - if (!firstTime) - { - x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); - } - _mm_storeu_ps(pDstCurrent, x02); pDstCurrent += 4; pMatCurrent += 4; @@ -664,11 +658,6 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I __m128 x3 = _mm_loadu_ps(pDstCurrent); x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); - if (!firstTime) - { - x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); - } - _mm_storeu_ps(pDstCurrent, x02); pMatCurrent += misalignment; pDstCurrent += misalignment; @@ -689,11 +678,7 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I x02 = _mm_add_ps(x02, x12); x22 = _mm_add_ps(x22, x32); x02 = _mm_add_ps(x02, x22); - - if (!firstTime) - { - x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); - } + _mm_storeu_ps(pDstCurrent, x02); pDstCurrent += 4; @@ -731,18 +716,151 @@ EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _I __m128 x3 = _mm_loadu_ps(pDstCurrent); x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); - if (!firstTime) + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += 4; + pDstCurrent += 4; + } + } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; + } + + while (pSrcCurrent < pSrcEnd) + { + __m128 x01 = _mm_loadu_ps(pSrcCurrent); + // Replicate each slot of x01 into its own register. + __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); + __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); + __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); + x01 = _mm_shuffle_ps(x01, x01, 0x00); + + int length = crow; + float* pDstCurrent = pdst; + + uintptr_t address = (uintptr_t)(pMatCurrent); + uintptr_t misalignment = address % 16; + int remainder = 0; + + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); + + _mm_storeu_ps(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + int remainder = 0; + if (misalignment != 0) + { + misalignment >>= 2; + misalignment = 4 - misalignment; + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (( 4 - misalignment) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); + + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + + if(length > 3) + { + remainder = length % 4; + while (pDstCurrent < pDstEnd) { - x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); + + _mm_storeu_ps(pDstCurrent, x02); + + pDstCurrent += 4; + pMatCurrent += 4; } + } + else + { + length = remainder; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (( 4 - remainder) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); + + x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); _mm_storeu_ps(pDstCurrent, x02); pMatCurrent += 4; pDstCurrent += 4; } } - - firstTime = false; + pMatCurrent += 3 * crow; pSrcCurrent += 4; } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs index 1ee1c2bb53..07958c19ed 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs @@ -230,22 +230,24 @@ public unsafe void SdcaL1UpdateSU() } [Benchmark] - public unsafe void MatMulX() + public unsafe void MatMul() { fixed (float* psrc = &src[0]) fixed (float* pdst = &dst[0]) + fixed (float* psrc1 = &src1[0]) { - Thunk.MatMulX(psrc, psrc, pdst, 1000, 1000); + Thunk.MatMul(psrc1, psrc, pdst, 1000, 1000); } } [Benchmark] - public unsafe void MatMulTranX() + public unsafe void MatMulTran() { fixed (float* psrc = &src[0]) fixed (float* pdst = &dst[0]) + fixed (float* psrc1 = &src1[0]) { - Thunk.MatMulTranX(psrc, psrc, pdst, 1000, 1000); + Thunk.MatMulTran(psrc1, psrc, pdst, 1000, 1000); } } } From aa805f0c094604cb83f53335e94b31604561eb8e Mon Sep 17 00:00:00 2001 From: Anipik Date: Tue, 16 Oct 2018 10:55:27 -0700 Subject: [PATCH 7/7] removed branch from hot path sseintrinsics --- src/Microsoft.ML.CpuMath/SseIntrinsics.cs | 158 +++++++++++++++++++--- 1 file changed, 138 insertions(+), 20 deletions(-) diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index 7f9cb16788..1a99c5bd92 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -356,10 +356,10 @@ public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; - bool firstTime = true; - // We do 4-way unrolling - while (pSrcCurrent < pSrcEnd) + // The reason behind adding the if condtion instead of boolean flag + // is to avoid branching in codegen. + if (pSrcCurrent < pSrcEnd) { Vector128 x01 = Sse.LoadVector128(pSrcCurrent); // Replicate each 32-bit slot of x01 (ABCD) into its own register. @@ -388,10 +388,140 @@ public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int x22 = Sse.Add(x22, x32); x02 = Sse.Add(x02, x22); - if (!firstTime) + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); + + Sse.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 4) + { + remainder = length % 4; + while (pDstCurrent + 4 <= pDstEnd) { - x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); + float* pMatTemp = pMatCurrent; + + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; + } + + // We do 4-way unrolling + while (pSrcCurrent < pSrcEnd) + { + Vector128 x01 = Sse.LoadVector128(pSrcCurrent); + // Replicate each 32-bit slot of x01 (ABCD) into its own register. + Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B + Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C + Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D + x01 = Sse.Shuffle(x01, x01, 0x00); // A + + int length = crow; + float* pDstCurrent = pdst; + + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 16); + + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); Sse.Store(pDstCurrent, x02); pDstCurrent += 4; @@ -430,10 +560,7 @@ public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int Vector128 x3 = Sse.LoadVector128(pDstCurrent); x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); - if (!firstTime) - { - x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); - } + x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); Sse.Store(pDstCurrent, x02); pMatCurrent += misalignment; @@ -456,11 +583,7 @@ public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int x22 = Sse.Add(x22, x32); x02 = Sse.Add(x02, x22); - if (!firstTime) - { - x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); - } - + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); Sse.Store(pDstCurrent, x02); pDstCurrent += 4; pMatCurrent += 4; @@ -496,18 +619,13 @@ public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int Vector128 x3 = Sse.LoadVector128(pDstCurrent); x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); - if (!firstTime) - { - x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); - } - + x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); Sse.Store(pDstCurrent, x02); pDstCurrent += 4; pMatCurrent += 4; } } - firstTime = false; pMatCurrent += 3 * crow; pSrcCurrent += 4; }