Skip to content

flops

flops

FLOPs estimation and Model FLOPs Utilization (MFU) computation.

Functions

estimate_flops

estimate_flops(model: str, input_tokens: int, output_tokens: int) -> tuple[float, float]

Estimate FLOPs for an inference pass.

Uses the 2 * P * T approximation where P = params, T = total tokens. Returns (total_flops, flops_per_token).

Source code in src/openjarvis/telemetry/flops.py
def estimate_flops(
    model: str, input_tokens: int, output_tokens: int
) -> tuple[float, float]:
    """Estimate FLOPs for an inference pass.

    Uses the 2 * P * T approximation where P = params, T = total tokens.
    Returns (total_flops, flops_per_token).
    """
    params_b = MODEL_PARAMS_B.get(model, 0.0)
    if params_b == 0.0:
        # Try prefix matching
        for key, val in MODEL_PARAMS_B.items():
            if model.startswith(key.split(":")[0]):
                params_b = val
                break

    total_tokens = input_tokens + output_tokens
    params = params_b * 1e9
    total_flops = 2.0 * params * total_tokens
    flops_per_token = 2.0 * params if total_tokens > 0 else 0.0
    return (total_flops, flops_per_token)

compute_mfu

compute_mfu(flops: float, duration_s: float, gpu_name: str, num_gpus: int = 1) -> float

Compute Model FLOPs Utilization.

MFU = actual_tflops / (peak_tflops * num_gpus)

Source code in src/openjarvis/telemetry/flops.py
def compute_mfu(
    flops: float, duration_s: float, gpu_name: str, num_gpus: int = 1
) -> float:
    """Compute Model FLOPs Utilization.

    MFU = actual_tflops / (peak_tflops * num_gpus)
    """
    peak = GPU_PEAK_TFLOPS_BF16.get(gpu_name, 0.0)
    if peak == 0.0:
        # Try substring matching
        for key, val in GPU_PEAK_TFLOPS_BF16.items():
            if key.lower() in gpu_name.lower():
                peak = val
                break
    if peak <= 0 or duration_s <= 0:
        return 0.0
    actual_tflops = flops / (duration_s * 1e12)
    return (actual_tflops / (peak * num_gpus)) * 100.0