Skip to content

Index

orchestrator

Orchestrator training infrastructure — SFT and GRPO pipelines.

Provides structured-mode training for the OrchestratorAgent with:

  • Episode types: Action, Observation, Episode, EpisodeState
  • Reward: Multi-objective reward balancing accuracy, cost, energy, latency, power
  • Prompt registry: Canonical system prompts for structured mode
  • Policy model: HuggingFace LM wrapper for action prediction
  • Environment: RL environment using OpenJarvis ToolExecutor
  • SFT trainer: Supervised fine-tuning on successful trajectories
  • GRPO trainer: Group Relative Policy Optimization

Importing this module triggers registration of orchestrator_sft and orchestrator_grpo in :class:~openjarvis.core.registry.LearningRegistry.

Classes

OrchestratorEnvironment

OrchestratorEnvironment(tools: List[BaseTool], max_turns: int = 10)

RL environment that executes tools via OpenJarvis ToolExecutor.

PARAMETER DESCRIPTION
tools

List of :class:BaseTool instances available to the agent.

TYPE: List[BaseTool]

max_turns

Maximum number of turns per episode.

TYPE: int DEFAULT: 10

Source code in src/openjarvis/learning/orchestrator/environment.py
def __init__(
    self,
    tools: List[BaseTool],
    max_turns: int = 10,
) -> None:
    self._tools = tools
    self._executor = ToolExecutor(tools)
    self._max_turns = max_turns
Functions
reset
reset(task: str) -> EpisodeState

Reset the environment for a new episode.

Args: task: The initial task/question.

Returns: A fresh :class:EpisodeState.

Source code in src/openjarvis/learning/orchestrator/environment.py
def reset(self, task: str) -> EpisodeState:
    """Reset the environment for a new episode.

    Args:
        task: The initial task/question.

    Returns:
        A fresh :class:`EpisodeState`.
    """
    return EpisodeState(initial_prompt=task)
step

Execute one step: dispatch the tool and observe the result.

Raises: ValueError: If the tool is not available or max turns exceeded.

Source code in src/openjarvis/learning/orchestrator/environment.py
def step(
    self,
    state: EpisodeState,
    action: OrchestratorAction,
) -> Tuple[EpisodeState, OrchestratorObservation]:
    """Execute one step: dispatch the tool and observe the result.

    Raises:
        ValueError: If the tool is not available or max turns exceeded.
    """
    available = self.get_available_tools()

    if action.tool_name not in available:
        raise ValueError(
            f"Tool '{action.tool_name}' not available. "
            f"Available: {available}"
        )

    if state.num_turns() >= self._max_turns:
        raise ValueError(
            f"Max turns ({self._max_turns}) exceeded"
        )

    # Execute tool via ToolExecutor
    tool_call = ToolCall(
        id=f"orch_{state.num_turns()}",
        name=action.tool_name,
        arguments=action.tool_input
        if action.tool_input.startswith("{")
        else f'{{"expression": {repr(action.tool_input)}}}',
    )

    t0 = time.time()
    result = self._executor.execute(tool_call)
    latency = time.time() - t0

    observation = OrchestratorObservation(
        content=result.content,
        latency_seconds=latency,
        cost_usd=result.cost_usd,
        energy_joules=0.0,
        power_watts=0.0,
        tokens=result.usage.get("total_tokens", 0),
    )

    state.add_turn(action, observation)
    return state, observation
is_done
is_done(state: EpisodeState) -> bool

Check if the episode is complete.

Source code in src/openjarvis/learning/orchestrator/environment.py
def is_done(self, state: EpisodeState) -> bool:
    """Check if the episode is complete."""
    if state.final_answer is not None:
        return True
    if state.num_turns() >= self._max_turns:
        return True
    return False
get_available_tools
get_available_tools() -> List[str]

Return names of all available tools.

Source code in src/openjarvis/learning/orchestrator/environment.py
def get_available_tools(self) -> List[str]:
    """Return names of all available tools."""
    return [t.spec.name for t in self._tools]

OrchestratorGRPOConfig dataclass

OrchestratorGRPOConfig(model_name: str = 'Qwen/Qwen3-1.7B', max_prompt_length: int = 24000, max_response_length: int = 8768, num_epochs: int = 10, batch_size: int = 16, learning_rate: float = 1e-06, max_grad_norm: float = 1.0, num_samples_per_prompt: int = 8, temperature: float = 1.0, kl_coef: float = 0.0001, clip_ratio: float = 0.2, available_tools: List[str] = (lambda: ['calculator', 'think', 'code_interpreter', 'web_search'])(), max_turns: int = 10, checkpoint_dir: str = 'checkpoints/orchestrator_grpo', save_every_n_epochs: int = 1, keep_last_n: int = 3, gradient_checkpointing: bool = True, use_8bit_ref: bool = True, use_8bit_optimizer: bool = False)

