EarlyStopping.jl

Linux Coverage
Build status codecov.io

A small package for applying early stopping criteria to loss-generating iterative algorithms, with a view to training and optimizing machine learning models.

The basis of IterationControl.jl, a package externally controlling iterative algorithms.

Includes the stopping criteria surveyed in Prechelt, Lutz (1998): "Early Stopping - But When?", in Neural Networks: Tricks of the Trade, ed. G. Orr, Springer.

Installation

using Pkg
Pkg.add("EarlyStopping")

Sample usage

The EarlyStopper objects defined in this package consume a sequence of numbers called losses generated by some external algorithm - generally the training loss or out-of-sample loss of some iterative statistical model - and decide when those losses have dropped sufficiently to warrant terminating the algorithm. A number of commonly applied stopping criteria, listed under Criteria below, are provided out-of-the-box.

Here's an example of using an EarlyStopper object to check against two of these criteria (either triggering the stop):

using EarlyStopping

stopper = EarlyStopper(Patience(2), InvalidValue()) # multiple criteria
done!(stopper, 0.123) # false
done!(stopper, 0.234) # false
done!(stopper, 0.345) # true

julia> message(stopper)
"Early stop triggered by Patience(2) stopping criterion. "

One may force an EarlyStopper to report its evolving state:

losses = [10.0, 11.0, 10.0, 11.0, 12.0, 10.0];
stopper = EarlyStopper(Patience(2), verbosity=1);

for loss in losses
    done!(stopper, loss) && break
