More on nodes

Node nesting

The main advantage of the Mill library is that it allows to arbitrarily nest and cross-product BagModels, as described in Theorem 5 in Tomáš Pevný , Vojtěch Kovařík (2019). In other words, instances themselves may be represented in much more complex way than in the BagNode and BagModel example.

Let's start the demonstration by nesting two MIL problems. The outer MIL model contains three samples (outer-level bags), whose instances are (inner-level) bags themselves. The first outer-level bag contains one inner-level bag problem with two inner-level instances, the second outer-level bag contains two inner-level bags with total of three inner-level instances, and finally the third outer-level bag contains two inner bags with four instances:

julia> ds = BagNode(BagNode(ArrayNode(randn(4, 10)),
                            [1:2, 3:4, 5:5, 6:7, 8:10]),
                    [1:1, 2:3, 4:5])
BagNode with 3 obs
  └── BagNode with 5 obs
        └── ArrayNode(4×10 Array with Float64 elements) with 10 obs

Here is one example of a model, which is appropriate for this hierarchy:

julia> using Flux: Dense, Chain, relu

julia> m = BagModel(
               BagModel(
                   ArrayModel(Dense(4, 3, relu)),
                   meanmax_aggregation(3),
                   ArrayModel(Dense(7, 3, relu))),
               meanmax_aggregation(3),
               ArrayModel(Chain(Dense(7, 3, relu), Dense(3, 2))))
BagModel … ↦ ⟨SegmentedMean(3), SegmentedMax(3)⟩ ↦ ArrayModel(Chain(Dense(7, 3, relu), Dense(3, 2)))
  └── BagModel … ↦ ⟨SegmentedMean(3), SegmentedMax(3)⟩ ↦ ArrayModel(Dense(7, 3, relu))
        └── ArrayModel(Dense(4, 3, relu))

and can be directly applied to obtain a result:

julia> m(ds)
2×3 ArrayNode{Array{Float32,2},Nothing}:
 0.36179775  0.5450924   0.5656006
 0.42562664  0.58594716  0.6534422

Here we again make use of the property that even if each instance is represented with an arbitrarily complex structure, we always obtain a vector representation after applying instance model im, regardless of the complexity of im and Mill.data(ds):

julia> m.im(Mill.data(ds))
3×5 ArrayNode{Array{Float32,2},Nothing}:
 0.26454517  0.24892195   0.0         0.47699857  0.22300115
 0.0         0.0          0.24698043  0.0         0.0
 0.0         0.122486025  0.15422387  0.0         0.06687504

In one final example we demonstrate a complex model consisting of all types of nodes introduced so far:

julia> ds = BagNode(ProductNode((BagNode(ArrayNode(randn(4, 10)),
                                         [1:2, 3:4, 5:5, 6:7, 8:10]),
                                 ArrayNode(randn(3, 5)),
                                 BagNode(BagNode(ArrayNode(randn(2, 30)),
                                                 [i:i+1 for i in 1:2:30]),
                                         [1:3, 4:6, 7:9, 10:12, 13:15]),
                                 ArrayNode(randn(2, 5)))),
                    [1:1, 2:3, 4:5])
BagNode with 3 obs
  └── ProductNode with 5 obs
        ├── BagNode with 5 obs
        │     ⋮
        ├── ArrayNode(3×5 Array with Float64 elements) with 5 obs
        ⋮
        └── ArrayNode(2×5 Array with Float64 elements) with 5 obs

Instead of defining a model manually, we make use of Model Reflection, another Mill.jl functionality, which simplifies model creation:

julia> m = reflectinmodel(ds)
BagModel … ↦ ⟨SegmentedMean(10), SegmentedMax(10)⟩ ↦ ArrayModel(Dense(21, 10))
  └── ProductModel … ↦ ArrayModel(Dense(40, 10))
        ├── BagModel … ↦ ⟨SegmentedMean(10), SegmentedMax(10)⟩ ↦ ArrayModel(Dense(21, 10))
        │     ⋮
        ├── ArrayModel(Dense(3, 10))
        ⋮
        └── ArrayModel(Dense(2, 10))

julia> m(ds)
10×3 ArrayNode{Array{Float32,2},Nothing}:
 -1.065589      -2.4036987   -1.2661439
  0.043658443    0.2956295    0.37168837
 -1.0958911     -1.4858038   -0.87121737
  1.1850482      1.0890555    1.1080894
  0.0103996815  -0.6553936   -0.5043688
 -0.1944636     -0.25665137  -0.114986196
  0.7850331      0.87222743   1.4043158
 -0.32910517     0.3159621   -0.27999502
 -0.12815368     1.2642355    0.332956
  0.010332675    0.20051233   0.088707946

Node conveniences

