Equinox Modules

class e3nn_jax.equinox.Linear(*, irreps_out: ~e3nn_jax._src.irreps.Irreps, irreps_in: ~e3nn_jax._src.irreps.Irreps, channel_out: int | None = None, channel_in: int | None = None, biases: bool = False, path_normalization: float | str | None = None, gradient_normalization: float | str | None = None, num_indexed_weights: int | None = None, weights_per_channel: bool = False, force_irreps_out: bool = False, weights_dim: int | None = None, input_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, linear_type: str = 'vanilla', key: ~jax.Array)[source]

Bases: Module

Equivariant Linear Equinox module

Parameters:
  • irreps_out (Irreps) – output representations, if allowed bu Schur’s lemma.

  • channel_out (optional int) – if specified, the last axis before the irreps is assumed to be the channel axis and is mixed with the irreps.

  • irreps_in (Irreps) – input representations. If not specified, the input representations is obtained when calling the module.

  • channel_in (optional int) – required when using ‘mixed_per_channel’ linear_type, indicating the size of the last axis before the irreps in the input.

  • biases (bool) – whether to add a bias to the output.

  • path_normalization (str or float) – Normalization of the paths, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the forward.

  • gradient_normalization (str or float) – Normalization of the gradients, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.

  • num_indexed_weights (optional int) – number of indexed weights. See example below.

  • weights_per_channel (bool) – whether to have one set of weights per channel.

  • force_irreps_out (bool) – whether to force the output irreps to be the one specified in irreps_out.

Due to how Equinox is implemented, the random key, irreps_in and irreps_out must be supplied at initialization. The type of the linear layer must also be supplied at initialization: ‘vanilla’, ‘indexed’, ‘mixed’, ‘mixed_per_channel’ Also, depending on what type of linear layer is used, additional options (eg. ‘num_indexed_weights’, ‘weights_per_channel’, ‘weights_dim’, ‘channel_in’) must be supplied.

Examples

Vanilla:

>>> import e3nn_jax as e3nn
>>> import jax

>>> x = e3nn.normal("0e + 1o")
>>> linear = e3nn.equinox.Linear(
        irreps_out="2x0e + 1o + 2e",
        irreps_in=x.irreps,
        key=jax.random.PRNGKey(0),
    )
>>> linear(x).irreps  # Note that the 2e is discarded. Avoid this by setting force_irreps_out=True.
2x0e+1x1o
>>> linear(x).shape
(5,)

External weights:

>>> linear = e3nn.equinox.Linear(
        irreps_out="2x0e + 1o",
        irreps_in=x.irreps,
        linear_type="mixed",
        weights_dim=4,
        key=jax.random.PRNGKey(0),
    )
>>> e = jnp.array([1., 2., 3., 4.])
>>> linear(e, x).irreps
    2x0e+1x1o
>>> linear(e, x).shape
(5,)

Indexed weights:

>>> linear = e3nn.equinox.Linear(
        irreps_out="2x0e + 1o + 2e",
        irreps_in=x.irreps,
        linear_type="indexed",
        num_indexed_weights=3,
        key=jax.random.PRNGKey(0),
    )
>>> i = jnp.array(2)
>>> linear(i, x).irreps
    2x0e+1x1o
>>> linear(i, x).shape
(5,)