Skip to content

flops

flops

FLOPs estimation and Model FLOPs Utilization (MFU) computation.

Functions

estimate_flops

estimate_flops(model: str, input_tokens: int, output_tokens: int) -> tuple[float, float]

Estimate FLOPs for an inference pass (assumes KV caching).

Uses the 2 * P * T approximation where P = params, T = total tokens. Returns (total_flops, flops_per_token).

input_tokens must include system-prompt tokens and must not be reduced for KV-cache reuse — it should represent the full prompt size that was sent to the engine.

Source code in src/openjarvis/telemetry/flops.py
def estimate_flops(
    model: str, input_tokens: int, output_tokens: int
) -> tuple[float, float]:
    """Estimate FLOPs for an inference pass (assumes KV caching).

    Uses the 2 * P * T approximation where P = params, T = total tokens.
    Returns (total_flops, flops_per_token).

    ``input_tokens`` must include system-prompt tokens and must *not*
    be reduced for KV-cache reuse — it should represent the full prompt
    size that was sent to the engine.
    """
    params_b = _get_params_b(model)
    total_tokens = input_tokens + output_tokens
    params = params_b * 1e9
    total_flops = 2.0 * params * total_tokens
    flops_per_token = 2.0 * params if total_tokens > 0 else 0.0
    return (total_flops, flops_per_token)

estimate_flops_no_kv_cache

estimate_flops_no_kv_cache(model: str, input_tokens: int, output_tokens: int) -> tuple[float, float]

Estimate FLOPs without KV caching (full recompute per token).

Without KV cache, each token is re-processed for every subsequent token. FLOPs = P * N * (N + 1) where P = params, N = total_tokens. Returns (total_flops, flops_per_token_avg).

Source code in src/openjarvis/telemetry/flops.py
def estimate_flops_no_kv_cache(
    model: str, input_tokens: int, output_tokens: int
) -> tuple[float, float]:
    """Estimate FLOPs without KV caching (full recompute per token).

    Without KV cache, each token is re-processed for every subsequent token.
    FLOPs = P * N * (N + 1) where P = params, N = total_tokens.
    Returns (total_flops, flops_per_token_avg).
    """
    params_b = _get_params_b(model)
    total_tokens = input_tokens + output_tokens
    if total_tokens == 0:
        return (0.0, 0.0)
    params = params_b * 1e9
    total_flops = params * total_tokens * (total_tokens + 1)
    flops_per_token = total_flops / total_tokens
    return (total_flops, flops_per_token)

compute_mfu

compute_mfu(flops: float, duration_s: float, gpu_name: str, num_gpus: int = 1) -> float

Compute Model FLOPs Utilization.

MFU = actual_tflops / (peak_tflops * num_gpus)

Source code in src/openjarvis/telemetry/flops.py
def compute_mfu(
    flops: float, duration_s: float, gpu_name: str, num_gpus: int = 1
) -> float:
    """Compute Model FLOPs Utilization.

    MFU = actual_tflops / (peak_tflops * num_gpus)
    """
    peak = GPU_PEAK_TFLOPS_BF16.get(gpu_name, 0.0)
    if peak == 0.0:
        # Try substring matching
        for key, val in GPU_PEAK_TFLOPS_BF16.items():
            if key.lower() in gpu_name.lower():
                peak = val
                break
    if peak <= 0 or duration_s <= 0:
        return 0.0
    actual_tflops = flops / (duration_s * 1e12)
    return (actual_tflops / (peak * num_gpus)) * 100.0