IrrepsArray
- class e3nn_jax.IrrepsArray(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]], array: Array, *, zero_flags=None, chunks: List[Array | None] | None = None)[source]
Bases:
object
Array with a representation of rotations.
The IrrepsArray class enforce equivariance by storing an array of data (
.array
) along with its representation (.irreps
).The data is stored as a single array of shape
(..., irreps.dim)
.The data can be accessed as a list of arrays (
.chunks
) matching each item of the.irreps
.- Parameters:
Examples
>>> import e3nn_jax as e3nn >>> x = e3nn.IrrepsArray("1o + 2x0e", jnp.ones(5)) >>> y = e3nn.from_chunks("1o + 2x0e", [None, jnp.ones((2, 1))], ()) >>> x + y 1x1o+2x0e [1. 1. 1. 2. 2.]
Example of indexing:
>>> x = IrrepsArray("0e + 1o", jnp.arange(2 * 4).reshape(2, 4)) >>> x[0] 1x0e+1x1o [0 1 2 3] >>> x[1, "0e"] 1x0e [4] >>> x[:, 1:] 1x1o [[1 2 3] [5 6 7]] >>> IrrepsArray("5x0e", jnp.arange(5))[1:3] 2x0e [1 2]
Methods:
astype
(dtype)Change the dtype of the array.
axis_to_irreps
([axis])Repeat the irreps by the last axis of the array.
axis_to_mul
([axis])Repeat the multiplicity by the previous last axis of the array.
broadcast_to
(shape)Broadcast the array to a new shape.
extend_with_zeros
(new_irreps)Extend the array with zeros.
filter
([keep, drop, lmax])Filter the irreps.
mul_to_axis
([factor, axis])Create a new axis in the previous last position by factoring the multiplicities.
rechunk
(irreps)Rechunk the array with new (equivalent) irreps.
regroup
()Regroup the same irreps together.
Remove all zero chunks.
repeat_irreps_by_last_axis
([axis])Repeat the irreps by the last axis of the array.
reshape
(shape)Reshape the array.
simplify
()Simplify the irreps.
sort
()Sort the irreps.
transform_by_angles
(alpha, beta, gamma[, k, ...])Rotate the data by angles according to the irreps.
transform_by_axis_angle
(axis, angle[, k])Rotate data by a rotation given by an axis and an angle.
transform_by_log_coordinates
(log_coordinates)Rotate data by a rotation given by log coordinates.
Rotate data by a rotation given by a matrix.
transform_by_quaternion
(q[, k])Rotate data by a rotation given by a quaternion.
unify
()Unify the irreps.
Attributes:
List of arrays matching each item of the
.irreps
.dtype.
Number of dimensions.
Shape.
Return the slice with respect to the chunks.
Same as
__getitem__
in the irreps dimension.Return the slice with respect to the multiplicities.
- astype(dtype) IrrepsArray [source]
Change the dtype of the array.
- Parameters:
dtype (dtype) – new dtype
- Returns:
new IrrepsArray
- Return type:
- axis_to_irreps(axis: int = -2) IrrepsArray [source]
Repeat the irreps by the last axis of the array.
Examples
>>> x = IrrepsArray("0e + 1e", jnp.arange(2 * 4).reshape(2, 4)) >>> x.axis_to_irreps() 1x0e+1x1e+1x0e+1x1e [0 1 2 3 4 5 6 7]
- axis_to_mul(axis: int = -2) IrrepsArray [source]
Repeat the multiplicity by the previous last axis of the array.
Decrease the dimension of the array by 1.
- Parameters:
axis (int) – axis to convert into multiplicity
Examples
>>> x = IrrepsArray("0e + 1e", jnp.arange(2 * 4).reshape(2, 4)) >>> x.axis_to_mul() 2x0e+2x1e [0 4 1 2 3 5 6 7]
- broadcast_to(shape) IrrepsArray [source]
Broadcast the array to a new shape.
- property chunks: List[Array | None][source]
List of arrays matching each item of the
.irreps
.Examples
>>> x = IrrepsArray("2x0e + 0e", jnp.arange(3)) >>> len(x.chunks) 2 >>> x.chunks[0] Array([[0], [1]], dtype=int32) >>> x.chunks[1] Array([[2]], dtype=int32)
The follwing is always true:
>>> all(e.shape == x.shape[:-1] + (mul, ir.dim) for (mul, ir), e in zip(x.irreps, x.chunks)) True
- extend_with_zeros(new_irreps: Irreps) IrrepsArray [source]
Extend the array with zeros.
- Parameters:
new_irreps (Irreps) – new irreps, must be a superset of the current irreps
Examples
>>> IrrepsArray("0e + 1o", jnp.array([1, 3, 3, 3])).extend_with_zeros("0e + 0e + 1o + 2x0e") 1x0e+1x0e+1x1o+2x0e [1 0 3 3 3 0 0]
- filter(keep: Irreps | List[Irrep] | Callable[[MulIrrep], bool] = None, *, drop: Irreps | List[Irrep] | Callable[[MulIrrep], bool] = None, lmax: int = None) IrrepsArray [source]
Filter the irreps.
- Parameters:
Examples
>>> IrrepsArray("0e + 2x1o + 2x0e", jnp.arange(9)).filter(["1o"]) 2x1o [1 2 3 4 5 6]
- mul_to_axis(factor: int | None = None, axis: int = -2) IrrepsArray [source]
Create a new axis in the previous last position by factoring the multiplicities.
Increase the dimension of the array by 1.
- Parameters:
Examples
>>> x = IrrepsArray("6x0e + 3x1e", jnp.arange(15)) >>> x.mul_to_axis() 2x0e+1x1e [[ 0 1 6 7 8] [ 2 3 9 10 11] [ 4 5 12 13 14]]
- rechunk(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]]) IrrepsArray [source]
Rechunk the array with new (equivalent) irreps.
- Parameters:
irreps (Irreps) – new irreps
- Returns:
new IrrepsArray
- Return type:
Examples
>>> x = e3nn.from_chunks("6x0e + 4x0e", [None, jnp.ones((4, 1))], ()) >>> x.rechunk("5x0e + 5x0e").chunks [None, Array([[0.], [1.], [1.], [1.], [1.]], dtype=float32)]
- regroup() IrrepsArray [source]
Regroup the same irreps together.
Equivalent to
sorted()
followed bysimplify()
.Examples
>>> IrrepsArray("0e + 1o + 2x0e", jnp.arange(6)).regroup() 3x0e+1x1o [0 4 5 1 2 3]
- remove_zero_chunks() IrrepsArray [source]
Remove all zero chunks.
- repeat_irreps_by_last_axis(axis: int = -2) IrrepsArray [source]
Repeat the irreps by the last axis of the array.
Examples
>>> x = IrrepsArray("0e + 1e", jnp.arange(2 * 4).reshape(2, 4)) >>> x.axis_to_irreps() 1x0e+1x1e+1x0e+1x1e [0 1 2 3 4 5 6 7]
- reshape(shape) IrrepsArray [source]
Reshape the array.
- Parameters:
shape (tuple) – new shape
- Returns:
new IrrepsArray
- Return type:
Examples
>>> IrrepsArray("2x0e + 1o", jnp.ones((6, 5))).reshape((2, 3, 5)) 2x0e+1x1o [[[1. 1. 1. 1. 1.] [1. 1. 1. 1. 1.] [1. 1. 1. 1. 1.]] [[1. 1. 1. 1. 1.] [1. 1. 1. 1. 1.] [1. 1. 1. 1. 1.]]]
- simplify() IrrepsArray [source]
Simplify the irreps.
Examples
>>> IrrepsArray("0e + 0e + 0e", jnp.ones(3)).simplify() 3x0e [1. 1. 1.]
>>> IrrepsArray("0e + 0x1e + 0e", jnp.ones(2)).simplify() 2x0e [1. 1.]
- sort() IrrepsArray [source]
Sort the irreps.
Examples
>>> IrrepsArray("0e + 1o + 2x0e", jnp.arange(6)).sort() 1x0e+2x0e+1x1o [0 4 5 1 2 3]
- transform_by_angles(alpha: float, beta: float, gamma: float, k: int = 0, inverse: bool = False) IrrepsArray [source]
Rotate the data by angles according to the irreps.
- Parameters:
- Returns:
rotated data
- Return type:
Examples
>>> np.set_printoptions(precision=3, suppress=True) >>> x = IrrepsArray("2e", jnp.array([0.1, 2, 1.0, 1, 1])) >>> x.transform_by_angles(jnp.pi, 0, 0) 1x2e [ 0.1 -2. 1. -1. 1. ]
- transform_by_axis_angle(axis: Array, angle: float, k: int = 0) IrrepsArray [source]
Rotate data by a rotation given by an axis and an angle.
- Parameters:
- Returns:
rotated data
- Return type:
- transform_by_log_coordinates(log_coordinates: Array, k: int = 0) IrrepsArray [source]
Rotate data by a rotation given by log coordinates.
- Parameters:
- Returns:
rotated data
- Return type:
- transform_by_matrix(R: Array) IrrepsArray [source]
Rotate data by a rotation given by a matrix.
- Parameters:
R (
jax.Array
) – rotation matrix- Returns:
rotated data
- Return type:
- transform_by_quaternion(q: Array, k: int = 0) IrrepsArray [source]
Rotate data by a rotation given by a quaternion.
- Parameters:
- Returns:
rotated data
- Return type:
- unify() IrrepsArray [source]
Unify the irreps.
Examples
>>> IrrepsArray("0e + 0x1e + 0e", jnp.ones(2)).unify() 1x0e+0x1e+1x0e [1. 1.]
- e3nn_jax.from_chunks(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]], chunks: List[Array | None], leading_shape: Tuple[int, ...], dtype=None, *, backend=None) IrrepsArray [source]
Create an IrrepsArray from a list of arrays.
- e3nn_jax.as_irreps_array(array: Array | IrrepsArray, *, backend=None)[source]
Convert an array to an IrrepsArray.
- Parameters:
array (jax.Array or IrrepsArray) – array to convert
- Returns:
IrrepsArray
- e3nn_jax.zeros(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]], leading_shape: Tuple = (), dtype: dtype = None) IrrepsArray [source]
Create an IrrepsArray of zeros.
- e3nn_jax.zeros_like(irreps_array: IrrepsArray) IrrepsArray [source]
Create an IrrepsArray of zeros with the same shape as another IrrepsArray.
- e3nn_jax.concatenate(arrays: List[IrrepsArray], axis: int = -1) IrrepsArray [source]
Concatenate a list of IrrepsArray.
- Parameters:
arrays (list of
IrrepsArray
) – list of data to concatenateaxis (int) – axis to concatenate on
- Returns:
concatenated array
- Return type:
Examples
>>> x = e3nn.IrrepsArray("3x0e + 2x0o", jnp.arange(2 * 5).reshape(2, 5)) >>> y = e3nn.IrrepsArray("3x0e + 2x0o", jnp.arange(2 * 5).reshape(2, 5) + 10) >>> e3nn.concatenate([x, y], axis=0) 3x0e+2x0o [[ 0 1 2 3 4] [ 5 6 7 8 9] [10 11 12 13 14] [15 16 17 18 19]] >>> e3nn.concatenate([x, y], axis=1) 3x0e+2x0o+3x0e+2x0o [[ 0 1 2 3 4 10 11 12 13 14] [ 5 6 7 8 9 15 16 17 18 19]]
- e3nn_jax.mean(array: IrrepsArray, axis: None | int | Tuple[int, ...] = None, keepdims: bool = False) IrrepsArray [source]
Mean of IrrepsArray along the specified axis.
- Parameters:
array (
IrrepsArray
) – input arrayaxis (optional int or tuple of ints) – axis along which the mean is computed.
- Returns:
mean of the input array
- Return type:
Examples
>>> x = e3nn.IrrepsArray("3x0e + 2x0e", jnp.arange(2 * 5).reshape(2, 5)) >>> e3nn.mean(x, axis=0) 3x0e+2x0e [2.5 3.5 4.5 5.5 6.5] >>> e3nn.mean(x, axis=1) 1x0e+1x0e [[1. 3.5] [6. 8.5]] >>> e3nn.mean(x) 1x0e+1x0e [3.5 6. ]
- e3nn_jax.norm(array: IrrepsArray, *, squared: bool = False, per_irrep: bool = True) IrrepsArray [source]
Norm of IrrepsArray.
- Parameters:
array (IrrepsArray) – input array
squared (bool) – if True, return the squared norm
per_irrep (bool) – if True, return the norm of each irrep individually
- Returns:
norm of the input array
- Return type:
Examples
>>> x = e3nn.IrrepsArray("2x0e + 1e + 2e", jnp.arange(10.0)) >>> e3nn.norm(x) 2x0e+1x0e+1x0e [ 0. 1. 5.3851647 15.9687195]
>>> e3nn.norm(x, squared=True) 2x0e+1x0e+1x0e [ 0. 1. 29. 255.]
>>> e3nn.norm(x, per_irrep=False) 1x0e [16.881943]
- e3nn_jax.dot(a: IrrepsArray, b: IrrepsArray, per_irrep: bool = False) IrrepsArray [source]
Dot product of two IrrepsArray.
- Parameters:
a (IrrepsArray) – first array (this array get complex conjugated)
b (IrrepsArray) – second array
per_irrep (bool) – if True, return the dot product of each irrep individually
- Returns:
dot product of the two input arrays, as a scalar
- Return type:
Examples
>>> x = e3nn.IrrepsArray("0e + 1e", jnp.array([1.0j, 1.0, 0.0, 0.0])) >>> y = e3nn.IrrepsArray("0e + 1e", jnp.array([1.0, 2.0, 1.0, 1.0])) >>> e3nn.dot(x, y) 1x0e [2.-1.j]
>>> e3nn.dot(x, y, per_irrep=True) 1x0e+1x0e [0.-1.j 2.+0.j]
- e3nn_jax.cross(a: IrrepsArray, b: IrrepsArray) IrrepsArray [source]
Cross product of two IrrepsArray.
- Parameters:
a (IrrepsArray) – first array of vectors
b (IrrepsArray) – second array of vectors
- Returns:
cross product of the two input arrays
- Return type:
Examples
>>> x = e3nn.IrrepsArray("1o", jnp.array([1.0, 0.0, 0.0])) >>> y = e3nn.IrrepsArray("1e", jnp.array([0.0, 1.0, 0.0])) >>> e3nn.cross(x, y) 1x1o [0. 0. 1.]
- e3nn_jax.normal(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]], key: Array = None, leading_shape: Tuple[int, ...] = (), *, normalize: bool = False, normalization: str | None = None, dtype: dtype | None = None) IrrepsArray [source]
Random array with normal distribution.
- Parameters:
irreps (Irreps) – irreps of the output array
key (jax.Array) – random key (if not provided, use the hash of the irreps as seed, usefull for debugging)
leading_shape (tuple of int) – shape of the leading dimensions
normalize (bool) – if True, normalize the output array
normalization (str) – normalization of the output array,
"component"
or"norm"
This parameter is ignored ifnormalize=False
. This parameter only affects the variance distribution.
- Returns:
random array
- Return type:
Examples
>>> jnp.set_printoptions(precision=2, suppress=True) >>> e3nn.normal("1o").shape (3,)
Generate a random array with normalization
"component"
>>> x = e3nn.normal("0e + 5e", jax.random.PRNGKey(0), (), normalization="component") >>> x 1x0e+1x5e [ 1.19 -1.1 0.44 0.6 -0.39 0.69 0.46 -2.07 -0.21 -0.99 -0.68 0.27] >>> e3nn.norm(x, squared=True) 1x0e+1x0e [1.42 8.45]
Generate a random array with normalization
"norm"
>>> x = e3nn.normal("0e + 5e", jax.random.PRNGKey(0), (), normalization="norm") >>> x 1x0e+1x5e [-1.25 0.11 -0.24 -0.4 0.37 0.07 0.15 -0.38 0.35 -0.4 0.03 -0.18] >>> e3nn.norm(x, squared=True) 1x0e+1x0e [1.57 0.85]
Generate normalized random array
>>> x = e3nn.normal("0e + 5e", jax.random.PRNGKey(0), (), normalize=True) >>> x 1x0e+1x5e [-1. 0.12 -0.26 -0.43 0.4 0.08 0.16 -0.41 0.37 -0.44 0.03 -0.19] >>> e3nn.norm(x, squared=True) 1x0e+1x0e [1. 1.]
- e3nn_jax.sum(array: IrrepsArray, axis: None | int | Tuple[int, ...] = None, keepdims: bool = False) IrrepsArray [source]
Sum of IrrepsArray along the specified axis.
- Parameters:
array (
IrrepsArray
) – input arrayaxis (optional int or tuple of ints) – axis along which the sum is computed.
- Returns:
sum of the input array
- Return type:
Examples
>>> x = e3nn.IrrepsArray("3x0e + 2x0e", jnp.arange(2 * 5).reshape(2, 5)) >>> e3nn.sum(x, axis=0) 3x0e+2x0e [ 5 7 9 11 13] >>> e3nn.sum(x, axis=1) 1x0e+1x0e [[ 3 7] [18 17]] >>> e3nn.sum(x) 1x0e+1x0e [21 24] >>> e3nn.sum(x.regroup()) 1x0e [45]
- e3nn_jax.where(mask: Array, x: IrrepsArray, y: IrrepsArray) IrrepsArray [source]
Selects elements from
x
ory
, depending onmask
.- Equivalent to:
>>> e3nn.IrrepsArray(x.irreps, jnp.where(mask, x.array, y.array))
- Parameters:
mask – Boolean array of shape
(..., num_irreps)
or(..., 1)
.x – IrrepsArray of shape
(..., irreps.dim)
.y – IrrepsArray of shape
(..., irreps.dim)
.
- Returns:
IrrepsArray of shape
(..., irreps.dim)
.