IrrepsArray

class e3nn_jax.IrrepsArray(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]], array: Array, *, zero_flags=None, chunks: List[Array | None] | None = None)[source]

Bases: object

Array with a representation of rotations.

The IrrepsArray class enforce equivariance by storing an array of data (.array) along with its representation (.irreps).

The data is stored as a single array of shape (..., irreps.dim).

The data can be accessed as a list of arrays (.chunks) matching each item of the .irreps.

Parameters:
  • irreps (Irreps) – representation of the data

  • array (jax.Array) – the data, an array of shape (..., irreps.dim)

  • zero_flags (tuple of bool, optional) – whether each chunk of the data is zero

Examples

>>> import e3nn_jax as e3nn
>>> x = e3nn.IrrepsArray("1o + 2x0e", jnp.ones(5))
>>> y = e3nn.from_chunks("1o + 2x0e", [None, jnp.ones((2, 1))], ())
>>> x + y
1x1o+2x0e [1. 1. 1. 2. 2.]

Example of indexing:

>>> x = IrrepsArray("0e + 1o", jnp.arange(2 * 4).reshape(2, 4))
>>> x[0]
1x0e+1x1o [0 1 2 3]
>>> x[1, "0e"]
1x0e [4]
>>> x[:, 1:]
1x1o
[[1 2 3]
 [5 6 7]]
>>> IrrepsArray("5x0e", jnp.arange(5))[1:3]
2x0e [1 2]

Methods:

astype(dtype)

Change the dtype of the array.

axis_to_irreps([axis])

Repeat the irreps by the last axis of the array.

axis_to_mul([axis])

Repeat the multiplicity by the previous last axis of the array.

broadcast_to(shape)

Broadcast the array to a new shape.

extend_with_zeros(new_irreps)

Extend the array with zeros.

filter([keep, drop, lmax])

Filter the irreps.

mul_to_axis([factor, axis])

Create a new axis in the previous last position by factoring the multiplicities.

rechunk(irreps)

Rechunk the array with new (equivalent) irreps.

regroup()

Regroup the same irreps together.

remove_zero_chunks()

Remove all zero chunks.

repeat_irreps_by_last_axis([axis])

Repeat the irreps by the last axis of the array.

reshape(shape)

Reshape the array.

simplify()

Simplify the irreps.

sort()

Sort the irreps.

transform_by_angles(alpha, beta, gamma[, k, ...])

Rotate the data by angles according to the irreps.

transform_by_axis_angle(axis, angle[, k])

Rotate data by a rotation given by an axis and an angle.

transform_by_log_coordinates(log_coordinates)

Rotate data by a rotation given by log coordinates.

transform_by_matrix(R)

Rotate data by a rotation given by a matrix.

transform_by_quaternion(q[, k])

Rotate data by a rotation given by a quaternion.

unify()

Unify the irreps.

Attributes:

chunks

List of arrays matching each item of the .irreps.

dtype

dtype.

ndim

Number of dimensions.

shape

Shape.

slice_by_chunk

Return the slice with respect to the chunks.

slice_by_dim

Same as __getitem__ in the irreps dimension.

slice_by_mul

Return the slice with respect to the multiplicities.

astype(dtype) IrrepsArray[source]

Change the dtype of the array.

Parameters:

dtype (dtype) – new dtype

Returns:

new IrrepsArray

Return type:

IrrepsArray

axis_to_irreps(axis: int = -2) IrrepsArray[source]

Repeat the irreps by the last axis of the array.

Examples

>>> x = IrrepsArray("0e + 1e", jnp.arange(2 * 4).reshape(2, 4))
>>> x.axis_to_irreps()
1x0e+1x1e+1x0e+1x1e [0 1 2 3 4 5 6 7]
axis_to_mul(axis: int = -2) IrrepsArray[source]

Repeat the multiplicity by the previous last axis of the array.

Decrease the dimension of the array by 1.

Parameters:

axis (int) – axis to convert into multiplicity

Examples

