Neural Network Functions

e3nn_jax.gate(input: ~e3nn_jax._src.irreps_array.IrrepsArray, even_act: ~typing.Callable[[float], float] = <function gelu>, odd_act: ~typing.Callable[[float], float] = <function soft_odd>, even_gate_act: ~typing.Callable[[float], float] = <PjitFunction of <function sigmoid>>, odd_gate_act: ~typing.Callable[[float], float] = <PjitFunction of <function jax.numpy.tanh>>, normalize_act: bool = True) IrrepsArray[source]

Gate activation function.

The input is split into scalars that are activated separately, scalars that are used as gates, and non-scalars that are multiplied by the gates.

List of assumptions:

  • The gate scalars are on the right side of the scalars.

Parameters:
  • input (IrrepsArray) – Input data.

  • even_act (Callable[[float], float]) – Activation function for even scalars. Default: jax.nn.gelu().

  • odd_act (Callable[[float], float]) – Activation function for odd scalars. Default: \((1 - \exp(-x^2)) x\).

  • even_gate_act (Callable[[float], float]) – Activation function for even gate scalars. Default: jax.nn.sigmoid().

  • odd_gate_act (Callable[[float], float]) – Activation function for odd gate scalars. Default: jax.nn.tanh().

  • normalize_act (bool) – If True, the activation functions are normalized using e3nn_jax.normalize_function.

Returns:

Output data.

Return type:

IrrepsArray

Examples

The 3 last scalars are used as gates.

>>> gate("15x0e + 2x1e + 1x2e")
12x0e+2x1e+1x2e

Odd scalars used as gates change the parity of the gated quantities:

>>> gate("12x0e + 3x0o + 2x1e + 1x2e")
12x0e+2x1o+1x2o

Without anything to gate, all the scalars are activated:

>>> gate("12x0e + 3x0o")
12x0e+3x0o
e3nn_jax.scalar_activation(input: ~e3nn_jax._src.irreps_array.IrrepsArray, acts: ~typing.List[~typing.Callable[[float], float] | None] = None, *, even_act: ~typing.Callable[[float], float] = <function gelu>, odd_act: ~typing.Callable[[float], float] = <function soft_odd>, normalize_act: bool = True) IrrepsArray[source]

Apply activation functions to the scalars of an IrrepsArray. The activation functions are by default normalized.

Parameters:
  • input (IrrepsArray) – input array

  • acts (optional, list of functions) – list of activation functions, one for each chunk of the input

  • even_act (Callable[[float], float]) – Activation function for even scalars. Default: jax.nn.gelu().

  • odd_act (Callable[[float], float]) – Activation function for odd scalars. Default: \((1 - \exp(-x^2)) x\).

  • normalize_act (bool) – if True, normalize the activation functions using normalize_function

Returns:

output array

Return type:

IrrepsArray

Examples

>>> x = e3nn.IrrepsArray("0e + 0o + 1o", jnp.array([1.0, 0.0, 1.0, 1.0, 2.0]))
>>> scalar_activation(x, [jnp.exp, jnp.sin, None])
1x0e+1x0o+1x1o [1.0010498 0.        1.        1.        2.       ]
>>> scalar_activation(x, [jnp.exp, jnp.cos, None])
1x0e+1x0e+1x1o [1.0010498 1.3272501 1.        1.        2.       ]

Note

The parity of the output depends on the parity of the activation function.

e3nn_jax.norm_activation(input: IrrepsArray, acts: List[Callable[[float], float] | None], *, normalization: str = 'component') IrrepsArray[source]

Apply activation functions to the norms of the vectors of an IrrepsArray.

Parameters:
  • input (IrrepsArray) – input array

  • acts (list of functions) – list of activation functions, one for each chunk of the input

  • normalization (str) – “component” or “norm” if “component” the norm is divided by the square root of the number of components.

Returns:

output array

Return type:

IrrepsArray

Examples

>>> x = e3nn.IrrepsArray("0e + 1o", jnp.array([1.0, 1.0, 1.0, 2.0]))
>>> norm_activation(x, [None, jnp.tanh])
1x0e+1x1o [1.        0.8883856 0.8883856 1.7767712]