Signal on the Sphere

e3nn_jax.s2_irreps(lmax: int, p_val: int = 1, p_arg: int = -1) Irreps[source]

The Irreps of coefficients of a spherical harmonics expansion.

\[f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\]

When the inversion operator is applied to the signal, the new function \(I f\) is given by

\[[I f](\vec x) = p_{\text{val}} f(p_{\text{arg}} \vec x)\]
Parameters:
  • lmax (int) – maximum degree of the expansion

  • p_val (int) – parity of the value of the signal on the sphere (1 or -1)

  • p_arg (int) – parity of the argument of the signal on the sphere (1 or -1)

e3nn_jax.s2_dirac(position: Array | IrrepsArray, lmax: int, *, p_val: int = 1, p_arg: int = -1) IrrepsArray[source]

Spherical harmonics expansion of a Dirac delta on the sphere.

The integral of the Dirac delta is 1.

Parameters:
  • position (jax.Array or IrrepsArray) – position of the delta, shape (3,). It will be normalized to have a norm of 1.

  • lmax (int) – maximum degree of the spherical harmonics expansion

  • p_val (int) – parity of the value of the signal on the sphere (1 or -1)

  • p_arg (int) – parity of the argument of the signal on the sphere (1 or -1)

Returns:

Spherical harmonics coefficients

Return type:

IrrepsArray

Examples:

position = jnp.array([0.0, 0.0, 1.0])

coeffs_3 = e3nn.s2_dirac(position, 3, p_val=1, p_arg=-1)
coeffs_6 = e3nn.s2_dirac(position, 6, p_val=1, p_arg=-1)
coeffs_9 = e3nn.s2_dirac(position, 9, p_val=1, p_arg=-1)

Note

To compute a sum of weighted Dirac deltas, use:

positions = jnp.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1.0]])
weights = jnp.array([1, 1, -1, -1.0])

e3nn.sum(e3nn.s2_dirac(positions, 4, p_val=1, p_arg=-1) * weights[:, None], axis=0)
1x0e+1x1o+1x2e+1x3o+1x4e
[ 0.          0.13783222 -0.13783222 -0.13783222  0.          0.
 -0.17794065  0.         -0.30820224 -0.16644822  0.         -0.12893024
 -0.21054222  0.12893024  0.         -0.16644822  0.          0.
  0.          0.         -0.23873241  0.          0.26691097  0.
  0.        ]
e3nn_jax.to_s2point(coeffs: IrrepsArray, point: IrrepsArray, *, normalization: str = 'integral') IrrepsArray[source]

Evaluate a signal on the sphere given by the coefficient in the spherical harmonics basis.

It computes the same thing as to_s2grid() but at a single point.

Parameters:
  • coeffs (IrrepsArray) – coefficient array of shape (*shape1, irreps)

  • point (jax.Array) – point on the sphere of shape (*shape2, 3)

  • normalization ({'norm', 'component', 'integral'}) – normalization of the basis

Returns:

signal on the sphere of shape (*shape1, *shape2, irreps)

Return type:

IrrepsArray

e3nn_jax.to_s2grid(coeffs: IrrepsArray, res_beta: int, res_alpha: int, *, quadrature: str, normalization: str = 'integral', fft: bool = True, p_val: int | None = None, p_arg: int | None = None, use_s2fft: bool = False) SphericalSignal[source]

Sample a signal on the sphere given by the coefficient in the spherical harmonics basis.

The inverse transformation of from_s2grid()

Parameters:
  • coeffs (IrrepsArray) – coefficient array

  • res_beta (int) – number of points on the sphere in the \(\theta\) direction

  • res_alpha (int) – number of points on the sphere in the \(\phi\) direction

  • normalization ({'norm', 'component', 'integral'}) – normalization of the basis

  • quadrature (str) – “soft” or “gausslegendre”

  • fft (bool) – True if we use FFT, False if we use the naive implementation

  • p_val (int, optional) – parity of the value of the signal

  • p_arg (int, optional) – parity of the argument of the signal

Returns:

signal on the sphere of shape (..., y/beta, alpha)

Return type:

SphericalSignal

Note

We use a rectangular grid for the \(\beta\) and \(\alpha\) angles. The grid is uniform in the \(\alpha\) angle while for \(\beta\), two different quadratures are available:

  • The soft quadrature is a uniform sampling of the beta angle.

  • The gauss-legendre quadrature is a quadrature rule that is exact for polynomials of degree 2 res_beta - 1. On the sphere it is exact only for polynomials of \(y\).

