Tensor Products

e3nn_jax.spherical_harmonics(irreps_out: Irreps | int | Sequence[int], input: IrrepsArray | Array, normalize: bool, normalization: str = None, *, algorithm: Tuple[str, ...] = None) IrrepsArray[source]

Spherical harmonics.

https://user-images.githubusercontent.com/333780/79220728-dbe82c00-7e54-11ea-82c7-b3acbd9b2246.gif
Polynomials defined on the 3d space \(Y^l: \mathbb{R}^3 \longrightarrow \mathbb{R}^{2l+1}\)
Usually restricted on the sphere (with normalize=True) \(Y^l: S^2 \longrightarrow \mathbb{R}^{2l+1}\)
who satisfies the following properties:
  • are polynomials of the cartesian coordinates x, y, z

  • is equivariant \(Y^l(R x) = D^l(R) Y^l(x)\)

  • are orthogonal \(\int_{S^2} Y^l_m(x) Y^j_n(x) dx = \text{cste} \; \delta_{lj} \delta_{mn}\)

The value of the constant depends on the choice of normalization.

It obeys the following property:

\[ \begin{align}\begin{aligned}Y^{l+1}_i(x) &= \text{cste}(l) \; C_{ijk} Y^l_j(x) x_k\\\partial_k Y^{l+1}_i(x) &= \text{cste}(l) \; (l+1) C_{ijk} Y^l_j(x)\end{aligned}\end{align} \]

Where \(C\) are the clebsch_gordan.

Note

This function match with this table of standard real spherical harmonics from Wikipedia when normalize=True, normalization='integral' and is called with the argument in the order y,z,x (instead of x,y,z).

Parameters:
  • irreps_out (Irreps or list of int or int) – output irreps

  • input (IrrepsArray or jax.Array) – cartesian coordinates

  • normalize (bool) – if True, the polynomials are restricted to the sphere

  • normalization (str) – normalization of the constant \(\text{cste}\). Default is ‘component’

  • algorithm (Tuple[str]) – algorithm to use for the computation. (legendre|recursive, dense|sparse, [custom_jvp])

Returns:

polynomials of the spherical harmonics

Return type:

IrrepsArray

e3nn_jax.tensor_product(input1: IrrepsArray, input2: IrrepsArray, *, filter_ir_out: List[Irrep] | None = None, irrep_normalization: str | None = None, regroup_output: bool = True) IrrepsArray[source]

Tensor product reduced into irreps.

Parameters:
  • input1 (IrrepsArray) – First input

  • input2 (IrrepsArray) – Second input

  • filter_ir_out (list of Irrep, optional) – Filter the output irreps. Defaults to None.

  • irrep_normalization (str, optional) – Irrep normalization, "component" or "norm". Defaults to "component".

  • regroup_output (bool, optional) – Regroup the outputs into irreps. Defaults to True.

Returns:

Tensor product of the two inputs.

Return type:

IrrepsArray

Examples

>>> jnp.set_printoptions(precision=2, suppress=True)
>>> import e3nn_jax as e3nn
>>> x = e3nn.IrrepsArray("2x0e + 1o", jnp.arange(5))
>>> y = e3nn.IrrepsArray("0o + 2o", jnp.arange(6))
>>> e3nn.tensor_product(x, y)
2x0o+2x1e+1x2e+2x2o+1x3e
[  0.     0.     0.     0.     0.    -1.9   16.65  14.83   7.35 -12.57
   0.    -0.66   4.08   0.     0.     0.     0.     0.     1.     2.
   3.     4.     5.     9.9   10.97   9.27  -1.97  12.34  15.59  12.73]

Usage in combination with haiku.Linear or flax.Linear:

>>> import jax
>>> import flax.linen as nn
>>> linear = e3nn.flax.Linear("3x1e")
>>> params = linear.init(jax.random.PRNGKey(0), e3nn.tensor_product(x, y))
>>> jax.tree_util.tree_structure(params)
PyTreeDef({'params': {'w[1,0] 2x1e,3x1e': *}})
>>> z = linear.apply(params, e3nn.tensor_product(x, y))