>>> x = IrrepsArray("0e + 1e", jnp.arange(2 * 4).reshape(2, 4))
>>> x.axis_to_mul()
2x0e+2x1e [0 4 1 2 3 5 6 7]
broadcast_to(shape) IrrepsArray[source]

Broadcast the array to a new shape.

property chunks: List[Array | None][source]

List of arrays matching each item of the .irreps.

Examples

>>> x = IrrepsArray("2x0e + 0e", jnp.arange(3))
>>> len(x.chunks)
2
>>> x.chunks[0]
Array([[0],
       [1]], dtype=int32)
>>> x.chunks[1]
Array([[2]], dtype=int32)

The follwing is always true:

>>> all(e.shape == x.shape[:-1] + (mul, ir.dim) for (mul, ir), e in zip(x.irreps, x.chunks))
True
property dtype[source]

dtype. Equivalent to self.array.dtype.

extend_with_zeros(new_irreps: Irreps) IrrepsArray[source]

Extend the array with zeros.

Parameters:

new_irreps (Irreps) – new irreps, must be a superset of the current irreps

Examples

>>> IrrepsArray("0e + 1o", jnp.array([1, 3, 3, 3])).extend_with_zeros("0e + 0e + 1o + 2x0e")
1x0e+1x0e+1x1o+2x0e [1 0 3 3 3 0 0]
filter(keep: Irreps | List[Irrep] | Callable[[MulIrrep], bool] = None, *, drop: Irreps | List[Irrep] | Callable[[MulIrrep], bool] = None, lmax: int = None) IrrepsArray[source]

Filter the irreps.

Parameters:
  • keep (Irreps or list of Irrep or function) – list of irrep to keep

  • exclude (Irreps or list of Irrep or function) – list of irrep to exclude

  • lmax (int) – maximum l

Examples

>>> IrrepsArray("0e + 2x1o + 2x0e", jnp.arange(9)).filter(["1o"])
2x1o [1 2 3 4 5 6]
mul_to_axis(factor: int | None = None, axis: int = -2) IrrepsArray[source]

Create a new axis in the previous last position by factoring the multiplicities.

Increase the dimension of the array by 1.

Parameters:
  • factor (int or None) – factor the multiplicities by this number

  • axis (int) – the new axis will be placed before this axis

Examples

>>> x = IrrepsArray("6x0e + 3x1e", jnp.arange(15))
>>> x.mul_to_axis()
2x0e+1x1e
[[ 0  1  6  7  8]
 [ 2  3  9 10 11]
 [ 4  5 12 13 14]]
property ndim[source]

Number of dimensions. Equivalent to self.array.ndim.

rechunk(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]]) IrrepsArray[source]

Rechunk the array with new (equivalent) irreps.

Parameters:

irreps (Irreps) – new irreps

Returns:

new IrrepsArray

Return type:

IrrepsArray

Examples

>>> x = e3nn.from_chunks("6x0e + 4x0e", [None, jnp.ones((4, 1))], ())
>>> x.rechunk("5x0e + 5x0e").chunks
[None, Array([[0.],
       [1.],
       [1.],
       [1.],
       [1.]], dtype=float32)]
regroup() IrrepsArray[source]

Regroup the same irreps together.

Equivalent to sorted() followed by simplify().

Examples

>>> IrrepsArray("0e + 1o + 2x0e", jnp.arange(6)).regroup()
3x0e+1x1o [0 4 5 1 2 3]
remove_zero_chunks() IrrepsArray[source]

Remove all zero chunks.

repeat_irreps_by_last_axis(axis: int = -2) IrrepsArray[source]

Repeat the irreps by the last axis of the array.

Examples

>>> x = IrrepsArray("0e + 1e", jnp.arange(2 * 4).reshape(2, 4))
>>> x.axis_to_irreps()
1x0e+1x1e+1x0e+1x1e [0 1 2 3 4 5 6 7]
reshape(shape) IrrepsArray[source]

Reshape the array.

Parameters:

shape (tuple) – new shape

Returns:

