Skip to content

Commit da7b5a7

Browse files
committed
Change bound checking in SSE/AVX intrinsics to avoid integer overflow
1 parent 9383dd1 commit da7b5a7

File tree

2 files changed

+66
-66
lines changed

2 files changed

+66
-66
lines changed

src/Microsoft.ML.CpuMath/AvxIntrinsics.cs

Lines changed: 44 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,7 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)
420420

421421
Vector256<float> scalarVector256 = Avx.SetAllVector256(scalar);
422422

423-
while (pDstCurrent + 8 <= pDstEnd)
423+
while (pDstEnd - pDstCurrent >= 8)
424424
{
425425
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
426426
dstVector = Avx.Add(dstVector, scalarVector256);
@@ -431,7 +431,7 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)
431431

432432
Vector128<float> scalarVector128 = Sse.SetAllVector128(scalar);
433433

434-
if (pDstCurrent + 4 <= pDstEnd)
434+
if (pDstEnd - pDstCurrent >= 4)
435435
{
436436
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
437437
dstVector = Sse.Add(dstVector, scalarVector128);
@@ -460,7 +460,7 @@ public static unsafe void ScaleU(float scale, Span<float> dst)
460460

461461
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
462462

463-
while (pDstCurrent + 8 <= pEnd)
463+
while (pEnd - pDstCurrent >= 8)
464464
{
465465
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
466466

@@ -472,7 +472,7 @@ public static unsafe void ScaleU(float scale, Span<float> dst)
472472

473473
Vector128<float> scaleVector128 = Sse.SetAllVector128(scale);
474474

475-
if (pDstCurrent + 4 <= pEnd)
475+
if (pEnd - pDstCurrent >= 4)
476476
{
477477
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
478478

@@ -505,7 +505,7 @@ public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> ds
505505

506506
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
507507

508-
while (pDstCurrent + 8 <= pDstEnd)
508+
while (pDstEnd - pDstCurrent >= 8)
509509
{
510510
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
511511
srcVector = Avx.Multiply(srcVector, scaleVector256);
@@ -517,7 +517,7 @@ public static unsafe void ScaleSrcU(float scale, Span<float> src, Span<float> ds
517517

518518
Vector128<float> scaleVector128 = Sse.SetAllVector128(scale);
519519

520-
if (pDstCurrent + 4 <= pDstEnd)
520+
if (pDstEnd - pDstCurrent >= 4)
521521
{
522522
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
523523
srcVector = Sse.Multiply(srcVector, scaleVector128);
@@ -550,7 +550,7 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
550550
Vector256<float> a256 = Avx.SetAllVector256(a);
551551
Vector256<float> b256 = Avx.SetAllVector256(b);
552552

553-
while (pDstCurrent + 8 <= pDstEnd)
553+
while (pDstEnd - pDstCurrent >= 8)
554554
{
555555
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
556556
dstVector = Avx.Add(dstVector, b256);
@@ -563,7 +563,7 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
563563
Vector128<float> a128 = Sse.SetAllVector128(a);
564564
Vector128<float> b128 = Sse.SetAllVector128(b);
565565

566-
if (pDstCurrent + 4 <= pDstEnd)
566+
if (pDstEnd - pDstCurrent >= 4)
567567
{
568568
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
569569
dstVector = Sse.Add(dstVector, b128);
@@ -596,7 +596,7 @@ public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> ds
596596

597597
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
598598

599-
while (pDstCurrent + 8 <= pEnd)
599+
while (pEnd - pDstCurrent >= 8)
600600
{
601601
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
602602
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
@@ -611,7 +611,7 @@ public static unsafe void AddScaleU(float scale, Span<float> src, Span<float> ds
611611

612612
Vector128<float> scaleVector128 = Sse.SetAllVector128(scale);
613613

614-
if (pDstCurrent + 4 <= pEnd)
614+
if (pEnd - pDstCurrent >= 4)
615615
{
616616
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
617617
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
@@ -652,7 +652,7 @@ public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float
652652

653653
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
654654

655-
while (pResCurrent + 8 <= pResEnd)
655+
while (pResEnd - pResCurrent >= 8)
656656
{
657657
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
658658
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
@@ -667,7 +667,7 @@ public static unsafe void AddScaleCopyU(float scale, Span<float> src, Span<float
667667

668668
Vector128<float> scaleVector128 = Sse.SetAllVector128(scale);
669669

670-
if (pResCurrent + 4 <= pResEnd)
670+
if (pResEnd - pResCurrent >= 4)
671671
{
672672
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
673673
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
@@ -708,7 +708,7 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx
708708

709709
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
710710

711-
while (pIdxCurrent + 8 <= pEnd)
711+
while (pEnd - pIdxCurrent >= 8)
712712
{
713713
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
714714
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
@@ -723,7 +723,7 @@ public static unsafe void AddScaleSU(float scale, Span<float> src, Span<int> idx
723723

724724
Vector128<float> scaleVector128 = Sse.SetAllVector128(scale);
725725

726-
if (pIdxCurrent + 4 <= pEnd)
726+
if (pEnd - pIdxCurrent >= 4)
727727
{
728728
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
729729
Vector128<float> dstVector = SseIntrinsics.Load4(pDstCurrent, pIdxCurrent);
@@ -755,7 +755,7 @@ public static unsafe void AddU(Span<float> src, Span<float> dst)
755755
float* pDstCurrent = pdst;
756756
float* pEnd = psrc + src.Length;
757757

758-
while (pSrcCurrent + 8 <= pEnd)
758+
while (pEnd - pSrcCurrent >= 8)
759759
{
760760
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
761761
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
@@ -767,7 +767,7 @@ public static unsafe void AddU(Span<float> src, Span<float> dst)
767767
pDstCurrent += 8;
768768
}
769769

770-
if (pSrcCurrent + 4 <= pEnd)
770+
if (pEnd - pSrcCurrent >= 4)
771771
{
772772
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
773773
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
@@ -804,7 +804,7 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
804804
float* pDstCurrent = pdst;
805805
int* pEnd = pidx + idx.Length;
806806

807-
while (pIdxCurrent + 8 <= pEnd)
807+
while (pEnd - pIdxCurrent >= 8)
808808
{
809809
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
810810
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
@@ -816,7 +816,7 @@ public static unsafe void AddSU(Span<float> src, Span<int> idx, Span<float> dst)
816816
pSrcCurrent += 8;
817817
}
818818

819-
if (pIdxCurrent + 4 <= pEnd)
819+
if (pEnd - pIdxCurrent >= 4)
820820
{
821821
Vector128<float> dstVector = SseIntrinsics.Load4(pDstCurrent, pIdxCurrent);
822822
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
@@ -849,7 +849,7 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp
849849
float* pDstCurrent = pdst;
850850
float* pEnd = pdst + dst.Length;
851851

852-
while (pDstCurrent + 8 <= pEnd)
852+
while (pEnd - pDstCurrent >= 8)
853853
{
854854
Vector256<float> src1Vector = Avx.LoadVector256(pSrc1Current);
855855
Vector256<float> src2Vector = Avx.LoadVector256(pSrc2Current);
@@ -861,7 +861,7 @@ public static unsafe void MulElementWiseU(Span<float> src1, Span<float> src2, Sp
861861
pDstCurrent += 8;
862862
}
863863

864-
if (pDstCurrent + 4 <= pEnd)
864+
if (pEnd - pDstCurrent >= 4)
865865
{
866866
Vector128<float> src1Vector = Sse.LoadVector128(pSrc1Current);
867867
Vector128<float> src2Vector = Sse.LoadVector128(pSrc2Current);
@@ -896,7 +896,7 @@ public static unsafe float SumU(Span<float> src)
896896

897897
Vector256<float> result256 = Avx.SetZeroVector256<float>();
898898

899-
while (pSrcCurrent + 8 <= pSrcEnd)
899+
while (pSrcEnd - pSrcCurrent >= 8)
900900
{
901901
result256 = Avx.Add(result256, Avx.LoadVector256(pSrcCurrent));
902902
pSrcCurrent += 8;
@@ -907,7 +907,7 @@ public static unsafe float SumU(Span<float> src)
907907

908908
Vector128<float> result128 = Sse.SetZeroVector128();
909909

910-
if (pSrcCurrent + 4 <= pSrcEnd)
910+
if (pSrcEnd - pSrcCurrent >= 4)
911911
{
912912
result128 = Sse.Add(result128, Sse.LoadVector128(pSrcCurrent));
913913
pSrcCurrent += 4;
@@ -934,7 +934,7 @@ public static unsafe float SumSqU(Span<float> src)
934934

935935
Vector256<float> result256 = Avx.SetZeroVector256<float>();
936936

937-
while (pSrcCurrent + 8 <= pSrcEnd)
937+
while (pSrcEnd - pSrcCurrent >= 8)
938938
{
939939
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
940940
result256 = Avx.Add(result256, Avx.Multiply(srcVector, srcVector));
@@ -947,7 +947,7 @@ public static unsafe float SumSqU(Span<float> src)
947947

948948
Vector128<float> result128 = Sse.SetZeroVector128();
949949

950-
if (pSrcCurrent + 4 <= pSrcEnd)
950+
if (pSrcEnd - pSrcCurrent >= 4)
951951
{
952952
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
953953
result128 = Sse.Add(result128, Sse.Multiply(srcVector, srcVector));
@@ -979,7 +979,7 @@ public static unsafe float SumSqDiffU(float mean, Span<float> src)
979979
Vector256<float> result256 = Avx.SetZeroVector256<float>();
980980
Vector256<float> meanVector256 = Avx.SetAllVector256(mean);
981981

982-
while (pSrcCurrent + 8 <= pSrcEnd)
982+
while (pSrcEnd - pSrcCurrent >= 8)
983983
{
984984
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
985985
srcVector = Avx.Subtract(srcVector, meanVector256);
@@ -994,7 +994,7 @@ public static unsafe float SumSqDiffU(float mean, Span<float> src)
994994
Vector128<float> result128 = Sse.SetZeroVector128();
995995
Vector128<float> meanVector128 = Sse.SetAllVector128(mean);
996996

997-
if (pSrcCurrent + 4 <= pSrcEnd)
997+
if (pSrcEnd - pSrcCurrent >= 4)
998998
{
999999
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
10001000
srcVector = Sse.Subtract(srcVector, meanVector128);
@@ -1027,7 +1027,7 @@ public static unsafe float SumAbsU(Span<float> src)
10271027

10281028
Vector256<float> result256 = Avx.SetZeroVector256<float>();
10291029

1030-
while (pSrcCurrent + 8 <= pSrcEnd)
1030+
while (pSrcEnd - pSrcCurrent >= 8)
10311031
{
10321032
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
10331033
result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256));
@@ -1040,7 +1040,7 @@ public static unsafe float SumAbsU(Span<float> src)
10401040

10411041
Vector128<float> result128 = Sse.SetZeroVector128();
10421042

1043-
if (pSrcCurrent + 4 <= pSrcEnd)
1043+
if (pSrcEnd - pSrcCurrent >= 4)
10441044
{
10451045
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
10461046
result128 = Sse.Add(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
@@ -1072,7 +1072,7 @@ public static unsafe float SumAbsDiffU(float mean, Span<float> src)
10721072
Vector256<float> result256 = Avx.SetZeroVector256<float>();
10731073
Vector256<float> meanVector256 = Avx.SetAllVector256(mean);
10741074

1075-
while (pSrcCurrent + 8 <= pSrcEnd)
1075+
while (pSrcEnd - pSrcCurrent >= 8)
10761076
{
10771077
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
10781078
srcVector = Avx.Subtract(srcVector, meanVector256);
@@ -1087,7 +1087,7 @@ public static unsafe float SumAbsDiffU(float mean, Span<float> src)
10871087
Vector128<float> result128 = Sse.SetZeroVector128();
10881088
Vector128<float> meanVector128 = Sse.SetAllVector128(mean);
10891089

1090-
if (pSrcCurrent + 4 <= pSrcEnd)
1090+
if (pSrcEnd - pSrcCurrent >= 4)
10911091
{
10921092
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
10931093
srcVector = Sse.Subtract(srcVector, meanVector128);
@@ -1120,7 +1120,7 @@ public static unsafe float MaxAbsU(Span<float> src)
11201120

11211121
Vector256<float> result256 = Avx.SetZeroVector256<float>();
11221122

1123-
while (pSrcCurrent + 8 <= pSrcEnd)
1123+
while (pSrcEnd - pSrcCurrent >= 8)
11241124
{
11251125
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
11261126
result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256));
@@ -1133,7 +1133,7 @@ public static unsafe float MaxAbsU(Span<float> src)
11331133

11341134
Vector128<float> result128 = Sse.SetZeroVector128();
11351135

1136-
if (pSrcCurrent + 4 <= pSrcEnd)
1136+
if (pSrcEnd - pSrcCurrent >= 4)
11371137
{
11381138
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
11391139
result128 = Sse.Max(result128, Sse.And(srcVector, SseIntrinsics.AbsMask128));
@@ -1165,7 +1165,7 @@ public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
11651165
Vector256<float> result256 = Avx.SetZeroVector256<float>();
11661166
Vector256<float> meanVector256 = Avx.SetAllVector256(mean);
11671167

1168-
while (pSrcCurrent + 8 <= pSrcEnd)
1168+
while (pSrcEnd - pSrcCurrent >= 8)
11691169
{
11701170
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
11711171
srcVector = Avx.Subtract(srcVector, meanVector256);
@@ -1180,7 +1180,7 @@ public static unsafe float MaxAbsDiffU(float mean, Span<float> src)
11801180
Vector128<float> result128 = Sse.SetZeroVector128();
11811181
Vector128<float> meanVector128 = Sse.SetAllVector128(mean);
11821182

1183-
if (pSrcCurrent + 4 <= pSrcEnd)
1183+
if (pSrcEnd - pSrcCurrent >= 4)
11841184
{
11851185
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
11861186
srcVector = Sse.Subtract(srcVector, meanVector128);
@@ -1215,7 +1215,7 @@ public static unsafe float DotU(Span<float> src, Span<float> dst)
12151215

12161216
Vector256<float> result256 = Avx.SetZeroVector256<float>();
12171217

1218-
while (pSrcCurrent + 8 <= pSrcEnd)
1218+
while (pSrcEnd - pSrcCurrent >= 8)
12191219
{
12201220
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
12211221
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
@@ -1231,7 +1231,7 @@ public static unsafe float DotU(Span<float> src, Span<float> dst)
12311231

12321232
Vector128<float> result128 = Sse.SetZeroVector128();
12331233

1234-
if (pSrcCurrent + 4 <= pSrcEnd)
1234+
if (pSrcEnd - pSrcCurrent >= 4)
12351235
{
12361236
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
12371237
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
@@ -1272,7 +1272,7 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx
12721272

12731273
Vector256<float> result256 = Avx.SetZeroVector256<float>();
12741274

1275-
while (pIdxCurrent + 8 <= pIdxEnd)
1275+
while (pIdxEnd - pIdxCurrent >= 8)
12761276
{
12771277
Vector256<float> srcVector = Load8(pSrcCurrent, pIdxCurrent);
12781278
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
@@ -1288,7 +1288,7 @@ public static unsafe float DotSU(Span<float> src, Span<float> dst, Span<int> idx
12881288

12891289
Vector128<float> result128 = Sse.SetZeroVector128();
12901290

1291-
if (pIdxCurrent + 4 <= pIdxEnd)
1291+
if (pIdxEnd - pIdxCurrent >= 4)
12921292
{
12931293
Vector128<float> srcVector = SseIntrinsics.Load4(pSrcCurrent, pIdxCurrent);
12941294
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
@@ -1327,7 +1327,7 @@ public static unsafe float Dist2(Span<float> src, Span<float> dst)
13271327

13281328
Vector256<float> sqDistanceVector256 = Avx.SetZeroVector256<float>();
13291329

1330-
while (pSrcCurrent + 8 <= pSrcEnd)
1330+
while (pSrcEnd - pSrcCurrent >= 8)
13311331
{
13321332
Vector256<float> distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent),
13331333
Avx.LoadVector256(pDstCurrent));
@@ -1343,7 +1343,7 @@ public static unsafe float Dist2(Span<float> src, Span<float> dst)
13431343

13441344
Vector128<float> sqDistanceVector128 = Sse.SetZeroVector128();
13451345

1346-
if (pSrcCurrent + 4 <= pSrcEnd)
1346+
if (pSrcEnd - pSrcCurrent >= 4)
13471347
{
13481348
Vector128<float> distanceVector = Sse.Subtract(Sse.LoadVector128(pSrcCurrent),
13491349
Sse.LoadVector128(pDstCurrent));
@@ -1384,7 +1384,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, flo
13841384
Vector256<float> xPrimal256 = Avx.SetAllVector256(primalUpdate);
13851385
Vector256<float> xThreshold256 = Avx.SetAllVector256(threshold);
13861386

1387-
while (pSrcCurrent + 8 <= pSrcEnd)
1387+
while (pSrcEnd - pSrcCurrent >= 8)
13881388
{
13891389
Vector256<float> xSrc = Avx.LoadVector256(pSrcCurrent);
13901390

@@ -1403,7 +1403,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, Span<float> src, flo
14031403
Vector128<float> xPrimal128 = Sse.SetAllVector128(primalUpdate);
14041404
Vector128<float> xThreshold128 = Sse.SetAllVector128(threshold);
14051405

1406-
if (pSrcCurrent + 4 <= pSrcEnd)
1406+
if (pSrcEnd - pSrcCurrent >= 4)
14071407
{
14081408
Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);
14091409

@@ -1446,7 +1446,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span<float> src, Sp
14461446
Vector256<float> xPrimal256 = Avx.SetAllVector256(primalUpdate);
14471447
Vector256<float> xThreshold = Avx.SetAllVector256(threshold);
14481448

1449-
while (pIdxCurrent + 8 <= pIdxEnd)
1449+
while (pIdxEnd - pIdxCurrent >= 8)
14501450
{
14511451
Vector256<float> xSrc = Avx.LoadVector256(pSrcCurrent);
14521452

@@ -1464,7 +1464,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, Span<float> src, Sp
14641464
Vector128<float> xPrimal128 = Sse.SetAllVector128(primalUpdate);
14651465
Vector128<float> xThreshold128 = Sse.SetAllVector128(threshold);
14661466

1467-
if (pIdxCurrent + 4 <= pIdxEnd)
1467+
if (pIdxEnd - pIdxCurrent >= 4)
14681468
{
14691469
Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);
14701470

0 commit comments

Comments
 (0)