diff --git a/src/Microsoft.ML.CpuMath/Avx.cs b/src/Microsoft.ML.CpuMath/Avx.cs index 5d4610d9bc..f7769b3295 100644 --- a/src/Microsoft.ML.CpuMath/Avx.cs +++ b/src/Microsoft.ML.CpuMath/Avx.cs @@ -34,7 +34,7 @@ public static bool CheckAvx() return Thunk.ChkAvx(); } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) + public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) { Contracts.Assert(Compat(mat)); Contracts.Assert(Compat(src)); @@ -50,18 +50,18 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulX(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); + Thunk.MatMulX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); } else { Contracts.Assert(0 <= crun && crun <= src.Size); - Thunk.MatMulTranX(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); + Thunk.MatMulTranX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); } } } } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) { Contracts.Assert(Compat(mat)); @@ -73,8 +73,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (iposMin >= iposLim) { - if (!add) - dst.ZeroItems(); + dst.ZeroItems(); return; } Contracts.AssertNonEmpty(rgposSrc); @@ -85,16 +84,8 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo fixed (float* psrc = &srcValues.Items[0]) fixed (int* ppossrc = &rgposSrc[0]) { - if (!tran) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulPX(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); - } - else - { - Contracts.Assert(0 <= crun && crun <= srcValues.Size); - Thunk.MatMulTranPX(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size); - } + Contracts.Assert(0 <= crun && crun <= dst.Size); + Thunk.MatMulPX(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); } } } diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index 415f1656c4..1107723e87 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -142,20 +142,22 @@ private static Vector256 GetNewDst256(in Vector256 xDst1, in Vecto } // Multiply matrix times vector into vector. - public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) - { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); + MatMulX(mat.Items, src.Items, dst.Items, crow, ccol); + } + public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int crow, int ccol) + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* pmat = &mat[0]) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + { float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; @@ -164,29 +166,117 @@ public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, while (pDstCurrent < pDstEnd) { Vector256 res0 = Avx.SetZeroVector256(); - Vector256 res1 = res0; - Vector256 res2 = res0; - Vector256 res3 = res0; + Vector256 res1 = Avx.SetZeroVector256(); + Vector256 res2 = Avx.SetZeroVector256(); + Vector256 res3 = Avx.SetZeroVector256(); + int length = ccol; float* pSrcCurrent = psrc; - while (pSrcCurrent < pSrcEnd) + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); + + int remainder = 0; + if ((misalignment & 3) != 0) { - float* pMatTemp = pMatCurrent; + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations + while (pSrcCurrent < pSrcEnd) + { + Vector256 vector = Avx.LoadVector256(pSrcCurrent); - Vector256 x01 = Avx.LoadAlignedVector256(pMatTemp); - Vector256 x11 = Avx.LoadAlignedVector256(pMatTemp += ccol); - Vector256 x21 = Avx.LoadAlignedVector256(pMatTemp += ccol); - Vector256 x31 = Avx.LoadAlignedVector256(pMatTemp += ccol); - Vector256 x02 = Avx.LoadAlignedVector256(pSrcCurrent); + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); - res0 = Avx.Add(res0, Avx.Multiply(x01, x02)); - res1 = Avx.Add(res1, Avx.Multiply(x11, x02)); - res2 = Avx.Add(res2, Avx.Multiply(x21, x02)); - res3 = Avx.Add(res3, Avx.Multiply(x31, x02)); + res0 = Avx.Add(res0, x01); + res1 = Avx.Add(res1, x11); + res2 = Avx.Add(res2, x21); + res3 = Avx.Add(res3, x31); - pSrcCurrent += 8; - pMatCurrent += 8; + pSrcCurrent += 8; + pMatCurrent += 8; + } + } + else + { + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + + Vector256 mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); + + res0 = Avx.Multiply(x01, vector); + res1 = Avx.Multiply(x11, vector); + res2 = Avx.Multiply(x21, vector); + res3 = Avx.Multiply(x31, vector); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } + + if (length > 7) + { + remainder = length % 8; + while (pSrcCurrent + 8 <= pSrcEnd) + { + Vector256 vector = Avx.LoadVector256(pSrcCurrent); + + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.Multiply(vector, Avx.LoadVector256(pMatTemp += ccol)); + + res0 = Avx.Add(res0, x01); + res1 = Avx.Add(res1, x11); + res2 = Avx.Add(res2, x21); + res3 = Avx.Add(res3, x31); + + pSrcCurrent += 8; + pMatCurrent += 8; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pSrcCurrent -= (8 - remainder); + + Vector256 mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + + float* pMatTemp = pMatCurrent; + Vector256 x01 = Avx.And(mask, Avx.LoadVector256(pMatTemp)); + Vector256 x11 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x21 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 x31 = Avx.And(mask, Avx.LoadVector256(pMatTemp += ccol)); + Vector256 vector = Avx.And(mask, Avx.LoadVector256(pSrcCurrent)); + + res0 = Avx.Add(res0, Avx.Multiply(x01, vector)); + res1 = Avx.Add(res1, Avx.Multiply(x11, vector)); + res2 = Avx.Add(res2, Avx.Multiply(x21, vector)); + res3 = Avx.Add(res3, Avx.Multiply(x31, vector)); + + pMatCurrent += 8; + pSrcCurrent += 8; + } } // Add up the entries of each, with the 4 results in res0 @@ -195,11 +285,7 @@ public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, res0 = Avx.HorizontalAdd(res0, res2); Vector128 sum = Sse.Add(Avx.GetLowerHalf(res0), GetHigh(in res0)); - if (add) - { - sum = Sse.Add(sum, Sse.LoadAlignedVector128(pDstCurrent)); - } - Sse.StoreAligned(pDstCurrent, sum); + Sse.Store(pDstCurrent, sum); pDstCurrent += 4; pMatCurrent += 3 * ccol; @@ -208,7 +294,7 @@ public static unsafe void MatMulX(bool add, AlignedArray mat, AlignedArray src, } // Partial sparse source vector. - public static unsafe void MatMulPX(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src, + public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { Contracts.Assert(HasCompatibleAlignment(mat)); @@ -255,193 +341,317 @@ public static unsafe void MatMulPX(bool add, AlignedArray mat, int[] rgposSrc, A ppos++; } - if (add) - { - result = Avx.Add(result, Avx.LoadAlignedVector256(pDstCurrent)); - } Avx.StoreAligned(pDstCurrent, result); - pDstCurrent += 8; pm0 += 8 * ccol; } } } - public static unsafe void MatMulTranX(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulTranX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) - { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); + MatMulTranX(mat.Items, src.Items, dst.Items, crow, ccol); + } + public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int crow, int ccol) + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* pmat = &mat[0]) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + { float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; - // We do 4-way unrolling - if (!add) + // The reason behind adding the if condtion instead of boolean flag + // is to avoid branching in codegen. + if (pSrcCurrent < pSrcEnd) { - Vector128 h01 = Sse.LoadAlignedVector128(pSrcCurrent); + Vector128 h01 = Sse.LoadVector128(pSrcCurrent); // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Sse.Shuffle(h01, h01, 0x55); // B - Vector128 h21 = Sse.Shuffle(h01, h01, 0xAA); // C - Vector128 h31 = Sse.Shuffle(h01, h01, 0xFF); // D - h01 = Sse.Shuffle(h01, h01, 0x00); // A + Vector128 h11 = Avx.Permute(h01, 0x55); // B + Vector128 h21 = Avx.Permute(h01, 0xAA); // C + Vector128 h31 = Avx.Permute(h01, 0xFF); // D + h01 = Avx.Permute(h01, 0x00); // A Vector256 x01 = Avx.SetHighLow(h01, h01); Vector256 x11 = Avx.SetHighLow(h11, h11); Vector256 x21 = Avx.SetHighLow(h21, h21); Vector256 x31 = Avx.SetHighLow(h31, h31); - pSrcCurrent += 4; - + int length = crow; float* pDstCurrent = pdst; - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector256 x02 = Avx.LoadAlignedVector256(pMatTemp); - Vector256 x12 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x22 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x32 = Avx.LoadAlignedVector256(pMatTemp += crow); - - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); - - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); - Avx.StoreAligned(pDstCurrent, x02); + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + else + { + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); + + Avx.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 7) + { + remainder = length % 8; + while (pDstCurrent + 8 <= pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + else + { + remainder = length; + } - pDstCurrent += 8; - pMatCurrent += 8; + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pDstCurrent -= (8 - remainder); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } } pMatCurrent += 3 * crow; + pSrcCurrent += 4; } + // We do 4-way unrolling while (pSrcCurrent < pSrcEnd) { - Vector128 h01 = Sse.LoadAlignedVector128(pSrcCurrent); + Vector128 h01 = Sse.LoadVector128(pSrcCurrent); // Replicate each slot of h01 (ABCD) into its own register. - Vector128 h11 = Sse.Shuffle(h01, h01, 0x55); // B - Vector128 h21 = Sse.Shuffle(h01, h01, 0xAA); // C - Vector128 h31 = Sse.Shuffle(h01, h01, 0xFF); // D - h01 = Sse.Shuffle(h01, h01, 0x00); // A + Vector128 h11 = Avx.Permute(h01, 0x55); // B + Vector128 h21 = Avx.Permute(h01, 0xAA); // C + Vector128 h31 = Avx.Permute(h01, 0xFF); // D + h01 = Avx.Permute(h01, 0x00); // A Vector256 x01 = Avx.SetHighLow(h01, h01); Vector256 x11 = Avx.SetHighLow(h11, h11); Vector256 x21 = Avx.SetHighLow(h21, h21); Vector256 x31 = Avx.SetHighLow(h31, h31); + int length = crow; float* pDstCurrent = pdst; - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - - Vector256 x02 = Avx.LoadAlignedVector256(pMatTemp); - Vector256 x12 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x22 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x32 = Avx.LoadAlignedVector256(pMatTemp += crow); - Vector256 x3 = Avx.LoadAlignedVector256(pDstCurrent); + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 32); - x02 = Avx.Multiply(x01, x02); - x12 = Avx.Multiply(x11, x12); - x22 = Avx.Multiply(x21, x22); - x32 = Avx.Multiply(x31, x32); + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); - x02 = Avx.Add(x02, x12); - x22 = Avx.Add(x22, x32); - x02 = Avx.Add(x02, x22); - x3 = Avx.Add(x02, x3); + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); - Avx.StoreAligned(pDstCurrent, x3); + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); - pDstCurrent += 8; - pMatCurrent += 8; + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } } - - pMatCurrent += 3 * crow; - pSrcCurrent += 4; - } - } - } - - // Partial sparse source vector. - public static unsafe void MatMulTranPX(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src, - int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow) - { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); - - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) - fixed (int* pposSrc = &rgposSrc[0]) - { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - - int* ppos = pposSrc + iposMin; - int* pposEnd = pposSrc + iposEnd; - float* pDstEnd = pdst + crow; - - if (!add) - { - int col = *ppos - posMin; - ppos++; - - Vector256 x0 = Avx.SetAllVector256(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; - - while (pDstCurrent < pDstEnd) + else { - Vector256 x1 = Avx.LoadAlignedVector256(pMatCurrent); - x1 = Avx.Multiply(x1, x0); - Avx.StoreAligned(pDstCurrent, x1); - - pDstCurrent += 8; - pMatCurrent += 8; - } - } - - // REVIEW: Should we explore unrolling the outer loop? - while (ppos < pposEnd) - { - int col = *ppos - posMin; - - Vector256 x0 = Avx.SetAllVector256(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 8 - misalignment; + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(leadingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Or(x02, Avx.And(x3, trailingMask)); + + x02 = Avx.Add(x02, Avx.And(x3, leadingMask)); + + Avx.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 7) + { + remainder = length % 8; + while (pDstCurrent + 8 <= pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector256 x02 = Avx.Multiply(x01, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.Multiply(x11, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.Multiply(x21, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.Multiply(x31, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + x02 = Avx.Add(x02, Avx.LoadVector256(pDstCurrent)); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } + } + else + { + remainder = length; + } - while (pDstCurrent < pDstEnd) - { - Vector256 x1 = Avx.LoadAlignedVector256(pMatCurrent); - Vector256 x2 = Avx.LoadAlignedVector256(pDstCurrent); - x1 = Avx.Multiply(x1, x0); - x2 = Avx.Add(x2, x1); - Avx.StoreAligned(pDstCurrent, x2); - - pDstCurrent += 8; - pMatCurrent += 8; + if (remainder != 0) + { + pMatCurrent -= (8 - remainder); + pDstCurrent -= (8 - remainder); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); + + float* pMatTemp = pMatCurrent; + Vector256 x02 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp)); + Vector256 x12 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x22 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + Vector256 x32 = Avx.And(trailingMask, Avx.LoadVector256(pMatTemp += crow)); + + x02 = Avx.Multiply(x01, x02); + x12 = Avx.Multiply(x11, x12); + x22 = Avx.Multiply(x21, x22); + x32 = Avx.Multiply(x31, x32); + + x02 = Avx.Add(x02, x12); + x22 = Avx.Add(x22, x32); + x02 = Avx.Add(x02, x22); + + Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); + Vector256 x3 = Avx.LoadVector256(pDstCurrent); + x02 = Avx.Or(x02, Avx.And(x3, leadingMask)); + + x02 = Avx.Add(x02, Avx.And(x3, trailingMask)); + + Avx.Store(pDstCurrent, x02); + pDstCurrent += 8; + pMatCurrent += 8; + } } - ppos++; + pMatCurrent += 3 * crow; + pSrcCurrent += 4; } } } @@ -499,7 +709,7 @@ public static unsafe void Scale(float scale, Span dst) if (length < 8) { - switch(length) + switch (length) { case 7: dst[6] *= scale; goto case 6; case 6: dst[5] *= scale; goto case 5; @@ -541,7 +751,7 @@ public static unsafe void Scale(float scale, Span dst) Vector256 result = Avx.LoadVector256(pDstCurrent); Vector256 leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); - Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (( 8 - misalignment) * 8)); + Vector256 trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + ((8 - misalignment) * 8)); Vector256 temp = Avx.And(result, leadingMask); result = Avx.And(result, trailingMask); diff --git a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs index 30308f219d..bcc92234d0 100644 --- a/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs +++ b/src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs @@ -75,40 +75,32 @@ public static void AssertCompatible(ICpuFullMatrix mat, ICpuVector src, ICpuVect /// /// Matrix multiplication: - /// if (add) - /// dst = mat * src - /// else - /// dest += mat * src + /// dst = mat * src /// - /// The addition flag /// The multiplier matrix /// The source vector /// The destination vector - public static void MatTimesSrc(bool add, ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) + public static void MatTimesSrc(ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) { bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol); AssertCompatible(mat, src, dst); var m = A(mat); - CpuMathUtils.MatTimesSrc(colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt); + CpuMathUtils.MatTimesSrc(colMajor, m.Items, A(src).Items, A(dst).Items, m.RunCnt); } /// /// Matrix transpose multiplication: - /// if (add) - /// dst = mat' * src - /// else - /// dest += mat' * src + /// dst = mat' * src /// - /// The addition flag /// The multiplier matrix /// The source vector /// The destination vector - public static void MatTranTimesSrc(bool add, ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) + public static void MatTranTimesSrc(ICpuFullMatrix mat, ICpuVector src, ICpuVector dst) { bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol); AssertCompatible(mat, dst, src); var m = A(mat); - CpuMathUtils.MatTimesSrc(!colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt); + CpuMathUtils.MatTimesSrc(!colMajor, m.Items, A(src).Items, A(dst).Items, m.RunCnt); } } diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs index fcbbb35222..aa8ff85bc6 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs @@ -24,7 +24,7 @@ public static partial class CpuMathUtils public static int GetVectorAlignment() => Avx.IsSupported ? Vector256Alignment : (Sse.IsSupported ? Vector128Alignment : FloatAlignment); - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) + public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) { Contracts.Assert(mat.Size == dst.Size * src.Size); Contracts.Assert(crun >= 0); @@ -34,12 +34,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(crun <= dst.Size); - AvxIntrinsics.MatMulX(add, mat, src, dst, crun, src.Size); + AvxIntrinsics.MatMulX(mat, src, dst, crun, src.Size); } else { Contracts.Assert(crun <= src.Size); - AvxIntrinsics.MatMulTranX(add, mat, src, dst, dst.Size, crun); + AvxIntrinsics.MatMulTranX(mat, src, dst, dst.Size, crun); } } else if (Sse.IsSupported) @@ -47,12 +47,12 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(crun <= dst.Size); - SseIntrinsics.MatMulA(add, mat, src, dst, crun, src.Size); + SseIntrinsics.MatMul(mat, src, dst, crun, src.Size); } else { Contracts.Assert(crun <= src.Size); - SseIntrinsics.MatMulTranA(add, mat, src, dst, dst.Size, crun); + SseIntrinsics.MatMulTran(mat, src, dst, dst.Size, crun); } } else @@ -68,14 +68,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr dotProduct += mat[i * src.Size + j] * src[j]; } - if (add) - { - dst[i] += dotProduct; - } - else - { - dst[i] = dotProduct; - } + dst[i] = dotProduct; } } else @@ -89,20 +82,13 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr dotProduct += mat[j * src.Size + i] * src[j]; } - if (add) - { - dst[i] += dotProduct; - } - else - { - dst[i] = dotProduct; - } + dst[i] = dotProduct; } } } } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) { Contracts.AssertValue(rgposSrc); @@ -113,8 +99,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (iposMin >= iposLim) { - if (!add) - dst.ZeroItems(); + dst.ZeroItems(); return; } @@ -123,76 +108,26 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (Avx.IsSupported) { - if (!tran) - { - Contracts.Assert(crun <= dst.Size); - AvxIntrinsics.MatMulPX(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); - } - else - { - Contracts.Assert(crun <= srcValues.Size); - AvxIntrinsics.MatMulTranPX(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size); - } + Contracts.Assert(crun <= dst.Size); + AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); } else if (Sse.IsSupported) { - if (!tran) - { - Contracts.Assert(crun <= dst.Size); - SseIntrinsics.MatMulPA(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); - } - else - { - Contracts.Assert(crun <= srcValues.Size); - SseIntrinsics.MatMulTranPA(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size); - } + Contracts.Assert(crun <= dst.Size); + SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size); } else { - if (!tran) - { - Contracts.Assert(crun <= dst.Size); - for (int i = 0; i < crun; i++) - { - float dotProduct = 0; - for (int j = iposMin; j < iposLim; j++) - { - int col = rgposSrc[j] - posMin; - dotProduct += mat[i * srcValues.Size + col] * srcValues[col]; - } - - if (add) - { - dst[i] += dotProduct; - } - else - { - dst[i] = dotProduct; - } - } - } - else + Contracts.Assert(crun <= dst.Size); + for (int i = 0; i < crun; i++) { - Contracts.Assert(crun <= srcValues.Size); - for (int i = 0; i < dst.Size; i++) + float dotProduct = 0; + for (int j = iposMin; j < iposLim; j++) { - float dotProduct = 0; - for (int j = iposMin; j < iposLim; j++) - { - int col = rgposSrc[j] - posMin; - dotProduct += mat[col * dst.Size + i] * srcValues[col]; - } - - if (add) - { - dst[i] += dotProduct; - } - else - { - dst[i] = dotProduct; - } + int col = rgposSrc[j] - posMin; + dotProduct += mat[i * srcValues.Size + col] * srcValues[col]; } - + dst[i] = dotProduct; } } } diff --git a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs index 927bf32673..d3fbaea2dc 100644 --- a/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs +++ b/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs @@ -16,10 +16,10 @@ public static partial class CpuMathUtils public static int GetVectorAlignment() => Vector128Alignment; - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, add, mat, src, dst, crun); + public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, mat, src, dst, crun); - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, - int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun); + public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun); public static void Add(float a, Span dst) => SseUtils.Add(a, dst); diff --git a/src/Microsoft.ML.CpuMath/Sse.cs b/src/Microsoft.ML.CpuMath/Sse.cs index 3c965ad89f..4281893852 100644 --- a/src/Microsoft.ML.CpuMath/Sse.cs +++ b/src/Microsoft.ML.CpuMath/Sse.cs @@ -30,7 +30,7 @@ private static bool Compat(AlignedArray a) return q; } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) + public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) { Contracts.Assert(Compat(mat)); Contracts.Assert(Compat(src)); @@ -46,18 +46,18 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr if (!tran) { Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulA(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); + Thunk.MatMul(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size); } else { Contracts.Assert(0 <= crun && crun <= src.Size); - Thunk.MatMulTranA(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); + Thunk.MatMulTran(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun); } } } } - public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, + public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues, int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) { Contracts.Assert(Compat(mat)); @@ -69,8 +69,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo if (iposMin >= iposLim) { - if (!add) - dst.ZeroItems(); + dst.ZeroItems(); return; } Contracts.AssertNonEmpty(rgposSrc); @@ -81,16 +80,8 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo fixed (float* psrc = &srcValues.Items[0]) fixed (int* ppossrc = &rgposSrc[0]) { - if (!tran) - { - Contracts.Assert(0 <= crun && crun <= dst.Size); - Thunk.MatMulPA(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); - } - else - { - Contracts.Assert(0 <= crun && crun <= srcValues.Size); - Thunk.MatMulTranPA(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size); - } + Contracts.Assert(0 <= crun && crun <= dst.Size); + Thunk.MatMulPA(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size); } } } diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index e84ac5a6ec..1a99c5bd92 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -24,13 +24,6 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath { internal static class SseIntrinsics { - internal static readonly Vector128 AbsMask128 = Sse2.IsSupported ? - Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) : - Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); - - // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray - private const int Vector128Alignment = 16; - public static readonly uint[] LeadingAlignmentMask = new uint[16] { 0x00000000, 0x00000000, 0x00000000, 0x00000000, @@ -47,6 +40,13 @@ internal static class SseIntrinsics 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, }; + internal static readonly Vector128 AbsMask128 = Sse2.IsSupported ? + Sse.StaticCast(Sse2.SetAllVector128(0x7FFFFFFF)) : + Sse.SetAllVector128(BitConverter.Int32BitsToSingle(0x7FFFFFFF)); + + // The count of bytes in Vector128, corresponding to _cbAlign in AlignedArray + private const int Vector128Alignment = 16; + [MethodImplAttribute(MethodImplOptions.AggressiveInlining)] private static bool HasCompatibleAlignment(AlignedArray alignedArray) { @@ -135,20 +135,22 @@ internal static Vector128 GetNewDst128(in Vector128 xDst1, in Vect } // Multiply matrix times vector into vector. - public static unsafe void MatMulA(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) - { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); + MatMul(mat.Items, src.Items, dst.Items, crow, ccol); + } + public static unsafe void MatMul(float[] mat, float[] src, float[] dst, int crow, int ccol) + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* pmat = &mat[0]) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) + { float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pDstCurrent = pdst; @@ -157,29 +159,117 @@ public static unsafe void MatMulA(bool add, AlignedArray mat, AlignedArray src, while (pDstCurrent < pDstEnd) { Vector128 res0 = Sse.SetZeroVector128(); - Vector128 res1 = res0; - Vector128 res2 = res0; - Vector128 res3 = res0; + Vector128 res1 = Sse.SetZeroVector128(); + Vector128 res2 = Sse.SetZeroVector128(); + Vector128 res3 = Sse.SetZeroVector128(); + int length = ccol; float* pSrcCurrent = psrc; - while (pSrcCurrent < pSrcEnd) + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 16); + int remainder = 0; + + if ((misalignment & 3) != 0) { - float* pMatTemp = pMatCurrent; + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations + while (pSrcCurrent < pSrcEnd) + { + Vector128 vector = Sse.LoadVector128(pSrcCurrent); - Vector128 x01 = Sse.LoadAlignedVector128(pMatTemp); - Vector128 x11 = Sse.LoadAlignedVector128(pMatTemp += ccol); - Vector128 x21 = Sse.LoadAlignedVector128(pMatTemp += ccol); - Vector128 x31 = Sse.LoadAlignedVector128(pMatTemp += ccol); - Vector128 x02 = Sse.LoadAlignedVector128(pSrcCurrent); + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); - res0 = Sse.Add(res0, Sse.Multiply(x01, x02)); - res1 = Sse.Add(res1, Sse.Multiply(x11, x02)); - res2 = Sse.Add(res2, Sse.Multiply(x21, x02)); - res3 = Sse.Add(res3, Sse.Multiply(x31, x02)); + res0 = Sse.Add(res0, x01); + res1 = Sse.Add(res1, x11); + res2 = Sse.Add(res2, x21); + res3 = Sse.Add(res3, x31); - pSrcCurrent += 4; - pMatCurrent += 4; + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + if (misalignment != 0) + { + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + + Vector128 mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); + + res0 = Sse.Multiply(x01, vector); + res1 = Sse.Multiply(x11, vector); + res2 = Sse.Multiply(x21, vector); + res3 = Sse.Multiply(x31, vector); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } + + if (length > 4) + { + remainder = length % 4; + while (pSrcCurrent < pSrcEnd) + { + Vector128 vector = Sse.LoadVector128(pSrcCurrent); + + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.Multiply(vector, Sse.LoadVector128(pMatTemp += ccol)); + + res0 = Sse.Add(res0, x01); + res1 = Sse.Add(res1, x11); + res2 = Sse.Add(res2, x21); + res3 = Sse.Add(res3, x31); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pSrcCurrent -= (4 - remainder); + + Vector128 mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + + float* pMatTemp = pMatCurrent; + Vector128 x01 = Sse.And(mask, Sse.LoadVector128(pMatTemp)); + Vector128 x11 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x21 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 x31 = Sse.And(mask, Sse.LoadVector128(pMatTemp += ccol)); + Vector128 vector = Sse.And(mask, Sse.LoadVector128(pSrcCurrent)); + + res0 = Sse.Add(res0, Sse.Multiply(x01, vector)); + res1 = Sse.Add(res1, Sse.Multiply(x11, vector)); + res2 = Sse.Add(res2, Sse.Multiply(x21, vector)); + res3 = Sse.Add(res3, Sse.Multiply(x31, vector)); + + pMatCurrent += 4; + pSrcCurrent += 4; + } } // Add up the entries of each, with the 4 results in res0 @@ -187,12 +277,7 @@ public static unsafe void MatMulA(bool add, AlignedArray mat, AlignedArray src, res2 = Sse3.HorizontalAdd(res2, res3); res0 = Sse3.HorizontalAdd(res0, res2); - if (add) - { - res0 = Sse.Add(res0, Sse.LoadAlignedVector128(pDstCurrent)); - } - Sse.StoreAligned(pDstCurrent, res0); - + Sse.Store(pDstCurrent, res0); pDstCurrent += 4; pMatCurrent += 3 * ccol; } @@ -200,7 +285,7 @@ public static unsafe void MatMulA(bool add, AlignedArray mat, AlignedArray src, } // Partial sparse source vector. - public static unsafe void MatMulPA(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src, + public static unsafe void MatMulPA(AlignedArray mat, int[] rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { Contracts.Assert(HasCompatibleAlignment(mat)); @@ -245,182 +330,304 @@ public static unsafe void MatMulPA(bool add, AlignedArray mat, int[] rgposSrc, A ppos++; } - if (add) - { - result = Sse.Add(result, Sse.LoadAlignedVector128(pDstCurrent)); - } Sse.StoreAligned(pDstCurrent, result); - pDstCurrent += 4; pm0 += 4 * ccol; } } } - public static unsafe void MatMulTranA(bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) + public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); + Contracts.Assert(crow % 4 == 0); + Contracts.Assert(ccol % 4 == 0); + MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); + } - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) + public static unsafe void MatMulTran(float[] mat, float[] src, float[] dst, int crow, int ccol) + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* pmat = &mat[0]) + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - float* pSrcEnd = psrc + ccol; float* pDstEnd = pdst + crow; float* pSrcCurrent = psrc; float* pMatCurrent = pmat; - if (!add) + // The reason behind adding the if condtion instead of boolean flag + // is to avoid branching in codegen. + if (pSrcCurrent < pSrcEnd) { - Vector128 x01 = Sse.LoadAlignedVector128(pSrcCurrent); + Vector128 x01 = Sse.LoadVector128(pSrcCurrent); // Replicate each 32-bit slot of x01 (ABCD) into its own register. Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D x01 = Sse.Shuffle(x01, x01, 0x00); // A - pSrcCurrent += 4; - + int length = crow; float* pDstCurrent = pdst; - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); - Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); - - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); - - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 16); - Sse.StoreAligned(pDstCurrent, x02); + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); + + Sse.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 4) + { + remainder = length % 4; + while (pDstCurrent + 4 <= pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + remainder = length; + } - pDstCurrent += 4; - pMatCurrent += 4; + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); + + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } } pMatCurrent += 3 * crow; + pSrcCurrent += 4; } + // We do 4-way unrolling while (pSrcCurrent < pSrcEnd) { - Vector128 x01 = Sse.LoadAlignedVector128(pSrcCurrent); + Vector128 x01 = Sse.LoadVector128(pSrcCurrent); // Replicate each 32-bit slot of x01 (ABCD) into its own register. Vector128 x11 = Sse.Shuffle(x01, x01, 0x55); // B Vector128 x21 = Sse.Shuffle(x01, x01, 0xAA); // C Vector128 x31 = Sse.Shuffle(x01, x01, 0xFF); // D x01 = Sse.Shuffle(x01, x01, 0x00); // A + int length = crow; float* pDstCurrent = pdst; - while (pDstCurrent < pDstEnd) - { - float* pMatTemp = pMatCurrent; - - Vector128 x02 = Sse.LoadAlignedVector128(pMatTemp); - Vector128 x12 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x22 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x32 = Sse.LoadAlignedVector128(pMatTemp += crow); - Vector128 x3 = Sse.LoadAlignedVector128(pDstCurrent); + nuint address = (nuint)(pMatCurrent); + int misalignment = (int)(address % 16); - x02 = Sse.Multiply(x01, x02); - x12 = Sse.Multiply(x11, x12); - x22 = Sse.Multiply(x21, x22); - x32 = Sse.Multiply(x31, x32); + if ((misalignment & 3) != 0) + { + while (pDstCurrent < pDstEnd) + { + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); - x02 = Sse.Add(x02, x12); - x22 = Sse.Add(x22, x32); - x02 = Sse.Add(x02, x22); - x3 = Sse.Add(x02, x3); + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); - Sse.StoreAligned(pDstCurrent, x3); + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); - pDstCurrent += 4; - pMatCurrent += 4; + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } } - - pMatCurrent += 3 * crow; - pSrcCurrent += 4; - } - } - } - - // Partial sparse source vector. - public static unsafe void MatMulTranPA(bool add, AlignedArray mat, int[] rgposSrc, AlignedArray src, - int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow) - { - Contracts.Assert(HasCompatibleAlignment(mat)); - Contracts.Assert(HasCompatibleAlignment(src)); - Contracts.Assert(HasCompatibleAlignment(dst)); - - fixed (float* pSrcStart = &src.Items[0]) - fixed (float* pDstStart = &dst.Items[0]) - fixed (float* pMatStart = &mat.Items[0]) - fixed (int* pposSrc = &rgposSrc[0]) - { - float* psrc = GetAlignedBase(src, pSrcStart); - float* pdst = GetAlignedBase(dst, pDstStart); - float* pmat = GetAlignedBase(mat, pMatStart); - - int* ppos = pposSrc + iposMin; - int* pposEnd = pposSrc + iposEnd; - float* pDstEnd = pdst + crow; - - if (!add) - { - int col = *ppos - posMin; - ppos++; - - Vector128 x0 = Sse.SetAllVector128(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; - - while (pDstCurrent < pDstEnd) + else { - Vector128 x1 = Sse.LoadAlignedVector128(pMatCurrent); - x1 = Sse.Multiply(x1, x0); - Sse.StoreAligned(pDstCurrent, x1); - - pDstCurrent += 4; - pMatCurrent += 4; - } - } - - // REVIEW: Should we explore unrolling the outer loop? - while (ppos < pposEnd) - { - int col = *ppos - posMin; - - Vector128 x0 = Sse.SetAllVector128(psrc[col]); - float* pDstCurrent = pdst; - float* pMatCurrent = pmat + col * crow; + int remainder = 0; + if (misalignment != 0) + { + // Handle cases where the data is not 128-bit aligned by doing an unaligned read and then + // masking any elements that will be included in the first aligned read + misalignment >>= 2; + misalignment = 4 - misalignment; + + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(leadingMask, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + ((4 - misalignment) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Or(x02, Sse.And(x3, trailingMask)); + + x02 = Sse.Add(x02, Sse.And(x3, leadingMask)); + + Sse.Store(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + if (length > 4) + { + remainder = length % 4; + while (pDstCurrent + 4 <= pDstEnd) + { + float* pMatTemp = pMatCurrent; + + Vector128 x02 = Sse.Multiply(x01, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.Multiply(x11, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.Multiply(x21, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.Multiply(x31, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + x02 = Sse.Add(x02, Sse.LoadVector128(pDstCurrent)); + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + remainder = length; + } - while (pDstCurrent < pDstEnd) - { - Vector128 x1 = Sse.LoadAlignedVector128(pMatCurrent); - Vector128 x2 = Sse.LoadAlignedVector128(pDstCurrent); - x1 = Sse.Multiply(x1, x0); - x2 = Sse.Add(x2, x1); - Sse.StoreAligned(pDstCurrent, x2); - - pDstCurrent += 4; - pMatCurrent += 4; + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + Vector128 trailingMask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4)); + + float* pMatTemp = pMatCurrent; + Vector128 x02 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp)); + Vector128 x12 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x22 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + Vector128 x32 = Sse.And(trailingMask, Sse.LoadVector128(pMatTemp += crow)); + + x02 = Sse.Multiply(x01, x02); + x12 = Sse.Multiply(x11, x12); + x22 = Sse.Multiply(x21, x22); + x32 = Sse.Multiply(x31, x32); + + x02 = Sse.Add(x02, x12); + x22 = Sse.Add(x22, x32); + x02 = Sse.Add(x02, x22); + + Vector128 leadingMask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + ((4 - remainder) * 4)); + Vector128 x3 = Sse.LoadVector128(pDstCurrent); + x02 = Sse.Or(x02, Sse.And(x3, leadingMask)); + + x02 = Sse.Add(x02, Sse.And(x3, trailingMask)); + Sse.Store(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } } - ppos++; + pMatCurrent += 3 * crow; + pSrcCurrent += 4; } } } diff --git a/src/Microsoft.ML.CpuMath/Thunk.cs b/src/Microsoft.ML.CpuMath/Thunk.cs index 23192fc277..22938bb0d7 100644 --- a/src/Microsoft.ML.CpuMath/Thunk.cs +++ b/src/Microsoft.ML.CpuMath/Thunk.cs @@ -16,16 +16,16 @@ internal static unsafe class Thunk public static extern bool ChkAvx(); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulA(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); + public static extern void MatMul(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulX(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); + public static extern void MatMulX(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulPA(bool add, /*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, + public static extern void MatMulPA(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, int posMin, int iposMin, int iposLim, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulPX(bool add, /*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, + public static extern void MatMulPX(/*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, int posMin, int iposMin, int iposLim, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] @@ -65,16 +65,9 @@ public static extern void RespNormU(bool add, float alpha, float beta, bool avgO // and columns from that perspective. Alternatively, crow is the number of rows in the transpose of pmat // (thought of as row-major order). [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranA(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); + public static extern void MatMulTran(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranX(bool add, /*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); - - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranPA(bool add, /*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, - int posMin, int iposMin, int iposLim, float* pdst, int crow); - [DllImport(NativePath), SuppressUnmanagedCodeSecurity] - public static extern void MatMulTranPX(bool add, /*const*/ float* pmat, /*const*/ int* pposSrc, /*const*/ float* psrc, - int posMin, int iposMin, int iposLim, float* pdst, int crow); + public static extern void MatMulTranX(/*const*/ float* pmat, /*const*/ float* psrc, float* pdst, int crow, int ccol); [DllImport(NativePath), SuppressUnmanagedCodeSecurity] public static extern void MatMulTranRU(bool add, /*const*/ int* pstarts, /*const*/ int* pindices, /*const*/ float* pcoefs, diff --git a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs index 7f0fe543f6..5c01472ed9 100644 --- a/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs +++ b/src/Microsoft.ML.TimeSeries/AdaptiveSingularSpectrumSequenceModeler.cs @@ -1130,10 +1130,10 @@ public void Consume(ref Single input, bool updateModel = false) _x[_windowSize - 1] = input; // Computing y: Eq. (11) in https://hal-institut-mines-telecom.archives-ouvertes.fr/hal-00479772/file/twocolumns.pdf - CpuAligenedMathUtils.MatTimesSrc(false, _wTrans, _x, _y); + CpuAligenedMathUtils.MatTimesSrc(_wTrans, _x, _y); // Updating the state vector - CpuAligenedMathUtils.MatTranTimesSrc(false, _wTrans, _y, _xSmooth); + CpuAligenedMathUtils.MatTranTimesSrc(_wTrans, _y, _xSmooth); _nextPrediction = _autoregressionNoiseMean + _observationNoiseMean; for (i = 0; i < _windowSize - 2; ++i) @@ -1337,7 +1337,7 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) nu += _y[i] * _y[i]; } - CpuAligenedMathUtils.MatTranTimesSrc(false, _wTrans, _y, _xSmooth); + CpuAligenedMathUtils.MatTranTimesSrc(_wTrans, _y, _xSmooth); for (i = 0; i < _windowSize - 1; ++i) _alpha[i] = _xSmooth[i] / (1 - nu); @@ -1382,8 +1382,8 @@ private void TrainCore(Single[] dataArray, int originalSeriesLength) _x[i - originalSeriesLength + _windowSize] = dataArray[i]; } - CpuAligenedMathUtils.MatTimesSrc(false, _wTrans, _x, _y); - CpuAligenedMathUtils.MatTranTimesSrc(false, _wTrans, _y, _xSmooth); + CpuAligenedMathUtils.MatTimesSrc(_wTrans, _x, _y); + CpuAligenedMathUtils.MatTranTimesSrc(_wTrans, _y, _xSmooth); for (i = 1; i < _windowSize; ++i) { diff --git a/src/Microsoft.ML.Transforms/RffTransform.cs b/src/Microsoft.ML.Transforms/RffTransform.cs index 218cc30b56..0dd2c95ffd 100644 --- a/src/Microsoft.ML.Transforms/RffTransform.cs +++ b/src/Microsoft.ML.Transforms/RffTransform.cs @@ -610,7 +610,7 @@ private void TransformFeatures(ref VBuffer src, ref VBuffer dst, T if (src.IsDense) { featuresAligned.CopyFrom(src.Values, 0, src.Length); - CpuMathUtils.MatTimesSrc(false, false, transformInfo.RndFourierVectors, featuresAligned, productAligned, + CpuMathUtils.MatTimesSrc(false, transformInfo.RndFourierVectors, featuresAligned, productAligned, transformInfo.NewDim); } else @@ -618,7 +618,7 @@ private void TransformFeatures(ref VBuffer src, ref VBuffer dst, T // This overload of MatTimesSrc ignores the values in slots that are not in src.Indices, so there is // no need to zero them out. featuresAligned.CopyFrom(src.Indices, src.Values, 0, 0, src.Count, zeroItems: false); - CpuMathUtils.MatTimesSrc(false, false, transformInfo.RndFourierVectors, src.Indices, featuresAligned, 0, 0, + CpuMathUtils.MatTimesSrc(transformInfo.RndFourierVectors, src.Indices, featuresAligned, 0, 0, src.Count, productAligned, transformInfo.NewDim); } diff --git a/src/Native/CpuMathNative/Avx.cpp b/src/Native/CpuMathNative/Avx.cpp index fd2e78ed5a..b52e28cd80 100644 --- a/src/Native/CpuMathNative/Avx.cpp +++ b/src/Native/CpuMathNative/Avx.cpp @@ -394,46 +394,6 @@ EXPORT_API(void) MatMulTranX(bool add, _In_ const float * pmat, _In_ const float _vleave(); } -// Partial sparse source vector. -EXPORT_API(void) MatMulTranPX(bool add, _In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, - int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow) -{ - const int * ppos = pposSrc + iposMin; - const int * pposLim = pposSrc + iposLim; - const float * pdLim = pdst + crow; - - if (!add) - { - int col = *ppos++ - posMin; - const float * pm = pmat + col * crow; - __m256 x0 = _mm256_set1_ps(psrc[col]); - for (float * pd = pdst; pd < pdLim; pd += 8, pm += 8) - { - __m256 x1 = _mm256_load_ps(pm); - x1 = _mm256_mul_ps(x1, x0); - _mm256_store_ps(pd, x1); - } - } - - // REVIEW: Should we explore unrolling the outer loop? - for (; ppos < pposLim; ppos++) - { - int col = *ppos - posMin; - __m256 x0 = _mm256_set1_ps(psrc[col]); - const float * pm = pmat + col * crow; - for (float * pd = pdst; pd < pdLim; pd += 8, pm += 8) - { - __m256 x1 = _mm256_load_ps(pm); - __m256 x2 = _mm256_load_ps(pd); - x1 = _mm256_mul_ps(x1, x0); - x2 = _mm256_add_ps(x2, x1); - _mm256_store_ps(pd, x2); - } - } - - _vleave(); -} - // Sparse matrix. EXPORT_API(void) MatMulTranRX(bool add, _In_ const int * pstarts, _In_ const int * pindices, _In_ const float * pcoefs, _In_ const float * psrc, _Inout_ float * pd, int crow, int ccol) diff --git a/src/Native/CpuMathNative/Sse.cpp b/src/Native/CpuMathNative/Sse.cpp index 4a2d30e979..9fde43fb12 100644 --- a/src/Native/CpuMathNative/Sse.cpp +++ b/src/Native/CpuMathNative/Sse.cpp @@ -122,33 +122,124 @@ EXPORT_API(bool) ChkAvx() } // Multiply matrix times vector into vector. -EXPORT_API(void) MatMulA(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) +EXPORT_API(void) MatMul(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { - const float * psLim = psrc + ccol; - const float * pdLim = pdst + crow; - const float * pm = pmat; - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 3 * ccol) + const float * pSrcEnd = psrc + ccol; + const float * pDstEnd = pdst + crow; + float* pDstCurrent = pdst; + const float* pMatCurrent = pmat; + + while (pDstCurrent < pDstEnd) { __m128 res0 = _mm_setzero_ps(); __m128 res1 = res0; __m128 res2 = res0; __m128 res3 = res0; - for (const float * ps = psrc; ps < psLim; ps += 4, pm += 4) - { - const float * pmTmp; - __m128 x01 = _mm_load_ps(pmTmp = pm); - __m128 x11 = _mm_load_ps(pmTmp += ccol); - __m128 x21 = _mm_load_ps(pmTmp += ccol); - __m128 x31 = _mm_load_ps(pmTmp += ccol); - __m128 x02 = _mm_load_ps(ps); - x01 = _mm_mul_ps(x01, x02); - x11 = _mm_mul_ps(x11, x02); - x21 = _mm_mul_ps(x21, x02); - x31 = _mm_mul_ps(x31, x02); - res0 = _mm_add_ps(res0, x01); - res1 = _mm_add_ps(res1, x11); - res2 = _mm_add_ps(res2, x21); - res3 = _mm_add_ps(res3, x31); + + int length = ccol; + const float* pSrcCurrent = psrc; + + uintptr_t address = (uintptr_t)(pMatCurrent); + uintptr_t misalignment = address % 16; + int remainder = 0; + + if ((misalignment & 3) != 0) + { + while (pSrcCurrent < pSrcEnd) + { + __m128 vector = _mm_loadu_ps(pSrcCurrent); + + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_mul_ps(vector, _mm_loadu_ps(pMatTemp += ccol)); + + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + if (misalignment != 0) + { + misalignment >>= 2; + misalignment = 4 - misalignment; + + __m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); + + res0 = _mm_mul_ps(x01, vector); + res1 = _mm_mul_ps(x11, vector); + res2 = _mm_mul_ps(x21, vector); + res3 = _mm_mul_ps(x31, vector); + + pMatCurrent += misalignment; + pSrcCurrent += misalignment; + length -= misalignment; + } + + if (length > 3) + { + remainder = length % 4; + while(pSrcCurrent < pSrcEnd) + { + __m128 vector = _mm_loadu_ps(pSrcCurrent); + + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp)); + __m128 x11 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + __m128 x21 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + __m128 x31 = _mm_mul_ps(vector, _mm_load_ps(pMatTemp += ccol)); + + res0 = _mm_add_ps(res0, x01); + res1 = _mm_add_ps(res1, x11); + res2 = _mm_add_ps(res2, x21); + res3 = _mm_add_ps(res3, x31); + + pSrcCurrent += 4; + pMatCurrent += 4; + } + } + else + { + remainder = length; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pSrcCurrent -= (4 - remainder); + + __m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + + const float* pMatTemp = pMatCurrent; + __m128 x01 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp)); + __m128 x11 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x21 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 x31 = _mm_and_ps(mask, _mm_loadu_ps(pMatTemp += ccol)); + __m128 vector = _mm_and_ps(mask, _mm_loadu_ps(pSrcCurrent)); + + res0 = _mm_add_ps(x01, _mm_mul_ps(x01, vector)); + res1 = _mm_add_ps(x11, _mm_mul_ps(x11, vector)); + res2 = _mm_add_ps(x21, _mm_mul_ps(x21, vector)); + res3 = _mm_add_ps(x31, _mm_mul_ps(x31, vector)); + + pMatCurrent += 4; + pSrcCurrent += 4; + } } // Add up the entries of each, with the 4 results in res0 @@ -156,14 +247,15 @@ EXPORT_API(void) MatMulA(bool add, _In_ const float * pmat, _In_ const float * p res2 = _mm_hadd_ps(res2, res3); res0 = _mm_hadd_ps(res0, res2); - if (add) - res0 = _mm_add_ps(res0, _mm_load_ps(pd)); - _mm_store_ps(pd, res0); + _mm_storeu_ps(pDstCurrent, res0); + + pDstCurrent += 4; + pMatCurrent += 3 * ccol; } } // Partial sparse source vector. -EXPORT_API(void) MatMulPA(bool add, _In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, +EXPORT_API(void) MatMulPA(_In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow, int ccol) { // REVIEW: For extremely sparse inputs, interchanging the loops would @@ -188,8 +280,6 @@ EXPORT_API(void) MatMulPA(bool add, _In_ const float * pmat, _In_ const int * pp res = _mm_add_ps(res, x2); } - if (add) - res = _mm_add_ps(res, _mm_load_ps(pd)); _mm_store_ps(pd, res); } } @@ -495,108 +585,284 @@ EXPORT_API(void) RespNormU(bool add, float alpha, float beta, bool avgOverFullKe } } -EXPORT_API(void) MatMulTranA(bool add, _In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) +EXPORT_API(void) MatMulTran(_In_ const float * pmat, _In_ const float * psrc, _Inout_ float * pdst, int crow, int ccol) { - const float * psLim = psrc + ccol; - const float * pdLim = pdst + crow; - const float * pm = pmat; - const float * ps = psrc; + const float * pSrcEnd = psrc + ccol; + const float * pDstEnd = pdst + crow; + + const float* pMatCurrent = pmat; + const float* pSrcCurrent = psrc; - if (!add) + if (pSrcCurrent < pSrcEnd) { - __m128 x01 = _mm_load_ps(ps); + __m128 x01 = _mm_loadu_ps(pSrcCurrent); // Replicate each slot of x01 into its own register. __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); x01 = _mm_shuffle_ps(x01, x01, 0x00); - ps += 4; - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) + + int length = crow; + float* pDstCurrent = pdst; + + uintptr_t address = (uintptr_t)(pMatCurrent); + uintptr_t misalignment = address % 16; + int remainder = 0; + + if ((misalignment & 3) != 0) { - const float * pmTmp; - __m128 x02 = _mm_load_ps(pmTmp = pm); - __m128 x12 = _mm_load_ps(pmTmp += crow); - __m128 x22 = _mm_load_ps(pmTmp += crow); - __m128 x32 = _mm_load_ps(pmTmp += crow); - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - _mm_store_ps(pd, x02); + while (pDstCurrent < pDstEnd) + { + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + _mm_storeu_ps(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } } + else + { + int remainder = 0; + if (misalignment != 0) + { + misalignment >>= 2; + misalignment = 4 - misalignment; + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (( 4 - misalignment) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); + + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } - pm += 3 * crow; + if(length > 3) + { + remainder = length % 4; + while (pDstCurrent < pDstEnd) + { + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + _mm_storeu_ps(pDstCurrent, x02); + + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + length = remainder; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (( 4 - remainder) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); + + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += 4; + pDstCurrent += 4; + } + } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; } - for (; ps < psLim; ps += 4) + while (pSrcCurrent < pSrcEnd) { - __m128 x01 = _mm_load_ps(ps); + __m128 x01 = _mm_loadu_ps(pSrcCurrent); // Replicate each slot of x01 into its own register. __m128 x11 = _mm_shuffle_ps(x01, x01, 0x55); __m128 x21 = _mm_shuffle_ps(x01, x01, 0xAA); __m128 x31 = _mm_shuffle_ps(x01, x01, 0xFF); x01 = _mm_shuffle_ps(x01, x01, 0x00); - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) - { - const float * pmTmp; - __m128 x02 = _mm_load_ps(pmTmp = pm); - __m128 x12 = _mm_load_ps(pmTmp += crow); - __m128 x22 = _mm_load_ps(pmTmp += crow); - __m128 x32 = _mm_load_ps(pmTmp += crow); - __m128 x3 = _mm_load_ps(pd); - x02 = _mm_mul_ps(x01, x02); - x12 = _mm_mul_ps(x11, x12); - x22 = _mm_mul_ps(x21, x22); - x32 = _mm_mul_ps(x31, x32); - x02 = _mm_add_ps(x02, x12); - x22 = _mm_add_ps(x22, x32); - x02 = _mm_add_ps(x02, x22); - x3 = _mm_add_ps(x02, x3); - _mm_store_ps(pd, x3); - } - pm += 3 * crow; - } -} + int length = crow; + float* pDstCurrent = pdst; -// Partial sparse source vector. -EXPORT_API(void) MatMulTranPA(bool add, _In_ const float * pmat, _In_ const int * pposSrc, _In_ const float * psrc, - int posMin, int iposMin, int iposLim, _Inout_ float * pdst, int crow) -{ - const int * ppos = pposSrc + iposMin; - const int * pposLim = pposSrc + iposLim; - const float * pdLim = pdst + crow; + uintptr_t address = (uintptr_t)(pMatCurrent); + uintptr_t misalignment = address % 16; + int remainder = 0; - if (!add) - { - int col = *ppos++ - posMin; - const float * pm = pmat + col * crow; - __m128 x0 = _mm_set1_ps(psrc[col]); - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) + if ((misalignment & 3) != 0) { - __m128 x1 = _mm_load_ps(pm); - x1 = _mm_mul_ps(x1, x0); - _mm_store_ps(pd, x1); - } - } + while (pDstCurrent < pDstEnd) + { + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_loadu_ps(pMatTemp += crow)); - // REVIEW: Should we explore unrolling the outer loop? - for (; ppos < pposLim; ppos++) - { - int col = *ppos - posMin; - __m128 x0 = _mm_set1_ps(psrc[col]); - const float * pm = pmat + col * crow; - for (float * pd = pdst; pd < pdLim; pd += 4, pm += 4) + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); + + _mm_storeu_ps(pDstCurrent, x02); + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else { - __m128 x1 = _mm_load_ps(pm); - __m128 x2 = _mm_load_ps(pd); - x1 = _mm_mul_ps(x1, x0); - x2 = _mm_add_ps(x2, x1); - _mm_store_ps(pd, x2); + int remainder = 0; + if (misalignment != 0) + { + misalignment >>= 2; + misalignment = 4 - misalignment; + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4)); + + // We only align pMat since it has significantly more reads. + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(leadingMask, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (( 4 - misalignment) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_or_ps(x02, _mm_and_ps(x3, trailingMask)); + x02 = _mm_add_ps(x02, _mm_and_ps(x3, leadingMask)); + + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += misalignment; + pDstCurrent += misalignment; + length -= misalignment; + } + + if(length > 3) + { + remainder = length % 4; + while (pDstCurrent < pDstEnd) + { + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_mul_ps(x01, _mm_load_ps(pMatTemp)); + __m128 x12 = _mm_mul_ps(x11, _mm_load_ps(pMatTemp += crow)); + __m128 x22 = _mm_mul_ps(x21, _mm_load_ps(pMatTemp += crow)); + __m128 x32 = _mm_mul_ps(x31, _mm_load_ps(pMatTemp += crow)); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + x02 = _mm_add_ps(x02, _mm_loadu_ps(pDstCurrent)); + + _mm_storeu_ps(pDstCurrent, x02); + + pDstCurrent += 4; + pMatCurrent += 4; + } + } + else + { + length = remainder; + } + + if (remainder != 0) + { + pMatCurrent -= (4 - remainder); + pDstCurrent -= (4 - remainder); + + __m128 trailingMask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4)); + + const float* pMatTemp = pMatCurrent; + __m128 x02 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp)); + __m128 x12 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x22 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + __m128 x32 = _mm_and_ps(trailingMask, _mm_loadu_ps(pMatTemp += crow)); + + x02 = _mm_mul_ps(x01, x02); + x12 = _mm_mul_ps(x11, x12); + x22 = _mm_mul_ps(x21, x22); + x32 = _mm_mul_ps(x31, x32); + + x02 = _mm_add_ps(x02, x12); + x22 = _mm_add_ps(x22, x32); + x02 = _mm_add_ps(x02, x22); + + __m128 leadingMask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (( 4 - remainder) * 4)); + __m128 x3 = _mm_loadu_ps(pDstCurrent); + x02 = _mm_or_ps(x02, _mm_and_ps(x3, leadingMask)); + + x02 = _mm_add_ps(x02, _mm_and_ps(x3, trailingMask)); + _mm_storeu_ps(pDstCurrent, x02); + pMatCurrent += 4; + pDstCurrent += 4; + } } + + pMatCurrent += 3 * crow; + pSrcCurrent += 4; } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs index b9deabd455..16b04ba6bd 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/AvxPerformanceTests.cs @@ -98,5 +98,12 @@ public void SdcaL1UpdateU() [Benchmark] public void SdcaL1UpdateSU() => AvxIntrinsics.SdcaL1UpdateSU(DefaultScale, IndexLength, src, idx, DefaultScale, dst, result); + [Benchmark] + public void MatMulX() + => AvxIntrinsics.MatMulX(src, src1, dst, 1000, 1000); + + [Benchmark] + public void MatMulTranX() + => AvxIntrinsics.MatMulTranX(src, src1, dst, 1000, 1000); } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs index 3cce45046c..07958c19ed 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/NativePerformanceTests.cs @@ -228,5 +228,27 @@ public unsafe void SdcaL1UpdateSU() CpuMathNativeUtils.SdcaL1UpdateSU(DefaultScale, psrc, pidx, DefaultScale, pdst, pres, IndexLength); } } + + [Benchmark] + public unsafe void MatMul() + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* psrc1 = &src1[0]) + { + Thunk.MatMul(psrc1, psrc, pdst, 1000, 1000); + } + } + + [Benchmark] + public unsafe void MatMulTran() + { + fixed (float* psrc = &src[0]) + fixed (float* pdst = &dst[0]) + fixed (float* psrc1 = &src1[0]) + { + Thunk.MatMulTran(psrc1, psrc, pdst, 1000, 1000); + } + } } } diff --git a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs index c10ad936c2..8e94dabb96 100644 --- a/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs +++ b/test/Microsoft.ML.CpuMath.PerformanceTests/SsePerformanceTests.cs @@ -98,5 +98,13 @@ public void SdcaL1UpdateU() [Benchmark] public void SdcaL1UpdateSU() => SseIntrinsics.SdcaL1UpdateSU(DefaultScale, IndexLength, src, idx, DefaultScale, dst, result); + + [Benchmark] + public void MatMulX() + => SseIntrinsics.MatMul(src, src1, dst, 1000, 1000); + + [Benchmark] + public void MatMulTranX() + => SseIntrinsics.MatMulTran(src, src1, dst, 1000, 1000); } } diff --git a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs index 7c92f747e2..8ce7878f83 100644 --- a/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs +++ b/test/Microsoft.ML.CpuMath.UnitTests.netcoreapp/UnitTests.cs @@ -84,29 +84,13 @@ public CpuMathUtilsUnitTests() [InlineData(0, 0, 0, new float[] { -416.6801f, -416.6801f, -416.6801f, -416.6801f, -416.6801f, -416.6801f, -416.6801f, -416.6801f })] [InlineData(1, 1, 0, new float[] { 1496f, 3672f, 5848f, 8024f, 10200f, 12376f, 14552f, 16728f })] [InlineData(1, 0, 1, new float[] { 204f, 492f, 780f, 1068f, 1356f, 1644f, 1932f, 2220f, 2508f, 2796f, 3084f, 3372f, 3660f, 3948f, 4236f, 4524f })] - public void MatMulATest(int matTest, int srcTest, int dstTest, float[] expected) + public void MatMulTest(int matTest, int srcTest, int dstTest, float[] expected) { AlignedArray mat = _testMatrices[matTest]; AlignedArray src = _testSrcVectors[srcTest]; AlignedArray dst = _testDstVectors[dstTest]; - CpuMathUtils.MatTimesSrc(false, false, mat, src, dst, dst.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { -416.6801f, -415.6801f, -414.6801f, -413.6801f, -412.6801f, -411.6801f, -410.6801f, -409.6801f })] - [InlineData(1, 1, 0, new float[] { 1496f, 3673f, 5850f, 8027f, 10204f, 12381f, 14558f, 16735f })] - [InlineData(1, 0, 1, new float[] { 204f, 493f, 782f, 1071f, 1360f, 1649f, 1938f, 2227f, 2516f, 2805f, 3094f, 3383f, 3672f, 3961f, 4250f, 4539f })] - public void MatMulAAddTest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - - CpuMathUtils.MatTimesSrc(false, true, mat, src, dst, dst.Size); + CpuMathUtils.MatTimesSrc(false, mat, src, dst, dst.Size); float[] actual = new float[dst.Size]; dst.CopyTo(actual, 0, dst.Size); Assert.Equal(expected, actual, _matMulComparer); @@ -116,29 +100,13 @@ public void MatMulAAddTest(int matTest, int srcTest, int dstTest, float[] expect [InlineData(0, 0, 0, new float[] { 70.56001f, -85.68f, -351.36f, 498.24f, -3829.32f, -969.48f, 1168.2f, 118.44f })] [InlineData(1, 0, 1, new float[] { 2724f, 2760f, 2796f, 2832f, 2868f, 2904f, 2940f, 2976f, 3012f, 3048f, 3084f, 3120f, 3156f, 3192f, 3228f, 3264f })] [InlineData(1, 1, 0, new float[] { 11016f, 11152f, 11288f, 11424f, 11560f, 11696f, 11832f, 11968f })] - public void MatMulTranATest(int matTest, int srcTest, int dstTest, float[] expected) + public void MatMulTranTest(int matTest, int srcTest, int dstTest, float[] expected) { AlignedArray mat = _testMatrices[matTest]; AlignedArray src = _testSrcVectors[srcTest]; AlignedArray dst = _testDstVectors[dstTest]; - CpuMathUtils.MatTimesSrc(true, false, mat, src, dst, src.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { 70.56001f, -84.68f, -349.36f, 501.24f, -3825.32f, -964.48f, 1174.2f, 125.44f })] - [InlineData(1, 0, 1, new float[] { 2724f, 2761f, 2798f, 2835f, 2872f, 2909f, 2946f, 2983f, 3020f, 3057f, 3094f, 3131f, 3168f, 3205f, 3242f, 3279f })] - [InlineData(1, 1, 0, new float[] { 11016f, 11153f, 11290f, 11427f, 11564f, 11701f, 11838f, 11975f })] - public void MatMulTranAAddTest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - - CpuMathUtils.MatTimesSrc(true, true, mat, src, dst, src.Size); + CpuMathUtils.MatTimesSrc(true, mat, src, dst, src.Size); float[] actual = new float[dst.Size]; dst.CopyTo(actual, 0, dst.Size); Assert.Equal(expected, actual, _matMulComparer); @@ -155,58 +123,7 @@ public void MatMulPATest(int matTest, int srcTest, int dstTest, float[] expected AlignedArray dst = _testDstVectors[dstTest]; int[] idx = _testIndexArray; - CpuMathUtils.MatTimesSrc(false, false, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { 38.25002f, 39.25002f, 40.25002f, 41.25002f, 42.25002f, 43.25002f, 44.25002f, 45.25002f })] - [InlineData(1, 1, 0, new float[] { 910f, 2191f, 3472f, 4753f, 6034f, 7315f, 8596f, 9877f })] - [InlineData(1, 0, 1, new float[] { 95f, 232f, 369f, 506f, 643f, 780f, 917f, 1054f, 1191f, 1328f, 1465f, 1602f, 1739f, 1876f, 2013f, 2150f })] - public void MatMulPAAddTest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - int[] idx = _testIndexArray; - - CpuMathUtils.MatTimesSrc(false, true, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { 33.32f, -40.46f, -165.92f, 235.28f, -1808.29f, -457.81f, 551.65f, 55.93f })] - [InlineData(1, 0, 1, new float[] { 1265f, 1282f, 1299f, 1316f, 1333f, 1350f, 1367f, 1384f, 1401f, 1418f, 1435f, 1452f, 1469f, 1486f, 1503f, 1520f })] - [InlineData(1, 1, 0, new float[] { 6720f, 6800f, 6880f, 6960f, 7040f, 7120f, 7200f, 7280f })] - public void MatMulTranPATest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - int[] idx = _testIndexArray; - - CpuMathUtils.MatTimesSrc(true, false, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, src.Size); - float[] actual = new float[dst.Size]; - dst.CopyTo(actual, 0, dst.Size); - Assert.Equal(expected, actual, _matMulComparer); - } - - [Theory] - [InlineData(0, 0, 0, new float[] { 33.32f, -39.46f, -163.92f, 238.28f, -1804.29f, -452.81f, 557.65f, 62.93f })] - [InlineData(1, 0, 1, new float[] { 1265f, 1283f, 1301f, 1319f, 1337f, 1355f, 1373f, 1391f, 1409f, 1427f, 1445f, 1463f, 1481f, 1499f, 1517f, 1535f })] - [InlineData(1, 1, 0, new float[] { 6720f, 6801f, 6882f, 6963f, 7044f, 7125f, 7206f, 7287f })] - public void MatMulTranPAAddTest(int matTest, int srcTest, int dstTest, float[] expected) - { - AlignedArray mat = _testMatrices[matTest]; - AlignedArray src = _testSrcVectors[srcTest]; - AlignedArray dst = _testDstVectors[dstTest]; - int[] idx = _testIndexArray; - - CpuMathUtils.MatTimesSrc(true, true, mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, src.Size); + CpuMathUtils.MatTimesSrc(mat, idx, src, 0, 0, (srcTest == 0) ? 4 : 9, dst, dst.Size); float[] actual = new float[dst.Size]; dst.CopyTo(actual, 0, dst.Size); Assert.Equal(expected, actual, _matMulComparer);