Skip to content

Commit 5e20acd

Browse files
committed
Merge branch 'master' of https://github.com/dotnet/machinelearning into naindicator
2 parents 9c1e758 + 0b84350 commit 5e20acd

File tree

12 files changed

+357
-77
lines changed

12 files changed

+357
-77
lines changed

src/Microsoft.ML.CpuMath/Avx.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ public static void Scale(float a, float[] dst, int count)
625625
unsafe
626626
{
627627
fixed (float* pd = &dst[0])
628-
Thunk.ScaleU(a, pd, count);
628+
Thunk.Scale(a, pd, count);
629629
}
630630
}
631631

src/Microsoft.ML.CpuMath/AvxIntrinsics.cs

Lines changed: 124 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,36 @@
1313
using System.Runtime.CompilerServices;
1414
using System.Runtime.Intrinsics;
1515
using System.Runtime.Intrinsics.X86;
16+
using nuint = System.UInt64;
1617

1718
namespace Microsoft.ML.Runtime.Internal.CpuMath
1819
{
1920
internal static class AvxIntrinsics
2021
{
22+
public static readonly uint[] LeadingAlignmentMask = new uint[64]
23+
{
24+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
25+
0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
26+
0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
27+
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
28+
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
29+
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000,
30+
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000,
31+
0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000,
32+
};
33+
34+
public static readonly uint[] TrailingAlignmentMask = new uint[64]
35+
{
36+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000,
37+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF,
38+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF,
39+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
40+
0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
41+
0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
42+
0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
43+
0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF,
44+
};
45+
2146
private static readonly Vector256<float> _absMask256 = Avx.StaticCast<int, float>(Avx.SetAllVector256(0x7FFFFFFF));
2247

2348
private const int Vector256Alignment = 32;
@@ -461,45 +486,122 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)
461486
}
462487
}
463488

