Tutorial: learn crystal energies with Nequip

mp-169 crystal

The goal of this tutorial is to show how to use Nequip to train a neural network to predict the energy of crystals.

What this tutorial will cover (and not cover)

Create a Nequip model
Create a simple dataset
Train the model to predict the energy
Introduction to jraph to deal with graphs
Introduction to flax to create parameterized models
Introduction to optax to optimize the parameters
Train to predict the forces (by differentiating the energy)
Having more than one batch and pad them with jraph.pad_with_graphs
Add support for different atom types (by embedding them)
Test the model on unseen crystals

Import the libraries

For this tutorial we will need the following libraries:

import flax  # neural network modules for jax
import jax
import jax.numpy as jnp
import jraph  # graph neural networks in jax
import matplotlib.pyplot as plt
import numpy as np
import optax  # optimizers for jax
from matscipy.neighbours import neighbour_list  # fast neighbour list implementation

from nequip_jax import NEQUIPLayerFlax  # e3nn implementation of NEQUIP
import e3nn_jax as e3nn

If you want to use a GPU, checkout the jax installation guide.

Here are the pip commands to install everything:

pip install -U pip
pip install -U "jax[cpu]"

pip install -U flax jraph optax
pip install -U matscipy
pip install -U matplotlib
pip install -U e3nn_jax

pip install git+https://github.com/mariogeiger/nequip-jax

Dataset

Let’s create a very simple dataset of few crystals from materials project made only of carbon atoms. Materials project provides a Predicted Formation Energy for each crystal, and we will use this as our target.

To compute the graph connectivity we use matscipy library which has a very fast neighbour list implementation.

def compute_edges(positions, cell, cutoff):
    """Compute edges of the graph from positions and cell."""
    receivers, senders, senders_unit_shifts = neighbour_list(
        quantities="ijS",
        pbc=np.array([True, True, True]),
        cell=cell,
        positions=positions,
        cutoff=cutoff,
    )

    num_edges = senders.shape[0]
    assert senders.shape == (num_edges,)
    assert receivers.shape == (num_edges,)
    assert senders_unit_shifts.shape == (num_edges, 3)
    return senders, receivers, senders_unit_shifts

This image shows the edges created by matscipy.neighbours.neighbour_list for an example crystal in 2D. Note that all the edges point (receivers side) to an atom in the central cell.

Graph Edges of a Periodic Crystal

Then we use jraph to create a graph objects and (later) batch them together. jraph is a library for graph neural networks in jax developed by DeepMind. The following function create_graph creates a graph object from the given positions, cell and energy of the crystal. Each crystal is stored in a jraph.GraphsTuple, which is the cornerstone datatype of the jraph library. It is a named tuple that contains all the information about a graph. The documentation of jraph.GraphsTuple can be found here.

def create_graph(positions, cell, energy, cutoff):
    """Create a graph from positions, cell, and energy."""
    senders, receivers, senders_unit_shifts = compute_edges(positions, cell, cutoff)

    # In a jraph.GraphsTuple object, nodes, edges, and globals can be any
    # pytree. We will use dicts of arrays.
    # What matters is that the first axis of each array has length equal to
    # the number of nodes, edges, or graphs.
    num_nodes = positions.shape[0]
    num_edges = senders.shape[0]

    graph = jraph.GraphsTuple(
        # positions are per-node features:
        nodes=dict(positions=positions),
        # Unit shifts are per-edge features:
        edges=dict(shifts=senders_unit_shifts),
        # energy and cell are per-graph features:
        globals=dict(energies=np.array([energy]), cells=cell[None, :, :]),
        # The rest of the fields describe the connectivity and size of the graph.
        senders=senders,
        receivers=receivers,
        n_node=np.array([num_nodes]),
        n_edge=np.array([num_edges]),
    )
    return graph

We need to specify the cutoff for the neighbour list. This is the distance up to which we consider two atoms to be connected. All the distances here are in angstroms.

cutoff = 2.0  # in angstroms

Now we can create the graphs for the crystals. The values of the positions, cell and energy are taken from the materials project website.

mp47 = create_graph(
    positions=np.array(
        [
            [-0.0, 1.44528, 0.26183],
            [1.25165, 0.72264, 2.34632],
            [1.25165, 0.72264, 3.90714],
            [-0.0, 1.44528, 1.82265],
        ]
    ),
    cell=np.array([[2.5033, 0.0, 0.0], [-1.25165, 2.16792, 0.0], [0.0, 0.0, 4.16897]]),
    energy=0.163,  # eV/atom
    cutoff=cutoff,
)
print(f"mp47 has {mp47.n_node} nodes and {mp47.n_edge} edges")

