Utils

e3nn_jax.utils.vmap(fun: Callable[[...], Any], in_axes: int | None | Sequence[Any] = 0, out_axes: Any = 0)[source]

Wrapper around jax.vmap() that handles e3nn_jax.IrrepsArray objects.

Parameters:
  • fun – Function to be mapped.

  • in_axes – Specifies which axes to map over for the input arguments. See jax.vmap() for details.

  • out_axes – Specifies which axes to map over for the output arguments. See jax.vmap() for details.

Returns:

Batched/vectorized version of fun.

Example

>>> import jax.numpy as jnp
>>> x = e3nn.from_chunks("0e + 0e", [jnp.ones((100, 1, 1)), None], (100,))
>>> x.zero_flags
(False, True)
>>> y = vmap(e3nn.scalar_activation)(x)
>>> y.zero_flags
(False, True)
>>> assert y.chunks[1] is None
e3nn_jax.utils.equivariance_test(fun: Callable[[IrrepsArray], IrrepsArray], rng_key: Array, *args)[source]

Test equivariance of a function.

Parameters:
  • fun – function to test

  • rng_key – random number generator key

  • *args – arguments to pass to fun, can be IrrepsArray or Irreps if an argument is Irreps, it will be replaced by a random IrrepsArray

Returns:

outputs of fun(R args) and R fun(args) for a random rotation R and inversion

Return type:

out1, out2

Example

>>> fun = e3nn.norm
>>> rng = jax.random.PRNGKey(0)
>>> x = e3nn.IrrepsArray("1e", jnp.array([0.0, 4.0, 3.0]))
>>> equivariance_test(fun, rng, x)
(1x0e [5.], 1x0e [5.])
e3nn_jax.utils.assert_equivariant(fun: Callable[[IrrepsArray], IrrepsArray], rng_key: Array, *args, atol: float = 1e-06, rtol: float = 1e-06)[source]

Assert that a function is equivariant.

Parameters:
  • fun – function to test

  • rng_key – random number generator key

  • *args – arguments to pass to fun, can be IrrepsArray or Irreps if an argument is Irreps, it will be replaced by a random IrrepsArray

  • atol – absolute tolerance

  • rtol – relative tolerance

Examples

>>> fun = e3nn.norm
>>> rng = jax.random.PRNGKey(0)
>>> x = e3nn.IrrepsArray("1e", jnp.array([0.0, 4.0, 3.0]))
>>> assert_equivariant(fun, rng, x)

We can also pass the irreps of the inputs instead of the inputs themselves: >>> assert_equivariant(fun, rng, “1e”)

e3nn_jax.utils.assert_output_dtype_matches_input_dtype(fun: Callable, *args, **kwargs)[source]

Checks that the dtype of fun(*args, **kwargs) matches that of the input (*args, **kwargs).

Parameters:
  • fun – function to test

  • *args – arguments to pass to fun

  • **kwargs – keyword arguments to pass to fun

Raises:

AssertionError – if the dtype of fun(*args, **kwargs) does not match that of the input (*args, **kwargs).