4
4
5
5
import numpy as np
6
6
7
+ from adaptive .learner .average_mixin import DataPoint
7
8
from adaptive .learner .base_learner import BaseLearner
8
9
from adaptive .notebook_integration import ensure_holoviews
9
10
from adaptive .utils import cache_latest
@@ -40,14 +41,15 @@ def __init__(self, function, atol=None, rtol=None):
40
41
if rtol is None :
41
42
rtol = np .inf
42
43
43
- self .data = {}
44
+ self .data = DataPoint ()
44
45
self .pending_points = set ()
45
46
self .function = function
46
47
self .atol = atol
47
48
self .rtol = rtol
48
- self .npoints = 0
49
- self .sum_f = 0
50
- self .sum_f_sq = 0
49
+
50
+ @property
51
+ def npoints (self ):
52
+ return self .data .n
51
53
52
54
@property
53
55
def n_requested (self ):
@@ -72,35 +74,23 @@ def ask(self, n, tell_pending=True):
72
74
73
75
def tell (self , n , value ):
74
76
if n in self .data :
75
- # The point has already been added before.
76
77
return
77
-
78
78
self .data [n ] = value
79
79
self .pending_points .discard (n )
80
- self .sum_f += value
81
- self .sum_f_sq += value ** 2
82
- self .npoints += 1
83
80
84
81
def tell_pending (self , n ):
85
82
self .pending_points .add (n )
86
83
87
84
@property
88
85
def mean (self ):
89
86
"""The average of all values in `data`."""
90
- return self .sum_f / self . npoints
87
+ return self .data . mean
91
88
92
89
@property
93
90
def std (self ):
94
91
"""The corrected sample standard deviation of the values
95
92
in `data`."""
96
- n = self .npoints
97
- if n < 2 :
98
- return np .inf
99
- numerator = self .sum_f_sq - n * self .mean ** 2
100
- if numerator < 0 :
101
- # in this case the numerator ~ -1e-15
102
- return 0
103
- return sqrt (numerator / (n - 1 ))
93
+ return self .data .std
104
94
105
95
@cache_latest
106
96
def loss (self , real = True , * , n = None ):
@@ -110,10 +100,8 @@ def loss(self, real=True, *, n=None):
110
100
n = n
111
101
if n < 2 :
112
102
return np .inf
113
- standard_error = self .std / sqrt (n )
114
- return max (
115
- standard_error / self .atol , standard_error / abs (self .mean ) / self .rtol
116
- )
103
+ sem = self .data .standard_error
104
+ return max (sem / self .atol , sem / abs (self .mean ) / self .rtol )
117
105
118
106
def _loss_improvement (self , n ):
119
107
loss = self .loss ()
@@ -142,7 +130,7 @@ def plot(self):
142
130
return hv .operation .histogram (vals , num_bins = num_bins , dimension = 1 )
143
131
144
132
def _get_data (self ):
145
- return (self .data , self . npoints , self . sum_f , self . sum_f_sq )
133
+ return dict (self .data )
146
134
147
135
def _set_data (self , data ):
148
- self .data , self . npoints , self . sum_f , self . sum_f_sq = data
136
+ self .data = DataPoint ( data )
0 commit comments