diff --git a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs index a3f0582800..5127a92fd0 100644 --- a/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/AvxIntrinsics.cs @@ -153,6 +153,8 @@ private static Vector256 MultiplyAdd(Vector256 src1, Vector256 mat, ReadOnlySpan sr { Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); + Contracts.Assert(src.Length == dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -307,6 +310,8 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr public static unsafe void MatMulP(AlignedArray mat, ReadOnlySpan rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { + Contracts.Assert(src.Size == dst.Size); + MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); } @@ -315,6 +320,7 @@ public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgp { Contracts.Assert(crow % 8 == 0); Contracts.Assert(ccol % 8 == 0); + Contracts.Assert(src.Length == dst.Length); // REVIEW: For extremely sparse inputs, interchanging the loops would // likely be more efficient. @@ -468,6 +474,8 @@ Vector256 SparseMultiplicationAcrossRow() public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { + Contracts.Assert(src.Size == dst.Size); + MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); } @@ -475,6 +483,7 @@ public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan dst) public static unsafe void ScaleSrcU(float scale, ReadOnlySpan src, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1044,6 +1055,8 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1096,6 +1109,8 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span src, ReadOnlySpan dst, Span result, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) fixed (float* pres = &MemoryMarshal.GetReference(result)) @@ -1150,6 +1165,8 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan src, Re public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadOnlySpan idx, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1198,6 +1215,8 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadO public static unsafe void AddU(ReadOnlySpan src, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1245,6 +1264,8 @@ public static unsafe void AddU(ReadOnlySpan src, Span dst, int cou public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1726,6 +1747,8 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan src) public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1778,6 +1801,8 @@ public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan dst, ReadOnlySpan idx, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) @@ -1832,6 +1857,8 @@ public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan ds public static unsafe float Dist2(ReadOnlySpan src, ReadOnlySpan dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { diff --git a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs index cf85a98132..93be0b9d1b 100644 --- a/src/Microsoft.ML.CpuMath/SseIntrinsics.cs +++ b/src/Microsoft.ML.CpuMath/SseIntrinsics.cs @@ -117,6 +117,8 @@ internal static Vector128 GetNewDst128(in Vector128 xDst1, in Vect // Multiply matrix times vector into vector. public static unsafe void MatMul(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { + Contracts.Assert(src.Size == dst.Size); + MatMul(mat.Items, src.Items, dst.Items, crow, ccol); } @@ -124,6 +126,7 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr { Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); + Contracts.Assert(src.Length == dst.Length); fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -279,12 +282,15 @@ public static unsafe void MatMul(ReadOnlySpan mat, ReadOnlySpan sr public static unsafe void MatMulP(AlignedArray mat, int[] rgposSrc, AlignedArray src, int posMin, int iposMin, int iposEnd, AlignedArray dst, int crow, int ccol) { + Contracts.Assert(src.Size == dst.Size); + MatMulP(mat.Items, rgposSrc, src.Items, posMin, iposMin, iposEnd, dst.Items, crow, ccol); } public static unsafe void MatMulP(ReadOnlySpan mat, ReadOnlySpan rgposSrc, ReadOnlySpan src, int posMin, int iposMin, int iposEnd, Span dst, int crow, int ccol) { + Contracts.Assert(src.Length == dst.Length); Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); @@ -443,11 +449,14 @@ Vector128 SparseMultiplicationAcrossRow() public static unsafe void MatMulTran(AlignedArray mat, AlignedArray src, AlignedArray dst, int crow, int ccol) { + Contracts.Assert(src.Size == dst.Size); + MatMulTran(mat.Items, src.Items, dst.Items, crow, ccol); } public static unsafe void MatMulTran(ReadOnlySpan mat, ReadOnlySpan src, Span dst, int crow, int ccol) { + Contracts.Assert(src.Length == dst.Length); Contracts.Assert(crow % 4 == 0); Contracts.Assert(ccol % 4 == 0); @@ -893,6 +902,8 @@ public static unsafe void Scale(float scale, Span dst) public static unsafe void ScaleSrcU(float scale, ReadOnlySpan src, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -961,6 +972,8 @@ public static unsafe void ScaleAddU(float a, float b, Span dst) public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1000,6 +1013,8 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan src, Span src, ReadOnlySpan dst, Span result, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) fixed (float* pres = &MemoryMarshal.GetReference(result)) @@ -1041,6 +1056,8 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan src, Re public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadOnlySpan idx, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1077,6 +1094,8 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan src, ReadO public static unsafe void AddU(ReadOnlySpan src, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1112,6 +1131,8 @@ public static unsafe void AddU(ReadOnlySpan src, Span dst, int cou public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, Span dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1145,6 +1166,9 @@ public static unsafe void AddSU(ReadOnlySpan src, ReadOnlySpan idx, public static unsafe void MulElementWiseU(ReadOnlySpan src1, ReadOnlySpan src2, Span dst, int count) { + Contracts.Assert(src1.Length == dst.Length); + Contracts.Assert(src2.Length == dst.Length); + fixed (float* psrc1 = &MemoryMarshal.GetReference(src1)) fixed (float* psrc2 = &MemoryMarshal.GetReference(src2)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) @@ -1479,6 +1503,8 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan src) public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) { @@ -1518,6 +1544,8 @@ public static unsafe float DotU(ReadOnlySpan src, ReadOnlySpan dst public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan dst, ReadOnlySpan idx, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) fixed (int* pidx = &MemoryMarshal.GetReference(idx)) @@ -1559,6 +1587,8 @@ public static unsafe float DotSU(ReadOnlySpan src, ReadOnlySpan ds public static unsafe float Dist2(ReadOnlySpan src, ReadOnlySpan dst, int count) { + Contracts.Assert(src.Length == dst.Length); + fixed (float* psrc = &MemoryMarshal.GetReference(src)) fixed (float* pdst = &MemoryMarshal.GetReference(dst)) {