Skip to content

Commit 715526f

Browse files
committed
Improvements to ThreadSafeVarInfo (#425)
The `ThreadSafeVarInfo` still has quite a few implementation details inherited from compat with `VarInfo`. This PR makes `ThreadSafeVarInfo` work better with other implementations, e.g. `SimpleVarInfo`. It also fixes TuringLang/Turing.jl#1878 (comment)
1 parent e31a790 commit 715526f

File tree

5 files changed

+128
-13
lines changed

5 files changed

+128
-13
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.20.1"
3+
version = "0.20.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/compiler.jl

+2-5
Original file line numberDiff line numberDiff line change
@@ -684,17 +684,14 @@ For example, if `T === Float64` and `spl::Hamiltonian`, the matching type is
684684
"""
685685
get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} = T
686686
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Union{Missing,AbstractFloat}})
687-
return Union{Missing,floatof(eltype(vi, spl))}
687+
return Union{Missing,float_type_with_fallback(eltype(vi, spl))}
688688
end
689689
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:AbstractFloat})
690-
return floatof(eltype(vi, spl))
690+
return float_type_with_fallback(eltype(vi, spl))
691691
end
692692
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T,N}}) where {T,N}
693693
return Array{get_matching_type(spl, vi, T),N}
694694
end
695695
function get_matching_type(spl::AbstractSampler, vi, ::Type{<:Array{T}}) where {T}
696696
return Array{get_matching_type(spl, vi, T)}
697697
end
698-
699-
floatof(::Type{T}) where {T<:Real} = typeof(one(T) / one(T))
700-
floatof(::Type) = Real # fallback if type inference failed

src/simple_varinfo.jl

+32-4
Original file line numberDiff line numberDiff line change
@@ -202,18 +202,37 @@ struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo
202202
transformation::C
203203
end
204204

205-
SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, NoTransformation())
205+
# Makes things a bit more readable vs. putting `Float64` everywhere.
206+
const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64
206207

208+
function SimpleVarInfo{NT,T}(values, logp) where {NT,T}
209+
return SimpleVarInfo{NT,T,NoTransformation}(values, logp, NoTransformation())
210+
end
207211
function SimpleVarInfo{T}(θ) where {T<:Real}
208-
return SimpleVarInfo(θ, zero(T))
212+
return SimpleVarInfo{typeof(θ),T}(θ, zero(T))
213+
end
214+
215+
# Constructors without type-specification.
216+
SimpleVarInfo(θ) = SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ)
217+
function SimpleVarInfo::Union{<:NamedTuple,<:AbstractDict})
218+
return if isempty(θ)
219+
# Can't infer from values, so we just use default.
220+
SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ)
221+
else
222+
# Infer from `values`.
223+
SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(θ)))}(θ)
224+
end
209225
end
226+
227+
SimpleVarInfo(values, logp) = SimpleVarInfo{typeof(values),typeof(logp)}(values, logp)
228+
229+
# Using `kwargs` to specify the values.
210230
function SimpleVarInfo{T}(; kwargs...) where {T<:Real}
211231
return SimpleVarInfo{T}(NamedTuple(kwargs))
212232
end
213233
function SimpleVarInfo(; kwargs...)
214-
return SimpleVarInfo{Float64}(NamedTuple(kwargs))
234+
return SimpleVarInfo(NamedTuple(kwargs))
215235
end
216-
SimpleVarInfo(θ) = SimpleVarInfo{Float64}(θ)
217236

218237
# Constructor from `Model`.
219238
SimpleVarInfo(model::Model, args...) = SimpleVarInfo{Float64}(model, args...)
@@ -582,3 +601,12 @@ julia> # Truth.
582601
```
583602
"""
584603
Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ))
604+
605+
# Threadsafe stuff.
606+
# For `SimpleVarInfo` we don't really need `Ref` so let's not use it.
607+
function ThreadSafeVarInfo(vi::SimpleVarInfo)
608+
return ThreadSafeVarInfo(vi, zeros(typeof(getlogp(vi)), Threads.nthreads()))
609+
end
610+
function ThreadSafeVarInfo(vi::SimpleVarInfo{<:Any,<:Ref})
611+
return ThreadSafeVarInfo(vi, [Ref(zero(getlogp(vi))) for _ in 1:Threads.nthreads()])
612+
end

src/threadsafe.jl

+17-3
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,40 @@ function ThreadSafeVarInfo(vi::AbstractVarInfo)
1313
end
1414
ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi
1515

16+
const ThreadSafeVarInfoWithRef{V<:AbstractVarInfo} = ThreadSafeVarInfo{
17+
V,<:AbstractArray{<:Ref}
18+
}
19+
1620
# Instead of updating the log probability of the underlying variables we
1721
# just update the array of log probabilities.
1822
function acclogp!!(vi::ThreadSafeVarInfo, logp)
23+
vi.logps[Threads.threadid()] += logp
24+
return vi
25+
end
26+
function acclogp!!(vi::ThreadSafeVarInfoWithRef, logp)
1927
vi.logps[Threads.threadid()][] += logp
2028
return vi
2129
end
2230

