Machine Learning on Graphs

By Matthias Bal
January 21st, 2021
6 minutes
AItechnologyArtificial IntelligenceMachine LearningMachine Learning Graphs

Why Should You Care?

In 2020, graph neural networks were among the fastest-growing topics at both ICLR and NeurIPS, two of the biggest, annual machine learning conferences for research and industry. Although deep learning on graphs is still quite new and not yet well-known in the broader community, graph machine learning frameworks have already been deployed to production at large tech companies.

This post will show how graphs appear naturally in many different fields and how the deep learning toolbox can be applied to this kind of data. We will highlight recent applications in research and in the industry to better understand why interest in graph machine learning is expected to continue to grow in the next few years.

Why Deep Learning on Graphs?

In the last decade, the advent of large-scale datasets and dedicated computing power has enabled deep learning approaches to achieve unprecedented performance on a broad range of well-defined tasks. So far, the main focus has been on data defined on regular grids, like images (2D grids of pixel values), text (1D grid of tokens), and sound (1D grids of audio samples). But not all data fits a regular grid.

Image source:

In computer graphics, biology, physics, and network science, data is often defined on irregular grids, graphs or surfaces. Think of point clouds, 3D meshes, protein structures, molecules, and networks of epidemic spread, payments or traffic. Geometric deep learning sets out to apply deep learning methods to graphs and other irregular data structures.

What is a Graph?

Graphs are representations of relationships (edges) between entities (nodes) that serve as a powerful abstraction to encode structure. Edge relationships between nodes can appear naturally, like in social networks, but also more abstractly, like in similarity graphs, where they reflect a measure of likeness or distance between nodes. Since learning from data is fundamentally about discovering relationships between entities, graph representations provide a natural representation of structured data.

Machine Learning Graph
Image source:

Graphs can also help to build up global information from local information. Let’s consider the example of a social network where nodes correspond to users and edges to connections. Looking at connected neighbours provides useful information about the local environment of users. But detecting clusters of malicious actors in a network often requires zooming out and looking at the global structure, which, in turn, is composed of the local environments of many individual nodes interacting with each other.

Graph Neural Networks

A key concept in deep learning and neural networks is representation learning: turning structure in data into representations useful for machines to work with. By extracting signals from very large and complex datasets, remarkably rich representations can be obtained from data.

Graph neural networks (GNNs) implement representation learning for graphs by converting graph data into useful, low-dimensional representations while trying to preserve structural information. Current generations of GNN algorithms rely on the idea of message-passing. Nodes and edges are endowed with feature vectors, and messages are passed between connected nodes. Each node repeatedly updates its own representation of the local environment by gathering features from its neighbours, propagating information across the graph.

Image source:

The building blocks for aggregating and updating representations are parametrized using neural networks. Crucially, GNNs require only a fixed amount of parameters that do not depend on the graph's size, making learning scalable to very large graphs if neighbourhoods of nodes can be efficiently sampled.

Applications in Research and Production

Having gotten a taste of what deep learning on graphs is all about, let us now highlight a diverse set of research and production applications to showcase the potential of graph neural networks.

Recommender Systems

PinSage is a scalable framework for product discovery and recommendation developed by Pinterest in 2018. Deployed in production, the framework’s graph neural network components were acting on a graph with around 3 billion nodes and 18 billion edges. Pinterest reported around 30% relative improvement in terms of user engagement rates compared to previous systems.

Alibaba developed AliGraph in 2019 to be able to capture rich and complex relationships among billions of elements in large-graph datasets. Their deployed system shows speed and efficiency improvements and provides the backend for product recommendation and personalized search.

Uber incorporated graph neural networks into their food recommendation engine for Uber Eats and discovered that the graph-learned embeddings captured more information than any existing feature in their recommendation system, leading to significant performance improvements, user engagement and click-through rates.

Anomaly Detection

