Skip to content

Commit 263a67b

Browse files
authored
Same implementation for Sparse Multiplication for aligned and unaligned arrays (#1274)
* sparse vector corrected * Removind Dead Code, correcting names, adding assert checks to correct place, span overloads and function for common code * fixing build on unix * cmake file corrected, if def removed from sse.cpp and unitest name modified * Performance test corrected, resolved merge conflicts, fma supported added
1 parent 00021b6 commit 263a67b

File tree

14 files changed

+543
-5913
lines changed

14 files changed

+543
-5913
lines changed

src/Microsoft.ML.CpuMath/Avx.cs

-1,165
This file was deleted.

src/Microsoft.ML.CpuMath/AvxIntrinsics.cs

+136-56
Original file line numberDiff line numberDiff line change
@@ -46,25 +46,6 @@ internal static class AvxIntrinsics
4646

4747
private static readonly Vector256<float> _absMask256 = Avx.StaticCast<int, float>(Avx.SetAllVector256(0x7FFFFFFF));
4848

49-
private const int Vector256Alignment = 32;
50-
51-
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
52-
private static bool HasCompatibleAlignment(AlignedArray alignedArray)
53-
{
54-
Contracts.AssertValue(alignedArray);
55-
Contracts.Assert(alignedArray.Size > 0);
56-
return (alignedArray.CbAlign % Vector256Alignment) == 0;
57-
}
58-
59-
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
60-
private static unsafe float* GetAlignedBase(AlignedArray alignedArray, float* unalignedBase)
61-
{
62-
Contracts.AssertValue(alignedArray);
63-
float* alignedBase = unalignedBase + alignedArray.GetBase((long)unalignedBase);
64-
Contracts.Assert(((long)alignedBase % Vector256Alignment) == 0);
65-
return alignedBase;
66-
}
67-
6849
[MethodImplAttribute(MethodImplOptions.AggressiveInlining)]
6950
private static Vector128<float> GetHigh(in Vector256<float> x)
7051
=> Avx.ExtractVector128(x, 1);
@@ -170,19 +151,19 @@ private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<flo
170151
}
171152

172153
// Multiply matrix times vector into vector.
173-
public static unsafe void MatMulX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
154+
public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
174155
{
175-
Contracts.Assert(crow % 4 == 0);
176-
Contracts.Assert(ccol % 4 == 0);
177-
178-
MatMulX(mat.Items, src.Items, dst.Items, crow, ccol);
156+
MatMul(mat.Items, src.Items, dst.Items, crow, ccol);
179157
}
180158

