Skip to content

Commit 10a2186

Browse files
committed
add lots more opt outs
1 parent 1a8ea7a commit 10a2186

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

src/lib/broadcast.jl

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,6 @@ end
164164
# https://github.com/FluxML/Zygote.jl/pull/1001 tries to use broadcast_forward (using Dual numbers)
165165
# whenever possible, this was previously used only for CuArrays. It is usually much faster.
166166

167-
# https://github.com/JuliaDiff/ChainRules.jl/pull/644 implements broadcasting.
168-
# Its generic rule would be applied before the one defined here, with AbstractArrayStyle
169-
# @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
170-
# but does not pass all Zygote's tests. So disable it:
171-
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
172-
173167
@generated inclen(::NTuple{N,Any}) where N = Val(N+1)
174168

175169
# Avoid hitting special cases for `Adjoint` etc.
@@ -290,3 +284,35 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
290284

291285
pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz]
292286

287+
# ChainRules opt-out
288+
# =================
289+
290+
# https://github.com/JuliaDiff/ChainRules.jl/pull/644 implements broadcasting.
291+
# Its generic rule would be applied before the one defined above, with AbstractArrayStyle:
292+
# @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
293+
# but does not pass all Zygote's tests. So disable it:
294+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any,N}) where {F,N}
295+
# That expands to
296+
# rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), f::F, args::Vararg{Any, N}) where {F, N}) = nothing
297+
# which is now ambiguous with many other rrules defined there. So we need more opt-outs:
298+
299+
const _NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted}
300+
301+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(+), xs::_NumericOrBroadcast...)
302+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(-), x::_NumericOrBroadcast, y::_NumericOrBroadcast)
303+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(-), x::_NumericOrBroadcast)
304+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(*), x::_NumericOrBroadcast, y::_NumericOrBroadcast)
305+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::_NumericOrBroadcast, ::Val{2})
306+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(/), x::_NumericOrBroadcast, y::Number)
307+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(identity), x::_NumericOrBroadcast)
308+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::Type{T}, x::_NumericOrBroadcast) where {T<:Number}
309+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(float), x::_NumericOrBroadcast)
310+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(conj), x::_NumericOrBroadcast)
311+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(adjoint), x::_NumericOrBroadcast)
312+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(conj), x::AbstractArray{<:Real})
313+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(adjoint), x::AbstractArray{<:Real})
314+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(real), x::_NumericOrBroadcast)
315+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real})
316+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(imag), x::_NumericOrBroadcast)
317+
ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(complex), x::_NumericOrBroadcast)
318+

0 commit comments

Comments
 (0)