@@ -38,10 +38,14 @@ int matmul_double(double_cmat matC, double_cmat matA, double_cmat matB){
38
38
return 0 ;
39
39
}
40
40
41
- inline int matmul_double_blas (double_cmat C , double_cmat A_slice , double_cmat B_slice ) {
41
+ int matmul_double_blas (double_cmat C , double_cmat A_slice , double_cmat B_slice ) {
42
42
//printf("start matmul blas %d %d\n", A_slice.shape[0], A_slice.shape[1]);
43
43
// Check dimensions for compatibility
44
- //memset(&C.data[0][0], 0, sizeof(C.data[0][0])*C.shape[0]*C.shape[1]);
44
+ //if (A_slice.shape[1] != B_slice.shape[0] || A_slice.shape[0] != C.shape[0] || B_slice.shape[1] != C.shape[1]) {
45
+ // fprintf(stderr, "Matrix dimensions do not match for multiplication.\n");
46
+ // return -1;
47
+ //}
48
+ //for(int arenai=0; arenai<NS; arenai++) C.arena[arenai] = 0;
45
49
// Call dgemm
46
50
cblas_dgemm (CblasRowMajor , CblasNoTrans , CblasNoTrans ,
47
51
C .shape [0 ], C .shape [1 ], A_slice .shape [1 ],
@@ -85,29 +89,29 @@ int matmul_double_strassen_winograd(double_cmat matC, double_cmat matA, double_c
85
89
return -1 ;
86
90
}
87
91
int N = matA .shape [0 ];
88
- int I = matA .shape [0 ]/2 ;
92
+ int II = matA .shape [0 ]/2 ;
89
93
if (N <= 4096 ) {
90
94
return matmul_double_blas (matC , matA , matB );
91
95
}
92
- double_cmat A11 = slice_double_matrix (matA , pairint {0 ,I }, pairint {0 ,I });
93
- double_cmat A12 = slice_double_matrix (matA , pairint {0 ,I }, pairint {I ,N });
94
- double_cmat A21 = slice_double_matrix (matA , pairint {I ,N }, pairint {0 ,I });
95
- double_cmat A22 = slice_double_matrix (matA , pairint {I ,N }, pairint {I ,N });
96
- double_cmat B11 = slice_double_matrix (matB , pairint {0 ,I }, pairint {0 ,I });
97
- double_cmat B12 = slice_double_matrix (matB , pairint {0 ,I }, pairint {I ,N });
98
- double_cmat B21 = slice_double_matrix (matB , pairint {I ,N }, pairint {0 ,I });
99
- double_cmat B22 = slice_double_matrix (matB , pairint {I ,N }, pairint {I ,N });
100
- double_cmat C11 = slice_double_matrix (matC , pairint {0 ,I }, pairint {0 ,I });
101
- double_cmat C12 = slice_double_matrix (matC , pairint {0 ,I }, pairint {I ,N });
102
- double_cmat C21 = slice_double_matrix (matC , pairint {I ,N }, pairint {0 ,I });
103
- double_cmat C22 = slice_double_matrix (matC , pairint {I ,N }, pairint {I ,N });
96
+ double_cmat A11 = slice_double_matrix (matA , pairint {0 ,II }, pairint {0 ,II });
97
+ double_cmat A12 = slice_double_matrix (matA , pairint {0 ,II }, pairint {II ,N });
98
+ double_cmat A21 = slice_double_matrix (matA , pairint {II ,N }, pairint {0 ,II });
99
+ double_cmat A22 = slice_double_matrix (matA , pairint {II ,N }, pairint {II ,N });
100
+ double_cmat B11 = slice_double_matrix (matB , pairint {0 ,II }, pairint {0 ,II });
101
+ double_cmat B12 = slice_double_matrix (matB , pairint {0 ,II }, pairint {II ,N });
102
+ double_cmat B21 = slice_double_matrix (matB , pairint {II ,N }, pairint {0 ,II });
103
+ double_cmat B22 = slice_double_matrix (matB , pairint {II ,N }, pairint {II ,N });
104
+ double_cmat C11 = slice_double_matrix (matC , pairint {0 ,II }, pairint {0 ,II });
105
+ double_cmat C12 = slice_double_matrix (matC , pairint {0 ,II }, pairint {II ,N });
106
+ double_cmat C21 = slice_double_matrix (matC , pairint {II ,N }, pairint {0 ,II });
107
+ double_cmat C22 = slice_double_matrix (matC , pairint {II ,N }, pairint {II ,N });
104
108
double_cmat S1 , T1 , M1 , M2 , M6 , M7 ;
105
- create_double_matrix (pairint {I , I }, & S1 );
106
- create_double_matrix (pairint {I , I }, & T1 );
107
- create_double_matrix (pairint {I , I }, & M1 );
108
- create_double_matrix (pairint {I , I }, & M2 );
109
- create_double_matrix (pairint {I , I }, & M6 );
110
- create_double_matrix (pairint {I , I }, & M7 );
109
+ create_double_matrix (pairint {II , II }, & S1 );
110
+ create_double_matrix (pairint {II , II }, & T1 );
111
+ create_double_matrix (pairint {II , II }, & M1 );
112
+ create_double_matrix (pairint {II , II }, & M2 );
113
+ create_double_matrix (pairint {II , II }, & M6 );
114
+ create_double_matrix (pairint {II , II }, & M7 );
111
115
112
116
113
117
matmul_double_strassen_winograd (M1 , A11 , B11 );
@@ -161,32 +165,32 @@ int matmul_double_recursive_bilinear(double_cmat matC, double_cmat matA, double_
161
165
return -1 ;
162
166
}
163
167
int N = matA .shape [0 ];
164
- int I = matA .shape [0 ]/2 ;
168
+ int II = matA .shape [0 ]/2 ;
165
169
if (N <= 4096 ) {
166
170
return matmul_double_blas (matC , matA , matB );
167
171
}
168
- double_cmat A11 = slice_double_matrix (matA , pairint {0 ,I }, pairint {0 ,I });
169
- double_cmat A12 = slice_double_matrix (matA , pairint {0 ,I }, pairint {I ,N });
170
- double_cmat A21 = slice_double_matrix (matA , pairint {I ,N }, pairint {0 ,I });
171
- double_cmat A22 = slice_double_matrix (matA , pairint {I ,N }, pairint {I ,N });
172
- double_cmat B11 = slice_double_matrix (matB , pairint {0 ,I }, pairint {0 ,I });
173
- double_cmat B12 = slice_double_matrix (matB , pairint {0 ,I }, pairint {I ,N });
174
- double_cmat B21 = slice_double_matrix (matB , pairint {I ,N }, pairint {0 ,I });
175
- double_cmat B22 = slice_double_matrix (matB , pairint {I ,N }, pairint {I ,N });
176
- double_cmat C11 = slice_double_matrix (matC , pairint {0 ,I }, pairint {0 ,I });
177
- double_cmat C12 = slice_double_matrix (matC , pairint {0 ,I }, pairint {I ,N });
178
- double_cmat C21 = slice_double_matrix (matC , pairint {I ,N }, pairint {0 ,I });
179
- double_cmat C22 = slice_double_matrix (matC , pairint {I ,N }, pairint {I ,N });
172
+ double_cmat A11 = slice_double_matrix (matA , pairint {0 ,II }, pairint {0 ,II });
173
+ double_cmat A12 = slice_double_matrix (matA , pairint {0 ,II }, pairint {II ,N });
174
+ double_cmat A21 = slice_double_matrix (matA , pairint {II ,N }, pairint {0 ,II });
175
+ double_cmat A22 = slice_double_matrix (matA , pairint {II ,N }, pairint {II ,N });
176
+ double_cmat B11 = slice_double_matrix (matB , pairint {0 ,II }, pairint {0 ,II });
177
+ double_cmat B12 = slice_double_matrix (matB , pairint {0 ,II }, pairint {II ,N });
178
+ double_cmat B21 = slice_double_matrix (matB , pairint {II ,N }, pairint {0 ,II });
179
+ double_cmat B22 = slice_double_matrix (matB , pairint {II ,N }, pairint {II ,N });
180
+ double_cmat C11 = slice_double_matrix (matC , pairint {0 ,II }, pairint {0 ,II });
181
+ double_cmat C12 = slice_double_matrix (matC , pairint {0 ,II }, pairint {II ,N });
182
+ double_cmat C21 = slice_double_matrix (matC , pairint {II ,N }, pairint {0 ,II });
183
+ double_cmat C22 = slice_double_matrix (matC , pairint {II ,N }, pairint {II ,N });
180
184
double_cmat S5 , T5 , M1 , M2 , M3 , M4 , M5 , M6 , M7 ;
181
- create_double_matrix (pairint {I , I }, & S5 );
182
- create_double_matrix (pairint {I , I }, & T5 );
183
- create_double_matrix (pairint {I , I }, & M1 );
184
- create_double_matrix (pairint {I , I }, & M2 );
185
- create_double_matrix (pairint {I , I }, & M3 );
186
- create_double_matrix (pairint {I , I }, & M4 );
187
- create_double_matrix (pairint {I , I }, & M5 );
188
- create_double_matrix (pairint {I , I }, & M6 );
189
- create_double_matrix (pairint {I , I }, & M7 );
185
+ create_double_matrix (pairint {II , II }, & S5 );
186
+ create_double_matrix (pairint {II , II }, & T5 );
187
+ create_double_matrix (pairint {II , II }, & M1 );
188
+ create_double_matrix (pairint {II , II }, & M2 );
189
+ create_double_matrix (pairint {II , II }, & M3 );
190
+ create_double_matrix (pairint {II , II }, & M4 );
191
+ create_double_matrix (pairint {II , II }, & M5 );
192
+ create_double_matrix (pairint {II , II }, & M6 );
193
+ create_double_matrix (pairint {II , II }, & M7 );
190
194
191
195
192
196
@@ -209,8 +213,8 @@ int matmul_double_recursive_bilinear(double_cmat matC, double_cmat matA, double_
209
213
matadd_double (C11 , M1 , M2 );
210
214
matsub_double (C12 , M5 , M7 );
211
215
matadd_double (C21 , M3 , M6 );
212
- for (int i = 0 ; i < I ; i ++ ) {
213
- for (int j = 0 ; j < I ; j ++ ) {
216
+ for (int i = 0 ; i < II ; i ++ ) {
217
+ for (int j = 0 ; j < II ; j ++ ) {
214
218
C22 .data [i ][j ] = M5 .data [i ][j ] + M6 .data [i ][j ] - M2 .data [i ][j ] - M4 .data [i ][j ];
215
219
}
216
220
}
@@ -244,22 +248,22 @@ int matmul_double_schwartz2024(double_cmat matC, double_cmat matA, double_cmat m
244
248
return -1 ;
245
249
}
246
250
int N = matA .shape [0 ];
247
- int I = matA .shape [0 ]/2 ;
251
+ int II = matA .shape [0 ]/2 ;
248
252
if (N <= 4096 ) {
249
253
return matmul_double_blas (matC , matA , matB );
250
254
}
251
- double_cmat A11 = slice_double_matrix (matA , pairint {0 ,I }, pairint {0 ,I });
252
- double_cmat A12 = slice_double_matrix (matA , pairint {0 ,I }, pairint {I ,N });
253
- double_cmat A21 = slice_double_matrix (matA , pairint {I ,N }, pairint {0 ,I });
254
- double_cmat A22 = slice_double_matrix (matA , pairint {I ,N }, pairint {I ,N });
255
- double_cmat B11 = slice_double_matrix (matB , pairint {0 ,I }, pairint {0 ,I });
256
- double_cmat B12 = slice_double_matrix (matB , pairint {0 ,I }, pairint {I ,N });
257
- double_cmat B21 = slice_double_matrix (matB , pairint {I ,N }, pairint {0 ,I });
258
- double_cmat B22 = slice_double_matrix (matB , pairint {I ,N }, pairint {I ,N });
259
- double_cmat C11 = slice_double_matrix (matC , pairint {0 ,I }, pairint {0 ,I });
260
- double_cmat C12 = slice_double_matrix (matC , pairint {0 ,I }, pairint {I ,N });
261
- double_cmat C21 = slice_double_matrix (matC , pairint {I ,N }, pairint {0 ,I });
262
- double_cmat C22 = slice_double_matrix (matC , pairint {I ,N }, pairint {I ,N });
255
+ double_cmat A11 = slice_double_matrix (matA , pairint {0 ,II }, pairint {0 ,II });
256
+ double_cmat A12 = slice_double_matrix (matA , pairint {0 ,II }, pairint {II ,N });
257
+ double_cmat A21 = slice_double_matrix (matA , pairint {II ,N }, pairint {0 ,II });
258
+ double_cmat A22 = slice_double_matrix (matA , pairint {II ,N }, pairint {II ,N });
259
+ double_cmat B11 = slice_double_matrix (matB , pairint {0 ,II }, pairint {0 ,II });
260
+ double_cmat B12 = slice_double_matrix (matB , pairint {0 ,II }, pairint {II ,N });
261
+ double_cmat B21 = slice_double_matrix (matB , pairint {II ,N }, pairint {0 ,II });
262
+ double_cmat B22 = slice_double_matrix (matB , pairint {II ,N }, pairint {II ,N });
263
+ double_cmat C11 = slice_double_matrix (matC , pairint {0 ,II }, pairint {0 ,II });
264
+ double_cmat C12 = slice_double_matrix (matC , pairint {0 ,II }, pairint {II ,N });
265
+ double_cmat C21 = slice_double_matrix (matC , pairint {II ,N }, pairint {0 ,II });
266
+ double_cmat C22 = slice_double_matrix (matC , pairint {II ,N }, pairint {II ,N });
263
267
// pa11 = a21 + a22
264
268
// pa12 = a22
265
269
// pa21 = -a11 - a12
@@ -279,8 +283,8 @@ int matmul_double_schwartz2024(double_cmat matC, double_cmat matA, double_cmat m
279
283
// pb22 = b11 + b22
280
284
double_cmat tmp1 ;
281
285
double_cmat tmp2 ;
282
- create_double_matrix (pairint {I , I }, & tmp1 );
283
- create_double_matrix (pairint {I , I }, & tmp2 );
286
+ create_double_matrix (pairint {II , II }, & tmp1 );
287
+ create_double_matrix (pairint {II , II }, & tmp2 );
284
288
assign_double_clone (tmp2 , A11 );
285
289
matneg_double (tmp1 , A11 );
286
290
matadd_double (A11 , A21 , A22 );
@@ -313,8 +317,8 @@ int matmul_double_schwartz2024(double_cmat matC, double_cmat matA, double_cmat m
313
317
assign_double_clone (tmp1 , C21 );
314
318
assign_double_clone (C21 , C11 );
315
319
matsub_double (C11 , tmp1 , C22 );
316
- for (int i = 0 ; i < I ; i ++ ) {
317
- for (int j = 0 ; j < I ; j ++ ) {
320
+ for (int i = 0 ; i < II ; i ++ ) {
321
+ for (int j = 0 ; j < II ; j ++ ) {
318
322
C22 .data [i ][j ] = C12 .data [i ][j ] - C21 .data [i ][j ] - C22 .data [i ][j ];
319
323
}
320
324
}
0 commit comments