Skip to content

Commit de01fb3

Browse files
Merge pull request #72 from dharux/sartsa-fix
Fix issues with SARTSATraces
2 parents b23dd43 + 29a6a3e commit de01fb3

9 files changed

+164
-36
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReinforcementLearningTrajectories"
22
uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c"
3-
version = "0.4"
3+
version = "0.4.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/common/CircularArraySARTSATraces.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ function CircularArraySARTSATraces(;
2424
reward_eltype, reward_size = reward
2525
terminal_eltype, terminal_size = terminal
2626

27-
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
27+
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+2)) +
2828
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
2929
Traces(
30-
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
31-
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
30+
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity+1),
31+
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity+1),
3232
)
3333
end
3434

src/common/CircularPrioritizedTraces.jl

+22-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ end
1212
function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
1313
new_names = (:key, :priority, names...)
1414
new_Ts = Tuple{Int,Float32,Ts.parameters...}
15-
c = capacity(traces)
15+
if traces isa CircularArraySARTSATraces
16+
c = capacity(traces) - 1
17+
else
18+
c = capacity(traces)
19+
end
1620
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
1721
CircularVectorBuffer{Int}(c),
1822
SumTree(c),
@@ -34,6 +38,22 @@ function Base.push!(t::CircularPrioritizedTraces, x)
3438
end
3539
end
3640

41+
function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x)
42+
initial_length = length(t.traces)
43+
push!(t.traces, x)
44+
if length(t.traces) == 1
45+
push!(t.keys, 1)
46+
push!(t.priorities, t.default_priority)
47+
elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 )
48+
# only add a key if the length changes after insertion of the tuple
49+
# or if the trace is already at capacity
50+
push!(t.keys, t.keys[end] + 1)
51+
push!(t.priorities, t.default_priority)
52+
else
53+
# may be partial inserting at the first step, ignore it
54+
end
55+
end
56+
3757
function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
3858
if k === :priority
3959
@assert length(vs) == length(keys)
@@ -48,6 +68,7 @@ function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys)
4868
end
4969

5070
Base.size(t::CircularPrioritizedTraces) = size(t.traces)
71+
max_length(t::CircularPrioritizedTraces) = max_length(t.traces)
5172

5273
function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
5374
if s === :priority

src/episodes.jl

+30-4
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces)
138138

139139
fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces)
140140

141+
max_length(eb::EpisodesBuffer) = max_length(eb.traces)
142+
141143
function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
142144
push!(eb.traces, xs)
143145
partial = ispartial_insert(eb, xs)
@@ -146,10 +148,12 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple)
146148
push!(eb.episodes_lengths, 0)
147149
push!(eb.sampleable_inds, 0)
148150
elseif !partial #typical inserting
149-
if length(eb.traces) < length(eb) && length(eb) > 2 #case when PartialNamedTuple is used. Steps are indexable one step later
150-
eb.sampleable_inds[end-1] = 1
151-
else #case when we don't, length of traces and eb will match.
152-
eb.sampleable_inds[end] = 1 #previous step is now indexable
151+
if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched
152+
if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode
153+
eb.sampleable_inds[end-1] = 1 # steps are indexable one step later
154+
end
155+
else
156+
eb.sampleable_inds[end] = 1 # otherwise, previous step is now indexable
153157
end
154158
push!(eb.sampleable_inds, 0) #this one is no longer
155159
ep_length = last(eb.step_numbers)
@@ -172,6 +176,28 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl
172176
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
173177
end
174178

179+
function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple)
180+
if max_length(eb) == capacity(eb.traces)
181+
popfirst!(eb)
182+
end
183+
push!(eb.traces, xs.namedtuple)
184+
eb.sampleable_inds[end-1] = 1 #completes the episode trajectory.
185+
end
186+
187+
function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}})
188+
if max_length(eb) == capacity(eb.traces)
189+
addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal])
190+
xs = merge(xs.namedtuple, addition)
191+
push!(eb.traces, xs)
192+
pop!(eb.traces[:state].trace)
193+
pop!(eb.traces[:reward])
194+
pop!(eb.traces[:terminal])
195+
else
196+
push!(eb.traces, xs.namedtuple)
197+
eb.sampleable_inds[end-1] = 1
198+
end
199+
end
200+
175201
for f in (:pop!, :popfirst!)
176202
@eval function Base.$f(eb::EpisodesBuffer)
177203
$f(eb.episodes_lengths)

