ChainRules.WithSomeZeros
— TypeWithSomeZeros{T}
This is a union of LinearAlgebra types, all of which are partly structral zeros, with a simple backing array given by parent(x)
. All have methods of _rewrap
to re-create.
This exists to solve a type instability, as broadcasting for instance λ .* Diagonal(rand(3))
gives a dense matrix when x==Inf
. But withsomezeros_rewrap(x, λ .* parent(x))
is type-stable.
ChainRules.OneElement
— TypeOneElement(val, ind, axes) <: AbstractArray
Extremely simple struct
used for the gradient of scalar getindex
.
ChainRules._instantiate_zeros
— Method_instantiate_zeros(ẋs, xs)
Forward rules for vect
, cat
etc may receive a mixture of data and ZeroTangent
s. To avoid vect(1, ZeroTangent(), 3)
or worse vcat([1,2], ZeroTangent(), [6,7])
, this materialises each zero ẋ
to be zero(x)
.
ChainRules._matfun
— Function_matfun(f, A) -> (Y, intermediates)
Compute the matrix function Y=f(A)
for matrix A
. The function returns a tuple containing the result and a tuple of intermediates to be reused by _matfun_frechet
to compute the Fréchet derivative.
ChainRules._matfun!
— Function_matfun!(f, A) -> (Y, intermediates)
Similar to _matfun
, but where A
may be overwritten.
ChainRules._matfun
— Method_matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
Compute the matrix function f(A)
for real or complex hermitian A
. The function returns a tuple containing the result and a tuple of intermediates to be reused by _matfun_frechet
to compute the Fréchet derivative.
Note any function f
used with this must have a frule
defined on it.
ChainRules._matfun_frechet
— Function_matfun_frechet(f, E, A, Y, intermediates)
Compute the Fréchet derivative of the matrix function $Y = f(A)$ at $A$ in the direction of $E$, where intermediates
is the second argument returned by _matfun
.
The Fréchet derivative is the unique linear map $L_f \colon E → L_f(A, E)$, such that
\[L_f(A, E) = f(A + E) - f(A) + o(\lVert E \rVert).\]
ChainRules._matfun_frechet!
— Function_matfun_frechet!(f, E, A, Y, intermediates)
Similar to _matfun_frechet
, but where E
may be overwritten.
ChainRules._matfun_frechet_adjoint!
— Method_matfun_frechet_adjoint!(f, E, A, Y, intermediates)
Similar to _matfun_frechet_adjoint
, but where E
may be overwritten.
ChainRules._matfun_frechet_adjoint
— Method_matfun_frechet_adjoint(f, E, A, Y, intermediates)
Compute the adjoint of the Fréchet derivative of the matrix function $Y = f(A)$ at $A$ in the direction of $E$, where intermediates
is the second argument returned by _matfun
.
Given the Fréchet $L_f(A, E)$ computed by _matfun_frechet
, then its adjoint $L_f^⋆(A, E)$ is defined by the identity
\[\langle B, L_f(A, C) \rangle = \langle L_f^⋆(A, B), C \rangle.\]
This identity is satisfied by $L_f^⋆(A, E) = L_f(A, E')'$.
ChainRules._setindex_zero
— Method_setindex_zero(x, dy, inds...)
This returns roughly dx = zero(x)
, except that this is guaranteed to be mutable via similar
, and its element type is wide enough to allow setindex!(dx, dy, inds...)
, which is exactly what ∇getindex
does next.
It's unfortunate to close over x
, but similar(typeof(x), axes(x))
doesn't allow eltype(dy)
, nor does it work for many structured matrices.
ChainRules._tuple_N
— Methodfor a given tuple type, returns a Val{N} where N is the length of the tuple
ChainRules._uses_input_only
— Method_uses_input_only(f, xT::Type)
Returns true
if it can prove that derivatives_given_output
will work using only the input of the given type. Thus there is no need to store the output y = f(x::xT)
, allowing us to take a fast path in the rrule
for sum(f, xs)
.
Works by seeing if the result of derivatives_given_output(nothing, f, x)
can be inferred. The method of derivatives_given_output
usually comes from @scalar_rule
.
ChainRules.unzip
— Methodunzip(A)
Converts an array of tuples into a tuple of arrays. Eager. Will work by reinterpret
when possible.
julia> ChainRules.unzip([(1,2), (30,40), (500,600)]) # makes two new Arrays:
([1, 30, 500], [2, 40, 600])
julia> typeof(ans)
Tuple{Vector{Int64}, Vector{Int64}}
julia> ChainRules.unzip([(1,nothing) (3,nothing) (5,nothing)]) # this can reinterpret:
([1 3 5], [nothing nothing nothing])
julia> ans[1]
1×3 reinterpret(Int64, ::Matrix{Tuple{Int64, Nothing}}):
1 3 5
ChainRules.unzip
— Methodunzip(t)
Also works on a tuple of tuples:
julia> unzip(((1,2), (30,40), (500,600)))
((1, 30, 500), (2, 40, 600))
ChainRules.unzip_broadcast
— Methodunzip_broadcast(f, args...)
For a function f
which returns a tuple, this is == unzip(broadcast(f, args...))
, but performed using StructArrays
for efficiency. Used in the gradient of broadcasting.
Examples
julia> using ChainRules: unzip_broadcast, unzip
julia> unzip_broadcast(x -> (x,2x), 1:3)
([1, 2, 3], [2, 4, 6])
julia> mats = @btime unzip_broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000)); # 2 arrays, each 7.63 MiB
min 1.776 ms, mean 20.421 ms (4 allocations, 15.26 MiB)
julia> mats == @btime unzip(broadcast((x,y) -> (x+y, x-y), 1:1000, transpose(1:1000))) # intermediate matrix of tuples
min 2.660 ms, mean 40.007 ms (6 allocations, 30.52 MiB)
true
ChainRules.∇getindex
— Method∇getindex(x, dy, inds...)
For the rrule
of y = x[inds...]
, this function is roughly setindex(zero(x), dy, inds...)
, returning the array dx
. Differentiable. Includes ProjectTo(x)(dx)
.