flops
flops
¶
FLOPs estimation and Model FLOPs Utilization (MFU) computation.
Functions¶
estimate_flops
¶
Estimate FLOPs for an inference pass (assumes KV caching).
Uses the 2 * P * T approximation where P = params, T = total tokens. Returns (total_flops, flops_per_token).
input_tokens must include system-prompt tokens and must not
be reduced for KV-cache reuse — it should represent the full prompt
size that was sent to the engine.
Source code in src/openjarvis/telemetry/flops.py
estimate_flops_no_kv_cache
¶
estimate_flops_no_kv_cache(model: str, input_tokens: int, output_tokens: int) -> tuple[float, float]
Estimate FLOPs without KV caching (full recompute per token).
Without KV cache, each token is re-processed for every subsequent token. FLOPs = P * N * (N + 1) where P = params, N = total_tokens. Returns (total_flops, flops_per_token_avg).
Source code in src/openjarvis/telemetry/flops.py
compute_mfu
¶
Compute Model FLOPs Utilization.
MFU = actual_tflops / (peak_tflops * num_gpus)