NEW Try Dasha in your browser with Dasha Playground!

Graph Neural Networks: the Hows and the Whys | Dasha.AI

Let’s chat about a popular class of neural networks - graph networks. This article should be easy to follow for a machine learning beginner, provided you are familiar with the domain-specific terms.

We assume that the reader knows that a graph is a pair of sets of vertices (V) and edges (E), each edge is a pair of vertices, see Fig. 1.

Next, we consider non-oriented graphs. That is, we assume that the edges are unordered two-element sets. Graphs arise in many analysis problems, especially when there are many objects and an obvious set of relationships between them (for example, a graph of users of a social network, in which each edge corresponds to the friendship between users).

Graph processing problem

Let’s first consider such a question: how to submit a graph to the input of a neural network (Fig. 3)? A conventional feed forward network (FFN) operates with a fixed input length, (fig. 2), therefore, the first option that comes to mind is to obtain a vector of a fixed dimension (representation or embedding) over the graph. We’ll get to that eventually. First, however, let’s note that there are other techniques for obtaining it, which I didn’t consider in this review like random walks on the graph, for instance.

Another approach is the creation of special neural networks - "graph" (Graph neural networks or GNNs for short), which are capable of processing graphs. Note that an image is also a graph since it is not just a set of pixels, there are neighboring pixels for each pixel and this is important information, since their values, as a rule, are highly correlated, and their strong difference can mean the presence of contrast transitions in the image in the corresponding place: boundaries of objects. Fig. 4 shows examples of neighborhood graphs built using a matrix of pixels (these are the so-called 4- and 8-neighborhoods).

Text is also a graph, albeit a very simple one, because text is a finite set of words (and punctuation marks, which are sometimes neglected in analysis), on which linear order is introduced (there is the first word, the second, etc.), the corresponding graph is shown in Fig 5.

In the general case, an arbitrary graph can be very different from graphs that correspond to images and texts: it doesn’t have a clear structure, arbitrary degrees of vertices, etc. But still, we got some hints on how to process graphs using neural networks. In special cases, when a graph represents text or an image, such networks should be similar to classical ones for working with texts and images: convolutional, recurrent, transformers, etc.

Note that the order of vertices in a graph is often unimportant (as opposed to graphs of images and texts). Therefore, when processing graphs, the equivariance property is often crucial: if, for example, a network layer changes the labels of vertices, then when renumbering the vertices, the corresponding renumbering of labels occurs (as on fig. 5). In Fig 5 there are the two vertices that are indistinguishable from the point of view of the structure of the graph, so if you change their labels and pass through the graph network, then the response labels should also change.

That is why it is impossible to simply feed an adjacency matrix into an ordinary convolutional network: here the vertices are forcedly ordered, and the result of the convolutional layer is not equivariant with respect to the simultaneous identical permutation of rows and columns of such a matrix.

Node representation learning

