Euclidean neural networks

What is e3nn-jax?

e3nn-jax is a python library built on jax to create \(O(3)\)-equivariant neural networks.

Amuse-bouche

e3nn-jax contains many tools to manipulate irreps of the group \(O(3)\).

import jax
import jax.numpy as jnp
import haiku as hk

import e3nn_jax as e3nn


# Create a neural network
@hk.without_apply_rng
@hk.transform
def net(x, f):
    # the inputs and outputs are all of type e3nn.IrrepsArray
    Y = e3nn.spherical_harmonics([0, 1, 2], x, False)
    f = e3nn.tensor_product(Y, f)
    return e3nn.haiku.Linear("0e + 0o + 1o")(f)

# Create some inputs
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2.0, 0.0]))
f = e3nn.normal("4x0e + 1o + 1e", jax.random.PRNGKey(0), (16,))
print(f"feature vector: {f.shape}")

# Initialize the neural network
w = net.init(jax.random.PRNGKey(0), x, f)
print(jax.tree_util.tree_map(jnp.shape, w))

# Evaluate the neural network
f = net.apply(w, x, f)
print(f"feature vector: {f.shape}")
feature vector: (16, 10)
{'linear': {'w[0,0] 5x0e,1x0e': (5, 1), 'w[1,1] 1x0o,1x0o': (1, 1), 'w[2,2] 7x1o,1x1o': (7, 1)}}
feature vector: (16, 5)

Why rewrite e3nn in jax?

Jax has two beautiful function transformations: jax.grad and jax.vmap.

On top of that it is very powerful to optimize code. It can for instance get rid of the dead code:

def f(x):
    y = jnp.exp(x)
    return x + 1

print(jit_code(f, 1.0))
tmp_0 = parameter(0)
tmp_1 = constant(1)
ROOT tmp_2 = add(tmp_0, tmp_1)

It will reuse the same expression instead of computing it again. The following code calls twice the exponential function, but it will only compute it once.

def f(x):
    return jnp.exp(x) + jnp.exp(x)

print(jit_code(f, 1.0))
tmp_0 = parameter(0)
ROOT tmp_1 = fusion(tmp_0), kind=kLoop, calls=
(param_0: f32[]) -> f32[] {
  tmp_0 = parameter(0)
  tmp_1 = exponential(tmp_0)
  ROOT tmp_2 = add(tmp_1, tmp_1)
}

This mechanism is quite robust.

def f(x):
    x = jnp.stack([x, x])
    y1 = g(x[0])
    y2 = h(x[1])
    x = jnp.array([y1, y2])
    return jnp.sum(x)

@jax.jit
def g(x):
    return jax.grad(jnp.exp)((x + 1) - 1)

@jax.jit
def h(x):
    return jnp.exp(jnp.cos(0) * x)

print(jit_code(f, 1.0))
tmp_0 = parameter(0)
ROOT tmp_1 = fusion(tmp_0), kind=kLoop, calls=
(param_0: f32[]) -> f32[] {
  tmp_0 = parameter(0)
  tmp_1 = exponential(tmp_0)
  ROOT tmp_2 = add(tmp_1, tmp_1)
}

Irreps

In e3nn we have a notation to define direct sums of irreducible representations of \(O(3)\).

e3nn.Irreps("0e + 2x1o")
1x0e+2x1o

This mean one scalar and two vectors. 0e stands for the even irrep L=0 and 1o stands for the odd irrep L=1. The suffixes e and o stand for even and odd – the representation of parity.

The class Irreps has many methods to manipulate the representations.

IrrepsArray

IrrepsArray contains an irreps attribute of class Irreps and an array attribute of class jax.numpy.ndarray.

