diff --git a/Project.toml b/Project.toml index 193d5a0..040791e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ReinforcementLearningTrajectories" uuid = "6486599b-a3cd-4e92-a99a-2cea90cc8c3c" -version = "0.4" +version = "0.4.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/common/CircularArraySARTSATraces.jl b/src/common/CircularArraySARTSATraces.jl index 94f2aa1..393e64b 100644 --- a/src/common/CircularArraySARTSATraces.jl +++ b/src/common/CircularArraySARTSATraces.jl @@ -24,11 +24,11 @@ function CircularArraySARTSATraces(; reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) + + MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+2)) + MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) + Traces( - reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), - terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), + reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity+1), + terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity+1), ) end diff --git a/src/common/CircularPrioritizedTraces.jl b/src/common/CircularPrioritizedTraces.jl index 755b9a1..09b2ffd 100644 --- a/src/common/CircularPrioritizedTraces.jl +++ b/src/common/CircularPrioritizedTraces.jl @@ -12,7 +12,11 @@ end function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts} new_names = (:key, :priority, names...) new_Ts = Tuple{Int,Float32,Ts.parameters...} - c = capacity(traces) + if traces isa CircularArraySARTSATraces + c = capacity(traces) - 1 + else + c = capacity(traces) + end CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}( CircularVectorBuffer{Int}(c), SumTree(c), @@ -34,6 +38,22 @@ function Base.push!(t::CircularPrioritizedTraces, x) end end +function Base.push!(t::CircularPrioritizedTraces{<:CircularArraySARTSATraces}, x) + initial_length = length(t.traces) + push!(t.traces, x) + if length(t.traces) == 1 + push!(t.keys, 1) + push!(t.priorities, t.default_priority) + elseif length(t.traces) > 1 && (initial_length < length(t.traces) || initial_length == capacity(t.traces)-1 ) + # only add a key if the length changes after insertion of the tuple + # or if the trace is already at capacity + push!(t.keys, t.keys[end] + 1) + push!(t.priorities, t.default_priority) + else + # may be partial inserting at the first step, ignore it + end +end + function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys) if k === :priority @assert length(vs) == length(keys) @@ -48,6 +68,7 @@ function Base.setindex!(t::CircularPrioritizedTraces, vs, k::Symbol, keys) end Base.size(t::CircularPrioritizedTraces) = size(t.traces) +max_length(t::CircularPrioritizedTraces) = max_length(t.traces) function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol) if s === :priority diff --git a/src/episodes.jl b/src/episodes.jl index e93fb92..90a8b79 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -138,6 +138,8 @@ fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces) fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces) +max_length(eb::EpisodesBuffer) = max_length(eb.traces) + function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.traces, xs) partial = ispartial_insert(eb, xs) @@ -146,10 +148,12 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.episodes_lengths, 0) push!(eb.sampleable_inds, 0) elseif !partial #typical inserting - if length(eb.traces) < length(eb) && length(eb) > 2 #case when PartialNamedTuple is used. Steps are indexable one step later - eb.sampleable_inds[end-1] = 1 - else #case when we don't, length of traces and eb will match. - eb.sampleable_inds[end] = 1 #previous step is now indexable + if haskey(eb,:next_action) && length(eb) < max_length(eb) # if trace has next_action and lengths are mismatched + if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode + eb.sampleable_inds[end-1] = 1 # steps are indexable one step later + end + else + eb.sampleable_inds[end] = 1 # otherwise, previous step is now indexable end push!(eb.sampleable_inds, 0) #this one is no longer ep_length = last(eb.step_numbers) @@ -172,6 +176,28 @@ function Base.push!(eb::EpisodesBuffer, xs::PartialNamedTuple) #wrap a NamedTupl eb.sampleable_inds[end-1] = 1 #completes the episode trajectory. end +function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularArraySARTSATraces}, xs::PartialNamedTuple) + if max_length(eb) == capacity(eb.traces) + popfirst!(eb) + end + push!(eb.traces, xs.namedtuple) + eb.sampleable_inds[end-1] = 1 #completes the episode trajectory. +end + +function Base.push!(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces{<:CircularArraySARTSATraces}}, xs::PartialNamedTuple{@NamedTuple{action::Int64}}) + if max_length(eb) == capacity(eb.traces) + addition = (name => zero(eltype(eb.traces[name])) for name in [:state, :reward, :terminal]) + xs = merge(xs.namedtuple, addition) + push!(eb.traces, xs) + pop!(eb.traces[:state].trace) + pop!(eb.traces[:reward]) + pop!(eb.traces[:terminal]) + else + push!(eb.traces, xs.namedtuple) + eb.sampleable_inds[end-1] = 1 + end +end + for f in (:pop!, :popfirst!) @eval function Base.$f(eb::EpisodesBuffer) $f(eb.episodes_lengths) diff --git a/src/samplers.jl b/src/samplers.jl index 8701189..e5443a7 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -74,7 +74,7 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir t = e.traces p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) - w .*= e.sampleable_inds[1:end-1] + w .*= e.sampleable_inds[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...)) end @@ -247,7 +247,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) valids, ns = valid_range(s,e) - w .*= valids[1:end-1] + w .*= valids[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) merge( (key=t.keys[inds], priority=p[inds]), @@ -362,7 +362,7 @@ function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, < p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) valids, ns = valid_range(s,e) - w .*= valids[1:end-1] + w .*= valids[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) merge( (key=t.keys[inds], priority=p[inds]), diff --git a/src/traces.jl b/src/traces.jl index 1cd36bb..35e1f77 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -247,6 +247,7 @@ function Base.:(+)(t1::Traces{k1,T1,N1,E1}, t2::Traces{k2,T2,N2,E2}) where {k1,T end Base.size(t::Traces) = (mapreduce(length, min, t.traces),) +max_length(t::Traces) = mapreduce(length, max, t.traces) function capacity(t::Traces{names,Trs,N,E}) where {names,Trs,N,E} minimum(map(idx->capacity(t[idx]), names)) diff --git a/test/common.jl b/test/common.jl index afcc89b..714520b 100644 --- a/test/common.jl +++ b/test/common.jl @@ -24,7 +24,7 @@ @test length(t) == 0 end -@testset "CircularArraySARTSTraces" begin +@testset "CircularArraySARTSATraces" begin t = CircularArraySARTSATraces(; capacity=3, state=Float32 => (2, 3), @@ -35,13 +35,14 @@ end @test t isa CircularArraySARTSATraces - push!(t, (state=ones(Float32, 2, 3), action=ones(Float32, 2)) |> gpu) + push!(t, (state=ones(Float32, 2, 3),)) + push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu) @test length(t) == 0 push!(t, (reward=1.0f0, terminal=false) |> gpu) - @test length(t) == 0 # next_state and next_action is still missing + @test length(t) == 0 # next_action is still missing - push!(t, (next_state=ones(Float32, 2, 3) * 2, next_action=ones(Float32, 2) * 2) |> gpu) + push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 2) |> gpu) @test length(t) == 1 # this will trigger the scalar indexing of CuArray @@ -55,17 +56,18 @@ end ) push!(t, (reward=2.0f0, terminal=false)) - push!(t, (state=ones(Float32, 2, 3) * 3, action=ones(Float32, 2) * 3) |> gpu) + push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 3) |> gpu) @test length(t) == 2 push!(t, (reward=3.0f0, terminal=false)) - push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 4) |> gpu) + push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 4) |> gpu) @test length(t) == 3 push!(t, (reward=4.0f0, terminal=false)) - push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 5) |> gpu) + push!(t, (state=ones(Float32, 2, 3) * 6, action=ones(Float32, 2) * 5) |> gpu) + push!(t, (reward=5.0f0, terminal=false)) @test length(t) == 3 @@ -127,9 +129,9 @@ end @test t isa CircularArraySLARTTraces end -@testset "CircularPrioritizedTraces" begin +@testset "CircularPrioritizedTraces-SARTS" begin t = CircularPrioritizedTraces( - CircularArraySARTSATraces(; + CircularArraySARTSTraces(; capacity=3 ), default_priority=1.0f0 @@ -160,7 +162,7 @@ end #EpisodesBuffer t = CircularPrioritizedTraces( - CircularArraySARTSATraces(; + CircularArraySARTSTraces(; capacity=10 ), default_priority=1.0f0 @@ -186,3 +188,61 @@ end eb[:priority, [1, 2]] = [0, 0] @test eb[:priority] == [zeros(2);ones(8)] end + +@testset "CircularPrioritizedTraces-SARTSA" begin + t = CircularPrioritizedTraces( + CircularArraySARTSATraces(; + capacity=3 + ), + default_priority=1.0f0 + ) + + push!(t, (state=0, action=0)) + + for i in 1:5 + push!(t, (reward=1.0f0, terminal=false, state=i, action=i)) + end + + @test length(t) == 3 + + s = BatchSampler(5) + + b = sample(s, t) + + t[:priority, [1, 2]] = [0, 0] + + # shouldn't be changed since [1,2] are old keys + @test t[:priority] == [1.0f0, 1.0f0, 1.0f0] + + t[:priority, [3, 4, 5]] = [0, 1, 0] + + b = sample(s, t) + + @test b.key == [4, 4, 4, 4, 4] # the priority of the rest transitions are set to 0 + + #EpisodesBuffer + t = CircularPrioritizedTraces( + CircularArraySARTSATraces(; + capacity=10 + ), + default_priority=1.0f0 + ) + + eb = EpisodesBuffer(t) + push!(eb, (state = 1,)) + for i = 1:5 + push!(eb, (state = i+1, action =i, reward = i, terminal = false)) + end + push!(eb, PartialNamedTuple((action = 6,))) + push!(eb, (state = 7,)) + for (j,i) = enumerate(8:11) + push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) + end + push!(eb, PartialNamedTuple((action=12,))) + s = BatchSampler(1000) + b = sample(s, eb) + cm = counter(b[:state]) + @test !haskey(cm, 6) + @test !haskey(cm, 11) + @test all(in(keys(cm)), [1:5;7:10]) +end diff --git a/test/episodes.jl b/test/episodes.jl index 7c297d5..0932416 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -100,7 +100,10 @@ using Test for i = 1:5 push!(eb, (state = i+1, action =i, reward = i, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if length(eb) >= 1 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == i + 1 @test eb.episodes_lengths[end-i:end] == fill(i, i+1) end @@ -123,18 +126,24 @@ using Test ep2_len += 1 push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == j + 1 @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) end - @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] + @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0] @test length(eb.traces) == 9 #an action is missing at this stage #three last steps replace oldest steps in the buffer. for (i, s) = enumerate(12:13) ep2_len += 1 push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == i + 1 + 4 @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end @@ -299,7 +308,10 @@ using Test for i = 1:5 push!(eb, (state = i+1, action =i, reward = i, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == i + 1 @test eb.episodes_lengths[end-i:end] == fill(i, i+1) end @@ -321,17 +333,23 @@ using Test ep2_len += 1 push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == j + 1 @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) end - @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] + @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0] @test length(eb.traces) == 9 #an action is missing at this stage for (i, s) = enumerate(12:13) ep2_len += 1 push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) @test eb.sampleable_inds[end] == 0 - @test eb.sampleable_inds[end-1] == 1 + @test eb.sampleable_inds[end-1] == 0 + if eb.step_numbers[end] > 2 + @test eb.sampleable_inds[end-2] == 1 + end @test eb.step_numbers[end] == i + 1 + 4 @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end diff --git a/test/samplers.jl b/test/samplers.jl index 1b914bb..2565ddd 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -130,15 +130,17 @@ import ReinforcementLearningTrajectories.fetch batchsize = 4 eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0)) s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batchsize=batchsize) - - push!(eb, (state = 1, action = 1)) + + push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5)) + push!(eb, (state = i+1, action =i, reward = i, terminal = i == 5)) end - push!(eb, (state = 7, action = 7)) - for (j,i) = enumerate(8:11) - push!(eb, (state = i, action =i, reward = i-1, terminal = false)) + push!(eb, PartialNamedTuple((action=6,))) + push!(eb, (state = 7,)) + for (j,i) = enumerate(7:10) + push!(eb, (state = i+1, action =i, reward = i, terminal = i==10)) end + push!(eb, PartialNamedTuple((action = 11,))) weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) inds = [i for i in eachindex(weights) if weights[i] == 1] batch = sample(s1, eb)