FluxPrune

Build Status

FluxPrune.jl provides iterative pruning algorithms for Flux models. Pruning strategies can be unstructured or structured. Unstructured strategies operate on arrays, while structured strategies operate on layers.

Examples

Unstructured edge pruning:

using Flux, FluxPrune

m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32), Dense(512, 10))
# prune all weights to 70% sparsity
 = prune(LevelPrune(0.7), m)
# prune all weights with magnitude lower than 0.5
 = prune(ThresholdPrune(0.5), m)
# prune each layer in a Chain at a different rate
# (just uses broadcasting then re-Chains)
 = prune([LevelPrune(0.4), LevelPrune(0.6), LevelPrune(0.7)], m)

Structured channel pruning:

using Flux, FluxPrune

m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32), Dense(512, 10))
# prune all conv layer channels to 30% sparsity
 = prune(ChannelPrune(0.3), m)

Mixed pruning:

using Flux, FluxPrune

m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32), Dense(512, 10))
# apply channel and edge pruning
 = prune([ChannelPrune(0.3), ChannelPrune(0.4), LevelPrune(0.8)], m)

Iterative pruning:

using Flux, FluxPrune

m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32), Dense(512, 10))
# target pruning levels step-by-step
# the first argument to iterativeprune will finetune the model
#  and return true to indicate moving onto the next stage,
#  or false to indicate that finetune must be called again
stages = [[ChannelPrune(0.1), ChannelPrune(0.1), LevelPrune(0.4)],
          [ChannelPrune(0.2), ChannelPrune(0.3), LevelPrune(0.7)],
          [ChannelPrune(0.3), ChannelPrune(0.5), LevelPrune(0.9)]]
 = iterativeprune(stages, m) do 
    opt = Momentum()
    ps = Flux.params()
    Flux.@epochs 10 Flux.train!(loss, ps, data, opt)

    return accuracy(data, ) > target_accuracy
end