mp48 = create_graph(
    positions=np.array(
        [
            [0.0, 0.0, 1.95077],
            [0.0, 0.0, 5.8523],
            [-0.0, 1.42449, 1.95077],
            [1.23365, 0.71225, 5.8523],
        ]
    ),
    cell=np.array([[2.46729, 0.0, 0.0], [-1.23365, 2.13674, 0.0], [0.0, 0.0, 7.80307]]),
    energy=0.008,  # eV/atom
    cutoff=cutoff,
)
print(f"mp48 has {mp48.n_node} nodes and {mp48.n_edge} edges")

mp66 = create_graph(
    positions=np.array(
        [
            [0.0, 0.0, 1.78037],
            [0.89019, 0.89019, 2.67056],
            [0.0, 1.78037, 0.0],
            [0.89019, 2.67056, 0.89019],
            [1.78037, 0.0, 0.0],
            [2.67056, 0.89019, 0.89019],
            [1.78037, 1.78037, 1.78037],
            [2.67056, 2.67056, 2.67056],
        ]
    ),
    cell=np.array([[3.56075, 0.0, 0.0], [0.0, 3.56075, 0.0], [0.0, 0.0, 3.56075]]),
    energy=0.138,  # eV/atom
    cutoff=cutoff,
)
print(f"mp66 has {mp66.n_node} nodes and {mp66.n_edge} edges")

mp169 = create_graph(
    positions=np.array(
        [
            [-0.66993, 0.0, 3.5025],
            [3.5455, 0.0, 0.00033],
            [1.45739, 1.22828, 3.5025],
            [1.41818, 1.22828, 0.00033],
        ]
    ),
    cell=np.array([[4.25464, 0.0, 0.0], [0.0, 2.45656, 0.0], [-1.37907, 0.0, 3.50283]]),
    energy=0.003,  # eV/atom
    cutoff=cutoff,
)
print(f"mp169 has {mp169.n_node} nodes and {mp169.n_edge} edges")
mp47 has [4] nodes and [16] edges
mp48 has [4] nodes and [12] edges
mp66 has [8] nodes and [32] edges
mp169 has [4] nodes and [12] edges

Now that we have mp47, mp48, mp66 and mp169 as graphs, we can batch them together to create a dataset. Batching is an important concept in jraph, it merges different graphs into a single graph. The batch function takes a list of graphs and returns a single graph with the same fields as the input graphs. The n_node and n_edge fields are used to keep track of the number of nodes and edges in each graph.

dataset = jraph.batch([mp47, mp48, mp66, mp169])
print(f"dataset has {dataset.n_node} nodes and {dataset.n_edge} edges")

# Print the shapes of the fields of the dataset.
print(jax.tree_util.tree_map(jnp.shape, dataset))
dataset has [4 4 8 4] nodes and [16 12 32 12] edges
GraphsTuple(nodes={'positions': (20, 3)}, edges={'shifts': (72, 3)}, receivers=(72,), senders=(72,), globals={'cells': (4, 3, 3), 'energies': (4,)}, n_node=(4,), n_edge=(4,))

Model

Before defining the model, we need to make sure we properly take into account the periodic boundary conditions of the crystals. The model will need to know the relative vectors between the atoms in the crystal. We know the positions of the atoms inside the unit cell, but we need to know the relative vectors between the atoms even if they don’t belong to the same cell.

def get_relative_vectors(senders, receivers, n_edge, positions, cells, shifts):
    """Compute the relative vectors between the senders and receivers."""
    num_nodes = positions.shape[0]
    num_edges = senders.shape[0]
    num_graphs = n_edge.shape[0]

    assert positions.shape == (num_nodes, 3)
    assert cells.shape == (num_graphs, 3, 3)
    assert senders.shape == (num_edges,)
    assert receivers.shape == (num_edges,)
    assert shifts.shape == (num_edges, 3)

    # We need to repeat the cells for each edge.
    cells = jnp.repeat(cells, n_edge, axis=0, total_repeat_length=num_edges)

    # Compute the two ends of each edge.
    positions_receivers = positions[receivers]
    positions_senders = positions[senders] + jnp.einsum("ei,eij->ej", shifts, cells)

    vectors = e3nn.IrrepsArray("1o", positions_receivers - positions_senders)
    return vectors

