Skip to content

Commit b4992b6

Browse files
committed
Fix meanvar & convolve2NNGradient API unit tests
1 parent 5203d02 commit b4992b6

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

tests/simple/signal.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def simple_signal(verbose=False):
103103

104104
c = af.convolve2NN(a, b)
105105
display_func(c)
106-
g = af.convolve2NN(a, b, c, gradType=af.CONV_GRADIENT.DATA)
106+
in_dims = c.dims()
107+
incoming_grad = af.constant(1, in_dims[0], in_dims[1]);
108+
g = af.convolve2GradientNN(incoming_grad, a, b, c)
107109
display_func(g)
108110

109111
a = af.randu(5, 5, 3)

tests/simple/statistics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ def simple_statistics(verbose=False):
3434
print_func(af.var(a, isbiased=True))
3535
print_func(af.var(a, weights=w))
3636

37-
mean, var = af.mean_var(a, dim=0)
37+
mean, var = af.meanvar(a, dim=0)
3838
display_func(mean)
3939
display_func(var)
40-
mean, var = af.mean_var(a, weights=w, bias=VARIANCE.SAMPLE, dim=0)
40+
mean, var = af.meanvar(a, weights=w, bias=af.VARIANCE.SAMPLE, dim=0)
4141
display_func(mean)
4242
display_func(var)
4343

0 commit comments

Comments
 (0)