Skip to content

Commit 24fa396

Browse files
zuhengxutorfjeldedevmotion
authored
making rqs compatible with float32 input (#267)
* making rqs compatible with float32 input * minor format edit * Update Format.yml * fix format * Update Format.yml * Update Format.yml * save additional allocations in rqs layer Co-authored-by: David Widmann <[email protected]> * rm allocations in rqs layer Co-authored-by: David Widmann <[email protected]> * bump version to 0.12.6 * add tests for rqs --------- Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent dd8a24b commit 24fa396

File tree

4 files changed

+49
-5
lines changed

4 files changed

+49
-5
lines changed

.github/workflows/Format.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ jobs:
3535
if: github.event_name == 'pull_request'
3636
with:
3737
tool_name: JuliaFormatter
38-
fail_on_error: true
38+
fail_on_error: true

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Bijectors"
22
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
3-
version = "0.12.5"
3+
version = "0.12.6"
44

55
[deps]
66
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/bijectors/rational_quadratic_spline.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ function RationalQuadraticSpline(
100100
widths::A, heights::A, derivatives::A, B::T2
101101
) where {T1,T2,A<:AbstractVector{T1}}
102102
return RationalQuadraticSpline(
103-
(cumsum(vcat([zero(T1)], LogExpFunctions.softmax(widths))) .- 0.5) * 2 * B,
104-
(cumsum(vcat([zero(T1)], LogExpFunctions.softmax(heights))) .- 0.5) * 2 * B,
103+
cumsum(vcat([zero(T1)], LogExpFunctions.softmax(widths))) .* (2 * B) .- B,
104+
cumsum(vcat([zero(T1)], LogExpFunctions.softmax(heights))) .* (2 * B) .- B,
105105
vcat([one(T1)], LogExpFunctions.log1pexp.(derivatives), [one(T1)]),
106106
)
107107
end
@@ -118,7 +118,7 @@ function RationalQuadraticSpline(
118118
)
119119

120120
return RationalQuadraticSpline(
121-
(2 * B) .* (cumsum(ws; dims=2) .- 0.5), (2 * B) .* (cumsum(hs; dims=2) .- 0.5), ds
121+
(2 * B) .* cumsum(ws; dims=2) .- B, (2 * B) .* cumsum(hs; dims=2) .- B, ds
122122
)
123123
end
124124

test/bijectors/rational_quadratic_spline.jl

+44
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Test
22
using Bijectors
33
using Bijectors: RationalQuadraticSpline
4+
using LogExpFunctions
45

56
@testset "RationalQuadraticSpline" begin
67
# Monotonic spline on '[-B, B]' with `K` intermediate knots/"connection points".
@@ -59,4 +60,47 @@ using Bijectors: RationalQuadraticSpline
5960
x = [-5.0, 5.0]
6061
test_bijector(b, x; y=x, logjac=zero(eltype(x)))
6162
end
63+
64+
@testset "Float32 support" begin
65+
ws = randn(Float32, K)
66+
hs = randn(Float32, K)
67+
ds = randn(Float32, K - 1)
68+
69+
Ws = randn(Float32, d, K)
70+
Hs = randn(Float32, d, K)
71+
Ds = randn(Float32, d, K - 1)
72+
73+
# success of construction
74+
b = RationalQuadraticSpline(ws, hs, ds, B)
75+
bb = RationalQuadraticSpline(Ws, Hs, Ds, B)
76+
end
77+
78+
@testset "consistency after commit" begin
79+
ws = randn(K)
80+
hs = randn(K)
81+
ds = randn(K - 1)
82+
83+
Ws = randn(d, K)
84+
Hs = randn(d, K)
85+
Ds = randn(d, K - 1)
86+
87+
Ws_t = hcat(zeros(size(Ws, 1)), LogExpFunctions.softmax(Ws; dims=2))
88+
Hs_t = hcat(zeros(size(Ws, 1)), LogExpFunctions.softmax(Hs; dims=2))
89+
90+
# success of construction
91+
b = RationalQuadraticSpline(ws, hs, ds, B)
92+
b_mv = RationalQuadraticSpline(Ws, Hs, Ds, B)
93+
94+
# consistency of evaluation
95+
@test all(
96+
(cumsum(vcat([zero(Float64)], LogExpFunctions.softmax(ws))) .- 0.5) * 2 * B .≈
97+
b.widths,
98+
)
99+
@test all(
100+
(cumsum(vcat([zero(Float64)], LogExpFunctions.softmax(hs))) .- 0.5) * 2 * B .≈
101+
b.heights,
102+
)
103+
@test all((2 * B) .* (cumsum(Ws_t; dims=2) .- 0.5) .≈ b_mv.widths)
104+
@test all((2 * B) .* (cumsum(Hs_t; dims=2) .- 0.5) .≈ b_mv.heights)
105+
end
62106
end

0 commit comments

Comments
 (0)