Euclidean neural networks

What is e3nn-jax?

e3nn-jax is a python library built on jax to create \(O(3)\)-equivariant neural networks.

Amuse-bouche

e3nn-jax contains many tools to manipulate irreps of the group \(O(3)\).

import jax
import jax.numpy as jnp
import haiku as hk

import e3nn_jax as e3nn


# Create a neural network
@hk.without_apply_rng
@hk.transform
def net(x, f):
    # the inputs and outputs are all of type e3nn.IrrepsArray
    Y = e3nn.spherical_harmonics([0, 1, 2], x, False)
    f = e3nn.tensor_product(Y, f)
    return e3nn.haiku.Linear("0e + 0o + 1o")(f)

# Create some inputs
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2.0, 0.0]))
f = e3nn.normal("4x0e + 1o + 1e", jax.random.PRNGKey(0), (16,))
print(f"feature vector: {f.shape}")

# Initialize the neural network
w = net.init(jax.random.PRNGKey(0), x, f)
print(jax.tree_util.tree_map(jnp.shape, w))

# Evaluate the neural network
f = net.apply(w, x, f)
print(f"feature vector: {f.shape}")
feature vector: (16, 10)
{'linear': {'w[0,0] 5x0e,1x0e': (5, 1), 'w[1,1] 1x0o,1x0o': (1, 1), 'w[2,2] 7x1o,1x1o': (7, 1)}}
feature vector: (16, 5)

Why rewrite e3nn in jax?

Jax has two beautiful function transformations: jax.grad and jax.vmap.

On top of that it is very powerful to optimize code. It can for instance get rid of the dead code:

def f(x):
    y = jnp.exp(x)
    return x + 1

print(jit_code(f, 1.0))
tmp_0 = parameter(0)
tmp_1 = constant(1)
ROOT tmp_2 = add(tmp_0, tmp_1)

It will reuse the same expression instead of computing it again. The following code calls twice the exponential function, but it will only compute it once.

def f(x):
    return jnp.exp(x) + jnp.exp(x)

print(jit_code(f, 1.0))
tmp_0 = parameter(0)
ROOT tmp_1 = fusion(tmp_0), kind=kLoop, calls=
(param_0: f32[]) -> f32[] {
  tmp_0 = parameter(0)
  tmp_1 = exponential(tmp_0)
  ROOT tmp_2 = add(tmp_1, tmp_1)
}

This mechanism is quite robust.

def f(x):
    x = jnp.stack([x, x])
    y1 = g(x[0])
    y2 = h(x[1])
    x = jnp.array([y1, y2])
    return jnp.sum(x)

@jax.jit
def g(x):
    return jax.grad(jnp.exp)((x + 1) - 1)

@jax.jit
def h(x):
    return jnp.exp(jnp.cos(0) * x)

print(jit_code(f, 1.0))
tmp_0 = parameter(0)
ROOT tmp_1 = fusion(tmp_0), kind=kLoop, calls=
(param_0: f32[]) -> f32[] {
  tmp_0 = parameter(0)
  tmp_1 = exponential(tmp_0)
  ROOT tmp_2 = add(tmp_1, tmp_1)
}

Irreps

In e3nn we have a notation to define direct sums of irreducible representations of \(O(3)\).

e3nn.Irreps("0e + 2x1o")
1x0e+2x1o

This mean one scalar and two vectors. 0e stands for the even irrep L=0 and 1o stands for the odd irrep L=1. The suffixes e and o stand for even and odd – the representation of parity.

The class Irreps has many methods to manipulate the representations.

IrrepsArray

IrrepsArray contains an irreps attribute of class Irreps and an array attribute of class jax.numpy.ndarray.

