FluxTraining.CallbackType
abstract type Callback

Supertype of all callbacks. Callbacks add custom functionality to the training loop by hooking into different Events.Events

Any Callback can be used by passing it to Learner. See subtypes(FluxTraining.Callback) for implementations.

Extending

See Custom callbacks for a less succinct tutorial format.

  1. Create a struct MyCallback that subtypes FluxTraining.Callback.

  2. Add event handlers by implementing methods for on(event, phase, callback, learner). Methods should always dispatch on your callback, and may dispatch on specific Phases.Phases and Events.Events.

    For example, to implement an event handler that runs at the end of every step during training: on(::StepEnd, ::AbstractTrainingPhase, ::MyCallback, learner).

  3. Define what state the callback accesses and/or modifies by implementing stateaccess(::MyCallback). While learner is always passed as an argument to on event handlers, by default a callback can not read or write to its fields. See stateaccess for more detail.

    If a callback needs to write some state that other callbacks should be able to access, it can store it in learner.cbstate if you add a permission in stateaccess.

  4. If the callback needs some one-time initialization, you can implement init! which will be run at least once before any step is run.

FluxTraining.CancelEpochExceptionType
CancelEpochException(message)

Throw during fitting to cancel the currently running epoch. This prematurely ends the current epoch without throwing an error. Must be thrown inside the context of runepoch.

Examples

runepoch(learner, phase) do _
    for batch in batches
        step!(learner, phase, batch)
        if learner.step.loss < 1.
            throw(CancelEpochException("Reached target loss"))
        end
    end
end
FluxTraining.CancelStepExceptionType
CancelStepException(message)

Throw during fitting to cancel the currently running step. This prematurely ends the current step without throwing an error. Must be thrown inside the context of runstep.

Examples

runepoch(learner, phase) do _
    for (xs, ys) in batches
        runstep(learner, phase, (; xs, ys)) do _, state
            # training logic...
            if isnan(state.loss).
                throw(CancelStepException("Skipping NaN loss"))
            end
        end
    end
end
FluxTraining.CustomCallbackType
CustomCallback(f, Event, [TPhase = Phase, access = (;)])

A callback that runs f(learner) every time an event of type Event during a phase of type in Phase.

If f needs to access learner state, pass access, a named tuple in the same form as stateaccess.

Instead of using CustomCallback it is recommended to properly implement a Callback.

Examples

We can get a quick idea of when a new epoch starts as follows:

cb = CustomCallback(learner -> println("New epoch!"), EpochBegin)
FluxTraining.EarlyStoppingType
EarlyStopping(criteria...; kwargs...)
EarlyStopping(n)

Stop training early when criteria are met. See EarlyStopping.jl for available stopping criteria.

Passing an integer n uses the simple patience criterion: stop if the validation loss hasn't decreased for n epochs.

You can control which phases are taken to measure the out-of-sample loss and the training loss with keyword arguments trainphase (default AbstractTrainingPhase) and testphase (default AbstractValidationPhase).

Examples

Learner(model, lossfn, callbacks=[EarlyStopping(3)])
import FluxTraining.ES: Disjunction, InvalidValue, TimeLimit

callback = EarlyStopping(Disjunction(InvalidValue(), TimeLimit(0.5)))
Learner(model, lossfn, callbacks=[callback])
FluxTraining.HyperParameterType
HyperParameter{T}

A hyperparameter is any state that influences the training and is not a parameter of the model.

Hyperparameters can be scheduled using the Scheduler callback.

FluxTraining.LearnerMethod
Learner(model, lossfn; [callbacks = [], optimizer = ADAM(), kwargs...])

Holds and coordinates all state of the training. model is trained by optimizing lossfn with optimizer on data.

Arguments

Positional arguments:

  • model: A Flux.jl model or a NamedTuple of models.
  • lossfn: Loss function with signature lossfn(model(x), y) -> Number.

Keyword arguments (optional):

  • data = (): Data iterators. A 2-tuple will be treated as (trainingdataiter, validdataiter). You can also pass in an empty tuple () and use the epoch! method with a dataiter as third argument.

    A data iterator is an iterable over batches. For regular supervised training, each batch should be a tuple (xs, ys).

  • optimizer = ADAM(): The optimizer used to update the model's weights

  • callbacks = []: A list of callbacks that should be used. If usedefaultcallbacks == true, this will be extended by the default callbacks

  • usedefaultcallbacks = true: Whether to add some basic callbacks. Included are Metrics, Recorder, ProgressPrinter, StopOnNaNLoss, and MetricsPrinter.

  • cbrunner = LinearRunner(): Callback runner to use.