First, we will consider a rather narrow task - node representation learning. In practice, we don’t have just a graph, each vertex (and / or edge) has a description (fig. 6); we will assume that this feature description is a real vector of fixed length. In the example with a social network about each user, we can know gender, age, etc. But such a representation doesn’t take into account information about the structure of the graph: it is local and doesn’t store information about the neighbors of the vertex (i.e., the user's friends).

GNNs compute vertex representations in an iterative process, different kinds of GNNs do that in different ways, and each iteration corresponds to a network layer. The simplest concept for such computation is Neural Message Passing. In general, Message Passing is a fairly well-known technique in graph analysis; each vertex has a certain state, which is refined in one iteration using the following formula:

hereinafter, N(v) is the neighborhood of the vertex v, AGG (AGGREGATE) is the aggregation function (in the sense, it collects information about neighbors, for example, summing their states), UPDATE is the function of updating the state of the vertex (taking into account the collected information about the neighbors). The only requirement imposed on the last two functions is differentiability, in order to use them in a computational graph and calculate the network parameters using the backpropagation method. The second function is sometimes also called COMBINE, since it takes into account both the current state of the vertex and the aggregation of its neighbors. Obviously, at one iteration, for each vertex, we use only information from its unit neighborhood - its “friends”, on the second, information from a second-order neighborhood (friends of friends) seeps into its updated representation, after the k-th iteration, from the neighborhood of k-th order.

In classical graph networks (original GNN, [1], [2]), the state change occurs according to a very simple formula:

that is, aggregation here is the usual summation, and when updating, the last state and the result of the aggregation are passed through linear layers (multiplied by matrices), then their linear combination with an offset is passed through the activation function.

One of the most popular graph networks is GraphSAGE (Graph Sample and Aggregate), and it has an almost identical formula:

vertical concatenation occurs in square brackets (the product of a matrix by concatenation corresponds to the sum of the products of matrices by concatenated vectors), but in the original work [3], different aggregations were used:

  • Averaging - the mean,
  • LSTM aggregation (using which neighbors were mixed),
  • Pooling (line layer + pooling):

When using LSTM aggregation, the vertices from the neighborhood were ordered and the sequence of their representations was fed into the recurrent LSTM network. Added unsupervised loss to GraphSAGE (I won’t describe it in detail here, but the idea is that neighboring vertices get similar representations and random ones get dissimilar ones). Also, in the original publication ([3]), neighbors from the neighborhood were sampled (Fig. 8 shows an increase in the computation speed), the resulting state was normalized (but sometimes this is not done in the actual work).

Later, modifications of GraphSAGE were proposed, for example, PinSage [4], in which another sampling of the neighborhood states is performed (for processing large graphs), in particular, random walkings are made, sampling is performed among the top-k visited vertices, and the ReLU activation function is used. and the states of neighbors are passed through a separate module before aggregation.

In graph networks, the term self-loops means that we include it in the vicinity of a vertex and its state does not appear separately in the formula, for example, as here [5]:

Generalized Neighborhood Aggregation

More often, the aggregation was done by summing the states of neighbors, although you can use averaging (Neighborhood Normalization):


The latter averaging is used by Graph convolutional networks (GCNs). In [6], it was shown that it is important to choose the right aggregation, since it can lose some important information, for example, the degree of a vertex under ordinary averaging.

Another popular aggregation is set pooling [7]:

functions f and g are implemented by deep multi-layer perceptrons. Note that it has been proven that a network with such aggregation is a universal approximator.

Also, note Janossy pooling [8], in which the summation is performed over all permutations (this is in theory, but in practice - over several random ones).

The GIN (Graph Isomorphism Network) uses a fairly simple formula for state adaptation (and aggregation here is a simple summation) [9]:

There are Learnable Aggregators ([10]), and the simplest option is to include in the masking formula (vectors that are elementwise multiplied by the states of neighbors and are the outputs of a separate learning module):

And, of course, during aggregation, you can use the Attention, then the states of neighbors are summed up with attention coefficients:

Different attention mechanisms can be used, including those with several heads (we will not list them all); LeakyReLU was used as a function f in the original work on Neighborhood Attention: Graph Attention Network (GAT) [11]. The interpretation of the attention mechanism is present here: we look at our neighbors and then select the coefficients in the aggregation.

Adaptation problem

The main problem when using the described neural propagation of the so-called over-smoothing: after several iterations of vertex states recalculation, the representations of neighboring vertices become similar, since they have similar neighborhoods. To combat this, you can do the following:

use fewer layers of aggregation and more for "processing features", skip layers or concatenating states from previous layers, use architectures that have a memory effect (eg GRU).

For one of the methods of state concatenation in the theory of graph networks, there is the term "Jumping Knowledge Connections" [12]:

The DEEPGCNS network [13] also uses the idea of throwing as does DenseNet but the former also uses dilated convolutions by analogy, as is done in convolutional networks: the neighbors of the vertex are ordered and aggregation occurs by k of them, for ordinary convolution by the first k, for sparse with a step of 1: 1, 3, 5, etc. (fig. 10).

Graph aggregation

So far, we have considered ways to iteratively refine the representations of vertices, but we started with the fact that it would be nice to be able to represent a graph in the form of a fixed vector. How do you do this with vertex representations? The first obvious way is Graph Pooling, which is to summarize the views obtained:

Another approach is to consider that the set of vertices is a sequence, therefore, their representations can be fed to the input of the LSTM network, possibly with an attention mechanism. In the SortPool approach [14], the vertices are sorted (in some way, for example, by the last element in the vector-representation), the representations of the first k vertices are concatenated and fed to the input to an ordinary neural network.

There is also a group of approaches in which the graph is iteratively simplified. In fact, pooling in convolutional layers for images and texts just simplifies the graph, so in GNN it is also reasonable to have such an operation. In SAGPool (Self-Attention Graph Pooling) [15], the ratings-ratings of the vertices are iteratively obtained, leaving the top-vertices. The process continues until there is one left or to the Readout layer (this is done in the original publication). As the name suggests, SAGPool uses the attention mechanism.

Another method that uses graph simplification - DiffPool [16] - at each iteration, the graph vertices are clustered (in terms of the theory of graph networks - “community allocation”), the vertices are aggregated in clusters, we go to the graph with the vertices corresponding to the communities. In MinCutPool [17], clustering is performed by the spectral clustering method (k-means in a special space, which is obtained by finding the eigenvectors of the matrix obtained from the graph adjacency matrix).

Using GNN

Now let's consider what tasks and how are solved using graph networks. From what we currently know:

the input of the network is a graph, each vertex of which has a feature description. This description can be considered the initial state (representation) of the vertex (fig. 14), there may be layers that independently modify the views (for each vertex, its presentation is passed through a small network). (1 in fig. 14), there may be layers that modify the representations of all vertices, taking into account the representations of the neighboring vertices. (2 in fig. 14), there may be layers that simplify the graph (for example, reduce the number of vertices). (3) in fig. 14). there can be layers that receive a graph representation (a vector of fixed length) along with the current graph with vertex representations. (4 in fig. 14).

