Ignite.jl
Documentation for Ignite.jl.
Docstrings
Ignite.Ignite
Ignite.AbstractEvent
Ignite.AbstractFiringEvent
Ignite.AbstractLoopEvent
Ignite.AndEvent
Ignite.Engine
Ignite.EventHandler
Ignite.FilteredEvent
Ignite.OrEvent
Ignite.State
Ignite.add_event_handler!
Ignite.every_filter
Ignite.filter_event
Ignite.fire_event!
Ignite.fire_event_handler!
Ignite.getsomething!
Ignite.once_filter
Ignite.reset!
Ignite.run!
Ignite.terminate!
Ignite.throttle_filter
Ignite.timeout_filter
Ignite.@getsomething!
Ignite.Ignite
— ModuleIgnite.jl
Welcome to Ignite.jl
, a Julia port of the Python library ignite
for simplifying neural network training and validation loops using events and handlers.
Ignite.jl
provides a simple yet flexible engine and event system, allowing for the easy composition of training pipelines with various events such as artifact saving, metric logging, and model validation. Event-based training abstracts away the training loop, replacing it with:
- An engine which wraps a process function that consumes a single batch of data,
- An iterable data loader which produces said batches of data, and
- Events and corresponding event handlers which are attached to the engine, configured to fire at specific points during training.
Event handlers are much more flexibile compared to other approaches like callbacks: handlers can be any callable; multiple handlers can be attached to a single event; multiple events can trigger the same handler; and custom events can be defined to fire at user-specified points during training. This makes adding functionality to your training pipeline easy, minimizing the need to modify existing code.
Quick Start
The example below demonstrates how to use Ignite.jl
to train a simple neural network. Key features to note:
- The training step is factored out of the training loop: the
train_step
process function takes a batch of training data and computes the training loss, gradients, and updates the model parameters. - Data loaders can be any iterable collection. Here, we use a
DataLoader
fromMLUtils.jl
using Ignite
using Flux, Zygote, Optimisers, MLUtils # for training a neural network
# Build simple neural network and initialize Adam optimizer
model = Chain(Dense(1 => 32, tanh), Dense(32 => 1))
optim = Flux.setup(Optimisers.Adam(1.0f-3), model)
# Create mock data and data loaders
f(x) = 2x - x^3
xtrain, xtest = 2 * rand(Float32, 1, 10_000) .- 1, collect(reshape(range(-1.0f0, 1.0f0; length = 100), 1, :))
ytrain, ytest = f.(xtrain), f.(xtest)
train_data_loader = DataLoader((; x = xtrain, y = ytrain); batchsize = 64, shuffle = true, partial = false)
eval_data_loader = DataLoader((; x = xtest, y = ytest); batchsize = 10, shuffle = false)
# Create training engine:
# - `engine` is a reference to the parent `trainer` engine, created below
# - `batch` is a batch of training data, retrieved by iterating `train_data_loader`
# - (optional) return value is stored in `trainer.state.output`
function train_step(engine, batch)
x, y = batch
l, gs = Zygote.withgradient(m -> sum(abs2, m(x) .- y), model)
Optimisers.update!(optim, model, gs[1])
return Dict("loss" => l)
end
trainer = Engine(train_step)
# Start the training
Ignite.run!(trainer, train_data_loader; max_epochs = 25, epoch_length = 100)
Periodically evaluate model
The real power of Ignite.jl
comes when adding event handlers to our training engine.
Let's evaluate our model after every 5th training epoch. This can be easily incorporated without needing to modify any of the above training code:
- Create an
evaluator
engine which consumes batches of evaluation data - Add event handlers to the
evaluator
engine which accumulate a running average of evaluation metrics over batches of evaluation data; we useOnlineStats.jl
to make this easy. - Add an event handler to the
trainer
which runs theevaluator
on the evaluation data loader every 5 training epochs.
using OnlineStats: Mean, fit!, value # for tracking evaluation metrics
# Create an evaluation engine using `do` syntax:
evaluator = Engine() do engine, batch
x, y = batch
ypred = model(x) # evaluate model on a single batch of validation data
return Dict("ytrue" => y, "ypred" => ypred) # result is stored in `evaluator.state.output`
end
# Add event handlers to the evaluation engine to track metrics:
add_event_handler!(evaluator, STARTED()) do engine
# When `evaluator` starts, initialize the running mean
engine.state.metrics = Dict("abs_err" => Mean()) # new fields can be added to `engine.state` dynamically
end
add_event_handler!(evaluator, ITERATION_COMPLETED()) do engine
# Each iteration, compute eval metrics from predictions
o = engine.state.output
m = engine.state.metrics["abs_err"]
fit!(m, abs.(o["ytrue"] .- o["ypred"]) |> vec)
end
# Add an event handler to `trainer` which runs `evaluator` every 5 epochs:
add_event_handler!(trainer, EPOCH_COMPLETED(; every = 5)) do engine
Ignite.run!(evaluator, eval_data_loader)
@info "Evaluation metrics: abs_err = $(evaluator.state.metrics["abs_err"])"
end
# Run the trainer with periodic evaluation
Ignite.run!(trainer, train_data_loader; max_epochs = 25, epoch_length = 100)
Terminating a run
There are several ways to stop a training run before it has completed:
- Throw an exception as usual. This will immediately stop training. An
EXCEPTION_RAISED()
event will be subsequently be fired. - Use a keyboard interrupt, i.e. throw an
InterruptException
viaCtrl+C
orCmd+C
. Training will halt, and anINTERRUPT()
event will be fired. - Gracefully terminate via
Ignite.terminate!(trainer)
, or equivalently,trainer.should_terminate = true
. This will allow the current iteration to finish but no further iterations will begin. Then, aTERMINATE()
event will be fired followed by aCOMPLETED()
event.
Early stopping
To implement early stopping, we can add an event handler to trainer
which checks the evaluation metrics and gracefully terminates trainer
if the metrics fail to improve. To do so, we first define a training termination trigger using Flux.early_stopping
:
# Callback which returns `true` if the eval loss fails to decrease by
# at least `min_dist` for two consecutive evaluations
early_stop_trigger = Flux.early_stopping(2; init_score = Inf32, min_dist = 5.0f-3) do
return value(evaluator.state.metrics["abs_err"])
end
Then, we add an event handler to trainer
which checks the early stopping trigger and terminates training if the trigger returns true
:
# This handler must fire every 5th epoch, the same as the evaluation event handler,
# to ensure new evaluation metrics are available
add_event_handler!(trainer, EPOCH_COMPLETED(; every = 5)) do engine
if early_stop_trigger()
@info "Stopping early"
Ignite.terminate!(trainer)
end
end
# Run the trainer with periodic evaluation and early stopping
Ignite.run!(trainer, train_data_loader; max_epochs = 25, epoch_length = 100)
Note: instead of adding a new event, the evaluation event handler from the previous section could have been modified to check early_stop_trigger()
immediately after evaluator
is run.
Artifact saving
Logging artifacts can be easily added to the trainer, again without modifying the above code. For example, save the current model and optimizer state to disk every 10 epochs using JLD2.jl
:
using JLD2
# Save model and optimizer state every 10 epochs
add_event_handler!(trainer, EPOCH_COMPLETED(; every = 10)) do engine
model_state = Flux.state(model)
jldsave("model_and_optim.jld2"; model_state, optim)
@info "Saved model and optimizer state to disk"
end
Trigger multiple functions per event
Multiple event handlers can be added to the same event:
add_event_handler!(trainer, COMPLETED()) do engine
# Runs after training has completed
end
add_event_handler!(trainer, COMPLETED()) do engine
# Also runs after training has completed, after the above function runs
end
Attach the same handler to multiple events
The boolean operators |
and &
can be used to combine events together:
add_event_handler!(trainer, EPOCH_COMPLETED(; every = 10) | COMPLETED()) do engine
# Runs at the end of every 10th epoch, or when training is completed
end
throttled_event = EPOCH_COMPLETED(; every = 3) & EPOCH_COMPLETED(; event_filter = throttle_filter(30.0))
add_event_handler!(trainer, throttled_event) do engine
# Runs at the end of every 3rd epoch if at least 30s has passed since the last firing
end
Define custom events
Custom events can be created and fired at user-defined stages in the training process.
For example, suppose we want to define events that fire at the start and finish of both the backward pass and the optimizer step. All we need to do is define new event types that subtype AbstractLoopEvent
, and then fire them at appropriate points in the train_step
process function using fire_event!
:
struct BACKWARD_STARTED <: AbstractLoopEvent end
struct BACKWARD_COMPLETED <: AbstractLoopEvent end
struct OPTIM_STEP_STARTED <: AbstractLoopEvent end
struct OPTIM_STEP_COMPLETED <: AbstractLoopEvent end
function train_step(engine, batch)
x, y = batch
# Compute the gradients of the loss with respect to the model
fire_event!(engine, BACKWARD_STARTED())
l, gs = Zygote.withgradient(m -> sum(abs2, m(x) .- y), model)
engine.state.gradients = gs # the engine state can be accessed by event handlers
fire_event!(engine, BACKWARD_COMPLETED())
# Update the model's parameters
fire_event!(engine, OPTIM_STEP_STARTED())
Optimisers.update!(optim, model, gs[1])
fire_event!(engine, OPTIM_STEP_COMPLETED())
return Dict("loss" => l)
end
trainer = Engine(train_step)
Then, add event handlers for these custom events as usual:
add_event_handler!(trainer, BACKWARD_COMPLETED(; every = 10)) do engine
# This code runs after every 10th backward pass is completed
end
This page was generated using Literate.jl.
Ignite.AbstractEvent
— Typeabstract type AbstractEvent
Abstract supertype for all events.
Ignite.AbstractFiringEvent
— Typeabstract type AbstractFiringEvent <: AbstractEvent
Abstract supertype for all events which can trigger event handlers via fire_event!
.
Ignite.AbstractLoopEvent
— Typeabstract type AbstractLoopEvent <: AbstractFiringEvent
Abstract supertype for events fired during the normal execution of Ignite.run!
.
A default convenience constructor (EVENT::Type{<:AbstractLoopEvent})(; kwargs...)
is provided to allow for easy filtering of AbstractLoopEvent
s. For example, EPOCH_COMPLETED(every = 3)
will build a FilteredEvent
which is triggered every third time an EPOCH_COMPLETED()
event is fired. See filter_event
for allowed keywords.
By inheriting from AbstractLoopEvent
, custom events will inherit these convenience constructors, too. If this is undesired, one can instead inherit from the supertype AbstractFiringEvent
.
Ignite.AndEvent
— Typestruct AndEvent{E1<:AbstractEvent, E2<:AbstractEvent} <: AbstractEvent
AndEvent(event1, event2)
wraps two events and triggers if and only if both wrapped events are triggered by the same firing event firing.
AndEvent
s can be constructed via the &
operator: event1 & event2
.
Fields:
event1::AbstractEvent
: The first wrapped event that will be considered for triggering.event2::AbstractEvent
: The second wrapped event that will be considered for triggering.
Ignite.Engine
— Typemutable struct Engine{P}
An Engine
struct to be run using Ignite.run!
. Can be constructed via engine = Engine(process_function; kwargs...)
, where the process function takes two arguments: the parent engine
, and a batch of data.
Fields:
process_function::Any
: A function that processes a single batch of data and returns an output.state::State
: An object that holds the current state of the engine.event_handlers::Vector{EventHandler}
: A list of event handlers that are called at specific points when the engine is running.logger::Union{Nothing, Base.CoreLogging.AbstractLogger}
: An optional logger; ifnothing
, thencurrent_logger()
will be used.timer::TimerOutputs.TimerOutput
: Internal timer. Can be used withTimerOutputs
to record event timingsshould_terminate::Bool
: A flag that indicates whether the engine should stop running.exception::Union{Nothing, Exception}
: Exception thrown during training
Ignite.EventHandler
— Typestruct EventHandler{E<:AbstractEvent, H, A<:Tuple}
EventHandler
s wrap an event
and a corresponding handler!
. The handler!
is executed when event
is triggered by a call to fire_event!
. The output from handler!
is ignored. Additional args
for handler!
may be stored in EventHandler
at construction; see add_event_handler!
.
When h::EventHandler
is triggered, the event handler is called as h.handler!(engine::Engine, h.args...)
.
Fields:
event::AbstractEvent
: Event which triggers handlerhandler!::Any
: Event handler which executes when triggered byevent
args::Tuple
: Additional arguments passed to the event handler
Ignite.FilteredEvent
— Typestruct FilteredEvent{E<:AbstractEvent, F} <: AbstractEvent
FilteredEvent(event::E, event_filter::F)
wraps an event
and a event_filter
function.
When a firing event e
is fired, if event_filter(engine, e)
returns true
then the filtered event will be fired too.
Fields:
event::AbstractEvent
: The wrapped event that will be fired if the filter function returns true when applied to a firing event.event_filter::Any
: The filter function(::Engine, ::AbstractFiringEvent) -> Bool
returnstrue
if the filtered event should be fired.
Ignite.OrEvent
— Typestruct OrEvent{E1<:AbstractEvent, E2<:AbstractEvent} <: AbstractEvent
OrEvent(event1, event2)
wraps two events and triggers if either of the wrapped events are triggered by a firing event firing.
OrEvent
s can be constructed via the |
operator: event1 | event2
.
Fields:
event1::AbstractEvent
: The first wrapped event that will be checked if it should be fired.event2::AbstractEvent
: The second wrapped event that will be checked if it should be fired.
Ignite.State
— Typestruct State <: AbstractDict{Symbol, Any}
Current state of the engine.
State
is a light wrapper around a DefaultOrderedDict{Symbol, Any, Nothing}
with the following keys:
:iteration
: the current iteration, initialized with 0, incremented immediately beforeITERATION_STARTED()
event is fired.:epoch
: the current epoch, initialized with 0, set toiteration ÷ epoch_length + 1
immediately beforeEPOCH_STARTED()
event is fired.:epoch_iteration
: the current iteration within the current epoch, initialized with 0, set tomod1(iteration, epoch_length)
immediately beforeITERATION_STARTED()
event is fired.:max_epochs
: The number of epochs to run.:epoch_length
: The number of batches processed per epoch.:output
: The output ofprocess_function
after a single iteration.:last_event
: The last event fired.:counters
: ADefaultOrderedDict{AbstractFiringEvent, Int, Int}(0)
with firing event firing counters.:times
: AnOrderedDict{AbstractFiringEvent, Float64}()
with total and per-epoch times fetched on firing event keys.
Fields can be accessed and modified using getproperty
and setproperty!
. For example, engine.state.iteration
can be used to access the current iteration, and engine.state.new_field = value
can be used to store value
for later use e.g. by an event handler.
Ignite.add_event_handler!
— Methodadd_event_handler!(
handler!,
engine::Engine,
event::AbstractEvent,
handler_args...
) -> Engine
Add an event handler to an engine which is fired when event
is triggered.
When fired, the event handler is called as handler!(engine::Engine, handler_args...)
.
Ignite.every_filter
— Methodevery_filter(
every::Union{Int64, AbstractVector{Int64}}
) -> Union{Ignite.EveryFilter{Int64}, Ignite.EveryFilter{T} where T<:AbstractVector{Int64}}
Creates an event filter function for use in a FilteredEvent
that returns true
periodically depending on every
:
- If
every = n::Int
, the filter will trigger everyn
th firing of the event. - If
every = Int[n₁, n₂, ...]
, the filter will trigger everyn₁
th firing, everyn₂
th firing, and so on.
Ignite.filter_event
— MethodFilter the input event
to fire conditionally:
Inputs:
event::AbstractFiringEvent
: event to be filtered.event_filter::Any
: A event_filter function(::Engine, ::AbstractFiringEvent) -> Bool
returningtrue
if the filtered event should be fired.every::Union{Int, <:AbstractVector{Int}}
: the period(s) in which the filtered event should be fired; seeevery_filter
.once::Union{Int, <:AbstractVector{Int}}
: the point(s) at which the filtered event should be fired; seeonce_filter
.
Ignite.fire_event!
— Methodfire_event!(
engine::Engine,
e::AbstractFiringEvent
) -> Engine
Execute all event handlers triggered by the firing event e
.
Ignite.fire_event_handler!
— Methodfire_event_handler!(
engine::Engine,
event_handler!::EventHandler,
e::AbstractFiringEvent
) -> Engine
Execute event_handler!
if it is triggered by the firing event e
.
Ignite.getsomething!
— Methodgetsomething!(f, s::State, key) -> Any
Get the field key
from the state s
if it exists and is not nothing
, otherwise set it to f()
and return that value.
Functional form of @getsomething!
. For example, the following two expressions are equivalent:
@getsomething! engine.state.x = init
getsomething!(()->init, engine.state, :x)
Ignite.once_filter
— Methodonce_filter(
once::Union{Int64, AbstractVector{Int64}}
) -> Union{Ignite.OnceFilter{Int64}, Ignite.OnceFilter{T} where T<:AbstractVector{Int64}}
Creates an event filter function for use in a FilteredEvent
that returns true
at specific points depending on once
:
- If
once = n::Int
, the filter will trigger only on then
th firing of the event. - If
once = Int[n₁, n₂, ...]
, the filter will trigger only on then₁
th firing, then₂
th firing, and so on.
Ignite.reset!
— Methodreset!(engine::Engine) -> Engine
Reset the engine state.
Ignite.run!
— Methodrun!(
engine::Engine,
dataloader;
max_epochs,
epoch_length,
resume,
resume_from_epoch,
resume_from_iteration
) -> Engine
Run the engine
. Data batches are retrieved by iterating dataloader
. The data loader may be infinite; by default, it is restarted if it empties.
Inputs:
engine::Engine
: An instance of theEngine
struct containing theprocess_function
to run each iteration.dataloader
: A data loader to iterate over.max_epochs::Int
: the number of epochs to run. Defaults to 1.epoch_length::Int
: the length of an epoch. Ifnothing
, falls back tolength(dataloader)
.
Conceptually, running the engine is roughly equivalent to the following:
- The engine state is initialized.
- The engine begins running for
max_epochs
epochs, or untilengine.should_terminate == true
. - At the start of each epoch,
EPOCH_STARTED()
event is fired. - An iteration loop is performed for
epoch_length
number of iterations, or untilengine.should_terminate == true
. - At the start of each iteration,
ITERATION_STARTED()
event is fired, and a batch of data is loaded. - The
process_function
is called on the loaded data batch. - At the end of each iteration,
ITERATION_COMPLETED()
event is fired. - At the end of each epoch,
EPOCH_COMPLETED()
event is fired. - At the end of all the epochs,
COMPLETED()
event is fired.
If engine.should_terminate
is set to true
while running the engine, the engine will be terminated gracefully after the next completed iteration. This will subsequently trigger a TERMINATE()
event to be fired followed by a COMPLETED()
event.
Ignite.terminate!
— Methodterminate!(engine::Engine) -> Engine
Terminate the engine by setting engine.should_terminate = true
.
Ignite.throttle_filter
— Functionthrottle_filter(throttle::Real) -> Ignite.ThrottleFilter
throttle_filter(
throttle::Real,
last_fire::Real
) -> Ignite.ThrottleFilter
Creates an event filter function for use in a FilteredEvent
that returns true
if at least throttle
seconds has passed since it was last fired.
Ignite.timeout_filter
— Functiontimeout_filter(timeout::Real) -> Ignite.TimeoutFilter
timeout_filter(
timeout::Real,
start_time::Real
) -> Ignite.TimeoutFilter
Creates an event filter function for use in a FilteredEvent
that returns true
if at least timeout
seconds has passed since the filter function was created.
Ignite.@getsomething!
— Macroy = @getsomething! engine.state.x = init
Get the field x
from engine.state
if it exists and is not nothing
, otherwise set it to init
and return that value. Equivalent to
y = engine.state.x
if y === nothing
y = engine.state.x = init
end
Useful for initializing stateful fields. For example, to record all losses during training one could do the following:
losses = @getsomething! engine.state.losses = Float64[]
push!(losses, loss)
See also the functional form getsomething!
.