e3nn_jax.from_s2grid(x: SphericalSignal, irreps: Irreps, *, normalization: str = 'integral', lmax_in: int | None = None, fft: bool = True, use_s2fft: bool = False) IrrepsArray[source]

Transform signal on the sphere into spherical harmonics coefficients.

The output has degree \(l\) between 0 and lmax, and parity \(p = p_{val}p_{arg}^l\)

The inverse transformation of to_s2grid()

Parameters:
  • x (SphericalSignal) – signal on the sphere of shape (..., y/beta, alpha)

  • irreps (Irreps) – irreps of the coefficients

  • normalization ({'norm', 'component', 'integral'}) – normalization of the spherical harmonics basis

  • lmax_in (int, optional) – maximum degree of the input signal, only used for normalization purposes

  • fft (bool) – True if we use FFT, False if we use the naive implementation

Returns:

coefficient array of shape (..., (lmax+1)^2)

Return type:

IrrepsArray

class e3nn_jax.SphericalSignal(grid_values: Array, quadrature: str, *, p_val: int = 1, p_arg: int = -1, _perform_checks: bool = True)[source]

Bases: object

Representation of a signal on the sphere.

Parameters:
  • grid_values – values of the signal on a grid, shape (res_beta, res_alpha)

  • quadrature – quadrature used to create the grid, either "soft" or "gausslegendre"

  • p_val – parity of the signal, either +1 or -1

  • p_arg – parity of the argument of the signal, either +1 or -1

Examples

Create a signal from a function defined on the sphere: .. jupyter-execute:

def f(coords):
    x, y, z = coords
    return x**2 - y**2

signal = e3nn.SphericalSignal.from_function(f, 50, 49, quadrature="soft")
signal

Create a signal of zeros:

e3nn.SphericalSignal.zeros(50, 49, quadrature="soft")
SphericalSignal(shape=(50, 49), res_beta=50, res_alpha=49, quadrature=soft, p_val=1, p_arg=-1)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]

Create a signal from a spherical harmonic expansion:

coeffs = e3nn.IrrepsArray("0e + 1o", jnp.array([1.0, 0.0, 2.0, 0.0]))
signal = e3nn.to_s2grid(coeffs, 50, 49, quadrature="soft")
signal
SphericalSignal(shape=(50, 49), res_beta=50, res_alpha=49, quadrature=soft, p_val=1, p_arg=-1)
[[-2.462 -2.462 -2.462 ... -2.462 -2.462 -2.462]
 [-2.449 -2.449 -2.449 ... -2.449 -2.449 -2.449]
 [-2.421 -2.421 -2.421 ... -2.421 -2.421 -2.421]
 ...
 [ 4.421  4.421  4.421 ...  4.421  4.421  4.421]
 [ 4.449  4.449  4.449 ...  4.449  4.449  4.449]
 [ 4.462  4.462  4.462 ...  4.462  4.462  4.462]]

Apply a function to the signal:

signal = signal.apply(jnp.exp)
signal
SphericalSignal(shape=(50, 49), res_beta=50, res_alpha=49, quadrature=soft, p_val=1, p_arg=-1)
[[ 0.085  0.085  0.085 ...  0.085  0.085  0.085]
 [ 0.086  0.086  0.086 ...  0.086  0.086  0.086]
 [ 0.089  0.089  0.089 ...  0.089  0.089  0.089]
 ...
 [83.217 83.217 83.217 ... 83.217 83.217 83.217]
 [85.518 85.518 85.518 ... 85.518 85.518 85.518]
 [86.695 86.695 86.695 ... 86.695 86.695 86.695]]

Convert the signal back to a spherical harmonic expansion:

irreps = e3nn.s2_irreps(4)
coeffs = e3nn.from_s2grid(signal, irreps)
coeffs["4e"]
1x4e [0.  0.  0.  0.  2.1 0.  0.  0.  0. ]

Resample the signal to a different grid resolution:

signal = signal.resample(100, 99, lmax=5)
signal
SphericalSignal(shape=(100, 99), res_beta=100, res_alpha=99, quadrature=soft, p_val=1, p_arg=-1)
[[-0.443 -0.443 -0.443 ... -0.443 -0.443 -0.443]
 [-0.433 -0.433 -0.433 ... -0.433 -0.433 -0.433]
 [-0.412 -0.412 -0.412 ... -0.412 -0.412 -0.412]
 ...
 [85.097 85.097 85.097 ... 85.097 85.097 85.097]
 [85.647 85.647 85.647 ... 85.647 85.647 85.647]
 [85.923 85.923 85.923 ... 85.923 85.923 85.923]]