The irreps can be determined without providing input data:

>>> e3nn.tensor_product("2x1e + 2e", "2e")
1x0e+3x1e+3x2e+3x3e+1x4e
e3nn_jax.tensor_square(input: IrrepsArray, *, irrep_normalization: str | None = None, normalized_input: bool = False, regroup_output: bool = True) IrrepsArray[source]

Tensor product of a IrrepsArray with itself.

Parameters:
  • input (IrrepsArray) – Input to be squared

  • irrep_normalization (str, optional) – Irrep normalization, "component" or "norm".

  • normalized_input (bool, optional) – If True, the input is assumed to be striclty normalized. Note that this is different from irrep_normalization="norm" for which the input is of norm 1 in average. Defaults to False.

  • regroup_output (bool, optional) – If True, the output irreps are regrouped. Defaults to True.

Returns:

Tensor product of the input with itself.

Return type:

IrrepsArray

Examples

>>> jnp.set_printoptions(precision=2, suppress=True)
>>> import e3nn_jax as e3nn
>>> x = e3nn.IrrepsArray("0e + 1o", jnp.array([10, 1, 2, 3.0]))
>>> e3nn.tensor_square(x)
2x0e+1x1o+1x2e [57.74  3.61 10.   20.   30.    3.    2.   -0.58  6.    4.  ]
>>> e3nn.tensor_square(x, normalized_input=True)
2x0e+1x1o+1x2e [100.    14.    17.32  34.64  51.96  11.62   7.75  -2.24  23.24  15.49]
e3nn_jax.reduced_tensor_product_basis(formula_or_irreps_list: str | List[Irreps], *, epsilon: float = 1e-05, keep_ir: Irreps | List[Irrep] | None = None, _use_optimized_implementation: bool = True, **irreps_dict) IrrepsArray[source]

Reduce a tensor product of multiple irreps subject to some permutation symmetry given by a formula.

Parameters:
  • formula_or_irreps_list (str or list of Irreps) – a formula of the form ijk=jik=ikj or ijk=-jki. The left hand side is the original formula and the right hand side are the signed permutations. If no index symmetry is present, a list of irreps can be given instead.

  • epsilon (float) – the tolerance for the Gram-Schmidt orthogonalization. Default: 1e-5

  • keep_ir (list of Irrep) – irrep to keep in the output. Default: keep all irrep

  • irreps_dict (dict) – the irreps of each index of the formula. For instance i="1x1o".

Returns:

The change of basis

The shape is (d1, ..., dn, irreps_out.dim) where di is the dimension of the index i and n is the number of indices in the formula.

Return type:

IrrepsArray

Examples

>>> np.set_printoptions(precision=3, suppress=True)
>>> reduced_tensor_product_basis("ij=-ji", i="1x1o")
1x1e
[[[ 0.     0.     0.   ]
  [ 0.     0.     0.707]
  [ 0.    -0.707  0.   ]]

 [[ 0.     0.    -0.707]
  [ 0.     0.     0.   ]
  [ 0.707  0.     0.   ]]

 [[ 0.     0.707  0.   ]
  [-0.707  0.     0.   ]
  [ 0.     0.     0.   ]]]
e3nn_jax.reduced_symmetric_tensor_product_basis(irreps: Irreps, degree: int, *, epsilon: float = 1e-05, keep_ir: Irreps | List[Irrep] | None = None, _use_optimized_implementation: bool = True) IrrepsArray[source]

Reduce a symmetric tensor product, usually called for a single irrep.

Parameters:
  • irreps (Irreps) – the irreps of each index.

  • degree (int) – the degree of the tensor product. i.e. the number of indices.

  • epsilon (float) – the tolerance for the Gram-Schmidt orthogonalization. Default: 1e-5

  • keep_ir (list of Irrep) – irrep to keep in the output. Default: keep all irrep

Returns:

The change of basis

The shape is (d, ..., d, irreps_out.dim) where d is the dimension of irreps.

Return type:

IrrepsArray

