Skip to content

bandit_router

bandit_router

Bandit router — Thompson Sampling / UCB for query→model selection.

Classes

ArmStats dataclass

ArmStats(successes: int = 0, failures: int = 0, total_reward: float = 0.0, pulls: int = 0)

Statistics for a single arm (model).

BanditRouterPolicy

BanditRouterPolicy(*, strategy: Literal['thompson', 'ucb'] = 'thompson', exploration_factor: float = 2.0, min_pulls: int = 3, reward_threshold: float = 0.5)

Multi-armed bandit router using Thompson Sampling or UCB.

Each (query_class, model) pair is an arm. Rewards come from trace outcomes.

Source code in src/openjarvis/learning/bandit_router.py
def __init__(
    self,
    *,
    strategy: Literal["thompson", "ucb"] = "thompson",
    exploration_factor: float = 2.0,  # UCB exploration constant
    min_pulls: int = 3,  # minimum pulls before trusting estimates
    reward_threshold: float = 0.5,  # reward above this = success
) -> None:
    self._strategy = strategy
    self._exploration = exploration_factor
    self._min_pulls = min_pulls
    self._reward_threshold = reward_threshold
    # query_class -> model -> ArmStats
    self._arms: Dict[str, Dict[str, ArmStats]] = defaultdict(
        lambda: defaultdict(ArmStats)
    )
    self._total_pulls = 0
Functions
route
route(context: RoutingContext, models: List[str]) -> str

Select model using the configured bandit strategy.

Source code in src/openjarvis/learning/bandit_router.py
def route(self, context: RoutingContext, models: List[str]) -> str:
    """Select model using the configured bandit strategy."""
    if not models:
        raise ValueError("No models available")

    query_class = _derive_query_class(context)
    arms = self._arms[query_class]

    # Ensure all models have arms
    for m in models:
        if m not in arms:
            arms[m] = ArmStats()

    # Check minimum pulls — explore uniformly first
    under_explored = [m for m in models if arms[m].pulls < self._min_pulls]
    if under_explored:
        return random.choice(under_explored)

    if self._strategy == "thompson":
        return self._thompson_select(models, arms)
    else:
        return self._ucb_select(models, arms)
update
update(query_class: str, model: str, reward: float) -> None

Update arm statistics with observed reward.

Source code in src/openjarvis/learning/bandit_router.py
def update(self, query_class: str, model: str, reward: float) -> None:
    """Update arm statistics with observed reward."""
    stats = self._arms[query_class][model]
    stats.pulls += 1
    stats.total_reward += reward
    if reward >= self._reward_threshold:
        stats.successes += 1
    else:
        stats.failures += 1
    self._total_pulls += 1
get_stats
get_stats(query_class: Optional[str] = None) -> Dict[str, Any]

Get arm statistics.

Source code in src/openjarvis/learning/bandit_router.py
def get_stats(self, query_class: Optional[str] = None) -> Dict[str, Any]:
    """Get arm statistics."""
    if query_class:
        arms = self._arms.get(query_class, {})
        return {
            m: {
                "pulls": s.pulls,
                "mean_reward": s.mean_reward,
                "successes": s.successes,
                "failures": s.failures,
            }
            for m, s in arms.items()
        }
    return {
        qc: {
            m: {"pulls": s.pulls, "mean_reward": s.mean_reward}
            for m, s in arms.items()
        }
        for qc, arms in self._arms.items()
    }
reset
reset() -> None

Reset all state.

Source code in src/openjarvis/learning/bandit_router.py
def reset(self) -> None:
    """Reset all state."""
    self._arms.clear()
    self._total_pulls = 0

Functions

ensure_registered

ensure_registered() -> None

Register BanditRouterPolicy if not already present.

Source code in src/openjarvis/learning/bandit_router.py
def ensure_registered() -> None:
    """Register BanditRouterPolicy if not already present."""
    if not RouterPolicyRegistry.contains("bandit"):
        RouterPolicyRegistry.register_value("bandit", BanditRouterPolicy)