@@ -46,25 +46,6 @@ internal static class AvxIntrinsics
46
46
47
47
private static readonly Vector256 < float > _absMask256 = Avx . StaticCast < int , float > ( Avx . SetAllVector256 ( 0x7FFFFFFF ) ) ;
48
48
49
- private const int Vector256Alignment = 32 ;
50
-
51
- [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
52
- private static bool HasCompatibleAlignment ( AlignedArray alignedArray )
53
- {
54
- Contracts . AssertValue ( alignedArray ) ;
55
- Contracts . Assert ( alignedArray . Size > 0 ) ;
56
- return ( alignedArray . CbAlign % Vector256Alignment ) == 0 ;
57
- }
58
-
59
- [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
60
- private static unsafe float * GetAlignedBase ( AlignedArray alignedArray , float * unalignedBase )
61
- {
62
- Contracts . AssertValue ( alignedArray ) ;
63
- float * alignedBase = unalignedBase + alignedArray . GetBase ( ( long ) unalignedBase ) ;
64
- Contracts . Assert ( ( ( long ) alignedBase % Vector256Alignment ) == 0 ) ;
65
- return alignedBase ;
66
- }
67
-
68
49
[ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
69
50
private static Vector128 < float > GetHigh ( in Vector256 < float > x )
70
51
=> Avx . ExtractVector128 ( x , 1 ) ;
@@ -170,19 +151,19 @@ private static Vector256<float> MultiplyAdd(Vector256<float> src1, Vector256<flo
170
151
}
171
152
172
153
// Multiply matrix times vector into vector.
173
- public static unsafe void MatMulX ( AlignedArray mat , AlignedArray src , AlignedArray dst , int crow , int ccol )
154
+ public static unsafe void MatMul ( AlignedArray mat , AlignedArray src , AlignedArray dst , int crow , int ccol )
174
155
{
175
- Contracts . Assert ( crow % 4 == 0 ) ;
176
- Contracts . Assert ( ccol % 4 == 0 ) ;
177
-
178
- MatMulX ( mat . Items , src . Items , dst . Items , crow , ccol ) ;
156
+ MatMul ( mat . Items , src . Items , dst . Items , crow , ccol ) ;
179
157
}
180
158
181
- public static unsafe void MatMulX ( float [ ] mat , float [ ] src , float [ ] dst , int crow , int ccol )
159
+ public static unsafe void MatMul ( ReadOnlySpan < float > mat , ReadOnlySpan < float > src , Span < float > dst , int crow , int ccol )
182
160
{
183
- fixed ( float * psrc = & src [ 0 ] )
184
- fixed ( float * pdst = & dst [ 0 ] )
185
- fixed ( float * pmat = & mat [ 0 ] )
161
+ Contracts . Assert ( crow % 4 == 0 ) ;
162
+ Contracts . Assert ( ccol % 4 == 0 ) ;
163
+
164
+ fixed ( float * psrc = & MemoryMarshal . GetReference ( src ) )
165
+ fixed ( float * pdst = & MemoryMarshal . GetReference ( dst ) )
166
+ fixed ( float * pmat = & MemoryMarshal . GetReference ( mat ) )
186
167
fixed ( uint * pLeadingAlignmentMask = & LeadingAlignmentMask [ 0 ] )
187
168
fixed ( uint * pTrailingAlignmentMask = & TrailingAlignmentMask [ 0 ] )
188
169
{
@@ -312,32 +293,134 @@ public static unsafe void MatMulX(float[] mat, float[] src, float[] dst, int cro
312
293
}
313
294
314
295
// Partial sparse source vector.
315
- public static unsafe void MatMulPX ( AlignedArray mat , int [ ] rgposSrc , AlignedArray src ,
316
- int posMin , int iposMin , int iposEnd , AlignedArray dst , int crow , int ccol )
296
+ public static unsafe void MatMulP ( AlignedArray mat , ReadOnlySpan < int > rgposSrc , AlignedArray src ,
297
+ int posMin , int iposMin , int iposEnd , AlignedArray dst , int crow , int ccol )
317
298
{
318
- Contracts . Assert ( HasCompatibleAlignment ( mat ) ) ;
319
- Contracts . Assert ( HasCompatibleAlignment ( src ) ) ;
320
- Contracts . Assert ( HasCompatibleAlignment ( dst ) ) ;
299
+ MatMulP ( mat . Items , rgposSrc , src . Items , posMin , iposMin , iposEnd , dst . Items , crow , ccol ) ;
300
+ }
301
+
302
+ public static unsafe void MatMulP ( ReadOnlySpan < float > mat , ReadOnlySpan < int > rgposSrc , ReadOnlySpan < float > src ,
303
+ int posMin , int iposMin , int iposEnd , Span < float > dst , int crow , int ccol )
304
+ {
305
+ Contracts . Assert ( crow % 8 == 0 ) ;
306
+ Contracts . Assert ( ccol % 8 == 0 ) ;
321
307
322
308
// REVIEW: For extremely sparse inputs, interchanging the loops would
323
309
// likely be more efficient.
324
- fixed ( float * pSrcStart = & src . Items [ 0 ] )
325
- fixed ( float * pDstStart = & dst . Items [ 0 ] )
326
- fixed ( float * pMatStart = & mat . Items [ 0 ] )
327
- fixed ( int * pposSrc = & rgposSrc [ 0 ] )
310
+ fixed ( float * psrc = & MemoryMarshal . GetReference ( src ) )
311
+ fixed ( float * pdst = & MemoryMarshal . GetReference ( dst ) )
312
+ fixed ( float * pmat = & MemoryMarshal . GetReference ( mat ) )
313
+ fixed ( int * pposSrc = & MemoryMarshal . GetReference ( rgposSrc ) )
314
+ fixed ( uint * pLeadingAlignmentMask = & LeadingAlignmentMask [ 0 ] )
315
+ fixed ( uint * pTrailingAlignmentMask = & TrailingAlignmentMask [ 0 ] )
328
316
{
329
- float * psrc = GetAlignedBase ( src , pSrcStart ) ;
330
- float * pdst = GetAlignedBase ( dst , pDstStart ) ;
331
- float * pmat = GetAlignedBase ( mat , pMatStart ) ;
332
-
333
317
int * pposMin = pposSrc + iposMin ;
334
318
int * pposEnd = pposSrc + iposEnd ;
335
319
float * pDstEnd = pdst + crow ;
336
320
float * pm0 = pmat - posMin ;
337
321
float * pSrcCurrent = psrc - posMin ;
338
322
float * pDstCurrent = pdst ;
339
323
340
- while ( pDstCurrent < pDstEnd )
324
+ nuint address = ( nuint ) ( pDstCurrent ) ;
325
+ int misalignment = ( int ) ( address % 32 ) ;
326
+ int length = crow ;
327
+ int remainder = 0 ;
328
+
329
+ if ( ( misalignment & 3 ) != 0 )
330
+ {
331
+ while ( pDstCurrent < pDstEnd )
332
+ {
333
+ Avx . Store ( pDstCurrent , SparseMultiplicationAcrossRow ( ) ) ;
334
+ pDstCurrent += 8 ;
335
+ pm0 += 8 * ccol ;
336
+ }
337
+ }
338
+ else
339
+ {
340
+ if ( misalignment != 0 )
341
+ {
342
+ misalignment >>= 2 ;
343
+ misalignment = 8 - misalignment ;
344
+
345
+ Vector256 < float > mask = Avx . LoadVector256 ( ( ( float * ) ( pLeadingAlignmentMask ) ) + ( misalignment * 8 ) ) ;
346
+
347
+ float * pm1 = pm0 + ccol ;
348
+ float * pm2 = pm1 + ccol ;
349
+ float * pm3 = pm2 + ccol ;
350
+ Vector256 < float > result = Avx . SetZeroVector256 < float > ( ) ;
351
+
352
+ int * ppos = pposMin ;
353
+
354
+ while ( ppos < pposEnd )
355
+ {
356
+ int col1 = * ppos ;
357
+ int col2 = col1 + 4 * ccol ;
358
+ Vector256 < float > x1 = Avx . SetVector256 ( pm3 [ col2 ] , pm2 [ col2 ] , pm1 [ col2 ] , pm0 [ col2 ] ,
359
+ pm3 [ col1 ] , pm2 [ col1 ] , pm1 [ col1 ] , pm0 [ col1 ] ) ;
360
+
361
+ x1 = Avx . And ( mask , x1 ) ;
362
+ Vector256 < float > x2 = Avx . SetAllVector256 ( pSrcCurrent [ col1 ] ) ;
363
+ result = MultiplyAdd ( x2 , x1 , result ) ;
364
+ ppos ++ ;
365
+ }
366
+
367
+ Avx . Store ( pDstCurrent , result ) ;
368
+ pDstCurrent += misalignment ;
369
+ pm0 += misalignment * ccol ;
370
+ length -= misalignment ;
371
+ }
372
+
373
+ if ( length > 7 )
374
+ {
375
+ remainder = length % 8 ;
376
+ while ( pDstCurrent < pDstEnd )
377
+ {
378
+ Avx . Store ( pDstCurrent , SparseMultiplicationAcrossRow ( ) ) ;
379
+ pDstCurrent += 8 ;
380
+ pm0 += 8 * ccol ;
381
+ }
382
+ }
383
+ else
384
+ {
385
+ remainder = length ;
386
+ }
387
+
388
+ if ( remainder != 0 )
389
+ {
390
+ pDstCurrent -= ( 8 - remainder ) ;
391
+ pm0 -= ( 8 - remainder ) * ccol ;
392
+ Vector256 < float > trailingMask = Avx . LoadVector256 ( ( ( float * ) ( pTrailingAlignmentMask ) ) + ( remainder * 8 ) ) ;
393
+ Vector256 < float > leadingMask = Avx . LoadVector256 ( ( ( float * ) ( pLeadingAlignmentMask ) ) + ( ( 8 - remainder ) * 8 ) ) ;
394
+
395
+ float * pm1 = pm0 + ccol ;
396
+ float * pm2 = pm1 + ccol ;
397
+ float * pm3 = pm2 + ccol ;
398
+ Vector256 < float > result = Avx . SetZeroVector256 < float > ( ) ;
399
+
400
+ int * ppos = pposMin ;
401
+
402
+ while ( ppos < pposEnd )
403
+ {
404
+ int col1 = * ppos ;
405
+ int col2 = col1 + 4 * ccol ;
406
+ Vector256 < float > x1 = Avx . SetVector256 ( pm3 [ col2 ] , pm2 [ col2 ] , pm1 [ col2 ] , pm0 [ col2 ] ,
407
+ pm3 [ col1 ] , pm2 [ col1 ] , pm1 [ col1 ] , pm0 [ col1 ] ) ;
408
+ x1 = Avx . And ( x1 , trailingMask ) ;
409
+
410
+ Vector256 < float > x2 = Avx . SetAllVector256 ( pSrcCurrent [ col1 ] ) ;
411
+ result = MultiplyAdd ( x2 , x1 , result ) ;
412
+ ppos ++ ;
413
+ }
414
+
415
+ result = Avx . Add ( result , Avx . And ( leadingMask , Avx . LoadVector256 ( pDstCurrent ) ) ) ;
416
+
417
+ Avx . Store ( pDstCurrent , result ) ;
418
+ pDstCurrent += 8 ;
419
+ pm0 += 8 * ccol ;
420
+ }
421
+ }
422
+
423
+ Vector256 < float > SparseMultiplicationAcrossRow ( )
341
424
{
342
425
float * pm1 = pm0 + ccol ;
343
426
float * pm2 = pm1 + ccol ;
@@ -351,33 +434,30 @@ public static unsafe void MatMulPX(AlignedArray mat, int[] rgposSrc, AlignedArra
351
434
int col1 = * ppos ;
352
435
int col2 = col1 + 4 * ccol ;
353
436
Vector256 < float > x1 = Avx . SetVector256 ( pm3 [ col2 ] , pm2 [ col2 ] , pm1 [ col2 ] , pm0 [ col2 ] ,
354
- pm3 [ col1 ] , pm2 [ col1 ] , pm1 [ col1 ] , pm0 [ col1 ] ) ;
437
+ pm3 [ col1 ] , pm2 [ col1 ] , pm1 [ col1 ] , pm0 [ col1 ] ) ;
355
438
Vector256 < float > x2 = Avx . SetAllVector256 ( pSrcCurrent [ col1 ] ) ;
356
439
result = MultiplyAdd ( x2 , x1 , result ) ;
357
-
358
440
ppos ++ ;
359
441
}
360
442
361
- Avx . StoreAligned ( pDstCurrent , result ) ;
362
- pDstCurrent += 8 ;
363
- pm0 += 8 * ccol ;
443
+ return result ;
364
444
}
365
445
}
366
446
}
367
447
368
- public static unsafe void MatMulTranX ( AlignedArray mat , AlignedArray src , AlignedArray dst , int crow , int ccol )
448
+ public static unsafe void MatMulTran ( AlignedArray mat , AlignedArray src , AlignedArray dst , int crow , int ccol )
369
449
{
370
- Contracts . Assert ( crow % 4 == 0 ) ;
371
- Contracts . Assert ( ccol % 4 == 0 ) ;
372
-
373
- MatMulTranX ( mat . Items , src . Items , dst . Items , crow , ccol ) ;
450
+ MatMulTran ( mat . Items , src . Items , dst . Items , crow , ccol ) ;
374
451
}
375
452
376
- public static unsafe void MatMulTranX ( float [ ] mat , float [ ] src , float [ ] dst , int crow , int ccol )
453
+ public static unsafe void MatMulTran ( ReadOnlySpan < float > mat , ReadOnlySpan < float > src , Span < float > dst , int crow , int ccol )
377
454
{
378
- fixed ( float * psrc = & src [ 0 ] )
379
- fixed ( float * pdst = & dst [ 0 ] )
380
- fixed ( float * pmat = & mat [ 0 ] )
455
+ Contracts . Assert ( crow % 4 == 0 ) ;
456
+ Contracts . Assert ( ccol % 4 == 0 ) ;
457
+
458
+ fixed ( float * psrc = & MemoryMarshal . GetReference ( src ) )
459
+ fixed ( float * pdst = & MemoryMarshal . GetReference ( dst ) )
460
+ fixed ( float * pmat = & MemoryMarshal . GetReference ( mat ) )
381
461
fixed ( uint * pLeadingAlignmentMask = & LeadingAlignmentMask [ 0 ] )
382
462
fixed ( uint * pTrailingAlignmentMask = & TrailingAlignmentMask [ 0 ] )
383
463
{
0 commit comments