new IrrepsArray

Return type:

IrrepsArray

Examples

>>> IrrepsArray("2x0e + 1o", jnp.ones((6, 5))).reshape((2, 3, 5))
2x0e+1x1o
[[[1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]]

 [[1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]
  [1. 1. 1. 1. 1.]]]
property shape[source]

Shape. Equivalent to self.array.shape.

simplify() IrrepsArray[source]

Simplify the irreps.

Examples

>>> IrrepsArray("0e + 0e + 0e", jnp.ones(3)).simplify()
3x0e [1. 1. 1.]
>>> IrrepsArray("0e + 0x1e + 0e", jnp.ones(2)).simplify()
2x0e [1. 1.]
property slice_by_chunk[source]

Return the slice with respect to the chunks.

property slice_by_dim[source]

Same as __getitem__ in the irreps dimension.

property slice_by_mul[source]

Return the slice with respect to the multiplicities.

sort() IrrepsArray[source]

Sort the irreps.

Examples

>>> IrrepsArray("0e + 1o + 2x0e", jnp.arange(6)).sort()
1x0e+2x0e+1x1o [0 4 5 1 2 3]
transform_by_angles(alpha: float, beta: float, gamma: float, k: int = 0, inverse: bool = False) IrrepsArray[source]

Rotate the data by angles according to the irreps.

Parameters:
  • alpha (float) – third rotation angle around the second axis (in radians)

  • beta (float) – second rotation angle around the first axis (in radians)

  • gamma (float) – first rotation angle around the second axis (in radians)

  • k (int) – parity operation

  • inverse (bool) – if True, apply the inverse rotation

Returns:

rotated data

Return type:

IrrepsArray

Examples

>>> np.set_printoptions(precision=3, suppress=True)
>>> x = IrrepsArray("2e", jnp.array([0.1, 2, 1.0, 1, 1]))
>>> x.transform_by_angles(jnp.pi, 0, 0)
1x2e [ 0.1 -2.   1.  -1.   1. ]
transform_by_axis_angle(axis: Array, angle: float, k: int = 0) IrrepsArray[source]

Rotate data by a rotation given by an axis and an angle.

Parameters:
  • axis (jax.Array) – axis

  • angle (float) – angle (in radians)

  • k (int) – parity operation

Returns:

rotated data

Return type:

IrrepsArray

transform_by_log_coordinates(log_coordinates: Array, k: int = 0) IrrepsArray[source]

Rotate data by a rotation given by log coordinates.

Parameters:
  • log_coordinates (jax.Array) – log coordinates

  • k (int) – parity operation

Returns:

rotated data

Return type:

IrrepsArray

transform_by_matrix(R: Array) IrrepsArray[source]

Rotate data by a rotation given by a matrix.

Parameters:

R (jax.Array) – rotation matrix

Returns:

rotated data

Return type:

IrrepsArray

transform_by_quaternion(q: Array, k: int = 0) IrrepsArray[source]

Rotate data by a rotation given by a quaternion.

Parameters:
Returns:

rotated data

Return type:

IrrepsArray

unify() IrrepsArray[source]

Unify the irreps.

Examples

>>> IrrepsArray("0e + 0x1e + 0e", jnp.ones(2)).unify()
1x0e+0x1e+1x0e [1. 1.]
e3nn_jax.from_chunks(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]], chunks: List[Array | None], leading_shape: Tuple[int, ...], dtype=None, *, backend=None) IrrepsArray[source]

Create an IrrepsArray from a list of arrays.

Parameters:
  • irreps (Irreps) – irreps

  • chunks (list of optional jax.Array) – list of arrays

  • leading_shape (tuple of int) – leading shape of the arrays (without the irreps)

Returns:

IrrepsArray

e3nn_jax.as_irreps_array(array: Array | IrrepsArray, *, backend=None)[source]

Convert an array to an IrrepsArray.

Parameters:

array (jax.Array or IrrepsArray) – array to convert

Returns:

IrrepsArray