464-
public static unsafe void ScaleU(float scale, Span<float> dst)
489+
public static unsafe void Scale(float scale, Span<float> dst)
465490
{
466-
fixed (float* pdst = dst)
491+
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
492+
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
493+
fixed (float* pd = dst)
467494
{
468-
float* pDstCurrent = pdst;
469-
float* pEnd = pdst + dst.Length;
470-
495+
float* pDstCurrent = pd;
496+
int length = dst.Length;
471497
Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
472498

473-
while (pDstCurrent + 8 <= pEnd)
499+
if (length < 8)
474500
{
475-
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
501+
switch(length)
502+
{
503+
case 7: dst[6] *= scale; goto case 6;
504+
case 6: dst[5] *= scale; goto case 5;
505+
case 5: dst[4] *= scale; goto case 4;
506+
case 4: dst[3] *= scale; goto case 3;
507+
case 3: dst[2] *= scale; goto case 2;
508+
case 2: dst[1] *= scale; goto case 1;
509+
case 1: dst[0] *= scale; break;
510+
}
511+
return;
512+
}
476513

477-
dstVector = Avx.Multiply(scaleVector256, dstVector);
478-
Avx.Store(pDstCurrent, dstVector);
514+
nuint address = (nuint)(pd);
515+
int misalignment = (int)(address % 32);
516+
int remainder = 0;
479517

480-
pDstCurrent += 8;
518+
if ((misalignment & 3) != 0)
519+
{
520+
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
521+
remainder = length % 8;
522+
523+
for (float* pEnd = pd + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 8)
524+
{
525+
Vector256<float> temp = Avx.LoadVector256(pDstCurrent);
526+
temp = Avx.Multiply(scaleVector256, temp);
527+
Avx.Store(pDstCurrent, temp);
528+
}
481529
}
530+
else
531+
{
532+
if (misalignment != 0)
533+
{
534+
// Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
535+
// masking any elements that will be included in the first aligned read
482536

483-
Vector128<float> scaleVector128 = Sse.SetAllVector128(scale);
537+
misalignment >>= 2;
538+
misalignment = 8 - misalignment;
484539

485-
if (pDstCurrent + 4 <= pEnd)
486-
{
487-
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
540+
Vector256<float> result = Avx.LoadVector256(pDstCurrent);
488541

489-
dstVector = Sse.Multiply(scaleVector128, dstVector);
490-
Sse.Store(pDstCurrent, dstVector);
542+
Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
543+
Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (( 8 - misalignment) * 8));
491544

492-
pDstCurrent += 4;
545+
Vector256<float> temp = Avx.And(result, leadingMask);
546+
result = Avx.And(result, trailingMask);
547+
548+
temp = Avx.Multiply(scaleVector256, temp);
549+
result = Avx.Or(temp, result);
550+
551+
Avx.Store(pDstCurrent, result);
552+
553+
pDstCurrent += misalignment;
554+
length -= misalignment;
555+
}
556+
557+
if (length > 7)
558+
{
559+
// Handle all the 256-bit blocks that we can now that we have offset to an aligned address
560+
561+
remainder = length % 8;
562+
563+
for (float* pEnd = pDstCurrent + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 8)
564+
{
565+
// The JIT will only fold away unaligned loads due to the semantics behind
566+
// the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
567+
// modern hardware has unaligned loads that are as fast as aligned loads,
568+
// when it doesn't cross a cache-line/page boundary, we will just assert
569+
// that the alignment is correct and allow for the more-efficient codegen.
570+
571+
Contracts.Assert(((nuint)(pDstCurrent) % 32) == 0);
572+
Vector256<float> temp = Avx.LoadVector256(pDstCurrent);
573+
temp = Avx.Multiply(scaleVector256, temp);
574+
Avx.Store(pDstCurrent, temp);
575+
}
576+
}
577+
else
578+
{
579+
// Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
580+
// 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
581+
// unaligned loads where we mask the input each time.
582+
remainder = length;
583+
}
493584
}
494585

495-
while (pDstCurrent < pEnd)
586+
if (remainder != 0)
496587
{
497-
Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent);
588+
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
589+
// unaligned load will read to the end of the array and then mask out any elements already processed
498590

499-
dstVector = Sse.MultiplyScalar(scaleVector128, dstVector);
500-
Sse.StoreScalar(pDstCurrent, dstVector);
591+
pDstCurrent -= (8 - remainder);
501592

502-
pDstCurrent++;
593+
Vector256<float> result = Avx.LoadVector256(pDstCurrent);
594+
595+
Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
596+
Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8));
597+
598+
Vector256<float> temp = Avx.And(result, trailingMask);
599+
result = Avx.And(result, leadingMask);
600+
601+
temp = Avx.Multiply(scaleVector256, temp);
602+
temp = Avx.Or(temp, result);
603+
604+
Avx.Store(pDstCurrent, temp);
503605
}
504606
}
505607
}

src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,11 @@ private static void Scale(float a, Span<float> dst)
248248
{
249249
if (Avx.IsSupported)
250250
{
251-
AvxIntrinsics.ScaleU(a, dst);
251+
AvxIntrinsics.Scale(a, dst);
252252
}
253253
else if (Sse.IsSupported)
254254
{
255-
SseIntrinsics.ScaleU(a, dst);
255+
SseIntrinsics.Scale(a, dst);
256256
}
257257
else
258258
{

src/Microsoft.ML.CpuMath/Sse.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -606,7 +606,7 @@ public static void Scale(float a, AlignedArray dst)
606606
unsafe
607607
{
608608
fixed (float* pdst = &dst.Items[0])
609-
Thunk.ScaleA(a, Ptr(dst, pdst), dst.Size);
609+
Thunk.Scale(a, Ptr(dst, pdst), dst.Size);
610610
}
611611
}
612612

@@ -618,7 +618,7 @@ public static void Scale(float a, float[] dst, int count)
618618
unsafe
619619
{
620620
fixed (float* pd = &dst[0])
621-
Thunk.ScaleU(a, pd, count);
621+
Thunk.Scale(a, pd, count);
622622
}
623623
}
624624

@@ -631,7 +631,7 @@ public static void Scale(float a, float[] dst, int offset, int count)
631631
unsafe
632632
{
633633
fixed (float* pd = &dst[offset])
634-
Thunk.ScaleU(a, pd, count);
634+
Thunk.Scale(a, pd, count);
635635
}
636636
}
637637

0 commit comments

Comments
 (0)