Configuration for orchestrator GRPO training.

OrchestratorGRPOTrainer

OrchestratorGRPOTrainer(config: OrchestratorGRPOConfig)

GRPO trainer for orchestrator policy.

torch must be installed to call :meth:train.

Source code in src/openjarvis/learning/orchestrator/grpo_trainer.py
def __init__(self, config: OrchestratorGRPOConfig) -> None:
    self.config = config
    self.device = None
    self.global_step = 0

    if HAS_TORCH and torch is not None:
        self.device = _select_torch_device()

    self._init_model()
    self._init_optimizer()
Functions
train
train() -> None

Run the GRPO training loop.

Source code in src/openjarvis/learning/orchestrator/grpo_trainer.py
def train(self) -> None:
    """Run the GRPO training loop."""
    if not HAS_TORCH:
        raise RuntimeError(
            "PyTorch is required for training. "
            "Install with: pip install torch transformers"
        )

    for epoch in range(self.config.num_epochs):
        self._train_epoch(epoch)

        if (epoch + 1) % self.config.save_every_n_epochs == 0:
            self._save_checkpoint(epoch)

OrchestratorPolicyModel

OrchestratorPolicyModel(model: Any = None, tokenizer: Any = None, max_tokens: int = 256, temperature: float = 0.7)

Wrapper around a causal LM for orchestrator policy prediction.

Input format (prompt)::

Task: {initial_prompt}

Available tools: calculator, think, ...

History:
Turn 1:
  Thought: ...
  Tool: ...
  Observation: ...

What should you do next?
Format your response as:
THOUGHT: [your reasoning]
TOOL: [tool_name]
INPUT: [input for tool]

Output format (from model)::

THOUGHT: [reasoning]
TOOL: [tool_name]
INPUT: [input]
--- or ---
FINAL_ANSWER: [answer]
Source code in src/openjarvis/learning/orchestrator/policy_model.py
def __init__(
    self,
    model: Any = None,
    tokenizer: Any = None,
    max_tokens: int = 256,
    temperature: float = 0.7,
) -> None:
    self.model = model
    self.tokenizer = tokenizer
    self.max_tokens = max_tokens
    self.temperature = temperature
Functions
from_pretrained classmethod
from_pretrained(model_name: str = 'Qwen/Qwen3-1.7B', gradient_checkpointing: bool = False, load_in_8bit: bool = False, device: Optional[str] = None, **kwargs: Any) -> 'OrchestratorPolicyModel'

Load model from a HuggingFace checkpoint.

Raises ImportError if transformers is not installed.

Source code in src/openjarvis/learning/orchestrator/policy_model.py
@classmethod
def from_pretrained(
    cls,
    model_name: str = "Qwen/Qwen3-1.7B",
    gradient_checkpointing: bool = False,
    load_in_8bit: bool = False,
    device: Optional[str] = None,
    **kwargs: Any,
) -> "OrchestratorPolicyModel":
    """Load model from a HuggingFace checkpoint.

    Raises ``ImportError`` if ``transformers`` is not installed.
    """
    import torch as _torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model_kwargs: dict[str, Any] = {"torch_dtype": _torch.bfloat16}

    if load_in_8bit:
        try:
            from transformers import BitsAndBytesConfig

            model_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_8bit=True
            )
        except ImportError as exc:
            logger.debug("FP8 not available, falling back to BF16: %s", exc)

    if device is not None:
        if device == "auto":
            model_kwargs["device_map"] = "auto"
        else:
            model_kwargs["device_map"] = {"": device}

    model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)

    if gradient_checkpointing and hasattr(
        model, "gradient_checkpointing_enable"
    ):
        model.gradient_checkpointing_enable(
            gradient_checkpointing_kwargs={"use_reentrant": False}
        )

    return cls(model=model, tokenizer=tokenizer, **kwargs)
from_checkpoint classmethod
from_checkpoint(checkpoint_path: str, **kwargs: Any) -> 'OrchestratorPolicyModel'

Load from a previously saved checkpoint directory.

Source code in src/openjarvis/learning/orchestrator/policy_model.py
@classmethod
def from_checkpoint(
    cls, checkpoint_path: str, **kwargs: Any
) -> "OrchestratorPolicyModel":
    """Load from a previously saved checkpoint directory."""
    return cls.from_pretrained(checkpoint_path, **kwargs)
