@@ -75,108 +75,6 @@ def conditional(self, name, Xnew, jitter=1e-6, **kwargs):
75
75
return pm .MvNormal (name , mu = mu , chol = chol )
76
76
77
77
78
- class HSGP (pm .gp .Latent ):
79
- ## inputs: M, c
80
-
81
- def __init__ (
82
- self , n_basis , c = 3 / 2 , * , mean_func = pm .gp .mean .Zero (), cov_func = pm .gp .cov .Constant (0.0 )
83
- ):
84
- ## TODO: specify either c or L
85
- self .M = n_basis
86
- self .c = c
87
- super ().__init__ (mean_func = mean_func , cov_func = cov_func )
88
-
89
- def _validate_cov_func (self , cov_func ):
90
- ## TODO: actually validate it. Right now this fails unless cov func is exactly
91
- # in the form eta**2 * pm.gp.cov.Matern12(...) and will error otherwise.
92
- cov , scaling_factor = cov_func .factor_list
93
- return scaling_factor , cov .ls , cov .spectral_density
94
-
95
- def prior (self , name , X , ** kwargs ):
96
- f , Phi , L , spd , beta , Xmu , Xsd = self ._build_prior (name , X , ** kwargs )
97
- self .X , self .f = X , f
98
- self .Phi , self .L , self .spd , self .beta = Phi , L , spd , beta
99
- self .Xmu , self .Xsd = Xmu , Xsd
100
- return f
101
-
102
- def _generate_basis (self , X , L ):
103
- indices = pt .arange (1 , self .M + 1 )
104
- m1 = (np .pi / (2.0 * L )) * pt .tile (L + X , self .M )
105
- m2 = pt .diag (indices )
106
- Phi = pt .sin (m1 @ m2 ) / pt .sqrt (L )
107
- omega = (np .pi * indices ) / (2.0 * L )
108
- return Phi , omega
109
-
110
- def _build_prior (self , name , X , ** kwargs ):
111
- n_obs = np .shape (X )[0 ]
112
-
113
- # standardize input scale
114
- X = pt .as_tensor_variable (X )
115
- Xmu = pt .mean (X , axis = 0 )
116
- Xsd = pt .std (X , axis = 0 )
117
- Xz = (X - Xmu ) / Xsd
118
-
119
- # define L using Xz and c
120
- La = pt .abs (pt .min (Xz )) # .eval()?
121
- Lb = pt .max (Xz )
122
- L = self .c * pt .max ([La , Lb ])
123
-
124
- # make basis and omega, spectral density
125
- Phi , omega = self ._generate_basis (Xz , L )
126
- scale , ls , spectral_density = self ._validate_cov_func (self .cov_func )
127
- spd = scale * spectral_density (omega , ls / Xsd ).flatten ()
128
-
129
- beta = pm .Normal (f"{ name } _coeffs_" , size = self .M )
130
- f = pm .Deterministic (name , self .mean_func (X ) + pt .dot (Phi * pt .sqrt (spd ), beta ))
131
- return f , Phi , L , spd , beta , Xmu , Xsd
132
-
133
- def _build_conditional (self , Xnew , Xmu , Xsd , L , beta ):
134
- Xnewz = (Xnew - Xmu ) / Xsd
135
- Phi , omega = self ._generate_basis (Xnewz , L )
136
- scale , ls , spectral_density = self ._validate_cov_func (self .cov_func )
137
- spd = scale * spectral_density (omega , ls / Xsd ).flatten ()
138
- return self .mean_func (Xnew ) + pt .dot (Phi * pt .sqrt (spd ), beta )
139
-
140
- def conditional (self , name , Xnew ):
141
- # warn about extrapolation
142
- fnew = self ._build_conditional (Xnew , self .Xmu , self .Xsd , self .L , self .beta )
143
- return pm .Deterministic (name , fnew )
144
-
145
-
146
- class ExpQuad (pm .gp .cov .ExpQuad ):
147
- @staticmethod
148
- def spectral_density (omega , ls ):
149
- # univariate spectral denisty, implement multi
150
- return pt .sqrt (2 * np .pi ) * ls * pt .exp (- 0.5 * ls ** 2 * omega ** 2 )
151
-
152
-
153
- class Matern52 (pm .gp .cov .Matern52 ):
154
- @staticmethod
155
- def spectral_density (omega , ls ):
156
- # univariate spectral denisty, implement multi
157
- # https://arxiv.org/pdf/1611.06740.pdf
158
- lam = pt .sqrt (5 ) * (1.0 / ls )
159
- return (16.0 / 3.0 ) * lam ** 5 * (1.0 / (lam ** 2 + omega ** 2 ) ** 3 )
160
-
161
-
162
- class Matern32 (pm .gp .cov .Matern32 ):
163
- @staticmethod
164
- def spectral_density (omega , ls ):
165
- # univariate spectral denisty, implement multi
166
- # https://arxiv.org/pdf/1611.06740.pdf
167
- lam = np .sqrt (3.0 ) * (1.0 / ls )
168
- return 4.0 * lam ** 3 * (1.0 / pt .square (lam ** 2 + omega ** 2 ))
169
-
170
-
171
- class Matern12 (pm .gp .cov .Matern12 ):
172
- @staticmethod
173
- def spectral_density (omega , ls ):
174
- # univariate spectral denisty, implement multi
175
- # https://arxiv.org/pdf/1611.06740.pdf
176
- lam = 1.0 / ls
177
- return 2.0 * lam * (1.0 / (lam ** 2 + omega ** 2 ))
178
-
179
-
180
78
class KarhunenLoeveExpansion (pm .gp .Latent ):
181
79
def __init__ (
182
80
self ,
0 commit comments