Irreps

class e3nn_jax.Irrep(l: int | Irrep | MulIrrep | Tuple[int, int], p=None)[source]

Bases: object

Irreducible representation of \(O(3)\).

This class does not contain any data, it is a structure that describe the representation. It is typically used as argument of other classes of the library to define the input and output representations of functions.

Parameters:
  • l – non-negative integer, the degree of the representation, \(l = 0, 1, \dots\)

  • p – {1, -1}, the parity of the representation

Examples

Create a scalar representation (\(l=0\)) of even parity.

>>> Irrep(0, 1)
0e

Create a pseudotensor representation (\(l=2\)) of odd parity.

>>> Irrep(2, -1)
2o

Create a vector representation (\(l=1\)) of the parity of the spherical harmonics (\(-1^l\) gives odd parity).

>>> Irrep("1y")
1o
>>> Irrep("2o").dim
5
>>> Irrep("2e") in Irrep("1o") * Irrep("1o")
True
>>> Irrep("1o") + Irrep("2o")
1x1o+1x2o

Methods:

D_from_angles(alpha, beta, gamma[, k])

Matrix \(p^k D^l(\alpha, \beta, \gamma)\).

D_from_log_coordinates(log_coordinates[, k])

Matrix \(p^k D^l(\alpha)\).

D_from_matrix(R)

Matrix of the representation.

D_from_quaternion(q[, k])

Matrix of the representation, see Irrep.D_from_angles.

generators()

Generators of the representation of \(SO(3)\).

is_scalar()

Equivalent to l == 0 and p == 1.

iterator([lmax])

Iterator through all the irreps of \(O(3)\).

Attributes:

dim

The dimension of the representation, \(2 l + 1\).

D_from_angles(alpha, beta, gamma, k=0)[source]

Matrix \(p^k D^l(\alpha, \beta, \gamma)\).

(matrix) Representation of \(O(3)\). \(D\) is the representation of \(SO(3)\).

Parameters:
  • alpha (jax.Array) – of shape \((...)\) Rotation \(\alpha\) around Y axis, applied third.

  • beta (jax.Array) – of shape \((...)\) Rotation \(\beta\) around X axis, applied second.

  • gamma (jax.Array) – of shape \((...)\) Rotation \(\gamma\) around Y axis, applied first.

  • k (optional jax.Array) – of shape \((...)\) How many times the parity is applied.

Returns:

of shape \((..., 2l+1, 2l+1)\)

Return type:

jax.Array

See also

Irreps.D_from_angles

D_from_log_coordinates(log_coordinates, k=0)[source]

Matrix \(p^k D^l(\alpha)\).

(matrix) Representation of \(O(3)\). \(D\) is the representation of \(SO(3)\).

Parameters:
  • log_coordinates (jax.Array) – of shape \((..., 3)\)

  • k (optional jax.Array) – of shape \((...)\) How many times the parity is applied.

Returns:

of shape \((..., 2l+1, 2l+1)\)

Return type:

jax.Array

See also

Irreps.D_from_log_coordinates

D_from_matrix(R)[source]

Matrix of the representation.

Parameters:
  • R (jax.Array) – array of shape \((..., 3, 3)\)

  • k (jax.Array, optional) – array of shape \((...)\)

Returns:

array of shape \((..., 2l+1, 2l+1)\)

Return type:

jax.Array

Examples

>>> m = Irrep(1, -1).D_from_matrix(-jnp.eye(3))
>>> m + 0.0
Array([[-1.,  0.,  0.],
       [ 0., -1.,  0.],
       [ 0.,  0., -1.]], dtype=float32)
D_from_quaternion(q, k=0)[source]

Matrix of the representation, see Irrep.D_from_angles.

Parameters:
Returns:

shape \((..., 2l+1, 2l+1)\)

Return type:

jax.Array

property dim: int[source]

The dimension of the representation, \(2 l + 1\).

generators()[source]

Generators of the representation of \(SO(3)\).

Returns:

array of shape \((3, 2l+1, 2l+1)\)

Return type:

jax.Array

See also

generators

is_scalar() bool[source]

Equivalent to l == 0 and p == 1.

classmethod iterator(lmax=None)[source]

Iterator through all the irreps of \(O(3)\).

Examples

>>> it = Irrep.iterator()
>>> next(it), next(it), next(it), next(it)
(0e, 0o, 1o, 1e)
class e3nn_jax.Irreps(irreps: None | Irrep | MulIrrep | str | Irreps | List[str | Irrep | MulIrrep | Tuple[int, int | Irrep | MulIrrep | Tuple[int, int]]] = None)[source]

Bases: tuple

Direct sum of irreducible representations of \(O(3)\).

This class does not contain any data, it is a structure that describe the representation. It is typically used as argument of other classes of the library to define the input and output representations of functions.

dim[source]

the total dimension of the representation

Type:

int

num_irreps[source]

number of irreps. the sum of the multiplicities

Type:

int

ls[source]