predict_action
predict_action(state: EpisodeState, available_tools: List[str]) -> OrchestratorAction

Predict the next action given current state.

Source code in src/openjarvis/learning/orchestrator/policy_model.py
def predict_action(
    self,
    state: EpisodeState,
    available_tools: List[str],
) -> OrchestratorAction:
    """Predict the next action given current state."""
    prompt = self._build_prompt(state, available_tools)

    if self.model is None:
        raise RuntimeError(
            "Cannot generate actions without a loaded model. "
            "Load with OrchestratorPolicyModel.from_pretrained() first."
        )

    output_text = self._generate(prompt)
    policy_output = self._parse_output(output_text, available_tools)
    return OrchestratorAction(
        thought=policy_output.thought,
        tool_name=policy_output.tool_name,
        tool_input=policy_output.tool_input,
        is_final_answer=policy_output.is_final_answer,
    )
save
save(path: str) -> None

Save model and tokenizer to path.

Source code in src/openjarvis/learning/orchestrator/policy_model.py
def save(self, path: str) -> None:
    """Save model and tokenizer to *path*."""
    if self.model is not None:
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)

AdaptiveRewardWeights

AdaptiveRewardWeights(initial_alpha: float = 0.6, final_alpha: float = 0.3, initial_beta_cost: float = 0.1, final_beta_cost: float = 0.15, initial_beta_energy: float = 0.1, final_beta_energy: float = 0.2, initial_gamma_latency: float = 0.1, final_gamma_latency: float = 0.15, initial_gamma_power: float = 0.1, final_gamma_power: float = 0.2, total_steps: int = 10000)

Adaptive reward weights that shift during training.

Early training focuses on accuracy (higher alpha). Late training optimises efficiency (higher cost/energy/power weights).

Source code in src/openjarvis/learning/orchestrator/reward.py
def __init__(
    self,
    initial_alpha: float = 0.6,
    final_alpha: float = 0.3,
    initial_beta_cost: float = 0.1,
    final_beta_cost: float = 0.15,
    initial_beta_energy: float = 0.1,
    final_beta_energy: float = 0.2,
    initial_gamma_latency: float = 0.1,
    final_gamma_latency: float = 0.15,
    initial_gamma_power: float = 0.1,
    final_gamma_power: float = 0.2,
    total_steps: int = 10000,
) -> None:
    self.initial_alpha = initial_alpha
    self.final_alpha = final_alpha
    self.initial_beta_cost = initial_beta_cost
    self.final_beta_cost = final_beta_cost
    self.initial_beta_energy = initial_beta_energy
    self.final_beta_energy = final_beta_energy
    self.initial_gamma_latency = initial_gamma_latency
    self.final_gamma_latency = final_gamma_latency
    self.initial_gamma_power = initial_gamma_power
    self.final_gamma_power = final_gamma_power
    self.total_steps = total_steps
Functions
get_weights
get_weights(current_step: int) -> RewardWeights

Get weights for current_step via linear interpolation.

Source code in src/openjarvis/learning/orchestrator/reward.py
def get_weights(self, current_step: int) -> RewardWeights:
    """Get weights for *current_step* via linear interpolation."""
    progress = min(1.0, current_step / self.total_steps)

    alpha = self.initial_alpha + (self.final_alpha - self.initial_alpha) * progress
    beta_cost = (
        self.initial_beta_cost
        + (self.final_beta_cost - self.initial_beta_cost) * progress
    )
    beta_energy = (
        self.initial_beta_energy
        + (self.final_beta_energy - self.initial_beta_energy) * progress
    )
    gamma_latency = (
        self.initial_gamma_latency
        + (self.final_gamma_latency - self.initial_gamma_latency) * progress
    )
    gamma_power = (
        self.initial_gamma_power
        + (self.final_gamma_power - self.initial_gamma_power) * progress
    )

    # Normalize to sum to 1.0
    total = alpha + beta_cost + beta_energy + gamma_latency + gamma_power
    return RewardWeights(
        alpha=alpha / total,
        beta_cost=beta_cost / total,
        beta_energy=beta_energy / total,
        gamma_latency=gamma_latency / total,
        gamma_power=gamma_power / total,
    )

MultiObjectiveReward

MultiObjectiveReward(weights: RewardWeights, normalizers: Normalizers)

Multi-objective reward combining accuracy, cost, energy, latency, power.

Formula::

reward = alpha * accuracy
         - beta_cost  * (cost   / cost_scale)
         - beta_energy * (energy / energy_scale)
         - gamma_latency * (latency / latency_scale)
         - gamma_power   * (power   / power_scale)
