Flax Modules
- class e3nn_jax.flax.Linear(irreps_out: ~e3nn_jax._src.irreps.Irreps, irreps_in: ~e3nn_jax._src.irreps.Irreps | None = None, channel_out: int | None = None, gradient_normalization: float | str | None = None, path_normalization: float | str | None = None, biases: bool = False, parameter_initializer: ~typing.Callable[[], ~jax.nn.initializers.Initializer] | None = None, instructions: ~typing.List[~typing.Tuple[int, int]] | None = None, num_indexed_weights: int | None = None, weights_per_channel: bool = False, force_irreps_out: bool = False, simplify_irreps_internally: bool = True, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
Module
Equivariant Linear Flax module
- Parameters:
irreps_out (
Irreps
) – output representations, if allowed bu Schur’s lemma.channel_out (optional int) – if specified, the last axis before the irreps is assumed to be the channel axis and is mixed with the irreps.
irreps_in (optional
Irreps
) – input representations. If not specified, the input representations is obtained when calling the module.biases (bool) – whether to add a bias to the output.
path_normalization (str or float) – Normalization of the paths,
element
orpath
. 0/1 corresponds to a normalization where each element/path has an equal contribution to the forward.gradient_normalization (str or float) – Normalization of the gradients,
element
orpath
. 0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.num_indexed_weights (optional int) – number of indexed weights. See example below.
weights_per_channel (bool) – whether to have one set of weights per channel.
force_irreps_out (bool) – whether to force the output irreps to be the one specified in
irreps_out
.
Examples
Vanilla:
>>> import e3nn_jax as e3nn >>> import jax >>> >>> linear = Linear("2x0e + 1o + 2e") >>> x = e3nn.normal("0e + 1o") >>> w = linear.init(jax.random.PRNGKey(0), x) >>> linear.apply(w, x).irreps # Note that the 2e is discarded 2x0e+1x1o >>> linear.apply(w, x).shape (5,)
External weights:
>>> linear = Linear("2x0e + 1o") >>> e = jnp.array([1., 2., 3., 4.]) >>> w = linear.init(jax.random.PRNGKey(0), e, x) >>> linear.apply(w, e, x).shape (5,)
Indexed weights:
>>> linear = Linear("2x0e + 1o", num_indexed_weights=3) >>> i = jnp.array(2) >>> w = linear.init(jax.random.PRNGKey(0), i, x) >>> linear.apply(w, i, x).shape (5,)
- class e3nn_jax.flax.MultiLayerPerceptron(list_neurons: ~typing.Tuple[int, ...], act: ~typing.Callable | None = None, gradient_normalization: str | float = None, output_activation: ~typing.Callable | bool = True, with_bias: bool = False, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
Module
Just a simple MLP for scalars. No equivariance here.
- Parameters:
list_neurons (list of int) – number of neurons in each layer (excluding the input layer)
act (optional callable) – activation function
gradient_normalization (str or float) –
normalization of the gradient
- ”element”: normalization done in initialization variance of the weights, (the default in pytorch)
gives the same importance to each neuron, a layer with more neurons will have a higher importance than a layer with less neurons
- ”path” (default): normalization done explicitly in the forward pass,
gives the same importance to every layer independently of the number of neurons
- class e3nn_jax.flax.BatchNorm(use_running_average: bool | None = None, eps: float = 0.0001, momentum: float = 0.1, affine: bool = True, reduce: str = 'mean', instance: bool = False, normalization: str = None, parent: ~flax.linen.module.Module | ~flax.core.scope.Scope | ~flax.linen.module._Sentinel | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]
Bases:
Module
Equivariant Batch Normalization.
It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations are orthonormal.
- Parameters:
use_running_average – if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.
eps (float) – epsilon for numerical stability, has to be between 0 and 1. the field norm is transformed to
(1 - eps) * norm + eps
leading to a slower convergence toward norm 1.momentum – momentum for moving average
affine – whether to include learnable weights and biases
reduce – reduce mode, either ‘mean’ or ‘max’
instance – whether to use instance normalization instead of batch normalization
normalization – normalization mode, either ‘norm’ or ‘component’