Lux
The 🔥 Deep Learning Framework
Installation
] add Lux
Getting Started
using Lux, Random, Optimisers, Zygote
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
# Construct the layer
model = Chain(
BatchNorm(128),
Dense(128, 256, tanh),
BatchNorm(256),
Chain(
Dense(256, 1, tanh),
Dense(1, 10)
)
)
# Parameter and State Variables
ps, st = Lux.setup(rng, model) .|> gpu
# Dummy Input
x = rand(rng, Float32, 128, 2) |> gpu
# Run the model
y, st = Lux.apply(model, x, ps, st)
# Gradients
gs = gradient(p -> sum(Lux.apply(model, x, p, st)[1]), ps)[1]
# Optimization
st_opt = Optimisers.setup(Optimisers.ADAM(0.0001), ps)
st_opt, ps = Optimisers.update(st_opt, ps, gs)
Citation
If you found this library to be useful in academic work, then please cite:
@misc{pal2022lux,
author = {Pal, Avik},
title = {Lux: Explicit Parameterization of Deep Neural Networks in Julia},
year = {2022},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/avik-pal/Lux.jl/}}
}
Also consider starring our github repo