Skip to content

Commit 366e79f

Browse files
committed
avoid I collision with complex number
1 parent 47f5621 commit 366e79f

File tree

1 file changed

+66
-62
lines changed

1 file changed

+66
-62
lines changed

matmul.c

+66-62
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@ int matmul_double(double_cmat matC, double_cmat matA, double_cmat matB){
3838
return 0;
3939
}
4040

41-
inline int matmul_double_blas(double_cmat C, double_cmat A_slice, double_cmat B_slice) {
41+
int matmul_double_blas(double_cmat C, double_cmat A_slice, double_cmat B_slice) {
4242
//printf("start matmul blas %d %d\n", A_slice.shape[0], A_slice.shape[1]);
4343
// Check dimensions for compatibility
44-
//memset(&C.data[0][0], 0, sizeof(C.data[0][0])*C.shape[0]*C.shape[1]);
44+
//if (A_slice.shape[1] != B_slice.shape[0] || A_slice.shape[0] != C.shape[0] || B_slice.shape[1] != C.shape[1]) {
45+
// fprintf(stderr, "Matrix dimensions do not match for multiplication.\n");
46+
// return -1;
47+
//}
48+
//for(int arenai=0; arenai<NS; arenai++) C.arena[arenai] = 0;
4549
// Call dgemm
4650
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
4751
C.shape[0], C.shape[1], A_slice.shape[1],
@@ -85,29 +89,29 @@ int matmul_double_strassen_winograd(double_cmat matC, double_cmat matA, double_c
8589
return -1;
8690
}
8791
int N = matA.shape[0];
88-
int I = matA.shape[0]/2;
92+
int II = matA.shape[0]/2;
8993
if (N <= 4096) {
9094
return matmul_double_blas(matC, matA, matB);
9195
}
92-
double_cmat A11 = slice_double_matrix(matA, pairint {0,I}, pairint {0,I});
93-
double_cmat A12 = slice_double_matrix(matA, pairint {0,I}, pairint {I,N});
94-
double_cmat A21 = slice_double_matrix(matA, pairint {I,N}, pairint {0,I});
95-
double_cmat A22 = slice_double_matrix(matA, pairint {I,N}, pairint {I,N});
96-
double_cmat B11 = slice_double_matrix(matB, pairint {0,I}, pairint {0,I});
97-
double_cmat B12 = slice_double_matrix(matB, pairint {0,I}, pairint {I,N});
98-
double_cmat B21 = slice_double_matrix(matB, pairint {I,N}, pairint {0,I});
99-
double_cmat B22 = slice_double_matrix(matB, pairint {I,N}, pairint {I,N});
100-
double_cmat C11 = slice_double_matrix(matC, pairint {0,I}, pairint {0,I});
101-
double_cmat C12 = slice_double_matrix(matC, pairint {0,I}, pairint {I,N});
102-
double_cmat C21 = slice_double_matrix(matC, pairint {I,N}, pairint {0,I});
103-
double_cmat C22 = slice_double_matrix(matC, pairint {I,N}, pairint {I,N});
96+
double_cmat A11 = slice_double_matrix(matA, pairint {0,II}, pairint {0,II});
97+
double_cmat A12 = slice_double_matrix(matA, pairint {0,II}, pairint {II,N});
98+
double_cmat A21 = slice_double_matrix(matA, pairint {II,N}, pairint {0,II});
99+
double_cmat A22 = slice_double_matrix(matA, pairint {II,N}, pairint {II,N});
100+
double_cmat B11 = slice_double_matrix(matB, pairint {0,II}, pairint {0,II});
101+
double_cmat B12 = slice_double_matrix(matB, pairint {0,II}, pairint {II,N});
102+
double_cmat B21 = slice_double_matrix(matB, pairint {II,N}, pairint {0,II});
103+
double_cmat B22 = slice_double_matrix(matB, pairint {II,N}, pairint {II,N});
104+
double_cmat C11 = slice_double_matrix(matC, pairint {0,II}, pairint {0,II});
105+
double_cmat C12 = slice_double_matrix(matC, pairint {0,II}, pairint {II,N});
106+
double_cmat C21 = slice_double_matrix(matC, pairint {II,N}, pairint {0,II});
107+
double_cmat C22 = slice_double_matrix(matC, pairint {II,N}, pairint {II,N});
104108
double_cmat S1, T1, M1, M2, M6, M7;
105-
create_double_matrix(pairint {I,I}, &S1);
106-
create_double_matrix(pairint {I,I}, &T1);
107-
create_double_matrix(pairint {I,I}, &M1);
108-
create_double_matrix(pairint {I,I}, &M2);
109-
create_double_matrix(pairint {I,I}, &M6);
110-
create_double_matrix(pairint {I,I}, &M7);
109+
create_double_matrix(pairint {II,II}, &S1);
110+
create_double_matrix(pairint {II,II}, &T1);
111+
create_double_matrix(pairint {II,II}, &M1);
112+
create_double_matrix(pairint {II,II}, &M2);
113+
create_double_matrix(pairint {II,II}, &M6);
114+
create_double_matrix(pairint {II,II}, &M7);
111115

112116

113117
matmul_double_strassen_winograd(M1, A11, B11);
@@ -161,32 +165,32 @@ int matmul_double_recursive_bilinear(double_cmat matC, double_cmat matA, double_
161165
return -1;
162166
}
163167
int N = matA.shape[0];
164-
int I = matA.shape[0]/2;
168+
int II = matA.shape[0]/2;
165169
if (N <= 4096) {
166170
return matmul_double_blas(matC, matA, matB);
167171
}
168-
double_cmat A11 = slice_double_matrix(matA, pairint {0,I}, pairint {0,I});
169-
double_cmat A12 = slice_double_matrix(matA, pairint {0,I}, pairint {I,N});
170-
double_cmat A21 = slice_double_matrix(matA, pairint {I,N}, pairint {0,I});
171-
double_cmat A22 = slice_double_matrix(matA, pairint {I,N}, pairint {I,N});
172-
double_cmat B11 = slice_double_matrix(matB, pairint {0,I}, pairint {0,I});
173-
double_cmat B12 = slice_double_matrix(matB, pairint {0,I}, pairint {I,N});
174-
double_cmat B21 = slice_double_matrix(matB, pairint {I,N}, pairint {0,I});
175-
double_cmat B22 = slice_double_matrix(matB, pairint {I,N}, pairint {I,N});
176-
double_cmat C11 = slice_double_matrix(matC, pairint {0,I}, pairint {0,I});
177-
double_cmat C12 = slice_double_matrix(matC, pairint {0,I}, pairint {I,N});
178-
double_cmat C21 = slice_double_matrix(matC, pairint {I,N}, pairint {0,I});
179-
double_cmat C22 = slice_double_matrix(matC, pairint {I,N}, pairint {I,N});
172+
double_cmat A11 = slice_double_matrix(matA, pairint {0,II}, pairint {0,II});
173+
double_cmat A12 = slice_double_matrix(matA, pairint {0,II}, pairint {II,N});
174+
double_cmat A21 = slice_double_matrix(matA, pairint {II,N}, pairint {0,II});
175+
double_cmat A22 = slice_double_matrix(matA, pairint {II,N}, pairint {II,N});
176+
double_cmat B11 = slice_double_matrix(matB, pairint {0,II}, pairint {0,II});
177+
double_cmat B12 = slice_double_matrix(matB, pairint {0,II}, pairint {II,N});
178+
double_cmat B21 = slice_double_matrix(matB, pairint {II,N}, pairint {0,II});
179+
double_cmat B22 = slice_double_matrix(matB, pairint {II,N}, pairint {II,N});
180+
double_cmat C11 = slice_double_matrix(matC, pairint {0,II}, pairint {0,II});
181+
double_cmat C12 = slice_double_matrix(matC, pairint {0,II}, pairint {II,N});
182+
double_cmat C21 = slice_double_matrix(matC, pairint {II,N}, pairint {0,II});
183+
double_cmat C22 = slice_double_matrix(matC, pairint {II,N}, pairint {II,N});
180184
double_cmat S5, T5, M1, M2, M3, M4, M5, M6, M7;
181-
create_double_matrix(pairint {I,I}, &S5);
182-
create_double_matrix(pairint {I,I}, &T5);
183-
create_double_matrix(pairint {I,I}, &M1);
184-
create_double_matrix(pairint {I,I}, &M2);
185-
create_double_matrix(pairint {I,I}, &M3);
186-
create_double_matrix(pairint {I,I}, &M4);
187-
create_double_matrix(pairint {I,I}, &M5);
188-
create_double_matrix(pairint {I,I}, &M6);
189-
create_double_matrix(pairint {I,I}, &M7);
185+
create_double_matrix(pairint {II,II}, &S5);
186+
create_double_matrix(pairint {II,II}, &T5);
187+
create_double_matrix(pairint {II,II}, &M1);
188+
create_double_matrix(pairint {II,II}, &M2);
189+
create_double_matrix(pairint {II,II}, &M3);
190+
create_double_matrix(pairint {II,II}, &M4);
191+
create_double_matrix(pairint {II,II}, &M5);
192+
create_double_matrix(pairint {II,II}, &M6);
193+
create_double_matrix(pairint {II,II}, &M7);
190194

191195

192196

@@ -209,8 +213,8 @@ int matmul_double_recursive_bilinear(double_cmat matC, double_cmat matA, double_
209213
matadd_double(C11, M1, M2);
210214
matsub_double(C12, M5, M7);
211215
matadd_double(C21, M3, M6);
212-
for (int i=0; i < I; i++) {
213-
for (int j=0; j < I; j++) {
216+
for (int i=0; i < II; i++) {
217+
for (int j=0; j < II; j++) {
214218
C22.data[i][j] = M5.data[i][j] + M6.data[i][j] - M2.data[i][j] - M4.data[i][j];
215219
}
216220
}
@@ -244,22 +248,22 @@ int matmul_double_schwartz2024(double_cmat matC, double_cmat matA, double_cmat m
244248
return -1;
245249
}
246250
int N = matA.shape[0];
247-
int I = matA.shape[0]/2;
251+
int II = matA.shape[0]/2;
248252
if (N <= 4096) {
249253
return matmul_double_blas(matC, matA, matB);
250254
}
251-
double_cmat A11 = slice_double_matrix(matA, pairint {0,I}, pairint {0,I});
252-
double_cmat A12 = slice_double_matrix(matA, pairint {0,I}, pairint {I,N});
253-
double_cmat A21 = slice_double_matrix(matA, pairint {I,N}, pairint {0,I});
254-
double_cmat A22 = slice_double_matrix(matA, pairint {I,N}, pairint {I,N});
255-
double_cmat B11 = slice_double_matrix(matB, pairint {0,I}, pairint {0,I});
256-
double_cmat B12 = slice_double_matrix(matB, pairint {0,I}, pairint {I,N});
257-
double_cmat B21 = slice_double_matrix(matB, pairint {I,N}, pairint {0,I});
258-
double_cmat B22 = slice_double_matrix(matB, pairint {I,N}, pairint {I,N});
259-
double_cmat C11 = slice_double_matrix(matC, pairint {0,I}, pairint {0,I});
260-
double_cmat C12 = slice_double_matrix(matC, pairint {0,I}, pairint {I,N});
261-
double_cmat C21 = slice_double_matrix(matC, pairint {I,N}, pairint {0,I});
262-
double_cmat C22 = slice_double_matrix(matC, pairint {I,N}, pairint {I,N});
255+
double_cmat A11 = slice_double_matrix(matA, pairint {0,II}, pairint {0,II});
256+
double_cmat A12 = slice_double_matrix(matA, pairint {0,II}, pairint {II,N});
257+
double_cmat A21 = slice_double_matrix(matA, pairint {II,N}, pairint {0,II});
258+
double_cmat A22 = slice_double_matrix(matA, pairint {II,N}, pairint {II,N});
259+
double_cmat B11 = slice_double_matrix(matB, pairint {0,II}, pairint {0,II});
260+
double_cmat B12 = slice_double_matrix(matB, pairint {0,II}, pairint {II,N});
261+
double_cmat B21 = slice_double_matrix(matB, pairint {II,N}, pairint {0,II});
262+
double_cmat B22 = slice_double_matrix(matB, pairint {II,N}, pairint {II,N});
263+
double_cmat C11 = slice_double_matrix(matC, pairint {0,II}, pairint {0,II});
264+
double_cmat C12 = slice_double_matrix(matC, pairint {0,II}, pairint {II,N});
265+
double_cmat C21 = slice_double_matrix(matC, pairint {II,N}, pairint {0,II});
266+
double_cmat C22 = slice_double_matrix(matC, pairint {II,N}, pairint {II,N});
263267
// pa11 = a21 + a22
264268
// pa12 = a22
265269
// pa21 = -a11 - a12
@@ -279,8 +283,8 @@ int matmul_double_schwartz2024(double_cmat matC, double_cmat matA, double_cmat m
279283
// pb22 = b11 + b22
280284
double_cmat tmp1;
281285
double_cmat tmp2;
282-
create_double_matrix(pairint {I,I}, &tmp1);
283-
create_double_matrix(pairint {I,I}, &tmp2);
286+
create_double_matrix(pairint {II,II}, &tmp1);
287+
create_double_matrix(pairint {II,II}, &tmp2);
284288
assign_double_clone(tmp2, A11);
285289
matneg_double(tmp1, A11);
286290
matadd_double(A11, A21, A22);
@@ -313,8 +317,8 @@ int matmul_double_schwartz2024(double_cmat matC, double_cmat matA, double_cmat m
313317
assign_double_clone(tmp1, C21);
314318
assign_double_clone(C21, C11);
315319
matsub_double(C11, tmp1, C22);
316-
for (int i=0; i < I; i++) {
317-
for (int j=0; j < I; j++) {
320+
for (int i=0; i < II; i++) {
321+
for (int j=0; j < II; j++) {
318322
C22.data[i][j] = C12.data[i][j] - C21.data[i][j] - C22.data[i][j];
319323
}
320324
}

0 commit comments

Comments
 (0)