Skip to content

Commit 7d9660a

Browse files
authored
Common Implemenatation for MatMul and MatMulTran for both aligned and unaligned arrays (#1218)
* implemenatation and unitTests added * added performance test for matmul and matmulTrans * load combined with math operation * add flag removed * TransPA removed as nobody uses this combination of flags * removed firstTime and corrected nativePerformanceTests * removed branch from hot path sseintrinsics
1 parent 8b19930 commit 7d9660a

File tree

16 files changed

+1221
-722
lines changed

16 files changed

+1221
-722
lines changed

src/Microsoft.ML.CpuMath/Avx.cs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public static bool CheckAvx()
3434
return Thunk.ChkAvx();
3535
}
3636

37-
public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
37+
public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
3838
{
3939
Contracts.Assert(Compat(mat));
4040
Contracts.Assert(Compat(src));
@@ -50,18 +50,18 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr
5050
if (!tran)
5151
{
5252
Contracts.Assert(0 <= crun && crun <= dst.Size);
53-
Thunk.MatMulX(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size);
53+
Thunk.MatMulX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size);
5454
}
5555
else
5656
{
5757
Contracts.Assert(0 <= crun && crun <= src.Size);
58-
Thunk.MatMulTranX(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun);
58+
Thunk.MatMulTranX(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun);
5959
}
6060
}
6161
}
6262
}
6363

64-
public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
64+
public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
6565
int posMin, int iposMin, int iposLim, AlignedArray dst, int crun)
6666
{
6767
Contracts.Assert(Compat(mat));
@@ -73,8 +73,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo
7373

7474
if (iposMin >= iposLim)
7575
{
76-
if (!add)
77-
dst.ZeroItems();
76+
dst.ZeroItems();
7877
return;
7978
}
8079
Contracts.AssertNonEmpty(rgposSrc);
@@ -85,16 +84,8 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo
8584
fixed (float* psrc = &srcValues.Items[0])
8685
fixed (int* ppossrc = &rgposSrc[0])
8786
{
88-
if (!tran)
89-
{
90-
Contracts.Assert(0 <= crun && crun <= dst.Size);
91-
Thunk.MatMulPX(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size);
92-
}
93-
else
94-
{
95-
Contracts.Assert(0 <= crun && crun <= srcValues.Size);
96-
Thunk.MatMulTranPX(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size);
97-
}
87+
Contracts.Assert(0 <= crun && crun <= dst.Size);
88+
Thunk.MatMulPX(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size);
9889
}
9990
}
10091
}

src/Microsoft.ML.CpuMath/AvxIntrinsics.cs

Lines changed: 380 additions & 170 deletions
Large diffs are not rendered by default.

src/Microsoft.ML.CpuMath/CpuAligenedMathUtils.cs

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -75,40 +75,32 @@ public static void AssertCompatible(ICpuFullMatrix mat, ICpuVector src, ICpuVect
7575

7676
/// <summary>
7777
/// Matrix multiplication:
78-
/// if (add)
79-
/// dst = mat * src
80-
/// else
81-
/// dest += mat * src
78+
/// dst = mat * src
8279
/// </summary>
83-
/// <param name="add">The addition flag</param>
8480
/// <param name="mat">The multiplier matrix</param>
8581
/// <param name="src">The source vector</param>
8682
/// <param name="dst">The destination vector</param>
87-
public static void MatTimesSrc(bool add, ICpuFullMatrix mat, ICpuVector src, ICpuVector dst)
83+
public static void MatTimesSrc(ICpuFullMatrix mat, ICpuVector src, ICpuVector dst)
8884
{
8985
bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol);
9086
AssertCompatible(mat, src, dst);
9187
var m = A(mat);
92-
CpuMathUtils.MatTimesSrc(colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt);
88+
CpuMathUtils.MatTimesSrc(colMajor, m.Items, A(src).Items, A(dst).Items, m.RunCnt);
9389
}
9490

9591
/// <summary>
9692
/// Matrix transpose multiplication:
97-
/// if (add)
98-
/// dst = mat' * src
99-
/// else
100-
/// dest += mat' * src
93+
/// dst = mat' * src
10194
/// </summary>
102-
/// <param name="add">The addition flag</param>
10395
/// <param name="mat">The multiplier matrix</param>
10496
/// <param name="src">The source vector</param>
10597
/// <param name="dst">The destination vector</param>
106-
public static void MatTranTimesSrc(bool add, ICpuFullMatrix mat, ICpuVector src, ICpuVector dst)
98+
public static void MatTranTimesSrc(ICpuFullMatrix mat, ICpuVector src, ICpuVector dst)
10799
{
108100
bool colMajor = typeof(TMatrix) == typeof(CpuAlignedMatrixCol);
109101
AssertCompatible(mat, dst, src);
110102
var m = A(mat);
111-
CpuMathUtils.MatTimesSrc(!colMajor, add, m.Items, A(src).Items, A(dst).Items, m.RunCnt);
103+
CpuMathUtils.MatTimesSrc(!colMajor, m.Items, A(src).Items, A(dst).Items, m.RunCnt);
112104
}
113105
}
114106

