Signal on the Sphere
- e3nn_jax.s2_irreps(lmax: int, p_val: int = 1, p_arg: int = -1) Irreps [source]
The Irreps of coefficients of a spherical harmonics expansion.
\[f(\vec x) = \sum_{l=0}^{L} \sum_{m=-l}^{l} c_l^m Y_{l,m}(\vec x)\]When the inversion operator is applied to the signal, the new function \(I f\) is given by
\[[I f](\vec x) = p_{\text{val}} f(p_{\text{arg}} \vec x)\]
- e3nn_jax.s2_dirac(position: Array | IrrepsArray, lmax: int, *, p_val: int = 1, p_arg: int = -1) IrrepsArray [source]
Spherical harmonics expansion of a Dirac delta on the sphere.
The integral of the Dirac delta is 1.
- Parameters:
position (
jax.Array
orIrrepsArray
) – position of the delta, shape(3,)
. It will be normalized to have a norm of 1.lmax (int) – maximum degree of the spherical harmonics expansion
p_val (int) – parity of the value of the signal on the sphere (1 or -1)
p_arg (int) – parity of the argument of the signal on the sphere (1 or -1)
- Returns:
Spherical harmonics coefficients
- Return type:
Examples:
position = jnp.array([0.0, 0.0, 1.0]) coeffs_3 = e3nn.s2_dirac(position, 3, p_val=1, p_arg=-1) coeffs_6 = e3nn.s2_dirac(position, 6, p_val=1, p_arg=-1) coeffs_9 = e3nn.s2_dirac(position, 9, p_val=1, p_arg=-1)
Note
To compute a sum of weighted Dirac deltas, use:
positions = jnp.array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1.0]]) weights = jnp.array([1, 1, -1, -1.0]) e3nn.sum(e3nn.s2_dirac(positions, 4, p_val=1, p_arg=-1) * weights[:, None], axis=0)
1x0e+1x1o+1x2e+1x3o+1x4e [ 0. 0.13783222 -0.13783222 -0.13783222 0. 0. -0.17794065 0. -0.30820224 -0.16644822 0. -0.12893024 -0.21054222 0.12893024 0. -0.16644822 0. 0. 0. 0. -0.23873241 0. 0.26691097 0. 0. ]
- e3nn_jax.to_s2point(coeffs: IrrepsArray, point: IrrepsArray, *, normalization: str = 'integral') IrrepsArray [source]
Evaluate a signal on the sphere given by the coefficient in the spherical harmonics basis.
It computes the same thing as
to_s2grid()
but at a single point.- Parameters:
coeffs (
IrrepsArray
) – coefficient array of shape(*shape1, irreps)
point (
jax.Array
) – point on the sphere of shape(*shape2, 3)
normalization ({'norm', 'component', 'integral'}) – normalization of the basis
- Returns:
signal on the sphere of shape
(*shape1, *shape2, irreps)
- Return type:
- e3nn_jax.to_s2grid(coeffs: IrrepsArray, res_beta: int, res_alpha: int, *, quadrature: str, normalization: str = 'integral', fft: bool = True, p_val: int | None = None, p_arg: int | None = None, use_s2fft: bool = False) SphericalSignal [source]
Sample a signal on the sphere given by the coefficient in the spherical harmonics basis.
The inverse transformation of
from_s2grid()
- Parameters:
coeffs (
IrrepsArray
) – coefficient arrayres_beta (int) – number of points on the sphere in the \(\theta\) direction
res_alpha (int) – number of points on the sphere in the \(\phi\) direction
normalization ({'norm', 'component', 'integral'}) – normalization of the basis
quadrature (str) – “soft” or “gausslegendre”
fft (bool) – True if we use FFT, False if we use the naive implementation
p_val (int, optional) – parity of the value of the signal
p_arg (int, optional) – parity of the argument of the signal
- Returns:
signal on the sphere of shape
(..., y/beta, alpha)
- Return type:
Note
We use a rectangular grid for the \(\beta\) and \(\alpha\) angles. The grid is uniform in the \(\alpha\) angle while for \(\beta\), two different quadratures are available:
The soft quadrature is a uniform sampling of the beta angle.
The gauss-legendre quadrature is a quadrature rule that is exact for polynomials of degree
2 res_beta - 1
. On the sphere it is exact only for polynomials of \(y\).
- e3nn_jax.from_s2grid(x: SphericalSignal, irreps: Irreps, *, normalization: str = 'integral', lmax_in: int | None = None, fft: bool = True, use_s2fft: bool = False) IrrepsArray [source]
Transform signal on the sphere into spherical harmonics coefficients.
The output has degree \(l\) between 0 and lmax, and parity \(p = p_{val}p_{arg}^l\)
The inverse transformation of
to_s2grid()
- Parameters:
x (
SphericalSignal
) – signal on the sphere of shape(..., y/beta, alpha)
irreps (
Irreps
) – irreps of the coefficientsnormalization ({'norm', 'component', 'integral'}) – normalization of the spherical harmonics basis
lmax_in (int, optional) – maximum degree of the input signal, only used for normalization purposes
fft (bool) – True if we use FFT, False if we use the naive implementation
- Returns:
coefficient array of shape
(..., (lmax+1)^2)
- Return type:
- class e3nn_jax.SphericalSignal(grid_values: Array, quadrature: str, *, p_val: int = 1, p_arg: int = -1, _perform_checks: bool = True)[source]
Bases:
object
Representation of a signal on the sphere.
- Parameters:
grid_values – values of the signal on a grid, shape
(res_beta, res_alpha)
quadrature – quadrature used to create the grid, either
"soft"
or"gausslegendre"
p_val – parity of the signal, either
+1
or-1
p_arg – parity of the argument of the signal, either
+1
or-1
Examples
Create a signal from a function defined on the sphere: .. jupyter-execute:
def f(coords): x, y, z = coords return x**2 - y**2 signal = e3nn.SphericalSignal.from_function(f, 50, 49, quadrature="soft") signal
Create a signal of zeros:
e3nn.SphericalSignal.zeros(50, 49, quadrature="soft")
SphericalSignal(shape=(50, 49), res_beta=50, res_alpha=49, quadrature=soft, p_val=1, p_arg=-1) [[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]]
Create a signal from a spherical harmonic expansion:
coeffs = e3nn.IrrepsArray("0e + 1o", jnp.array([1.0, 0.0, 2.0, 0.0])) signal = e3nn.to_s2grid(coeffs, 50, 49, quadrature="soft") signal
SphericalSignal(shape=(50, 49), res_beta=50, res_alpha=49, quadrature=soft, p_val=1, p_arg=-1) [[-2.462 -2.462 -2.462 ... -2.462 -2.462 -2.462] [-2.449 -2.449 -2.449 ... -2.449 -2.449 -2.449] [-2.421 -2.421 -2.421 ... -2.421 -2.421 -2.421] ... [ 4.421 4.421 4.421 ... 4.421 4.421 4.421] [ 4.449 4.449 4.449 ... 4.449 4.449 4.449] [ 4.462 4.462 4.462 ... 4.462 4.462 4.462]]
Apply a function to the signal:
signal = signal.apply(jnp.exp) signal
SphericalSignal(shape=(50, 49), res_beta=50, res_alpha=49, quadrature=soft, p_val=1, p_arg=-1) [[ 0.085 0.085 0.085 ... 0.085 0.085 0.085] [ 0.086 0.086 0.086 ... 0.086 0.086 0.086] [ 0.089 0.089 0.089 ... 0.089 0.089 0.089] ... [83.217 83.217 83.217 ... 83.217 83.217 83.217] [85.518 85.518 85.518 ... 85.518 85.518 85.518] [86.695 86.695 86.695 ... 86.695 86.695 86.695]]
Convert the signal back to a spherical harmonic expansion:
irreps = e3nn.s2_irreps(4) coeffs = e3nn.from_s2grid(signal, irreps) coeffs["4e"]
1x4e [0. 0. 0. 0. 2.1 0. 0. 0. 0. ]
Resample the signal to a different grid resolution:
signal = signal.resample(100, 99, lmax=5) signal
SphericalSignal(shape=(100, 99), res_beta=100, res_alpha=99, quadrature=soft, p_val=1, p_arg=-1) [[-0.443 -0.443 -0.443 ... -0.443 -0.443 -0.443] [-0.433 -0.433 -0.433 ... -0.433 -0.433 -0.433] [-0.412 -0.412 -0.412 ... -0.412 -0.412 -0.412] ... [85.097 85.097 85.097 ... 85.097 85.097 85.097] [85.647 85.647 85.647 ... 85.647 85.647 85.647] [85.923 85.923 85.923 ... 85.923 85.923 85.923]]
Compute the integral of the signal:
signal.integrate()
1x0e [157.361]
Rotate the signal (we need to determine
lmax
because the rotation is done in the Fourier domain):signal = signal.transform_by_angles(jnp.pi / 2, jnp.pi / 3, 0.0, lmax=5)
Sample a point on the sphere, using the signal as a density function:
indices = signal.sample(jax.random.PRNGKey(0)) signal.grid_vectors[indices], signal.grid_values[indices]
(Array([ 0.958, 0.11 , -0.265], dtype=float32), Array(58.511, dtype=float32))
Plot the signal:
import plotly.graph_objects as go go.Figure([go.Surface(signal.plotly_surface())])
Methods:
apply
(func)Applies a function pointwise on the grid.
find_peaks
(lmax)Locate peaks on the signal on the sphere.
from_function
(func, res_beta, res_alpha, ...)Create a signal on the sphere from a function of the coordinates.
Integrate the signal on the sphere.
pad_to_plot
(*[, translation, radius, ...])Postprocess the borders of a given signal to allow to plot with plotly.
plotly_surface
([translation, radius, ...])Returns a dictionary that can be plotted with plotly.
replace_values
(grid_values)Replace the grid values of the signal.
resample
(res_beta, res_alpha, lmax[, quadrature])Resamples a signal via the spherical harmonic coefficients.
sample
(key)Sample a point on the sphere using the signal as a probability distribution.
transform_by_angles
(alpha, beta, gamma, lmax)Rotate the signal by the given Euler angles.
transform_by_axis_angle
(axis, angle, lmax)Rotate the signal by the given angle around an axis.
transform_by_matrix
(R, lmax)Rotate the signal by the given rotation matrix.
transform_by_quaternion
(q, lmax)Rotate the signal by the given quaternion.
zeros
(res_beta, res_alpha, quadrature, *[, ...])Create a null signal on a grid.
Attributes:
Returns the dtype of this signal.
Returns alpha values on the grid for this signal.
Grid resolution for (beta, alpha).
(res_beta, res_alpha, 3)
.Returns y-values on the grid for this signal.
Returns the number of dimensions of this signal.
Returns quadrature weights along the y-coordinates.
Grid resolution for alpha.
Grid resolution for beta.
Returns the shape of this signal.
- apply(func: Callable[[Array], Array]) SphericalSignal [source]
Applies a function pointwise on the grid.
- find_peaks(lmax: int) Tuple[ndarray, ndarray] [source]
Locate peaks on the signal on the sphere.
Currently cannot be wrapped with jax.jit().
- static from_function(func: ~typing.Callable[[~jax.Array], float], res_beta: int, res_alpha: int, quadrature: str, *, p_val: int = 1, p_arg: int = -1, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) SphericalSignal [source]
Create a signal on the sphere from a function of the coordinates.
- Parameters:
func (
Callable
) – function on the sphere that maps a 3-dimensional array (x, y, z) to a numberres_beta – resolution for beta
res_alpha – resolution for alpha
quadrature – quadrature to use
p_val – parity of the signal, either +1 or -1
p_arg – parity of the argument of the signal, either +1 or -1
dtype – dtype of the signal
- Returns:
signal on the sphere
- Return type:
- property grid_vectors: Array[source]
(res_beta, res_alpha, 3)
.- Type:
Returns the coordinates of the points on the sphere. Shape
- integrate() IrrepsArray [source]
Integrate the signal on the sphere.
The integral of a constant signal of value 1 is 4pi.
- Returns:
integral of the signal
- Return type:
- pad_to_plot(*, translation: Array | None = None, radius: float = 1.0, scale_radius_by_amplitude: bool = False, normalize_radius_by_max_amplitude: bool = False) Tuple[Array, Array] [source]
Postprocess the borders of a given signal to allow to plot with plotly.
- Parameters:
translation (optional) – translation vector
radius (float) – radius of the sphere
scale_radius_by_amplitude (bool) – to rescale the output vectors with the amplitude of the signal
normalize_radius_by_max_amplitude (bool) – when scale_radius_by_amplitude is True, rescales the surface so that the maximum amplitude is equal to the radius
- Returns:
vectors on the sphere, shape
(res_beta + 2, res_alpha + 1, 3)
f (jax.Array
): padded signal, shape(res_beta + 2, res_alpha + 1)
- Return type:
r (
jax.Array
)
- plotly_surface(translation: Array | None = None, radius: float = 1.0, scale_radius_by_amplitude: bool = False, normalize_radius_by_max_amplitude: bool = False)[source]
Returns a dictionary that can be plotted with plotly.
- Parameters:
translation (optional) – translation vector
radius (float) – radius of the sphere
scale_radius_by_amplitude (bool) – to rescale the output vectors with the amplitude of the signal
normalize_radius_by_max_amplitude (bool) – when scale_radius_by_amplitude is True, rescales the surface so that the maximum amplitude is equal to the radius
- Returns:
dictionary that can be plotted with plotly
- Return type:
Examples:
import jax.numpy as jnp import e3nn_jax as e3nn coeffs = e3nn.normal(e3nn.s2_irreps(5), jax.random.PRNGKey(0)) signal = e3nn.to_s2grid(coeffs, 70, 141, quadrature="gausslegendre") import plotly.graph_objects as go go.Figure([go.Surface(signal.plotly_surface())])
One can also scale the radius of the sphere by the amplitude of the signal:
go.Figure([go.Surface(signal.plotly_surface(scale_radius_by_amplitude=True))])
- replace_values(grid_values: Array) SphericalSignal [source]
Replace the grid values of the signal.
- resample(res_beta: int, res_alpha: int, lmax: int, quadrature: str | None = None) SphericalSignal [source]
Resamples a signal via the spherical harmonic coefficients.
- Parameters:
res_beta – New resolution for beta.
res_alpha – New resolution for alpha.
lmax – Maximum l for the spherical harmonics.
quadrature – Quadrature to use. Defaults to reusing the current quadrature.
- Returns:
A new SphericalSignal with the new resolution.
- sample(key: Array) Tuple[Array, Array] [source]
Sample a point on the sphere using the signal as a probability distribution.
The probability distribution does not need to be normalized.
Examples:
coeffs = e3nn.IrrepsArray("0e + 1o", jnp.array([1.0, 2.0, 0.0, 0.0])) signal = e3nn.to_s2grid(coeffs, 50, 69, quadrature="gausslegendre") signal = signal.apply(jnp.exp) beta_index, alpha_index = signal.sample(jax.random.PRNGKey(0)) print(beta_index, alpha_index) print(signal.grid_vectors[beta_index, alpha_index])
19 20 [ 0.913 -0.336 -0.233]
- transform_by_angles(alpha: float, beta: float, gamma: float, lmax: int) SphericalSignal [source]
Rotate the signal by the given Euler angles.
- transform_by_axis_angle(axis: Array, angle: float, lmax: int) SphericalSignal [source]
Rotate the signal by the given angle around an axis.
- transform_by_matrix(R: Array, lmax: int) SphericalSignal [source]
Rotate the signal by the given rotation matrix.
- transform_by_quaternion(q: Array, lmax: int) SphericalSignal [source]
Rotate the signal by the given quaternion.
- static zeros(res_beta: int, res_alpha: int, quadrature: str, *, p_val: int = 1, p_arg: int = -1, dtype: ~numpy.dtype = <class 'jax.numpy.float32'>) SphericalSignal [source]
Create a null signal on a grid.