Euclidean neural networks
What is e3nn-jax
?
e3nn-jax
is a python library built on jax to create \(O(3)\)-equivariant neural networks.
Amuse-bouche
e3nn-jax
contains many tools to manipulate irreps of the group \(O(3)\).
import jax
import jax.numpy as jnp
import haiku as hk
import e3nn_jax as e3nn
# Create a neural network
@hk.without_apply_rng
@hk.transform
def net(x, f):
# the inputs and outputs are all of type e3nn.IrrepsArray
Y = e3nn.spherical_harmonics([0, 1, 2], x, False)
f = e3nn.tensor_product(Y, f)
return e3nn.haiku.Linear("0e + 0o + 1o")(f)
# Create some inputs
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2.0, 0.0]))
f = e3nn.normal("4x0e + 1o + 1e", jax.random.PRNGKey(0), (16,))
print(f"feature vector: {f.shape}")
# Initialize the neural network
w = net.init(jax.random.PRNGKey(0), x, f)
print(jax.tree_util.tree_map(jnp.shape, w))
# Evaluate the neural network
f = net.apply(w, x, f)
print(f"feature vector: {f.shape}")
feature vector: (16, 10)
{'linear': {'w[0,0] 5x0e,1x0e': (5, 1), 'w[1,1] 1x0o,1x0o': (1, 1), 'w[2,2] 7x1o,1x1o': (7, 1)}}
feature vector: (16, 5)
Why rewrite e3nn in jax?
Jax has two beautiful function transformations: jax.grad
and jax.vmap
.
On top of that it is very powerful to optimize code. It can for instance get rid of the dead code:
def f(x):
y = jnp.exp(x)
return x + 1
print(jit_code(f, 1.0))
tmp_0 = parameter(0)
tmp_1 = constant(1)
ROOT tmp_2 = add(tmp_0, tmp_1)
It will reuse the same expression instead of computing it again. The following code calls twice the exponential function, but it will only compute it once.
def f(x):
return jnp.exp(x) + jnp.exp(x)
print(jit_code(f, 1.0))
tmp_0 = parameter(0)
ROOT tmp_1 = fusion(tmp_0), kind=kLoop, calls=
(param_0: f32[]) -> f32[] {
tmp_0 = parameter(0)
tmp_1 = exponential(tmp_0)
ROOT tmp_2 = add(tmp_1, tmp_1)
}
This mechanism is quite robust.
def f(x):
x = jnp.stack([x, x])
y1 = g(x[0])
y2 = h(x[1])
x = jnp.array([y1, y2])
return jnp.sum(x)
@jax.jit
def g(x):
return jax.grad(jnp.exp)((x + 1) - 1)
@jax.jit
def h(x):
return jnp.exp(jnp.cos(0) * x)
print(jit_code(f, 1.0))
tmp_0 = parameter(0)
ROOT tmp_1 = fusion(tmp_0), kind=kLoop, calls=
(param_0: f32[]) -> f32[] {
tmp_0 = parameter(0)
tmp_1 = exponential(tmp_0)
ROOT tmp_2 = add(tmp_1, tmp_1)
}
Irreps
In e3nn we have a notation to define direct sums of irreducible representations of \(O(3)\).
e3nn.Irreps("0e + 2x1o")
1x0e+2x1o
This mean one scalar and two vectors.
0e
stands for the even irrep L=0
and 1o
stands for the odd irrep L=1
.
The suffixes e
and o
stand for even and odd – the representation of parity.
The class Irreps
has many methods to manipulate the representations.
IrrepsArray
IrrepsArray
contains an irreps
attribute of class Irreps
and an array
attribute of class jax.numpy.ndarray
.
x = e3nn.IrrepsArray("2x0e + 1o", jnp.array(
[
[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 1.0, 0.0],
]
))
x
2x0e+1x1o
[[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]]
y = e3nn.IrrepsArray("0o + 2x0e", jnp.array(
[
[1.5, 0.0, 1.0],
[0.5, -1.0, 2.0],
[0.5, 1.0, 1.5],
]
))
The irrep index is always the last index.
assert x.irreps.dim == x.shape[-1]
x.shape
(3, 5)
IrrepsArray
handles
binary operations:
x + x
2x0e+1x1o
[[2. 0. 0. 0. 0.]
[0. 2. 2. 0. 0.]
[0. 0. 0. 2. 0.]]
2.0 * x
2x0e+1x1o
[[2. 0. 0. 0. 0.]
[0. 2. 2. 0. 0.]
[0. 0. 0. 2. 0.]]
x / 2.0
2x0e+1x1o
[[0.5 0. 0. 0. 0. ]
[0. 0.5 0.5 0. 0. ]
[0. 0. 0. 0.5 0. ]]
x * y
1x0o+1x0e+1x1o
[[ 1.5 0. 0. 0. 0. ]
[ 0. -1. 2. 0. 0. ]
[ 0. 0. 0. 1.5 0. ]]
x / y
1x0o+1x0e+1x1o
[[ 0.667 nan 0. 0. 0. ]
[ 0. -1. 0.5 0. 0. ]
[ 0. 0. 0. 0.667 0. ]]
1.0 / y
1x0o+2x0e
[[ 0.667 inf 1. ]
[ 2. -1. 0.5 ]
[ 2. 1. 0.667]]
x == x
True
Indexing:
x[0]
2x0e+1x1o [1. 0. 0. 0. 0.]
x[1, "1o"]
1x1o [1. 0. 0.]
x[..., "1o"]
1x1o
[[0. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]]
x[..., "2x0e + 1o"]
2x0e+1x1o
[[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]]
x[..., 2:]
1x1o
[[0. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]]
Reductions:
e3nn.mean(y)
1x0o+1x0e [0.833 0.75 ]
e3nn.sum(x)
1x0e+1x1o [2. 1. 1. 0.]
e3nn.sum(x, axis=0)
2x0e+1x1o [1. 1. 1. 1. 0.]
e3nn.sum(x, axis=1)
1x0e+1x1o
[[1. 0. 0. 0.]
[1. 1. 0. 0.]
[0. 0. 1. 0.]]
And other operations:
e3nn.concatenate([x, x], axis=0)
2x0e+1x1o
[[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]
[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]]
z = e3nn.concatenate([x, y], axis=1)
z
2x0e+1x1o+1x0o+2x0e
[[ 1. 0. 0. 0. 0. 1.5 0. 1. ]
[ 0. 1. 1. 0. 0. 0.5 -1. 2. ]
[ 0. 0. 0. 1. 0. 0.5 1. 1.5]]
z.sort().simplify()
4x0e+1x0o+1x1o
[[ 1. 0. 0. 1. 1.5 0. 0. 0. ]
[ 0. 1. -1. 2. 0.5 1. 0. 0. ]
[ 0. 0. 1. 1.5 0.5 0. 1. 0. ]]
x.reshape((1, 3, -1))
2x0e+1x1o
[[[1. 0. 0. 0. 0.]
[0. 1. 1. 0. 0.]
[0. 0. 0. 1. 0.]]]
x1, x2, x3 = x
x1
2x0e+1x1o [1. 0. 0. 0. 0.]
Tensor prodcut
Let’s create a list of 10 vectors (1o
irreps) and a list of 10 2e
irreps and compute their tensor product.
x1 = e3nn.normal("1e", jax.random.PRNGKey(0), (10,))
x2 = e3nn.normal("2e", jax.random.PRNGKey(1), (10,))
e3nn.tensor_product(x1, x2)
1x1e+1x2e+1x3e
[[ 1.298 0.097 -0.336 0.303 0.331 -1.069 -0.206 0.028 0.392 -1.2
0.746 0.279 -0.269 0.794 -0.689]
[ 1.029 2.11 -1.19 -0.06 -1.281 1.413 2.249 -0.136 -0.969 0.509
2.447 0.98 0.832 -2.133 -1.167]
[-0.192 -0.66 -0.658 0.517 -0.033 -0.096 0.062 -0.971 -0.66 0.578
0.494 -0.705 0.496 0.458 0.07 ]
[ 0.012 -1.358 -2.218 2.539 -1.263 0.125 0.628 2.36 0.507 -1.482
-0.664 -2.479 -1.348 2.419 -0.669]
[ 0.319 -0.019 -0.215 0.581 0.041 0.969 -0.501 0.015 -0.04 -0.101
0.169 -0.267 -0.497 -0.787 0.097]
[ 0.017 0.067 1.882 -0.245 0.352 0.158 -0.081 0.1 0.24 0.084
-0.113 0.411 -1.524 -0.454 1.631]
[-1.379 -0.187 -0.364 -1.92 0.956 -0.722 1.131 -1.364 0.651 -0.047
-1.72 -0.924 0.303 -1.579 -0.639]
[ 1.874 1.268 -0.703 -1.074 -0.841 -1.442 -0.058 -0.063 1.063 -2.184
0.989 -1.1 -1.189 0.784 0.534]
[ 1.645 0.055 -0.504 -0.388 -0.211 1.101 -0.15 -0.652 -0.631 0.193
0.322 -0.623 -0.856 -1.7 -0.908]
[ 0.038 -0.129 -0.889 -0.651 0.862 -1.287 0.512 -0.852 0.119 -0.953
-0.594 0.024 0.227 -0.163 -1.328]]
Learnable Modules
We use dm-haiku
or flax
to create parameterized modules.
Here is an example using e3nn_jax.flax.Linear
.
import flax
model = e3nn.flax.Linear("1e")
x = e3nn.IrrepsArray("2x1e", jnp.array([1.0, 0.0, 0.0, 0.0, 2.0, 0.0]))
# initialize the parameters randomly
w = model.init(jax.random.PRNGKey(0), x)
# apply the module
model.apply(w, x)
1x1e [-1.438 -2.562 -0. ]
Spherical Harmonics
Let’s compute the sphercal harmonics of degree \(L=2\) for \(\vec x = (0, 0, 1)\) using the function e3nn_jax.spherical_harmonics
.
vector = e3nn.IrrepsArray("1o", jnp.array([0.0, 0.0, 1.0]))
e3nn.spherical_harmonics(2, vector, normalize=True)
1x2e [ 0. 0. -1.118 0. 1.936]
Note the normalize
option. If normalize
is False
, the function becomes an homogeneous polynomial, see below:
a = 3.5
assert jnp.allclose(
e3nn.spherical_harmonics(2, a * vector, False).array,
a**2 * e3nn.spherical_harmonics(2, vector, False).array,
)
Gradient
The gradient of an equivariant function is also equivariant.
If a function inputs and outputs IrrepsArray
, we can compute its gradient using e3nn_jax.grad
.
def cross_product(x, y):
return e3nn.tensor_product(x, y)["1e"]
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2.0, 0.0]))
y = e3nn.IrrepsArray("1o", jnp.array([2.0, 2.0, 1.0]))
e3nn.grad(cross_product)(x, y)
1x0o+1x1o+1x2o [ 0. 2. 2. 1. 0. -0. 0. -0. 0.]
Here is a vector and its spherical harmonics of degree \(L=2\): \(Y^2\)
x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2, 3]))
e3nn.spherical_harmonics("2e", x, False)
1x2e [11.619 7.746 -2.236 23.238 15.492]
We can verify that it can also be obtained by taking the gradient of \(Y^3\)
f = e3nn.grad(lambda x: e3nn.spherical_harmonics("3o", x, False))
0.18443 * f(x)["2e"]
1x2e [11.619 7.746 -2.236 23.238 15.492]