diff --git a/docs/src/api.md b/docs/src/api.md index 378bf72a..434ee70a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -85,4 +85,5 @@ It is defined in Functors.jl and re-exported by Optimisers.jl here for convenien Functors.KeyPath Functors.haskeypath Functors.getkeypath +Functors.setkeypath! ``` diff --git a/src/interface.jl b/src/interface.jl index ac9b90bc..be0427ab 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -67,6 +67,12 @@ function update(tree, model, grad, higher...) update!(t′, x′, grad, higher...) end +function update!(::AbstractRule, model, grad, higher...) + throw(ArgumentError("""update! must be called with an optimiser state tree, not a rule. + Call `opt_state = setup(rule, model)` first, then `update!(opt_state, model, grad)`. + """)) +end + function update!(tree, model, grad, higher...) # First walk is to accumulate the gradient. This recursion visits every copy of # shared leaves, but stops when branches are absent from the gradient: diff --git a/test/runtests.jl b/test/runtests.jl index fc0fe57f..bb5cbec9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -104,6 +104,10 @@ end @test isnan(m3n.γ[3]) end + @testset "friendly error when using rule instead of state" begin + @test_throws ErrorException Optimisers.update!(Adam(), rand(2), rand(2)) + end + @testset "Dict support" begin @testset "simple dict" begin d = Dict(:a => [1.0,2.0], :b => [3.0,4.0], :c => 1)