Benchmarks
TP Task
d/dw sum tanh TP(w, x1, x2)
x1, x2 = batch 64
TP = FullyConnected(irreps, irreps, irreps)
irreps = 128x0e + 128x1e + 128x2e
Records on NVIDIA RTX A5500 using cuda 12.0
version |
time |
||
---|---|---|---|
jax 0.4.11 |
e3nn 0.19.2 |
3.52ms |
tensor_product -> Linear |
jax 0.4.11 |
e3nn 0.19.2 |
3.57ms |
Fully ConnectedTensorProduct |
jax 0.4.11 |
e3nn 0.19.1 |
5.87ms |
tensor_product -> Linear |
jax 0.4.11 |
e3nn 0.19.1 |
4.58ms |
Fully ConnectedTensorProduct |
Records on NVIDIA RTX A5000 using cuda 12.0
version |
time |
||
---|---|---|---|
jax 0.4.11 |
e3nn 0.19.1 |
6.55ms |
Fully ConnectedTensorProduct |
Records on NVIDIA RTX A5000 using cuda 11.7
version |
time |
||
---|---|---|---|
jax 0.3.25 |
e3nn 0.12.0 |
5.2ms |
Fully ConnectedTensorProduct |
jax 0.3.24 |
e3nn 0.12.0 |
6.8ms |
tensor_product -> Linear |
jax 0.3.24 |
e3nn 0.12.0 |
5.2ms |
Fully ConnectedTensorProduct |
jax 0.3.24 |
e3nn 0.7.0 |
5.2ms |
Fully ConnectedTensorProduct |
jax 0.3.24 |
e3nn 0.6.0 |
5.2ms |
Fully ConnectedTensorProduct |
jax 0.3.24 |
e3nn 0.4.0 |
5.2ms |
Fully ConnectedTensorProduct |
jax 0.3.15 |
e3nn 0.12.0 |
5.2ms |
Fully ConnectedTensorProduct |
Records on NVIDIA RTX A5000 using cuda 11.6
version |
time |
|
---|---|---|
pytorch 1.11.0 |
e3nn 0.5.0 |
between 13ms and 14ms. |
jax 0.3.15 |
e3nn 0.7.0 |
1.7ms |
x8 speedup
Records on GTX1080
On Pytorch it takes 140 ms.
python examples/tensor_product_benchmark.py --irreps "128x0e + 128x1e + 128x2e" --extrachannels f --specialized-code f --fused f --lists t --custom-einsum-jvp f --batch 64 -n 10
======= Benchmark with settings: ======
jit : True
irreps : 128x0e + 128x1e + 128x2e
irreps_in1 : None
irreps_in2 : None
irreps_out : None
cuda : True
backward : True
opt_ein : True
custom_einsum_jvp : False
specialized_code : False
elementwise : False
extrachannels : False
fused : False
lists : True
n : 10
batch : 64
========================================
31457280 parameters
starting...
12.9 ms
======= Benchmark with settings: ======
specialized_code : True
lists : False
========================================
12 ms
QM9
Record on V100
6 Conv + Gate FC [64, 64]
lmax=2, mul=512 pytorch 430ms jax 91ms x5
lmax=2, mul=128 pytorch 230ms jax 42ms x5
lmax=1, mul=128 pytorch 100ms jax 20ms x5
lmax=1, mul=256 pytorch 130ms jax 34ms x4