Skip to content

grpo_trainer

grpo_trainer

General-purpose GRPO trainer -- Group Relative Policy Optimization.

Fine-tunes any local model by sampling N responses per prompt, computing group-relative advantages, and applying a clipped policy gradient with KL penalty vs a frozen reference model.

Classes

RewardFn

Bases: Protocol

Protocol for reward functions used by GRPOTrainer.

DefaultRewardFn

Default reward function using length-normalized response quality heuristics.

Functions
score
score(prompt: str, response: str, ground_truth: str | None) -> float

Score a response. Higher is better, range [0, 1].

Source code in src/openjarvis/learning/intelligence/grpo_trainer.py
def score(self, prompt: str, response: str, ground_truth: str | None) -> float:
    """Score a response. Higher is better, range [0, 1]."""
    score = 0.5  # baseline

    # Length heuristic: prefer non-empty, not-too-long responses
    if not response.strip():
        return 0.0
    resp_len = len(response)
    if resp_len < 10:
        score -= 0.1
    elif resp_len > 5000:
        score -= 0.05

    # Ground truth matching
    if ground_truth is not None:
        gt_lower = ground_truth.strip().lower()
        resp_lower = response.strip().lower()
        if gt_lower == resp_lower:
            score += 0.4
        elif gt_lower in resp_lower:
            score += 0.2

    return max(0.0, min(1.0, score))

GRPOTrainer

GRPOTrainer(config: GRPOConfig, reward_fn: RewardFn | None = None)

General-purpose GRPO trainer.

PARAMETER DESCRIPTION
config

GRPOConfig controlling model, sampling, and optimization params.

TYPE: GRPOConfig

reward_fn

Pluggable reward function. Defaults to DefaultRewardFn.

TYPE: RewardFn | None DEFAULT: None

Source code in src/openjarvis/learning/intelligence/grpo_trainer.py
def __init__(
    self,
    config: GRPOConfig,
    reward_fn: RewardFn | None = None,
) -> None:
    self.config = config
    self.reward_fn: RewardFn = reward_fn or DefaultRewardFn()
Functions
train
train(trace_store: Any) -> Dict[str, Any]

End-to-end: mine prompts from traces, then train.

PARAMETER DESCRIPTION
trace_store

Object with list_traces() returning trace objects.

TYPE: Any

Source code in src/openjarvis/learning/intelligence/grpo_trainer.py
def train(self, trace_store: Any) -> Dict[str, Any]:
    """End-to-end: mine prompts from traces, then train.

    Parameters
    ----------
    trace_store:
        Object with ``list_traces()`` returning trace objects.
    """
    prompts = self._mine_prompts(trace_store)
    return self.train_on_prompts(prompts)
train_on_prompts
train_on_prompts(prompts: List[str], ground_truths: List[str | None] | None = None) -> Dict[str, Any]

Run GRPO training on a set of prompts.

PARAMETER DESCRIPTION
prompts

List of prompt strings to train on.

TYPE: List[str]

ground_truths

Optional parallel list of ground-truth answers for reward scoring.

TYPE: List[str | None] | None DEFAULT: None

Source code in src/openjarvis/learning/intelligence/grpo_trainer.py
def train_on_prompts(
    self,
    prompts: List[str],
    ground_truths: List[str | None] | None = None,
) -> Dict[str, Any]:
    """Run GRPO training on a set of prompts.

    Parameters
    ----------
    prompts:
        List of prompt strings to train on.
    ground_truths:
        Optional parallel list of ground-truth answers for reward scoring.
    """
    if not prompts:
        return {"status": "skipped", "reason": "no training data"}

    if len(prompts) < self.config.min_prompts:
        return {
            "status": "skipped",
            "reason": (
                f"only {len(prompts)} prompts, "
                f"min_prompts={self.config.min_prompts}"
            ),
        }

    if not HAS_TORCH:
        return {"status": "error", "reason": "torch not available"}

    if not HAS_TRANSFORMERS:
        return {"status": "error", "reason": "transformers not available"}

    try:
        return self._run_grpo(prompts, ground_truths)
    except Exception as exc:
        logger.warning("GRPO training failed: %s", exc)
        return {"status": "error", "reason": str(exc)}