GNNs in 16 lines
As has been mentioned in Šimon Mandlík , Tomáš Pevný (2020), multiple instance learning is an essential piece for implementing message passing inference over graphs, the main concept behind spatial Graph Neural Networks (GNNs). It is straightforward and quick to achieve this with Mill.jl
. We begin with some dependencies:
using Flux, LightGraphs, Statistics
Let's assume a graph g
, in this case created by barabasi_albert
function from LightGraphs.jl
julia> g = barabasi_albert(10, 3, 2)
{10, 14} undirected simple Int64 graph
Furthermore, let's assume that each vertex is described by seven features stored in a matrix X
:
julia> X = ArrayNode(randn(Float32, 7, 10))
7×10 ArrayNode{Array{Float32,2},Nothing}:
0.029847916 -0.1444097 -0.77922016 … 0.6306408 -0.29290652
-0.4743148 0.8359631 0.7541415 0.106086716 1.6031942
1.2026051 -0.20845382 -0.65242165 -0.7265922 -0.4708613
-0.6189464 -0.19995673 -1.1439679 0.35289836 1.344933
-0.69497037 -1.4790846 -0.87187344 -2.1563978 1.3466836
-2.179117 1.7636526 0.118887305 … -0.15667026 -0.97331053
-1.3534855 0.7487757 -2.0543125 -0.9947581 -0.67006004
We use ScatteredBags
from Mill.jl
to encode neighbors of each vertex. In other words, each vertex is described by a bag of its neighbors. This information is conveniently stored in fadjlist
field of g
, therefore the bags can be constructed as:
julia> b = ScatteredBags(g.fadjlist)
ScatteredBags{Int64}([[4, 5, 6, 8, 10], Int64[], [4, 5], [1, 3, 6, 7], [1, 3, 8], [1, 4, 7, 9], [4, 6], [1, 5, 9], [6, 8, 10], [1, 9]])
Finally, we create two models. First model called lift
will pre-process the description of vertices to some latent space for message passing, and the second one will realize the message passing itself, which we will call mp
:
julia> lift = reflectinmodel(X, d -> Dense(d, 10), d -> mean_aggregation(d))
ArrayModel(Dense(7, 10))
julia> U = lift(X)
10×10 ArrayNode{Array{Float32,2},Nothing}:
0.22535409 0.3822525 0.2969856 … 0.775619 -1.243016
0.5277017 -0.2792255 -1.0729944 -0.33189312 -1.247246
1.6393183 -0.18840593 0.5840895 0.24021442 -0.5072842
-0.30440924 -0.46194744 -0.20234056 -1.0374451 1.0918131
-1.5296241 0.22416803 -0.8389578 -1.196345 0.34388122
-1.4859715 -0.06790199 0.57404274 … -0.8074341 2.2101219
0.61812794 -1.5480729 -0.6891072 -0.546315 0.19719586
-1.3225273 0.00831057 -0.5432095 -1.2234465 1.0026237
-1.4688778 1.3808706 -0.5251934 -0.08397776 -0.7011221
-0.65222025 -0.24466255 -0.31357002 -0.16080739 -0.17337486
julia> mp = reflectinmodel(BagNode(U, b), d -> Dense(d, 10), d -> mean_aggregation(d))
BagModel … ↦ ⟨SegmentedMean(10)⟩ ↦ ArrayModel(Dense(11, 10))
└── ArrayModel(Dense(10, 10))
Notice that BagNode(U, b)
now essentially encodes vertex features as well as the adjacency matrix. This also means that one step of message passing algorithm can be realized as:
julia> Y = mp(BagNode(U, b))
10×10 ArrayNode{Array{Float32,2},Nothing}:
-0.035713222 0.0 -0.22519352 0.1356892 … 0.06734029 0.4295932
0.5615496 0.0 0.7537555 0.074028336 0.2551962 -0.4399602
0.98289657 0.0 1.7433251 0.14271578 0.2615756 -0.7452709
-0.55371916 0.0 0.46866804 -1.5699606 -0.94492227 -2.1067574
0.5289915 0.0 0.65951484 1.1315223 0.08326764 1.4815167
0.1337585 0.0 -0.90799093 1.2921826 … 0.49060202 2.0067813
0.7641726 0.0 1.5191854 -0.2869706 0.23503898 -1.3377743
0.34246188 0.0 1.6722084 -0.7982003 -0.26235634 -1.7717428
0.78609854 0.0 0.5489987 0.31208873 0.8744348 -0.5674417
-0.6528163 0.0 -0.7301485 -0.5214443 -0.3157166 -0.17945614
and it is differentiable, which can be verified by executing:
julia> gradient(() -> sum(sin.(mp(BagNode(U, b)) |> Mill.data)), params(mp))
Grads(...)
If we put everything together, the GNN implementation is implemented in the following 16 lines:
struct GNN{L,M, R}
lift::L
mp::M
m::R
end
Flux.@functor GNN
function mpstep(m::GNN, U::ArrayNode, bags, n)
n == 0 && return(U)
mpstep(m, m.mp(BagNode(U, bags)), bags, n - 1)
end
function (m::GNN)(g, X, n)
U = m.lift(X)
bags = Mill.ScatteredBags(g.fadjlist)
o = mpstep(m, U, bags, n)
m.m(vcat(mean(Mill.data(o), dims = 2), maximum(Mill.data(o), dims = 2)))
end
As it is the case with whole Mill.jl
, even this graph neural network is properly integrated with Flux.jl
ecosystem and suports automatic differentiation:
zd = 10
f(d) = Chain(Dense(d, zd, relu), Dense(zd, zd))
agg(d) = meanmax_aggregation(d)
gnn = GNN(reflectinmodel(X, f, agg),
BagModel(f(zd), agg(zd), f(2zd + 1)),
f(2zd))
julia> gnn(g, X, 5)
10×1 Array{Float32,2}:
0.07108084
-0.24861374
-0.023377487
0.25094596
0.45109645
-0.027826611
-0.06066554
-0.29686767
-0.35696217
-0.18162829
julia> gradient(() -> gnn(g, X, 5) |> sum, params(gnn))
Grads(...)
The above implementation is surprisingly general, as it supports an arbitrarily rich description of vertices. For simplicity, we used only vectors in X
, however, any Mill.jl
hierarchy is applicable.
To put different weights on edges, one can use WeightedBagNode
s instead.