Utils
- e3nn_jax.utils.vmap(fun: Callable[[...], Any], in_axes: int | None | Sequence[Any] = 0, out_axes: Any = 0)[source]
Wrapper around
jax.vmap()
that handlese3nn_jax.IrrepsArray
objects.- Parameters:
fun – Function to be mapped.
in_axes – Specifies which axes to map over for the input arguments. See
jax.vmap()
for details.out_axes – Specifies which axes to map over for the output arguments. See
jax.vmap()
for details.
- Returns:
Batched/vectorized version of
fun
.
Example
>>> import jax.numpy as jnp >>> x = e3nn.from_chunks("0e + 0e", [jnp.ones((100, 1, 1)), None], (100,)) >>> x.zero_flags (False, True) >>> y = vmap(e3nn.scalar_activation)(x) >>> y.zero_flags (False, True) >>> assert y.chunks[1] is None
- e3nn_jax.utils.equivariance_test(fun: Callable[[IrrepsArray], IrrepsArray], rng_key: Array, *args)[source]
Test equivariance of a function.
- Parameters:
fun – function to test
rng_key – random number generator key
*args – arguments to pass to fun, can be IrrepsArray or Irreps if an argument is Irreps, it will be replaced by a random IrrepsArray
- Returns:
outputs of fun(R args) and R fun(args) for a random rotation R and inversion
- Return type:
out1, out2
Example
>>> fun = e3nn.norm >>> rng = jax.random.PRNGKey(0) >>> x = e3nn.IrrepsArray("1e", jnp.array([0.0, 4.0, 3.0])) >>> equivariance_test(fun, rng, x) (1x0e [5.], 1x0e [5.])
- e3nn_jax.utils.assert_equivariant(fun: Callable[[IrrepsArray], IrrepsArray], rng_key: Array, *args, atol: float = 1e-06, rtol: float = 1e-06)[source]
Assert that a function is equivariant.
- Parameters:
fun – function to test
rng_key – random number generator key
*args – arguments to pass to fun, can be IrrepsArray or Irreps if an argument is Irreps, it will be replaced by a random IrrepsArray
atol – absolute tolerance
rtol – relative tolerance
Examples
>>> fun = e3nn.norm >>> rng = jax.random.PRNGKey(0) >>> x = e3nn.IrrepsArray("1e", jnp.array([0.0, 4.0, 3.0])) >>> assert_equivariant(fun, rng, x)
We can also pass the irreps of the inputs instead of the inputs themselves: >>> assert_equivariant(fun, rng, “1e”)
- e3nn_jax.utils.assert_output_dtype_matches_input_dtype(fun: Callable, *args, **kwargs)[source]
Checks that the dtype of
fun(*args, **kwargs)
matches that of the input(*args, **kwargs)
.- Parameters:
fun – function to test
*args – arguments to pass to fun
**kwargs – keyword arguments to pass to fun
- Raises:
AssertionError – if the dtype of fun(*args, **kwargs) does not match that of the input (*args, **kwargs).