Typically, such networks are used for tasks:

node-level predictions, such as node classification, edge-level predictions, for example, in the problem of predicting the probability of occurrence of an edge in a graph (link prediction), graph-level predictions, for example for graph classification, community detection or relation prediction.

For the graph level predictions, a graph representation is taken (in pure form or passed through a feedforward network), the error function depends on the problem, for example, in the regression problem (a real number is determined from the graph), you can use such a formula:

In tasks at the vertex level, the same is done with the representations of the vertices:

in the tasks of clustering vertices (for example, community detection), clustering is carried out in the space of vertex representations. In tasks at the edge level, they work with the concatenation of the edge representation or with a function from a pair of representations, for example, with the dot product of representations:

Graph networks can also be used in tasks without labels, then in the error functions it is estimated how consistent the representations of the vertices and the graph containing these vertices are, for example, in Deep Graph Infomax (DGI), a discriminator is also used, which, for a pair (vertex representation, graph representation), tries to understand whether a vertex belongs to a graph, negative examples are obtained by modifying the original graph [18].

Graph Augmentation

As in other areas of deep learning, graph neural networks use augmentation - methods of increasing the training sample by modifying the initial data. In principle, the sampling of neighbors, which we considered, in a sense, changes the original information. If the graph is sparse, then it is replenished with edges according to the principle "my friend's friend is my friend." If dense and large, then in some problems subgraphs are taken.

What have we not covered?

The most important questions that we have not yet touched on:

application of spectral graph theory in graph networks, including obtaining representations of convolutions, other robust graph analysis methods, generative graph networks, examples of graph network applications and libraries for GNN implementation.


There is some good reading material on graph networks. First of all, here are some great review books:

Secondly, a bunch of good reviews came out recently. A lot of information in this post came from these:


[1] Merkwirth, Christian & Lengauer, Thomas «Automatic Generation of Complementary Descriptors with Molecular Graph Networks». Journal of chemical information and modeling, 2005. 45. 1159-68. [2] F. Scarselli, M. Gori, A. C. Tsoi, M. Hagenbuchner and G. Monfardini, "The Graph Neural Network Model," in IEEE Transactions on Neural Networks, vol. 20, no. 1, pp. 61-80, Jan. 2009, doi: 10.1109/TNN.2008.2005605. [3] W. L. Hamilton, R. Ying и J. Leskovec «Inductive Representation Learning onLarge Graphs» // [4] R. Ying, R. He, K. Chen, P. Eksombatchai, W. L. Hamilton, and J. Leskovec. 2018a. Graph convolutional neural networks for web-scale recommender systems. In Proc. of SIGKDD. DOI: 10.1145/3219819.3219890 [5] George T. Cantwell, M. E. J. Newman «Message passing on networks with loops» // [6] Thomas N Kipf, Max Welling «Semi-supervised classification with graph convolutional networks» // ICLR, 2017. [7] M. Zaheer, S. Kottur, S. Ravanbakhsh, B. Poczos, R. R. Salakhutdinov, and A. J. Smola. 2017. Deep sets. In Proc. of NIPS, pages 3391–3401. [8] Ryan L. Murphy «Janossy pooling: learning deep permutation-invariant functions for variable-size inputs» // [9] K. Xu и др. «How Powerful are Graph Neural Networks?» // [10] Z. Li и H. Lu. «Learnable Aggregator for GCN» // [11] P. Velickovic, G. Cucurull, A. Casanova, A. Romero, P. Lio, and Y. Bengio «Graph attention networks» // In Proc. of ICLR, 2018 [12] K. Xu, C. Li, Y. Tian, T. Sonobe, K. Kawarabayashi, and S. Jegelka. 2018. Representation learning on graphs with jumping knowledge networks. In Proc. of ICML, pages 5449–5458. [13] G. Li, M. Muller, A. Thabet, and B. Ghanem «DeepGCNs: Can GCNs go as deep as CNNs?» // In Proc. of ICCV 2019 [13b] Cangea, C., Velickovic, P., Jovanovic, N., Kipf, T., and Lio,P. «Towards sparse hierarchical graph classifiers» // arXiv:1811.01287, 2018. [14] M. Zhang, Z. Cui, M. Neumann, Y. Chen «An End-to-End Deep Learning Architecture for Graph Classification» [15] J. Lee, I. Lee, J. Kang «Self-Attention Graph Pooling» [16] R. Ying и др. «Hierarchical Graph Representation Learning with Differentiable Pooling» // [17] Filippo Maria Bianchi «Spectral Clustering with Graph Neural Networks for Graph Pooling» // [18] Petar Veličković, William Fedus, William L. Hamilton, Pietro Liò, Yoshua Bengio, R Devon Hjelm «Deep Graph Infomax» //

Related Posts