Through acquiring the London startup Fabula AI in 2019, Twitter has been investing in graph deep learning technology to detect malicious behaviour, network manipulation, and fake news promotion in its social network graphs.

Computer Vision

MagicLeap’s SuperGlue model uses graph neural networks to address the problem of feature matching, which is useful for reconstruction, recognition, localisation, and mapping in 3D space. Their model outperformed other learned approaches and achieved state-of-the-art results on challenging real-time indoor and outdoor pose estimation.

Combinatorial Optimization

Problems in combinatorial optimization frequently appear in manufacturing and logistics. Graph neural networks can be integrated with existing solvers or serve as components in end-to-end optimization systems. A recent example is Google’s work on improving automatic chip design with deep reinforcement learning, where graph and edge embeddings are computed for the graph of a chip’s circuit components, serving as input to the model’s policy and value networks.

Life Sciences

In life sciences, graph neural networks have recently been applied to antibiotic discovery to predict antibacterial activity from molecular structure. The model predicted an existing drug to have antibiotic activity, which was tested and verified in the lab by treating infections in mice, even though the drug looked very different from conventional antibiotics. An example of geometric deep learning can be found in recent work on predicting protein interactions, which defines an architecture of convolutions and filtering operations on point clouds of atoms to identify binding sites and predict protein-protein interactions.

Image source:

Future advances of machine learning applications in the physical sciences and life sciences are likely to reinforce each other, with fast and accurate simulation tools from chemistry and physics making their way into bioinformatics software used across the life sciences.

Natural Language Processing: Transformers

Whether it's Google's BERT serving search queries, OpenAI's GPT-3 writing coherent paragraphs, or DeepMind's AlphaFold2 predicting protein structures, state-of-the-art AI systems today rely increasingly on just a single kind of model architecture: attention-based Transformer models. Transformers are set neural networks which dynamically pay attention to context according to whatever is needed for the task at hand.

From a graph neural network perspective, Transformer models can be interpreted as acting on a fully-connected graph of nodes. By letting go of explicit, sparse graph structures and connecting all nodes to each other, Transformers rely on constructing implicit, latent graphs where the strength of connections between nodes is computed dynamically from attending to context. Combining this kind of open-ended architectural flexibility with massive amounts of data and compute leads to very powerful and adaptive classes of neural network models.


Geometric deep learning provides an attractive framework to lift deep learning successes to domains characterized by irregularly-structured data like graphs. Given the encouraging results obtained so far, its adoption rate is expected to grow fast in the next few years as research continues to mature into practical applications, offering lots of potential for early adopters in industry.



What 2021 holds for Graph ML?. Leading researchers in Graph ML… | by Michael Bronstein

Top Applications of Graph Neural Networks 2021 | by Sergei Ivanov

Geometric ML becomes real in fundamental sciences | by Michael Bronstein

Latent graph neural networks: Manifold learning 2.0? | by Michael Bronstein

Do we need deep graph neural networks? | by Michael Bronstein

Food Discovery with Uber Eats: Using Graph Learning to Power Recommendations

Towards understanding glasses with graph neural networks

Transformers are Graph Neural Networks


nnzhan/Awesome-Graph-Neural-Networks: Paper Lists for Graph Neural Networks

thunlp/GNNPapers: Must-read papers on graph neural networks (GNN)

Learning Material

Stanford CS224W: Machine Learning with Graphs

Graph Representation Learning Book

Graph Convolutional Networks I · Deep Learning

Graph Convolutional Networks II · Deep Learning

Week 13 – Lecture: Graph Convolutional Networks (GCNs)

Xavier Bresson "Recent Developments of Graph Network Architectures"


rusty1s/pytorch_geometric: Geometric Deep Learning Extension Library for PyTorch

dmlc/dgl: Python package built to ease deep learning on graph, on top of existing DL frameworks.

deepmind/graph_nets: Build Graph Nets in Tensorflow

Stay up to date

Stay ahead of the world. Our team shares their
knowledge learnt on the field. Sign up for our