e3nn_jax.zeros(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]], leading_shape: Tuple = (), dtype: dtype = None) IrrepsArray[source]

Create an IrrepsArray of zeros.

e3nn_jax.zeros_like(irreps_array: IrrepsArray) IrrepsArray[source]

Create an IrrepsArray of zeros with the same shape as another IrrepsArray.

e3nn_jax.concatenate(arrays: List[IrrepsArray], axis: int = -1) IrrepsArray[source]

Concatenate a list of IrrepsArray.

Parameters:
  • arrays (list of IrrepsArray) – list of data to concatenate

  • axis (int) – axis to concatenate on

Returns:

concatenated array

Return type:

IrrepsArray

Examples

>>> x = e3nn.IrrepsArray("3x0e + 2x0o", jnp.arange(2 * 5).reshape(2, 5))
>>> y = e3nn.IrrepsArray("3x0e + 2x0o", jnp.arange(2 * 5).reshape(2, 5) + 10)
>>> e3nn.concatenate([x, y], axis=0)
3x0e+2x0o
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]]
>>> e3nn.concatenate([x, y], axis=1)
3x0e+2x0o+3x0e+2x0o
[[ 0  1  2  3  4 10 11 12 13 14]
 [ 5  6  7  8  9 15 16 17 18 19]]
e3nn_jax.mean(array: IrrepsArray, axis: None | int | Tuple[int, ...] = None, keepdims: bool = False) IrrepsArray[source]

Mean of IrrepsArray along the specified axis.

Parameters:
  • array (IrrepsArray) – input array

  • axis (optional int or tuple of ints) – axis along which the mean is computed.

Returns:

mean of the input array

Return type:

IrrepsArray

Examples

>>> x = e3nn.IrrepsArray("3x0e + 2x0e", jnp.arange(2 * 5).reshape(2, 5))
>>> e3nn.mean(x, axis=0)
3x0e+2x0e [2.5 3.5 4.5 5.5 6.5]
>>> e3nn.mean(x, axis=1)
1x0e+1x0e
[[1.  3.5]
 [6.  8.5]]
>>> e3nn.mean(x)
1x0e+1x0e [3.5 6. ]
e3nn_jax.norm(array: IrrepsArray, *, squared: bool = False, per_irrep: bool = True) IrrepsArray[source]

Norm of IrrepsArray.

Parameters:
  • array (IrrepsArray) – input array

  • squared (bool) – if True, return the squared norm

  • per_irrep (bool) – if True, return the norm of each irrep individually

Returns:

norm of the input array

Return type:

IrrepsArray

Examples

>>> x = e3nn.IrrepsArray("2x0e + 1e + 2e", jnp.arange(10.0))
>>> e3nn.norm(x)
2x0e+1x0e+1x0e [ 0.         1.         5.3851647 15.9687195]
>>> e3nn.norm(x, squared=True)
2x0e+1x0e+1x0e [  0.   1.  29. 255.]
>>> e3nn.norm(x, per_irrep=False)
1x0e [16.881943]
e3nn_jax.dot(a: IrrepsArray, b: IrrepsArray, per_irrep: bool = False) IrrepsArray[source]

Dot product of two IrrepsArray.

Parameters:
  • a (IrrepsArray) – first array (this array get complex conjugated)

  • b (IrrepsArray) – second array

  • per_irrep (bool) – if True, return the dot product of each irrep individually

Returns:

dot product of the two input arrays, as a scalar

Return type:

IrrepsArray

Examples

>>> x = e3nn.IrrepsArray("0e + 1e", jnp.array([1.0j, 1.0, 0.0, 0.0]))
>>> y = e3nn.IrrepsArray("0e + 1e", jnp.array([1.0, 2.0, 1.0, 1.0]))
>>> e3nn.dot(x, y)
1x0e [2.-1.j]
>>> e3nn.dot(x, y, per_irrep=True)
1x0e+1x0e [0.-1.j 2.+0.j]
e3nn_jax.cross(a: IrrepsArray, b: IrrepsArray) IrrepsArray[source]

