Tutorial: learn crystal energies with Nequip ============================================ .. image:: mp-169.png :width: 400 :alt: mp-169 crystal The goal of this tutorial is to show how to use Nequip to train a neural network to predict the energy of crystals. .. |check| raw:: html .. |uncheck| raw:: html What this tutorial will cover (and not cover) --------------------------------------------- | |check| Create a Nequip model | |check| Create a simple dataset | |check| Train the model to predict the energy | |check| Introduction to ``jraph`` to deal with graphs | |check| Introduction to ``flax`` to create parameterized models | |check| Introduction to ``optax`` to optimize the parameters | |uncheck| Train to predict the forces (by differentiating the energy) | |uncheck| Having more than one batch and pad them with ``jraph.pad_with_graphs`` | |uncheck| Add support for different atom types (by embedding them) | |uncheck| Test the model on unseen crystals Import the libraries -------------------- For this tutorial we will need the following libraries: .. jupyter-execute:: import flax # neural network modules for jax import jax import jax.numpy as jnp import jraph # graph neural networks in jax import matplotlib.pyplot as plt import numpy as np import optax # optimizers for jax from matscipy.neighbours import neighbour_list # fast neighbour list implementation from nequip_jax import NEQUIPLayerFlax # e3nn implementation of NEQUIP import e3nn_jax as e3nn If you want to use a GPU, checkout the `jax installation guide `_. Here are the ``pip`` commands to install everything:: pip install -U pip pip install -U "jax[cpu]" pip install -U flax jraph optax pip install -U matscipy pip install -U matplotlib pip install -U e3nn_jax pip install git+https://github.com/mariogeiger/nequip-jax Dataset ------- Let's create a **very** simple dataset of few crystals from `materials project `_ made only of carbon atoms. Materials project provides a *Predicted Formation Energy* for each crystal, and we will use this as our target. * `mp-47 `_ * `mp-48 `_ * `mp-66 `_ * `mp-169 `_ To compute the graph connectivity we use `matscipy `_ library which has a very fast neighbour list implementation. .. jupyter-execute:: def compute_edges(positions, cell, cutoff): """Compute edges of the graph from positions and cell.""" receivers, senders, senders_unit_shifts = neighbour_list( quantities="ijS", pbc=np.array([True, True, True]), cell=cell, positions=positions, cutoff=cutoff, ) num_edges = senders.shape[0] assert senders.shape == (num_edges,) assert receivers.shape == (num_edges,) assert senders_unit_shifts.shape == (num_edges, 3) return senders, receivers, senders_unit_shifts This image shows the edges created by ``matscipy.neighbours.neighbour_list`` for an example crystal in 2D. Note that all the edges point (``receivers`` side) to an atom in the central cell. .. image:: graph.png :width: 400 :alt: Graph Edges of a Periodic Crystal Then we use `jraph `_ to create a graph objects and (later) batch them together. ``jraph`` is a library for graph neural networks in jax developed by DeepMind. The following function ``create_graph`` creates a graph object from the given positions, cell and energy of the crystal. Each crystal is stored in a ``jraph.GraphsTuple``, which is the cornerstone datatype of the ``jraph`` library. It is a named tuple that contains all the information about a graph. The documentation of ``jraph.GraphsTuple`` can be found `here `_. .. jupyter-execute:: def create_graph(positions, cell, energy, cutoff): """Create a graph from positions, cell, and energy.""" senders, receivers, senders_unit_shifts = compute_edges(positions, cell, cutoff) # In a jraph.GraphsTuple object, nodes, edges, and globals can be any # pytree. We will use dicts of arrays. # What matters is that the first axis of each array has length equal to # the number of nodes, edges, or graphs. num_nodes = positions.shape[0] num_edges = senders.shape[0] graph = jraph.GraphsTuple( # positions are per-node features: nodes=dict(positions=positions), # Unit shifts are per-edge features: edges=dict(shifts=senders_unit_shifts), # energy and cell are per-graph features: globals=dict(energies=np.array([energy]), cells=cell[None, :, :]), # The rest of the fields describe the connectivity and size of the graph. senders=senders, receivers=receivers, n_node=np.array([num_nodes]), n_edge=np.array([num_edges]), ) return graph We need to specify the cutoff for the neighbour list. This is the distance up to which we consider two atoms to be connected. All the distances here are in angstroms. .. jupyter-execute:: cutoff = 2.0 # in angstroms Now we can create the graphs for the crystals. The values of the positions, cell and energy are taken from the materials project website. .. jupyter-execute:: mp47 = create_graph( positions=np.array( [ [-0.0, 1.44528, 0.26183], [1.25165, 0.72264, 2.34632], [1.25165, 0.72264, 3.90714], [-0.0, 1.44528, 1.82265], ] ), cell=np.array([[2.5033, 0.0, 0.0], [-1.25165, 2.16792, 0.0], [0.0, 0.0, 4.16897]]), energy=0.163, # eV/atom cutoff=cutoff, ) print(f"mp47 has {mp47.n_node} nodes and {mp47.n_edge} edges") mp48 = create_graph( positions=np.array( [ [0.0, 0.0, 1.95077], [0.0, 0.0, 5.8523], [-0.0, 1.42449, 1.95077], [1.23365, 0.71225, 5.8523], ] ), cell=np.array([[2.46729, 0.0, 0.0], [-1.23365, 2.13674, 0.0], [0.0, 0.0, 7.80307]]), energy=0.008, # eV/atom cutoff=cutoff, ) print(f"mp48 has {mp48.n_node} nodes and {mp48.n_edge} edges") mp66 = create_graph( positions=np.array( [ [0.0, 0.0, 1.78037], [0.89019, 0.89019, 2.67056], [0.0, 1.78037, 0.0], [0.89019, 2.67056, 0.89019], [1.78037, 0.0, 0.0], [2.67056, 0.89019, 0.89019], [1.78037, 1.78037, 1.78037], [2.67056, 2.67056, 2.67056], ] ), cell=np.array([[3.56075, 0.0, 0.0], [0.0, 3.56075, 0.0], [0.0, 0.0, 3.56075]]), energy=0.138, # eV/atom cutoff=cutoff, ) print(f"mp66 has {mp66.n_node} nodes and {mp66.n_edge} edges") mp169 = create_graph( positions=np.array( [ [-0.66993, 0.0, 3.5025], [3.5455, 0.0, 0.00033], [1.45739, 1.22828, 3.5025], [1.41818, 1.22828, 0.00033], ] ), cell=np.array([[4.25464, 0.0, 0.0], [0.0, 2.45656, 0.0], [-1.37907, 0.0, 3.50283]]), energy=0.003, # eV/atom cutoff=cutoff, ) print(f"mp169 has {mp169.n_node} nodes and {mp169.n_edge} edges") Now that we have ``mp47``, ``mp48``, ``mp66`` and ``mp169`` as graphs, we can batch them together to create a dataset. `Batching `_ is an important concept in ``jraph``, it merges different graphs into a single graph. The ``batch`` function takes a list of graphs and returns a single graph with the same fields as the input graphs. The ``n_node`` and ``n_edge`` fields are used to keep track of the number of nodes and edges in each graph. .. jupyter-execute:: dataset = jraph.batch([mp47, mp48, mp66, mp169]) print(f"dataset has {dataset.n_node} nodes and {dataset.n_edge} edges") # Print the shapes of the fields of the dataset. print(jax.tree_util.tree_map(jnp.shape, dataset)) Model ----- Before defining the model, we need to make sure we properly take into account the periodic boundary conditions of the crystals. The model will need to know the relative vectors between the atoms in the crystal. We know the positions of the atoms inside the unit cell, but we need to know the relative vectors between the atoms even if they don't belong to the same cell. .. jupyter-execute:: def get_relative_vectors(senders, receivers, n_edge, positions, cells, shifts): """Compute the relative vectors between the senders and receivers.""" num_nodes = positions.shape[0] num_edges = senders.shape[0] num_graphs = n_edge.shape[0] assert positions.shape == (num_nodes, 3) assert cells.shape == (num_graphs, 3, 3) assert senders.shape == (num_edges,) assert receivers.shape == (num_edges,) assert shifts.shape == (num_edges, 3) # We need to repeat the cells for each edge. cells = jnp.repeat(cells, n_edge, axis=0, total_repeat_length=num_edges) # Compute the two ends of each edge. positions_receivers = positions[receivers] positions_senders = positions[senders] + jnp.einsum("ei,eij->ej", shifts, cells) vectors = e3nn.IrrepsArray("1o", positions_receivers - positions_senders) return vectors Now we define the model layer based on `Nequip architecture `_. .. image:: nequip.png :width: 600 :alt: Nequip architecture For that we will use the implementation available at `github.com/mariogeiger/nequip-jax `_. This implementation provides a ``NEQUIPLayerFlax`` class that implements the *Interaction Block* part of the Nequip architecture (the blue box in the figure above). The *Embedding*, *Self-Interaction* and *Global Pooling* parts of the Nequip architecture are not implemented in ``nequip-jax`` and we will need to implement them ourselves below. The class below defines a ``flax``-Module. ``flax`` is a neural network library that is built on top of ``jax``. ``flax`` provides a ``Module`` class that is similar to ``nn.Module`` in PyTorch. ``flax`` also provides a ``compact`` decorator that allows us to define the ``__call__`` method of the ``Module`` in a more concise way. See `the flax documentation `_ for more details. .. jupyter-execute:: class Model(flax.linen.Module): @flax.linen.compact def __call__(self, graphs): num_nodes = graphs.nodes["positions"].shape[0] senders = graphs.senders receivers = graphs.receivers vectors = get_relative_vectors( senders, receivers, graphs.n_edge, positions=graphs.nodes["positions"], cells=graphs.globals["cells"], shifts=graphs.edges["shifts"], ) # We divide the relative vectors by the cutoff # because NEQUIPLayerFlax assumes a cutoff of 1.0 vectors = vectors / cutoff # Embedding: since we have a single atom type, we don't need embedding # The node features are just ones and the species indices are all zeros features = e3nn.IrrepsArray("0e", jnp.ones((num_nodes, 1))) species = jnp.zeros((num_nodes,), dtype=jnp.int32) # Apply 3 Nequip layers with different internal representations for irreps in [ "32x0e + 8x1o + 8x2e", "32x0e + 8x1o + 8x2e", "32x0e", ]: layer = NEQUIPLayerFlax( avg_num_neighbors=20.0, # average number of neighbors to normalize by output_irreps=irreps, ) features = layer(vectors, features, species, senders, receivers) # Self-Interaction layers features = e3nn.flax.Linear("16x0e")(features) features = e3nn.flax.Linear("0e")(features) # Global Pooling # Average the features (energy prediction) over the nodes of each graph return e3nn.scatter_sum(features, nel=graphs.n_node) / graphs.n_node[:, None] Training -------- Now that we defined the model, we need to define the loss function to train it. For this example we will use the mean squared error. .. jupyter-execute:: def loss_fn(preds, targets): assert preds.shape == targets.shape return jnp.mean(jnp.square(preds - targets)) Now let's use the magic of ``flax`` to initialize the model and use the magic of ``optax`` to define the optimizer and initialize it as well. As optimizer we will use Adam. This optimizer needs to keep track of the average of the gradients and the average of the squared gradients. This is why it has a state. The state is initialized with ``opt.init``. .. jupyter-execute:: random_key = jax.random.PRNGKey(0) # change it to get different initializations # Initialize the model model = Model() params = jax.jit(model.init)(random_key, dataset) # Initialize the optimizer opt = optax.adam(1e-3) opt_state = opt.init(params) Let's define the training step. We will use ``jax.jit`` to compile the function and make it faster. This function takes as input the model parameters, the optimizer state and the dataset and returns the updated optimizer state, the updated model parameters and the loss. .. jupyter-execute:: @jax.jit def train_step(opt_state, params, dataset): """Perform a single training step.""" num_graphs = dataset.n_node.shape[0] # Compute the loss as a function of the parameters def fun(w): preds = model.apply(w, dataset).array.squeeze(1) targets = dataset.globals["energies"] assert preds.shape == (num_graphs,) assert targets.shape == (num_graphs,) return loss_fn(preds, targets) # And take its gradient loss, grad = jax.value_and_grad(fun)(params) # Update the parameters and the optimizer state updates, opt_state = opt.update(grad, opt_state) params = optax.apply_updates(params, updates) return opt_state, params, loss Finally, let's train the model for 1000 iterations. .. jupyter-execute:: losses = [] for _ in range(1000): opt_state, params, loss = train_step(opt_state, params, dataset) losses.append(loss) Did it work? .. jupyter-execute:: plt.plot(losses) plt.xscale("log") plt.yscale("log") plt.xlabel("Iteration") plt.ylabel("Loss")