src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs

Lines changed: 20 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public static partial class CpuMathUtils
2424
public static int GetVectorAlignment()
2525
=> Avx.IsSupported ? Vector256Alignment : (Sse.IsSupported ? Vector128Alignment : FloatAlignment);
2626

27-
public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
27+
public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
2828
{
2929
Contracts.Assert(mat.Size == dst.Size * src.Size);
3030
Contracts.Assert(crun >= 0);
@@ -34,25 +34,25 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr
3434
if (!tran)
3535
{
3636
Contracts.Assert(crun <= dst.Size);
37-
AvxIntrinsics.MatMulX(add, mat, src, dst, crun, src.Size);
37+
AvxIntrinsics.MatMulX(mat, src, dst, crun, src.Size);
3838
}
3939
else
4040
{
4141
Contracts.Assert(crun <= src.Size);
42-
AvxIntrinsics.MatMulTranX(add, mat, src, dst, dst.Size, crun);
42+
AvxIntrinsics.MatMulTranX(mat, src, dst, dst.Size, crun);
4343
}
4444
}
4545
else if (Sse.IsSupported)
4646
{
4747
if (!tran)
4848
{
4949
Contracts.Assert(crun <= dst.Size);
50-
SseIntrinsics.MatMulA(add, mat, src, dst, crun, src.Size);
50+
SseIntrinsics.MatMul(mat, src, dst, crun, src.Size);
5151
}
5252
else
5353
{
5454
Contracts.Assert(crun <= src.Size);
55-
SseIntrinsics.MatMulTranA(add, mat, src, dst, dst.Size, crun);
55+
SseIntrinsics.MatMulTran(mat, src, dst, dst.Size, crun);
5656
}
5757
}
5858
else
@@ -68,14 +68,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr
6868
dotProduct += mat[i * src.Size + j] * src[j];
6969
}
7070

71-
if (add)
72-
{
73-
dst[i] += dotProduct;
74-
}
75-
else
76-
{
77-
dst[i] = dotProduct;
78-
}
71+
dst[i] = dotProduct;
7972
}
8073
}
8174
else
@@ -89,20 +82,13 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr
8982
dotProduct += mat[j * src.Size + i] * src[j];
9083
}
9184

92-
if (add)
93-
{
94-
dst[i] += dotProduct;
95-
}
96-
else
97-
{
98-
dst[i] = dotProduct;
99-
}
85+
dst[i] = dotProduct;
10086
}
10187
}
10288
}
10389
}
10490

