Message Passing

Index

Interface

GraphNeuralNetworks.apply_edgesFunction
apply_edges(f, g, xi, xj, e)
apply_edges(f, g; [xi, xj, e])

Returns the message from node j to node i . In the message-passing scheme, the incoming messages from the neighborhood of i will later be aggregated in order to update the features of node i.

The function operates on batches of edges, therefore xi, xj, and e are tensors whose last dimension is the batch size, or can be named tuples of such tensors.

Arguments

  • g: A GNNGraph.
  • xi: An array or a named tuple containing arrays whose last dimension's size is g.num_nodes. It will be appropriately materialized on the target node of each edge (see also edge_index).
  • xj: As xi, but now to be materialized on each edge's source node.
  • e: An array or a named tuple containing arrays whose last dimension's size is g.num_edges.
  • f: A function that takes as inputs the edge-materialized xi, xj, and e. These are arrays (or named tuples of arrays) whose last dimension' size is the size of a batch of edges. The output of f has to be an array (or a named tuple of arrays) with the same batch size.

See also propagate and aggregate_neighbors.

GraphNeuralNetworks.aggregate_neighborsFunction
aggregate_neighbors(g::GNNGraph, aggr, m)

Given a graph g, edge features m, and an aggregation operator aggr (e.g +, min, max, mean), returns the new node features

\[\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i}\]

Neighborhood aggregation is the second step of propagate, where it comes after apply_edges.

GraphNeuralNetworks.propagateFunction
propagate(f, g, aggr; xi, xj, e)  ->  m̄

Performs message passing on graph g. Takes care of materializing the node features on each edge, applying the message function, and returning an aggregated message $\bar{\mathbf{m}}$ (depending on the return value of f, an array or a named tuple of arrays with last dimension's size g.num_nodes).

It can be decomposed in two steps:

m = apply_edges(f, g, xi, xj, e)
m̄ = aggregate_neighbors(g, aggr, m)

GNN layers typically call propagate in their forward pass, providing as input f a closure.

Arguments

  • g: A GNNGraph.
  • xi: An array or a named tuple containing arrays whose last dimension's size is g.num_nodes. It will be appropriately materialized on the target node of each edge (see also edge_index).
  • xj: As xj, but to be materialized on edges' sources.
  • e: An array or a named tuple containing arrays whose last dimension's size is g.num_edges.
  • f: A generic function that will be passed over to apply_edges. Has to take as inputs the edge-materialized xi, xj, and e (arrays or named tuples of arrays whose last dimension' size is the size of a batch of edges). Its output has to be an array or a named tuple of arrays with the same batch size.
  • aggr: Neighborhood aggregation operator. Use +, mean, max, or min.

Examples

using GraphNeuralNetworks, Flux

struct GNNConv <: GNNLayer
    W
    b
    σ
end

Flux.@functor GNNConv

function GNNConv(ch::Pair{Int,Int}, σ=identity)
    in, out = ch
    W = Flux.glorot_uniform(out, in)
    b = zeros(Float32, out)
    GNNConv(W, b, σ)
end

function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix)
    message(xi, xj, e) = l.W * xj
    m̄ = propagate(message, g, +, xj=x)
    return l.σ.(m̄ .+ l.bias)
end

l = GNNConv(10 => 20)
l(g, x)

See also apply_edges and aggregate_neighbors.

Built-in message functions

GraphNeuralNetworks.e_mul_xjFunction
e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj

Reshape e into broadcast compatible shape with xj (by prepending singleton dimensions) then perform broadcasted multiplication.