x = e3nn.IrrepsArray("2x0e + 1o", jnp.array(
    [
        [1.0, 0.0,  0.0, 0.0, 0.0],
        [0.0, 1.0,  1.0, 0.0, 0.0],
        [0.0, 0.0,  0.0, 1.0, 0.0],
    ]
))
x
2x0e+1x1o
[[1. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]
y = e3nn.IrrepsArray("0o + 2x0e", jnp.array(
    [
        [1.5,  0.0, 1.0],
        [0.5, -1.0, 2.0],
        [0.5,  1.0, 1.5],
    ]
))

The irrep index is always the last index.

assert x.irreps.dim == x.shape[-1]
x.shape
(3, 5)

IrrepsArray handles

  • binary operations:

x + x
2x0e+1x1o
[[2. 0. 0. 0. 0.]
 [0. 2. 2. 0. 0.]
 [0. 0. 0. 2. 0.]]
2.0 * x
2x0e+1x1o
[[2. 0. 0. 0. 0.]
 [0. 2. 2. 0. 0.]
 [0. 0. 0. 2. 0.]]
x / 2.0
2x0e+1x1o
[[0.5 0.  0.  0.  0. ]
 [0.  0.5 0.5 0.  0. ]
 [0.  0.  0.  0.5 0. ]]
x * y
1x0o+1x0e+1x1o
[[ 1.5  0.   0.   0.   0. ]
 [ 0.  -1.   2.   0.   0. ]
 [ 0.   0.   0.   1.5  0. ]]
x / y
1x0o+1x0e+1x1o
[[ 0.667    nan  0.     0.     0.   ]
 [ 0.    -1.     0.5    0.     0.   ]
 [ 0.     0.     0.     0.667  0.   ]]
1.0 / y
1x0o+2x0e
[[ 0.667    inf  1.   ]
 [ 2.    -1.     0.5  ]
 [ 2.     1.     0.667]]
x == x
True
  • Indexing:

x[0]
2x0e+1x1o [1. 0. 0. 0. 0.]
x[1, "1o"]
1x1o [1. 0. 0.]
x[..., "1o"]
1x1o
[[0. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]]
x[..., "2x0e + 1o"]
2x0e+1x1o
[[1. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]
x[..., 2:]
1x1o
[[0. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]]
  • Reductions:

e3nn.mean(y)
1x0o+1x0e [0.833 0.75 ]
e3nn.sum(x)
1x0e+1x1o [2. 1. 1. 0.]
e3nn.sum(x, axis=0)
2x0e+1x1o [1. 1. 1. 1. 0.]
e3nn.sum(x, axis=1)
1x0e+1x1o
[[1. 0. 0. 0.]
 [1. 1. 0. 0.]
 [0. 0. 1. 0.]]
  • And other operations:

e3nn.concatenate([x, x], axis=0)
2x0e+1x1o
[[1. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]
z = e3nn.concatenate([x, y], axis=1)
z
2x0e+1x1o+1x0o+2x0e
[[ 1.   0.   0.   0.   0.   1.5  0.   1. ]
 [ 0.   1.   1.   0.   0.   0.5 -1.   2. ]
 [ 0.   0.   0.   1.   0.   0.5  1.   1.5]]
z.sort().simplify()
4x0e+1x0o+1x1o
[[ 1.   0.   0.   1.   1.5  0.   0.   0. ]
 [ 0.   1.  -1.   2.   0.5  1.   0.   0. ]
 [ 0.   0.   1.   1.5  0.5  0.   1.   0. ]]
x.reshape((1, 3, -1))
2x0e+1x1o
[[[1. 0. 0. 0. 0.]
  [0. 1. 1. 0. 0.]
  [0. 0. 0. 1. 0.]]]
x1, x2, x3 = x
x1
2x0e+1x1o [1. 0. 0. 0. 0.]

Tensor prodcut

Let’s create a list of 10 vectors (1o irreps) and a list of 10 2e irreps and compute their tensor product.

x1 = e3nn.normal("1e", jax.random.PRNGKey(0), (10,))
x2 = e3nn.normal("2e", jax.random.PRNGKey(1), (10,))
e3nn.tensor_product(x1, x2)
1x1e+1x2e+1x3e
[[ 1.298  0.097 -0.336  0.303  0.331 -1.069 -0.206  0.028  0.392 -1.2
   0.746  0.279 -0.269  0.794 -0.689]
 [ 1.029  2.11  -1.19  -0.06  -1.281  1.413  2.249 -0.136 -0.969  0.509
   2.447  0.98   0.832 -2.133 -1.167]
 [-0.192 -0.66  -0.658  0.517 -0.033 -0.096  0.062 -0.971 -0.66   0.578
   0.494 -0.705  0.496  0.458  0.07 ]
 [ 0.012 -1.358 -2.218  2.539 -1.263  0.125  0.628  2.36   0.507 -1.482
  -0.664 -2.479 -1.348  2.419 -0.669]
 [ 0.319 -0.019 -0.215  0.581  0.041  0.969 -0.501  0.015 -0.04  -0.101
   0.169 -0.267 -0.497 -0.787  0.097]
 [ 0.017  0.067  1.882 -0.245  0.352  0.158 -0.081  0.1    0.24   0.084
  -0.113  0.411 -1.524 -0.454  1.631]
 [-1.379 -0.187 -0.364 -1.92   0.956 -0.722  1.131 -1.364  0.651 -0.047
  -1.72  -0.924  0.303 -1.579 -0.639]
 [ 1.874  1.268 -0.703 -1.074 -0.841 -1.442 -0.058 -0.063  1.063 -2.184
   0.989 -1.1   -1.189  0.784  0.534]
 [ 1.645  0.055 -0.504 -0.388 -0.211  1.101 -0.15  -0.652 -0.631  0.193
   0.322 -0.623 -0.856 -1.7   -0.908]
 [ 0.038 -0.129 -0.889 -0.651  0.862 -1.287  0.512 -0.852  0.119 -0.953
  -0.594  0.024  0.227 -0.163 -1.328]]

Learnable Modules

We use dm-haiku or flax to create parameterized modules. Here is an example using e3nn_jax.flax.Linear.

import flax

model = e3nn.flax.Linear("1e")

x = e3nn.IrrepsArray("2x1e", jnp.array([1.0, 0.0, 0.0, 0.0, 2.0, 0.0]))

# initialize the parameters randomly
w = model.init(jax.random.PRNGKey(0), x)

# apply the module
model.apply(w, x)
1x1e [-1.438 -2.562 -0.   ]

Spherical Harmonics

Let’s compute the sphercal harmonics of degree \(L=2\) for \(\vec x = (0, 0, 1)\) using the function e3nn_jax.spherical_harmonics.

vector = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 1.0]))
e3nn.spherical_harmonics(2, vector, normalize=True)
1x2e [ 0.     0.    -1.118  0.     1.936]

