FluxTraining.AbstractCallback
— Typeabstract type AbstractCallback
Supertype of SafeCallback
/Callback
. When implementing callbacks, you should subtype SafeCallback
instead.
FluxTraining.AbstractMetric
— Typeabstract type AbstractMetric
Abstract type for metrics passed to Metrics
.
For most use cases, you should use Metric
, the standard implementation.
Interface
If Metric
doesn't fit your use case, you can create a new subtype of AbstractMetric
and implement the following methods to make it compatible with Metrics
:
reset!
(metric)
step!
(metric, learner)
stepvalue
(metric)
epochvalue
(metric)
metricname
(metric)
FluxTraining.Callback
— Typeabstract type Callback
Supertype of all callbacks. Callbacks add custom functionality to the training loop by hooking into different Events.Event
s
Any Callback
can be used by passing it to Learner
. See subtypes(FluxTraining.Callback)
for implementations.
Extending
See Custom callbacks for a less succinct tutorial format.
Create a
struct MyCallback
that subtypesFluxTraining.Callback
.Add event handlers by implementing methods for
on
(event, phase, callback, learner)
. Methods should always dispatch on your callback, and may dispatch on specificPhases.Phase
s andEvents.Event
s.For example, to implement an event handler that runs at the end of every step during training:
on(::StepEnd, ::AbstractTrainingPhase, ::MyCallback, learner)
.Define what state the callback accesses and/or modifies by implementing
stateaccess
(::MyCallback)
. Whilelearner
is always passed as an argument toon
event handlers, by default a callback can not read or write to its fields. Seestateaccess
for more detail.If a callback needs to write some state that other callbacks should be able to access, it can store it in
learner.cbstate
if you add a permission instateaccess
.If the callback needs some one-time initialization, you can implement
init!
which will be run at least once before any step is run.
FluxTraining.CallbackCondition
— Typeabstract type CallbackCondition
Supertype for conditions to use with ConditionalCallback
. To implement a CallbackCondition
, implement shouldrun
(::MyCondition, event, phase)
.
See FrequencyThrottle
, TimeThrottle
and throttle
.
FluxTraining.CancelEpochException
— TypeCancelEpochException(message)
Throw during fitting to cancel the currently running epoch. This prematurely ends the current epoch without throwing an error. Must be thrown inside the context of runepoch
.
Examples
runepoch(learner, phase) do _
for batch in batches
step!(learner, phase, batch)
if learner.step.loss < 1.
throw(CancelEpochException("Reached target loss"))
end
end
end
FluxTraining.CancelFittingException
— TypeCancelFittingException(msg)
Throw during fitting to cancel it.
FluxTraining.CancelStepException
— TypeCancelStepException(message)
Throw during fitting to cancel the currently running step. This prematurely ends the current step without throwing an error. Must be thrown inside the context of runstep
.
Examples
runepoch(learner, phase) do _
for (xs, ys) in batches
runstep(learner, phase, (; xs, ys)) do _, state
# training logic...
if isnan(state.loss).
throw(CancelStepException("Skipping NaN loss"))
end
end
end
end
FluxTraining.Checkpointer
— TypeCheckpointer(folder)
Saves learner.model
to folder
after every AbstractTrainingPhase
. If keep_top_k
is provided, only the best k models (by smallest training loss) and the latest model are kept.
Use FluxTraining.
loadmodel
to load a model.
FluxTraining.ConditionalCallback
— TypeConditionalCallback(callback, condition) <: Callback
Wrapper callback that only forwards events to the wrapped callback if CallbackCondition
condition
is met. See throttle
.
FluxTraining.ConflictResolution
— Typeabstract type ConflictResolution
A conflict resolution strategy for resolving write/write conflicts of two callbacks.
See resolveconflict
.
FluxTraining.CustomCallback
— TypeCustomCallback(f, Event, [TPhase = Phase, access = (;)])
A callback that runs f(learner)
every time an event of type Event
during a phase of type in Phase
.
If f
needs to access learner state, pass access
, a named tuple in the same form as stateaccess
.
Instead of using CustomCallback
it is recommended to properly implement a Callback
.
Examples
We can get a quick idea of when a new epoch starts as follows:
cb = CustomCallback(learner -> println("New epoch!"), EpochBegin)
FluxTraining.EarlyStopping
— TypeEarlyStopping(criteria...; kwargs...)
EarlyStopping(n)
Stop training early when criteria
are met. See EarlyStopping.jl for available stopping criteria.
Passing an integer n
uses the simple patience criterion: stop if the validation loss hasn't decreased for n
epochs.
You can control which phases are taken to measure the out-of-sample loss and the training loss with keyword arguments trainphase
(default AbstractTrainingPhase
) and testphase
(default AbstractValidationPhase
).
Examples
Learner(model, lossfn, callbacks=[EarlyStopping(3)])
import FluxTraining.ES: Disjunction, InvalidValue, TimeLimit
callback = EarlyStopping(Disjunction(InvalidValue(), TimeLimit(0.5)))
Learner(model, lossfn, callbacks=[callback])
FluxTraining.FitException
— Typeabstract type FitException
Abstract types for exceptions that can be thrown during fitting, to change its control flow.
See CancelStepException
, CancelEpochException
, CancelFittingException
.
FluxTraining.HyperParameter
— TypeHyperParameter{T}
A hyperparameter is any state that influences the training and is not a parameter of the model.
Hyperparameters can be scheduled using the Scheduler
callback.
FluxTraining.Learner
— MethodLearner(model, data, optimizer, lossfn, [callbacks...; kwargs...])
FluxTraining.Learner
— MethodLearner(model, lossfn; [callbacks = [], optimizer = ADAM(), kwargs...])
Holds and coordinates all state of the training. model
is trained by optimizing lossfn
with optimizer
on data
.
Arguments
Positional arguments:
model
: A Flux.jl model or aNamedTuple
of models.lossfn
: Loss function with signaturelossfn(model(x), y) -> Number
.
Keyword arguments (optional):
data = ()
: Data iterators. A 2-tuple will be treated as(trainingdataiter, validdataiter)
. You can also pass in an empty tuple()
and use theepoch!
method with adataiter
as third argument.A data iterator is an iterable over batches. For regular supervised training, each batch should be a tuple
(xs, ys)
.optimizer = ADAM()
: The optimizer used to update themodel
's weightscallbacks = []
: A list of callbacks that should be used. Ifusedefaultcallbacks == true
, this will be extended by the default callbacksusedefaultcallbacks = true
: Whether to add some basic callbacks. Included areMetrics
,Recorder
,ProgressPrinter
,StopOnNaNLoss
, andMetricsPrinter
.cbrunner = LinearRunner()
: Callback runner to use.
Fields
(Use this as a reference when implementing callbacks)
model
,optimizer
, andlossfn
are stored as passed indata
is aPropDict
of data iterators, usually:training
and:validation
.params
: An instance ofmodel
's parameters of typeFlux.Params
. Ifmodel
is aNamedTuple
, thenparams
is aNamedTuple
as well.step::
PropDict
: State of the last step. Contents depend on the last runPhase
.cbstate::
PropDict
: Special state container that callbacks can save state to for other callbacks. Its keys depend on what callbacks are being used. See the custom callbacks guide for more info.
FluxTraining.LearningRate
— Typeabstract type LearningRate <: HyperParameter
Hyperparameter for the optimizer's learning rate.
See Scheduler
and hyperparameter scheduling.
FluxTraining.LogHistograms
— TypeLogHistograms(backends...[; freq = 100]) <: Callback
Callback that logs histograms of model weights to LoggerBackend
s backends
every freq
steps.
If histograms should be logged every step, pass freq = nothing
FluxTraining.LogHyperParams
— TypeLogHyperParams(backends...) <: Callback
Callback that logs hyperparameters to one or more LoggerBackend
s.
See also LoggerBackend
, Loggables.Loggable
, log_to
, TensorBoardBackend
Example
logcb = LogHyperParams(TensorBoardBackend("tblogs"))
schedule = ...
Learner(model, lossfn; callbacks=[Scheduler(LearningRate => schedule), logcb])
FluxTraining.LogMetrics
— TypeLogMetrics(backends...) <: Callback
Callback that logs step and epoch metrics to one or more LoggerBackend
s.
See also LoggerBackend
, Loggables.Loggable
, log_to
, TensorBoardBackend
Example:
logcb = LogMetrics(TensorBoardBackend("tblogs"))
Learner(model, lossfn; callbacks=[Metrics(accuracy), logcb])
FluxTraining.LogTraces
— TypeLogTraces(backends...) <: Callback
Callback that logs step traces to one or more LoggerBackend
s.
See also LoggerBackend
, Loggables.Loggable
, log_to
, TensorBoardBackend
Example:
logcb = LogTraces(TensorBoardBackend("tblogs"))
tracer = Traces((trace = learner -> learner.step.loss^2,), TrainingPhase)
Learner(model, lossfn; callbacks=[tracer, logcb])
FluxTraining.LogVisualization
— TypeLogVisualization(visfn, backends...[; freq = 100])
Logs images created by visfn(learner.step)
to backends
every freq
steps.
FluxTraining.LoggerBackend
— Typeabstract type LoggerBackend
Backend for logging callbacks like.
To add support for logging Loggables.Loggable
L
to backend B
, implement
log_to
(backend::B, loggable::L, names, i)
See also LogMetrics
, LogHyperParams
, log_to
FluxTraining.Metric
— MethodMetric(metricfn[; statistic, device, name])
Implementation of AbstractMetric
that can be used with the Metrics
callback.
Arguments
Positional:
metricfn(ŷs, ys)
should return a number.
Keyword:
statistic
is aOnlineStats.Statistic
that is updated after every step. The default isOnlineStats.Mean()
name
is used for printing.device
is a function applied toŷs
andys
before passing them tometricfn
. The default isFlux.cpu
so that the callback works ifmetricfn
doesn't support arrays from other device types. If, for example,metricfn
works onCurArray
s, you can passdevice = Flux.gpu
.phase = Phase
: a (sub)type ofPhase
that restricts for which phases the metric is computed.
Examples
Metric(accuracy)
Metric(Flux.mse, device = gpu, name = "Mean Squared Error")
Metric(Flux.mae, device = gpu)
cb = Metric(Flux.mse, device = gpu, name = "Mean Squared Error")
If a metric is expensive to compute and you don't want it to slow down the training phase, you can compute it on the validation phase only:
cb = Metric(expensivemetric, P = ValidationPhase)
FluxTraining.Metrics
— TypeMetrics(metrics...) <: Callback
Callback that tracks metrics during training.
You can pass any number of metrics
with every argument being
- an
AbstractMetric
likeMetric
; or - a function
f(ŷs, ys) -> val
This callback is added by default to every Learner
unless you pass in usedefaultcallbacks = false
. A metric tracking learner.lossfn
Loss
is included by default.
The computed metrics can be access in learner.cbstate.metricsstep
and learner.cbstate.metricsepoch
for steps and epochs, respectively.
Examples
Track accuracy
:
cb = Metrics(accuracy)
Pass in [Metric
]s:
cb = Metrics(
Metric(Flux.mse, device = gpu),
Metric(Flux.mae, device = gpu)
)
FluxTraining.MetricsPrinter
— TypeFluxTraining.NoConflict
— Typeabstract type NoConflict <: ConflictResolution
Return from resolveconflict
to indicate that, while the callbacks modify the same state, they can be used together without any problems.
FluxTraining.NotDefined
— Typeabstract type NotDefined <: ConflictResolution
The default implementation of resolveconflict
. If a conflict is detected, this ensures an error message is printed.
FluxTraining.ProgressPrinter
— TypeProgressPrinter()
Prints a progress bar of the currently running epoch.
FluxTraining.PropDict
— TypePropDict(dict)
Like a Dict{Symbol}
, but attribute syntax can be used to access values.
FluxTraining.Recorder
— TypeRecorder()
Maintains a History
. It's stored in learner.cbstate.history
.
FluxTraining.RunFirst
— Typeabstract type RunFirst <: ConflictResolution
Return RunFirst(cb1/cb2)
from resolveconflict
(cb1, cb2)
to indicate that one of the callbacks should always run before the other.
FluxTraining.SanityCheck
— TypeSanityCheck([checks; usedefault = true])
Callback that runs sanity Check
s when the Learner
is initialized. If usedefault
is true
, it will run all checks in FluxTraining.CHECKS in addition to the ones you pass in.
FluxTraining.Scheduler
— TypeScheduler(schedules...)
Callback for hyperparameter scheduling. Takes pairs of HyperParameter
types and ParameterSchedulers.jl schedules.
See the tutorial for more information.
Example
es = length(learner.data.training)
lrschedule = ParameterSchedulers.Step(;λ=1.0, γ=0.9, step_sizes=[10, 20])
scheduler = Scheduler(
LearningRate => lrschedule
)
FluxTraining.StopOnNaNLoss
— TypeStopOnNaNLoss()
Stops the training when a NaN loss is encountered.
This callback is added by default to every Learner
unless you pass in usedefaultcallbacks = false
.
FluxTraining.TensorBoardBackend
— TypeTensorBoardBackend(logdir[, tb_overwrite];
time=time(),
purge_step=nothing,
step_increment=1,
min_level=Logging.Info)
TensorBoard backend for logging callbacks. Takes the same arguments as TensorBoardLogger.TBLogger
.
FluxTraining.ToDevice
— TypeToDevice(movedatafn, movemodelfn) <: Callback
Moves model and step data to a device using movedatafn
for step data and movemodelfn
for the model. For example ToDevice(Flux.gpu, Flux.gpu)
, moves them to a GPU if available. See ToGPU
.
By default, only moves step.xs
and step.ys
, but this can be extended to other state by implementing on(::StepBegin, ::MyCustomPhase, ::ToDevice, learner)
.
FluxTraining.Traces
— TypeTraces(preprocess[, phase])
Record a trace during phase
by apply each pre-processing function in preprocess
to the Learner
to produce a trace value. The trace is recorded at the end of each learning step.
See LogTraces
for logging of the trace values.
cb = Traces((loss2 = learner -> learner.step.loss^2,
avg_gnorm = learner -> mean(map((_, g) -> norm(g), pairs(learner.step.grads))))
TrainingPhase)
FluxTraining.Unresolvable
— Typeabstract type Unresolvable <: ConflictResolution
Return from resolveconflict
to indicate that two callbacks are incompatible and cannot be used together.
FluxTraining.GarbageCollect
— FunctionGarbageCollect(nsteps)
Every nsteps
steps, forces garbage collection. Use this if you get memory leaks from, for example, parallel data loading.
Performs an additional C-call on Linux systems that can sometimes help.
FluxTraining.ToGPU
— MethodToGPU()
Callback that moves model and batch data to the GPU during training. Convenience for ToDevice
(Flux.gpu)
.
FluxTraining.accesses
— Functionaccesses()
Enumerate all valid state accesses of permissions of kind perm
.
accesses((x = Read(),), Read()) === [(:x,)]
accesses((x = Read(),), Write()) === []
FluxTraining.addcallback!
— Methodaddcallback!(learner, callback)
Adds callback
to learner
and updates the dependency graph.
FluxTraining.callbackgraph
— Methodcallbackgraph(callbacks) -> SimpleDiGraph
Creates a directed acyclic graph from a list of callbacks
. Ordering is given through runafter
and resolveconflict
.
If a write conflict cannot be resolved (i.e. resolveconflict
) is not implemented), throws an error.
FluxTraining.edgesrunafter
— Methodedgesrunafter(callbacks)
Return a vector of Edge
s representing dependencies defined by runafter
.
FluxTraining.epoch!
— Functionepoch!(learner, phase[, dataiter])
Train learner
for one epoch on dataiter
. Iterates through dataiter
and step!
s for each batch/item.
If no data iterator is passed in, use learner.data[phasedataiter(phase)]
.
Extending
The default implementation iterates over every batch in dataiter
and calls step!
for each. This behavior can be overloaded by implementing epoch!(learner, ::MyPhase, dataiter)
.
If you're implementing a custom epoch!
method, it is recommended you make use of runepoch
to get begin and end events as well as proper handling of CancelEpochException
s.
See the default implementation for reference.
FluxTraining.fit!
— Methodfit!(learner, nepochs)
fit!(learner, nepochs, (trainiter, validiter))
Train learner
for nepochs
of training and validation each. Use data iterators that are passed in. If none are given, use learner.data.training
and learner.data.validation
.
Examples
fit!(learner, 10)
fit!(learner, 10, (traindl, valdl))
FluxTraining.getcallback
— Methodgetcallback(learner, C)
Find callback of type C
in learner
's callbacks and return it. If there is none, return nothing
.
FluxTraining.init!
— Methodinit!(callback, learner)
Initialize a callback. Default is to do nothing.
Extending
To extend for a callback, implement init!(cb::MyCallback, learner)
. init!
can set up internal state of a callback that depends on learner
and can also initialize shared callback state in learner.cbstate
. Just like on
event handlers, the state access permissions must be correctly defined using stateaccess
to do so.
init!
must also be idempotent, i.e. running it twice on the same Learner
should have the same effect as runnning it once.
FluxTraining.iterpairs
— Methoditerpairs(a)
Iterators over the Cartesian product of a
with itself, skipping any pairs (a, b)
where a == b
.
FluxTraining.loadmodel
— Methodloadmodel(path)
Loads a model that was saved to path
using FluxTraining.
savemodel
.
FluxTraining.log_to
— Methodlog_to(backend, loggable, group, i)
log_to(backends, loggable, group, i)
Log loggable
to backend
with group
to index i
.
loggable
is anyLoggables.Loggable
group
can be aString
or a tuple ofString
s implying some grouping which can be used by a supporting backend.i
is a step counter and unique for every group.
FluxTraining.on
— Methodon(event::Event, phase::Phase, callback::AbstractCallback, learner)
Handle event
with Callback
callback
. By default, this event handler does nothing for a callback.
To see events which an AbstractCallback
handles, use
methods(Training.on, (Any, Any, MyCallbackType, Any)
Extending
You can add event handlers to Callback
s by implementing a method for on
. See also Callback
and custom callbacks.
A method of on
should always dispatch on the callback type, i.e. on(event, phase, cb::MyCallback, learner)
. It may also dispatch on specific Event
s and Phase
. It should not dispatch on a specific type for learner
.
FluxTraining.onecycle
— Methodonecycle(nsteps, max_val, [start_val, end_val; pct_start])
Creates a one-cycle Schedule
over nsteps
steps from start_val
over max_val
to end_val
.
Examples
epochlength = length(traindataiter)
cb = Scheduler(LearningRate => onecycle(10epochlength, 0.01))
learner = Learner(<args>...; callbacks=[cb])
FluxTraining.process_top_k_checkpoints
— MethodMakes sure only the best k and the latest checkpoints are kept on disk.
FluxTraining.removecallback!
— Methodremovecallback!(learner, C)
Remove the first callback of type C
from learner
and return it. If there is none, return nothing
.
FluxTraining.replacecallback!
— Methodreplacecallback!(learner, callback::C)
Replace existing callback of type C
on learner with callback
. Return the replaced callback.
If learner
doesn't have a callback of type C
, add callback
and return nothing
.
FluxTraining.resolveconflict
— Methodresolveconflict(cb1, cb2)
Define a conflict resolution strategy for resolving a write/write conflict between two callbacks.
The default is [NotDefined()
], which will result in an error and a message to implement this method.
To implement, dispatch on the callback types that you which to resolve (in any order) and return one of the following:
Unresolvable
()
if the callbacks must not be used togetherRunFirst
(cb)
if one of the callbacks needs to run first; orNoConflict
()
if the callbacks may run together in any order
FluxTraining.runepoch
— Methodrunepoch(epochfn, learner, phase)
Run epochfn
inside the context of an epoch. Calls epochfn(handle)
where handle(e)
can be called to dispatch events.
Takes care of dispatching EpochBegin
and EpochEnd
events as well as handling CancelEpochException
s.
FluxTraining.runstep
— Functionrunstep(stepfn, learner, phase) -> state
Run stepfn
inside the context of a step. Calls stepfn(handle, state)
where handle(e)
can be called to dispatch events and state
is a PropDict
which step data, gradients and losses can be written to. Return state
.
Takes care of dispatching StepBegin
and StepEnd
events as well as handling CancelStepException
s.
FluxTraining.runtests
— MethodFluxTraining.runtests(pattern...; kwargs...)
Equivalent to ReTest.retest(FluxTraining, pattern...; kwargs...)
. This function is defined automatically in any module containing a @testset
, possibly nested within submodules.
FluxTraining.setcallbacks!
— Methodsetcallbacks!(learner, callbacks)
Set learner
's callbacks to callbacks
, removing all current callbacks.
FluxTraining.sethyperparameter!
— Functionsethyperparameter!(learner, H, value) -> learner
Sets hyperparameter H
to value
on learner
, returning the modified learner.
FluxTraining.stateaccess
— Methodstateaccess(callback)
Return a named tuple determining what learner state callback
can access. The default is (;)
, the empty named tuple, meaning no state can be accessed. Implementations of stateaccess
should always return the least permissions possible.
Extending
For example, the ToGPU
callback needs to write both the model and the batch data, so its stateaccess
implementation is:
stateaccess(::ToGPU) = (
model = Write(),
params = Write(),
step = (xs = Write(), ys = Write()),
)
When defining stateaccess
, be careful that you do return a NamedTuple
. (x = Read(),)
is one but (x = Read())
(without the comma) is parsed as an assignment with value Read()
.
FluxTraining.stateaccess
— Methodstateaccess(::Type{HyperParameter})
Defines what Learner
state is accessed when calling sethyperparameter!
and gethyperparameter
. This is needed so that Scheduler
can access the state.
FluxTraining.step!
— Functionstep!(learner, phase::Phase, batch)
Run one step of training for learner
on batch. Behavior is customized through phase
.
Extending
This is a required method for custom Phase
s to implement. To implement step!
, it is recommended you make use of runstep
to get begin and end events as well as proper handling of CancelStepException
s.
See the implementations of TrainingPhase
and ValidationPhase
for reference.
FluxTraining.testlearner
— Methodtestlearner(callbacks...[; opt, nbatches, coeff, batchsize, kwargs...])
Construct a Learner
with a simple optimization problem. This learner should be used in tests that require training a model, e.g. for callbacks.
FluxTraining.throttle
— Methodthrottle(callback, Event, freq = 1)
throttle(callback, Event, seconds = 1)
Throttle Event
type for callback
so that it is triggered either only every freq
'th time or every seconds
seconds.
Examples
If you want to only sporadically log metrics (LogMetrics
) or images (LogVisualization
), throttle
can be used as follows.
Every 10 steps:
callback = throttle(LogMetrics(TensorBoardBackend()), StepEnd, freq = 10)
learner = Learner(<args>; callbacks=[callback])
Or every 5 seconds:
callback = throttle(LogMetrics(TensorBoardBackend()), StepEnd, seconds = 5)
learner = Learner(<args>; callbacks=[callback])
FluxTraining.Loggables
— Modulemodule Loggables
Defines Loggables.Loggable
and its subtypes.
FluxTraining.Loggables.Loggable
— Typeabstract type Loggable
Abstract type for data that LoggerBackend
s can log. See subtypes(FluxTraining.Loggables.Loggable)
and LoggerBackend
FluxTraining.Phases.AbstractTrainingPhase
— Typeabstract type AbstractTrainingPhase <: Phase
An abstract type for phases where parameter updates are being made. This exists so callbacks can dispatch on it and work with custom training phases.
The default implementation for supervised tasks is TrainingPhase
.
FluxTraining.Phases.AbstractValidationPhase
— Typeabstract type AbstractValidationPhase <: Phase
An abstract type for phases where no parameter updates are being made. This exists so callbacks can dispatch on it and work with custom validation phases.
The default implementation for supervised tasks is ValidationPhase
.
FluxTraining.Phases.Phase
— TypeFluxTraining.Phases.TrainingPhase
— TypeTrainingPhase() <: AbstractTrainingPhase
A regular training phase for supervised learning. It iterates over batches in learner.data.training
and updates the model parameters using learner.optim
after calculating the gradients.
Throws the following events in this order:
EpochBegin
when an epoch starts,StepBegin
when a step starts,LossBegin
after the forward pass but before loss calculation,BackwardBegin
after loss calculation but before backward pass,BackwardEnd
after the bacward pass but before the optimization step,StepEnd
when a step ends; andEpochEnd
when an epoch ends
It writes the following step state to learner.state
, grouped by the event from which on it is available.
StepBegin
:xs
andys
: encoded input and target (batch)
LossBegin
:ŷs
: model output
BackwardBegin
:loss
: loss
BackwardEnd
:grads
: calculated gradients
FluxTraining.Phases.ValidationPhase
— TypeValidationPhase()
A regular validation phase. It iterates over batches in learner.data.validation
and performs a forward pass.
Throws the following events: EpochBegin
, StepBegin
, LossBegin
, StepEnd
, EpochEnd
.
Throws the following events in this order:
EpochBegin
when an epoch starts,StepBegin
when a step starts,LossBegin
after the forward pass but before loss calculation,StepEnd
when a step ends; andEpochEnd
when an epoch ends
It writes the following step state to learner.state
, grouped by the event from which on it is available.
StepBegin
:xs
andys
: encoded input and target (batch)
LossBegin
:ŷs
: model output
StepEnd
:loss
: loss
FluxTraining.Events
— Modulemodule Events
Provides the abstract Event
type and concrete event types.
Events in TrainingPhase
and ValidationPhase
:
EpochBegin
andEpochEnd
, called at the beginning and end of each epoch.StepBegin
andStepEnd
, called at the beginning and end of each batch.LossBegin
, called after the forward pass but before the loss calculation.
TrainingPhase
only:
BackwardBegin
, called after forward pass and loss calculation but before gradient calculation.BackwardEnd
, called after gradient calculation but before parameter update.
FluxTraining.Events.BackwardBegin
— TypeBackwardBegin()
Event
called between calculating loss and calculating gradients
FluxTraining.Events.BackwardEnd
— TypeBackwardEnd()
Event
called between calculating gradients and updating parameters.
FluxTraining.Events.EpochBegin
— TypeEpochBegin()
Event
called at the beginning of an epoch.
FluxTraining.Events.EpochEnd
— TypeEpochEnd()
Event
called at the end of an epoch.
FluxTraining.Events.Event
— Typeabstract type Event
Abstract type for events that callbacks can hook into
FluxTraining.Events.LossBegin
— TypeLossBegin()
Event
called between calculating y_pred
and calculating loss
FluxTraining.Events.StepBegin
— TypeStepBegin()
Event
called at the beginning of a batch.
FluxTraining.Events.StepEnd
— TypeStepEnd()
Event
called at the end of a batch.