Extra Stuff

e3nn_jax.sus(x)[source]

Smooth Unit Step function.

-inf->0, 0->0, 2->0.6, +inf->1

\[\begin{split}\text{sus}(x) = \begin{cases} 0, & \text{if } x < 0 \\ exp(-1/x), & \text{if } x \geq 0 \\ \end{cases}\end{split}\]
e3nn_jax.normalize_function(phi: Callable[[float], float]) Callable[[float], float][source]

Normalize a function, \(\psi(x)=\phi(x)/c\) where \(c\) is the normalization constant such that

\[\int_{-\infty}^{\infty} \psi(x)^2 \frac{e^{-x^2/2}}{\sqrt{2\pi}} dx = 1\]
e3nn_jax.scatter_sum(data: Array | IrrepsArray, *, dst: Array | None = None, nel: Array | None = None, output_size: int | None = None, map_back: bool = False, mode: str = 'promise_in_bounds') Array | IrrepsArray[source]

Scatter sum of data.

Performs either of the following two operations::

output[dst[i]] += data[i]

or:

output[i] = sum(data[sum(nel[:i]):sum(nel[:i+1])])
Parameters:
  • data (jax.Array or IrrepsArray) – array of shape (n1,..nd, ...)

  • dst (optional, jax.Array) – array of shape (n1,..nd). If not specified, nel must be specified.

  • nel (optional, jax.Array) – array of shape (output_size,). If not specified, dst must be specified.

  • output_size (optional, int) – size of output array. If not specified, nel must be specified or map_back must be True.

  • map_back (bool) – whether to map back to the input position

Returns:

output array of shape (output_size, ...)

Return type:

jax.Array or IrrepsArray

e3nn_jax.scatter_mean(data: Array | IrrepsArray, *, dst: Array | None = None, nel: Array | None = None, output_size: int | None = None, map_back: bool = False, mode: str = 'promise_in_bounds') Array | IrrepsArray[source]

Scatter mean of data.

Performs either of the following two operations:

n[dst[i]] += 1
output[dst[i]] += data[i] / n[i]

or:

output[i] = sum(data[sum(nel[:i]):sum(nel[:i+1])]) / nel[i]
Parameters:
  • data (jax.Array or IrrepsArray) – array of shape (n1,..nd, ...)

  • dst (optional, jax.Array) – array of shape (n1,..nd). If not specified, nel must be specified.

  • nel (optional, jax.Array) – array of shape (output_size,). If not specified, dst must be specified.

  • output_size (optional, int) – size of output array. If not specified, nel must be specified or map_back must be True.

  • map_back (bool) – whether to map back to the input position

Returns:

output array of shape (output_size, ...)

Return type:

jax.Array or IrrepsArray

e3nn_jax.scatter_max(data: Array | IrrepsArray, *, dst: Array | None = None, nel: Array | None = None, initial: float = -inf, output_size: int | None = None, map_back: bool = False, mode: str = 'promise_in_bounds') Array | IrrepsArray[source]

Scatter max of data.

Performs either of the following two operations:

output[i] = max(initial, *(x for j, x in zip(dst, data) if j == i))

or:

output[i] = max(initial, *data[sum(nel[:i]):sum(nel[:i+1])])
Parameters:
  • data (jax.Array or IrrepsArray) – array of shape (n, ...)

  • dst (optional, jax.Array) – array of shape (n,). If not specified, nel must be specified.

  • nel (optional, jax.Array) – array of shape (output_size,). If not specified, dst must be specified.

  • initial (float) – initial value to compare to

  • output_size (optional, int) – size of output array. If not specified, nel must be specified or map_back must be True.

  • map_back (bool) – whether to map back to the input position

Returns:

output array of shape (output_size, ...)

Return type:

jax.Array or IrrepsArray

e3nn_jax.radius_graph(pos: IrrepsArray | Array, r_max: float, *, batch: Array = None, size: int = None, loop: bool = False, fill_src: int = -1, fill_dst: int = -1)[source]

Try to use matscipy.neighbours.neighbour_list instead.

Parameters:
  • pos (jax.Array) – array of shape (n, 3)

  • r_max (float) –

  • batch (jax.Array) – indices

  • size (int) – size of the output

  • loop (bool) – whether to include self-loops

Returns:

tuple containing:

jax.Array: source indices jax.Array: destination indices

Return type:

(tuple)

Examples

>>> key = jax.random.PRNGKey(0)
>>> pos = jax.random.normal(key, (20, 3))
>>> batch = jnp.arange(20) < 10
>>> radius_graph(pos, 0.8, batch=batch)
(Array([ 3,  7, 10, 11, 12, 18], dtype=int32), Array([ 7,  3, 11, 10, 18, 12], dtype=int32))