Legacy

class e3nn_jax.legacy.FunctionalTensorProduct(irreps_in1: Irreps, irreps_in2: Irreps, irreps_out: Irreps, instructions: List[Tuple[int, int, int, str, bool, float | None]], in1_var: List[float] | None = None, in2_var: List[float] | None = None, out_var: List[float] | None = None, irrep_normalization: str = None, path_normalization: str | float = None, gradient_normalization: str | float = None)[source]

Bases: object

Tensor product of two tensors.

Parameters:
  • irreps_in1Irreps of the first tensor.

  • irreps_in2Irreps of the second tensor.

  • irreps_outIrreps of the output tensor.

  • instructions – List of instructions. [(i_in1, i_in2, i_out, connection_mode, has_weight, (path_weight)), ...] - i_in1, i_in2, i_out are indices of the irreps_in1, irreps_in2, irreps_out. - connection_mode is one of uvw, uvu, uvv, uuw, uuu, uvuv - has_weight is a boolean indicating whether the instruction has a weight. - path_weight (optional, 1.0 by default) is the weight of the path.

  • in1_var – Variance of the first tensor.

  • in2_var – Variance of the second tensor.

  • out_var – Variance of the output tensor.

  • irrep_normalization – Normalization of the tensors. component or norm.

  • path_normalization (str or float) – Normalization of the paths, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the forward propagation.

  • gradient_normalization (str or float) – Normalization of the gradients, element or path. 0/1 corresponds to a normalization where each element/path has an equal contribution to the learning.

Methods:

left_right(weights, input1[, input2, ...])

Compute the tensor product of two input tensors.

right(weights[, input2, custom_einsum_jvp])

Compute the right contraction of the tensor product.

left_right(weights: List[Array] | Array, input1: IrrepsArray, input2: IrrepsArray = None, *, custom_einsum_jvp: bool = None, fused: bool = None, sparse: bool = None) IrrepsArray[source]

Compute the tensor product of two input tensors.

Parameters:
  • weights (array or list of arrays) – The weights of the tensor product.

  • input1 (IrrepsArray) – The first input tensor.

  • input2 (IrrepsArray) – The second input tensor.

  • custom_einsum_jvp (bool) – If True, use the custom jvp for the einsum code.

  • fused (bool) – If True, fuse all the einsums.

Returns:

The output tensor.

Return type:

IrrepsArray

right(weights: List[Array], input2: IrrepsArray = None, *, custom_einsum_jvp=None) Array[source]

Compute the right contraction of the tensor product.

Parameters:
  • weights (array or list of arrays) – The weights of the tensor product.

  • input2 (IrrepsArray) – The second input tensor.

  • custom_einsum_jvp (bool) – If True, use the custom jvp for the einsum code.

Returns:

A matrix of shape (irreps_in1.dim, irreps_out.dim).