Model Reflection

Since constructions of large models can be a tedious and error-prone process, Mill.jl provides reflectinmodel function that helps to automate it. The simplest definition accepts only one argument, a sample ds, and returns a compatible model:

julia> ds = BagNode(ProductNode((BagNode(ArrayNode(randn(4, 10)),
                                         [1:2, 3:4, 5:5, 6:7, 8:10]),
                                 ArrayNode(randn(3, 5)),
                                 BagNode(BagNode(ArrayNode(randn(2, 30)),
                                                 [i:i+1 for i in 1:2:30]),
                                         [1:3, 4:6, 7:9, 10:12, 13:15]),
                                 ArrayNode(randn(2, 5)))),
                    [1:1, 2:3, 4:5]);

julia> printtree(ds)
BagNode with 3 obs
  └── ProductNode with 5 obs
        ├── BagNode with 5 obs
        │     └── ArrayNode(4×10 Array with Float64 elements) with 10 obs
        ├── ArrayNode(3×5 Array with Float64 elements) with 5 obs
        ├── BagNode with 5 obs
        │     └── BagNode with 15 obs
        │           └── ArrayNode(2×30 Array with Float64 elements) with 30 obs
        └── ArrayNode(2×5 Array with Float64 elements) with 5 obs
julia> m = reflectinmodel(ds);

julia> printtree(m)
BagModel … ↦ ⟨SegmentedMean(10), SegmentedMax(10)⟩ ↦ ArrayModel(Dense(21, 10))
  └── ProductModel … ↦ ArrayModel(Dense(40, 10))
        ├── BagModel … ↦ ⟨SegmentedMean(10), SegmentedMax(10)⟩ ↦ ArrayModel(Dense(21, 10))
        │     └── ArrayModel(Dense(4, 10))
        ├── ArrayModel(Dense(3, 10))
        ├── BagModel … ↦ ⟨SegmentedMean(10), SegmentedMax(10)⟩ ↦ ArrayModel(Dense(21, 10))
        │     └── BagModel … ↦ ⟨SegmentedMean(10), SegmentedMax(10)⟩ ↦ ArrayModel(Dense(21, 10))
        │           └── ArrayModel(Dense(2, 10))
        └── ArrayModel(Dense(2, 10))
julia> m(ds)
10×3 ArrayNode{Array{Float32,2},Nothing}:
 -0.038918063  -1.0269539    -1.019303
  1.1405787     0.4443741    -0.22033617
 -0.41863015   -2.121142     -0.9783317
  0.35558683    0.25830314   -0.9468039
  0.034859784  -0.503356     -0.035982188
 -0.80761534   -0.35997766    0.39803472
 -0.6087157     0.06487247    1.0217441
 -1.7817953    -0.029188544   1.4289896
 -0.73254365    2.8453107    -1.0076041
 -0.48766953   -0.9546175     1.3383192

The sample ds serves here as a specimen needed to specify a structure of the problem and calculate dimensions.

Optional arguments

To have better control over the topology, reflectinmodel accepts up to two more optional arguments and four keyword arguments:

  • The first optional argument expects a function that returns a layer (or a set of layers) given input dimension d (defaults to d -> Flux.Dense(d, 10)).
  • The second optional argument is a function returning aggregation function for BagModel nodes (defaults to d -> mean_aggregation(d)).

Compare the following example to the previous one:

using Flux
julia> m = reflectinmodel(ds, d -> Dense(d, 5, relu), d -> max_aggregation(d));

julia> printtree(m)
BagModel … ↦ ⟨SegmentedMax(5)⟩ ↦ ArrayModel(Dense(6, 5, relu))
  └── ProductModel … ↦ ArrayModel(Dense(20, 5, relu))
        ├── BagModel … ↦ ⟨SegmentedMax(5)⟩ ↦ ArrayModel(Dense(6, 5, relu))
        │     └── ArrayModel(Dense(4, 5, relu))
        ├── ArrayModel(Dense(3, 5, relu))
        ├── BagModel … ↦ ⟨SegmentedMax(5)⟩ ↦ ArrayModel(Dense(6, 5, relu))
        │     └── BagModel … ↦ ⟨SegmentedMax(5)⟩ ↦ ArrayModel(Dense(6, 5, relu))
        │           └── ArrayModel(Dense(2, 5, relu))
        └── ArrayModel(Dense(2, 5, relu))
