Flax Modules

class e3nn_jax.flax.Linear(irreps_out: ~e3nn_jax._src.irreps.Irreps, irreps_in: ~e3nn_jax._src.irreps.Irreps | None = None, channel_out: int | None = None, gradient_normalization: float | str | None = None, path_normalization: float | str | None = None, biases: bool = False, num_indexed_weights: int | None = None, weights_per_channel: bool = False, force_irreps_out: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: Module

Equivariant Linear Flax module

Parameters:
  • irreps_out (Irreps) – output representations, if allowed bu Schur’s lemma.

  • channel_out (optional int) – if specified, the last axis before the irreps is assumed to be the channel axis and is mixed with the irreps.

  • irreps_in (optional Irreps) – input representations. If not specified, the input representations is obtained when calling the module.

  • biases (bool) – whether to add a bias to the output.

  • path_normalization (str or float) – Normalization of the paths, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the forward.

  • gradient_normalization (str or float) – Normalization of the gradients, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.

  • num_indexed_weights (optional int) – number of indexed weights. See example below.

  • weights_per_channel (bool) – whether to have one set of weights per channel.

  • force_irreps_out (bool) – whether to force the output irreps to be the one specified in irreps_out.

Examples

Vanilla:

>>> import e3nn_jax as e3nn
>>> import jax
>>>
>>> linear = Linear("2x0e + 1o + 2e")
>>> x = e3nn.normal("0e + 1o")
>>> w = linear.init(jax.random.PRNGKey(0), x)
>>> linear.apply(w, x).irreps  # Note that the 2e is discarded
2x0e+1x1o
>>> linear.apply(w, x).shape
(5,)

External weights:

>>> linear = Linear("2x0e + 1o")
>>> e = jnp.array([1., 2., 3., 4.])
>>> w = linear.init(jax.random.PRNGKey(0), e, x)
>>> linear.apply(w, e, x).shape
(5,)

Indexed weights:

>>> linear = Linear("2x0e + 1o", num_indexed_weights=3)
>>> i = jnp.array(2)
>>> w = linear.init(jax.random.PRNGKey(0), i, x)
>>> linear.apply(w, i, x).shape
(5,)
class e3nn_jax.flax.MultiLayerPerceptron(list_neurons: ~typing.Tuple[int, ...], act: ~typing.Callable | None = None, gradient_normalization: str | float = None, output_activation: ~typing.Callable | bool = True, with_bias: bool = False, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: Module

Just a simple MLP for scalars. No equivariance here.

Parameters:
  • list_neurons (list of int) – number of neurons in each layer (excluding the input layer)

  • act (optional callable) – activation function

  • gradient_normalization (str or float) –

    normalization of the gradient

    • ”element”: normalization done in initialization variance of the weights, (the default in pytorch)

      gives the same importance to each neuron, a layer with more neurons will have a higher importance than a layer with less neurons

    • ”path” (default): normalization done explicitly in the forward pass,

      gives the same importance to every layer independently of the number of neurons

class e3nn_jax.flax.BatchNorm(use_running_average: bool | None = None, eps: float = 0.0001, momentum: float = 0.1, affine: bool = True, reduce: str = 'mean', instance: bool = False, normalization: str = None, parent: ~typing.Type[~flax.linen.module.Module] | ~flax.core.scope.Scope | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Bases: Module

Equivariant Batch Normalization.

It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations are orthonormal.

Parameters:

use_running_average

if True, the statistics stored in batch_stats will be

used instead of computing the batch statistics on the input.
eps (float): epsilon for numerical stability, has to be between 0 and 1.

the field norm is transformed to (1 - eps) * norm + eps leading to a slower convergence toward norm 1.

momentum: momentum for moving average affine: whether to include learnable weights and biases reduce: reduce mode, either ‘mean’ or ‘max’ instance: whether to use instance normalization instead of batch normalization normalization: normalization mode, either ‘norm’ or ‘component’