end
[ Info: loss: 10.0       state: (loss = 10.0, n_increases = 0)
[ Info: loss: 11.0       state: (loss = 11.0, n_increases = 1)
[ Info: loss: 10.0       state: (loss = 10.0, n_increases = 0)
[ Info: loss: 11.0       state: (loss = 11.0, n_increases = 1)
[ Info: loss: 12.0       state: (loss = 12.0, n_increases = 2)

The "object-oriented" interface demonstrated here is not code-optimized but will suffice for the majority of use-cases. For performant code, use the functional interface described under Implementing new criteria below.

Criteria

To list all stopping criterion, do subtypes(StoppingCriterion). Each subtype T has a detailed doc-string queried with ?T at the REPL. Here is a short summary:

criterion description notation in Prechelt
Never() Never stop
InvalidValue() Stop when NaN, Inf or -Inf encountered
TimeLimit(t=0.5) Stop after t hours
NumberLimit(n=100) Stop after n loss updates (excl. "training losses")
NumberSinceBest(n=6) Stop after n loss updates (excl. "training losses")
Threshold(value=0.0) Stop when loss < value
GL(alpha=2.0) Stop after "Generalization Loss" exceeds alpha GL_α
PQ(alpha=0.75, k=5) Stop after "Progress-modified GL" exceeds alpha PQ_α
Patience(n=5) Stop after n consecutive loss increases UP_s
Disjunction(c...) Stop when any of the criteria c apply
Warmup(c; n=1) Wait for n loss updates before checking criteria c

Criteria tracking both training and out-of-sample losses

For criteria tracking both an "out-of-sample" loss and a "training" loss (eg, stopping criterion of type PQ), specify training=true if the update is for training, as in

done!(stopper, 0.123, training=true)

In these cases, the out-of-sample update must always come after the corresponding training update. Multiple training updates may precede the out-of-sample update, as in the following example:

criterion = PQ(alpha=2.0, k=2)
needs_training_losses(criterion) # true

stopper = EarlyStopper(criterion)

done!(stopper, 9.5, training=true) # false
done!(stopper, 9.3, training=true) # false
done!(stopper, 10.0) # false

done!(stopper, 9.3, training=true) # false
done!(stopper, 9.1, training=true) # false
done!(stopper, 8.9, training=true) # false
done!(stopper, 8.0) # false

done!(stopper, 8.3, training=true) # false
done!(stopper, 8.4, training=true) # false
done!(stopper, 9.0) # true

Important. If there is no distinction between in and out-of-sample losses, then any criterion can be applied, and in that case training=true is never specified (regardless of the actual interpretation of the losses being tracked).

Stopping times

To determine the stopping time for an iterator losses, use stopping_time(criterion, losses). This is useful for debugging new criteria (see below). If the iterator terminates without a stop, 0 is returned.

julia> stopping_time(InvalidValue(), [10.0, 3.0, Inf, 4.0])
3

julia> stopping_time(Patience(3), [10.0, 3.0, 4.0, 5.0], verbosity=1)
[ Info: loss updates: 1
[ Info: state: (loss = 10.0, n_increases = 0)
[ Info: loss updates: 2
[ Info: state: (loss = 3.0, n_increases = 0)
[ Info: loss updates: 3
[ Info: state: (loss = 4.0, n_increases = 1)
[ Info: loss updates: 4
[ Info: state: (loss = 5.0, n_increases = 2)
0

If the losses include both training and out-of-sample losses as described above, pass an extra Bool vector marking the training losses with true, as in

stopping_time(PQ(),
              [0.123, 0.321, 0.52, 0.55, 0.56, 0.58],
              [true, true, false, true, true, false])

Implementing new criteria

To implement a new stopping criterion, one must:

  • Define a new struct for the criterion, which must subtype StoppingCriterion.
  • Overload methods update and done for the new type.
struct NewCriteria <: StoppingCriterion
    # Put relevant fields here
end

# Provide a default constructor with all key-word arguments
NewCriteria(; kwargs...) = ...

# Return the initial state of the NewCriteria after
# receiving an out-of-sample loss
update(c::NewCriteria, loss, ::Nothing) = ...

# Return an updated state for NewCriteria given a `loss`
# and the current `state`
update(c::NewCriteria, loss, state) = ...

# Return true if NewCriteria should stop given `state`.
# Always return false if `state === nothing`
done(c::NewCriteria, state) = state === nothing ? false : ....

Optionally, one may define the following:

  • Overload the final message with message.
  • Handle training losses by overloading update_training and the trait needs_training_losses.
# Final message when NewCriteria triggers a stop
message(c::NewCriteria, state) = ...

# Methods for initializing/updating the state given a training loss
update_training(c::NewCriteria, loss, ::Nothing) = ...
update_training(c::NewCriteria, loss, state) = ...

Wrappers. If your criterion wraps another criterion (as Warmup does) then the criterion must be a field and must store the criterion being wrapped.

New Criteria Example

We demonstrate this with a simplified version of the code for Patience:

Defining the new type

using EarlyStopping

struct Patience <: StoppingCriterion
    n::Int
end
Patience(; n=5) = Patience(n)

Overloading update and done

All information to be "remembered" must passed around in an object called state below, which is the return value of update (and update_training). The update function has two methods:

  • Initialization: update(c::NewCriteria, loss, ::Nothing)
  • Subsequent Loss Updates: update(c::NewCriteria, loss, state)

Where state is the return of the previous call to update or update_training. Notice, that state === nothing indicates an uninitialized criteria.

import EarlyStopping: update, done

function update(criterion::Patience, loss, ::Nothing)
    return (loss=loss, n_increases=0) # state
end

function update(criterion::Patience, loss, state)
    old_loss, n = state
    if loss > old_loss
        n += 1
    else
        n = 0
    end
    return (loss=loss, n_increases=n) # state
end

The done method returns true or false depending on the state, but always returns false if state === nothing.

done(criterion::Patience, state) =
    state === nothing ? false : state.n_increases == criterion.n

Optional methods

The final message of an EarlyStopper is generated using a message method for StoppingCriterion. Here is the fallback (which does not use state):

EarlyStopping.message(criteria::StoppingCriterion, state)
    = "Early stop triggered by $criterion stopping criterion. "

The optional update_training methods (two for each criterion) have the same signature as the update methods above. Refer to the PQ code for an example.

If a stopping criterion requires one or more update_training calls per update call to work, you should overload the trait needs_training_losses for that type, as in this example from the source code:

EarlyStopping.needs_training_losses(::Type{<:PQ}) = true

Unit Testing

The following are provided to facilitate testing of new criteria:

  • stopping_time: returns the stopping time for an iterator losses using criterion.
  • @test_criteria NewCriteria(): Runs a suite of unit tests against the provided StoppingCriteria. This macro is only part of the test suite and is not part of the API.