Skip to content

Commit c0315a4

Browse files
committed
update lib
1 parent 13cd8b0 commit c0315a4

File tree

3 files changed

+18
-26
lines changed

3 files changed

+18
-26
lines changed

cmat.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ int create_float_matrix(int shape[2], float_cmat* p_new_mat) {
5050
}
5151

5252
int create_double_matrix(int shape[2], double_cmat* p_new_mat) {
53+
//printf("create mat\n");
54+
//printf("shape %d %d\n", shape[0], shape[1]);
5355
// create a matrix with 0.0 in shape [shape[0], shape[1]]
5456
int N = shape[0] * shape[1];
5557
(*p_new_mat).data = (double **)malloc(shape[0] * sizeof(double *));
@@ -66,6 +68,7 @@ int create_double_matrix(int shape[2], double_cmat* p_new_mat) {
6668
for (int j = 0; j < N; j++) {
6769
(*p_new_mat).arena[j] = 0.0;
6870
}
71+
//printf("create mat finished\n");
6972
return 0;
7073
}
7174

@@ -205,6 +208,7 @@ float_cmat slice_float_matrix(float_cmat mat, int slice0[2], int slice1[2]) {
205208
}
206209

207210
double_cmat slice_double_matrix(double_cmat mat, int slice0[2], int slice1[2]) {
211+
//printf("slice double\n");
208212
double_cmat empty_mat;
209213
empty_mat.data = NULL;
210214
empty_mat.arena = NULL;
@@ -244,6 +248,7 @@ double_cmat slice_double_matrix(double_cmat mat, int slice0[2], int slice1[2]) {
244248
}
245249

246250
int create_slice_double_matrix_contiguous(double_cmat *dst, double_cmat mat, int slice0[2], int slice1[2]) {
251+
//printf("slice conting\n");
247252
// dst = mat[slice0, slice1]
248253
if (slice0[1] < 0) {
249254
slice0[1] += mat.shape[0];
@@ -267,6 +272,7 @@ int create_slice_double_matrix_contiguous(double_cmat *dst, double_cmat mat, int
267272
}
268273

269274
int create_double_contiguous_from_slice(double_cmat *dest, double_cmat *src) {
275+
//printf("create contig from slice\n");
270276
// dest = src.contiguous().copy()
271277
int i, j;
272278
int rows = src->shape[0];
@@ -333,6 +339,7 @@ int assign_float_slice(float_cmat m1, float_cmat m2, int slice0[2], int slice1[2
333339
}
334340

335341
int assign_double_slice(double_cmat m1, double_cmat m2, int slice0[2], int slice1[2]) {
342+
//printf("assign slice\n");
336343
// assign m2 to a slice of m1 defined by slice0(x) and slice1(y)
337344
// m1[slice0, slice1] = m2
338345
if (slice0[1] < 0) {
@@ -355,6 +362,7 @@ int assign_double_slice(double_cmat m1, double_cmat m2, int slice0[2], int slice
355362
}
356363

357364
int assign_double_clone(double_cmat m1, double_cmat m2) {
365+
//printf("assign double clone\n");
358366
// m1 = m2.copy()
359367
if (!(m1.shape[0] == m2.shape[0] && m1.shape[1] == m2.shape[1])) {
360368
return -1;
@@ -366,7 +374,8 @@ int assign_double_clone(double_cmat m1, double_cmat m2) {
366374
return 0;
367375
}
368376

369-
int matlincomb_double_contiguous(double_cmat res, int n_mats, double_cmat* mats, double* coeffs) {
377+
int matlincomb_double_contiguous(double_cmat res, int n_mats, double_cmat* mats, int8_t* coeffs) {
378+
//printf("lincomb\n");
370379
// res = coeffs[0] * mats[0] + ... + coeffs[n_mats-1] * mats[n_mats-1]
371380
// memset(&res.data[0][0], 0, sizeof(res.data[0][0])*res.shape[0]*res.shape[1]); // should not reset because it could appear in RHS
372381
if (n_mats <= 0) {
@@ -663,6 +672,7 @@ int free_float_matrix(float_cmat m) {
663672
}
664673

665674
int free_double_matrix(double_cmat m) {
675+
//printf("free\n");
666676
free(m.data);
667677
if (m.arena_shape[0] == m.shape[0] && m.arena_shape[1] == m.shape[1]) {
668678
free(m.arena);

cmat.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ struct float_cmat {
2424
};
2525

2626
struct double_cmat {
27-
double* arena; // keep all the data in one arena
28-
double** data; // two dim indexing
27+
double* restrict arena; // keep all the data in one arena
28+
double** restrict data; // two dim indexing
2929
int shape[2];
3030
int arena_shape[2]; // to record the original matrix shape for sliced matrix indexing
3131
int offset[2]; // to offset sliced matrix index
@@ -68,7 +68,7 @@ int assign_double_slice(double_cmat m1, double_cmat m2, int slice0[2], int slice
6868

6969
int assign_double_clone(double_cmat m1, double_cmat m2);
7070

71-
int matlincomb_double_contiguous(double_cmat res, int n_mats, double_cmat* mats, double* coeffs);
71+
int matlincomb_double_contiguous(double_cmat res, int n_mats, double_cmat* mats, int8_t* coeffs);
7272

7373
int matadd_int(int_cmat m1, int_cmat m2, int_cmat m3);
7474

matmul.c

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,38 +39,20 @@ int matmul_double(double_cmat matC, double_cmat matA, double_cmat matB){
3939
}
4040

4141
int matmul_double_blas(double_cmat C, double_cmat A_slice, double_cmat B_slice) {
42+
//printf("start matmul blas %d %d\n", A_slice.shape[0], A_slice.shape[1]);
4243
// Check dimensions for compatibility
4344
if (A_slice.shape[1] != B_slice.shape[0] || A_slice.shape[0] != C.shape[0] || B_slice.shape[1] != C.shape[1]) {
4445
fprintf(stderr, "Matrix dimensions do not match for multiplication.\n");
4546
return -1;
4647
}
47-
double_cmat CC;
48-
create_double_matrix(pairint {C.shape[0], C.shape[1]}, &CC);
49-
50-
if (is_contiguous_double(A_slice) && is_contiguous_double(B_slice)) {
48+
//memset(&C.data[0][0], 0, sizeof(C.data[0][0])*C.shape[0]*C.shape[1]);
5149
// Call dgemm
5250
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
5351
C.shape[0], C.shape[1], A_slice.shape[1],
5452
1.0, A_slice.data[0], A_slice.arena_shape[1],
5553
B_slice.data[0], B_slice.arena_shape[1],
56-
0.0, CC.arena, CC.arena_shape[1]);
57-
}
58-
else {
59-
double_cmat A,B;
60-
create_double_contiguous_from_slice(&A, &A_slice);
61-
create_double_contiguous_from_slice(&B, &B_slice);
62-
63-
// Call dgemm
64-
cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
65-
C.shape[0], C.shape[1], A.shape[1],
66-
1.0, A.arena, A.arena_shape[1],
67-
B.arena, B.arena_shape[1],
68-
0.0, CC.arena, CC.arena_shape[1]);
69-
free_double_matrix(A);
70-
free_double_matrix(B);
71-
}
72-
assign_double_clone(C, CC);
73-
free_double_matrix(CC);
54+
0.0, C.arena, C.arena_shape[1]);
55+
//printf("finish matmul blas\n");
7456
return 0;
7557
}
7658

0 commit comments

Comments
 (0)