abstract type Callback

Supertype of all callbacks. Callbacks add custom functionality to the training loop by hooking into different Events.Events

Any Callback can be used by passing it to Learner. See subtypes(FluxTraining.Callback) for implementations.


See Custom callbacks for a less succinct tutorial format.

  1. Create a struct MyCallback that subtypes FluxTraining.Callback.

  2. Add event handlers by implementing methods for on(event, phase, callback, learner). Methods should always dispatch on your callback, and may dispatch on specific Phases.Phases and Events.Events.

    For example, to implement an event handler that runs at the end of every step during training: on(::StepEnd, ::AbstractTrainingPhase, ::MyCallback, learner).

  3. Define what state the callback accesses and/or modifies by implementing stateaccess(::MyCallback). While learner is always passed as an argument to on event handlers, by default a callback can not read or write to its fields. See stateaccess for more detail.

    If a callback needs to write some state that other callbacks should be able to access, it can store it in learner.cbstate if you add a permission in stateaccess.

  4. If the callback needs some one-time initialization, you can implement init! which will be run at least once before any step is run.


Throw during fitting to cancel the currently running epoch. This prematurely ends the current epoch without throwing an error. Must be thrown inside the context of runepoch.


runepoch(learner, phase) do _
    for batch in batches
        step!(learner, phase, batch)
        if learner.step.loss < 1.
            throw(CancelEpochException("Reached target loss"))

Throw during fitting to cancel the currently running step. This prematurely ends the current step without throwing an error. Must be thrown inside the context of runstep.


runepoch(learner, phase) do _
    for (xs, ys) in batches
        runstep(learner, phase, (; xs, ys)) do _, state
            # training logic...
            if isnan(state.loss).
                throw(CancelStepException("Skipping NaN loss"))
CustomCallback(f, Event, [TPhase = Phase, access = (;)])

A callback that runs f(learner) every time an event of type Event during a phase of type in Phase.

If f needs to access learner state, pass access, a named tuple in the same form as stateaccess.

Instead of using CustomCallback it is recommended to properly implement a Callback.


We can get a quick idea of when a new epoch starts as follows:

cb = CustomCallback(learner -> println("New epoch!"), EpochBegin)
EarlyStopping(criteria...; kwargs...)

Stop training early when criteria are met. See EarlyStopping.jl for available stopping criteria.

Passing an integer n uses the simple patience criterion: stop if the validation loss hasn't decreased for n epochs.

You can control which phases are taken to measure the out-of-sample loss and the training loss with keyword arguments trainphase (default AbstractTrainingPhase) and testphase (default AbstractValidationPhase).


Learner(model, lossfn, callbacks=[EarlyStopping(3)])
import FluxTraining.ES: Disjunction, InvalidValue, TimeLimit

callback = EarlyStopping(Disjunction(InvalidValue(), TimeLimit(0.5)))
Learner(model, lossfn, callbacks=[callback])

A hyperparameter is any state that influences the training and is not a parameter of the model.

Hyperparameters can be scheduled using the Scheduler callback.

Learner(model, lossfn; [callbacks = [], optimizer = ADAM(), kwargs...])

Holds and coordinates all state of the training. model is trained by optimizing lossfn with optimizer on data.


Positional arguments:

  • model: A Flux.jl model or a NamedTuple of models.
  • lossfn: Loss function with signature lossfn(model(x), y) -> Number.

Keyword arguments (optional):

  • data = (): Data iterators. A 2-tuple will be treated as (trainingdataiter, validdataiter). You can also pass in an empty tuple () and use the epoch! method with a dataiter as third argument.

    A data iterator is an iterable over batches. For regular supervised training, each batch should be a tuple (xs, ys).

  • optimizer = ADAM(): The optimizer used to update the model's weights

  • callbacks = []: A list of callbacks that should be used. If usedefaultcallbacks == true, this will be extended by the default callbacks

  • usedefaultcallbacks = true: Whether to add some basic callbacks. Included are Metrics, Recorder, ProgressPrinter, StopOnNaNLoss, and MetricsPrinter.

  • cbrunner = LinearRunner(): Callback runner to use.


(Use this as a reference when implementing callbacks)

  • model, optimizer, and lossfn are stored as passed in
  • data is a PropDict of data iterators, usually :training and :validation.
  • params: An instance of model's parameters of type Flux.Params. If model is a NamedTuple, then params is a NamedTuple as well.
  • step::PropDict: State of the last step. Contents depend on the last run Phase.
  • cbstate::PropDict: Special state container that callbacks can save state to for other callbacks. Its keys depend on what callbacks are being used. See the custom callbacks guide for more info.
LogHistograms(backends...[; freq = 100]) <: Callback

Callback that logs histograms of model weights to LoggerBackends backends every freq steps.

If histograms should be logged every step, pass freq = nothing

LogVisualization(visfn, backends...[; freq = 100])

Logs images created by visfn(learner.step) to backends every freq steps.

Metric(metricfn[; statistic, device, name])

