Equivariant Operations
- e3nn_jax.grad(fun: Callable[[IrrepsArray], IrrepsArray], argnums: int = 0, has_aux: bool = False, regroup_output: bool = True) IrrepsArray [source]
Take the gradient of an equivariant function and reduce it into irreps.
- Parameters:
fun – An equivariant function.
argnums – The argument number to differentiate with respect to.
has_aux – If True, the function returns a tuple of the output and an auxiliary value.
regroup_output (bool, optional) – Regroup the outputs into irreps. Defaults to True.
- Returns:
The gradient of the function. Also an equivariant function.
Examples
>>> jnp.set_printoptions(precision=3, suppress=True) >>> f = grad(lambda x: 0.5 * e3nn.norm(x, squared=True)) >>> x = e3nn.IrrepsArray("1o", jnp.array([1.0, 2, 3])) >>> f(x) 1x1o [1. 2. 3.]