DiffRules

Many differentiation methods rely on the notion of "primitive" differentiation rules that can be composed via various formulations of the chain rule. Using DiffRules, you can define new differentiation rules, query whether or not a given rule exists, and symbolically apply rules to simple Julia expressions.

Note that DiffRules is not a fully-fledged symbolic differentiation tool. It is a (very) simple global database of common derivative definitions, and was developed with the goal of improving derivative coverage in downstream tools.

DiffRules.@define_diffruleMacro
@define_diffrule M.f(x) = :(df_dx($x))
@define_diffrule M.f(x, y) = :(df_dx($x, $y)), :(df_dy($x, $y))
⋮

Define a new differentiation rule for the function M.f and the given arguments, which should be treated as bindings to Julia expressions. Return the defined rule's key.

The LHS should be a function call with a non-splatted argument list, and the RHS should be the derivative expression, or in the n-ary case, an n-tuple of expressions where the ith expression is the derivative of f w.r.t the ith argument. Arguments should be interpolated wherever they are used on the RHS.

Note that differentiation rules are purely symbolic, so no type annotations should be used.

Examples

@define_diffrule Base.cos(x)          = :(-sin($x))
@define_diffrule Base.:/(x, y)        = :(inv($y)), :(-$x / ($y^2))
@define_diffrule Base.polygamma(m, x) = :NaN,       :(polygamma($m + 1, $x))
DiffRules.diffruleFunction
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...)

Return the derivative expression for M.f at the given argument(s), with the argument(s) interpolated into the returned expression.

In the n-ary case, an n-tuple of expressions will be returned where the ith expression is the derivative of f w.r.t the ith argument.

Examples

julia> DiffRules.diffrule(:Base, :sin, 1)
:(cos(1))

julia> DiffRules.diffrule(:Base, :sin, :x)
:(cos(x))

julia> DiffRules.diffrule(:Base, :sin, :(x * y^2))
:(cos(x * y ^ 2))
DiffRules.hasdiffruleFunction
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int)

Return true if a differentiation rule is defined for M.f and arity, or return false otherwise.

Here, arity refers to the number of arguments accepted by f.

Examples

julia> DiffRules.hasdiffrule(:Base, :sin, 1)
true

julia> DiffRules.hasdiffrule(:Base, :sin, 2)
false

julia> DiffRules.hasdiffrule(:Base, :-, 1)
true

julia> DiffRules.hasdiffrule(:Base, :-, 2)
true

julia> DiffRules.hasdiffrule(:Base, :-, 3)
false
DiffRules.diffrulesFunction
diffrules(; filter_modules=(:Base, :SpecialFunctions, :NaNMath))

Return a list of keys that can be used to access all defined differentiation rules for modules in filter_modules.

Each key is of the form (M::Symbol, f::Symbol, arity::Int). Here, arity refers to the number of arguments accepted by f and M is one of the modules in filter_modules.

To include all rules, specify filter_modules = nothing.

Note

Calling diffrules() with the implicit default keyword argument filter_modules does not return all rules defined by this package but rather only rules for the packages for which DiffRules 1.0 provided rules. This is done in order to not to break downstream packages that assumed this list would never change. It is planned to change diffrules() to return all rules, i.e., to use the default keyword argument filter_modules=nothing, in an upcoming breaking release of DiffRules.

Examples

julia> first(DiffRules.diffrules()) isa Tuple{Symbol,Symbol,Int}
true

julia> (:Base, :log, 1) in DiffRules.diffrules()
true

julia> (:Base, :*, 2) in DiffRules.diffrules()
true

If you call diffrules(), only rules for Base, SpecialFunctions, and NaNMath are returned but no rules for LogExpFunctions:

julia> any(M === :LogExpFunctions for (M, _, _) in DiffRules.diffrules())
false

If you set filter_modules=nothing, all rules defined in DiffRules are returned and in particular also rules for LogExpFunctions:

julia> any(
           M === :LogExpFunctions
           for (M, _, _) in DiffRules.diffrules(; filter_modules=nothing)
       )
true

If you set filter_modules=(:Base,) only rules for functions in Base are returned:

julia> all(M === :Base for (M, _, _) in DiffRules.diffrules(; filter_modules=(:Base,)))
true