Skip to content

Commit 3c9db4f

Browse files
Merge pull request #24 from JuliaReinforcementLearning/jpsl/tweaks
Add Feedback and expand test
2 parents c473aef + da410ca commit 3c9db4f

File tree

3 files changed

+34
-21
lines changed

3 files changed

+34
-21
lines changed

Project.toml

-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ version = "0.1.14"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
98

109
[compat]
1110
Adapt = "2, 3, 4"

src/CircularArrayBuffers.jl

+24-18
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Base.isempty(cb::CircularArrayBuffer) = cb.nframes == 0
7979
"""
8080
_buffer_index(cb::CircularArrayBuffer, i::Int)
8181
82-
Return the index of the `i`-th element in the buffer.
82+
Return the index of the `i`-th element in the buffer. Note the `i` is assumed to be the linear indexing of `cb`.
8383
"""
8484
@inline function _buffer_index(cb::CircularArrayBuffer, i::Int)
8585
idx = (cb.first - 1) * cb.step_size + i
@@ -90,25 +90,24 @@ end
9090
"""
9191
wrap_index(idx, n)
9292
93-
Return the index of the `idx`-th element in the buffer, if index is one past the size, return 1, else error.
93+
Return the index of the `idx`-th element in the buffer, if index is one past the size, return 1, else error.
9494
"""
95-
function wrap_index(idx, n)
95+
function wrap_index(idx::Int, n::Int)
9696
if idx <= n
9797
return idx
9898
elseif idx <= 2n
9999
return idx - n
100100
else
101-
@info "oops! idx $(idx) > 2n $(2n)"
102-
return idx - n
101+
return -1 # NOTE: This should never happen, due to @boundscheck
103102
end
104103
end
105104

106105
"""
107106
_buffer_frame(cb::CircularArrayBuffer, i::Int)
108107
109-
Return the index of the `i`-th frame in the buffer.
108+
Here `i` is assumed to be the last dimension of `cb`. Each `frame` means a slice of the last dimension. Since we use *circular frames* (the `data` buffer) underlying, this function transforms the logical `i`-th frame to the real frame of the internal buffer.
110109
"""
111-
@inline function _buffer_frame(cb::CircularArrayBuffer, i::Int)
110+
@inline function _buffer_frame(cb::CircularArrayBuffer{T,N}, i::Int) where {T,N}
112111
n = capacity(cb)
113112
idx = cb.first + i - 1
114113
return wrap_index(idx, n)
@@ -123,19 +122,26 @@ function Base.empty!(cb::CircularArrayBuffer)
123122
cb
124123
end
125124

126-
function Base.push!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
127-
if cb.nframes == capacity(cb)
125+
function _update_first_and_nframes!(cb)
126+
if isfull(cb)
128127
cb.first = (cb.first == capacity(cb) ? 1 : cb.first + 1)
129128
else
130129
cb.nframes += 1
131130
end
132-
if N == 1
133-
i = _buffer_frame(cb, cb.nframes)
134-
cb.buffer[i:i] .= Ref(data)
135-
else
136-
cb.buffer[ntuple(_ -> (:), N - 1)..., _buffer_frame(cb, cb.nframes)] .= data
137-
end
138-
cb
131+
return cb
132+
end
133+
134+
function Base.push!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
135+
_update_first_and_nframes!(cb)
136+
cb.buffer[ntuple(_ -> (:), N - 1)..., _buffer_frame(cb, cb.nframes)] .= data
137+
return cb
138+
end
139+
140+
function Base.push!(cb::CircularVectorBuffer{T}, data) where {T}
141+
_update_first_and_nframes!(cb)
142+
i = _buffer_frame(cb, cb.nframes)
143+
cb.buffer[i:i] .= Ref(data)
144+
return cb
139145
end
140146

141147
function Base.append!(cb::CircularArrayBuffer{T,N}, data) where {T,N}
@@ -180,7 +186,7 @@ function Base.pop!(cb::CircularArrayBuffer{T,N}) where {T,N}
180186
else
181187
res = @views cb.buffer[ntuple(_ -> (:), N - 1)..., _buffer_frame(cb, cb.nframes)]
182188
cb.nframes -= 1
183-
res
189+
return res
184190
end
185191
end
186192

@@ -194,7 +200,7 @@ function Base.popfirst!(cb::CircularArrayBuffer{T,N}) where {T,N}
194200
if cb.first > capacity(cb)
195201
cb.first = 1
196202
end
197-
res
203+
return res
198204
end
199205
end
200206

test/runtests.jl

+10-2
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,21 @@ CUDA.allowscalar(false)
2323
@test_throws BoundsError @view b[:, 9]
2424
end
2525

26-
@testset "Bounds error for zero-length buffer" begin
26+
@testset "Bounds error for zero-length / underfilled buffer" begin
2727
b = CircularVectorBuffer{Bool}(10)
28+
@test_throws BoundsError b[1]
2829
@test_throws BoundsError b[end]
29-
for i in 1:5
30+
31+
push!(b, true)
32+
@test b[1] == true
33+
@test b[end] == true
34+
@test_throws BoundsError b[2]
35+
for i in 1:15
3036
push!(b, true)
3137
end
3238
@test b[end] == true
39+
@test b[10] == true
40+
@test_throws BoundsError b[15]
3341
end
3442

3543
@testset "1D vector" begin

0 commit comments

Comments
 (0)