Cross product of two IrrepsArray.

Parameters:
Returns:

cross product of the two input arrays

Return type:

IrrepsArray

Examples

>>> x = e3nn.IrrepsArray("1o", jnp.array([1.0, 0.0, 0.0]))
>>> y = e3nn.IrrepsArray("1e", jnp.array([0.0, 1.0, 0.0]))
>>> e3nn.cross(x, y)
1x1o [0. 0. 1.]
e3nn_jax.normal(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]], key: Array = None, leading_shape: Tuple[int, ...] = (), *, normalize: bool = False, normalization: str | None = None, dtype: dtype | None = None) IrrepsArray[source]

Random array with normal distribution.

Parameters:
  • irreps (Irreps) – irreps of the output array

  • key (jax.Array) – random key (if not provided, use the hash of the irreps as seed, usefull for debugging)

  • leading_shape (tuple of int) – shape of the leading dimensions

  • normalize (bool) – if True, normalize the output array

  • normalization (str) – normalization of the output array, "component" or "norm" This parameter is ignored if normalize=False. This parameter only affects the variance distribution.

Returns:

random array

Return type:

IrrepsArray

Examples

>>> jnp.set_printoptions(precision=2, suppress=True)
>>> e3nn.normal("1o").shape
(3,)

Generate a random array with normalization "component"

>>> x = e3nn.normal("0e + 5e", jax.random.PRNGKey(0), (), normalization="component")
>>> x
1x0e+1x5e [ 1.19 -1.1   0.44  0.6  -0.39  0.69  0.46 -2.07 -0.21 -0.99 -0.68  0.27]
>>> e3nn.norm(x, squared=True)
1x0e+1x0e [1.42 8.45]

Generate a random array with normalization "norm"

>>> x = e3nn.normal("0e + 5e", jax.random.PRNGKey(0), (), normalization="norm")
>>> x
1x0e+1x5e [-1.25  0.11 -0.24 -0.4   0.37  0.07  0.15 -0.38  0.35 -0.4   0.03 -0.18]
>>> e3nn.norm(x, squared=True)
1x0e+1x0e [1.57 0.85]

Generate normalized random array

>>> x = e3nn.normal("0e + 5e", jax.random.PRNGKey(0), (), normalize=True)
>>> x
1x0e+1x5e [-1.    0.12 -0.26 -0.43  0.4   0.08  0.16 -0.41  0.37 -0.44  0.03 -0.19]
>>> e3nn.norm(x, squared=True)
1x0e+1x0e [1. 1.]
e3nn_jax.sum(array: IrrepsArray, axis: None | int | Tuple[int, ...] = None, keepdims: bool = False) IrrepsArray[source]

Sum of IrrepsArray along the specified axis.

Parameters:
  • array (IrrepsArray) – input array

  • axis (optional int or tuple of ints) – axis along which the sum is computed.

Returns:

sum of the input array

Return type:

IrrepsArray

Examples

>>> x = e3nn.IrrepsArray("3x0e + 2x0e", jnp.arange(2 * 5).reshape(2, 5))
>>> e3nn.sum(x, axis=0)
3x0e+2x0e [ 5  7  9 11 13]
>>> e3nn.sum(x, axis=1)
1x0e+1x0e
[[ 3  7]
 [18 17]]
>>> e3nn.sum(x)
1x0e+1x0e [21 24]
>>> e3nn.sum(x.regroup())
1x0e [45]
e3nn_jax.where(mask: Array, x: IrrepsArray, y: IrrepsArray)[source]

Selects elements from x or y, depending on mask.

Equivalent to:
>>> e3nn.IrrepsArray(x.irreps, jnp.where(mask, x.array, y.array))
Parameters:
  • mask – Boolean array of shape (..., num_irreps) or (..., 1).

  • x – IrrepsArray of shape (..., irreps.dim).

  • y – IrrepsArray of shape (..., irreps.dim).

Returns:

IrrepsArray of shape (..., irreps.dim).