e3nn_jax.reduced_antisymmetric_tensor_product_basis(irreps: Irreps, degree: int, *, epsilon: float = 1e-05, keep_ir: Irreps | List[Irrep] | None = None, _use_optimized_implementation: bool = True) IrrepsArray[source]

Reduce an antisymmetric tensor product.

Parameters:
  • irreps (Irreps) – the irreps of each index.

  • degree (int) – the degree of the tensor product. i.e. the number of indices.

  • epsilon (float) – the tolerance for the Gram-Schmidt orthogonalization. Default: 1e-5

  • keep_ir (list of Irrep) – irrep to keep in the output. Default: keep all irrep

Returns:

The change of basis

The shape is (d, ..., d, irreps_out.dim) where d is the dimension of irreps.

Return type:

IrrepsArray

e3nn_jax.elementwise_tensor_product(input1: IrrepsArray, input2: IrrepsArray, *, filter_ir_out: List[Irrep] | None = None, irrep_normalization: str | None = None) IrrepsArray[source]

Elementwise tensor product of two IrrepsArray.

Parameters:
  • input1 (IrrepsArray) – First input

  • input2 (IrrepsArray) – Second input with the same number of irreps as input1, input1.irreps.num_irreps == input2.irreps.num_irreps.

  • filter_ir_out (list of Irrep, optional) – Filter the output irreps. Defaults to None.

  • irrep_normalization (str, optional) – Irrep normalization, "component" or "norm". Defaults to "component".

Returns:

Elementwise tensor product of the two inputs.

The irreps are not sorted and not simplified.

Return type:

IrrepsArray

Examples

>>> jnp.set_printoptions(precision=2, suppress=True)
>>> import e3nn_jax as e3nn
>>> x = e3nn.IrrepsArray("2x0e + 1o", jnp.arange(5))
>>> y = e3nn.IrrepsArray("1e + 0o + 0o", jnp.arange(5))
>>> e3nn.elementwise_tensor_product(x, y)
1x1e+1x0o+1x1e [ 0.  0.  0.  3.  8. 12. 16.]
e3nn_jax.tensor_product_with_spherical_harmonics(input: IrrepsArray, vector: IrrepsArray, degree: int) IrrepsArray[source]

Tensor product of something with the spherical harmonics of a vector.

The idea of this optimization comes from the paper:

Reducing SO(3) Convolutions to SO(2) for Efficient Equivariant GNNs
Parameters:
  • input (IrrepsArray) – input

  • vector (IrrepsArray) – vector, irreps must be “1o” or “1e”

  • degree (int) – the maximum degree of the spherical harmonics

Returns:

tensor product

Return type:

IrrepsArray

Notes

This function is equivalent to:

tensor_product(input, spherical_harmonics(range(degree + 1), vector, True))

Examples

>>> input = e3nn.normal("3x0e + 2x1o", jax.random.PRNGKey(0))
>>> vector = e3nn.normal("1e", jax.random.PRNGKey(1))
>>> degree = 2
>>> output1 = tensor_product_with_spherical_harmonics(input, vector, degree)
>>> output2 = e3nn.tensor_product(input, e3nn.spherical_harmonics(range(degree + 1), vector, True))
>>> assert output1.irreps == output2.irreps
>>> assert jnp.allclose(output1.array, output2.array, atol=1e-6)
e3nn_jax.sh(irreps_out: Irreps | int | Sequence[int], input: Array, normalize: bool, normalization: str = None, *, algorithm: Tuple[str, ...] = None) Array[source]

Spherical harmonics.

Same function as e3nn_jax.spherical_harmonics() but with a simple interface.

Parameters:
  • irreps_out (Irreps or int or Sequence[int]) – the output irreps

  • input (jax.Array) – cartesian coordinates, shape (…, 3)

  • normalize (bool) – if True, the polynomials are restricted to the sphere

  • normalization (str) – normalization of the constant \(\text{cste}\). Default is ‘component’

  • algorithm (Tuple[str]) – algorithm to use for the computation. (legendre|recursive, dense|sparse, [custom_jvp])

Returns:

polynomials of the spherical harmonics

Return type:

jax.Array