Skip to content

Index

training

Training data extraction and fine-tuning pipelines for trace-driven learning.

Classes

TrainingDataMiner

TrainingDataMiner(trace_store: Any, *, min_quality: float = 0.7, min_samples_per_class: int = 1)

Extract supervised training pairs from stored traces.

PARAMETER DESCRIPTION
trace_store

Any object with a list_traces(limit=...) method returning List[Trace] (typically a :class:TraceStore).

TYPE: Any

min_quality

Minimum feedback score for a trace to be included.

TYPE: float DEFAULT: 0.7

min_samples_per_class

Minimum number of samples a query class must have to appear in routing/agent-config results.

TYPE: int DEFAULT: 1

Source code in src/openjarvis/learning/training/data.py
def __init__(
    self,
    trace_store: Any,
    *,
    min_quality: float = 0.7,
    min_samples_per_class: int = 1,
) -> None:
    self._store = trace_store
    self._min_quality = min_quality
    self._min_samples_per_class = min_samples_per_class
Functions
extract_sft_pairs
extract_sft_pairs(*, agent: str | None = None) -> List[Dict[str, Any]]

Return SFT training pairs from high-quality traces.

Each entry is a dict with keys: input, output, query_class, model, feedback.

Duplicate (input, output) pairs are collapsed; the first occurrence is kept.

Source code in src/openjarvis/learning/training/data.py
def extract_sft_pairs(self, *, agent: str | None = None) -> List[Dict[str, Any]]:
    """Return SFT training pairs from high-quality traces.

    Each entry is a dict with keys: ``input``, ``output``,
    ``query_class``, ``model``, ``feedback``.

    Duplicate ``(input, output)`` pairs are collapsed; the first
    occurrence is kept.
    """
    traces = self._quality_traces(agent=agent)
    seen: set[tuple[str, str]] = set()
    pairs: List[Dict[str, Any]] = []

    for t in traces:
        key = (t.query, t.result)
        if key in seen:
            continue
        seen.add(key)
        pairs.append(
            {
                "input": t.query,
                "output": t.result,
                "query_class": classify_query(t.query),
                "model": t.model,
                "feedback": t.feedback,
            }
        )

    return pairs
extract_routing_pairs
extract_routing_pairs(*, agent: str | None = None) -> Dict[str, Dict[str, Any]]

Return per-query-class routing recommendations.

Returns a dict mapping query class to:

  • best_model — model with highest average feedback for the class.
  • avg_feedback — average feedback across all models for the class.
  • sample_count — total number of qualifying traces in the class.
  • all_models — dict of {model: {"avg_feedback": float, "count": int}}.
Source code in src/openjarvis/learning/training/data.py
def extract_routing_pairs(
    self, *, agent: str | None = None
) -> Dict[str, Dict[str, Any]]:
    """Return per-query-class routing recommendations.

    Returns a dict mapping query class to:

    * ``best_model`` — model with highest average feedback for the class.
    * ``avg_feedback`` — average feedback across all models for the class.
    * ``sample_count`` — total number of qualifying traces in the class.
    * ``all_models`` — dict of ``{model: {"avg_feedback": float, "count": int}}``.
    """
    traces = self._quality_traces(agent=agent)

    # Accumulate per (query_class, model) feedback scores
    class_model_scores: Dict[str, Dict[str, List[float]]] = defaultdict(
        lambda: defaultdict(list)
    )
    for t in traces:
        qc = classify_query(t.query)
        class_model_scores[qc][t.model].append(t.feedback)  # type: ignore[arg-type]

    result: Dict[str, Dict[str, Any]] = {}
    for qc, model_scores in class_model_scores.items():
        total_count = sum(len(scores) for scores in model_scores.values())
        if total_count < self._min_samples_per_class:
            continue

        all_models: Dict[str, Dict[str, Any]] = {}
        best_model = ""
        best_avg = -1.0

        for model, scores in model_scores.items():
            avg = sum(scores) / len(scores)
            all_models[model] = {"avg_feedback": avg, "count": len(scores)}
            if avg > best_avg:
                best_avg = avg
                best_model = model

        total_scores = [s for scores in model_scores.values() for s in scores]
        overall_avg = sum(total_scores) / len(total_scores) if total_scores else 0.0

        result[qc] = {
            "best_model": best_model,
            "avg_feedback": overall_avg,
            "sample_count": total_count,
            "all_models": all_models,
        }

    return result
