Basic functions for Rotations
- e3nn_jax.rand_matrix(key, shape=(), dtype=<class 'jax.numpy.float32'>)[source]
Random rotation matrix.
- Parameters:
key – a PRNGKey used as the random key.
shape – a tuple of nonnegative integers representing the result shape.
- Returns:
array of shape \((..., 3, 3)\)
- Return type:
- e3nn_jax.identity_angles(shape=(), dtype=<class 'jax.numpy.float32'>)[source]
Angles of the identity rotation.
- Parameters:
shape – a tuple of nonnegative integers representing the result shape.
- Returns:
tuple containing:
- Return type:
(tuple)
- e3nn_jax.rand_angles(key, shape=(), dtype=<class 'jax.numpy.float32'>)[source]
Random rotation angles.
- Parameters:
key – a PRNGKey used as the random key.
shape – a tuple of nonnegative integers representing the result shape.
- Returns:
tuple containing:
- Return type:
(tuple)
- e3nn_jax.compose_angles(a1, b1, c1, a2, b2, c2)[source]
Compose angles.
Computes \((a, b, c)\) such that \(R(a, b, c) = R(a_1, b_1, c_1) \circ R(a_2, b_2, c_2)\)
- Parameters:
- Returns:
tuple containing:
- Return type:
(tuple)
- e3nn_jax.identity_quaternion(shape=(), dtype=<class 'jax.numpy.float32'>)[source]
Quaternion of identity rotation.
- Parameters:
shape – a tuple of nonnegative integers representing the result shape.
- Returns:
array of shape \((..., 4)\)
- Return type:
- e3nn_jax.rand_quaternion(key, shape=(), dtype=<class 'jax.numpy.float32'>)[source]
Generate random quaternion.
- Parameters:
key – a PRNGKey used as the random key.
shape – a tuple of nonnegative integers representing the result shape.
- Returns:
array of shape \((..., 4)\)
- Return type:
- e3nn_jax.rand_axis_angle(key, shape=(), dtype=<class 'jax.numpy.float32'>)[source]
Generate random rotation as axis-angle.
- Parameters:
key – a PRNGKey used as the random key.
shape – a tuple of nonnegative integers representing the result shape.
- Returns:
tuple containing:
- Return type:
(tuple)
- e3nn_jax.compose_axis_angle(axis1, angle1, axis2, angle2)[source]
Compose \((\vec x_1, \alpha_1)\) with \((\vec x_2, \alpha_2)\).
- e3nn_jax.matrix_to_angles(R)[source]
Conversion from matrix to angles. Warning: this function is not differentiable at rotation angles \(\pi\).
- e3nn_jax.matrix_to_axis_angle(R)[source]
Conversion from matrix to axis-angle. Warning: this function is not differentiable at rotation angles \(\pi\).
- e3nn_jax.angles_to_xyz(alpha, beta)[source]
Convert \((\alpha, \beta)\) into a point \((x, y, z)\) on the sphere.
- Parameters:
- Returns:
array of shape \((..., 3)\)
- Return type:
Examples
>>> angles_to_xyz(1.7, 0.0) + 0.0 Array([0., 1., 0.], dtype=float32, weak_type=True)