Note the normalize option. If normalize is False, the function becomes an homogeneous polynomial, see below:

a = 3.5

assert jnp.allclose(
    e3nn.spherical_harmonics(2, a * vector, False).array,
    a**2 * e3nn.spherical_harmonics(2, vector, False).array,
)

Gradient

The gradient of an equivariant function is also equivariant. If a function inputs and outputs IrrepsArray, we can compute its gradient using e3nn_jax.grad.

def cross_product(x, y):
    return e3nn.tensor_product(x, y)["1e"]

x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2.0, 0.0]))
y = e3nn.IrrepsArray("1o", jnp.array([2.0, 2.0, 1.0]))
e3nn.grad(cross_product)(x, y)
1x0o+1x1o+1x2o [ 0.  2.  2.  1.  0. -0.  0. -0.  0.]

Here is a vector and its spherical harmonics of degree \(L=2\): \(Y^2\)

x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2, 3]))
e3nn.spherical_harmonics("2e", x, False)
1x2e [11.619  7.746 -2.236 23.238 15.492]

We can verify that it can also be obtained by taking the gradient of \(Y^3\)

f = e3nn.grad(lambda x: e3nn.spherical_harmonics("3o", x, False))
0.18443 * f(x)["2e"]
1x2e [11.619  7.746 -2.236 23.238 15.492]

Content