@@ -93,31 +93,28 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
93
93
(-2.3378770664093453, [1.0])
94
94
```
95
95
"""
96
- struct LogDensityFunction{
97
- M<: Model ,V<: AbstractVarInfo ,C<: AbstractContext ,AD<: Union{Nothing,ADTypes.AbstractADType}
98
- }
96
+ struct LogDensityFunction{M<: Model ,V<: AbstractVarInfo ,C<: AbstractContext }
99
97
" model used for evaluation"
100
98
model:: M
101
99
" varinfo used for evaluation"
102
100
varinfo:: V
103
101
" context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
104
102
context:: C
105
- " AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
106
- adtype:: AD
107
103
" (internal use only) gradient preparation object for the model"
108
104
prep:: Union{Nothing,DI.GradientPrep}
109
105
110
106
function LogDensityFunction (
111
107
model:: Model ,
112
108
varinfo:: AbstractVarInfo = VarInfo (model),
113
- context:: AbstractContext = leafcontext (model. context);
114
- adtype:: Union{ADTypes.AbstractADType,Nothing} = model. adtype,
109
+ context:: AbstractContext = leafcontext (model. context),
115
110
)
111
+ adtype = model. adtype
116
112
if adtype === nothing
117
113
prep = nothing
118
114
else
119
115
# Make backend-specific tweaks to the adtype
120
116
adtype = tweak_adtype (adtype, model, varinfo, context)
117
+ model = Model (model, adtype)
121
118
# Check whether it is supported
122
119
is_supported (adtype) ||
123
120
@warn " The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
@@ -138,8 +135,8 @@ struct LogDensityFunction{
138
135
)
139
136
end
140
137
end
141
- return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype) } (
142
- model, varinfo, context, adtype, prep
138
+ return new {typeof(model),typeof(varinfo),typeof(context)} (
139
+ model, varinfo, context, prep
143
140
)
144
141
end
145
142
end
@@ -157,10 +154,10 @@ Create a new LogDensityFunction using the model, varinfo, and context from the g
157
154
function LogDensityFunction (
158
155
f:: LogDensityFunction , adtype:: Union{Nothing,ADTypes.AbstractADType}
159
156
)
160
- return if adtype === f. adtype
157
+ return if adtype === f. model . adtype
161
158
f # Avoid recomputing prep if not needed
162
159
else
163
- LogDensityFunction (f. model, f. varinfo, f. context; adtype = adtype )
160
+ LogDensityFunction (Model ( f. model, adtype), f. varinfo, f. context)
164
161
end
165
162
end
166
163
@@ -187,35 +184,45 @@ end
187
184
# ## LogDensityProblems interface
188
185
189
186
function LogDensityProblems. capabilities (
190
- :: Type{<:LogDensityFunction{M,V,C,Nothing}}
191
- ) where {M,V,C}
187
+ :: Type {
188
+ <: LogDensityFunction {
189
+ Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Nothing},V,C
190
+ },
191
+ },
192
+ ) where {F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C}
192
193
return LogDensityProblems. LogDensityOrder {0} ()
193
194
end
194
195
function LogDensityProblems. capabilities (
195
- :: Type{<:LogDensityFunction{M,V,C,AD}}
196
- ) where {M,V,C,AD<: ADTypes.AbstractADType }
196
+ :: Type {
197
+ <: LogDensityFunction {
198
+ Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD},V,C
199
+ },
200
+ },
201
+ ) where {
202
+ F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C,TAD<: ADTypes.AbstractADType
203
+ }
197
204
return LogDensityProblems. LogDensityOrder {1} ()
198
205
end
199
206
function LogDensityProblems. logdensity (f:: LogDensityFunction , x:: AbstractVector )
200
207
return logdensity_at (x, f. model, f. varinfo, f. context)
201
208
end
202
209
function LogDensityProblems. logdensity_and_gradient (
203
- f:: LogDensityFunction{M,V,C,AD } , x:: AbstractVector
204
- ) where {M,V,C,AD <: ADTypes.AbstractADType }
210
+ f:: LogDensityFunction{M,V,C} , x:: AbstractVector
211
+ ) where {M,V,C}
205
212
f. prep === nothing &&
206
213
error (" Gradient preparation not available; this should not happen" )
207
214
x = map (identity, x) # Concretise type
208
215
# Make branching statically inferrable, i.e. type-stable (even if the two
209
216
# branches happen to return different types)
210
- return if use_closure (f. adtype)
217
+ return if use_closure (f. model . adtype)
211
218
DI. value_and_gradient (
212
- x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
219
+ x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. model . adtype, x
213
220
)
214
221
else
215
222
DI. value_and_gradient (
216
223
logdensity_at,
217
224
f. prep,
218
- f. adtype,
225
+ f. model . adtype,
219
226
x,
220
227
DI. Constant (f. model),
221
228
DI. Constant (f. varinfo),
@@ -292,7 +299,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
292
299
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
293
300
"""
294
301
function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
295
- return LogDensityFunction (model, f. varinfo, f. context; adtype = f . adtype )
302
+ return LogDensityFunction (model, f. varinfo, f. context)
296
303
end
297
304
298
305
"""
0 commit comments