list of \(l\) values

Type:

list of int

lmax[source]

maximum \(l\) value

Type:

int

Examples

>>> x = Irreps([(100, (0, 1)), (50, (1, 1))])
>>> x
100x0e+50x1e
>>> x.dim
250
>>> Irreps("100x0e + 50x1e")
100x0e+50x1e
>>> Irreps("100x0e + 50x1e + 0x2e")
100x0e+50x1e+0x2e
>>> Irreps("100x0e + 50x1e + 0x2e").lmax
1
>>> Irrep("2e") in Irreps("0e + 2e")
True

Empty Irreps

>>> Irreps(), Irreps("")
(Irreps(), Irreps())

Methods:

D_from_angles(alpha, beta, gamma[, k])

Compute the D matrix from the angles.

D_from_log_coordinates(log_coordinates[, k])

Matrix of the representation.

D_from_matrix(R)

Matrix of the representation.

D_from_quaternion(q[, k])

Matrix of the representation.

count(ir)

Multiplicity of ir.

filter([keep, drop, lmax])

Filter the irreps.

generators()

Generators of the representation.

index(_object)

Return first index of value.

is_scalar()

Check if the representation is scalar.

regroup()

Regroup the same irreps together.

remove_zero_multiplicities()

Remove any irreps with multiplicities of zero.

repeat(n)

Repeat the representation n times.

set_mul(mul)

Set the multiplicities to one.

simplify()

Simplify the representations.

slices()

List of slices corresponding to indices for each irrep.

sort()

Sort the representations.

spherical_harmonics(lmax[, p])

Representation of the spherical harmonics.

unify()

Regroup same irrep together.

Attributes:

dim

Dimension of the irreps.

lmax

Maximum l value.

ls

List of the l values.

mul_gcd

Greatest common divisor of the multiplicities.

num_irreps

Sum of the multiplicities.

slice_by_chunk

Return the slice with respect to the chunks.

slice_by_dim

Return the slice with respect to the dimensions.

slice_by_mul

Return the slice with respect to the multiplicities.

D_from_angles(alpha, beta, gamma, k=0)[source]

Compute the D matrix from the angles.

Parameters:
  • alpha (float) – third rotation angle around the second axis (in radians)

  • beta (float) – second rotation angle around the first axis (in radians)

  • gamma (float) – first rotation angle around the second axis (in radians)

  • k (int) – parity operation

Returns:

array of shape \((..., \mathrm{dim}, \mathrm{dim})\)

Return type:

jax.Array

D_from_log_coordinates(log_coordinates, k=0)[source]

Matrix of the representation.

Parameters:
  • log_coordinates (jax.Array) – array of shape \((..., 3)\)

  • k (jax.Array, optional) – array of shape \((...)\)

Returns:

array of shape \((..., \mathrm{dim}, \mathrm{dim})\)

Return type:

jax.Array

D_from_matrix(R)[source]

Matrix of the representation.

Parameters:

R (jax.Array) – array of shape \((..., 3, 3)\)

Returns:

array of shape \((..., \mathrm{dim}, \mathrm{dim})\)

Return type:

jax.Array

D_from_quaternion(q, k=0)[source]

Matrix of the representation.

Parameters:
  • q (jax.Array) – array of shape \((..., 4)\)

  • k (jax.Array, optional) – array of shape \((...)\)

Returns:

array of shape \((..., \mathrm{dim}, \mathrm{dim})\)

Return type:

jax.Array

count(ir: int | Irrep | MulIrrep | Tuple[int, int]) int[source]

Multiplicity of ir.

Parameters:

ir (Irrep) –

Returns:

total multiplicity of ir

Return type:

int

Examples

>>> Irreps("2x0e + 3x1o").count("1o")
3
property dim: int[source]

Dimension of the irreps.

Examples

>>> Irreps("3x0e + 2x1e").dim
9
filter(keep: Irreps | List[Irrep] | Callable[[MulIrrep], bool] = None, *, drop: Irreps | List[Irrep] | Callable[[MulIrrep], bool] = None, lmax: int = None) Irreps[source]

Filter the irreps.

Parameters:
  • keep (Irreps or list of Irrep or function) – list of irrep to keep

  • drop (Irreps or list of Irrep or function) – list of irrep to drop

  • lmax (int) – maximum \(l\) value

Returns:

filtered irreps

Return type:

Irreps

Examples

>>> Irreps("1e + 2e + 0e").filter(keep=["0e", "1e"])
1x1e+1x0e
>>> Irreps("1e + 2e + 0e").filter(keep="2e + 2x1e")
1x1e+1x2e
>>> Irreps("1e + 2e + 0e").filter(drop="2e + 2x1e")
1x0e
>>> Irreps("1e + 2e + 0e").filter(lmax=1)
1x1e+1x0e
generators() Array[source]

Generators of the representation.

Returns:

array of shape \((3, \mathrm{dim}, \mathrm{dim})\)

Return type:

jax.Array