Fields

(Use this as a reference when implementing callbacks)

  • model, optimizer, and lossfn are stored as passed in
  • data is a PropDict of data iterators, usually :training and :validation.
  • params: An instance of model's parameters of type Flux.Params. If model is a NamedTuple, then params is a NamedTuple as well.
  • step::PropDict: State of the last step. Contents depend on the last run Phase.
  • cbstate::PropDict: Special state container that callbacks can save state to for other callbacks. Its keys depend on what callbacks are being used. See the custom callbacks guide for more info.
FluxTraining.LogHistogramsType
LogHistograms(backends...[; freq = 100]) <: Callback

Callback that logs histograms of model weights to LoggerBackends backends every freq steps.

If histograms should be logged every step, pass freq = nothing

FluxTraining.LogVisualizationType
LogVisualization(visfn, backends...[; freq = 100])

Logs images created by visfn(learner.step) to backends every freq steps.

FluxTraining.MetricMethod
Metric(metricfn[; statistic, device, name])

Implementation of AbstractMetric that can be used with the Metrics callback.

Arguments

Positional:

  • metricfn(ŷs, ys) should return a number.

Keyword:

  • statistic is a OnlineStats.Statistic that is updated after every step. The default is OnlineStats.Mean()
  • name is used for printing.
  • device is a function applied to ŷs and ys before passing them to metricfn. The default is Flux.cpu so that the callback works if metricfn doesn't support arrays from other device types. If, for example, metricfn works on CurArrays, you can pass device = Flux.gpu.
  • phase = Phase: a (sub)type of Phase that restricts for which phases the metric is computed.

Examples

  • Metric(accuracy)
  • Metric(Flux.mse, device = gpu, name = "Mean Squared Error")
  • Metric(Flux.mae, device = gpu)
cb = Metric(Flux.mse, device = gpu, name = "Mean Squared Error")

If a metric is expensive to compute and you don't want it to slow down the training phase, you can compute it on the validation phase only:

cb = Metric(expensivemetric, P = ValidationPhase)
FluxTraining.MetricsType
Metrics(metrics...) <: Callback

Callback that tracks metrics during training.

You can pass any number of metrics with every argument being

This callback is added by default to every Learner unless you pass in usedefaultcallbacks = false. A metric tracking learner.lossfn Loss is included by default.

The computed metrics can be access in learner.cbstate.metricsstep and learner.cbstate.metricsepoch for steps and epochs, respectively.

Examples

Track accuracy:

cb = Metrics(accuracy)

Pass in [Metric]s:

cb = Metrics(
    Metric(Flux.mse, device = gpu),
    Metric(Flux.mae, device = gpu)
)
FluxTraining.MetricsPrinterType
MetricsPrinter() <: Callback

Callback that prints metrics after every epoch. Relies on the metrics computed by Metrics, so will error if no Metrics callback is used.

This callback is added by default to every Learner unless you pass in usedefaultcallbacks = false.

FluxTraining.NoConflictType
abstract type NoConflict <: ConflictResolution

Return from resolveconflict to indicate that, while the callbacks modify the same state, they can be used together without any problems.

FluxTraining.NotDefinedType
abstract type NotDefined <: ConflictResolution

The default implementation of resolveconflict. If a conflict is detected, this ensures an error message is printed.

FluxTraining.PropDictType
PropDict(dict)

Like a Dict{Symbol}, but attribute syntax can be used to access values.

FluxTraining.RunFirstType
abstract type RunFirst <: ConflictResolution

Return RunFirst(cb1/cb2) from resolveconflict(cb1, cb2) to indicate that one of the callbacks should always run before the other.

FluxTraining.SanityCheckType
SanityCheck([checks; usedefault = true])

Callback that runs sanity Checks when the Learner is initialized. If usedefault is true, it will run all checks in FluxTraining.CHECKS in addition to the ones you pass in.

FluxTraining.StopOnNaNLossType
StopOnNaNLoss()

Stops the training when a NaN loss is encountered.

This callback is added by default to every Learner unless you pass in usedefaultcallbacks = false.

FluxTraining.ToDeviceType
ToDevice(movedatafn, movemodelfn) <: Callback