Compute the integral of the signal:

signal.integrate()
1x0e [157.361]

Rotate the signal (we need to determine lmax because the rotation is done in the Fourier domain):

signal = signal.transform_by_angles(jnp.pi / 2, jnp.pi / 3, 0.0, lmax=5)

Sample a point on the sphere, using the signal as a density function:

indices = signal.sample(jax.random.PRNGKey(0))
signal.grid_vectors[indices], signal.grid_values[indices]
(Array([ 0.958,  0.11 , -0.265], dtype=float32), Array(58.511, dtype=float32))

Plot the signal:

import plotly.graph_objects as go
go.Figure([go.Surface(signal.plotly_surface())])

Methods:

apply(func)

Applies a function pointwise on the grid.

find_peaks(lmax)

Locate peaks on the signal on the sphere.

from_function(func, res_beta, res_alpha, ...)

Create a signal on the sphere from a function of the coordinates.

integrate()

Integrate the signal on the sphere.

pad_to_plot(*[, translation, radius, ...])

Postprocess the borders of a given signal to allow to plot with plotly.

plotly_surface([translation, radius, ...])

Returns a dictionary that can be plotted with plotly.

replace_values(grid_values)

Replace the grid values of the signal.

resample(res_beta, res_alpha, lmax[, quadrature])

Resamples a signal via the spherical harmonic coefficients.

sample(key)

Sample a point on the sphere using the signal as a probability distribution.

transform_by_angles(alpha, beta, gamma, lmax)

Rotate the signal by the given Euler angles.

transform_by_axis_angle(axis, angle, lmax)

Rotate the signal by the given angle around an axis.

transform_by_matrix(R, lmax)

Rotate the signal by the given rotation matrix.

transform_by_quaternion(q, lmax)

Rotate the signal by the given quaternion.

zeros(res_beta, res_alpha, quadrature, *[, ...])

Create a null signal on a grid.

Attributes:

dtype

Returns the dtype of this signal.

grid_alpha

Returns alpha values on the grid for this signal.

grid_resolution

Grid resolution for (beta, alpha).

grid_vectors

(res_beta, res_alpha, 3).

grid_y

Returns y-values on the grid for this signal.

ndim

Returns the number of dimensions of this signal.

quadrature_weights

Returns quadrature weights along the y-coordinates.

res_alpha

Grid resolution for alpha.

res_beta

Grid resolution for beta.

shape

Returns the shape of this signal.

apply(func: Callable[[Array], Array]) SphericalSignal[source]

Applies a function pointwise on the grid.

property dtype: dtype[source]

Returns the dtype of this signal.

find_peaks(lmax: int) Tuple[ndarray, ndarray][source]

Locate peaks on the signal on the sphere.

Currently cannot be wrapped with jax.jit().

static from_function(func: ~typing.Callable[[~jax.Array], float], res_beta: int, res_alpha: int, quadrature: str, *, p_val: int = 1, p_arg: int = -1, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) SphericalSignal[source]

Create a signal on the sphere from a function of the coordinates.

Parameters:
  • func (Callable) – function on the sphere that maps a 3-dimensional array (x, y, z) to a number

  • res_beta – resolution for beta

  • res_alpha – resolution for alpha

  • quadrature – quadrature to use

  • p_val – parity of the signal, either +1 or -1

  • p_arg – parity of the argument of the signal, either +1 or -1

  • dtype – dtype of the signal

Returns:

signal on the sphere

Return type:

SphericalSignal

property grid_alpha: Array[source]

Returns alpha values on the grid for this signal.

property grid_resolution: Tuple[int, int][source]

Grid resolution for (beta, alpha).

property grid_vectors: Array[source]

(res_beta, res_alpha, 3).

Type:

Returns the coordinates of the points on the sphere. Shape

property grid_y: Array[source]

Returns y-values on the grid for this signal.

integrate() IrrepsArray[source]

Integrate the signal on the sphere.

The integral of a constant signal of value 1 is 4pi.

Returns:

integral of the signal

Return type:

IrrepsArray

property ndim: int[source]

Returns the number of dimensions of this signal.

