Equivariant Operations

e3nn_jax.grad(fun: Callable[[IrrepsArray], IrrepsArray], argnums: int = 0, has_aux: bool = False, regroup_output: bool = True) IrrepsArray[source]

Take the gradient of an equivariant function and reduce it into irreps.

Parameters:
  • fun – An equivariant function.

  • argnums – The argument number to differentiate with respect to.

  • has_aux – If True, the function returns a tuple of the output and an auxiliary value.

  • regroup_output (bool, optional) – Regroup the outputs into irreps. Defaults to True.

Returns:

The gradient of the function. Also an equivariant function.

Examples

>>> jnp.set_printoptions(precision=3, suppress=True)
>>> f = grad(lambda x: 0.5 * e3nn.norm(x, squared=True))
>>> x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2, 3]))
>>> f(x)
1x1o [1. 2. 3.]