extract_agent_config_pairs
extract_agent_config_pairs(*, agent: str | None = None) -> Dict[str, Dict[str, Any]]

Return per-query-class agent and tool recommendations.

Returns a dict mapping query class to:

  • best_agent — agent with the highest average feedback.
  • best_tools — most frequently used tools by the best agent.
  • avg_feedback — average feedback across all agents for the class.
  • sample_count — total number of qualifying traces in the class.
Source code in src/openjarvis/learning/training/data.py
def extract_agent_config_pairs(
    self, *, agent: str | None = None
) -> Dict[str, Dict[str, Any]]:
    """Return per-query-class agent and tool recommendations.

    Returns a dict mapping query class to:

    * ``best_agent`` — agent with the highest average feedback.
    * ``best_tools`` — most frequently used tools by the best agent.
    * ``avg_feedback`` — average feedback across all agents for the class.
    * ``sample_count`` — total number of qualifying traces in the class.
    """
    traces = self._quality_traces(agent=agent)

    # Accumulate per (query_class, agent) feedback and tools
    class_agent_scores: Dict[str, Dict[str, List[float]]] = defaultdict(
        lambda: defaultdict(list)
    )
    class_agent_tools: Dict[str, Dict[str, List[List[str]]]] = defaultdict(
        lambda: defaultdict(list)
    )

    for t in traces:
        qc = classify_query(t.query)
        class_agent_scores[qc][t.agent].append(t.feedback)  # type: ignore[arg-type]
        tools = self._tools_from_trace(t)
        class_agent_tools[qc][t.agent].append(tools)

    result: Dict[str, Dict[str, Any]] = {}
    for qc, agent_scores in class_agent_scores.items():
        total_count = sum(len(scores) for scores in agent_scores.values())
        if total_count < self._min_samples_per_class:
            continue

        best_agent = ""
        best_avg = -1.0
        for agent, scores in agent_scores.items():
            avg = sum(scores) / len(scores)
            if avg > best_avg:
                best_avg = avg
                best_agent = agent

        # Collect tool frequency for best agent
        tool_freq: Dict[str, int] = defaultdict(int)
        for tool_list in class_agent_tools[qc].get(best_agent, []):
            for tool in tool_list:
                tool_freq[tool] += 1

        best_tools = sorted(tool_freq, key=tool_freq.get, reverse=True)  # type: ignore[arg-type]

        total_scores = [s for scores in agent_scores.values() for s in scores]
        overall_avg = sum(total_scores) / len(total_scores) if total_scores else 0.0

        result[qc] = {
            "best_agent": best_agent,
            "best_tools": best_tools,
            "avg_feedback": overall_avg,
            "sample_count": total_count,
        }

    return result

LoRATrainer

LoRATrainer(config: LoRATrainingConfig, *, model_name: str = 'Qwen/Qwen3-0.6B', device: Optional[str] = None)

Fine-tune a local causal LM with LoRA (or QLoRA) adapters.

PARAMETER DESCRIPTION
config

LoRA training configuration.

TYPE: LoRATrainingConfig

model_name

HuggingFace model identifier or local path.

TYPE: str DEFAULT: 'Qwen/Qwen3-0.6B'

device

PyTorch device string. None auto-detects (cuda > mps > cpu).

TYPE: Optional[str] DEFAULT: None

RAISES DESCRIPTION
ImportError

If torch is not installed.

Source code in src/openjarvis/learning/training/lora.py
def __init__(
    self,
    config: LoRATrainingConfig,
    *,
    model_name: str = "Qwen/Qwen3-0.6B",
    device: Optional[str] = None,
) -> None:
    if not HAS_TORCH:
        raise ImportError(
            "torch is required for LoRATrainer. "
            "Install with: pip install torch transformers peft"
        )

    self.config = config
    self.model_name = model_name
    self.device = _select_device(device)
    self.tokenizer: Any = None
    self.model: Any = None
