Euclidean neural networks
What is e3nn-jax
?
e3nn-jax
is a python library built on jax to create \(O(3)\)-equivariant neural networks.
Amuse-bouche
e3nn-jax
contains many tools to manipulate irreps of the group \(O(3)\).
import jax
import jax.numpy as jnp
import haiku as hk
import e3nn_jax as e3nn
# Create a neural network
@hk.without_apply_rng
@hk.transform
def net(x, f):
# the inputs and outputs are all of type e3nn.IrrepsArray
Y = e3nn.spherical_harmonics([0, 1, 2], x, False)
f = e3nn.tensor_product(Y, f)
return e3nn.haiku.Linear("0e + 0o + 1o")(f)
# Create some inputs
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2.0, 0.0]))
f = e3nn.normal("4x0e + 1o + 1e", jax.random.PRNGKey(0), (16,))
print(f"feature vector: {f.shape}")
# Initialize the neural network
w = net.init(jax.random.PRNGKey(0), x, f)
print(jax.tree_util.tree_map(jnp.shape, w))
# Evaluate the neural network
f = net.apply(w, x, f)
print(f"feature vector: {f.shape}")
feature vector: (16, 10)
{'linear': {'w[0,0] 5x0e,1x0e': (5, 1), 'w[1,1] 1x0o,1x0o': (1, 1), 'w[2,2] 7x1o,1x1o': (7, 1)}}
feature vector: (16, 5)
Why rewrite e3nn in jax?
Jax has two beautiful function transformations: jax.grad
and jax.vmap
.
On top of that it is very powerful to optimize code. It can for instance get rid of the dead code:
def f(x):
y = jnp.exp(x)
return x + 1
print(jit_code(f, 1.0))
tmp_0 = parameter(0)
tmp_1 = constant(1)
ROOT tmp_2 = add(tmp_0, tmp_1)
It will reuse the same expression instead of computing it again. The following code calls twice the exponential function, but it will only compute it once.
def f(x):
return jnp.exp(x) + jnp.exp(x)
print(jit_code(f, 1.0))
tmp_0 = parameter(0)
ROOT tmp_1 = fusion(tmp_0), kind=kLoop, calls=
(param_0: f32[]) -> f32[] {
tmp_0 = parameter(0)
tmp_1 = exponential(tmp_0)
ROOT tmp_2 = add(tmp_1, tmp_1)
}
This mechanism is quite robust.
def f(x):
x = jnp.stack([x, x])
y1 = g(x[0])
y2 = h(x[1])
x = jnp.array([y1, y2])
return jnp.sum(x)
@jax.jit
def g(x):
return jax.grad(jnp.exp)((x + 1) - 1)
@jax.jit
def h(x):
return jnp.exp(jnp.cos(0) * x)
print(jit_code(f, 1.0))
tmp_0 = parameter(0)
ROOT tmp_1 = fusion(tmp_0), kind=kLoop, calls=
(param_0: f32[]) -> f32[] {
tmp_0 = parameter(0)
tmp_1 = exponential(tmp_0)
ROOT tmp_2 = add(tmp_1, tmp_1)
}
Irreps
In e3nn we have a notation to define direct sums of irreducible representations of \(O(3)\).
e3nn.Irreps("0e + 2x1o")
1x0e+2x1o
This mean one scalar and two vectors.
0e
stands for the even irrep L=0
and 1o
stands for the odd irrep L=1
.
The suffixes e
and o
stand for even and odd – the representation of parity.
The class Irreps
has many methods to manipulate the representations.
IrrepsArray
IrrepsArray
contains an irreps
attribute of class Irreps
and an array
attribute of class jax.numpy.ndarray
.
x = e3nn.IrrepsArray("2x0e + 1o", jnp.array(
[
[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
]
))
x
2x0e+1x1o
[[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]]
y = e3nn.IrrepsArray("0o + 2x0e", jnp.array(
[
[1.5, 0.0, 1.0],
[0.5, -1.0, 2.0],
[0.5, 1.0, 1.5],
]
))
The irrep index is always the last index.
assert x.irreps.dim == x.shape[-1]
x.shape
(3, 5)
IrrepsArray
handles
binary operations:
x + x
2x0e+1x1o
[[2. 0. 0. 0. 0.]
[0. 2. 2. 0. 0.]
[0. 0. 0. 2. 0.]]
2.0 * x
2x0e+1x1o
[[2. 0. 0. 0. 0.]
[0. 2. 2. 0. 0.]
[0. 0. 0. 2. 0.]]
x / 2.0
2x0e+1x1o
[[0.5 0. 0. 0. 0. ]
[0. 0.5 0.5 0. 0. ]
[0. 0. 0. 0.5 0. ]]
x * y
1x0o+1x0e+1x1o
[[ 1.5 0. 0. 0. 0. ]
[ 0. -1. 2. 0. 0. ]
[ 0. 0. 0. 1.5 0. ]]
x / y
1x0o+1x0e+1x1o
[[ 0.667 nan 0. 0. 0. ]
[ 0. -1. 0.5 0. 0. ]
[ 0. 0. 0. 0.667 0. ]]
1.0 / y
1x0o+2x0e
[[ 0.667 inf 1. ]
[ 2. -1. 0.5 ]
[ 2. 1. 0.667]]
x == x
2x0e+1x0e
[[ True True True]
[ True True True]
[ True True True]]
Indexing:
x[0]
2x0e+1x1o [1. 0. 0. 0. 0.]
x[1, "1o"]
1x1o [1. 0. 0.]
x[..., "1o"]
1x1o
[[0. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]]
x[..., "2x0e + 1o"]
2x0e+1x1o
[[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]]
x[..., 2:]
1x1o
[[0. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]]
Reductions:
e3nn.mean(y)
1x0o+1x0e [0.833 0.75 ]
e3nn.sum(x)
1x0e+1x1o [2. 1. 1. 0.]
e3nn.sum(x, axis=0)
2x0e+1x1o [1. 1. 1. 1. 0.]
e3nn.sum(x, axis=1)
1x0e+1x1o
[[1. 0. 0. 0.]
[1. 1. 0. 0.]
[0. 0. 1. 0.]]
And other operations:
e3nn.concatenate([x, x], axis=0)
2x0e+1x1o
[[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]
[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]]
z = e3nn.concatenate([x, y], axis=1)
z
2x0e+1x1o+1x0o+2x0e
[[ 1. 0. 0. 0. 0. 1.5 0. 1. ]
[ 0. 1. 1. 0. 0. 0.5 -1. 2. ]
[ 0. 0. 0. 1. 0. 0.5 1. 1.5]]
z.sort().simplify()
4x0e+1x0o+1x1o
[[ 1. 0. 0. 1. 1.5 0. 0. 0. ]
[ 0. 1. -1. 2. 0.5 1. 0. 0. ]
[ 0. 0. 1. 1.5 0.5 0. 1. 0. ]]
x.reshape((1, 3, -1))
2x0e+1x1o
[[[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]]]
x1, x2, x3 = x
x1
2x0e+1x1o [1. 0. 0. 0. 0.]
Tensor prodcut
Let’s create a list of 10 vectors (1o
irreps) and a list of 10 2e
irreps and compute their tensor product.
x1 = e3nn.normal("1e", jax.random.PRNGKey(0), (10,))
x2 = e3nn.normal("2e", jax.random.PRNGKey(1), (10,))
e3nn.tensor_product(x1, x2)
1x1e+1x2e+1x3e
[[-0.925 -0.062 -0.629 2.123 0.157 0.152 -1.038 0.138 1.501 -0.347
0.349 -0.305 -0.046 1.441 -0.211]
[ 0.168 0.174 0.083 -0.132 0.815 -1.488 -0.257 -0.862 -0.122 -1.174
0.247 0.086 -0.588 0.297 -0.238]
[ 0.056 0.63 0.542 0.195 -0.642 0.16 -0.273 0.06 -0.644 -0.521
-0.361 0.051 0.484 0.485 0.17 ]
[-0.817 -3.901 1.362 -2.566 -2.797 -0.757 -0.718 -2.082 0.327 -0.871
3.094 -1.871 3.11 -2.477 2.605]
[ 0.274 0.055 -0.735 -0.774 1.836 -1.523 0.06 -0.683 -0.483 -0.982
-0.497 -0.248 -0.336 0.632 -1.512]
[ 2.06 -0.164 0.05 -0.315 0.269 0.768 -0.673 -0.468 -0.994 0.209
-0.874 -1.26 -0.728 -1.465 -1.035]
[-0.79 -2.676 -1.826 1.505 -0.915 0.402 -1.097 2.015 -1.464 -0.847
0.851 0.934 -3.423 -2.184 -0.041]
[ 0.787 2.846 2.817 1.569 -1.082 0.364 -1.598 1.441 2.436 2.213
-1.315 -2.402 -1.448 -2.138 -2.046]
[ 0.654 0.445 0.052 0.073 -1.818 0.055 0.661 -0.202 1.061 0.264
-0.087 0.208 1.068 0.292 0.606]
[-1.024 -0.674 -0.5 0.04 0.15 -0.994 0.27 0.464 0.144 -0.039
0.126 -0.34 -1.091 -1.213 -0.682]]
Learnable Modules
We use dm-haiku
or flax
to create parameterized modules.
Here is an example using e3nn_jax.flax.Linear
.
import flax
model = e3nn.flax.Linear("1e")
x = e3nn.IrrepsArray("2x1e", jnp.array([1.0, 0.0, 0.0, 0.0, 2.0, 0.0]))
# initialize the parameters randomly
w = model.init(jax.random.PRNGKey(0), x)
# apply the module
model.apply(w, x)
1x1e [ 0.054 -2.285 0. ]
Spherical Harmonics
Let’s compute the sphercal harmonics of degree \(L=2\) for \(\vec x = (0, 0, 1)\) using the function e3nn_jax.spherical_harmonics
.
vector = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 1.0]))
e3nn.spherical_harmonics(2, vector, normalize=True)
1x2e [ 0. 0. -1.118 0. 1.936]
Note the normalize
option. If normalize
is False
, the function becomes an homogeneous polynomial, see below:
a = 3.5
assert jnp.allclose(
e3nn.spherical_harmonics(2, a * vector, False).array,
a**2 * e3nn.spherical_harmonics(2, vector, False).array,
)
Gradient
The gradient of an equivariant function is also equivariant.
If a function inputs and outputs IrrepsArray
, we can compute its gradient using e3nn_jax.grad
.
def cross_product(x, y):
return e3nn.tensor_product(x, y)["1e"]
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2.0, 0.0]))
y = e3nn.IrrepsArray("1o", jnp.array([2.0, 2.0, 1.0]))
e3nn.grad(cross_product)(x, y)
1x0o+1x1o+1x2o [ 0. 2. 2. 1. 0. -0. 0. -0. 0.]
Here is a vector and its spherical harmonics of degree \(L=2\): \(Y^2\)
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2, 3]))
e3nn.spherical_harmonics("2e", x, False)
1x2e [11.619 7.746 -2.236 23.238 15.492]
We can verify that it can also be obtained by taking the gradient of \(Y^3\)
f = e3nn.grad(lambda x: e3nn.spherical_harmonics("3o", x, False))
0.18443 * f(x)["2e"]
1x2e [11.619 7.746 -2.236 23.238 15.492]