Source code in src/openjarvis/learning/orchestrator/reward.py
def __init__(
    self,
    weights: RewardWeights,
    normalizers: Normalizers,
) -> None:
    self.weights = weights
    self.normalizers = normalizers
Functions
compute
compute(episode: Episode) -> float

Compute scalar reward for an episode.

Source code in src/openjarvis/learning/orchestrator/reward.py
def compute(self, episode: Episode) -> float:
    """Compute scalar reward for an episode."""
    accuracy_reward = 1.0 if episode.correct else 0.0

    cost_penalty = episode.total_cost_usd / self.normalizers.cost_scale
    energy_penalty = (
        episode.total_energy_joules / self.normalizers.energy_scale
    )
    latency_penalty = (
        episode.total_latency_seconds / self.normalizers.latency_scale
    )
    power_penalty = episode.max_power_watts / self.normalizers.power_scale

    return (
        self.weights.alpha * accuracy_reward
        - self.weights.beta_cost * cost_penalty
        - self.weights.beta_energy * energy_penalty
        - self.weights.gamma_latency * latency_penalty
        - self.weights.gamma_power * power_penalty
    )
compute_with_breakdown
compute_with_breakdown(episode: Episode) -> Dict[str, float]

Compute reward with detailed per-component breakdown.

Source code in src/openjarvis/learning/orchestrator/reward.py
def compute_with_breakdown(self, episode: Episode) -> Dict[str, float]:
    """Compute reward with detailed per-component breakdown."""
    accuracy_reward = 1.0 if episode.correct else 0.0

    cost_penalty = episode.total_cost_usd / self.normalizers.cost_scale
    energy_penalty = (
        episode.total_energy_joules / self.normalizers.energy_scale
    )
    latency_penalty = (
        episode.total_latency_seconds / self.normalizers.latency_scale
    )
    power_penalty = episode.max_power_watts / self.normalizers.power_scale

    accuracy_component = self.weights.alpha * accuracy_reward
    cost_component = -self.weights.beta_cost * cost_penalty
    energy_component = -self.weights.beta_energy * energy_penalty
    latency_component = -self.weights.gamma_latency * latency_penalty
    power_component = -self.weights.gamma_power * power_penalty

    total_reward = (
        accuracy_component
        + cost_component
        + energy_component
        + latency_component
        + power_component
    )

    ipj = episode.compute_ipj()

    return {
        "total_reward": total_reward,
        "accuracy_reward": accuracy_reward,
        "accuracy_component": accuracy_component,
        "cost_penalty": cost_penalty,
        "cost_component": cost_component,
        "energy_penalty": energy_penalty,
        "energy_component": energy_component,
        "latency_penalty": latency_penalty,
        "latency_component": latency_component,
        "power_penalty": power_penalty,
        "power_component": power_component,
        "ipj": ipj,
        "total_energy_joules": episode.total_energy_joules,
        "total_cost_usd": episode.total_cost_usd,
        "total_latency_seconds": episode.total_latency_seconds,
    }
compute_batch
compute_batch(episodes: List[Episode]) -> List[float]

Compute rewards for a batch of episodes.

Source code in src/openjarvis/learning/orchestrator/reward.py
def compute_batch(self, episodes: List[Episode]) -> List[float]:
    """Compute rewards for a batch of episodes."""
    return [self.compute(ep) for ep in episodes]

Normalizers dataclass

Normalizers(energy_scale: float = 100.0, cost_scale: float = 0.1, latency_scale: float = 30.0, power_scale: float = 200.0)

Normalization constants for reward scaling.

These are typical values used to scale metrics to similar ranges. Tune based on your specific tools and tasks.

RewardWeights dataclass

RewardWeights(alpha: float = 0.4, beta_cost: float = 0.15, beta_energy: float = 0.15, gamma_latency: float = 0.15, gamma_power: float = 0.15)

Weights for multi-objective reward function.

Each metric has its own coefficient: - alpha: Accuracy (correctness of answer) - beta_cost: API/cloud cost in USD - beta_energy: Energy consumption in joules - gamma_latency: Response time in seconds - gamma_power: Peak power usage in watts

OrchestratorSFTConfig dataclass

OrchestratorSFTConfig(model_name: str = 'Qwen/Qwen3-1.7B', max_seq_length: int = 4096, num_epochs: int = 3, batch_size: int = 8, learning_rate: float = 2e-05, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_grad_norm: float = 1.0, teacher_engine_key: str = '', teacher_model: str = '', traces_per_query: int = 2, max_attempts_per_trace: int = 3, generation_temperature: float = 0.7, trace_cache_path: str = 'data/orchestrator_sft_traces.jsonl', regenerate_traces: bool = False, checkpoint_dir: str = 'checkpoints/orchestrator_sft', save_every_n_epochs: int = 1, log_dir: str = 'logs/orchestrator_sft', log_every_n_steps: int = 10, use_wandb: bool = False, gradient_checkpointing: bool = True, available_tools: List[str] = (lambda: ['calculator', 'think', 'code_interpreter', 'web_search'])())