Moves model and step data to a device using movedatafn for step data and movemodelfn for the model. For example ToDevice(Flux.gpu, Flux.gpu), moves them to a GPU if available. See ToGPU.

By default, only moves step.xs and step.ys, but this can be extended to other state by implementing on(::StepBegin, ::MyCustomPhase, ::ToDevice, learner).

FluxTraining.TracesType
Traces(preprocess[, phase])

Record a trace during phase by apply each pre-processing function in preprocess to the Learner to produce a trace value. The trace is recorded at the end of each learning step.

See LogTraces for logging of the trace values.

cb = Traces((loss2 = learner -> learner.step.loss^2,
             avg_gnorm = learner -> mean(map((_, g) -> norm(g), pairs(learner.step.grads))))
            TrainingPhase)
FluxTraining.GarbageCollectFunction
GarbageCollect(nsteps)

Every nsteps steps, forces garbage collection. Use this if you get memory leaks from, for example, parallel data loading.

Performs an additional C-call on Linux systems that can sometimes help.

FluxTraining.ToGPUMethod
ToGPU()

Callback that moves model and batch data to the GPU during training. Convenience for ToDevice(Flux.gpu).

FluxTraining.accessesFunction
accesses()

Enumerate all valid state accesses of permissions of kind perm.

accesses((x = Read(),), Read()) === [(:x,)] accesses((x = Read(),), Write()) === []

FluxTraining.addcallback!Method
addcallback!(learner, callback)

Adds callback to learner and updates the dependency graph.

FluxTraining.callbackgraphMethod
callbackgraph(callbacks) -> SimpleDiGraph

Creates a directed acyclic graph from a list of callbacks. Ordering is given through runafter and resolveconflict.

If a write conflict cannot be resolved (i.e. resolveconflict) is not implemented), throws an error.

FluxTraining.edgesrunafterMethod
edgesrunafter(callbacks)

Return a vector of Edges representing dependencies defined by runafter.

FluxTraining.epoch!Function
epoch!(learner, phase[, dataiter])

Train learner for one epoch on dataiter. Iterates through dataiter and step!s for each batch/item.

If no data iterator is passed in, use learner.data[phasedataiter(phase)].

Extending

The default implementation iterates over every batch in dataiter and calls step! for each. This behavior can be overloaded by implementing epoch!(learner, ::MyPhase, dataiter).

If you're implementing a custom epoch! method, it is recommended you make use of runepoch to get begin and end events as well as proper handling of CancelEpochExceptions.

See the default implementation for reference.

FluxTraining.fit!Method
fit!(learner, nepochs)
fit!(learner, nepochs, (trainiter, validiter))

Train learner for nepochs of training and validation each. Use data iterators that are passed in. If none are given, use learner.data.training and learner.data.validation.

Examples

fit!(learner, 10)
fit!(learner, 10, (traindl, valdl))
FluxTraining.getcallbackMethod
getcallback(learner, C)

Find callback of type C in learner's callbacks and return it. If there is none, return nothing.

FluxTraining.init!Method
init!(callback, learner)

Initialize a callback. Default is to do nothing.

Extending

To extend for a callback, implement init!(cb::MyCallback, learner). init! can set up internal state of a callback that depends on learner and can also initialize shared callback state in learner.cbstate. Just like on event handlers, the state access permissions must be correctly defined using stateaccess to do so.

init! must also be idempotent, i.e. running it twice on the same Learner should have the same effect as runnning it once.

FluxTraining.iterpairsMethod
iterpairs(a)

Iterators over the Cartesian product of a with itself, skipping any pairs (a, b) where a == b.

FluxTraining.log_toMethod
log_to(backend, loggable, group, i)
log_to(backends, loggable, group, i)

Log loggable to backend with group to index i.

  • loggable is any Loggables.Loggable
  • group can be a String or a tuple of Strings implying some grouping which can be used by a supporting backend.
  • i is a step counter and unique for every group.
FluxTraining.onMethod
on(event::Event, phase::Phase, callback::AbstractCallback, learner)

Handle event with Callback callback. By default, this event handler does nothing for a callback.

To see events which an AbstractCallback handles, use

