DiffRules
Many differentiation methods rely on the notion of "primitive" differentiation rules that can be composed via various formulations of the chain rule. Using DiffRules, you can define new differentiation rules, query whether or not a given rule exists, and symbolically apply rules to simple Julia expressions.
Note that DiffRules is not a fully-fledged symbolic differentiation tool. It is a (very) simple global database of common derivative definitions, and was developed with the goal of improving derivative coverage in downstream tools.
DiffRules.@define_diffrule
— Macro@define_diffrule M.f(x) = :(df_dx($x))
@define_diffrule M.f(x, y) = :(df_dx($x, $y)), :(df_dy($x, $y))
⋮
Define a new differentiation rule for the function M.f
and the given arguments, which should be treated as bindings to Julia expressions. Return the defined rule's key.
The LHS should be a function call with a non-splatted argument list, and the RHS should be the derivative expression, or in the n
-ary case, an n
-tuple of expressions where the i
th expression is the derivative of f
w.r.t the i
th argument. Arguments should be interpolated wherever they are used on the RHS.
Note that differentiation rules are purely symbolic, so no type annotations should be used.
Examples
@define_diffrule Base.cos(x) = :(-sin($x))
@define_diffrule Base.:/(x, y) = :(inv($y)), :(-$x / ($y^2))
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma($m + 1, $x))
DiffRules.diffrule
— Functiondiffrule(M::Union{Expr,Symbol}, f::Symbol, args...)
Return the derivative expression for M.f
at the given argument(s), with the argument(s) interpolated into the returned expression.
In the n
-ary case, an n
-tuple of expressions will be returned where the i
th expression is the derivative of f
w.r.t the i
th argument.
Examples
julia> DiffRules.diffrule(:Base, :sin, 1)
:(cos(1))
julia> DiffRules.diffrule(:Base, :sin, :x)
:(cos(x))
julia> DiffRules.diffrule(:Base, :sin, :(x * y^2))
:(cos(x * y ^ 2))
DiffRules.hasdiffrule
— Functionhasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int)
Return true
if a differentiation rule is defined for M.f
and arity
, or return false
otherwise.
Here, arity
refers to the number of arguments accepted by f
.
Examples
julia> DiffRules.hasdiffrule(:Base, :sin, 1)
true
julia> DiffRules.hasdiffrule(:Base, :sin, 2)
false
julia> DiffRules.hasdiffrule(:Base, :-, 1)
true
julia> DiffRules.hasdiffrule(:Base, :-, 2)
true
julia> DiffRules.hasdiffrule(:Base, :-, 3)
false
DiffRules.diffrules
— Functiondiffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath))
Return a list of keys that can be used to access all defined differentiation rules for modules in filter_modules
.
Each key is of the form (M::Symbol, f::Symbol, arity::Int)
. Here, arity
refers to the number of arguments accepted by f
and M
is one of the modules in filter_modules
.
To include all rules, specify filter_modules = nothing
.
Calling diffrules()
with the implicit default keyword argument filter_modules
does not return all rules defined by this package but rather only rules for the packages for which DiffRules 1.0 provided rules. This is done in order to not to break downstream packages that assumed this list would never change. It is planned to change diffrules()
to return all rules, i.e., to use the default keyword argument filter_modules=nothing
, in an upcoming breaking release of DiffRules.
Examples
julia> first(DiffRules.diffrules())
(:Base, :log2, 1)
If you call diffrules()
, only rules for Base, SpecialFunctions, and NaNMath are returned but no rules for LogExpFunctions:
julia> any(M === :LogExpFunctions for (M, _, _) in DiffRules.diffrules())
false
If you set filter_modules=nothing
, all rules defined in DiffRules are returned and in particular also rules for LogExpFunctions:
julia> any(
M === :LogExpFunctions
for (M, _, _) in DiffRules.diffrules(; filter_modules=nothing)
)
true
If you set filter_modules=(:Base,)
only rules for functions in Base are returned:
julia> all(M === :Base for (M, _, _) in DiffRules.diffrules(; filter_modules=(:Base,)))
true