Sandy Yu
Sandy Yu

Reputation: 21

looking for a tool to calculate FLOPs of XLA-HLO computational graph

I'm looking for a tool to calculate the FLOPs when given the computational graph of XLA-HLO. Is someone know some HLO cost models or analytical models for print the FLOPs of operator node for computational graph? I need a source code of it or sample usage tool for it. Thanks :)

Upvotes: 2

Views: 256

Answers (1)

jakevdp
jakevdp

Reputation: 86513

You can estimate FLOPS and other computational characteristics before running the code using JAX's built-in Ahead of time lowering utilities. For example:

import jax

def f(M, x):
  return jax.lax.exp(M @ x + 1).sum()

M = jax.ShapeDtypeStruct((10000, 10000), 'float32')
x = jax.ShapeDtypeStruct((10000,), 'float32')

lowered = jax.jit(f).lower(M, x)
print(lowered.cost_analysis())
{'flops': 200020000.0,
 'transcendentals': 10000.0,
 'bytes accessed': 400360000.0,
 'utilization0{}': 5.0,
 'utilization1{}': 3.0,
 'bytes accessed0{}': 400120000.0,
 'bytes accessed1{}': 80004.0,
 'bytes accessedout{}': 160004.0}

Upvotes: 0

Related Questions