Skip to content

lora

lora

LoRATrainer — fine-tune local models via LoRA/QLoRA from trace-derived SFT pairs.

All torch, transformers, and peft imports are guarded so the module can be imported without GPU dependencies. The :class:LoRATrainingConfig dataclass works without any optional deps; :class:LoRATrainer raises ImportError at construction time when torch is unavailable.

Classes

LoRATrainingConfig dataclass

LoRATrainingConfig(lora_rank: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: List[str] = (lambda: ['q_proj', 'v_proj'])(), num_epochs: int = 3, batch_size: int = 4, learning_rate: float = 2e-05, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_grad_norm: float = 1.0, max_seq_length: int = 2048, use_4bit: bool = False, output_dir: str = 'checkpoints/lora', save_every_n_epochs: int = 1, gradient_checkpointing: bool = True)

Configuration for LoRA / QLoRA fine-tuning.

LoRATrainer

LoRATrainer(config: LoRATrainingConfig, *, model_name: str = 'Qwen/Qwen3-0.6B', device: Optional[str] = None)

Fine-tune a local causal LM with LoRA (or QLoRA) adapters.

PARAMETER DESCRIPTION
config

LoRA training configuration.

TYPE: LoRATrainingConfig

model_name

HuggingFace model identifier or local path.

TYPE: str DEFAULT: 'Qwen/Qwen3-0.6B'

device

PyTorch device string. None auto-detects (cuda > mps > cpu).

TYPE: Optional[str] DEFAULT: None

RAISES DESCRIPTION
ImportError

If torch is not installed.

Source code in src/openjarvis/learning/training/lora.py
def __init__(
    self,
    config: LoRATrainingConfig,
    *,
    model_name: str = "Qwen/Qwen3-0.6B",
    device: Optional[str] = None,
) -> None:
    if not HAS_TORCH:
        raise ImportError(
            "torch is required for LoRATrainer. "
            "Install with: pip install torch transformers peft"
        )

    self.config = config
    self.model_name = model_name
    self.device = _select_device(device)
    self.tokenizer: Any = None
    self.model: Any = None
Functions
prepare_dataset
prepare_dataset(pairs: List[Dict[str, Any]]) -> List[Dict[str, Any]]

Convert SFT pairs to tokenized examples.

Each returned dict contains input_ids, attention_mask, and text (the raw formatted string before tokenization).

PARAMETER DESCRIPTION
pairs

List of dicts with at least input and output keys, as produced by :class:TrainingDataMiner.extract_sft_pairs.

TYPE: List[Dict[str, Any]]

Source code in src/openjarvis/learning/training/lora.py
def prepare_dataset(
    self, pairs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
    """Convert SFT pairs to tokenized examples.

    Each returned dict contains ``input_ids``, ``attention_mask``,
    and ``text`` (the raw formatted string before tokenization).

    Parameters
    ----------
    pairs:
        List of dicts with at least ``input`` and ``output`` keys,
        as produced by :class:`TrainingDataMiner.extract_sft_pairs`.
    """
    self._ensure_tokenizer()

    dataset: List[Dict[str, Any]] = []
    for pair in pairs:
        text = self._format_pair(pair)
        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.config.max_seq_length,
            padding="max_length",
            return_tensors="pt",
        )
        dataset.append({
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "text": text,
        })

    return dataset
train
train(pairs: List[Dict[str, Any]]) -> Dict[str, Any]

Run LoRA fine-tuning on the given SFT pairs.

PARAMETER DESCRIPTION
pairs

List of dicts with at least input and output keys.

TYPE: List[Dict[str, Any]]

RETURNS DESCRIPTION
dict

Training summary with keys: status, epochs, total_steps, avg_loss, adapter_path, training_samples.

Source code in src/openjarvis/learning/training/lora.py
def train(self, pairs: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Run LoRA fine-tuning on the given SFT pairs.

    Parameters
    ----------
    pairs:
        List of dicts with at least ``input`` and ``output`` keys.

    Returns
    -------
    dict
        Training summary with keys: ``status``, ``epochs``,
        ``total_steps``, ``avg_loss``, ``adapter_path``,
        ``training_samples``.
    """
    if not pairs:
        return {"status": "skipped", "reason": "no training data"}

    dataset = self.prepare_dataset(pairs)
    self._load_model()
    self._apply_lora()

    optimizer = torch.optim.AdamW(
        self.model.parameters(),
        lr=self.config.learning_rate,
        weight_decay=self.config.weight_decay,
    )

    total_steps = 0
    cumulative_loss = 0.0
    num_loss_steps = 0

    self.model.train()

    for epoch in range(self.config.num_epochs):
        epoch_loss = self._train_epoch(dataset, optimizer)
        steps_in_epoch = max(
            1, (len(dataset) + self.config.batch_size - 1) // self.config.batch_size
        )
        total_steps += steps_in_epoch
        cumulative_loss += epoch_loss * steps_in_epoch
        num_loss_steps += steps_in_epoch

        logger.info(
            "Epoch %d/%d  loss=%.4f",
            epoch + 1,
            self.config.num_epochs,
            epoch_loss,
        )

        if (epoch + 1) % self.config.save_every_n_epochs == 0:
            self._save_adapter(epoch + 1)

    avg_loss = cumulative_loss / num_loss_steps if num_loss_steps else 0.0
    adapter_path = str(Path(self.config.output_dir) / "final")
    self._save_adapter_to(adapter_path)

    return {
        "status": "completed",
        "epochs": self.config.num_epochs,
        "total_steps": total_steps,
        "avg_loss": avg_loss,
        "adapter_path": adapter_path,
        "training_samples": len(pairs),
    }