Haiku Modules

class e3nn_jax.haiku.Linear(irreps_out: Irreps, channel_out: int = None, *, irreps_in: Irreps | None = None, biases: bool = False, path_normalization: str | float = None, gradient_normalization: str | float = None, get_parameter: Callable[[str, Tuple[int, ...], float, Any], Array] | None = None, num_indexed_weights: int | None = None, weights_per_channel: bool = False, force_irreps_out: bool = False, name: str | None = None)[source]

Bases: Module

Equivariant Linear Haiku Module.

Parameters:
  • irreps_out (Irreps) – output representations, if allowed bu Schur’s lemma.

  • channel_out (optional int) – if specified, the last axis before the irreps is assumed to be the channel axis and is mixed with the irreps.

  • irreps_in (optional Irreps) – input representations. If not specified, the input representations is obtained when calling the module.

  • biases (bool) – whether to add a bias to the output.

  • path_normalization (str or float) – Normalization of the paths, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the forward.

  • gradient_normalization (str or float) – Normalization of the gradients, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.

  • get_parameter (optional Callable) – function to get the parameters.

  • num_indexed_weights (optional int) – number of indexed weights. See example below.

  • weights_per_channel (bool) – whether to have one set of weights per channel.

  • force_irreps_out (bool) – whether to force the output irreps to be the one specified in irreps_out.

  • name (optional str) – name of the module.

Examples

Vanilla:

>>> import e3nn_jax as e3nn
>>> import jax
>>>
>>> @hk.without_apply_rng
... @hk.transform
... def linear(x):
...     return e3nn.haiku.Linear("0e + 1o + 2e")(x)
>>> x = e3nn.IrrepsArray("1o + 2x0e", jnp.ones(5))
>>> params = linear.init(jax.random.PRNGKey(0), x)
>>> y = linear.apply(params, x)
>>> y.irreps  # Note that the 2e is discarded
1x0e+1x1o
>>> y.shape
(4,)

External weights:

>>> @hk.without_apply_rng
... @hk.transform
... def linear(w, x):
...     return e3nn.haiku.Linear("0e + 1o")(w, x)
>>> x = e3nn.IrrepsArray("1o + 2x0e", jnp.ones(5))
>>> w = jnp.array([1., 2., 3., 4.])
>>> params = linear.init(jax.random.PRNGKey(0), w, x)
>>> y = linear.apply(params, w, x)
>>> y.shape
(4,)

External indices:

>>> @hk.without_apply_rng
... @hk.transform
... def linear(i, x):
...     return e3nn.haiku.Linear("0e + 1o", num_indexed_weights=4)(i, x)
>>> x = e3nn.IrrepsArray("1o + 2x0e", jnp.ones((2, 5)))
>>> i = jnp.array([2, 3])
>>> params = linear.init(jax.random.PRNGKey(0), i, x)
>>> y = linear.apply(params, i, x)
>>> y.shape
(2, 4)
class e3nn_jax.haiku.MultiLayerPerceptron(list_neurons: Sequence[int], act: Callable | None, *, gradient_normalization: str | float = None, output_activation: Callable | bool = True, with_bias: bool = False, name: str | None = None)[source]

Bases: Module

Just a simple MLP for scalars. No equivariance here.

Parameters:
  • list_neurons (list of int) – number of neurons in each layer (excluding the input layer)

  • act (optional callable) – activation function

  • gradient_normalization (str or float) –

    normalization of the gradient

    • ”element”: normalization done in initialization variance of the weights, (the default in pytorch)

      gives the same importance to each neuron, a layer with more neurons will have a higher importance than a layer with less neurons

    • ”path” (default): normalization done explicitly in the forward pass,

      gives the same importance to every layer independently of the number of neurons

class e3nn_jax.haiku.BatchNorm(*, irreps: Irreps = None, eps: float = 0.0001, momentum: float = 0.1, affine: bool = True, reduce: str = 'mean', instance: bool = False, normalization: str = None)[source]

Bases: Module

Equivariant Batch Normalization.

It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations are orthonormal.

Parameters:
  • irreps – Irreducible representations of the input and output (unchanged)

  • eps (float) – epsilon for numerical stability, has to be between 0 and 1. the field norm is transformed to (1 - eps) * norm + eps leading to a slower convergence toward norm 1.

  • momentum – momentum for moving average

  • affine – whether to include learnable weights and biases

  • reduce – reduce mode, either ‘mean’ or ‘max’

  • instance – whether to use instance normalization instead of batch normalization

  • normalization – normalization mode, either ‘norm’ or ‘component’

class e3nn_jax.haiku.Dropout(p, *, irreps=None)[source]

Bases: Module

Equivariant Dropout.

\(A_{zai}\) is the input and \(B_{zai}\) is the output where - z is the batch index - a any non-batch and non-irrep index - i is the irrep index, for instance if irreps="0e + 2x1e" then i=2 select the second vector

\[B_{zai} = \frac{x_{zi}}{1-p} A_{zai}\]

where \(p\) is the dropout probability and \(x\) is a Bernoulli random variable with parameter \(1-p\).

Parameters:
  • irreps (Irreps) – the irrep string

  • p (float) – dropout probability

Returns:

the dropout module

Return type:

Dropout

class e3nn_jax.haiku.SymmetricTensorProduct(orders: Tuple[int, ...], keep_irrep_out: Set[Irrep] | None = None, get_parameter: Callable[[str, Tuple[int, ...], Any], Array] | None = None)[source]

Bases: Module

Symmetric tensor product contraction with parameters

Equivalent to the following code executed in parallel on the channel dimension:

e3nn.haiku.Linear(irreps_out)(
    e3nn.concatenate([
        x,
        tensor_product(x, x),  # additionally keeping only the symmetric terms
        tensor_product(tensor_product(x, x), x),
        ...
    ])
)

Each channel has its own parameters.

Parameters:
  • orders (tuple of int) – orders of the tensor product

  • keep_irrep_out (optional, set of Irrep) – irreps to keep in the output

  • get_parameter (optional, callable) – function to get the parameters, by default it uses hk.get_parameter it should have the signature get_parameter(name, shape) -> Array and return a normal distribution with variance 1