Skip to content

sft_trainer

sft_trainer

SFT (Supervised Fine-Tuning) trainer for orchestrator.

Adapted from IPW's sft_trainer.py. Trains the orchestrator policy using supervised learning on trajectories. All torch/transformers imports are guarded so the module can be imported without GPU dependencies.

Classes

OrchestratorSFTConfig dataclass

OrchestratorSFTConfig(model_name: str = 'Qwen/Qwen3-1.7B', max_seq_length: int = 4096, num_epochs: int = 3, batch_size: int = 8, learning_rate: float = 2e-05, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_grad_norm: float = 1.0, teacher_engine_key: str = '', teacher_model: str = '', traces_per_query: int = 2, max_attempts_per_trace: int = 3, generation_temperature: float = 0.7, trace_cache_path: str = 'data/orchestrator_sft_traces.jsonl', regenerate_traces: bool = False, checkpoint_dir: str = 'checkpoints/orchestrator_sft', save_every_n_epochs: int = 1, log_dir: str = 'logs/orchestrator_sft', log_every_n_steps: int = 10, use_wandb: bool = False, gradient_checkpointing: bool = True, available_tools: List[str] = (lambda: ['calculator', 'think', 'code_interpreter', 'web_search'])())

Configuration for orchestrator SFT training.

OrchestratorSFTDataset

OrchestratorSFTDataset(trace_path: str, tokenizer: Any, max_seq_length: int = 4096)

Dataset for SFT training from generated trace JSONL files.

Source code in src/openjarvis/learning/intelligence/orchestrator/sft_trainer.py
def __init__(
    self,
    trace_path: str,
    tokenizer: Any,
    max_seq_length: int = 4096,
) -> None:
    self.tokenizer = tokenizer
    self.max_seq_length = max_seq_length
    self.traces: List[Dict[str, Any]] = []
    self._load_traces(trace_path)

OrchestratorSFTTrainer

OrchestratorSFTTrainer(config: OrchestratorSFTConfig)

SFT trainer for orchestrator policy.

Performs standard next-token cross-entropy loss on successful trajectories. torch must be installed to call :meth:train.

Source code in src/openjarvis/learning/intelligence/orchestrator/sft_trainer.py
def __init__(self, config: OrchestratorSFTConfig) -> None:
    self.config = config
    self.device = None
    self.global_step = 0

    if HAS_TORCH and torch is not None:
        self.device = _select_torch_device()

    self._init_model()
    self._init_data()
    self._init_optimizer()
Functions
train
train() -> None

Run the SFT training loop.

Source code in src/openjarvis/learning/intelligence/orchestrator/sft_trainer.py
def train(self) -> None:
    """Run the SFT training loop."""
    if not HAS_TORCH:
        raise RuntimeError(
            "PyTorch is required for training. "
            "Install with: pip install torch transformers"
        )

    for epoch in range(self.config.num_epochs):
        self._train_epoch(epoch)

        if (epoch + 1) % self.config.save_every_n_epochs == 0:
            self._save_checkpoint(epoch)