Skip to content

Commit 6519d41

Browse files
committed
Workaround failure in bench_cg. Evaluate x array early
1 parent 988f432 commit 6519d41

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

Diff for: examples/benchmarks/bench_cg.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def calc_arrayfire(A, b, x0, maxiter=10):
8080
beta_num = af.dot(r, r)
8181
beta = beta_num/alpha_num
8282
p = r + af.tile(beta, p.dims()[0]) * p
83+
af.eval(x)
8384
res = x0 - x
8485
return x, af.dot(res, res)
8586

@@ -137,11 +138,11 @@ def timeit(calc, iters, args):
137138

138139
def test():
139140
print("\nTesting benchmark functions...")
140-
A, b, x0 = setup_input(50) # dense A
141+
A, b, x0 = setup_input(50, 7) # dense A
141142
Asp = to_sparse(A)
142143
x1, _ = calc_arrayfire(A, b, x0)
143144
x2, _ = calc_arrayfire(Asp, b, x0)
144-
if af.sum(af.abs(x1 - x2)/x2 > 1e-6):
145+
if af.sum(af.abs(x1 - x2)/x2 > 1e-5):
145146
raise ValueError("arrayfire test failed")
146147
if np:
147148
An = to_numpy(A)
@@ -162,11 +163,13 @@ def test():
162163

163164

164165
def bench(n=4*1024, sparsity=7, maxiter=10, iters=10):
166+
165167
# generate data
166168
print("\nGenerating benchmark data for n = %i ..." %n)
167169
A, b, x0 = setup_input(n, sparsity) # dense A
168170
Asp = to_sparse(A) # sparse A
169171
input_info(A, Asp)
172+
170173
# make benchmarks
171174
print("Benchmarking CG solver for n = %i ..." %n)
172175
t1 = timeit(calc_arrayfire, iters, args=(A, b, x0, maxiter))
@@ -192,9 +195,8 @@ def bench(n=4*1024, sparsity=7, maxiter=10, iters=10):
192195
if (len(sys.argv) > 1):
193196
af.set_device(int(sys.argv[1]))
194197

195-
af.info()
196-
198+
af.info()
197199
test()
198-
200+
199201
for n in (128, 256, 512, 1024, 2048, 4096):
200202
bench(n)

0 commit comments

Comments
 (0)