105-
public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
91+
public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
10692
int posMin, int iposMin, int iposLim, AlignedArray dst, int crun)
10793
{
10894
Contracts.AssertValue(rgposSrc);
@@ -113,8 +99,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo
11399

114100
if (iposMin >= iposLim)
115101
{
116-
if (!add)
117-
dst.ZeroItems();
102+
dst.ZeroItems();
118103
return;
119104
}
120105

@@ -123,76 +108,26 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo
123108

124109
if (Avx.IsSupported)
125110
{
126-
if (!tran)
127-
{
128-
Contracts.Assert(crun <= dst.Size);
129-
AvxIntrinsics.MatMulPX(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
130-
}
131-
else
132-
{
133-
Contracts.Assert(crun <= srcValues.Size);
134-
AvxIntrinsics.MatMulTranPX(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size);
135-
}
111+
Contracts.Assert(crun <= dst.Size);
112+
AvxIntrinsics.MatMulPX(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
136113
}
137114
else if (Sse.IsSupported)
138115
{
139-
if (!tran)
140-
{
141-
Contracts.Assert(crun <= dst.Size);
142-
SseIntrinsics.MatMulPA(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
143-
}
144-
else
145-
{
146-
Contracts.Assert(crun <= srcValues.Size);
147-
SseIntrinsics.MatMulTranPA(add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, dst.Size);
148-
}
116+
Contracts.Assert(crun <= dst.Size);
117+
SseIntrinsics.MatMulPA(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun, srcValues.Size);
149118
}
150119
else
151120
{
152-
if (!tran)
153-
{
154-
Contracts.Assert(crun <= dst.Size);
155-
for (int i = 0; i < crun; i++)
156-
{
157-
float dotProduct = 0;
158-
for (int j = iposMin; j < iposLim; j++)
159-
{
160-
int col = rgposSrc[j] - posMin;
161-
dotProduct += mat[i * srcValues.Size + col] * srcValues[col];
162-
}
163-
164-
if (add)
165-
{
166-
dst[i] += dotProduct;
167-
}
168-
else
169-
{
170-
dst[i] = dotProduct;
171-
}
172-
}
173-
}
174-
else
121+
Contracts.Assert(crun <= dst.Size);
122+
for (int i = 0; i < crun; i++)
175123
{
176-
Contracts.Assert(crun <= srcValues.Size);
177-
for (int i = 0; i < dst.Size; i++)
124+
float dotProduct = 0;
125+
for (int j = iposMin; j < iposLim; j++)
178126
{
179-
float dotProduct = 0;
180-
for (int j = iposMin; j < iposLim; j++)
181-
{
182-
int col = rgposSrc[j] - posMin;
183-
dotProduct += mat[col * dst.Size + i] * srcValues[col];
184-
}
185-
186-
if (add)
187-
{
188-
dst[i] += dotProduct;
189-
}
190-
else
191-
{
192-
dst[i] = dotProduct;
193-
}
127+
int col = rgposSrc[j] - posMin;
128+
dotProduct += mat[i * srcValues.Size + col] * srcValues[col];
194129
}
195-
130+
dst[i] = dotProduct;
196131
}
197132
}
198133
}

src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ public static partial class CpuMathUtils
1616
public static int GetVectorAlignment()
1717
=> Vector128Alignment;
1818

19-
public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, add, mat, src, dst, crun);
19+
public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, mat, src, dst, crun);
2020

21-
public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
22-
int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(tran, add, mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun);
21+
public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
22+
int posMin, int iposMin, int iposLim, AlignedArray dst, int crun) => SseUtils.MatTimesSrc(mat, rgposSrc, srcValues, posMin, iposMin, iposLim, dst, crun);
2323

2424
public static void Add(float a, Span<float> dst) => SseUtils.Add(a, dst);
2525

src/Microsoft.ML.CpuMath/Sse.cs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ private static bool Compat(AlignedArray a)
3030
return q;
3131
}
3232

33-
public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
33+
public static void MatTimesSrc(bool tran, AlignedArray mat, AlignedArray src, AlignedArray dst, int crun)
3434
{
3535
Contracts.Assert(Compat(mat));
3636
Contracts.Assert(Compat(src));
@@ -46,18 +46,18 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, AlignedArr
4646
if (!tran)
4747
{
4848
Contracts.Assert(0 <= crun && crun <= dst.Size);
49-
Thunk.MatMulA(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size);
49+
Thunk.MatMul(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), crun, src.Size);
5050
}
5151
else
5252
{
5353
Contracts.Assert(0 <= crun && crun <= src.Size);
54-
Thunk.MatMulTranA(add, Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun);
54+
Thunk.MatMulTran(Ptr(mat, pmat), Ptr(src, psrc), Ptr(dst, pdst), dst.Size, crun);
5555
}
5656
}
5757
}
5858
}
5959

60-
public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
60+
public static void MatTimesSrc(AlignedArray mat, int[] rgposSrc, AlignedArray srcValues,
6161
int posMin, int iposMin, int iposLim, AlignedArray dst, int crun)
6262
{
6363
Contracts.Assert(Compat(mat));
@@ -69,8 +69,7 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo
6969

7070
if (iposMin >= iposLim)
7171
{
72-
if (!add)
73-
dst.ZeroItems();
72+
dst.ZeroItems();
7473
return;
7574
}
7675
Contracts.AssertNonEmpty(rgposSrc);
@@ -81,16 +80,8 @@ public static void MatTimesSrc(bool tran, bool add, AlignedArray mat, int[] rgpo
8180
fixed (float* psrc = &srcValues.Items[0])
8281
fixed (int* ppossrc = &rgposSrc[0])
8382
{
84-
if (!tran)
85-
{
86-
Contracts.Assert(0 <= crun && crun <= dst.Size);
87-
Thunk.MatMulPA(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size);
88-
}
89-
else
90-
{
91-
Contracts.Assert(0 <= crun && crun <= srcValues.Size);
92-
Thunk.MatMulTranPA(add, Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), dst.Size);
93-
}
83+
Contracts.Assert(0 <= crun && crun <= dst.Size);
84+
Thunk.MatMulPA(Ptr(mat, pmat), ppossrc, Ptr(srcValues, psrc), posMin, iposMin, iposLim, Ptr(dst, pdst), crun, srcValues.Size);
9485
}
9586
}
9687
}

0 commit comments

Comments
 (0)