2331
# The current log probability of the variables has to be computed from
2432
# both the wrapped variables and the thread-specific log probabilities.
25-
getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(getindex, vi.logps)
33+
getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(vi.logps)
34+
getlogp(vi::ThreadSafeVarInfoWithRef) = getlogp(vi.varinfo) + sum(getindex, vi.logps)
2635

2736
# TODO: Make remaining methods thread-safe.
28-
2937
function resetlogp!!(vi::ThreadSafeVarInfo)
38+
return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), zero(vi.logps))
39+
end
40+
function resetlogp!!(vi::ThreadSafeVarInfoWithRef)
3041
for x in vi.logps
3142
x[] = zero(x[])
3243
end
3344
return ThreadSafeVarInfo(resetlogp!!(vi.varinfo), vi.logps)
3445
end
3546
function setlogp!!(vi::ThreadSafeVarInfo, logp)
47+
return ThreadSafeVarInfo(setlogp!!(vi.varinfo, logp), zero(vi.logps))
48+
end
49+
function setlogp!!(vi::ThreadSafeVarInfoWithRef, logp)
3650
for x in vi.logps
3751
x[] = zero(x[])
3852
end
@@ -104,7 +118,7 @@ end
104118

105119
isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo)
106120
function BangBang.empty!!(vi::ThreadSafeVarInfo)
107-
return resetlogp!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo)))
121+
return resetlogp!!(Setfield.@set!(vi.varinfo = empty!!(vi.varinfo)))
108122
end
109123

110124
function BangBang.push!!(

src/utils.jl

+76
Original file line numberDiff line numberDiff line change
@@ -546,3 +546,79 @@ function nested_haskey(dict::AbstractDict, vn::VarName)
546546

547547
return canview(child, value)
548548
end
549+
550+
"""
551+
float_type_with_fallback(x)
552+
553+
Return type corresponding to `float(typeof(x))` if possible; otherwise return `Real`.
554+
"""
555+
float_type_with_fallback(::Type) = Real
556+
float_type_with_fallback(::Type{T}) where {T<:Real} = float(T)
557+
558+
"""
559+
infer_nested_eltype(x::Type)
560+
561+
Recursively unwrap the type, returning the first type where `eltype(x) === typeof(x)`.
562+
563+
This is useful for obtaining a reasonable default `eltype` in deeply nested types.
564+
565+
# Examples
566+
```jldoctest
567+
julia> # `AbstractArrary`
568+
DynamicPPL.infer_nested_eltype(typeof([1.0]))
569+
Float64
570+
571+
julia> # `NamedTuple` with `Float32`
572+
DynamicPPL.infer_nested_eltype(typeof((x = [1f0], )))
573+
Float32
574+
575+
julia> # `AbstractDict`
576+
DynamicPPL.infer_nested_eltype(typeof(Dict(:x => [1.0, ])))
577+
Float64
578+
579+
julia> # Nesting of containers.
580+
DynamicPPL.infer_nested_eltype(typeof([Dict(:x => 1.0,) ]))
581+
Float64
582+
583+
julia> DynamicPPL.infer_nested_eltype(typeof([Dict(:x => [1.0,],) ]))
584+
Float64
585+
586+
julia> # Empty `Tuple`.
587+
DynamicPPL.infer_nested_eltype(typeof(()))
588+
Any
589+
590+
julia> # Empty `Dict`.
591+
DynamicPPL.infer_nested_eltype(typeof(Dict()))
592+
Any
593+
```
594+
"""
595+
function infer_nested_eltype(::Type{T}) where {T}
596+
ET = eltype(T)
597+
return ET === T ? T : infer_nested_eltype(ET)
598+
end
599+
600+
# We can do a better job than just `Any` with `Union`.
601+
infer_nested_eltype(::Type{Union{}}) = Any
602+
function infer_nested_eltype(::Type{U}) where {U<:Union}
603+
return promote_type(U.a, infer_nested_eltype(U.b))
604+
end
605+
606+
# Handle `NamedTuple` and `Tuple` specially given how prolific they are.
607+
function infer_nested_eltype(::Type{<:NamedTuple{<:Any,V}}) where {V}
608+
return infer_nested_eltype(V)
609+
end
610+
611+
# Recursively deal with `Tuple` so it has the potential of being compiled away.
612+
infer_nested_eltype(::Type{Tuple{T}}) where {T} = infer_nested_eltype(T)
613+
function infer_nested_eltype(::Type{T}) where {T<:Tuple{<:Any,Vararg{Any}}}
614+
return promote_type(
615+
infer_nested_eltype(Base.tuple_type_tail(T)),
616+
infer_nested_eltype(Base.tuple_type_head(T)),
617+
)
618+
end
619+
620+
# Handle `AbstractDict` differently since `eltype` results in a `Pair`.
621+
infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET)
622+
623+
# No need + causes issues for some AD backends, e.g. Zygote.
624+
ChainRulesCore.@non_differentiable infer_nested_eltype(x)

0 commit comments

Comments
 (0)