@@ -257,7 +257,9 @@ def _infmean(input_array):
257
257
)
258
258
)
259
259
else :
260
- if n < 10 :
260
+ if n == 0 :
261
+ logger .info (f"Initialization only" )
262
+ elif n < 10 :
261
263
logger .info (f"Finished [100%]: Loss = { scores [- 1 ]:,.5g} " )
262
264
else :
263
265
avg_loss = _infmean (scores [max (0 , i - 1000 ) : i + 1 ])
@@ -466,7 +468,7 @@ class FullRankADVI(KLqp):
466
468
random_seed: None or int
467
469
leave None to use package global RandomStream or other
468
470
valid value to create instance specific one
469
- start: `Point `
471
+ start: `dict[str, np.ndarray]` or `StartDict `
470
472
starting point for inference
471
473
472
474
References
@@ -534,13 +536,11 @@ class SVGD(ImplicitGradient):
534
536
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
535
537
temperature: float
536
538
parameter responsible for exploration, higher temperature gives more broad posterior estimate
537
- start: `dict`
539
+ start: `dict[str, np.ndarray]` or `StartDict `
538
540
initial point for inference
539
541
random_seed: None or int
540
542
leave None to use package global RandomStream or other
541
543
valid value to create instance specific one
542
- start: `Point`
543
- starting point for inference
544
544
kwargs: other keyword arguments passed to estimator
545
545
546
546
References
@@ -631,7 +631,11 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
631
631
"is often **underestimated** when using temperature = 1."
632
632
)
633
633
if approx is None :
634
- approx = FullRank (model = kwargs .pop ("model" , None ))
634
+ approx = FullRank (
635
+ model = kwargs .pop ("model" , None ),
636
+ random_seed = kwargs .pop ("random_seed" , None ),
637
+ start = kwargs .pop ("start" , None ),
638
+ )
635
639
super ().__init__ (estimator = estimator , approx = approx , kernel = kernel , ** kwargs )
636
640
637
641
def fit (
0 commit comments