src/samplers.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7474
t = e.traces
7575
p = collect(deepcopy(t.priorities))
7676
w = StatsBase.FrequencyWeights(p)
77-
w .*= e.sampleable_inds[1:end-1]
77+
w .*= e.sampleable_inds[1:length(t)]
7878
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
7979
NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...))
8080
end
@@ -247,7 +247,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
247247
p = collect(deepcopy(t.priorities))
248248
w = StatsBase.FrequencyWeights(p)
249249
valids, ns = valid_range(s,e)
250-
w .*= valids[1:end-1]
250+
w .*= valids[1:length(t)]
251251
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
252252
merge(
253253
(key=t.keys[inds], priority=p[inds]),
@@ -362,7 +362,7 @@ function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, <
362362
p = collect(deepcopy(t.priorities))
363363
w = StatsBase.FrequencyWeights(p)
364364
valids, ns = valid_range(s,e)
365-
w .*= valids[1:end-1]
365+
w .*= valids[1:length(t)]
366366
inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize)
367367
merge(
368368
(key=t.keys[inds], priority=p[inds]),

src/traces.jl

+1
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ function Base.:(+)(t1::Traces{k1,T1,N1,E1}, t2::Traces{k2,T2,N2,E2}) where {k1,T
247247
end
248248

249249
Base.size(t::Traces) = (mapreduce(length, min, t.traces),)
250+
max_length(t::Traces) = mapreduce(length, max, t.traces)
250251

251252
function capacity(t::Traces{names,Trs,N,E}) where {names,Trs,N,E}
252253
minimum(map(idx->capacity(t[idx]), names))

test/common.jl

+70-10
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
@test length(t) == 0
2525
end
2626

27-
@testset "CircularArraySARTSTraces" begin
27+
@testset "CircularArraySARTSATraces" begin
2828
t = CircularArraySARTSATraces(;
2929
capacity=3,
3030
state=Float32 => (2, 3),
@@ -35,13 +35,14 @@ end
3535

3636
@test t isa CircularArraySARTSATraces
3737

38-
push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2)) |> gpu)
38+
push!(t, (state=ones(Float32, 2, 3),))
39+
push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu)
3940
@test length(t) == 0
4041

4142
push!(t, (reward=1.0f0, terminal=false) |> gpu)
42-
@test length(t) == 0 # next_state and next_action is still missing
43+
@test length(t) == 0 # next_action is still missing
4344

44-
push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2) |> gpu)
45+
push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 2) |> gpu)
4546
@test length(t) == 1
4647

4748
# this will trigger the scalar indexing of CuArray
@@ -55,17 +56,18 @@ end
5556
)
5657

5758
push!(t, (reward=2.0f0, terminal=false))
58-
push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3) |> gpu)
59+
push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 3) |> gpu)
5960

6061
@test length(t) == 2
6162

6263
push!(t, (reward=3.0f0, terminal=false))
63-
push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4) |> gpu)
64+
push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 4) |> gpu)
6465

6566
@test length(t) == 3
6667

6768
push!(t, (reward=4.0f0, terminal=false))
68-
push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5) |> gpu)
69+
push!(t, (state=ones(Float32, 2, 3) * 6, action=ones(Float32, 2) * 5) |> gpu)
70+
push!(t, (reward=5.0f0, terminal=false))
6971

7072
@test length(t) == 3
7173

@@ -127,9 +129,9 @@ end
127129
@test t isa CircularArraySLARTTraces
128130
end
129131