pad_to_plot(*, translation: Array | None = None, radius: float = 1.0, scale_radius_by_amplitude: bool = False, normalize_radius_by_max_amplitude: bool = False) Tuple[Array, Array][source]

Postprocess the borders of a given signal to allow to plot with plotly.

Parameters:
  • translation (optional) – translation vector

  • radius (float) – radius of the sphere

  • scale_radius_by_amplitude (bool) – to rescale the output vectors with the amplitude of the signal

  • normalize_radius_by_max_amplitude (bool) – when scale_radius_by_amplitude is True, rescales the surface so that the maximum amplitude is equal to the radius

Returns:

vectors on the sphere, shape (res_beta + 2, res_alpha + 1, 3) f (jax.Array): padded signal, shape (res_beta + 2, res_alpha + 1)

Return type:

r (jax.Array)

plotly_surface(translation: Array | None = None, radius: float = 1.0, scale_radius_by_amplitude: bool = False, normalize_radius_by_max_amplitude: bool = False)[source]

Returns a dictionary that can be plotted with plotly.

Parameters:
  • translation (optional) – translation vector

  • radius (float) – radius of the sphere

  • scale_radius_by_amplitude (bool) – to rescale the output vectors with the amplitude of the signal

  • normalize_radius_by_max_amplitude (bool) – when scale_radius_by_amplitude is True, rescales the surface so that the maximum amplitude is equal to the radius

Returns:

dictionary that can be plotted with plotly

Return type:

dict

Examples:

import jax.numpy as jnp
import e3nn_jax as e3nn
coeffs = e3nn.normal(e3nn.s2_irreps(5), jax.random.PRNGKey(0))
signal = e3nn.to_s2grid(coeffs, 70, 141, quadrature="gausslegendre")

import plotly.graph_objects as go
go.Figure([go.Surface(signal.plotly_surface())])

One can also scale the radius of the sphere by the amplitude of the signal:

go.Figure([go.Surface(signal.plotly_surface(scale_radius_by_amplitude=True))])
property quadrature_weights: Array[source]

Returns quadrature weights along the y-coordinates.

replace_values(grid_values: Array) SphericalSignal[source]

Replace the grid values of the signal.

property res_alpha: int[source]

Grid resolution for alpha.

property res_beta: int[source]

Grid resolution for beta.

resample(res_beta: int, res_alpha: int, lmax: int, quadrature: str | None = None) SphericalSignal[source]

Resamples a signal via the spherical harmonic coefficients.

Parameters:
  • res_beta – New resolution for beta.

  • res_alpha – New resolution for alpha.

  • lmax – Maximum l for the spherical harmonics.

  • quadrature – Quadrature to use. Defaults to reusing the current quadrature.

Returns:

A new SphericalSignal with the new resolution.

sample(key: Array) Tuple[Array, Array][source]

Sample a point on the sphere using the signal as a probability distribution.

The probability distribution does not need to be normalized.

Parameters:

key (jax.Array) – random key

Returns:

tuple containing:

beta_index (jax.Array): index of the sampled beta alpha_index (jax.Array): index of the sampled alpha

Return type:

(tuple)

Examples:

coeffs = e3nn.IrrepsArray("0e + 1o", jnp.array([1.0, 2.0, 0.0, 0.0]))
signal = e3nn.to_s2grid(coeffs, 50, 69, quadrature="gausslegendre")
signal = signal.apply(jnp.exp)

beta_index, alpha_index = signal.sample(jax.random.PRNGKey(0))
print(beta_index, alpha_index)
print(signal.grid_vectors[beta_index, alpha_index])
19 20
[ 0.913 -0.336 -0.233]
property shape: Tuple[int, ...][source]

Returns the shape of this signal.

transform_by_angles(alpha: float, beta: float, gamma: float, lmax: int) SphericalSignal[source]

Rotate the signal by the given Euler angles.

transform_by_axis_angle(axis: Array, angle: float, lmax: int) SphericalSignal[source]

Rotate the signal by the given angle around an axis.

transform_by_matrix(R: Array, lmax: int) SphericalSignal[source]

Rotate the signal by the given rotation matrix.

transform_by_quaternion(q: Array, lmax: int) SphericalSignal[source]

Rotate the signal by the given quaternion.

static zeros(res_beta: int, res_alpha: int, quadrature: str, *, p_val: int = 1, p_arg: int = -1, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) SphericalSignal[source]

Create a null signal on a grid.