Tensor Products
- e3nn_jax.spherical_harmonics(irreps_out: Irreps | int | Sequence[int], input: IrrepsArray | Array, normalize: bool, normalization: str = None, *, algorithm: Tuple[str, ...] = None) IrrepsArray [source]
Spherical harmonics.
Polynomials defined on the 3d space \(Y^l: \mathbb{R}^3 \longrightarrow \mathbb{R}^{2l+1}\)Usually restricted on the sphere (withnormalize=True
) \(Y^l: S^2 \longrightarrow \mathbb{R}^{2l+1}\)who satisfies the following properties:are polynomials of the cartesian coordinates
x, y, z
is equivariant \(Y^l(R x) = D^l(R) Y^l(x)\)
are orthogonal \(\int_{S^2} Y^l_m(x) Y^j_n(x) dx = \text{cste} \; \delta_{lj} \delta_{mn}\)
The value of the constant depends on the choice of normalization.
It obeys the following property:
\[ \begin{align}\begin{aligned}Y^{l+1}_i(x) &= \text{cste}(l) \; C_{ijk} Y^l_j(x) x_k\\\partial_k Y^{l+1}_i(x) &= \text{cste}(l) \; (l+1) C_{ijk} Y^l_j(x)\end{aligned}\end{align} \]Where \(C\) are the
clebsch_gordan
.Note
This function match with this table of standard real spherical harmonics from Wikipedia when
normalize=True
,normalization='integral'
and is called with the argument in the ordery,z,x
(instead ofx,y,z
).- Parameters:
irreps_out (
Irreps
or list of int or int) – output irrepsinput (
IrrepsArray
orjax.Array
) – cartesian coordinatesnormalize (bool) – if True, the polynomials are restricted to the sphere
normalization (str) – normalization of the constant \(\text{cste}\). Default is ‘component’
algorithm (Tuple[str]) – algorithm to use for the computation. (legendre|recursive, dense|sparse, [custom_jvp])
- Returns:
polynomials of the spherical harmonics
- Return type:
- e3nn_jax.tensor_product(input1: IrrepsArray, input2: IrrepsArray, *, filter_ir_out: List[Irrep] | None = None, irrep_normalization: str | None = None, regroup_output: bool = True) IrrepsArray [source]
Tensor product reduced into irreps.
- Parameters:
input1 (IrrepsArray) – First input
input2 (IrrepsArray) – Second input
filter_ir_out (list of Irrep, optional) – Filter the output irreps. Defaults to None.
irrep_normalization (str, optional) – Irrep normalization,
"component"
or"norm"
. Defaults to"component"
.regroup_output (bool, optional) – Regroup the outputs into irreps. Defaults to True.
- Returns:
Tensor product of the two inputs.
- Return type:
Examples
>>> jnp.set_printoptions(precision=2, suppress=True) >>> import e3nn_jax as e3nn >>> x = e3nn.IrrepsArray("2x0e + 1o", jnp.arange(5)) >>> y = e3nn.IrrepsArray("0o + 2o", jnp.arange(6)) >>> e3nn.tensor_product(x, y) 2x0o+2x1e+1x2e+2x2o+1x3e [ 0. 0. 0. 0. 0. -1.9 16.65 14.83 7.35 -12.57 0. -0.66 4.08 0. 0. 0. 0. 0. 1. 2. 3. 4. 5. 9.9 10.97 9.27 -1.97 12.34 15.59 12.73]
Usage in combination with
haiku.Linear
orflax.Linear
:>>> import jax >>> import flax.linen as nn >>> linear = e3nn.flax.Linear("3x1e") >>> params = linear.init(jax.random.PRNGKey(0), e3nn.tensor_product(x, y)) >>> jax.tree_util.tree_structure(params) PyTreeDef({'params': {'w[1,0] 2x1e,3x1e': *}}) >>> z = linear.apply(params, e3nn.tensor_product(x, y))
The irreps can be determined without providing input data:
>>> e3nn.tensor_product("2x1e + 2e", "2e") 1x0e+3x1e+3x2e+3x3e+1x4e
- e3nn_jax.tensor_square(input: IrrepsArray, *, irrep_normalization: str | None = None, normalized_input: bool = False, regroup_output: bool = True) IrrepsArray [source]
Tensor product of a
IrrepsArray
with itself.- Parameters:
input (IrrepsArray) – Input to be squared
irrep_normalization (str, optional) – Irrep normalization,
"component"
or"norm"
.normalized_input (bool, optional) – If True, the input is assumed to be striclty normalized. Note that this is different from
irrep_normalization="norm"
for which the input is of norm 1 in average. Defaults to False.regroup_output (bool, optional) – If True, the output irreps are regrouped. Defaults to True.
- Returns:
Tensor product of the input with itself.
- Return type:
Examples
>>> jnp.set_printoptions(precision=2, suppress=True) >>> import e3nn_jax as e3nn >>> x = e3nn.IrrepsArray("0e + 1o", jnp.array([10, 1, 2, 3.0])) >>> e3nn.tensor_square(x) 2x0e+1x1o+1x2e [57.74 3.61 10. 20. 30. 3. 2. -0.58 6. 4. ]
>>> e3nn.tensor_square(x, normalized_input=True) 2x0e+1x1o+1x2e [100. 14. 17.32 34.64 51.96 11.62 7.75 -2.24 23.24 15.49]
- e3nn_jax.reduced_tensor_product_basis(formula_or_irreps_list: str | List[Irreps], *, epsilon: float = 1e-05, keep_ir: Irreps | List[Irrep] | None = None, _use_optimized_implementation: bool = True, **irreps_dict) IrrepsArray [source]
Reduce a tensor product of multiple irreps subject to some permutation symmetry given by a formula.
- Parameters:
formula_or_irreps_list (str or list of Irreps) – a formula of the form
ijk=jik=ikj
orijk=-jki
. The left hand side is the original formula and the right hand side are the signed permutations. If no index symmetry is present, a list of irreps can be given instead.epsilon (float) – the tolerance for the Gram-Schmidt orthogonalization. Default:
1e-5
keep_ir (list of Irrep) – irrep to keep in the output. Default: keep all irrep
irreps_dict (dict) – the irreps of each index of the formula. For instance
i="1x1o"
.
- Returns:
- The change of basis
The shape is
(d1, ..., dn, irreps_out.dim)
wheredi
is the dimension of the indexi
andn
is the number of indices in the formula.
- Return type:
Examples
>>> np.set_printoptions(precision=3, suppress=True) >>> reduced_tensor_product_basis("ij=-ji", i="1x1o") 1x1e [[[ 0. 0. 0. ] [ 0. 0. 0.707] [ 0. -0.707 0. ]] [[ 0. 0. -0.707] [ 0. 0. 0. ] [ 0.707 0. 0. ]] [[ 0. 0.707 0. ] [-0.707 0. 0. ] [ 0. 0. 0. ]]]
- e3nn_jax.reduced_symmetric_tensor_product_basis(irreps: Irreps, degree: int, *, epsilon: float = 1e-05, keep_ir: Irreps | List[Irrep] | None = None, _use_optimized_implementation: bool = True) IrrepsArray [source]
Reduce a symmetric tensor product, usually called for a single irrep.
- Parameters:
- Returns:
- The change of basis
The shape is
(d, ..., d, irreps_out.dim)
whered
is the dimension ofirreps
.
- Return type:
- e3nn_jax.reduced_antisymmetric_tensor_product_basis(irreps: Irreps, degree: int, *, epsilon: float = 1e-05, keep_ir: Irreps | List[Irrep] | None = None, _use_optimized_implementation: bool = True) IrrepsArray [source]
Reduce an antisymmetric tensor product.
- Parameters:
- Returns:
- The change of basis
The shape is
(d, ..., d, irreps_out.dim)
whered
is the dimension ofirreps
.
- Return type:
- e3nn_jax.elementwise_tensor_product(input1: IrrepsArray, input2: IrrepsArray, *, filter_ir_out: List[Irrep] | None = None, irrep_normalization: str | None = None) IrrepsArray [source]
Elementwise tensor product of two
IrrepsArray
.- Parameters:
input1 (IrrepsArray) – First input
input2 (IrrepsArray) – Second input with the same number of irreps as
input1
,input1.irreps.num_irreps == input2.irreps.num_irreps
.filter_ir_out (list of Irrep, optional) – Filter the output irreps. Defaults to None.
irrep_normalization (str, optional) – Irrep normalization,
"component"
or"norm"
. Defaults to"component"
.
- Returns:
- Elementwise tensor product of the two inputs.
The irreps are not sorted and not simplified.
- Return type:
Examples
>>> jnp.set_printoptions(precision=2, suppress=True) >>> import e3nn_jax as e3nn >>> x = e3nn.IrrepsArray("2x0e + 1o", jnp.arange(5)) >>> y = e3nn.IrrepsArray("1e + 0o + 0o", jnp.arange(5)) >>> e3nn.elementwise_tensor_product(x, y) 1x1e+1x0o+1x1e [ 0. 0. 0. 3. 8. 12. 16.]
- e3nn_jax.tensor_product_with_spherical_harmonics(input: IrrepsArray, vector: IrrepsArray, degree: int) IrrepsArray [source]
Tensor product of something with the spherical harmonics of a vector.
The idea of this optimization comes from the paper:
Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs
- Parameters:
input (IrrepsArray) – input
vector (IrrepsArray) – vector, irreps must be “1o” or “1e”
degree (int) – the maximum degree of the spherical harmonics
- Returns:
tensor product
- Return type:
Notes
This function is equivalent to:
tensor_product(input, spherical_harmonics(range(degree + 1), vector, True))
Examples
>>> input = e3nn.normal("3x0e + 2x1o", jax.random.PRNGKey(0)) >>> vector = e3nn.normal("1e", jax.random.PRNGKey(1)) >>> degree = 2 >>> output1 = tensor_product_with_spherical_harmonics(input, vector, degree) >>> output2 = e3nn.tensor_product(input, e3nn.spherical_harmonics(range(degree + 1), vector, True)) >>> assert output1.irreps == output2.irreps >>> assert jnp.allclose(output1.array, output2.array, atol=1e-6)
- e3nn_jax.sh(irreps_out: Irreps | int | Sequence[int], input: Array, normalize: bool, normalization: str = None, *, algorithm: Tuple[str, ...] = None) Array [source]
Spherical harmonics.
Same function as
e3nn_jax.spherical_harmonics()
but with a simple interface.- Parameters:
irreps_out (
Irreps
or int or Sequence[int]) – the output irrepsinput (
jax.Array
) – cartesian coordinates, shape (…, 3)normalize (bool) – if True, the polynomials are restricted to the sphere
normalization (str) – normalization of the constant \(\text{cste}\). Default is ‘component’
algorithm (Tuple[str]) – algorithm to use for the computation. (legendre|recursive, dense|sparse, [custom_jvp])
- Returns:
polynomials of the spherical harmonics
- Return type: