Skip to content

Commit 3b3ad42

Browse files
committed
tests : add SVD experiments
1 parent a6acb33 commit 3b3ad42

File tree

2 files changed

+230
-0
lines changed

2 files changed

+230
-0
lines changed

tests/CMakeLists.txt

+12
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,15 @@ set(TEST_TARGET test3)
104104
add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
105105
target_link_libraries(${TEST_TARGET} PRIVATE ggml)
106106
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
107+
108+
#
109+
# test-svd0 (arm)
110+
111+
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" AND NOT GGML_NO_ACCELERATE)
112+
set(TEST_TARGET test-svd0)
113+
add_executable(${TEST_TARGET} ${TEST_TARGET}.c)
114+
target_link_libraries(${TEST_TARGET} PRIVATE ggml ${GGML_EXTRA_LIBS})
115+
target_compile_options(${TEST_TARGET} PRIVATE ${GGML_EXTRA_FLAGS})
116+
add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}>)
117+
endif()
118+

tests/test-svd0.c

+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// SVD dimensionality reduction
2+
3+
#include <float.h>
4+
#include <stdint.h>
5+
#include <stdio.h>
6+
#include <assert.h>
7+
#include <stdlib.h>
8+
#include <string.h>
9+
#include <time.h>
10+
#include <math.h>
11+
12+
#include <sys/time.h>
13+
14+
#ifdef GGML_USE_ACCELERATE
15+
#include <Accelerate/Accelerate.h>
16+
#endif
17+
18+
float frand() {
19+
return (float) rand() / (float) RAND_MAX;
20+
}
21+
22+
//int sgesvd_(char *__jobu, char *__jobvt, __CLPK_integer *__m,
23+
// __CLPK_integer *__n, __CLPK_real *__a, __CLPK_integer *__lda,
24+
// __CLPK_real *__s, __CLPK_real *__u, __CLPK_integer *__ldu,
25+
// __CLPK_real *__vt, __CLPK_integer *__ldvt, __CLPK_real *__work,
26+
// __CLPK_integer *__lwork,
27+
// __CLPK_integer *__info)
28+
29+
int main(int argc, const char ** argv) {
30+
int m = 10;
31+
int n = 5;
32+
33+
float * A = (float *) malloc(n * m * sizeof(float));
34+
float * A0 = (float *) malloc(n * m * sizeof(float));
35+
36+
for (int i = 0; i < n; ++i) {
37+
for (int j = 0; j < m; ++j) {
38+
A[i * m + j] = (float) (10.0f*(i + 1) + 1.0f * frand());
39+
//A[i * m + j] = (float) (10.0f*(i%2 + 1) + 0.1f * frand());
40+
//if (i == 2) {
41+
// A[i * m + j] += 20*frand();
42+
//}
43+
if ((i == 1 || i == 3) && j > m/2) {
44+
A[i * m + j] = -A[i * m + j];
45+
}
46+
}
47+
}
48+
49+
// average vector
50+
//float * M = (float *) malloc(m * sizeof(float));
51+
52+
//{
53+
// for (int j = 0; j < m; ++j) {
54+
// M[j] = 0.0f;
55+
// }
56+
// for (int i = 0; i < n; ++i) {
57+
// for (int j = 0; j < m; ++j) {
58+
// M[j] += A[i * m + j];
59+
// }
60+
// }
61+
// for (int j = 0; j < m; ++j) {
62+
// M[j] /= (float) n;
63+
// }
64+
//}
65+
66+
//// subtract average vector
67+
//for (int i = 0; i < n; ++i) {
68+
// for (int j = 0; j < m; ++j) {
69+
// A[i * m + j] -= M[j];
70+
// }
71+
//}
72+
73+
memcpy(A0, A, n * m * sizeof(float));
74+
75+
// print A
76+
printf("A:\n");
77+
for (int i = 0; i < n; ++i) {
78+
printf("col %d : ", i);
79+
for (int j = 0; j < m; ++j) {
80+
printf("%9.5f ", A[i * m + j]);
81+
}
82+
printf("\n");
83+
}
84+
printf("\n");
85+
86+
// SVD
87+
// A = U * S * V^T
88+
89+
float * U = (float *) malloc(n * m * sizeof(float));
90+
float * S = (float *) malloc(n * sizeof(float));
91+
float * V = (float *) malloc(n * n * sizeof(float));
92+
93+
int lda = m;
94+
int ldu = m;
95+
int ldvt = n;
96+
97+
float work_size;
98+
int lwork = -1;
99+
int info = 0;
100+
101+
sgesvd_("S", "S", &m, &n, A, &lda, S, U, &ldu, V, &ldvt, &work_size, &lwork, &info);
102+
103+
lwork = (int) work_size;
104+
105+
printf("work_size = %f, info = %d, lwork = %d\n", work_size, info, lwork);
106+
107+
float * work = (float *) malloc(lwork * sizeof(float));
108+
109+
sgesvd_("S", "S", &m, &n, A, &lda, S, U, &ldu, V, &ldvt, work, &lwork, &info);
110+
111+
// print U
112+
printf("U:\n");
113+
for (int i = 0; i < n; ++i) {
114+
printf("col %d : ", i);
115+
for (int j = 0; j < m; ++j) {
116+
printf("%9.5f ", U[i * m + j]);
117+
}
118+
printf("\n");
119+
}
120+
printf("\n");
121+
122+
// normalize S
123+
{
124+
double sum = 0.0;
125+
for (int i = 0; i < n; ++i) {
126+
sum += S[i];
127+
}
128+
sum *= sqrt((double) m);
129+
for (int i = 0; i < n; ++i) {
130+
S[i] /= sum;
131+
}
132+
}
133+
134+
// print S
135+
printf("S:\n");
136+
for (int i = 0; i < n; ++i) {
137+
printf("- %d = %9.5f\n", i, S[i]);
138+
}
139+
printf("\n");
140+
141+
// print V
142+
printf("V:\n");
143+
for (int i = 0; i < n; ++i) {
144+
printf("col %d : ", i);
145+
for (int j = 0; j < n; ++j) {
146+
printf("%9.5f ", V[i * n + j]);
147+
}
148+
printf("\n");
149+
}
150+
printf("\n");
151+
152+
// print A
153+
printf("A:\n");
154+
for (int i = 0; i < n; ++i) {
155+
printf("col %d : ", i);
156+
for (int j = 0; j < m; ++j) {
157+
printf("%9.5f ", A[i * m + j]);
158+
}
159+
printf("\n");
160+
}
161+
printf("\n");
162+
163+
// compute singular vectors in U
164+
for (int i = 0; i < n; ++i) {
165+
for (int j = 0; j < m; ++j) {
166+
U[i * m + j] *= S[i];
167+
}
168+
}
169+
170+
// normalize U
171+
for (int i = 0; i < n; ++i) {
172+
double sum = 0.0;
173+
for (int j = 0; j < m; ++j) {
174+
sum += U[i * m + j] * U[i * m + j];
175+
}
176+
sum = sqrt(sum);
177+
for (int j = 0; j < m; ++j) {
178+
U[i * m + j] /= sum*sqrt((double) m);
179+
}
180+
}
181+
182+
// print U
183+
printf("U:\n");
184+
for (int i = 0; i < n; ++i) {
185+
printf("col %d : ", i);
186+
for (int j = 0; j < m; ++j) {
187+
printf("%9.5f ", U[i * m + j]);
188+
}
189+
printf("\n");
190+
}
191+
printf("\n");
192+
193+
194+
// project A0 onto U
195+
float * A1 = (float *) malloc(n * n * sizeof(float));
196+
197+
for (int i = 0; i < n; ++i) {
198+
for (int j = 0; j < n; ++j) {
199+
A1[i * n + j] = 0.0f;
200+
for (int k = 0; k < m; ++k) {
201+
A1[i * n + j] += A0[i * m + k] * U[j * m + k];
202+
}
203+
}
204+
}
205+
206+
// print A1
207+
printf("A1:\n");
208+
for (int i = 0; i < n; ++i) {
209+
printf("col %d : ", i);
210+
for (int j = 0; j < n; ++j) {
211+
printf("%9.5f ", A1[i * n + j]);
212+
}
213+
printf("\n");
214+
}
215+
printf("\n");
216+
217+
return 0;
218+
}

0 commit comments

Comments
 (0)