index(_object)[source]

Return first index of value.

Raises ValueError if the value is not present.

is_scalar() bool[source]

Check if the representation is scalar.

Returns:

True if the representation is scalar

Return type:

bool

Examples

>>> Irreps("2x0e + 3x1o").is_scalar()
False
>>> Irreps("2x0e + 2x0e").is_scalar()
True
>>> Irreps("0o").is_scalar()
False
property lmax: int[source]

Maximum l value.

Examples

>>> Irreps("3x0e + 2x1e").lmax
1
property ls: List[int][source]

List of the l values.

Examples

>>> Irreps("3x0e + 2x1e").ls
[0, 0, 0, 1, 1]
property mul_gcd: int[source]

Greatest common divisor of the multiplicities.

Examples

>>> Irreps("3x0e + 2x1e").mul_gcd
1
property num_irreps: int[source]

Sum of the multiplicities.

Examples

>>> Irreps("3x0e + 2x1e").num_irreps
5
regroup() Irreps[source]

Regroup the same irreps together.

Equivalent to sort() followed by simplify().

Returns:

regrouped irreps

Return type:

Irreps

Examples

>>> Irreps("1e + 0e + 1e + 0x2e").regroup()
1x0e+2x1e
remove_zero_multiplicities() Irreps[source]

Remove any irreps with multiplicities of zero.

Examples

>>> Irreps("4x0e + 0x1o + 2x3e").remove_zero_multiplicities()
4x0e+2x3e
repeat(n: int) Irreps[source]

Repeat the representation n times.

Examples

>>> Irreps('0e + 1e').repeat(2)
1x0e+1x1e+1x0e+1x1e
set_mul(mul: int) Irreps[source]

Set the multiplicities to one.

Examples

>>> Irreps("2x0e + 1x1e").set_mul(1)
1x0e+1x1e
simplify() Irreps[source]

Simplify the representations.

Examples

Note that simplify does not sort the representations.

>>> Irreps("1e + 1e + 0e").simplify()
2x1e+1x0e

Equivalent representations which are separated from each other are not combined.

>>> Irreps("1e + 1e + 0e + 1e").simplify()
2x1e+1x0e+1x1e

Except if they are separated by an irrep with multiplicity of zero.

>>> Irreps("1e + 0x0e + 1e").simplify().simplify()
2x1e
property slice_by_chunk[source]

Return the slice with respect to the chunks.

Examples

>>> Irreps("2x1e + 2e + 3x0e").slice_by_chunk[:1]
2x1e
>>> Irreps("1e + 2e + 3x0e").slice_by_chunk[1:]
1x2e+3x0e
property slice_by_dim[source]

Return the slice with respect to the dimensions.

Examples

>>> Irreps("1e + 2e + 3x0e").slice_by_dim[:3]
1x1e
>>> Irreps("1e + 2e + 3x0e").slice_by_dim[3:8]
1x2e
property slice_by_mul[source]

Return the slice with respect to the multiplicities.

Examples

>>> Irreps("2x1e + 2e").slice_by_mul[2:]
1x2e
>>> Irreps("1e + 2e + 3x0e").slice_by_mul[1:3]
1x2e+1x0e
>>> Irreps("1e + 2e + 3x0e").slice_by_mul[1:]
1x2e+3x0e
slices() List[slice][source]

List of slices corresponding to indices for each irrep.

Examples

>>> Irreps('2x0e + 1e').slices()
[slice(0, 2, None), slice(2, 5, None)]
sort() Sort[source]

Sort the representations.

Returns:

tuple containing:

irreps (Irreps): sorted irreps p (tuple of int): permutation of the indices inv (tuple of int): inverse permutation of the indices

Return type:

(tuple)

Examples

>>> Irreps("1e + 0e + 1e").sort().irreps
1x0e+1x1e+1x1e
>>> Irreps("2o + 1e + 0e + 1e").sort().p
(3, 1, 0, 2)
>>> Irreps("2o + 1e + 0e + 1e").sort().inv
(2, 1, 3, 0)
static spherical_harmonics(lmax, p=-1)[source]

Representation of the spherical harmonics.

Parameters:
  • lmax (int) – maximum \(l\)

  • p (optional {1, -1}) – the parity of the representation

Returns:

representation of \((Y^0, Y^1, \dots, Y^{\mathrm{lmax}})\)

Return type:

Irreps

Examples

>>> Irreps.spherical_harmonics(3)
1x0e+1x1o+1x2e+1x3o
>>> Irreps.spherical_harmonics(4, p=1)
1x0e+1x1e+1x2e+1x3e+1x4e
unify() Irreps[source]

Regroup same irrep together.

Returns:

new Irreps object

Return type:

Irreps

Examples

>>> Irreps('0e + 1e').unify()
1x0e+1x1e
>>> Irreps('0e + 1e + 1e').unify()
1x0e+2x1e
>>> Irreps('0e + 0x1e + 0e').unify()
1x0e+0x1e+1x0e