Haiku Modules
- class e3nn_jax.haiku.Linear(irreps_out: Irreps, channel_out: int = None, *, irreps_in: Irreps | None = None, biases: bool = False, path_normalization: str | float = None, gradient_normalization: str | float = None, parameter_initializer: Callable[[], Callable[[Sequence[int], Any], Array]] | None = None, instructions: List[Tuple[int, int]] | None = None, num_indexed_weights: int | None = None, weights_per_channel: bool = False, force_irreps_out: bool = False, simplify_irreps_internally: bool = True, name: str | None = None)[source]
Bases:
Module
Equivariant Linear Haiku Module.
- Parameters:
irreps_out (
Irreps
) – output representations, if allowed bu Schur’s lemma.channel_out (optional int) – if specified, the last axis before the irreps is assumed to be the channel axis and is mixed with the irreps.
irreps_in (optional
Irreps
) – input representations. If not specified, the input representations is obtained when calling the module.biases (bool) – whether to add a bias to the output.
path_normalization (str or float) – Normalization of the paths,
element
orpath
. 0/1 corresponds to a normalization where each element/path has an equal contribution to the forward.gradient_normalization (str or float) – Normalization of the gradients,
element
orpath
. 0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.get_parameter (optional Callable) – function to get the parameters.
num_indexed_weights (optional int) – number of indexed weights. See example below.
weights_per_channel (bool) – whether to have one set of weights per channel.
force_irreps_out (bool) – whether to force the output irreps to be the one specified in
irreps_out
.name (optional str) – name of the module.
Examples
Vanilla:
>>> import e3nn_jax as e3nn >>> import jax >>> >>> @hk.without_apply_rng ... @hk.transform ... def linear(x): ... return e3nn.haiku.Linear("0e + 1o + 2e")(x) >>> x = e3nn.IrrepsArray("1o + 2x0e", jnp.ones(5)) >>> params = linear.init(jax.random.PRNGKey(0), x) >>> y = linear.apply(params, x) >>> y.irreps # Note that the 2e is discarded 1x0e+1x1o >>> y.shape (4,)
External weights:
>>> @hk.without_apply_rng ... @hk.transform ... def linear(w, x): ... return e3nn.haiku.Linear("0e + 1o")(w, x) >>> x = e3nn.IrrepsArray("1o + 2x0e", jnp.ones(5)) >>> w = jnp.array([1., 2., 3., 4.]) >>> params = linear.init(jax.random.PRNGKey(0), w, x) >>> y = linear.apply(params, w, x) >>> y.shape (4,)
External indices:
>>> @hk.without_apply_rng ... @hk.transform ... def linear(i, x): ... return e3nn.haiku.Linear("0e + 1o", num_indexed_weights=4)(i, x) >>> x = e3nn.IrrepsArray("1o + 2x0e", jnp.ones((2, 5))) >>> i = jnp.array([2, 3]) >>> params = linear.init(jax.random.PRNGKey(0), i, x) >>> y = linear.apply(params, i, x) >>> y.shape (2, 4)
- class e3nn_jax.haiku.MultiLayerPerceptron(list_neurons: Sequence[int], act: Callable | None, *, gradient_normalization: str | float = None, output_activation: Callable | bool = True, with_bias: bool = False, name: str | None = None)[source]
Bases:
Module
Just a simple MLP for scalars. No equivariance here.
- Parameters:
list_neurons (list of int) – number of neurons in each layer (excluding the input layer)
act (optional callable) – activation function
gradient_normalization (str or float) –
normalization of the gradient
- ”element”: normalization done in initialization variance of the weights, (the default in pytorch)
gives the same importance to each neuron, a layer with more neurons will have a higher importance than a layer with less neurons
- ”path” (default): normalization done explicitly in the forward pass,
gives the same importance to every layer independently of the number of neurons
- class e3nn_jax.haiku.BatchNorm(*, irreps: Irreps = None, eps: float = 0.0001, momentum: float = 0.1, affine: bool = True, reduce: str = 'mean', instance: bool = False, normalization: str = None)[source]
Bases:
Module
Equivariant Batch Normalization.
It normalizes by the norm of the representations. Note that the norm is invariant only for orthonormal representations. Irreducible representations are orthonormal.
- Parameters:
irreps – Irreducible representations of the input and output (unchanged)
eps (float) – epsilon for numerical stability, has to be between 0 and 1. the field norm is transformed to
(1 - eps) * norm + eps
leading to a slower convergence toward norm 1.momentum – momentum for moving average
affine – whether to include learnable weights and biases
reduce – reduce mode, either ‘mean’ or ‘max’
instance – whether to use instance normalization instead of batch normalization
normalization – normalization mode, either ‘norm’ or ‘component’
- class e3nn_jax.haiku.Dropout(p, *, irreps=None)[source]
Bases:
Module
Equivariant Dropout.
\(A_{zai}\) is the input and \(B_{zai}\) is the output where -
z
is the batch index -a
any non-batch and non-irrep index -i
is the irrep index, for instance ifirreps="0e + 2x1e"
theni=2
select the second vector\[B_{zai} = \frac{x_{zi}}{1-p} A_{zai}\]where \(p\) is the dropout probability and \(x\) is a Bernoulli random variable with parameter \(1-p\).
- class e3nn_jax.haiku.SymmetricTensorProduct(orders: Tuple[int, ...], keep_irrep_out: Set[Irrep] | None = None, get_parameter: Callable[[str, Tuple[int, ...], Any], Array] | None = None)[source]
Bases:
Module
Symmetric tensor product contraction with parameters
Equivalent to the following code executed in parallel on the channel dimension:
e3nn.haiku.Linear(irreps_out)( e3nn.concatenate([ x, tensor_product(x, x), # additionally keeping only the symmetric terms tensor_product(tensor_product(x, x), x), ... ]) )
Each channel has its own parameters.
- Parameters: