Skip to content

grpo_trainer

grpo_trainer

GRPO (Group Relative Policy Optimization) trainer for orchestrator.

Adapted from IPW's trainer.py. GRPO is simpler than PPO because it doesn't require a separate critic model — instead, it uses group-relative advantages: for each problem, sample N candidate trajectories, compute rewards, normalise within the group, and update the policy to increase the probability of better solutions.

All torch/transformers imports are guarded so the module can be imported without GPU dependencies.

Classes

OrchestratorGRPOConfig dataclass

OrchestratorGRPOConfig(model_name: str = 'Qwen/Qwen3-1.7B', max_prompt_length: int = 24000, max_response_length: int = 8768, num_epochs: int = 10, batch_size: int = 16, learning_rate: float = 1e-06, max_grad_norm: float = 1.0, num_samples_per_prompt: int = 8, temperature: float = 1.0, kl_coef: float = 0.0001, clip_ratio: float = 0.2, available_tools: List[str] = (lambda: ['calculator', 'think', 'code_interpreter', 'web_search'])(), max_turns: int = 10, checkpoint_dir: str = 'checkpoints/orchestrator_grpo', save_every_n_epochs: int = 1, keep_last_n: int = 3, gradient_checkpointing: bool = True, use_8bit_ref: bool = True, use_8bit_optimizer: bool = False)

Configuration for orchestrator GRPO training.

OrchestratorGRPOTrainer

OrchestratorGRPOTrainer(config: OrchestratorGRPOConfig)

GRPO trainer for orchestrator policy.

torch must be installed to call :meth:train.

Source code in src/openjarvis/learning/orchestrator/grpo_trainer.py
def __init__(self, config: OrchestratorGRPOConfig) -> None:
    self.config = config
    self.device = None
    self.global_step = 0

    if HAS_TORCH and torch is not None:
        self.device = _select_torch_device()

    self._init_model()
    self._init_optimizer()
Functions
train
train() -> None

Run the GRPO training loop.

Source code in src/openjarvis/learning/orchestrator/grpo_trainer.py
def train(self) -> None:
    """Run the GRPO training loop."""
    if not HAS_TORCH:
        raise RuntimeError(
            "PyTorch is required for training. "
            "Install with: pip install torch transformers"
        )

    for epoch in range(self.config.num_epochs):
        self._train_epoch(epoch)

        if (epoch + 1) % self.config.save_every_n_epochs == 0:
            self._save_checkpoint(epoch)