Implementation of AbstractMetric that can be used with the Metrics callback.



  • metricfn(ŷs, ys) should return a number.


  • statistic is a OnlineStats.Statistic that is updated after every step. The default is OnlineStats.Mean()
  • name is used for printing.
  • device is a function applied to ŷs and ys before passing them to metricfn. The default is Flux.cpu so that the callback works if metricfn doesn't support arrays from other device types. If, for example, metricfn works on CurArrays, you can pass device = Flux.gpu.
  • phase = Phase: a (sub)type of Phase that restricts for which phases the metric is computed.


  • Metric(accuracy)
  • Metric(Flux.mse, device = gpu, name = "Mean Squared Error")
  • Metric(Flux.mae, device = gpu)
cb = Metric(Flux.mse, device = gpu, name = "Mean Squared Error")

If a metric is expensive to compute and you don't want it to slow down the training phase, you can compute it on the validation phase only:

cb = Metric(expensivemetric, P = ValidationPhase)
Metrics(metrics...) <: Callback

Callback that tracks metrics during training.

You can pass any number of metrics with every argument being

This callback is added by default to every Learner unless you pass in usedefaultcallbacks = false. A metric tracking learner.lossfn Loss is included by default.

The computed metrics can be access in learner.cbstate.metricsstep and learner.cbstate.metricsepoch for steps and epochs, respectively.


Track accuracy:

cb = Metrics(accuracy)

Pass in [Metric]s:

cb = Metrics(
    Metric(Flux.mse, device = gpu),
    Metric(Flux.mae, device = gpu)
MetricsPrinter() <: Callback

Callback that prints metrics after every epoch. Relies on the metrics computed by Metrics, so will error if no Metrics callback is used.

This callback is added by default to every Learner unless you pass in usedefaultcallbacks = false.

abstract type NoConflict <: ConflictResolution

Return from resolveconflict to indicate that, while the callbacks modify the same state, they can be used together without any problems.

abstract type NotDefined <: ConflictResolution

The default implementation of resolveconflict. If a conflict is detected, this ensures an error message is printed.


Like a Dict{Symbol}, but attribute syntax can be used to access values.

abstract type RunFirst <: ConflictResolution

Return RunFirst(cb1/cb2) from resolveconflict(cb1, cb2) to indicate that one of the callbacks should always run before the other.

SanityCheck([checks; usedefault = true])

Callback that runs sanity Checks when the Learner is initialized. If usedefault is true, it will run all checks in FluxTraining.CHECKS in addition to the ones you pass in.


Stops the training when a NaN loss is encountered.

This callback is added by default to every Learner unless you pass in usedefaultcallbacks = false.

ToDevice(movedatafn, movemodelfn) <: Callback

Moves model and step data to a device using movedatafn for step data and movemodelfn for the model. For example ToDevice(Flux.gpu, Flux.gpu), moves them to a GPU if available. See ToGPU.

By default, only moves step.xs and step.ys, but this can be extended to other state by implementing on(::StepBegin, ::MyCustomPhase, ::ToDevice, learner).

Traces(preprocess[, phase])

Record a trace during phase by apply each pre-processing function in preprocess to the Learner to produce a trace value. The trace is recorded at the end of each learning step.

See LogTraces for logging of the trace values.

cb = Traces((loss2 = learner -> learner.step.loss^2,
             avg_gnorm = learner -> mean(map((_, g) -> norm(g), pairs(learner.step.grads))))

Every nsteps steps, forces garbage collection. Use this if you get memory leaks from, for example, parallel data loading.

Performs an additional C-call on Linux systems that can sometimes help.


Callback that moves model and batch data to the GPU during training. Convenience for ToDevice(Flux.gpu).


Enumerate all valid state accesses of permissions of kind perm.

accesses((x = Read(),), Read()) === [(:x,)] accesses((x = Read(),), Write()) === []

addcallback!(learner, callback)

Adds callback to learner and updates the dependency graph.

callbackgraph(callbacks) -> SimpleDiGraph

Creates a directed acyclic graph from a list of callbacks. Ordering is given through runafter and resolveconflict.

If a write conflict cannot be resolved (i.e. resolveconflict) is not implemented), throws an error.


Return a vector of Edges representing dependencies defined by runafter.

epoch!(learner, phase[, dataiter])

Train learner for one epoch on dataiter. Iterates through dataiter and step!s for each batch/item.

If no data iterator is passed in, use learner.data[phasedataiter(phase)].


The default implementation iterates over every batch in dataiter and calls step! for each. This behavior can be overloaded by implementing epoch!(learner, ::MyPhase, dataiter).

If you're implementing a custom epoch! method, it is recommended you make use of runepoch to get begin and end events as well as proper handling of CancelEpochExceptions.

See the default implementation for reference.

fit!(learner, nepochs)
fit!(learner, nepochs, (trainiter, validiter))

Train learner for nepochs of training and validation each. Use data iterators that are passed in. If none are given, use learner.data.training and learner.data.validation.


fit!(learner, 10)
fit!(learner, 10, (traindl, valdl))
getcallback(learner, C)

Find callback of type C in learner's callbacks and return it. If there is none, return nothing.

init!(callback, learner)

Initialize a callback. Default is to do nothing.


To extend for a callback, implement init!(cb::MyCallback, learner). init! can set up internal state of a callback that depends on learner and can also initialize shared callback state in learner.cbstate. Just like on event handlers, the state access permissions must be correctly defined using stateaccess to do so.

init! must also be idempotent, i.e. running it twice on the same Learner should have the same effect as runnning it once.


Iterators over the Cartesian product of a with itself, skipping any pairs (a, b) where a == b.

log_to(backend, loggable, group, i)
log_to(backends, loggable, group, i)

Log loggable to backend with group to index i.

  • loggable is any Loggables.Loggable
  • group can be a String or a tuple of Strings implying some grouping which can be used by a supporting backend.
  • i is a step counter and unique for every group.
on(event::Event, phase::Phase, callback::AbstractCallback, learner)

Handle event with Callback callback. By default, this event handler does nothing for a callback.

To see events which an AbstractCallback handles, use

methods(Training.on, (Any, Any, MyCallbackType, Any)


You can add event handlers to Callbacks by implementing a method for on. See also Callback and custom callbacks.

A method of on should always dispatch on the callback type, i.e. on(event, phase, cb::MyCallback, learner). It may also dispatch on specific Events and Phase. It should not dispatch on a specific type for learner.

onecycle(nsteps, max_val, [start_val, end_val; pct_start])

Creates a one-cycle Schedule over nsteps steps from start_val over max_val to end_val.


epochlength = length(traindataiter)
cb = Scheduler(LearningRate => onecycle(10epochlength, 0.01))
learner = Learner(<args>...; callbacks=[cb])
removecallback!(learner, C)

Remove the first callback of type C from learner and return it. If there is none, return nothing.

replacecallback!(learner, callback::C)

Replace existing callback of type C on learner with callback. Return the replaced callback.

If learner doesn't have a callback of type C, add callback and return nothing.

resolveconflict(cb1, cb2)

Define a conflict resolution strategy for resolving a write/write conflict between two callbacks.

The default is [NotDefined()], which will result in an error and a message to implement this method.

To implement, dispatch on the callback types that you which to resolve (in any order) and return one of the following:

  • Unresolvable() if the callbacks must not be used together
  • RunFirst(cb) if one of the callbacks needs to run first; or
  • NoConflict() if the callbacks may run together in any order
runstep(stepfn, learner, phase) -> state

Run stepfn inside the context of a step. Calls stepfn(handle, state) where handle(e) can be called to dispatch events and state is a PropDict which step data, gradients and losses can be written to. Return state.

Takes care of dispatching StepBegin and StepEnd events as well as handling CancelStepExceptions.

FluxTraining.runtests(pattern...; kwargs...)

Equivalent to ReTest.retest(FluxTraining, pattern...; kwargs...). This function is defined automatically in any module containing a @testset, possibly nested within submodules.

setcallbacks!(learner, callbacks)

Set learner's callbacks to callbacks, removing all current callbacks.

sethyperparameter!(learner, H, value) -> learner

Sets hyperparameter H to value on learner, returning the modified learner.


Return a named tuple determining what learner state callback can access. The default is (;), the empty named tuple, meaning no state can be accessed. Implementations of stateaccess should always return the least permissions possible.


For example, the ToGPU callback needs to write both the model and the batch data, so its stateaccess implementation is:

stateaccess(::ToGPU) = (
    model = Write(),
    params = Write(),
    step = (xs = Write(), ys = Write()),

When defining stateaccess, be careful that you do return a NamedTuple. (x = Read(),) is one but (x = Read()) (without the comma) is parsed as an assignment with value Read().


Defines what Learner state is accessed when calling sethyperparameter! and gethyperparameter. This is needed so that Scheduler can access the state.

step!(learner, phase::Phase, batch)

Run one step of training for learner on batch. Behavior is customized through phase.


This is a required method for custom Phases to implement. To implement step!, it is recommended you make use of runstep to get begin and end events as well as proper handling of CancelStepExceptions.

See the implementations of TrainingPhase and ValidationPhase for reference.

testlearner(callbacks...[; opt, nbatches, coeff, batchsize, kwargs...])

Construct a Learner with a simple optimization problem. This learner should be used in tests that require training a model, e.g. for callbacks.

throttle(callback, Event, freq = 1)
throttle(callback, Event, seconds = 1)

Throttle Event type for callback so that it is triggered either only every freq'th time or every seconds seconds.


If you want to only sporadically log metrics (LogMetrics) or images (LogVisualization), throttle can be used as follows.

Every 10 steps:

callback = throttle(LogMetrics(TensorBoardBackend()), StepEnd, freq = 10)
learner = Learner(<args>; callbacks=[callback])

Or every 5 seconds:

callback = throttle(LogMetrics(TensorBoardBackend()), StepEnd, seconds = 5)
learner = Learner(<args>; callbacks=[callback])