181-
public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int crow, int ccol)
159+
public static unsafe void MatMul(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
182160
{
183-
fixed (float* psrc = &src[0])
184-
fixed (float* pdst = &dst[0])
185-
fixed (float* pmat = &mat[0])
161+
Contracts.Assert(crow % 4 == 0);
162+
Contracts.Assert(ccol % 4 == 0);
163+
164+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
165+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
166+
fixed (float* pmat = &MemoryMarshal.GetReference(mat))
186167
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
187168
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
188169
{
@@ -312,32 +293,134 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
312293
}
313294

314295
// Partial sparse source vector.
315-
public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArray src,
316-
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
296+
public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan<int> rgposSrc, AlignedArray src,
297+
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
317298
{
318-
Contracts.Assert(HasCompatibleAlignment(mat));
319-
Contracts.Assert(HasCompatibleAlignment(src));
320-
Contracts.Assert(HasCompatibleAlignment(dst));
299+
MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol);
300+
}
301+
302+
public static unsafe void MatMulP(ReadOnlySpan<float> mat, ReadOnlySpan<int> rgposSrc, ReadOnlySpan<float> src,
303+
int posMin, int iposMin, int iposEnd, Span<float> dst, int crow, int ccol)
304+
{
305+
Contracts.Assert(crow % 8 == 0);
306+
Contracts.Assert(ccol % 8 == 0);
321307

322308
// REVIEW: For extremely sparse inputs, interchanging the loops would
323309
// likely be more efficient.
324-
fixed (float* pSrcStart = &src.Items[0])
325-
fixed (float* pDstStart = &dst.Items[0])
326-
fixed (float* pMatStart = &mat.Items[0])
327-
fixed (int* pposSrc = &rgposSrc[0])
310+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
311+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
312+
fixed (float* pmat = &MemoryMarshal.GetReference(mat))
313+
fixed (int* pposSrc = &MemoryMarshal.GetReference(rgposSrc))
314+
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
315+
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
328316
{
329-
float* psrc = GetAlignedBase(src, pSrcStart);
330-
float* pdst = GetAlignedBase(dst, pDstStart);
331-
float* pmat = GetAlignedBase(mat, pMatStart);
332-
333317
int* pposMin = pposSrc + iposMin;
334318
int* pposEnd = pposSrc + iposEnd;
335319
float* pDstEnd = pdst + crow;
336320
float* pm0 = pmat - posMin;
337321
float* pSrcCurrent = psrc - posMin;
338322
float* pDstCurrent = pdst;
339323

340-
while (pDstCurrent < pDstEnd)
324+
nuint address = (nuint)(pDstCurrent);
325+
int misalignment = (int)(address % 32);
326+
int length = crow;
327+
int remainder = 0;
328+
329+
if ((misalignment & 3) != 0)
330+
{
331+
while (pDstCurrent < pDstEnd)
332+
{
333+
Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
334+
pDstCurrent += 8;
335+
pm0 += 8 * ccol;
336+
}
337+
}
338+
else
339+
{
340+
if (misalignment != 0)
341+
{
342+
misalignment >>= 2;
343+
misalignment = 8 - misalignment;
344+
345+
Vector256<float> mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
346+
347+
float* pm1 = pm0 + ccol;
348+
float* pm2 = pm1 + ccol;
349+
float* pm3 = pm2 + ccol;
350+
Vector256<float> result = Avx.SetZeroVector256<float>();
351+
352+
int* ppos = pposMin;
353+
354+
while (ppos < pposEnd)
355+
{
356+
int col1 = *ppos;
357+
int col2 = col1 + 4 * ccol;
358+
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
359+
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
360+
361+
x1 = Avx.And(mask, x1);
362+
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
363+
result = MultiplyAdd(x2, x1, result);
364+
ppos++;
365+
}
366+
367+
Avx.Store(pDstCurrent, result);
368+
pDstCurrent += misalignment;
369+
pm0 += misalignment * ccol;
370+
length -= misalignment;
371+
}
372+
373+
if (length > 7)
374+
{
375+
remainder = length % 8;
376+
while (pDstCurrent < pDstEnd)
377+
{
378+
Avx.Store(pDstCurrent, SparseMultiplicationAcrossRow());
379+
pDstCurrent += 8;
380+
pm0 += 8 * ccol;
381+
}
382+
}
383+
else
384+
{
385+
remainder = length;
386+
}
387+
388+
if (remainder != 0)
389+
{
390+
pDstCurrent -= (8 - remainder);
391+
pm0 -= (8 - remainder) * ccol;
392+
Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
393+
Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8));
394+
395+
float* pm1 = pm0 + ccol;
396+
float* pm2 = pm1 + ccol;
397+
float* pm3 = pm2 + ccol;
398+
Vector256<float> result = Avx.SetZeroVector256<float>();
399+
400+
int* ppos = pposMin;
401+
402+
while (ppos < pposEnd)
403+
{
404+
int col1 = *ppos;
405+
int col2 = col1 + 4 * ccol;
406+
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
407+
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
408+
x1 = Avx.And(x1, trailingMask);
409+
410+
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
411+
result = MultiplyAdd(x2, x1, result);
412+
ppos++;
413+
}
414+
415+
result = Avx.Add(result, Avx.And(leadingMask, Avx.LoadVector256(pDstCurrent)));
416+
417+
Avx.Store(pDstCurrent, result);
418+
pDstCurrent += 8;
419+
pm0 += 8 * ccol;
420+
}
421+
}
422+
423+
Vector256<float> SparseMultiplicationAcrossRow()
341424
{
342425
float* pm1 = pm0 + ccol;
343426
float* pm2 = pm1 + ccol;
@@ -351,33 +434,30 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra
351434
int col1 = *ppos;
352435
int col2 = col1 + 4 * ccol;
353436
Vector256<float> x1 = Avx.SetVector256(pm3[col2], pm2[col2], pm1[col2], pm0[col2],
354-
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
437+
pm3[col1], pm2[col1], pm1[col1], pm0[col1]);
355438
Vector256<float> x2 = Avx.SetAllVector256(pSrcCurrent[col1]);
356439
result = MultiplyAdd(x2, x1, result);
357-
358440
ppos++;
359441
}
360442

361-
Avx.StoreAligned(pDstCurrent, result);
362-
pDstCurrent += 8;
363-
pm0 += 8 * ccol;
443+
return result;
364444
}
365445
}
366446
}
367447

368-
public static unsafe void MatMulTranX(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
448+
public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
369449
{
370-
Contracts.Assert(crow % 4 == 0);
371-
Contracts.Assert(ccol % 4 == 0);
372-
373-
MatMulTranX(mat.Items, src.Items, dst.Items, crow, ccol);
450+
MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol);
374451
}
375452

376-
public static unsafe void MatMulTranX(float[] mat, float[] src, float[] dst, int crow, int ccol)
453+
public static unsafe void MatMulTran(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
377454
{
378-
fixed (float* psrc = &src[0])
379-
fixed (float* pdst = &dst[0])
380-
fixed (float* pmat = &mat[0])
455+
Contracts.Assert(crow % 4 == 0);
456+
Contracts.Assert(ccol % 4 == 0);
457+
458+
fixed (float* psrc = &MemoryMarshal.GetReference(src))
459+
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
460+
fixed (float* pmat = &MemoryMarshal.GetReference(mat))
381461
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
382462
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
383463
{

src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, Al
3434
if (!tran)
3535
{
3636
Contracts.Assert(crun <= dst.Size);
37-
AvxIntrinsics.MatMulX(mat, src, dst, crun, src.Size);
37+
AvxIntrinsics.MatMul(mat, src, dst, crun, src.Size);
3838
}
3939
else
4040
{
4141
Contracts.Assert(crun <= src.Size);
42-
AvxIntrinsics.MatMulTranX(mat, src, dst, dst.Size, crun);
42+
AvxIntrinsics.MatMulTran(mat, src, dst, dst.Size, crun);
4343
}
4444
}
4545
else if (Sse.IsSupported)
@@ -109,12 +109,12 @@ public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray sr
109109
if (Avx.IsSupported)
110110
{
111111
Contracts.Assert(crun <= dst.Size);
112-
AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
112+
AvxIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
113113
}
114114
else if (Sse.IsSupported)
115115
{
116116
Contracts.Assert(crun <= dst.Size);
117-
SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
117+
SseIntrinsics.MatMulP(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
118118
}
119119
else
120120
{

0 commit comments

Comments
 (0)