Configuration for orchestrator SFT training.

OrchestratorSFTDataset

OrchestratorSFTDataset(trace_path: str, tokenizer: Any, max_seq_length: int = 4096)

Dataset for SFT training from generated trace JSONL files.

Source code in src/openjarvis/learning/orchestrator/sft_trainer.py
def __init__(
    self,
    trace_path: str,
    tokenizer: Any,
    max_seq_length: int = 4096,
) -> None:
    self.tokenizer = tokenizer
    self.max_seq_length = max_seq_length
    self.traces: List[Dict[str, Any]] = []
    self._load_traces(trace_path)

OrchestratorSFTTrainer

OrchestratorSFTTrainer(config: OrchestratorSFTConfig)

SFT trainer for orchestrator policy.

Performs standard next-token cross-entropy loss on successful trajectories. torch must be installed to call :meth:train.

Source code in src/openjarvis/learning/orchestrator/sft_trainer.py
def __init__(self, config: OrchestratorSFTConfig) -> None:
    self.config = config
    self.device = None
    self.global_step = 0

    if HAS_TORCH and torch is not None:
        self.device = _select_torch_device()

    self._init_model()
    self._init_data()
    self._init_optimizer()
Functions
train
train() -> None

Run the SFT training loop.

Source code in src/openjarvis/learning/orchestrator/sft_trainer.py
def train(self) -> None:
    """Run the SFT training loop."""
    if not HAS_TORCH:
        raise RuntimeError(
            "PyTorch is required for training. "
            "Install with: pip install torch transformers"
        )

    for epoch in range(self.config.num_epochs):
        self._train_epoch(epoch)

        if (epoch + 1) % self.config.save_every_n_epochs == 0:
            self._save_checkpoint(epoch)

Episode dataclass

Episode(task_id: str, initial_prompt: str, steps: List[EpisodeStep] = list(), final_answer: str = '', ground_truth: str = '', correct: bool = False, total_energy_joules: float = 0.0, total_cost_usd: float = 0.0, total_latency_seconds: float = 0.0, total_tokens: int = 0, max_power_watts: float = 0.0, metadata: Dict[str, Any] = dict())

Complete RL episode with aggregate metrics.

Attributes
task_id instance-attribute
task_id: str

Unique task identifier.

initial_prompt instance-attribute
initial_prompt: str

Initial question/task.

steps class-attribute instance-attribute
steps: List[EpisodeStep] = field(default_factory=list)

Sequence of (action, observation) pairs.

final_answer class-attribute instance-attribute
final_answer: str = ''

Final answer produced by orchestrator.

ground_truth class-attribute instance-attribute
ground_truth: str = ''

Ground truth answer.

correct class-attribute instance-attribute
correct: bool = False

Whether final answer matches ground truth.

Functions
add_step
add_step(action: OrchestratorAction, observation: OrchestratorObservation) -> None

Add a step to the episode and update aggregate metrics.

Source code in src/openjarvis/learning/orchestrator/types.py
def add_step(
    self, action: OrchestratorAction, observation: OrchestratorObservation
) -> None:
    """Add a step to the episode and update aggregate metrics."""
    step = EpisodeStep(
        turn=len(self.steps),
        action=action,
        observation=observation,
    )
    self.steps.append(step)

    self.total_energy_joules += observation.energy_joules
    self.total_latency_seconds += observation.latency_seconds
    self.total_cost_usd += observation.cost_usd
    self.total_tokens += observation.tokens
    self.max_power_watts = max(self.max_power_watts, observation.power_watts)

    if action.is_final_answer:
        self.final_answer = observation.content
num_turns
num_turns() -> int

Return number of turns in episode.

Source code in src/openjarvis/learning/orchestrator/types.py
def num_turns(self) -> int:
    """Return number of turns in episode."""
    return len(self.steps)
compute_ipj
compute_ipj() -> float

Compute Intelligence Per Joule (IPJ).

Returns: IPJ score (higher is better). 0.0 if energy is zero or the answer is incorrect.