Now we define the model layer based on Nequip architecture.

Nequip architecture

For that we will use the implementation available at github.com/mariogeiger/nequip-jax. This implementation provides a NEQUIPLayerFlax class that implements the Interaction Block part of the Nequip architecture (the blue box in the figure above). The Embedding, Self-Interaction and Global Pooling parts of the Nequip architecture are not implemented in nequip-jax and we will need to implement them ourselves below.

The class below defines a flax-Module. flax is a neural network library that is built on top of jax. flax provides a Module class that is similar to nn.Module in PyTorch. flax also provides a compact decorator that allows us to define the __call__ method of the Module in a more concise way. See the flax documentation for more details.

class Model(flax.linen.Module):
    @flax.linen.compact
    def __call__(self, graphs):
        num_nodes = graphs.nodes["positions"].shape[0]
        senders = graphs.senders
        receivers = graphs.receivers

        vectors = get_relative_vectors(
            senders,
            receivers,
            graphs.n_edge,
            positions=graphs.nodes["positions"],
            cells=graphs.globals["cells"],
            shifts=graphs.edges["shifts"],
        )

        # We divide the relative vectors by the cutoff
        # because NEQUIPLayerFlax assumes a cutoff of 1.0
        vectors = vectors / cutoff

        # Embedding: since we have a single atom type, we don't need embedding
        # The node features are just ones and the species indices are all zeros
        features = e3nn.IrrepsArray("0e", jnp.ones((num_nodes, 1)))
        species = jnp.zeros((num_nodes,), dtype=jnp.int32)

        # Apply 3 Nequip layers with different internal representations
        for irreps in [
            "32x0e + 8x1o + 8x2e",
            "32x0e + 8x1o + 8x2e",
            "32x0e",
        ]:
            layer = NEQUIPLayerFlax(
                avg_num_neighbors=20.0,  # average number of neighbors to normalize by
                output_irreps=irreps,
            )
            features = layer(vectors, features, species, senders, receivers)

        # Self-Interaction layers
        features = e3nn.flax.Linear("16x0e")(features)
        features = e3nn.flax.Linear("0e")(features)

        # Global Pooling
        # Average the features (energy prediction) over the nodes of each graph
        return e3nn.scatter_sum(features, nel=graphs.n_node) / graphs.n_node[:, None]

Training

Now that we defined the model, we need to define the loss function to train it. For this example we will use the mean squared error.

def loss_fn(preds, targets):
    assert preds.shape == targets.shape
    return jnp.mean(jnp.square(preds - targets))

Now let’s use the magic of flax to initialize the model and use the magic of optax to define the optimizer and initialize it as well. As optimizer we will use Adam. This optimizer needs to keep track of the average of the gradients and the average of the squared gradients. This is why it has a state. The state is initialized with opt.init.

random_key = jax.random.PRNGKey(0)  # change it to get different initializations

# Initialize the model
model = Model()
params = jax.jit(model.init)(random_key, dataset)

# Initialize the optimizer
opt = optax.adam(1e-3)
opt_state = opt.init(params)

Let’s define the training step. We will use jax.jit to compile the function and make it faster. This function takes as input the model parameters, the optimizer state and the dataset and returns the updated optimizer state, the updated model parameters and the loss.

@jax.jit
def train_step(opt_state, params, dataset):
    """Perform a single training step."""
    num_graphs = dataset.n_node.shape[0]

    # Compute the loss as a function of the parameters
    def fun(w):
        preds = model.apply(w, dataset).array.squeeze(1)
        targets = dataset.globals["energies"]

        assert preds.shape == (num_graphs,)
        assert targets.shape == (num_graphs,)
        return loss_fn(preds, targets)

    # And take its gradient
    loss, grad = jax.value_and_grad(fun)(params)

    # Update the parameters and the optimizer state
    updates, opt_state = opt.update(grad, opt_state)
    params = optax.apply_updates(params, updates)

    return opt_state, params, loss

Finally, let’s train the model for 1000 iterations.

losses = []
for _ in range(1000):
    opt_state, params, loss = train_step(opt_state, params, dataset)
    losses.append(loss)

Did it work?

plt.plot(losses)
plt.xscale("log")
plt.yscale("log")
plt.xlabel("Iteration")
plt.ylabel("Loss")
Text(0, 0.5, 'Loss')
../_images/nequip_12_1.png