|
13 | 13 | using System.Runtime.CompilerServices;
|
14 | 14 | using System.Runtime.Intrinsics;
|
15 | 15 | using System.Runtime.Intrinsics.X86;
|
| 16 | +using nuint = System.UInt64; |
16 | 17 |
|
17 | 18 | namespace Microsoft.ML.Runtime.Internal.CpuMath
|
18 | 19 | {
|
19 | 20 | internal static class AvxIntrinsics
|
20 | 21 | {
|
| 22 | + public static readonly uint[] LeadingAlignmentMask = new uint[64] |
| 23 | + { |
| 24 | + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, |
| 25 | + 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, |
| 26 | + 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, |
| 27 | + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, |
| 28 | + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, 0x00000000, |
| 29 | + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, 0x00000000, |
| 30 | + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, 0x00000000, |
| 31 | + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x00000000, |
| 32 | + }; |
| 33 | + |
| 34 | + public static readonly uint[] TrailingAlignmentMask = new uint[64] |
| 35 | + { |
| 36 | + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, |
| 37 | + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, |
| 38 | + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, |
| 39 | + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, |
| 40 | + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, |
| 41 | + 0x00000000, 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, |
| 42 | + 0x00000000, 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, |
| 43 | + 0x00000000, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, |
| 44 | + }; |
| 45 | + |
21 | 46 | private static readonly Vector256<float> _absMask256 = Avx.StaticCast<int, float>(Avx.SetAllVector256(0x7FFFFFFF));
|
22 | 47 |
|
23 | 48 | private const int Vector256Alignment = 32;
|
@@ -461,45 +486,122 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)
|
461 | 486 | }
|
462 | 487 | }
|
463 | 488 |
|
464 |
| - public static unsafe void ScaleU(float scale, Span<float> dst) |
| 489 | + public static unsafe void Scale(float scale, Span<float> dst) |
465 | 490 | {
|
466 |
| - fixed (float* pdst = dst) |
| 491 | + fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0]) |
| 492 | + fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0]) |
| 493 | + fixed (float* pd = dst) |
467 | 494 | {
|
468 |
| - float* pDstCurrent = pdst; |
469 |
| - float* pEnd = pdst + dst.Length; |
470 |
| - |
| 495 | + float* pDstCurrent = pd; |
| 496 | + int length = dst.Length; |
471 | 497 | Vector256<float> scaleVector256 = Avx.SetAllVector256(scale);
|
472 | 498 |
|
473 |
| - while (pDstCurrent + 8 <= pEnd) |
| 499 | + if (length < 8) |
474 | 500 | {
|
475 |
| - Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent); |
| 501 | + switch(length) |
| 502 | + { |
| 503 | + case 7: dst[6] *= scale; goto case 6; |
| 504 | + case 6: dst[5] *= scale; goto case 5; |
| 505 | + case 5: dst[4] *= scale; goto case 4; |
| 506 | + case 4: dst[3] *= scale; goto case 3; |
| 507 | + case 3: dst[2] *= scale; goto case 2; |
| 508 | + case 2: dst[1] *= scale; goto case 1; |
| 509 | + case 1: dst[0] *= scale; break; |
| 510 | + } |
| 511 | + return; |
| 512 | + } |
476 | 513 |
|
477 |
| - dstVector = Avx.Multiply(scaleVector256, dstVector); |
478 |
| - Avx.Store(pDstCurrent, dstVector); |
| 514 | + nuint address = (nuint)(pd); |
| 515 | + int misalignment = (int)(address % 32); |
| 516 | + int remainder = 0; |
479 | 517 |
|
480 |
| - pDstCurrent += 8; |
| 518 | + if ((misalignment & 3) != 0) |
| 519 | + { |
| 520 | + // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations |
| 521 | + remainder = length % 8; |
| 522 | + |
| 523 | + for (float* pEnd = pd + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 8) |
| 524 | + { |
| 525 | + Vector256<float> temp = Avx.LoadVector256(pDstCurrent); |
| 526 | + temp = Avx.Multiply(scaleVector256, temp); |
| 527 | + Avx.Store(pDstCurrent, temp); |
| 528 | + } |
481 | 529 | }
|
| 530 | + else |
| 531 | + { |
| 532 | + if (misalignment != 0) |
| 533 | + { |
| 534 | + // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then |
| 535 | + // masking any elements that will be included in the first aligned read |
482 | 536 |
|
483 |
| - Vector128<float> scaleVector128 = Sse.SetAllVector128(scale); |
| 537 | + misalignment >>= 2; |
| 538 | + misalignment = 8 - misalignment; |
484 | 539 |
|
485 |
| - if (pDstCurrent + 4 <= pEnd) |
486 |
| - { |
487 |
| - Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent); |
| 540 | + Vector256<float> result = Avx.LoadVector256(pDstCurrent); |
488 | 541 |
|
489 |
| - dstVector = Sse.Multiply(scaleVector128, dstVector); |
490 |
| - Sse.Store(pDstCurrent, dstVector); |
| 542 | + Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8)); |
| 543 | + Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (( 8 - misalignment) * 8)); |
491 | 544 |
|
492 |
| - pDstCurrent += 4; |
| 545 | + Vector256<float> temp = Avx.And(result, leadingMask); |
| 546 | + result = Avx.And(result, trailingMask); |
| 547 | + |
| 548 | + temp = Avx.Multiply(scaleVector256, temp); |
| 549 | + result = Avx.Or(temp, result); |
| 550 | + |
| 551 | + Avx.Store(pDstCurrent, result); |
| 552 | + |
| 553 | + pDstCurrent += misalignment; |
| 554 | + length -= misalignment; |
| 555 | + } |
| 556 | + |
| 557 | + if (length > 7) |
| 558 | + { |
| 559 | + // Handle all the 256-bit blocks that we can now that we have offset to an aligned address |
| 560 | + |
| 561 | + remainder = length % 8; |
| 562 | + |
| 563 | + for (float* pEnd = pDstCurrent + (length - remainder); pDstCurrent < pEnd; pDstCurrent += 8) |
| 564 | + { |
| 565 | + // The JIT will only fold away unaligned loads due to the semantics behind |
| 566 | + // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since |
| 567 | + // modern hardware has unaligned loads that are as fast as aligned loads, |
| 568 | + // when it doesn't cross a cache-line/page boundary, we will just assert |
| 569 | + // that the alignment is correct and allow for the more-efficient codegen. |
| 570 | + |
| 571 | + Contracts.Assert(((nuint)(pDstCurrent) % 32) == 0); |
| 572 | + Vector256<float> temp = Avx.LoadVector256(pDstCurrent); |
| 573 | + temp = Avx.Multiply(scaleVector256, temp); |
| 574 | + Avx.Store(pDstCurrent, temp); |
| 575 | + } |
| 576 | + } |
| 577 | + else |
| 578 | + { |
| 579 | + // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not |
| 580 | + // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two |
| 581 | + // unaligned loads where we mask the input each time. |
| 582 | + remainder = length; |
| 583 | + } |
493 | 584 | }
|
494 | 585 |
|
495 |
| - while (pDstCurrent < pEnd) |
| 586 | + if (remainder != 0) |
496 | 587 | {
|
497 |
| - Vector128<float> dstVector = Sse.LoadScalarVector128(pDstCurrent); |
| 588 | + // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next |
| 589 | + // unaligned load will read to the end of the array and then mask out any elements already processed |
498 | 590 |
|
499 |
| - dstVector = Sse.MultiplyScalar(scaleVector128, dstVector); |
500 |
| - Sse.StoreScalar(pDstCurrent, dstVector); |
| 591 | + pDstCurrent -= (8 - remainder); |
501 | 592 |
|
502 |
| - pDstCurrent++; |
| 593 | + Vector256<float> result = Avx.LoadVector256(pDstCurrent); |
| 594 | + |
| 595 | + Vector256<float> trailingMask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8)); |
| 596 | + Vector256<float> leadingMask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + ((8 - remainder) * 8)); |
| 597 | + |
| 598 | + Vector256<float> temp = Avx.And(result, trailingMask); |
| 599 | + result = Avx.And(result, leadingMask); |
| 600 | + |
| 601 | + temp = Avx.Multiply(scaleVector256, temp); |
| 602 | + temp = Avx.Or(temp, result); |
| 603 | + |
| 604 | + Avx.Store(pDstCurrent, temp); |
503 | 605 | }
|
504 | 606 | }
|
505 | 607 | }
|
|
0 commit comments