Extra Stuff
- e3nn_jax.sus(x)[source]
Smooth Unit Step function.
-inf->0, 0->0, 2->0.6, +inf->1
\[\begin{split}\text{sus}(x) = \begin{cases} 0, & \text{if } x < 0 \\ exp(-1/x), & \text{if } x \geq 0 \\ \end{cases}\end{split}\]
- e3nn_jax.normalize_function(phi: Callable[[float], float]) Callable[[float], float] [source]
Normalize a function, \(\psi(x)=\phi(x)/c\) where \(c\) is the normalization constant such that
\[\int_{-\infty}^{\infty} \psi(x)^2 \frac{e^{-x^2/2}}{\sqrt{2\pi}} dx = 1\]
- e3nn_jax.scatter_sum(data: Array | IrrepsArray, *, dst: Array | None = None, nel: Array | None = None, output_size: int | None = None, map_back: bool = False, mode: str = 'promise_in_bounds') Array | IrrepsArray [source]
Scatter sum of data.
- Performs either of the following two operations::
output[dst[i]] += data[i]
or:
output[i] = sum(data[sum(nel[:i]):sum(nel[:i+1])])
- Parameters:
data (
jax.Array
orIrrepsArray
) – array of shape(n1,..nd, ...)
dst (optional,
jax.Array
) – array of shape(n1,..nd)
. If not specified,nel
must be specified.nel (optional,
jax.Array
) – array of shape(output_size,)
. If not specified,dst
must be specified.output_size (optional, int) – size of output array. If not specified,
nel
must be specified ormap_back
must beTrue
.map_back (bool) – whether to map back to the input position
- Returns:
output array of shape
(output_size, ...)
- Return type:
- e3nn_jax.scatter_mean(data: Array | IrrepsArray, *, dst: Array | None = None, nel: Array | None = None, output_size: int | None = None, map_back: bool = False, mode: str = 'promise_in_bounds') Array | IrrepsArray [source]
Scatter mean of data.
Performs either of the following two operations:
n[dst[i]] += 1 output[dst[i]] += data[i] / n[i]
or:
output[i] = sum(data[sum(nel[:i]):sum(nel[:i+1])]) / nel[i]
- Parameters:
data (
jax.Array
orIrrepsArray
) – array of shape(n1,..nd, ...)
dst (optional,
jax.Array
) – array of shape(n1,..nd)
. If not specified,nel
must be specified.nel (optional,
jax.Array
) – array of shape(output_size,)
. If not specified,dst
must be specified.output_size (optional, int) – size of output array. If not specified,
nel
must be specified ormap_back
must beTrue
.map_back (bool) – whether to map back to the input position
- Returns:
output array of shape
(output_size, ...)
- Return type:
- e3nn_jax.scatter_max(data: Array | IrrepsArray, *, dst: Array | None = None, nel: Array | None = None, initial: float = -inf, output_size: int | None = None, map_back: bool = False, mode: str = 'promise_in_bounds') Array | IrrepsArray [source]
Scatter max of data.
Performs either of the following two operations:
output[i] = max(initial, *(x for j, x in zip(dst, data) if j == i))
or:
output[i] = max(initial, *data[sum(nel[:i]):sum(nel[:i+1])])
- Parameters:
data (
jax.Array
orIrrepsArray
) – array of shape(n, ...)
dst (optional,
jax.Array
) – array of shape(n,)
. If not specified,nel
must be specified.nel (optional,
jax.Array
) – array of shape(output_size,)
. If not specified,dst
must be specified.initial (float) – initial value to compare to
output_size (optional, int) – size of output array. If not specified,
nel
must be specified ormap_back
must beTrue
.map_back (bool) – whether to map back to the input position
- Returns:
output array of shape
(output_size, ...)
- Return type:
- e3nn_jax.radius_graph(pos: IrrepsArray | Array, r_max: float, *, batch: Array = None, size: int = None, loop: bool = False, fill_src: int = -1, fill_dst: int = -1)[source]
Try to use
matscipy.neighbours.neighbour_list
instead.- Parameters:
- Returns:
tuple containing:
jax.Array: source indices jax.Array: destination indices
- Return type:
(tuple)
Examples
>>> key = jax.random.PRNGKey(0) >>> pos = jax.random.normal(key, (20, 3)) >>> batch = jnp.arange(20) < 10 >>> radius_graph(pos, 0.8, batch=batch) (Array([ 3, 7, 10, 11, 12, 18], dtype=int32), Array([ 7, 3, 11, 10, 18, 12], dtype=int32))