ExplicitFluxLayers
Explicit Parameterization of Flux Layers
Installation
] add ExplicitFluxLayers
Getting Started
using ExplicitFluxLayers, Random, Optimisers
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
# Construct the layer
model = EFL.Chain(
EFL.BatchNorm(128),
EFL.Dense(128, 256, tanh),
EFL.BatchNorm(256),
EFL.Chain(
EFL.Dense(256, 1, tanh),
EFL.Dense(1, 10)
)
)
# Parameter and State Variables
ps, st = EFL.setup(rng, model) .|> EFL.gpu
# Dummy Input
x = rand(rng, Float32, 128, 2) |> EFL.gpu
# Run the model
y, st = EFL.apply(model, x, ps, st)
# Gradients
gs = gradient(p -> sum(EFL.apply(model, x, p, st)[1]), ps)[1]
# Optimization
st_opt = Optimisers.setup(Optimisers.ADAM(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)
Design Principles
-
Layers must be immutable -- i.e. they cannot store any parameters/states but rather stores information to construct them
-
Layers return a Tuple containing the result and the updated state
-
Layers are pure functions
-
Given same inputs the outputs must be same -- stochasticity is controlled by seeds passed in the state variables
-
Easily extendible for Custom Layers: Each Custom Layer should be a subtype of either:
a.
AbstractExplicitLayer
: Useful for Base Layers and needs to define the following functionsinitialparameters(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)
-- This returns aComponentArray
/NamedTuple
containing the trainable parameters for the layer.initialstates(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)
-- This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this includeBatchNorm
, Recurrent Neural Networks, etc.parameterlength(layer::CustomAbstractExplicitLayer)
&statelength(layer::CustomAbstractExplicitLayer)
-- These can be automatically calculated, but it is recommended that the user defines these.
b.
AbstractExplicitContainerLayer
: Used when the layer is storing otherAbstractExplicitLayer
s orAbstractExplicitContainerLayer
s. This allows good defaults of the dispatches for functions mentioned in the previous point.
Why use ExplicitFluxLayers over Flux?
- Large Neural Networks
- For small neural networks we recommend SimpleChains.jl.
- For SciML Applications (Neural ODEs, Deep Equilibrium Models) solvers typically expect a monolithic parameter vector. Flux enables this via its
destructure
mechanism, however, it often leads to weird bugs. EFL forces users to make an explicit distinction between state variables and parameter variables to avoid these issues. - Comes battery-included for distributed training using FluxMPI.jl
- Sensible display of Custom Layers -- Ever wanted to see Pytorch like Network printouts or wondered how to extend the pretty printing of Flux's layers. ExplicitFluxLayers handles all of that by default.
- Less Bug-ridden Code
- No arbitrary internal mutations -- all layers are implemented as pure functions.
- All layers are deterministic given the parameter and state -- if the layer is supposed to be stochastic (say
Dropout
), the state must contain a seed which is then updated after the function call.
- Easy Parameter Manipulation -- Wondering why Flux doesn't have
WeightNorm
,SpectralNorm
, etc. The implicit parameter handling makes it extremely hard to pass parameters around without mutations which AD systems don't like. With ExplicitFluxLayers implementing them is outright simple.
Usage Examples
- Differential Equations + Deep Learning
- Neural ODEs for MNIST Image Classification -- Example borrowed from DiffEqFlux.jl
- Deep Equilibrium Models
- Image Classification
- Distributed Training using MPI.jl -- FluxMPI + FastDEQ
Recommended Libraries for Various ML Tasks
ExplicitFluxLayers is exclusively focused on designing Neural Network Architectures. All other parts of the DL training/evaluation pipeline should be offloaded to the following frameworks:
- Data Manipulation/Loading -- Augmentor.jl, DataLoaders.jl, Images.jl
- Optimisation -- Optimisers.jl, ParameterSchedulers.jl
- Automatic Differentiation -- Zygote.jl
- Parameter Manipulation -- Functors.jl
- Model Checkpointing -- Serialization.jl
- Activation Functions / Common Neural Network Primitives -- NNlib.jl
- Distributed Training -- FluxMPI.jl
- Training Visualization -- Wandb.jl
If you found any other packages useful, please open a PR and add them to this list.
Implemented Layers
We don't have a Documentation Page as of now. But all these functions have docs which can be access in the REPL help mode.
Chain
,Parallel
,SkipConnection
,BranchLayer
,PairwiseFusion
Dense
,Diagonal
Conv
,MaxPool
,MeanPool
,GlobalMaxPool
,GlobalMeanPool
,Upsample
,AdaptiveMaxPool
,AdaptiveMeanPool
BatchNorm
,WeightNorm
,GroupNorm
ReshapeLayer
,SelectDim
,FlattenLayer
,NoOpLayer
,WrappedFunction
Dropout
,VariationalHiddenDropout
TODOs
- Support Recurrent Neural Networks
- Add wider support for Flux Layers
- Convolution --> ConvTranspose, CrossCor
- Upsampling --> PixelShuffle
- General Purpose --> Maxout, Bilinear, Embedding, AlphaDropout
- Normalization --> LayerNorm, InstanceNorm
- Port tests over from Flux