@@ -11,76 +11,62 @@ namespace Microsoft.ML.Runtime.Internal.CpuMath
11
11
{
12
12
internal static partial class CpuMathUtils
13
13
{
14
- // The count of bytes in Vector128<T>, corresponding to _cbAlign in AlignedArray
15
- private const int Vector128Alignment = 16 ;
16
-
17
- // The count of bytes in Vector256<T>, corresponding to _cbAlign in AlignedArray
18
- private const int Vector256Alignment = 32 ;
19
-
20
- // The count of bytes in a 32-bit float, corresponding to _cbAlign in AlignedArray
21
- private const int FloatAlignment = 4 ;
22
-
23
- // If neither AVX nor SSE is supported, return basic alignment for a 4-byte float.
24
- [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
25
- public static int GetVectorAlignment ( )
26
- => Avx . IsSupported ? Vector256Alignment : ( Sse . IsSupported ? Vector128Alignment : FloatAlignment ) ;
27
-
28
- public static void MatrixTimesSource ( bool transpose , AlignedArray matrix , AlignedArray source , AlignedArray destination , int stride )
14
+ public static void MatrixTimesSource ( bool transpose , ReadOnlySpan < float > matrix , ReadOnlySpan < float > source , Span < float > destination , int stride )
29
15
{
30
- Contracts . Assert ( matrix . Size == destination . Size * source . Size ) ;
16
+ Contracts . AssertNonEmpty ( matrix ) ;
17
+ Contracts . AssertNonEmpty ( source ) ;
18
+ Contracts . AssertNonEmpty ( destination ) ;
19
+ Contracts . Assert ( matrix . Length == destination . Length * source . Length ) ;
31
20
Contracts . Assert ( stride >= 0 ) ;
32
21
33
- if ( Avx . IsSupported )
22
+ if ( ! transpose )
34
23
{
35
- if ( ! transpose )
24
+ if ( Avx . IsSupported && source . Length >= 8 )
36
25
{
37
- Contracts . Assert ( stride <= destination . Size ) ;
38
- AvxIntrinsics . MatMul ( matrix , source , destination , stride , source . Size ) ;
26
+ Contracts . Assert ( stride <= destination . Length ) ;
27
+ AvxIntrinsics . MatMul ( matrix , source , destination , stride , source . Length ) ;
39
28
}
40
- else
29
+ else if ( Sse . IsSupported && source . Length >= 4 )
41
30
{
42
- Contracts . Assert ( stride <= source . Size ) ;
43
- AvxIntrinsics . MatMulTran ( matrix , source , destination , destination . Size , stride ) ;
44
- }
45
- }
46
- else if ( Sse . IsSupported )
47
- {
48
- if ( ! transpose )
49
- {
50
- Contracts . Assert ( stride <= destination . Size ) ;
51
- SseIntrinsics . MatMul ( matrix , source , destination , stride , source . Size ) ;
31
+ Contracts . Assert ( stride <= destination . Length ) ;
32
+ SseIntrinsics . MatMul ( matrix , source , destination , stride , source . Length ) ;
52
33
}
53
34
else
54
35
{
55
- Contracts . Assert ( stride <= source . Size ) ;
56
- SseIntrinsics . MatMulTran ( matrix , source , destination , destination . Size , stride ) ;
57
- }
58
- }
59
- else
60
- {
61
- if ( ! transpose )
62
- {
63
- Contracts . Assert ( stride <= destination . Size ) ;
36
+ Contracts . Assert ( stride <= destination . Length ) ;
64
37
for ( int i = 0 ; i < stride ; i ++ )
65
38
{
66
39
float dotProduct = 0 ;
67
- for ( int j = 0 ; j < source . Size ; j ++ )
40
+ for ( int j = 0 ; j < source . Length ; j ++ )
68
41
{
69
- dotProduct += matrix [ i * source . Size + j ] * source [ j ] ;
42
+ dotProduct += matrix [ i * source . Length + j ] * source [ j ] ;
70
43
}
71
44
72
45
destination [ i ] = dotProduct ;
73
46
}
74
47
}
48
+ }
49
+ else
50
+ {
51
+ if ( Avx . IsSupported && destination . Length >= 8 )
52
+ {
53
+ Contracts . Assert ( stride <= source . Length ) ;
54
+ AvxIntrinsics . MatMulTran ( matrix , source , destination , destination . Length , stride ) ;
55
+ }
56
+ else if ( Sse . IsSupported && destination . Length >= 4 )
57
+ {
58
+ Contracts . Assert ( stride <= source . Length ) ;
59
+ SseIntrinsics . MatMulTran ( matrix , source , destination , destination . Length , stride ) ;
60
+ }
75
61
else
76
62
{
77
- Contracts . Assert ( stride <= source . Size ) ;
78
- for ( int i = 0 ; i < destination . Size ; i ++ )
63
+ Contracts . Assert ( stride <= source . Length ) ;
64
+ for ( int i = 0 ; i < destination . Length ; i ++ )
79
65
{
80
66
float dotProduct = 0 ;
81
67
for ( int j = 0 ; j < stride ; j ++ )
82
68
{
83
- dotProduct += matrix [ j * source . Size + i ] * source [ j ] ;
69
+ dotProduct += matrix [ j * destination . Length + i ] * source [ j ] ;
84
70
}
85
71
86
72
destination [ i ] = dotProduct ;
@@ -89,17 +75,22 @@ public static void MatrixTimesSource(bool transpose, AlignedArray matrix, Aligne
89
75
}
90
76
}
91
77
92
- public static void MatrixTimesSource ( AlignedArray matrix , ReadOnlySpan < int > rgposSrc , AlignedArray sourceValues ,
93
- int posMin , int iposMin , int iposLimit , AlignedArray destination , int stride )
78
+ public static void MatrixTimesSource ( ReadOnlySpan < float > matrix , ReadOnlySpan < int > rgposSrc , ReadOnlySpan < float > sourceValues ,
79
+ int posMin , int iposMin , int iposLimit , Span < float > destination , int stride )
94
80
{
95
81
Contracts . Assert ( iposMin >= 0 ) ;
96
82
Contracts . Assert ( iposMin <= iposLimit ) ;
97
83
Contracts . Assert ( iposLimit <= rgposSrc . Length ) ;
98
- Contracts . Assert ( matrix . Size == destination . Size * sourceValues . Size ) ;
84
+ Contracts . AssertNonEmpty ( matrix ) ;
85
+ Contracts . AssertNonEmpty ( sourceValues ) ;
86
+ Contracts . AssertNonEmpty ( destination ) ;
87
+ Contracts . AssertNonEmpty ( rgposSrc ) ;
88
+ Contracts . Assert ( stride > 0 ) ;
89
+ Contracts . Assert ( matrix . Length == destination . Length * sourceValues . Length ) ;
99
90
100
91
if ( iposMin >= iposLimit )
101
92
{
102
- destination . ZeroItems ( ) ;
93
+ destination . Clear ( ) ;
103
94
return ;
104
95
}
105
96
@@ -108,24 +99,24 @@ public static void MatrixTimesSource(AlignedArray matrix, ReadOnlySpan<int> rgpo
108
99
109
100
if ( Avx . IsSupported )
110
101
{
111
- Contracts . Assert ( stride <= destination . Size ) ;
112
- AvxIntrinsics . MatMulP ( matrix , rgposSrc , sourceValues , posMin , iposMin , iposLimit , destination , stride , sourceValues . Size ) ;
102
+ Contracts . Assert ( stride <= destination . Length ) ;
103
+ AvxIntrinsics . MatMulP ( matrix , rgposSrc , sourceValues , posMin , iposMin , iposLimit , destination , stride , sourceValues . Length ) ;
113
104
}
114
105
else if ( Sse . IsSupported )
115
106
{
116
- Contracts . Assert ( stride <= destination . Size ) ;
117
- SseIntrinsics . MatMulP ( matrix , rgposSrc , sourceValues , posMin , iposMin , iposLimit , destination , stride , sourceValues . Size ) ;
107
+ Contracts . Assert ( stride <= destination . Length ) ;
108
+ SseIntrinsics . MatMulP ( matrix , rgposSrc , sourceValues , posMin , iposMin , iposLimit , destination , stride , sourceValues . Length ) ;
118
109
}
119
110
else
120
111
{
121
- Contracts . Assert ( stride <= destination . Size ) ;
112
+ Contracts . Assert ( stride <= destination . Length ) ;
122
113
for ( int i = 0 ; i < stride ; i ++ )
123
114
{
124
115
float dotProduct = 0 ;
125
116
for ( int j = iposMin ; j < iposLimit ; j ++ )
126
117
{
127
118
int col = rgposSrc [ j ] - posMin ;
128
- dotProduct += matrix [ i * sourceValues . Size + col ] * sourceValues [ col ] ;
119
+ dotProduct += matrix [ i * sourceValues . Length + col ] * sourceValues [ col ] ;
129
120
}
130
121
destination [ i ] = dotProduct ;
131
122
}
@@ -636,71 +627,6 @@ public static float L2DistSquared(ReadOnlySpan<float> left, ReadOnlySpan<float>
636
627
}
637
628
}
638
629
639
- public static void ZeroMatrixItems ( AlignedArray destination , int ccol , int cfltRow , int [ ] indices )
640
- {
641
- Contracts . Assert ( ccol > 0 ) ;
642
- Contracts . Assert ( ccol <= cfltRow ) ;
643
-
644
- if ( ccol == cfltRow )
645
- {
646
- ZeroItemsU ( destination , destination . Size , indices , indices . Length ) ;
647
- }
648
- else
649
- {
650
- ZeroMatrixItemsCore ( destination , destination . Size , ccol , cfltRow , indices , indices . Length ) ;
651
- }
652
- }
653
-
654
- private static unsafe void ZeroItemsU ( AlignedArray destination , int c , int [ ] indices , int cindices )
655
- {
656
- fixed ( float * pdst = & destination . Items [ 0 ] )
657
- fixed ( int * pidx = & indices [ 0 ] )
658
- {
659
- for ( int i = 0 ; i < cindices ; ++ i )
660
- {
661
- int index = pidx [ i ] ;
662
- Contracts . Assert ( index >= 0 ) ;
663
- Contracts . Assert ( index < c ) ;
664
- pdst [ index ] = 0 ;
665
- }
666
- }
667
- }
668
-
669
- private static unsafe void ZeroMatrixItemsCore ( AlignedArray destination , int c , int ccol , int cfltRow , int [ ] indices , int cindices )
670
- {
671
- fixed ( float * pdst = & destination . Items [ 0 ] )
672
- fixed ( int * pidx = & indices [ 0 ] )
673
- {
674
- int ivLogMin = 0 ;
675
- int ivLogLim = ccol ;
676
- int ivPhyMin = 0 ;
677
-
678
- for ( int i = 0 ; i < cindices ; ++ i )
679
- {
680
- int index = pidx [ i ] ;
681
- Contracts . Assert ( index >= 0 ) ;
682
- Contracts . Assert ( index < c ) ;
683
-
684
- int col = index - ivLogMin ;
685
- if ( ( uint ) col >= ( uint ) ccol )
686
- {
687
- Contracts . Assert ( ivLogMin > index || index >= ivLogLim ) ;
688
-
689
- int row = index / ccol ;
690
- ivLogMin = row * ccol ;
691
- ivLogLim = ivLogMin + ccol ;
692
- ivPhyMin = row * cfltRow ;
693
-
694
- Contracts . Assert ( index >= ivLogMin ) ;
695
- Contracts . Assert ( index < ivLogLim ) ;
696
- col = index - ivLogMin ;
697
- }
698
-
699
- pdst [ ivPhyMin + col ] = 0 ;
700
- }
701
- }
702
- }
703
-
704
630
public static void SdcaL1UpdateDense ( float primalUpdate , int count , ReadOnlySpan < float > source , float threshold , Span < float > v , Span < float > w )
705
631
{
706
632
Contracts . AssertNonEmpty ( source ) ;
0 commit comments