Adding custom nodes
Mill.jl
data nodes are lightweight wrappers around data, such as Array
, DataFrame
, and others. When implementing custom nodes, it is recommended to equip them with the following functionality to fit better into Mill.jl
environment:
- allow nesting (if needed)
- implement
getindex
to obtain subsets of observations. For this purpose,Mill.jl
defines aMill.subset
function for common datatypes, which can be used. - allow concatenation of nodes with
catobs
. Optionally, implementreduce(catobs, ...)
as well to avoid excessive compilations if a number of arguments will vary a lot - define a specialized method for
nobs
- register the custom node with HierarchicalUtils.jl to obtain pretty printing, iterators and other functionality
Unix path example
Let's define one custom node type for representing pathnames in Unix and one custom model type for processing it. We'll start by defining the structure holding pathnames:
struct PathNode{S <: AbstractString, C} <: AbstractNode
data::Vector{S}
metadata::C
end
PathNode(data::Vector{S}) where {S <: AbstractString} = PathNode(data, nothing)
We will support nobs
:
import StatsBase: nobs
Base.ndims(x::PathNode) = Colon()
nobs(a::PathNode) = length(a.data)
concatenation:
function Base.reduce(::typeof(catobs), as::Vector{T}) where {T <: PathNode}
PathNode(data, reduce(vcat, data.(as)), reduce(catobs, metadata.(as)))
end
and indexing:
function Base.getindex(x::PathNode, i::Mill.VecOrRange{<:Int})
PathNode(Mill.subset(Mill.data(x), i), Mill.subset(Mill.metadata(x), i))
end
The last touch is to add the definition needed by HierarchicalUtils.jl:
import HierarchicalUtils
HierarchicalUtils.NodeType(::Type{<:PathNode}) = HierarchicalUtils.LeafNode()
HierarchicalUtils.noderepr(n::PathNode) = "PathNode ($(nobs(n)) obs)"
Now, we are ready to create the first PathNode
:
julia> ds = PathNode(["/etc/passwd", "/home/tonda/.bashrc"])
PathNode (2 obs)
Similarly, we define a model node type which will be a counterpart processing the data:
struct PathModel{T, F} <: AbstractMillModel
m::T
path2mill::F
end
Flux.@functor PathModel
Note that the part of the model node is a function which converts the pathname string to a Mill.jl
structure. For simplicity, we use a trivial NGramMatrix
representation in this example and define path2mill
as follows:
function path2mill(s::String)
ss = String.(split(s, "/"))
BagNode(ArrayNode(Mill.NGramMatrix(ss, 3)), AlignedBags([1:length(ss)]))
end
path2mill(ss::Vector{S}) where {S <: AbstractString} = reduce(catobs, map(path2mill, ss))
path2mill(ds::PathNode) = path2mill(ds.data)
Now we define how the model node is applied:
(m::PathModel)(x::PathNode) = m.m(m.path2mill(x))
And again, define everything needed in HierarchicalUtils.jl:
HierarchicalUtils.NodeType(::Type{<:PathModel}) = HierarchicalUtils.LeafNode()
HierarchicalUtils.noderepr(n::PathModel) = "PathModel"
Let's test that everything works:
julia> pm = PathModel(reflectinmodel(path2mill(ds)), path2mill)
PathModel
julia> pm(ds).data
10×2 Array{Float32,2}:
-0.47714 -0.572013
-0.0310436 -0.0704895
0.391016 0.422294
0.41583 0.480285
0.126027 0.187971
-0.30234 -0.398954
-0.171723 -0.252697
-0.483189 -0.674319
0.167859 0.307745
0.550983 0.667438
The final touch would be to overload the reflectinmodel
as
function Mill.reflectinmodel(ds::PathNode, args...)
pm = reflectinmodel(path2mill(ds), args...)
PathModel(pm, path2mill)
end
which makes things even easier
julia> pm = reflectinmodel(ds)
PathModel
julia> pm(ds).data
10×2 Array{Float32,2}:
0.66415 0.649883
0.19207 0.208901
0.0267312 -0.00317923
0.0642215 0.0501003
-0.105406 -0.167113
-0.476674 -0.501406
-0.281625 -0.395572
-0.60416 -0.703313
-0.0924248 -0.158902
0.596379 0.542492