Neural Network Functions
- e3nn_jax.gate(input: ~e3nn_jax._src.irreps_array.IrrepsArray, even_act: ~typing.Callable[[float], float] = <function gelu>, odd_act: ~typing.Callable[[float], float] = <function soft_odd>, even_gate_act: ~typing.Callable[[float], float] = <PjitFunction of <function sigmoid>>, odd_gate_act: ~typing.Callable[[float], float] = <PjitFunction of <function tanh>>, normalize_act: bool = True) IrrepsArray [source]
Gate activation function.
The input is split into scalars that are activated separately, scalars that are used as gates, and non-scalars that are multiplied by the gates.
List of assumptions:
The gate scalars are on the right side of the scalars.
- Parameters:
input (IrrepsArray) – Input data.
even_act (Callable[[float], float]) – Activation function for even scalars. Default:
jax.nn.gelu()
.odd_act (Callable[[float], float]) – Activation function for odd scalars. Default: \((1 - \exp(-x^2)) x\).
even_gate_act (Callable[[float], float]) – Activation function for even gate scalars. Default:
jax.nn.sigmoid()
.odd_gate_act (Callable[[float], float]) – Activation function for odd gate scalars. Default:
jax.nn.tanh()
.normalize_act (bool) – If True, the activation functions are normalized using
e3nn_jax.normalize_function
.
- Returns:
Output data.
- Return type:
Examples
The 3 last scalars are used as gates.
>>> gate("15x0e + 2x1e + 1x2e") 12x0e+2x1e+1x2e
Odd scalars used as gates change the parity of the gated quantities:
>>> gate("12x0e + 3x0o + 2x1e + 1x2e") 12x0e+2x1o+1x2o
Without anything to gate, all the scalars are activated:
>>> gate("12x0e + 3x0o") 12x0e+3x0o
- e3nn_jax.scalar_activation(input: ~e3nn_jax._src.irreps_array.IrrepsArray, acts: ~typing.List[~typing.Callable[[float], float] | None] = None, *, even_act: ~typing.Callable[[float], float] = <function gelu>, odd_act: ~typing.Callable[[float], float] = <function soft_odd>, normalize_act: bool = True) IrrepsArray [source]
Apply activation functions to the scalars of an
IrrepsArray
. The activation functions are by default normalized.- Parameters:
input (IrrepsArray) – input array
acts (optional, list of functions) – list of activation functions, one for each chunk of the input
even_act (Callable[[float], float]) – Activation function for even scalars. Default:
jax.nn.gelu()
.odd_act (Callable[[float], float]) – Activation function for odd scalars. Default: \((1 - \exp(-x^2)) x\).
normalize_act (bool) – if True, normalize the activation functions using
normalize_function
- Returns:
output array
- Return type:
Examples
>>> x = e3nn.IrrepsArray("0e + 0o + 1o", jnp.array([1.0, 0.0, 1.0, 1.0, 2.0])) >>> scalar_activation(x, [jnp.exp, jnp.sin, None]) 1x0e+1x0o+1x1o [1.0010498 0. 1. 1. 2. ]
>>> scalar_activation(x, [jnp.exp, jnp.cos, None]) 1x0e+1x0e+1x1o [1.0010498 1.3272501 1. 1. 2. ]
Note
The parity of the output depends on the parity of the activation function.
- e3nn_jax.norm_activation(input: IrrepsArray, acts: List[Callable[[float], float] | None], *, normalization: str = 'component') IrrepsArray [source]
Apply activation functions to the norms of the vectors of an
IrrepsArray
.- Parameters:
input (IrrepsArray) – input array
acts (list of functions) – list of activation functions, one for each chunk of the input
normalization (str) – “component” or “norm” if “component” the norm is divided by the square root of the number of components.
- Returns:
output array
- Return type:
Examples
>>> x = e3nn.IrrepsArray("0e + 1o", jnp.array([1.0, 1.0, 1.0, 2.0])) >>> norm_activation(x, [None, jnp.tanh]) 1x0e+1x1o [1. 0.8883856 0.8883856 1.7767712]