Equinox Modules
- class e3nn_jax.equinox.Linear(*, irreps_out: ~e3nn_jax._src.irreps.Irreps, irreps_in: ~e3nn_jax._src.irreps.Irreps, channel_out: int | None = None, channel_in: int | None = None, biases: bool = False, path_normalization: float | str | None = None, gradient_normalization: float | str | None = None, num_indexed_weights: int | None = None, weights_per_channel: bool = False, force_irreps_out: bool = False, weights_dim: int | None = None, input_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>, linear_type: str = 'vanilla', key: ~jax.Array)[source]
Bases:
Module
Equivariant Linear Equinox module
- Parameters:
irreps_out (
Irreps
) – output representations, if allowed bu Schur’s lemma.channel_out (optional int) – if specified, the last axis before the irreps is assumed to be the channel axis and is mixed with the irreps.
irreps_in (
Irreps
) – input representations. If not specified, the input representations is obtained when calling the module.channel_in (optional int) – required when using ‘mixed_per_channel’ linear_type, indicating the size of the last axis before the irreps in the input.
biases (bool) – whether to add a bias to the output.
path_normalization (str or float) – Normalization of the paths,
element
orpath
. 0/1 corresponds to a normalization where each element/path has an equal contribution to the forward.gradient_normalization (str or float) – Normalization of the gradients,
element
orpath
. 0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.num_indexed_weights (optional int) – number of indexed weights. See example below.
weights_per_channel (bool) – whether to have one set of weights per channel.
force_irreps_out (bool) – whether to force the output irreps to be the one specified in
irreps_out
.
Due to how Equinox is implemented, the random key, irreps_in and irreps_out must be supplied at initialization. The type of the linear layer must also be supplied at initialization: ‘vanilla’, ‘indexed’, ‘mixed’, ‘mixed_per_channel’ Also, depending on what type of linear layer is used, additional options (eg. ‘num_indexed_weights’, ‘weights_per_channel’, ‘weights_dim’, ‘channel_in’) must be supplied.
Examples
Vanilla:
>>> import e3nn_jax as e3nn >>> import jax >>> x = e3nn.normal("0e + 1o") >>> linear = e3nn.equinox.Linear( irreps_out="2x0e + 1o + 2e", irreps_in=x.irreps, key=jax.random.PRNGKey(0), ) >>> linear(x).irreps # Note that the 2e is discarded. Avoid this by setting force_irreps_out=True. 2x0e+1x1o >>> linear(x).shape (5,)
External weights:
>>> linear = e3nn.equinox.Linear( irreps_out="2x0e + 1o", irreps_in=x.irreps, linear_type="mixed", weights_dim=4, key=jax.random.PRNGKey(0), ) >>> e = jnp.array([1., 2., 3., 4.]) >>> linear(e, x).irreps 2x0e+1x1o >>> linear(e, x).shape (5,)
Indexed weights:
>>> linear = e3nn.equinox.Linear( irreps_out="2x0e + 1o + 2e", irreps_in=x.irreps, linear_type="indexed", num_indexed_weights=3, key=jax.random.PRNGKey(0), ) >>> i = jnp.array(2) >>> linear(i, x).irreps 2x0e+1x1o >>> linear(i, x).shape (5,)