|
164 | 164 | # https://github.com/FluxML/Zygote.jl/pull/1001 tries to use broadcast_forward (using Dual numbers)
|
165 | 165 | # whenever possible, this was previously used only for CuArrays. It is usually much faster.
|
166 | 166 |
|
167 |
| -# https://github.com/JuliaDiff/ChainRules.jl/pull/644 implements broadcasting. |
168 |
| -# Its generic rule would be applied before the one defined here, with AbstractArrayStyle |
169 |
| -# @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F} |
170 |
| -# but does not pass all Zygote's tests. So disable it: |
171 |
| -ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any,N}) where {F,N} |
172 |
| - |
173 | 167 | @generated inclen(::NTuple{N,Any}) where N = Val(N+1)
|
174 | 168 |
|
175 | 169 | # Avoid hitting special cases for `Adjoint` etc.
|
@@ -290,3 +284,35 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
|
290 | 284 |
|
291 | 285 | pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz]
|
292 | 286 |
|
| 287 | +# ChainRules opt-out |
| 288 | +# ================= |
| 289 | + |
| 290 | +# https://github.com/JuliaDiff/ChainRules.jl/pull/644 implements broadcasting. |
| 291 | +# Its generic rule would be applied before the one defined above, with AbstractArrayStyle: |
| 292 | +# @adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F} |
| 293 | +# but does not pass all Zygote's tests. So disable it: |
| 294 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(Broadcast.broadcasted), f::F, args::Vararg{Any,N}) where {F,N} |
| 295 | +# That expands to |
| 296 | +# rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), f::F, args::Vararg{Any, N}) where {F, N}) = nothing |
| 297 | +# which is now ambiguous with many other rrules defined there. So we need more opt-outs: |
| 298 | + |
| 299 | +const _NumericOrBroadcast = Union{Number, AbstractArray{<:Number}, NTuple{<:Any,Number}, Broadcast.Broadcasted} |
| 300 | + |
| 301 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(+), xs::_NumericOrBroadcast...) |
| 302 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(-), x::_NumericOrBroadcast, y::_NumericOrBroadcast) |
| 303 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(-), x::_NumericOrBroadcast) |
| 304 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(*), x::_NumericOrBroadcast, y::_NumericOrBroadcast) |
| 305 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(Base.literal_pow), ::typeof(^), x::_NumericOrBroadcast, ::Val{2}) |
| 306 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(/), x::_NumericOrBroadcast, y::Number) |
| 307 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(identity), x::_NumericOrBroadcast) |
| 308 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::Type{T}, x::_NumericOrBroadcast) where {T<:Number} |
| 309 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(float), x::_NumericOrBroadcast) |
| 310 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(conj), x::_NumericOrBroadcast) |
| 311 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(adjoint), x::_NumericOrBroadcast) |
| 312 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(conj), x::AbstractArray{<:Real}) |
| 313 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(adjoint), x::AbstractArray{<:Real}) |
| 314 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(real), x::_NumericOrBroadcast) |
| 315 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(real), x::AbstractArray{<:Real}) |
| 316 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(imag), x::_NumericOrBroadcast) |
| 317 | +ChainRulesCore.@opt_out rrule(cfg::ZygoteRuleConfig, ::typeof(broadcasted), ::typeof(complex), x::_NumericOrBroadcast) |
| 318 | + |
0 commit comments