sft_trainer
sft_trainer
¶
SFT (Supervised Fine-Tuning) trainer for orchestrator.
Adapted from IPW's sft_trainer.py. Trains the orchestrator policy
using supervised learning on trajectories. All torch/transformers
imports are guarded so the module can be imported without GPU dependencies.
Classes¶
OrchestratorSFTConfig
dataclass
¶
OrchestratorSFTConfig(model_name: str = 'Qwen/Qwen3-1.7B', max_seq_length: int = 4096, num_epochs: int = 3, batch_size: int = 8, learning_rate: float = 2e-05, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_grad_norm: float = 1.0, teacher_engine_key: str = '', teacher_model: str = '', traces_per_query: int = 2, max_attempts_per_trace: int = 3, generation_temperature: float = 0.7, trace_cache_path: str = 'data/orchestrator_sft_traces.jsonl', regenerate_traces: bool = False, checkpoint_dir: str = 'checkpoints/orchestrator_sft', save_every_n_epochs: int = 1, log_dir: str = 'logs/orchestrator_sft', log_every_n_steps: int = 10, use_wandb: bool = False, gradient_checkpointing: bool = True, available_tools: List[str] = (lambda: ['calculator', 'think', 'code_interpreter', 'web_search'])())
Configuration for orchestrator SFT training.
OrchestratorSFTDataset
¶
Dataset for SFT training from generated trace JSONL files.
Source code in src/openjarvis/learning/intelligence/orchestrator/sft_trainer.py
OrchestratorSFTTrainer
¶
OrchestratorSFTTrainer(config: OrchestratorSFTConfig)
SFT trainer for orchestrator policy.
Performs standard next-token cross-entropy loss on successful
trajectories. torch must be installed to call :meth:train.
Source code in src/openjarvis/learning/intelligence/orchestrator/sft_trainer.py
Functions¶
train
¶
Run the SFT training loop.