grpo_trainer
grpo_trainer
¶
GRPO (Group Relative Policy Optimization) trainer for orchestrator.
Adapted from IPW's trainer.py. GRPO is simpler than PPO because it
doesn't require a separate critic model — instead, it uses
group-relative advantages: for each problem, sample N candidate
trajectories, compute rewards, normalise within the group, and update
the policy to increase the probability of better solutions.
All torch/transformers imports are guarded so the module can be
imported without GPU dependencies.
Classes¶
OrchestratorGRPOConfig
dataclass
¶
OrchestratorGRPOConfig(model_name: str = 'Qwen/Qwen3-1.7B', max_prompt_length: int = 24000, max_response_length: int = 8768, num_epochs: int = 10, batch_size: int = 16, learning_rate: float = 1e-06, max_grad_norm: float = 1.0, num_samples_per_prompt: int = 8, temperature: float = 1.0, kl_coef: float = 0.0001, clip_ratio: float = 0.2, available_tools: List[str] = (lambda: ['calculator', 'think', 'code_interpreter', 'web_search'])(), max_turns: int = 10, checkpoint_dir: str = 'checkpoints/orchestrator_grpo', save_every_n_epochs: int = 1, keep_last_n: int = 3, gradient_checkpointing: bool = True, use_8bit_ref: bool = True, use_8bit_optimizer: bool = False)
Configuration for orchestrator GRPO training.
OrchestratorGRPOTrainer
¶
OrchestratorGRPOTrainer(config: OrchestratorGRPOConfig)
GRPO trainer for orchestrator policy.
torch must be installed to call :meth:train.
Source code in src/openjarvis/learning/orchestrator/grpo_trainer.py
Functions¶
train
¶
Run the GRPO training loop.