GNNs in 16 lines

As has been mentioned in Šimon Mandlík , Tomáš Pevný (2020), multiple instance learning is an essential piece for implementing message passing inference over graphs, the main concept behind spatial Graph Neural Networks (GNNs). It is straightforward and quick to achieve this with Mill.jl. We begin with some dependencies:

using Flux, LightGraphs, Statistics

Let's assume a graph g, in this case created by barabasi_albert function from LightGraphs.jl

julia> g = barabasi_albert(10, 3, 2)
{10, 14} undirected simple Int64 graph

Furthermore, let's assume that each vertex is described by seven features stored in a matrix X:

julia> X = ArrayNode(randn(Float32, 7, 10))
7×10 ArrayNode{Array{Float32,2},Nothing}:
  0.029847916  -0.1444097   -0.77922016   …   0.6306408    -0.29290652
 -0.4743148     0.8359631    0.7541415        0.106086716   1.6031942
  1.2026051    -0.20845382  -0.65242165      -0.7265922    -0.4708613
 -0.6189464    -0.19995673  -1.1439679        0.35289836    1.344933
 -0.69497037   -1.4790846   -0.87187344      -2.1563978     1.3466836
 -2.179117      1.7636526    0.118887305  …  -0.15667026   -0.97331053
 -1.3534855     0.7487757   -2.0543125       -0.9947581    -0.67006004

We use ScatteredBags from Mill.jl to encode neighbors of each vertex. In other words, each vertex is described by a bag of its neighbors. This information is conveniently stored in fadjlist field of g, therefore the bags can be constructed as:

julia> b = ScatteredBags(g.fadjlist)
ScatteredBags{Int64}([[4, 5, 6, 8, 10], Int64[], [4, 5], [1, 3, 6, 7], [1, 3, 8], [1, 4, 7, 9], [4, 6], [1, 5, 9], [6, 8, 10], [1, 9]])

Finally, we create two models. First model called lift will pre-process the description of vertices to some latent space for message passing, and the second one will realize the message passing itself, which we will call mp:

julia> lift = reflectinmodel(X, d -> Dense(d, 10), d -> mean_aggregation(d))
ArrayModel(Dense(7, 10))

julia> U = lift(X)
10×10 ArrayNode{Array{Float32,2},Nothing}:
  0.22535409   0.3822525    0.2969856   …   0.775619    -1.243016
  0.5277017   -0.2792255   -1.0729944      -0.33189312  -1.247246
  1.6393183   -0.18840593   0.5840895       0.24021442  -0.5072842
 -0.30440924  -0.46194744  -0.20234056     -1.0374451    1.0918131
 -1.5296241    0.22416803  -0.8389578      -1.196345     0.34388122
 -1.4859715   -0.06790199   0.57404274  …  -0.8074341    2.2101219
  0.61812794  -1.5480729   -0.6891072      -0.546315     0.19719586
 -1.3225273    0.00831057  -0.5432095      -1.2234465    1.0026237
 -1.4688778    1.3808706   -0.5251934      -0.08397776  -0.7011221
 -0.65222025  -0.24466255  -0.31357002     -0.16080739  -0.17337486

julia> mp = reflectinmodel(BagNode(U, b), d -> Dense(d, 10), d -> mean_aggregation(d))
BagModel … ↦ ⟨SegmentedMean(10)⟩ ↦ ArrayModel(Dense(11, 10))
  └── ArrayModel(Dense(10, 10))

Notice that BagNode(U, b) now essentially encodes vertex features as well as the adjacency matrix. This also means that one step of message passing algorithm can be realized as:

julia> Y = mp(BagNode(U, b))
10×10 ArrayNode{Array{Float32,2},Nothing}:
 -0.035713222  0.0  -0.22519352   0.1356892    …   0.06734029   0.4295932
  0.5615496    0.0   0.7537555    0.074028336      0.2551962   -0.4399602
  0.98289657   0.0   1.7433251    0.14271578       0.2615756   -0.7452709
 -0.55371916   0.0   0.46866804  -1.5699606       -0.94492227  -2.1067574
  0.5289915    0.0   0.65951484   1.1315223        0.08326764   1.4815167
  0.1337585    0.0  -0.90799093   1.2921826    …   0.49060202   2.0067813
  0.7641726    0.0   1.5191854   -0.2869706        0.23503898  -1.3377743
  0.34246188   0.0   1.6722084   -0.7982003       -0.26235634  -1.7717428
  0.78609854   0.0   0.5489987    0.31208873       0.8744348   -0.5674417
 -0.6528163    0.0  -0.7301485   -0.5214443       -0.3157166   -0.17945614

and it is differentiable, which can be verified by executing:

julia> gradient(() -> sum(sin.(mp(BagNode(U, b)) |> Mill.data)), params(mp))
Grads(...)

If we put everything together, the GNN implementation is implemented in the following 16 lines:

struct GNN{L,M, R}
    lift::L
    mp::M
    m::R
end

Flux.@functor GNN

function mpstep(m::GNN, U::ArrayNode, bags, n)
    n == 0 && return(U)
    mpstep(m, m.mp(BagNode(U, bags)), bags, n - 1)
end

function (m::GNN)(g, X, n)
    U = m.lift(X)
    bags = Mill.ScatteredBags(g.fadjlist)
    o = mpstep(m, U, bags, n)
    m.m(vcat(mean(Mill.data(o), dims = 2), maximum(Mill.data(o), dims = 2)))
end

As it is the case with whole Mill.jl, even this graph neural network is properly integrated with Flux.jl ecosystem and suports automatic differentiation:

zd = 10
f(d) = Chain(Dense(d, zd, relu), Dense(zd, zd))
agg(d) = meanmax_aggregation(d)
gnn = GNN(reflectinmodel(X, f, agg),
          BagModel(f(zd), agg(zd), f(2zd + 1)),
          f(2zd))
julia> gnn(g, X, 5)
10×1 Array{Float32,2}:
  0.07108084
 -0.24861374
 -0.023377487
  0.25094596
  0.45109645
 -0.027826611
 -0.06066554
 -0.29686767
 -0.35696217
 -0.18162829

julia> gradient(() -> gnn(g, X, 5) |> sum, params(gnn))
Grads(...)

The above implementation is surprisingly general, as it supports an arbitrarily rich description of vertices. For simplicity, we used only vectors in X, however, any Mill.jl hierarchy is applicable.

To put different weights on edges, one can use WeightedBagNodes instead.