Source code in src/openjarvis/learning/orchestrator/types.py
def compute_ipj(self) -> float:
    """Compute Intelligence Per Joule (IPJ).

    Returns:
        IPJ score (higher is better).  0.0 if energy is zero or
        the answer is incorrect.
    """
    if self.total_energy_joules <= 0:
        return 0.0
    accuracy_score = 1.0 if self.correct else 0.0
    return accuracy_score / self.total_energy_joules
to_dict
to_dict() -> Dict[str, Any]

Convert episode to dictionary for serialization.

Source code in src/openjarvis/learning/orchestrator/types.py
def to_dict(self) -> Dict[str, Any]:
    """Convert episode to dictionary for serialization."""
    return {
        "task_id": self.task_id,
        "initial_prompt": self.initial_prompt,
        "steps": [
            {
                "turn": step.turn,
                "thought": step.action.thought,
                "tool": step.action.tool_name,
                "tool_input": step.action.tool_input,
                "observation": step.observation.content[:200],
                "energy_joules": step.observation.energy_joules,
                "latency_seconds": step.observation.latency_seconds,
                "cost_usd": step.observation.cost_usd,
            }
            for step in self.steps
        ],
        "final_answer": self.final_answer,
        "ground_truth": self.ground_truth,
        "correct": self.correct,
        "total_energy_joules": self.total_energy_joules,
        "total_latency_seconds": self.total_latency_seconds,
        "total_cost_usd": self.total_cost_usd,
        "total_tokens": self.total_tokens,
        "num_turns": self.num_turns(),
        "ipj": self.compute_ipj(),
    }

EpisodeState dataclass

EpisodeState(initial_prompt: str, history: List[Tuple[OrchestratorAction, OrchestratorObservation]] = list(), final_answer: Optional[str] = None)

Mutable state during episode execution.

Attributes
initial_prompt instance-attribute
initial_prompt: str

Initial task/question.

history class-attribute instance-attribute
history: List[Tuple[OrchestratorAction, OrchestratorObservation]] = field(default_factory=list)

History of (action, observation) pairs.

final_answer class-attribute instance-attribute
final_answer: Optional[str] = None

Final answer (set when is_final_answer action is taken).

Functions
add_turn
add_turn(action: OrchestratorAction, observation: OrchestratorObservation) -> None

Add a turn to the episode history.

Source code in src/openjarvis/learning/orchestrator/types.py
def add_turn(
    self,
    action: OrchestratorAction,
    observation: OrchestratorObservation,
) -> None:
    """Add a turn to the episode history."""
    self.history.append((action, observation))
    if action.is_final_answer:
        self.final_answer = observation.content
num_turns
num_turns() -> int

Return number of turns so far.

Source code in src/openjarvis/learning/orchestrator/types.py
def num_turns(self) -> int:
    """Return number of turns so far."""
    return len(self.history)
to_episode
to_episode(task_id: str, ground_truth: str, correct: bool) -> Episode

Convert state to Episode for reward computation.

Source code in src/openjarvis/learning/orchestrator/types.py
def to_episode(
    self, task_id: str, ground_truth: str, correct: bool
) -> Episode:
    """Convert state to Episode for reward computation."""
    episode = Episode(
        task_id=task_id,
        initial_prompt=self.initial_prompt,
        ground_truth=ground_truth,
        final_answer=self.final_answer or "",
        correct=correct,
    )
    for action, observation in self.history:
        episode.add_step(action, observation)
    return episode

EpisodeStep dataclass

EpisodeStep(turn: int, action: OrchestratorAction, observation: OrchestratorObservation)

Single step in an episode.

Attributes
turn instance-attribute
turn: int

Step number (0-indexed).

action instance-attribute

Action taken.

observation instance-attribute

Result of action.

OrchestratorAction dataclass

OrchestratorAction(thought: str, tool_name: str, tool_input: str, is_final_answer: bool = False)

Orchestrator action: thought + tool selection + tool input.

Attributes
thought instance-attribute
thought: str

Reasoning about what to do next.

tool_name instance-attribute
tool_name: str

Selected tool name (e.g., 'calculator', 'think').

tool_input instance-attribute
tool_input: str

Input/prompt to send to the tool.

is_final_answer class-attribute instance-attribute
is_final_answer: bool = False

Whether this action provides the final answer.

OrchestratorObservation dataclass

OrchestratorObservation(content: str, latency_seconds: float = 0.0, cost_usd: float = 0.0, energy_joules: float = 0.0, power_watts: float = 0.0, tokens: int = 0)

Result from executing an action, with flat telemetry fields.

Attributes
content instance-attribute
content: str

Tool response content.

latency_seconds class-attribute instance-attribute
latency_seconds: float = 0.0

Latency in seconds.

cost_usd class-attribute instance-attribute
cost_usd: float = 0.0

