Functions for radial parameterization
- e3nn_jax.soft_one_hot_linspace(input: Array, *, start: float, end: float, number: int, basis: str = None, cutoff: bool = None, start_zero: bool = None, end_zero: bool = None)[source]
Projection on a basis of functions.
Returns a set of \(\{y_i(x)\}_{i=1}^N\),
\[y_i(x) = \frac{1}{Z} f_i(x)\]where \(x\) is the input and \(f_i\) is the ith basis function. \(Z\) is a constant defined such that,
\[\langle \sum_{i=1}^N y_i(x)^2 \rangle_x \approx 1\]See the last plot below.
- Parameters:
input (jax.Array) – input of shape
[...]
start (float) – minimum value span by the basis
end (float) – maximum value span by the basis
number (int) – number of basis functions \(N\)
basis (str) – type of basis functions, one of
gaussian
,cosine
,smooth_finite
,fourier
cutoff (bool) – if
True
, the basis functions are cutoff at the start and end of the intervalstart_zero (bool) – if
True
, the first basis function is zero at the start of the intervalend_zero (bool) – if
True
, the last basis function is zero at the end of the interval
- Returns:
basis functions of shape
[..., number]
- Return type:
Examples
bases = ["gaussian", "cosine", "smooth_finite", "fourier"] x = np.linspace(-1.0, 2.0, 200)
fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True) for axs, b in zip(axss, bases): for ax, c in zip(axs, [True, False]): y = e3nn.soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c) plt.sca(ax) plt.plot(x, y) plt.plot([-0.5]*2, [-2, 2], "k-.") plt.plot([1.5]*2, [-2, 2], "k-.") plt.title(f"{b}" + (" with cutoff" if c else "")) plt.ylim(-1, 1.5) plt.tight_layout()
fig, axss = plt.subplots(len(bases), 2, figsize=(9, 6), sharex=True, sharey=True) for axs, b in zip(axss, bases): for ax, c in zip(axs, [True, False]): y = e3nn.soft_one_hot_linspace(x, start=-0.5, end=1.5, number=4, basis=b, cutoff=c) plt.sca(ax) plt.plot(x, np.sum(y**2, axis=-1)) plt.plot([-0.5]*2, [-2, 2], "k-.") plt.plot([1.5]*2, [-2, 2], "k-.") plt.title(f"{b}" + (" with cutoff" if c else "")) plt.ylim(0, 2) plt.tight_layout()
- e3nn_jax.bessel(x: Array, n: int, x_max: float = 1.0) Array [source]
Bessel basis functions.
They obey the following normalization:
\[\int_0^c r^2 B_n(r, c) B_m(r, c) dr = \delta_{nm}\]- Parameters:
- Returns:
basis functions of shape
[..., n]
- Return type:
Klicpera, J.; Groß, J.; Günnemann, S. Directional Message Passing for Molecular Graphs; ICLR 2020. Equation (7)
- e3nn_jax.poly_envelope(n0: int, n1: int, x_max: float = 1.0) Callable[[float], float] [source]
Polynomial envelope function with
n0
andn1
derivatives euqal to 0 atx=0
andx=1
respectively.Small documentation available at
https://mariogeiger.ch/polynomial_envelope_for_gnn.pdf
. This is a generalization of \(u_p(x)\), it is equivalent to \(u_p(x)\) whenn0 = p-1
andn1 = 2
.x = jnp.linspace(0.0, 1.0, 100) plt.plot(x, e3nn.poly_envelope(10, 5)(x), label="10, 5") plt.plot(x, e3nn.poly_envelope(4, 4)(x), label="4, 4") plt.plot(x, e3nn.poly_envelope(1, 2)(x), label="1, 2") plt.legend()
<matplotlib.legend.Legend at 0x7fe72d76f690>