Skip to content

Commit eaa9f40

Browse files
committed
Remove redundant deepcopies where safe to do so
1 parent ed03b9e commit eaa9f40

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

src/GeneralisedFilters.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function instantiate end
2626

2727
# Default method
2828
function instantiate(model, alg, initial; kwargs...)
29-
return Intermediate(initial, deepcopy(initial))
29+
return Intermediate(initial, initial)
3030
end
3131

3232
"""

src/algorithms/bootstrap.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ function instantiate(
5050
::StateSpaceModel{T}, filter::BootstrapFilter, initial; kwargs...
5151
) where {T}
5252
N = filter.N
53-
return ParticleIntermediate(initial, deepcopy(initial), Vector{Int}(undef, N))
53+
return ParticleIntermediate(initial, initial, Vector{Int}(undef, N))
5454
end
5555

5656
function initialise(
@@ -78,7 +78,8 @@ function predict(
7878
new_particles = map(
7979
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(filtered)
8080
)
81-
proposed = ParticleDistribution(new_particles, deepcopy(filtered.log_weights))
81+
# Don't need to deepcopy weights as filtered will be overwritten in the update step
82+
proposed = ParticleDistribution(new_particles, filtered.log_weights)
8283

8384
return update_ref!(proposed, ref_state, step)
8485
end

src/algorithms/rbpf.jl

+8-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222

2323
function instantiate(::HierarchicalSSM{T}, filter::RBPF, initial; kwargs...) where {T}
2424
N = filter.N
25-
return ParticleIntermediate(initial, deepcopy(initial), Vector{Int}(undef, N))
25+
return ParticleIntermediate(initial, initial, Vector{Int}(undef, N))
2626
end
2727

2828
function initialise(
@@ -56,7 +56,8 @@ function predict(
5656
new_particles = map(
5757
x -> marginal_predict(rng, model, algo, t, x; kwargs...), filtered.particles
5858
)
59-
proposed = ParticleDistribution(new_particles, deepcopy(filtered.log_weights))
59+
# Don't need to deepcopy weights as filtered will be overwritten in the update step
60+
proposed = ParticleDistribution(new_particles, filtered.log_weights)
6061

6162
return update_ref!(proposed, ref_state, t)
6263
end
@@ -156,7 +157,7 @@ end
156157

157158
function instantiate(model::HierarchicalSSM, algo::BatchRBPF, initial; kwargs...)
158159
N = algo.N
159-
return ParticleIntermediate(initial, deepcopy(initial), CuArray{Int}(undef, N))
160+
return ParticleIntermediate(initial, initial, CuArray{Int}(undef, N))
160161
end
161162

162163
function initialise(
@@ -205,8 +206,9 @@ function predict(
205206
new_outer=new_xs,
206207
kwargs...,
207208
)
209+
# Don't need to deepcopy weights as filtered will be overwritten in the update step
208210
proposed = RaoBlackwellisedParticleDistribution(
209-
BatchRaoBlackwellisedParticles(new_xs, new_zs), deepcopy(filtered.log_weights)
211+
BatchRaoBlackwellisedParticles(new_xs, new_zs), filtered.log_weights
210212
)
211213

212214
# return states
@@ -232,8 +234,9 @@ function update(
232234
)
233235

234236
new_weights = proposed.log_weights + inner_lls
237+
# Don't need to deepcopy particles as update will be overwritten in the next step
235238
filtered = RaoBlackwellisedParticleDistribution(
236-
BatchRaoBlackwellisedParticles(deepcopy(proposed.particles.xs), new_zs), new_weights
239+
BatchRaoBlackwellisedParticles(proposed.particles.xs, new_zs), new_weights
237240
)
238241

239242
step_ll = logsumexp(filtered.log_weights) - logsumexp(proposed.log_weights)

src/resamplers.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ function resample(rng::AbstractRNG, resampler::AbstractResampler, states)
1616
end
1717

1818
function construct_new_state(states::ParticleDistribution{PT,WT}, idxs) where {PT,WT}
19-
return ParticleDistribution(deepcopy(states.particles[idxs]), zeros(WT, length(states)))
19+
return ParticleDistribution(states.particles[idxs], zeros(WT, length(states)))
2020
end
2121

2222
function construct_new_state(
2323
states::RaoBlackwellisedParticleDistribution{T}, idxs
2424
) where {T}
2525
return RaoBlackwellisedParticleDistribution(
2626
BatchRaoBlackwellisedParticles(
27-
deepcopy(states.particles.xs[:, idxs]), deepcopy(states.particles.zs[idxs])
27+
states.particles.xs[:, idxs], states.particles.zs[idxs]
2828
),
2929
CUDA.zeros(T, length(states)),
3030
)

0 commit comments

Comments
 (0)