Cost in USD.

energy_joules class-attribute instance-attribute
energy_joules: float = 0.0

Energy consumed in joules.

power_watts class-attribute instance-attribute
power_watts: float = 0.0

Power usage in watts.

tokens class-attribute instance-attribute
tokens: int = 0

Tokens consumed.

PolicyOutput dataclass

PolicyOutput(thought: str, tool_name: str, tool_input: str, is_final_answer: bool = False, raw_text: str = '', confidence: float = 1.0)

Output from policy model prediction.

Attributes
thought instance-attribute
thought: str

Reasoning about what to do.

tool_name instance-attribute
tool_name: str

Selected tool.

tool_input instance-attribute
tool_input: str

Input for the tool.

is_final_answer class-attribute instance-attribute
is_final_answer: bool = False

Whether this provides the final answer.

raw_text class-attribute instance-attribute
raw_text: str = ''

Raw model output.

confidence class-attribute instance-attribute
confidence: float = 1.0

Confidence score (if available).

Functions

build_system_prompt

build_system_prompt(tool_names: Optional[List[str]] = None, *, tools: Optional[List['BaseTool']] = None) -> str

Build the complete system prompt for the given tools.

Args: tool_names: Tool names to include. If None, uses all tools from :data:TOOL_DESCRIPTIONS. This path is kept for backward compatibility with training pipelines. tools: Optional list of BaseTool instances. When provided, rich descriptions are auto-generated from ToolSpec, replacing the hardcoded :data:TOOL_DESCRIPTIONS lookup. Unknown / MCP tools get full descriptions instead of "Tool: {name}".

Returns: Complete system prompt string.

Source code in src/openjarvis/learning/orchestrator/prompt_registry.py
def build_system_prompt(
    tool_names: Optional[List[str]] = None,
    *,
    tools: Optional[List["BaseTool"]] = None,
) -> str:
    """Build the complete system prompt for the given tools.

    Args:
        tool_names: Tool names to include.  If ``None``, uses all
            tools from :data:`TOOL_DESCRIPTIONS`.  This path is kept for
            backward compatibility with training pipelines.
        tools: Optional list of ``BaseTool`` instances.  When provided,
            rich descriptions are auto-generated from ``ToolSpec``,
            replacing the hardcoded :data:`TOOL_DESCRIPTIONS` lookup.
            Unknown / MCP tools get full descriptions instead of
            ``"Tool: {name}"``.

    Returns:
        Complete system prompt string.
    """
    # When BaseTool instances are provided, generate descriptions from spec
    if tools is not None:
        from openjarvis.tools._stubs import build_tool_descriptions

        desc_text = build_tool_descriptions(tools, include_cost=True)

        # Auto-generate tool selection guide by grouping tools by category
        by_cat: Dict[str, List[str]] = {}
        for t in tools:
            cat = t.spec.category or "llm"
            by_cat.setdefault(cat, []).append(t.spec.name)

        guide: list[str] = ["Choose tools based on task type:\n"]
        for cat, names in by_cat.items():
            label = _CAT_LABELS.get(cat, cat.upper())
            guide.append(f"{label}:")
            for n in names:
                guide.append(f"- {n}")
            guide.append("")

        return SYSTEM_PROMPT_TEMPLATE.format(
            tools_description=desc_text,
            tool_selection_guide="\n".join(guide),
        )

    # Backward-compat: tool_names-only path (used by training pipelines)
    if tool_names is None:
        tool_names = list(TOOL_DESCRIPTIONS)

    # Tool descriptions
    desc_lines: list[str] = []
    for name in tool_names:
        if name in TOOL_DESCRIPTIONS:
            desc = TOOL_DESCRIPTIONS[name]["description"]
        else:
            desc = f"Tool: {name}"
        desc_lines.append(f"- {name}: {desc}")

    # Group tools by category
    by_cat_names: Dict[str, List[str]] = {}
    for name in tool_names:
        cat = (
            TOOL_DESCRIPTIONS[name]["category"]
            if name in TOOL_DESCRIPTIONS
            else "llm"
        )
        by_cat_names.setdefault(cat, []).append(name)

    guide = [
        "Choose tools based on task type:\n",
    ]

    # Math
    math_lines: list[str] = []
    if "calculator" in tool_names:
        math_lines.append(
            "- Simple arithmetic/algebra -> calculator (instant, accurate)"
        )
    if "code_interpreter" in tool_names:
        math_lines.append(
            "- Numerical algorithms -> code_interpreter (programmable)"
        )
    if math_lines:
        guide.append("MATH PROBLEMS:")
        guide.extend(math_lines)
        guide.append("")

    # Coding
    code_lines: list[str] = []
    if "code_interpreter" in tool_names:
        code_lines.append(
            "- Algorithm implementation/execution -> code_interpreter"
        )
    if code_lines:
        guide.append("CODING TASKS:")
        guide.extend(code_lines)
        guide.append("")

    # Reasoning
    reasoning_lines: list[str] = []
    if "think" in tool_names:
        reasoning_lines.append(
            "- Step-by-step analysis -> think (organize thoughts first)"
        )
    llm_tools = by_cat_names.get("llm", [])
    if llm_tools:
        reasoning_lines.append(
            f"- Complex reasoning -> {', '.join(llm_tools)}"
        )
    if reasoning_lines:
        guide.append("REASONING/LOGIC:")
        guide.extend(reasoning_lines)
        guide.append("")

    # General Q&A
    general_lines: list[str] = []
    if "web_search" in tool_names:
        general_lines.append("- Current events/recent info -> web_search")
    memory_tools = by_cat_names.get("memory", [])
    if memory_tools:
        general_lines.append(
            f"- Stored knowledge -> {', '.join(memory_tools)}"
        )
    if general_lines:
        guide.append("GENERAL Q&A / FACTUAL:")
        guide.extend(general_lines)
        guide.append("")

    return SYSTEM_PROMPT_TEMPLATE.format(
        tools_description="\n".join(desc_lines),
        tool_selection_guide="\n".join(guide),
    )