130-
@testset "CircularPrioritizedTraces" begin
132+
@testset "CircularPrioritizedTraces-SARTS" begin
131133
t = CircularPrioritizedTraces(
132-
CircularArraySARTSATraces(;
134+
CircularArraySARTSTraces(;
133135
capacity=3
134136
),
135137
default_priority=1.0f0
@@ -160,7 +162,7 @@ end
160162

161163
#EpisodesBuffer
162164
t = CircularPrioritizedTraces(
163-
CircularArraySARTSATraces(;
165+
CircularArraySARTSTraces(;
164166
capacity=10
165167
),
166168
default_priority=1.0f0
@@ -186,3 +188,61 @@ end
186188
eb[:priority, [1, 2]] = [0, 0]
187189
@test eb[:priority] == [zeros(2);ones(8)]
188190
end
191+
192+
@testset "CircularPrioritizedTraces-SARTSA" begin
193+
t = CircularPrioritizedTraces(
194+
CircularArraySARTSATraces(;
195+
capacity=3
196+
),
197+
default_priority=1.0f0
198+
)
199+
200+
push!(t, (state=0, action=0))
201+
202+
for i in 1:5
203+
push!(t, (reward=1.0f0, terminal=false, state=i, action=i))
204+
end
205+
206+
@test length(t) == 3
207+
208+
s = BatchSampler(5)
209+
210+
b = sample(s, t)
211+
212+
t[:priority, [1, 2]] = [0, 0]
213+
214+
# shouldn't be changed since [1,2] are old keys
215+
@test t[:priority] == [1.0f0, 1.0f0, 1.0f0]
216+
217+
t[:priority, [3, 4, 5]] = [0, 1, 0]
218+
219+
b = sample(s, t)
220+
221+
@test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0
222+
223+
#EpisodesBuffer
224+
t = CircularPrioritizedTraces(
225+
CircularArraySARTSATraces(;
226+
capacity=10
227+
),
228+
default_priority=1.0f0
229+
)
230+
231+
eb = EpisodesBuffer(t)
232+
push!(eb, (state = 1,))
233+
for i = 1:5
234+
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
235+
end
236+
push!(eb, PartialNamedTuple((action = 6,)))
237+
push!(eb, (state = 7,))
238+
for (j,i) = enumerate(8:11)
239+
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
240+
end
241+
push!(eb, PartialNamedTuple((action=12,)))
242+
s = BatchSampler(1000)
243+
b = sample(s, eb)
244+
cm = counter(b[:state])
245+
@test !haskey(cm, 6)
246+
@test !haskey(cm, 11)
247+
@test all(in(keys(cm)), [1:5;7:10])
248+
end

test/episodes.jl

+26-8
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ using Test
100100
for i = 1:5
101101
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
102102
@test eb.sampleable_inds[end] == 0
103-
@test eb.sampleable_inds[end-1] == 1
103+
@test eb.sampleable_inds[end-1] == 0
104+
if length(eb) >= 1
105+
@test eb.sampleable_inds[end-2] == 1
106+
end
104107
@test eb.step_numbers[end] == i + 1
105108
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
106109
end
@@ -123,18 +126,24 @@ using Test
123126
ep2_len += 1
124127
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
125128
@test eb.sampleable_inds[end] == 0
126-
@test eb.sampleable_inds[end-1] == 1
129+
@test eb.sampleable_inds[end-1] == 0
130+
if eb.step_numbers[end] > 2
131+
@test eb.sampleable_inds[end-2] == 1
132+
end
127133
@test eb.step_numbers[end] == j + 1
128134
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
129135
end
130-
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
136+
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0]
131137
@test length(eb.traces) == 9 #an action is missing at this stage
132138
#three last steps replace oldest steps in the buffer.
133139
for (i, s) = enumerate(12:13)
134140
ep2_len += 1
135141
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
136142
@test eb.sampleable_inds[end] == 0
137-
@test eb.sampleable_inds[end-1] == 1
143+
@test eb.sampleable_inds[end-1] == 0
144+
if eb.step_numbers[end] > 2
145+
@test eb.sampleable_inds[end-2] == 1
146+
end
138147
@test eb.step_numbers[end] == i + 1 + 4
139148
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
140149
end
@@ -299,7 +308,10 @@ using Test
299308
for i = 1:5
300309
push!(eb, (state = i+1, action =i, reward = i, terminal = false))
301310
@test eb.sampleable_inds[end] == 0
302-
@test eb.sampleable_inds[end-1] == 1
311+
@test eb.sampleable_inds[end-1] == 0
312+
if eb.step_numbers[end] > 2
313+
@test eb.sampleable_inds[end-2] == 1
314+
end
303315
@test eb.step_numbers[end] == i + 1
304316
@test eb.episodes_lengths[end-i:end] == fill(i, i+1)
305317
end
@@ -321,17 +333,23 @@ using Test
321333
ep2_len += 1
322334
push!(eb, (state = i, action =i-1, reward = i-1, terminal = false))
323335
@test eb.sampleable_inds[end] == 0
324-
@test eb.sampleable_inds[end-1] == 1
336+
@test eb.sampleable_inds[end-1] == 0
337+
if eb.step_numbers[end] > 2
338+
@test eb.sampleable_inds[end-2] == 1
339+
end
325340
@test eb.step_numbers[end] == j + 1
326341
@test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1)
327342
end
328-
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0]
343+
@test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0]
329344
@test length(eb.traces) == 9 #an action is missing at this stage
330345
for (i, s) = enumerate(12:13)
331346
ep2_len += 1
332347
push!(eb, (state = s, action =s-1, reward = s-1, terminal = false))
333348
@test eb.sampleable_inds[end] == 0
334-
@test eb.sampleable_inds[end-1] == 1
349+
@test eb.sampleable_inds[end-1] == 0
350+
if eb.step_numbers[end] > 2
351+
@test eb.sampleable_inds[end-2] == 1
352+
end
335353
@test eb.step_numbers[end] == i + 1 + 4
336354
@test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1)
337355
end

test/samplers.jl

+8-6
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,17 @@ import ReinforcementLearningTrajectories.fetch
130130
batchsize = 4
131131
eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0))
132132
s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize)
133-
134-
push!(eb, (state = 1, action = 1))
133+
134+
push!(eb, (state = 1,))
135135
for i = 1:5
136-
push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5))
136+
push!(eb, (state = i+1, action =i, reward = i, terminal = i == 5))
137137
end
138-
push!(eb, (state = 7, action = 7))
139-
for (j,i) = enumerate(8:11)
140-
push!(eb, (state = i, action =i, reward = i-1, terminal = false))
138+
push!(eb, PartialNamedTuple((action=6,)))
139+
push!(eb, (state = 7,))
140+
for (j,i) = enumerate(7:10)
141+
push!(eb, (state = i+1, action =i, reward = i, terminal = i==10))
141142
end
143+
push!(eb, PartialNamedTuple((action = 11,)))
142144
weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb)
143145
inds = [i for i in eachindex(weights) if weights[i] == 1]
144146
batch = sample(s1, eb)

0 commit comments

Comments
 (0)