Flax Modules

class e3nn_jax.flax.Linear(irreps_out: ~e3nn_jax._src.irreps.Irreps, irreps_in: ~e3nn_jax._src.irreps.Irreps | None = None, channel_out: int | None = None, gradient_normalization: float | str | None = None, path_normalization: float | str | None = None, biases: bool = False, num_indexed_weights: int | None = None, weights_per_channel: bool = False, force_irreps_out: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: Module

Equivariant Linear Flax module

  • irreps_out (Irreps) – output representations, if allowed bu Schur’s lemma.

  • channel_out (optional int) – if specified, the last axis before the irreps is assumed to be the channel axis and is mixed with the irreps.

  • irreps_in (optional Irreps) – input representations. If not specified, the input representations is obtained when calling the module.

  • biases (bool) – whether to add a bias to the output.

  • path_normalization (str or float) – Normalization of the paths, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the forward.

  • gradient_normalization (str or float) – Normalization of the gradients, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.

  • num_indexed_weights (optional int) – number of indexed weights. See example below.

  • weights_per_channel (bool) – whether to have one set of weights per channel.

  • force_irreps_out (bool) – whether to force the output irreps to be the one specified in irreps_out.



>>> import e3nn_jax as e3nn
>>> import jax
>>> linear = Linear("2x0e + 1o + 2e")
>>> x = e3nn.normal("0e + 1o")
>>> w = linear.init(jax.random.PRNGKey(0), x)
>>> linear.apply(w, x).irreps  # Note that the 2e is discarded
>>> linear.apply(w, x).shape

External weights:

>>> linear = Linear("2x0e + 1o")
>>> e = jnp.array([1., 2., 3., 4.])
>>> w = linear.init(jax.random.PRNGKey(0), e, x)
>>> linear.apply(w, e, x).shape

Indexed weights:

>>> linear = Linear("2x0e + 1o", num_indexed_weights=3)
>>> i = jnp.array(2)
>>> w = linear.init(jax.random.PRNGKey(0), i, x)
>>> linear.apply(w, i, x).shape
class e3nn_jax.flax.MultiLayerPerceptron(list_neurons: ~typing.Tuple[int, ...], act: ~typing.Callable | None = None, gradient_normalization: str | float = None, output_activation: ~typing.Callable | bool = True, with_bias: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: Module

Just a simple MLP for scalars. No equivariance here.

  • list_neurons (list of int) – number of neurons in each layer (excluding the input layer)

  • act (optional callable) – activation function

  • gradient_normalization (str or float) –

    normalization of the gradient

    • ”element”: normalization done in initialization variance of the weights, (the default in pytorch)

      gives the same importance to each neuron, a layer with more neurons will have a higher importance than a layer with less neurons

    • ”path” (default): normalization done explicitly in the forward pass,

      gives the same importance to every layer independently of the number of neurons

class e3nn_jax.flax.BatchNorm(use_running_average: bool | None = None, eps: float = 0.0001, momentum: float = 0.1, affine: bool = True, reduce: str = 'mean', instance: bool = False, normalization: str = None, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: Module

Equivariant Batch Normalization.

It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations are orthonormal.



if True, the statistics stored in batch_stats will be

used instead of computing the batch statistics on the input.
eps (float): epsilon for numerical stability, has to be between 0 and 1.

the field norm is transformed to (1 - eps) * norm + eps leading to a slower convergence toward norm 1.

momentum: momentum for moving average affine: whether to include learnable weights and biases reduce: reduce mode, either ‘mean’ or ‘max’ instance: whether to use instance normalization instead of batch normalization normalization: normalization mode, either ‘norm’ or ‘component’