FluxNLPModels.jl Tutorial

Setting up

This step-by-step example assumes prior knowledge of Julia and Flux.jl. See the Julia tutorial and the Flux.jl tutorial for more details.

We have aligned this tutorial to MLP_MNIST example and reused some of their functions.

What we cover in this tutorial

We will cover the following:

  • Define a Neural Network (NN) Model in Flux,
    • Fully connected model
  • Define or set the loss function
  • Data loading
    • MNIST
    • Divide the data into train and test
  • Define a method for calculating accuracy and loss
  • Transfer the NN model to FluxNLPModel
  • Using FluxNLPModels and access
    • Gradient of current weight
    • Objective (or loss) evaluated at current weights

Packages needed

using FluxNLPModels
using Flux, NLPModels
using Flux.Data: DataLoader
using Flux: onehotbatch, onecold
using Flux.Losses: logitcrossentropy
using MLDatasets
using JSOSolvers
WARNING: ignoring conflicting import of Krylov.solve! into JSOSolvers

Setting Neural Network (NN) Model

First, a NN model needs to be define in Flux.jl. Our model is very simple: It consists of one "hidden layer" with 32 "neurons", each connected to every input pixel. Each neuron has a sigmoid nonlinearity and is connected to every "neuron" in the output layer. Finally, softmax produces probabilities, i.e., positive numbers that add up to 1.

One can create a method that returns the model. This method can encapsulate the specific architecture and parameters of the model, making it easier to reuse and manage. It provides a convenient way to define and initialize the model when needed.

function build_model(; imgsize = (28, 28, 1), nclasses = 10)
  return Chain(Dense(prod(imgsize), 32, relu), Dense(32, nclasses))
end
build_model (generic function with 1 method)

Loss function

We can define any loss function that we need, here we use Flux build-in logitcrossentropy function.

## Loss function
const loss = Flux.logitcrossentropy
logitcrossentropy (generic function with 1 method)

Load datasets and define minibatch

In this section, we will cover the process of loading datasets and defining minibatches for training your model using Flux. Loading and preprocessing data is an essential step in machine learning, as it allows you to train your model on real-world examples.

We will specifically focus on loading the MNIST dataset. We will divide the data into training and testing sets, ensuring that we have separate data for model training and evaluation.

Additionally, we will define minibatches, which are subsets of the dataset that are used during the training process. Minibatches enable efficient training by processing a small batch of examples at a time, instead of the entire dataset. This technique helps in managing memory resources and improving convergence speed.

function getdata(bs)
  ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"

  # Loading Dataset
  xtrain, ytrain = MLDatasets.MNIST(Tx = Float32, split = :train)[:]
  xtest, ytest = MLDatasets.MNIST(Tx = Float32, split = :test)[:]

  # Reshape Data in order to flatten each image into a linear array
  xtrain = Flux.flatten(xtrain)
  xtest = Flux.flatten(xtest)

  # One-hot-encode the labels
  ytrain, ytest = onehotbatch(ytrain, 0:9), onehotbatch(ytest, 0:9)

  # Create DataLoaders (mini-batch iterators)
  train_loader = DataLoader((xtrain, ytrain), batchsize = bs, shuffle = true)
  test_loader = DataLoader((xtest, ytest), batchsize = bs)

  return train_loader, test_loader
end
getdata (generic function with 1 method)

Transfering to FluxNLPModels

  device = cpu
  train_loader, test_loader = getdata(128)

  ## Construct model
  model = build_model() |> device

  # now we set the model to FluxNLPModel
  nlp = FluxNLPModel(model, train_loader, test_loader; loss_f = loss)
FluxNLPModel{Float32, Vector{Float32}, typeof(Flux.Losses.logitcrossentropy)}
  Problem name: Generic
   All variables: ████████████████████ 25450  All constraints: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
            free: ████████████████████ 25450             free: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
           lower: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                lower: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
           upper: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                upper: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
         low/upp: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0              low/upp: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
           fixed: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                fixed: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
          infeas: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0               infeas: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
            nnzh: (  0.00% sparsity)   323863975          linear: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
                                                    nonlinear: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
                                                         nnzj: (------% sparsity)         

  Counters:
             obj: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                 grad: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                 cons: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
        cons_lin: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0             cons_nln: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                 jcon: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
           jgrad: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                  jac: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0              jac_lin: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
         jac_nln: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                jprod: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0            jprod_lin: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
       jprod_nln: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0               jtprod: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0           jtprod_lin: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
      jtprod_nln: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                 hess: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0                hprod: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     
           jhess: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0               jhprod: ⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅⋅ 0     

Tools associated with a FluxNLPModel

The problem dimension n, where w ∈ ℝⁿ:

n = nlp.meta.nvar
25450

Get the current network weights:

w = nlp.w
25450-element Vector{Float32}:
  0.015744066
 -0.00038297143
 -0.023328608
 -0.0037001388
 -0.0018926343
  0.040415697
 -0.052152522
  0.07806019
  0.036663763
 -0.014904902
  ⋮
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0

Evaluate the loss function (i.e. the objective function) at w:

using NLPModels
NLPModels.obj(nlp, w)
2.4031613f0

The length of w must be nlp.meta.nvar.

Evaluate the gradient at w:

g = similar(w)
NLPModels.grad!(nlp, w, g)
25450-element Vector{Float32}:
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  0.0
  ⋮
 -0.099184796
  0.016802719
  0.013390605
  0.048606716
  0.01852965
  0.009330362
  0.03778672
 -0.07561049
  0.03411634

Train a neural network with JSOSolvers.R2

max_time = 60. # run at most 1min
callback = (nlp,
            solver,
            stats) -> FluxNLPModels.minibatch_next_train!(nlp)

solver_stats = R2(nlp; callback, max_time)
test_accuracy = FluxNLPModels.accuracy(nlp) #check the accuracy
0.4367f0