extract_answer

extract_answer(text: str) -> str

Extract the core answer from a potentially verbose response.

Handles patterns like: - "The answer is 4" - "Result: 4.0" - "4" (unchanged) - "Therefore, the answer is approximately 4"

Source code in src/openjarvis/learning/orchestrator/types.py
def extract_answer(text: str) -> str:
    """Extract the core answer from a potentially verbose response.

    Handles patterns like:
    - "The answer is 4"
    - "Result: 4.0"
    - "4" (unchanged)
    - "Therefore, the answer is approximately 4"
    """
    text = text.strip()

    patterns = [
        r"(?:the\s+)?answer\s+is[:\s]+(.+?)(?:\.|$)",
        r"result[:\s]+(.+?)(?:\.|$)",
        r"=\s*(.+?)(?:\.|$)",
        r"therefore[,\s]+(?:the\s+)?(?:answer\s+is\s+)?(.+?)(?:\.|$)",
    ]

    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).strip()

    return text

grade_answer

grade_answer(predicted: str, expected: str, tolerance: float = 1e-06) -> bool

Grade an answer against expected, with smart matching.

Handles: - Exact string match (case-insensitive) - Numeric comparison with tolerance - Answer extraction from verbose responses

Args: predicted: The model's answer. expected: Ground truth answer. tolerance: Tolerance for numeric comparisons.

Returns: True if answer is correct.

Source code in src/openjarvis/learning/orchestrator/types.py
def grade_answer(
    predicted: str, expected: str, tolerance: float = 1e-6
) -> bool:
    """Grade an answer against expected, with smart matching.

    Handles:
    - Exact string match (case-insensitive)
    - Numeric comparison with tolerance
    - Answer extraction from verbose responses

    Args:
        predicted: The model's answer.
        expected: Ground truth answer.
        tolerance: Tolerance for numeric comparisons.

    Returns:
        True if answer is correct.
    """
    predicted = predicted.strip()
    expected = expected.strip()

    # Exact match (case-insensitive)
    if predicted.lower() == expected.lower():
        return True

    # Try extracting core answer
    pred_extracted = extract_answer(predicted)
    exp_extracted = extract_answer(expected)

    if pred_extracted.lower() == exp_extracted.lower():
        return True

    # Try numeric comparison
    pred_num = normalize_number(pred_extracted)
    exp_num = normalize_number(exp_extracted)

    if pred_num is not None and exp_num is not None:
        if exp_num == 0:
            return abs(pred_num) < tolerance
        return abs(pred_num - exp_num) / abs(exp_num) < tolerance

    return False

normalize_number

normalize_number(s: str) -> Optional[float]

Try to parse a string as a number.

Returns None if not a valid number.

Source code in src/openjarvis/learning/orchestrator/types.py
def normalize_number(s: str) -> Optional[float]:
    """Try to parse a string as a number.

    Returns None if not a valid number.
    """
    s = s.strip().lower()
    s = re.sub(r"[,\s]", "", s)  # Remove commas and spaces
    s = re.sub(r"\.0+$", "", s)  # Remove trailing .0

    try:
        return float(s)
    except ValueError:
        return None