x = e3nn.IrrepsArray("2x0e + 1o", jnp.array(
    [
        [1.0, 0.0,  0.0, 0.0, 0.0],
        [0.0, 1.0,  1.0, 0.0, 0.0],
        [0.0, 0.0,  0.0, 1.0, 0.0],
    ]
))
x
2x0e+1x1o
[[1. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]
y = e3nn.IrrepsArray("0o + 2x0e", jnp.array(
    [
        [1.5,  0.0, 1.0],
        [0.5, -1.0, 2.0],
        [0.5,  1.0, 1.5],
    ]
))

The irrep index is always the last index.

assert x.irreps.dim == x.shape[-1]
x.shape
(3, 5)

IrrepsArray handles

  • binary operations:

x + x
2x0e+1x1o
[[2. 0. 0. 0. 0.]
 [0. 2. 2. 0. 0.]
 [0. 0. 0. 2. 0.]]
2.0 * x
2x0e+1x1o
[[2. 0. 0. 0. 0.]
 [0. 2. 2. 0. 0.]
 [0. 0. 0. 2. 0.]]
x / 2.0
2x0e+1x1o
[[0.5 0.  0.  0.  0. ]
 [0.  0.5 0.5 0.  0. ]
 [0.  0.  0.  0.5 0. ]]
x * y
1x0o+1x0e+1x1o
[[ 1.5  0.   0.   0.   0. ]
 [ 0.  -1.   2.   0.   0. ]
 [ 0.   0.   0.   1.5  0. ]]
x / y
1x0o+1x0e+1x1o
[[ 0.667    nan  0.     0.     0.   ]
 [ 0.    -1.     0.5    0.     0.   ]
 [ 0.     0.     0.     0.667  0.   ]]
1.0 / y
1x0o+2x0e
[[ 0.667    inf  1.   ]
 [ 2.    -1.     0.5  ]
 [ 2.     1.     0.667]]
x == x
2x0e+1x0e
[[ True  True  True]
 [ True  True  True]
 [ True  True  True]]
  • Indexing:

x[0]
2x0e+1x1o [1. 0. 0. 0. 0.]
x[1, "1o"]
1x1o [1. 0. 0.]
x[..., "1o"]
1x1o
[[0. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]]
x[..., "2x0e + 1o"]
2x0e+1x1o
[[1. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]
x[..., 2:]
1x1o
[[0. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]]
  • Reductions:

e3nn.mean(y)
1x0o+1x0e [0.833 0.75 ]
e3nn.sum(x)
1x0e+1x1o [2. 1. 1. 0.]
e3nn.sum(x, axis=0)
2x0e+1x1o [1. 1. 1. 1. 0.]
e3nn.sum(x, axis=1)
1x0e+1x1o
[[1. 0. 0. 0.]
 [1. 1. 0. 0.]
 [0. 0. 1. 0.]]
  • And other operations:

e3nn.concatenate([x, x], axis=0)
2x0e+1x1o
[[1. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 0. 1. 0.]
 [1. 0. 0. 0. 0.]
 [0. 1. 1. 0. 0.]
 [0. 0. 0. 1. 0.]]
z = e3nn.concatenate([x, y], axis=1)
z
2x0e+1x1o+1x0o+2x0e
[[ 1.   0.   0.   0.   0.   1.5  0.   1. ]
 [ 0.   1.   1.   0.   0.   0.5 -1.   2. ]
 [ 0.   0.   0.   1.   0.   0.5  1.   1.5]]
z.sort().simplify()
4x0e+1x0o+1x1o
[[ 1.   0.   0.   1.   1.5  0.   0.   0. ]
 [ 0.   1.  -1.   2.   0.5  1.   0.   0. ]
 [ 0.   0.   1.   1.5  0.5  0.   1.   0. ]]
x.reshape((1, 3, -1))
2x0e+1x1o
[[[1. 0. 0. 0. 0.]
  [0. 1. 1. 0. 0.]
  [0. 0. 0. 1. 0.]]]
x1, x2, x3 = x
x1
2x0e+1x1o [1. 0. 0. 0. 0.]

Tensor prodcut

Let’s create a list of 10 vectors (1o irreps) and a list of 10 2e irreps and compute their tensor product.

x1 = e3nn.normal("1e", jax.random.PRNGKey(0), (10,))
x2 = e3nn.normal("2e", jax.random.PRNGKey(1), (10,))
e3nn.tensor_product(x1, x2)
1x1e+1x2e+1x3e
[[-0.925 -0.062 -0.629  2.123  0.157  0.152 -1.038  0.138  1.501 -0.347
   0.349 -0.305 -0.046  1.441 -0.211]
 [ 0.168  0.174  0.083 -0.132  0.815 -1.488 -0.257 -0.862 -0.122 -1.174
   0.247  0.086 -0.588  0.297 -0.238]
 [ 0.056  0.63   0.542  0.195 -0.642  0.16  -0.273  0.06  -0.644 -0.521
  -0.361  0.051  0.484  0.485  0.17 ]
 [-0.817 -3.901  1.362 -2.566 -2.797 -0.757 -0.718 -2.082  0.327 -0.871
   3.094 -1.871  3.11  -2.477  2.605]
 [ 0.274  0.055 -0.735 -0.774  1.836 -1.523  0.06  -0.683 -0.483 -0.982
  -0.497 -0.248 -0.336  0.632 -1.512]
 [ 2.06  -0.164  0.05  -0.315  0.269  0.768 -0.673 -0.468 -0.994  0.209
  -0.874 -1.26  -0.728 -1.465 -1.035]
 [-0.79  -2.676 -1.826  1.505 -0.915  0.402 -1.097  2.015 -1.464 -0.847
   0.851  0.934 -3.423 -2.184 -0.041]
 [ 0.787  2.846  2.817  1.569 -1.082  0.364 -1.598  1.441  2.436  2.213
  -1.315 -2.402 -1.448 -2.138 -2.046]
 [ 0.654  0.445  0.052  0.073 -1.818  0.055  0.661 -0.202  1.061  0.264
  -0.087  0.208  1.068  0.292  0.606]
 [-1.024 -0.674 -0.5    0.04   0.15  -0.994  0.27   0.464  0.144 -0.039
   0.126 -0.34  -1.091 -1.213 -0.682]]

Learnable Modules

We use dm-haiku or flax to create parameterized modules. Here is an example using e3nn_jax.flax.Linear.

import flax

model = e3nn.flax.Linear("1e")

x = e3nn.IrrepsArray("2x1e", jnp.array([1.0, 0.0, 0.0, 0.0, 2.0, 0.0]))

# initialize the parameters randomly
w = model.init(jax.random.PRNGKey(0), x)

# apply the module
model.apply(w, x)
1x1e [ 0.054 -2.285  0.   ]

Spherical Harmonics

Let’s compute the sphercal harmonics of degree \(L=2\) for \(\vec x = (0, 0, 1)\) using the function e3nn_jax.spherical_harmonics.

vector = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 1.0]))
e3nn.spherical_harmonics(2, vector, normalize=True)
1x2e [ 0.     0.    -1.118  0.     1.936]

