Adding custom nodes

Mill.jl data nodes are lightweight wrappers around data, such as Array, DataFrame, and others. When implementing custom nodes, it is recommended to equip them with the following functionality to fit better into Mill.jl environment:

  • allow nesting (if needed)
  • implement getindex to obtain subsets of observations. For this purpose, Mill.jl defines a Mill.subset function for common datatypes, which can be used.
  • allow concatenation of nodes with catobs. Optionally, implement reduce(catobs, ...) as well to avoid excessive compilations if a number of arguments will vary a lot
  • define a specialized method for nobs
  • register the custom node with HierarchicalUtils.jl to obtain pretty printing, iterators and other functionality

Unix path example

Let's define one custom node type for representing pathnames in Unix and one custom model type for processing it. We'll start by defining the structure holding pathnames:

struct PathNode{S <: AbstractString, C} <: AbstractNode
    data::Vector{S}
    metadata::C
end

PathNode(data::Vector{S}) where {S <: AbstractString} = PathNode(data, nothing)

We will support nobs:

import StatsBase: nobs
Base.ndims(x::PathNode) = Colon()
nobs(a::PathNode) = length(a.data)

concatenation:

function Base.reduce(::typeof(catobs), as::Vector{T}) where {T <: PathNode}
    PathNode(data, reduce(vcat, data.(as)), reduce(catobs, metadata.(as)))
end

and indexing:

function Base.getindex(x::PathNode, i::Mill.VecOrRange{<:Int})
    PathNode(Mill.subset(Mill.data(x), i), Mill.subset(Mill.metadata(x), i))
end

The last touch is to add the definition needed by HierarchicalUtils.jl:

import HierarchicalUtils
HierarchicalUtils.NodeType(::Type{<:PathNode}) = HierarchicalUtils.LeafNode()
HierarchicalUtils.noderepr(n::PathNode) = "PathNode ($(nobs(n)) obs)"

Now, we are ready to create the first PathNode:

julia> ds = PathNode(["/etc/passwd", "/home/tonda/.bashrc"])
PathNode (2 obs)

Similarly, we define a model node type which will be a counterpart processing the data:

struct PathModel{T, F} <: AbstractMillModel
    m::T
    path2mill::F
end

Flux.@functor PathModel

Note that the part of the model node is a function which converts the pathname string to a Mill.jl structure. For simplicity, we use a trivial NGramMatrix representation in this example and define path2mill as follows:

function path2mill(s::String)
    ss = String.(split(s, "/"))
    BagNode(ArrayNode(Mill.NGramMatrix(ss, 3)), AlignedBags([1:length(ss)]))
end

path2mill(ss::Vector{S}) where {S <: AbstractString} = reduce(catobs, map(path2mill, ss))
path2mill(ds::PathNode) = path2mill(ds.data)

Now we define how the model node is applied:

(m::PathModel)(x::PathNode) = m.m(m.path2mill(x))

And again, define everything needed in HierarchicalUtils.jl:

HierarchicalUtils.NodeType(::Type{<:PathModel}) = HierarchicalUtils.LeafNode()
HierarchicalUtils.noderepr(n::PathModel) = "PathModel"

Let's test that everything works:

julia> pm = PathModel(reflectinmodel(path2mill(ds)), path2mill)
PathModel

julia> pm(ds).data
10×2 Array{Float32,2}:
 -0.47714    -0.572013
 -0.0310436  -0.0704895
  0.391016    0.422294
  0.41583     0.480285
  0.126027    0.187971
 -0.30234    -0.398954
 -0.171723   -0.252697
 -0.483189   -0.674319
  0.167859    0.307745
  0.550983    0.667438

The final touch would be to overload the reflectinmodel as

function Mill.reflectinmodel(ds::PathNode, args...)
    pm = reflectinmodel(path2mill(ds), args...)
    PathModel(pm, path2mill)
end

which makes things even easier

julia> pm = reflectinmodel(ds)
PathModel

julia> pm(ds).data
10×2 Array{Float32,2}:
  0.66415     0.649883
  0.19207     0.208901
  0.0267312  -0.00317923
  0.0642215   0.0501003
 -0.105406   -0.167113
 -0.476674   -0.501406
 -0.281625   -0.395572
 -0.60416    -0.703313
 -0.0924248  -0.158902
  0.596379    0.542492