julia> m(ds)
5×3 ArrayNode{Array{Float32,2},Nothing}:
 0.26268238  0.051359285  0.0
 0.0         0.0          0.0
 0.44163176  0.9266923    0.4310562
 0.08718503  0.0          0.0
 0.0         0.0          0.0

Keyword arguments

The reflectinmodel allows even further customization. To index into the sample (or model), we can use printtree(ds; trav=true) from HierarchicalUtils.jl that prints the sample together with identifiers of individual nodes:

using HierarchicalUtils
julia> printtree(ds; trav=true)
BagNode with 3 obs [""]
  └── ProductNode with 5 obs ["U"]
        ├── BagNode with 5 obs ["Y"]
        │     └── ArrayNode(4×10 Array with Float64 elements) with 10 obs ["a"]
        ├── ArrayNode(3×5 Array with Float64 elements) with 5 obs ["c"]
        ├── BagNode with 5 obs ["g"]
        │     └── BagNode with 15 obs ["i"]
        │           └── ArrayNode(2×30 Array with Float64 elements) with 30 obs ["j"]
        └── ArrayNode(2×5 Array with Float64 elements) with 5 obs ["k"]

These identifiers can be used to override the default construction functions. Note that the output, i.e. the last feed-forward network of the whole model is always tagged with an empty string "", which simplifies putting linear layer with an appropriate output dimension on the end. Dictionaries with these overrides can be passed in as keyword arguments:

  • fsm overrides constructions of feed-forward models
  • fsa overrides construction of aggregation functions.

For example to specify just the last feed forward neural network:

julia> reflectinmodel(ds, d -> Dense(d, 5, relu), d -> meanmax_aggregation(d);
           fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12)))) |> printtree
BagModel … ↦ ⟨SegmentedMean(5), SegmentedMax(5)⟩ ↦ ArrayModel(Chain(Dense(11, 20, relu), Dense(20, 12)))
  └── ProductModel … ↦ ArrayModel(Dense(20, 5, relu))
        ├── BagModel … ↦ ⟨SegmentedMean(5), SegmentedMax(5)⟩ ↦ ArrayModel(Dense(11, 5, relu))
        │     └── ArrayModel(Dense(4, 5, relu))
        ├── ArrayModel(Dense(3, 5, relu))
        ├── BagModel … ↦ ⟨SegmentedMean(5), SegmentedMax(5)⟩ ↦ ArrayModel(Dense(11, 5, relu))
        │     └── BagModel … ↦ ⟨SegmentedMean(5), SegmentedMax(5)⟩ ↦ ArrayModel(Dense(11, 5, relu))
        │           └── ArrayModel(Dense(2, 5, relu))
        └── ArrayModel(Dense(2, 5, relu))

Both keyword arguments in action:

julia> reflectinmodel(ds, d -> Dense(d, 5, relu), d -> meanmax_aggregation(d);
           fsm = Dict("" => d -> Chain(Dense(d, 20, relu), Dense(20, 12))),
           fsa = Dict("Y" => d -> mean_aggregation(d), "g" => d -> pnorm_aggregation(d))) |> printtree
BagModel … ↦ ⟨SegmentedMean(5), SegmentedMax(5)⟩ ↦ ArrayModel(Chain(Dense(11, 20, relu), Dense(20, 12)))
  └── ProductModel … ↦ ArrayModel(Dense(20, 5, relu))
        ├── BagModel … ↦ ⟨SegmentedMean(5)⟩ ↦ ArrayModel(Dense(6, 5, relu))
        │     └── ArrayModel(Dense(4, 5, relu))
        ├── ArrayModel(Dense(3, 5, relu))
        ├── BagModel … ↦ ⟨SegmentedPNorm(5)⟩ ↦ ArrayModel(Dense(6, 5, relu))
        │     └── BagModel … ↦ ⟨SegmentedMean(5), SegmentedMax(5)⟩ ↦ ArrayModel(Dense(11, 5, relu))
        │           └── ArrayModel(Dense(2, 5, relu))
        └── ArrayModel(Dense(2, 5, relu))