Basic functions for Rotations

e3nn_jax.rand_matrix(key, shape=(), dtype=<class 'jax.numpy.float32'>)[source]

Random rotation matrix.

Parameters:
  • key – a PRNGKey used as the random key.

  • shape – a tuple of nonnegative integers representing the result shape.

Returns:

array of shape \((..., 3, 3)\)

Return type:

jax.Array

e3nn_jax.identity_angles(shape=(), dtype=<class 'jax.numpy.float32'>)[source]

Angles of the identity rotation.

Parameters:

shape – a tuple of nonnegative integers representing the result shape.

Returns:

tuple containing:

alpha (jax.Array): array of shape \((...)\) beta (jax.Array): array of shape \((...)\) gamma (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.rand_angles(key, shape=(), dtype=<class 'jax.numpy.float32'>)[source]

Random rotation angles.

Parameters:
  • key – a PRNGKey used as the random key.

  • shape – a tuple of nonnegative integers representing the result shape.

Returns:

tuple containing:

alpha (jax.Array): array of shape \((...)\) beta (jax.Array): array of shape \((...)\) gamma (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.compose_angles(a1, b1, c1, a2, b2, c2)[source]

Compose angles.

Computes \((a, b, c)\) such that \(R(a, b, c) = R(a_1, b_1, c_1) \circ R(a_2, b_2, c_2)\)

Parameters:
  • alpha1 (jax.Array) – array of shape \((...)\)

  • beta1 (jax.Array) – array of shape \((...)\)

  • gamma1 (jax.Array) – array of shape \((...)\)

  • alpha2 (jax.Array) – array of shape \((...)\)

  • beta2 (jax.Array) – array of shape \((...)\)

  • gamma2 (jax.Array) – array of shape \((...)\)

Returns:

tuple containing:

alpha (jax.Array): array of shape \((...)\) beta (jax.Array): array of shape \((...)\) gamma (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.inverse_angles(a, b, c)[source]

Angles of the inverse rotation.

Parameters:
  • alpha (jax.Array) – array of shape \((...)\)

  • beta (jax.Array) – array of shape \((...)\)

  • gamma (jax.Array) – array of shape \((...)\)

Returns:

tuple containing:

alpha (jax.Array): array of shape \((...)\) beta (jax.Array): array of shape \((...)\) gamma (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.identity_quaternion(shape=(), dtype=<class 'jax.numpy.float32'>)[source]

Quaternion of identity rotation.

Parameters:

shape – a tuple of nonnegative integers representing the result shape.

Returns:

array of shape \((..., 4)\)

Return type:

jax.Array

e3nn_jax.rand_quaternion(key, shape=(), dtype=<class 'jax.numpy.float32'>)[source]

Generate random quaternion.

Parameters:
  • key – a PRNGKey used as the random key.

  • shape – a tuple of nonnegative integers representing the result shape.

Returns:

array of shape \((..., 4)\)

Return type:

jax.Array

e3nn_jax.compose_quaternion(q1, q2)[source]

Compose two quaternions: \(q_1 \circ q_2\).

Parameters:
  • q1 (jax.Array) – array of shape \((..., 4)\)

  • q2 (jax.Array) – array of shape \((..., 4)\)

Returns:

array of shape \((..., 4)\)

Return type:

jax.Array

e3nn_jax.inverse_quaternion(q)[source]

Inverse of a quaternion.

Works only for unit quaternions.

Parameters:

q (jax.Array) – array of shape \((..., 4)\)

Returns:

array of shape \((..., 4)\)

Return type:

jax.Array

e3nn_jax.rand_axis_angle(key, shape=(), dtype=<class 'jax.numpy.float32'>)[source]

Generate random rotation as axis-angle.

Parameters:
  • key – a PRNGKey used as the random key.

  • shape – a tuple of nonnegative integers representing the result shape.

Returns:

tuple containing:

axis (jax.Array): array of shape \((..., 3)\) angle (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.compose_axis_angle(axis1, angle1, axis2, angle2)[source]

Compose \((\vec x_1, \alpha_1)\) with \((\vec x_2, \alpha_2)\).

Parameters:
  • axis1 (jax.Array) – array of shape \((..., 3)\)

  • angle1 (jax.Array) – array of shape \((...)\)

  • axis2 (jax.Array) – array of shape \((..., 3)\)

  • angle2 (jax.Array) – array of shape \((...)\)

Returns:

tuple containing:

axis (jax.Array): array of shape \((..., 3)\) angle (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.matrix_x(angle)[source]

Matrix of rotation around X axis.

Parameters:

angle (jax.Array) – array of shape \((...)\)

Returns:

array of shape \((..., 3, 3)\)

Return type:

jax.Array

e3nn_jax.matrix_y(angle)[source]

Matrix of rotation around Y axis.

Parameters:

angle (jax.Array) – array of shape \((...)\)

Returns:

array of shape \((..., 3, 3)\)

Return type:

jax.Array

e3nn_jax.matrix_z(angle)[source]

Matrix of rotation around Z axis.

Parameters:

angle (jax.Array) – array of shape \((...)\)

Returns:

array of shape \((..., 3, 3)\)

Return type:

jax.Array

e3nn_jax.angles_to_matrix(alpha, beta, gamma)[source]

Conversion from angles to matrix.

Parameters:
  • alpha (jax.Array) – array of shape \((...)\)

  • beta (jax.Array) – array of shape \((...)\)

  • gamma (jax.Array) – array of shape \((...)\)

Returns:

array of shape \((..., 3, 3)\)

Return type:

jax.Array

e3nn_jax.matrix_to_angles(R)[source]

Conversion from matrix to angles. Warning: this function is not differentiable at rotation angles \(\pi\).

Parameters:

R (jax.Array) – array of shape \((..., 3, 3)\)

Returns:

tuple containing:

alpha (jax.Array): array of shape \((...)\) beta (jax.Array): array of shape \((...)\) gamma (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.angles_to_quaternion(alpha, beta, gamma)[source]

Conversion from angles to quaternion.

Parameters:
  • alpha (jax.Array) – array of shape \((...)\)

  • beta (jax.Array) – array of shape \((...)\)

  • gamma (jax.Array) – array of shape \((...)\)

Returns:

array of shape \((..., 4)\)

Return type:

q (jax.Array)

e3nn_jax.matrix_to_quaternion(R)[source]

Conversion from matrix \(R\) to quaternion \(q\).

Parameters:

R (jax.Array) – array of shape \((..., 3, 3)\)

Returns:

array of shape \((..., 4)\)

Return type:

q (jax.Array)

e3nn_jax.axis_angle_to_quaternion(xyz, angle)[source]

Conversion from axis-angle to quaternion.

Parameters:
  • axis (jax.Array) – array of shape \((..., 3)\)

  • angle (jax.Array) – array of shape \((...)\)

Returns:

array of shape \((..., 4)\)

Return type:

q (jax.Array)

e3nn_jax.quaternion_to_axis_angle(q)[source]

Conversion from quaternion to axis-angle.

Parameters:

q (jax.Array) – array of shape \((..., 4)\)

Returns:

tuple containing:

axis (jax.Array): array of shape \((..., 3)\) angle (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.matrix_to_axis_angle(R)[source]

Conversion from matrix to axis-angle. Warning: this function is not differentiable at rotation angles \(\pi\).

Parameters:

R (jax.Array) – array of shape \((..., 3, 3)\)

Returns:

tuple containing:

axis (jax.Array): array of shape \((..., 3)\) angle (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.angles_to_axis_angle(alpha, beta, gamma)[source]

Conversion from angles to axis-angle.

Parameters:
  • alpha (jax.Array) – array of shape \((...)\)

  • beta (jax.Array) – array of shape \((...)\)

  • gamma (jax.Array) – array of shape \((...)\)

Returns:

tuple containing:

axis (jax.Array): array of shape \((..., 3)\) angle (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.axis_angle_to_matrix(axis, angle)[source]

Conversion from axis-angle to matrix.

Parameters:
  • axis (jax.Array) – array of shape \((..., 3)\)

  • angle (jax.Array) – array of shape \((...)\)

Returns:

array of shape \((..., 3, 3)\)

Return type:

jax.Array

e3nn_jax.quaternion_to_matrix(q)[source]

Conversion from quaternion to matrix.

Parameters:

q (jax.Array) – array of shape \((..., 4)\)

Returns:

array of shape \((..., 3, 3)\)

Return type:

jax.Array

e3nn_jax.quaternion_to_angles(q)[source]

Conversion from quaternion to angles.

Parameters:

q (jax.Array) – array of shape \((..., 4)\)

Returns:

tuple containing:

alpha (jax.Array): array of shape \((...)\) beta (jax.Array): array of shape \((...)\) gamma (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.axis_angle_to_angles(axis, angle)[source]

Conversion from axis-angle to angles.

Parameters:
  • axis (jax.Array) – array of shape \((..., 3)\)

  • angle (jax.Array) – array of shape \((...)\)

Returns:

tuple containing:

alpha (jax.Array): array of shape \((...)\) beta (jax.Array): array of shape \((...)\) gamma (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.angles_to_xyz(alpha, beta)[source]

Convert \((\alpha, \beta)\) into a point \((x, y, z)\) on the sphere.

Parameters:
  • alpha (jax.Array) – array of shape \((...)\)

  • beta (jax.Array) – array of shape \((...)\)

Returns:

array of shape \((..., 3)\)

Return type:

jax.Array

Examples

>>> angles_to_xyz(1.7, 0.0) + 0.0
Array([0., 1., 0.], dtype=float32, weak_type=True)
e3nn_jax.xyz_to_angles(xyz)[source]

The rotation \(R(\alpha, \beta, 0)\) such that \(\vec r = R \vec e_y\).

\[ \begin{align}\begin{aligned}\vec r = R(\alpha, \beta, 0) \vec e_y\\\alpha = \arctan(x/z)\\\beta = \arccos(y)\end{aligned}\end{align} \]
Parameters:

xyz (jax.Array) – array of shape \((..., 3)\)

Returns:

tuple containing:

alpha (jax.Array): array of shape \((...)\) beta (jax.Array): array of shape \((...)\)

Return type:

(tuple)

e3nn_jax.clebsch_gordan(l1: int, l2: int, l3: int) ndarray[source]

The Clebsch-Gordan coefficients of the real irreducible representations of \(SO(3)\).

Parameters:
  • l1 (int) – the representation order of the first irrep

  • l2 (int) – the representation order of the second irrep

  • l3 (int) – the representation order of the third irrep

Returns:

the Clebsch-Gordan coefficients

Return type:

np.ndarray

e3nn_jax.generators(l: int) ndarray[source]

The generators of the real irreducible representations of \(SO(3)\).

Parameters:

l (int) – the representation order of the irrep

Returns:

the generators

Return type:

np.ndarray