Abstract Types

Lux.AbstractExplicitLayerType
AbstractExplicitLayer

Abstract Type for all Lux Layers

Users implementing their custom layer, must implement

  • initialparameters(rng::AbstractRNG, layer::CustomAbstractExplicitLayer) – This returns a NamedTuple containing the trainable parameters for the layer.
  • initialstates(rng::AbstractRNG, layer::CustomAbstractExplicitLayer) – This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this include BatchNorm, LSTM, GRU etc.

Optionally:

  • parameterlength(layer::CustomAbstractExplicitLayer) – These can be automatically calculated, but it is recommended that the user defines these.
  • statelength(layer::CustomAbstractExplicitLayer) – These can be automatically calculated, but it is recommended that the user defines these.

See also AbstractExplicitContainerLayer

Lux.AbstractExplicitContainerLayerType
AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer

Abstract Container Type for certain Lux Layers. layers is a tuple containing fieldnames for the layer, and constructs the parameters and states using those.

Users implementing their custom layer can extend the same functions as in AbstractExplicitLayer

General

Lux.applyFunction
apply(model::AbstractExplicitLayer, x, ps::Union{ComponentArray,NamedTuple},
      st::NamedTuple)

Simply calls model(x, ps, st)

Lux.setupFunction
setup(rng::AbstractRNG, l::AbstractExplicitLayer)

Shorthand for getting the parameters and states of the layer l. Is equivalent to (initialparameters(rng, l), initialstates(rng, l)).

Parameters

Lux.initialparametersFunction
initialparameters(rng::AbstractRNG, l)

Generate the initial parameters of the layer l.

Lux.parameterlengthFunction
parameterlength(l)

Return the total number of parameters of the layer l.

States

Lux.initialstatesFunction
initialstates(rng::AbstractRNG, l)

Generate the initial states of the layer l.

Lux.statelengthFunction
statelength(l)

Return the total number of states of the layer l.

Lux.testmodeFunction
testmode(st::NamedTuple)

Make all occurances of training in state stVal(false).

Lux.trainmodeFunction
trainmode(st::NamedTuple)

Make all occurances of training in state stVal(true).

Lux.update_stateFunction
update_state(st::NamedTuple, key::Symbol, value; layer_check=_default_layer_check(key))

Recursively update all occurances of the key in the state st with the value.

Index