Skip to content

Commit 3ce839e

Browse files
jwood803Zruty0
authored andcommitted
Add debug asserts (#1566)
* Add debug asserts * Add debug asserts to AvxIntrinsics class
1 parent a1f5ac3 commit 3ce839e

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

src/Microsoft.ML.CpuMath/AvxIntrinsics.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,16 @@ private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<flo
153153
// Multiply matrix times vector into vector.
154154
public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
155155
{
156+
Contracts.Assert(src.Size == dst.Size);
157+
156158
MatMul(mat.Items, src.Items, dst.Items, crow, ccol);
157159
}
158160

159161
public static unsafe void MatMul(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
160162
{
161163
Contracts.Assert(crow % 4 == 0);
162164
Contracts.Assert(ccol % 4 == 0);
165+
Contracts.Assert(src.Length == dst.Length);
163166

164167
fixed (float* psrc = &MemoryMarshal.GetReference(src))
165168
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
@@ -307,6 +310,8 @@ public static unsafe void MatMul(ReadOnlySpan<float> mat, ReadOnlySpan<float> sr
307310
public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan<int> rgposSrc, AlignedArray src,
308311
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
309312
{
313+
Contracts.Assert(src.Size == dst.Size);
314+
310315
MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol);
311316
}
312317

@@ -315,6 +320,7 @@ public static unsafe void MatMulP(ReadOnlySpan<float> mat, ReadOnlySpan<int> rgp
315320
{
316321
Contracts.Assert(crow % 8 == 0);
317322
Contracts.Assert(ccol % 8 == 0);
323+
Contracts.Assert(src.Length == dst.Length);
318324

319325
// REVIEW: For extremely sparse inputs, interchanging the loops would
320326
// likely be more efficient.
@@ -468,13 +474,16 @@ Vector256<float> SparseMultiplicationAcrossRow()
468474

469475
public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
470476
{
477+
Contracts.Assert(src.Size == dst.Size);
478+
471479
MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol);
472480
}
473481

474482
public static unsafe void MatMulTran(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
475483
{
476484
Contracts.Assert(crow % 4 == 0);
477485
Contracts.Assert(ccol % 4 == 0);
486+
Contracts.Assert(src.Length == dst.Length);
478487

479488
fixed (float* psrc = &MemoryMarshal.GetReference(src))
480489
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
@@ -951,6 +960,8 @@ public static unsafe void Scale(float scale, Span<float> dst)
951960

952961
public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
953962
{
963+
Contracts.Assert(src.Length == dst.Length);
964+
954965
fixed (float* psrc = &MemoryMarshal.GetReference(src))
955966
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
956967
{
@@ -1044,6 +1055,8 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
10441055

10451056
public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
10461057
{
1058+
Contracts.Assert(src.Length == dst.Length);
1059+
10471060
fixed (float* psrc = &MemoryMarshal.GetReference(src))
10481061
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
10491062
{
@@ -1096,6 +1109,8 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<f
10961109

10971110
public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<float> dst, Span<float> result, int count)
10981111
{
1112+
Contracts.Assert(src.Length == dst.Length);
1113+
10991114
fixed (float* psrc = &MemoryMarshal.GetReference(src))
11001115
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
11011116
fixed (float* pres = &MemoryMarshal.GetReference(result))
@@ -1150,6 +1165,8 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, Re
11501165

11511166
public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
11521167
{
1168+
Contracts.Assert(src.Length == dst.Length);
1169+
11531170
fixed (float* psrc = &MemoryMarshal.GetReference(src))
11541171
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
11551172
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
@@ -1198,6 +1215,8 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadO
11981215

11991216
public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int count)
12001217
{
1218+
Contracts.Assert(src.Length == dst.Length);
1219+
12011220
fixed (float* psrc = &MemoryMarshal.GetReference(src))
12021221
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
12031222
{
@@ -1245,6 +1264,8 @@ public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int cou
12451264

12461265
public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
12471266
{
1267+
Contracts.Assert(src.Length == dst.Length);
1268+
12481269
fixed (float* psrc = &MemoryMarshal.GetReference(src))
12491270
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
12501271
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
@@ -1726,6 +1747,8 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
17261747

17271748
public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
17281749
{
1750+
Contracts.Assert(src.Length == dst.Length);
1751+
17291752
fixed (float* psrc = &MemoryMarshal.GetReference(src))
17301753
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
17311754
{
@@ -1778,6 +1801,8 @@ public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst
17781801

17791802
public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, ReadOnlySpan<int> idx, int count)
17801803
{
1804+
Contracts.Assert(src.Length == dst.Length);
1805+
17811806
fixed (float* psrc = &MemoryMarshal.GetReference(src))
17821807
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
17831808
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
@@ -1832,6 +1857,8 @@ public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> ds
18321857

18331858
public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
18341859
{
1860+
Contracts.Assert(src.Length == dst.Length);
1861+
18351862
fixed (float* psrc = &MemoryMarshal.GetReference(src))
18361863
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
18371864
{

src/Microsoft.ML.CpuMath/SseIntrinsics.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,16 @@ internal static Vector128<float> GetNewDst128(in Vector128<float> xDst1, in Vect
117117
// Multiply matrix times vector into vector.
118118
public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
119119
{
120+
Contracts.Assert(src.Size == dst.Size);
121+
120122
MatMul(mat.Items, src.Items, dst.Items, crow, ccol);
121123
}
122124

123125
public static unsafe void MatMul(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
124126
{
125127
Contracts.Assert(crow % 4 == 0);
126128
Contracts.Assert(ccol % 4 == 0);
129+
Contracts.Assert(src.Length == dst.Length);
127130

128131
fixed (float* psrc = &MemoryMarshal.GetReference(src))
129132
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
@@ -279,12 +282,15 @@ public static unsafe void MatMul(ReadOnlySpan<float> mat, ReadOnlySpan<float> sr
279282
public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan<int> rgposSrc, AlignedArray src,
280283
int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol)
281284
{
285+
Contracts.Assert(src.Size == dst.Size);
286+
282287
MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol);
283288
}
284289

285290
public static unsafe void MatMulP(ReadOnlySpan<float> mat, ReadOnlySpan<int> rgposSrc, ReadOnlySpan<float> src,
286291
int posMin, int iposMin, int iposEnd, Span<float> dst, int crow, int ccol)
287292
{
293+
Contracts.Assert(src.Length == dst.Length);
288294
Contracts.Assert(crow % 4 == 0);
289295
Contracts.Assert(ccol % 4 == 0);
290296

@@ -443,11 +449,14 @@ Vector128<float> SparseMultiplicationAcrossRow()
443449

444450
public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol)
445451
{
452+
Contracts.Assert(src.Size == dst.Size);
453+
446454
MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol);
447455
}
448456

449457
public static unsafe void MatMulTran(ReadOnlySpan<float> mat, ReadOnlySpan<float> src, Span<float> dst, int crow, int ccol)
450458
{
459+
Contracts.Assert(src.Length == dst.Length);
451460
Contracts.Assert(crow % 4 == 0);
452461
Contracts.Assert(ccol % 4 == 0);
453462

@@ -893,6 +902,8 @@ public static unsafe void Scale(float scale, Span<float> dst)
893902

894903
public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
895904
{
905+
Contracts.Assert(src.Length == dst.Length);
906+
896907
fixed (float* psrc = &MemoryMarshal.GetReference(src))
897908
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
898909
{
@@ -961,6 +972,8 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
961972

962973
public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<float> dst, int count)
963974
{
975+
Contracts.Assert(src.Length == dst.Length);
976+
964977
fixed (float* psrc = &MemoryMarshal.GetReference(src))
965978
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
966979
{
@@ -1000,6 +1013,8 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<f
10001013

10011014
public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<float> dst, Span<float> result, int count)
10021015
{
1016+
Contracts.Assert(src.Length == dst.Length);
1017+
10031018
fixed (float* psrc = &MemoryMarshal.GetReference(src))
10041019
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
10051020
fixed (float* pres = &MemoryMarshal.GetReference(result))
@@ -1041,6 +1056,8 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, Re
10411056

10421057
public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
10431058
{
1059+
Contracts.Assert(src.Length == dst.Length);
1060+
10441061
fixed (float* psrc = &MemoryMarshal.GetReference(src))
10451062
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
10461063
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
@@ -1077,6 +1094,8 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadO
10771094

10781095
public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int count)
10791096
{
1097+
Contracts.Assert(src.Length == dst.Length);
1098+
10801099
fixed (float* psrc = &MemoryMarshal.GetReference(src))
10811100
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
10821101
{
@@ -1112,6 +1131,8 @@ public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int cou
11121131

11131132
public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, Span<float> dst, int count)
11141133
{
1134+
Contracts.Assert(src.Length == dst.Length);
1135+
11151136
fixed (float* psrc = &MemoryMarshal.GetReference(src))
11161137
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
11171138
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
@@ -1145,6 +1166,9 @@ public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx,
11451166

11461167
public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan<float> src2, Span<float> dst, int count)
11471168
{
1169+
Contracts.Assert(src1.Length == dst.Length);
1170+
Contracts.Assert(src2.Length == dst.Length);
1171+
11481172
fixed (float* psrc1 = &MemoryMarshal.GetReference(src1))
11491173
fixed (float* psrc2 = &MemoryMarshal.GetReference(src2))
11501174
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
@@ -1479,6 +1503,8 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
14791503

14801504
public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
14811505
{
1506+
Contracts.Assert(src.Length == dst.Length);
1507+
14821508
fixed (float* psrc = &MemoryMarshal.GetReference(src))
14831509
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
14841510
{
@@ -1518,6 +1544,8 @@ public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst
15181544

15191545
public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, ReadOnlySpan<int> idx, int count)
15201546
{
1547+
Contracts.Assert(src.Length == dst.Length);
1548+
15211549
fixed (float* psrc = &MemoryMarshal.GetReference(src))
15221550
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
15231551
fixed (int* pidx = &MemoryMarshal.GetReference(idx))
@@ -1559,6 +1587,8 @@ public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> ds
15591587

15601588
public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> dst, int count)
15611589
{
1590+
Contracts.Assert(src.Length == dst.Length);
1591+
15621592
fixed (float* psrc = &MemoryMarshal.GetReference(src))
15631593
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
15641594
{

0 commit comments

Comments
 (0)