methods(Training.on, (Any, Any, MyCallbackType, Any)

Extending

You can add event handlers to Callbacks by implementing a method for on. See also Callback and custom callbacks.

A method of on should always dispatch on the callback type, i.e. on(event, phase, cb::MyCallback, learner). It may also dispatch on specific Events and Phase. It should not dispatch on a specific type for learner.

FluxTraining.onecycleMethod
onecycle(nsteps, max_val, [start_val, end_val; pct_start])

Creates a one-cycle Schedule over nsteps steps from start_val over max_val to end_val.

Examples


epochlength = length(traindataiter)
cb = Scheduler(LearningRate => onecycle(10epochlength, 0.01))
learner = Learner(<args>...; callbacks=[cb])
FluxTraining.removecallback!Method
removecallback!(learner, C)

Remove the first callback of type C from learner and return it. If there is none, return nothing.

FluxTraining.replacecallback!Method
replacecallback!(learner, callback::C)

Replace existing callback of type C on learner with callback. Return the replaced callback.

If learner doesn't have a callback of type C, add callback and return nothing.

FluxTraining.resolveconflictMethod
resolveconflict(cb1, cb2)

Define a conflict resolution strategy for resolving a write/write conflict between two callbacks.

The default is [NotDefined()], which will result in an error and a message to implement this method.

To implement, dispatch on the callback types that you which to resolve (in any order) and return one of the following:

  • Unresolvable() if the callbacks must not be used together
  • RunFirst(cb) if one of the callbacks needs to run first; or
  • NoConflict() if the callbacks may run together in any order
FluxTraining.runstepFunction
runstep(stepfn, learner, phase) -> state

Run stepfn inside the context of a step. Calls stepfn(handle, state) where handle(e) can be called to dispatch events and state is a PropDict which step data, gradients and losses can be written to. Return state.

Takes care of dispatching StepBegin and StepEnd events as well as handling CancelStepExceptions.

FluxTraining.runtestsMethod
FluxTraining.runtests(pattern...; kwargs...)

Equivalent to ReTest.retest(FluxTraining, pattern...; kwargs...). This function is defined automatically in any module containing a @testset, possibly nested within submodules.

FluxTraining.setcallbacks!Method
setcallbacks!(learner, callbacks)

Set learner's callbacks to callbacks, removing all current callbacks.

FluxTraining.sethyperparameter!Function
sethyperparameter!(learner, H, value) -> learner

Sets hyperparameter H to value on learner, returning the modified learner.

FluxTraining.stateaccessMethod
stateaccess(callback)

Return a named tuple determining what learner state callback can access. The default is (;), the empty named tuple, meaning no state can be accessed. Implementations of stateaccess should always return the least permissions possible.

Extending

For example, the ToGPU callback needs to write both the model and the batch data, so its stateaccess implementation is:

stateaccess(::ToGPU) = (
    model = Write(),
    params = Write(),
    step = (xs = Write(), ys = Write()),
)

When defining stateaccess, be careful that you do return a NamedTuple. (x = Read(),) is one but (x = Read()) (without the comma) is parsed as an assignment with value Read().

FluxTraining.stateaccessMethod
stateaccess(::Type{HyperParameter})

Defines what Learner state is accessed when calling sethyperparameter! and gethyperparameter. This is needed so that Scheduler can access the state.

FluxTraining.step!Function
step!(learner, phase::Phase, batch)

Run one step of training for learner on batch. Behavior is customized through phase.

Extending

This is a required method for custom Phases to implement. To implement step!, it is recommended you make use of runstep to get begin and end events as well as proper handling of CancelStepExceptions.

See the implementations of TrainingPhase and ValidationPhase for reference.

FluxTraining.testlearnerMethod
testlearner(callbacks...[; opt, nbatches, coeff, batchsize, kwargs...])

Construct a Learner with a simple optimization problem. This learner should be used in tests that require training a model, e.g. for callbacks.

FluxTraining.throttleMethod
throttle(callback, Event, freq = 1)
throttle(callback, Event, seconds = 1)

Throttle Event type for callback so that it is triggered either only every freq'th time or every seconds seconds.

Examples

If you want to only sporadically log metrics (LogMetrics) or images (LogVisualization), throttle can be used as follows.

Every 10 steps:

callback = throttle(LogMetrics(TensorBoardBackend()), StepEnd, freq = 10)
learner = Learner(<args>; callbacks=[callback])

Or every 5 seconds:

callback = throttle(LogMetrics(TensorBoardBackend()), StepEnd, seconds = 5)
learner = Learner(<args>; callbacks=[callback])