Note the normalize option. If normalize is False, the function becomes an homogeneous polynomial, see below:

a = 3.5

assert jnp.allclose(
    e3nn.spherical_harmonics(2, a * vector, False).array,
    a**2 * e3nn.spherical_harmonics(2, vector, False).array,
)

Gradient

The gradient of an equivariant function is also equivariant. If a function inputs and outputs IrrepsArray, we can compute its gradient using e3nn_jax.grad.

def cross_product(x, y):
    return e3nn.tensor_product(x, y)["1e"]

x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2.0, 0.0]))
y = e3nn.IrrepsArray("1o", jnp.array([2.0, 2.0, 1.0]))
e3nn.grad(cross_product)(x, y)
1x0o+1x1o+1x2o [ 0.  2.  2.  1.  0. -0.  0. -0.  0.]

Here is a vector and its spherical harmonics of degree \(L=2\): \(Y^2\)

x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2, 3]))
e3nn.spherical_harmonics("2e", x, False)
1x2e [11.619  7.746 -2.236 23.238 15.492]

We can verify that it can also be obtained by taking the gradient of \(Y^3\)

f = e3nn.grad(lambda x: e3nn.spherical_harmonics("3o", x, False))
0.18443 * f(x)["2e"]
1x2e [11.619  7.746 -2.236 23.238 15.492]

Content