Reputation: 21
I'm looking for a tool to calculate the FLOPs when given the computational graph of XLA-HLO. Is someone know some HLO cost models or analytical models for print the FLOPs of operator node for computational graph? I need a source code of it or sample usage tool for it. Thanks :)
Upvotes: 2
Views: 256
Reputation: 86513
You can estimate FLOPS and other computational characteristics before running the code using JAX's built-in Ahead of time lowering utilities. For example:
import jax
def f(M, x):
return jax.lax.exp(M @ x + 1).sum()
M = jax.ShapeDtypeStruct((10000, 10000), 'float32')
x = jax.ShapeDtypeStruct((10000,), 'float32')
lowered = jax.jit(f).lower(M, x)
print(lowered.cost_analysis())
{'flops': 200020000.0,
'transcendentals': 10000.0,
'bytes accessed': 400360000.0,
'utilization0{}': 5.0,
'utilization1{}': 3.0,
'bytes accessed0{}': 400120000.0,
'bytes accessed1{}': 80004.0,
'bytes accessedout{}': 160004.0}
Upvotes: 0