Skip to content

Commit 71c9ff3

Browse files
jwood803shauheen
authored andcommitted
Update loops in CpuMath to be more efficient (#1177)
* Update loops to be more efficient
1 parent 273e36c commit 71c9ff3

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/Microsoft.ML.CpuMath/AvxIntrinsics.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,7 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)
793793
{
794794
float* pDstEnd = pdst + dst.Length;
795795
float* pDstCurrent = pdst;
796+
int destinationEnd = pDstEnd - 4;
796797

797798
Vector256<float> scalarVector256 = Avx.SetAllVector256(scalar);
798799

@@ -807,7 +808,7 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)
807808

808809
Vector128<float> scalarVector128 = Sse.SetAllVector128(scalar);
809810

810-
if (pDstCurrent + 4 <= pDstEnd)
811+
if (pDstCurrent <= destinationEnd)
811812
{
812813
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
813814
dstVector = Sse.Add(dstVector, scalarVector128);
@@ -956,6 +957,7 @@ public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<f
956957
float* pDstEnd = pdst + count;
957958
float* pSrcCurrent = psrc;
958959
float* pDstCurrent = pdst;
960+
int destinationEnd = pDstEnd - 4;
959961

960962
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
961963

@@ -971,7 +973,7 @@ public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<f
971973

972974
Vector128<float> scaleVector128 = Sse.SetAllVector128(scale);
973975

974-
if (pDstCurrent + 4 <= pDstEnd)
976+
if (pDstCurrent <= destinationEnd)
975977
{
976978
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
977979
srcVector = Sse.Multiply(srcVector, scaleVector128);
@@ -1000,6 +1002,7 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
10001002
{
10011003
float* pDstEnd = pdst + dst.Length;
10021004
float* pDstCurrent = pdst;
1005+
int destinationEnd = pDstEnd - 4;
10031006

10041007
Vector256<float> a256 = Avx.SetAllVector256(a);
10051008
Vector256<float> b256 = Avx.SetAllVector256(b);
@@ -1017,7 +1020,7 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
10171020
Vector128<float> a128 = Sse.SetAllVector128(a);
10181021
Vector128<float> b128 = Sse.SetAllVector128(b);
10191022

1020-
if (pDstCurrent + 4 <= pDstEnd)
1023+
if (pDstCurrent <= destinationEnd)
10211024
{
10221025
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
10231026
dstVector = Sse.Add(dstVector, b128);

src/Microsoft.ML.CpuMath/SseIntrinsics.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -755,10 +755,11 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)
755755
{
756756
float* pDstEnd = pdst + dst.Length;
757757
float* pDstCurrent = pdst;
758+
int destinationEnd = pDstEnd - 4;
758759

759760
Vector128<float> scalarVector = Sse.SetAllVector128(scalar);
760761

761-
while (pDstCurrent + 4 <= pDstEnd)
762+
while (pDstCurrent <= destinationEnd)
762763
{
763764
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
764765
dstVector = Sse.Add(dstVector, scalarVector);
@@ -898,10 +899,11 @@ public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<f
898899
float* pDstEnd = pdst + count;
899900
float* pSrcCurrent = psrc;
900901
float* pDstCurrent = pdst;
902+
int destinationEnd = pDstEnd - 4;
901903

902904
Vector128<float> scaleVector = Sse.SetAllVector128(scale);
903905

904-
while (pDstCurrent + 4 <= pDstEnd)
906+
while (pDstCurrent <= destinationEnd)
905907
{
906908
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
907909
srcVector = Sse.Multiply(srcVector, scaleVector);
@@ -930,11 +932,12 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
930932
{
931933
float* pDstEnd = pdst + dst.Length;
932934
float* pDstCurrent = pdst;
935+
int destinationEnd = pDstEnd - 4;
933936

934937
Vector128<float> aVector = Sse.SetAllVector128(a);
935938
Vector128<float> bVector = Sse.SetAllVector128(b);
936939

937-
while (pDstCurrent + 4 <= pDstEnd)
940+
while (pDstCurrent <= destinationEnd)
938941
{
939942
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
940943
dstVector = Sse.Add(dstVector, bVector);

0 commit comments

Comments
 (0)