Functions
prepare_dataset
prepare_dataset(pairs: List[Dict[str, Any]]) -> List[Dict[str, Any]]

Convert SFT pairs to tokenized examples.

Each returned dict contains input_ids, attention_mask, and text (the raw formatted string before tokenization).

PARAMETER DESCRIPTION
pairs

List of dicts with at least input and output keys, as produced by :class:TrainingDataMiner.extract_sft_pairs.

TYPE: List[Dict[str, Any]]

Source code in src/openjarvis/learning/training/lora.py
def prepare_dataset(
    self, pairs: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
    """Convert SFT pairs to tokenized examples.

    Each returned dict contains ``input_ids``, ``attention_mask``,
    and ``text`` (the raw formatted string before tokenization).

    Parameters
    ----------
    pairs:
        List of dicts with at least ``input`` and ``output`` keys,
        as produced by :class:`TrainingDataMiner.extract_sft_pairs`.
    """
    self._ensure_tokenizer()

    dataset: List[Dict[str, Any]] = []
    for pair in pairs:
        text = self._format_pair(pair)
        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.config.max_seq_length,
            padding="max_length",
            return_tensors="pt",
        )
        dataset.append({
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "text": text,
        })

    return dataset
train
train(pairs: List[Dict[str, Any]]) -> Dict[str, Any]

Run LoRA fine-tuning on the given SFT pairs.

PARAMETER DESCRIPTION
pairs

List of dicts with at least input and output keys.

TYPE: List[Dict[str, Any]]

RETURNS DESCRIPTION
dict

Training summary with keys: status, epochs, total_steps, avg_loss, adapter_path, training_samples.

Source code in src/openjarvis/learning/training/lora.py
def train(self, pairs: List[Dict[str, Any]]) -> Dict[str, Any]:
    """Run LoRA fine-tuning on the given SFT pairs.

    Parameters
    ----------
    pairs:
        List of dicts with at least ``input`` and ``output`` keys.

    Returns
    -------
    dict
        Training summary with keys: ``status``, ``epochs``,
        ``total_steps``, ``avg_loss``, ``adapter_path``,
        ``training_samples``.
    """
    if not pairs:
        return {"status": "skipped", "reason": "no training data"}

    dataset = self.prepare_dataset(pairs)
    self._load_model()
    self._apply_lora()

    optimizer = torch.optim.AdamW(
        self.model.parameters(),
        lr=self.config.learning_rate,
        weight_decay=self.config.weight_decay,
    )

    total_steps = 0
    cumulative_loss = 0.0
    num_loss_steps = 0

    self.model.train()

    for epoch in range(self.config.num_epochs):
        epoch_loss = self._train_epoch(dataset, optimizer)
        steps_in_epoch = max(
            1, (len(dataset) + self.config.batch_size - 1) // self.config.batch_size
        )
        total_steps += steps_in_epoch
        cumulative_loss += epoch_loss * steps_in_epoch
        num_loss_steps += steps_in_epoch

        logger.info(
            "Epoch %d/%d  loss=%.4f",
            epoch + 1,
            self.config.num_epochs,
            epoch_loss,
        )

        if (epoch + 1) % self.config.save_every_n_epochs == 0:
            self._save_adapter(epoch + 1)

    avg_loss = cumulative_loss / num_loss_steps if num_loss_steps else 0.0
    adapter_path = str(Path(self.config.output_dir) / "final")
    self._save_adapter_to(adapter_path)

    return {
        "status": "completed",
        "epochs": self.config.num_epochs,
        "total_steps": total_steps,
        "avg_loss": avg_loss,
        "adapter_path": adapter_path,
        "training_samples": len(pairs),
    }

LoRATrainingConfig dataclass

LoRATrainingConfig(lora_rank: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.05, target_modules: List[str] = (lambda: ['q_proj', 'v_proj'])(), num_epochs: int = 3, batch_size: int = 4, learning_rate: float = 2e-05, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_grad_norm: float = 1.0, max_seq_length: int = 2048, use_4bit: bool = False, output_dir: str = 'checkpoints/lora', save_every_n_epochs: int = 1, gradient_checkpointing: bool = True)

Configuration for LoRA / QLoRA fine-tuning.