To make the handling of data and model hierarchies easier, Mill.jl provides several tools. Let's setup some data:

julia> AN = ArrayNode(Float32.([1 2 3 4; 5 6 7 8]))
2×4 ArrayNode{Array{Float32,2},Nothing}:
 1.0  2.0  3.0  4.0
 5.0  6.0  7.0  8.0

julia> AM = reflectinmodel(AN)
ArrayModel(Dense(2, 10))

julia> BN = BagNode(AN, [1:1, 2:3, 4:4])
BagNode with 3 obs
  └── ArrayNode(2×4 Array with Float32 elements) with 4 obs

julia> BM = reflectinmodel(BN)
BagModel … ↦ ⟨SegmentedMean(10), SegmentedMax(10)⟩ ↦ ArrayModel(Dense(21, 10))
  └── ArrayModel(Dense(2, 10))

julia> PN = ProductNode((a=ArrayNode(Float32.([1 2 3; 4 5 6])), b=BN))
ProductNode with 3 obs
  ├── a: ArrayNode(2×3 Array with Float32 elements) with 3 obs
  └── b: BagNode with 3 obs
           └── ArrayNode(2×4 Array with Float32 elements) with 4 obs

julia> PM = reflectinmodel(PN)
ProductModel … ↦ ArrayModel(Dense(20, 10))
  ├── a: ArrayModel(Dense(2, 10))
  └── b: BagModel … ↦ ⟨SegmentedMean(10), SegmentedMax(10)⟩ ↦ ArrayModel(Dense(21, 10))
           └── ArrayModel(Dense(2, 10))

Function: nobs

nobs function from StatsBase.jl returns a number of samples from the current level point of view. This number usually increases as we go down the tree when BagNodes are involved, as each bag may contain more than one instance.

julia> using StatsBase: nobs

julia> nobs(AN)
4

julia> nobs(BN)
3

julia> nobs(PN)
3

Indexing and Slicing

Indexing in Mill.jl operates on the level of observations:

julia> AN[1]
2×1 ArrayNode{Array{Float32,2},Nothing}:
 1.0
 5.0

julia> nobs(ans)
1

julia> BN[2]
BagNode with 1 obs
  └── ArrayNode(2×2 Array with Float32 elements) with 2 obs

julia> nobs(ans)
1

julia> PN[3]
ProductNode with 1 obs
  ├── a: ArrayNode(2×1 Array with Float32 elements) with 1 obs
  └── b: BagNode with 1 obs
           └── ArrayNode(2×1 Array with Float32 elements) with 1 obs

julia> nobs(ans)
1

julia> AN[[1, 4]]
2×2 ArrayNode{Array{Float32,2},Nothing}:
 1.0  4.0
 5.0  8.0

julia> nobs(ans)
2

julia> BN[1:2]
BagNode with 2 obs
  └── ArrayNode(2×3 Array with Float32 elements) with 3 obs

julia> nobs(ans)
2

julia> PN[[2, 3]]
ProductNode with 2 obs
  ├── a: ArrayNode(2×2 Array with Float32 elements) with 2 obs
  └── b: BagNode with 2 obs
           └── ArrayNode(2×3 Array with Float32 elements) with 3 obs

julia> nobs(ans)
2

julia> PN[Int[]]
ProductNode with 0 obs
  ├── a: ArrayNode(2×0 Array with Float32 elements) with 0 obs
  └── b: BagNode with 0 obs
           └── ArrayNode(2×0 Array with Float32 elements) with 0 obs

julia> nobs(ans)
0

This may be useful for creating minibatches and their permutations.

Note that apart from the perhaps apparent recurrent effect, this operation requires other implicit actions, such as properly recomputing bag indices:

julia> BN.bags
AlignedBags{Int64}(UnitRange{Int64}[1:1, 2:3, 4:4])

julia> BN[[1, 3]].bags
AlignedBags{Int64}(UnitRange{Int64}[1:1, 2:2])

Function: catobs

catobs function concatenates several datasets (trees) together:

julia> catobs(AN[1], AN[4])
2×2 ArrayNode{Array{Float32,2},Nothing}:
 1.0  4.0
 5.0  8.0

julia> catobs(BN[3], BN[[2, 1]])
BagNode with 3 obs
  └── ArrayNode(2×4 Array with Float32 elements) with 4 obs

julia> catobs(PN[[1, 2]], PN[3:3]) == PN
true

Again, the effect is recurrent and everything is appropriately recomputed:

julia> BN.bags
AlignedBags{Int64}(UnitRange{Int64}[1:1, 2:3, 4:4])

julia> catobs(BN[3], BN[[1]]).bags
AlignedBags{Int64}(UnitRange{Int64}[1:1, 2:2])
More tips

For more tips for handling datasets and models, see External tools.