Index
orchestrator
¶
Orchestrator training infrastructure — SFT and GRPO pipelines.
Provides structured-mode training for the OrchestratorAgent with:
- Episode types: Action, Observation, Episode, EpisodeState
- Reward: Multi-objective reward balancing accuracy, cost, energy, latency, power
- Prompt registry: Canonical system prompts for structured mode
- Policy model: HuggingFace LM wrapper for action prediction
- Environment: RL environment using OpenJarvis ToolExecutor
- SFT trainer: Supervised fine-tuning on successful trajectories
- GRPO trainer: Group Relative Policy Optimization
Importing this module triggers registration of orchestrator_sft and
orchestrator_grpo in :class:~openjarvis.core.registry.LearningRegistry.
Classes¶
OrchestratorEnvironment
¶
OrchestratorEnvironment(tools: List[BaseTool], max_turns: int = 10)
RL environment that executes tools via OpenJarvis ToolExecutor.
| PARAMETER | DESCRIPTION |
|---|---|
tools
|
List of :class:
TYPE:
|
max_turns
|
Maximum number of turns per episode.
TYPE:
|
Source code in src/openjarvis/learning/orchestrator/environment.py
Functions¶
reset
¶
reset(task: str) -> EpisodeState
Reset the environment for a new episode.
Args: task: The initial task/question.
Returns:
A fresh :class:EpisodeState.
Source code in src/openjarvis/learning/orchestrator/environment.py
step
¶
step(state: EpisodeState, action: OrchestratorAction) -> Tuple[EpisodeState, OrchestratorObservation]
Execute one step: dispatch the tool and observe the result.
Raises: ValueError: If the tool is not available or max turns exceeded.
Source code in src/openjarvis/learning/orchestrator/environment.py
get_available_tools
¶
OrchestratorGRPOConfig
dataclass
¶
OrchestratorGRPOConfig(model_name: str = 'Qwen/Qwen3-1.7B', max_prompt_length: int = 24000, max_response_length: int = 8768, num_epochs: int = 10, batch_size: int = 16, learning_rate: float = 1e-06, max_grad_norm: float = 1.0, num_samples_per_prompt: int = 8, temperature: float = 1.0, kl_coef: float = 0.0001, clip_ratio: float = 0.2, available_tools: List[str] = (lambda: ['calculator', 'think', 'code_interpreter', 'web_search'])(), max_turns: int = 10, checkpoint_dir: str = 'checkpoints/orchestrator_grpo', save_every_n_epochs: int = 1, keep_last_n: int = 3, gradient_checkpointing: bool = True, use_8bit_ref: bool = True, use_8bit_optimizer: bool = False)
Configuration for orchestrator GRPO training.
OrchestratorGRPOTrainer
¶
OrchestratorGRPOTrainer(config: OrchestratorGRPOConfig)
GRPO trainer for orchestrator policy.
torch must be installed to call :meth:train.
Source code in src/openjarvis/learning/orchestrator/grpo_trainer.py
Functions¶
train
¶
Run the GRPO training loop.
Source code in src/openjarvis/learning/orchestrator/grpo_trainer.py
OrchestratorPolicyModel
¶
OrchestratorPolicyModel(model: Any = None, tokenizer: Any = None, max_tokens: int = 256, temperature: float = 0.7)
Wrapper around a causal LM for orchestrator policy prediction.
Input format (prompt)::
Task: {initial_prompt}
Available tools: calculator, think, ...
History:
Turn 1:
Thought: ...
Tool: ...
Observation: ...
What should you do next?
Format your response as:
THOUGHT: [your reasoning]
TOOL: [tool_name]
INPUT: [input for tool]
Output format (from model)::
THOUGHT: [reasoning]
TOOL: [tool_name]
INPUT: [input]
--- or ---
FINAL_ANSWER: [answer]
Source code in src/openjarvis/learning/orchestrator/policy_model.py
Functions¶
from_pretrained
classmethod
¶
from_pretrained(model_name: str = 'Qwen/Qwen3-1.7B', gradient_checkpointing: bool = False, load_in_8bit: bool = False, device: Optional[str] = None, **kwargs: Any) -> 'OrchestratorPolicyModel'
Load model from a HuggingFace checkpoint.
Raises ImportError if transformers is not installed.
Source code in src/openjarvis/learning/orchestrator/policy_model.py
from_checkpoint
classmethod
¶
Load from a previously saved checkpoint directory.
Source code in src/openjarvis/learning/orchestrator/policy_model.py
predict_action
¶
predict_action(state: EpisodeState, available_tools: List[str]) -> OrchestratorAction
Predict the next action given current state.
Source code in src/openjarvis/learning/orchestrator/policy_model.py
save
¶
AdaptiveRewardWeights
¶
AdaptiveRewardWeights(initial_alpha: float = 0.6, final_alpha: float = 0.3, initial_beta_cost: float = 0.1, final_beta_cost: float = 0.15, initial_beta_energy: float = 0.1, final_beta_energy: float = 0.2, initial_gamma_latency: float = 0.1, final_gamma_latency: float = 0.15, initial_gamma_power: float = 0.1, final_gamma_power: float = 0.2, total_steps: int = 10000)
Adaptive reward weights that shift during training.
Early training focuses on accuracy (higher alpha). Late training optimises efficiency (higher cost/energy/power weights).
Source code in src/openjarvis/learning/orchestrator/reward.py
Functions¶
get_weights
¶
get_weights(current_step: int) -> RewardWeights
Get weights for current_step via linear interpolation.
Source code in src/openjarvis/learning/orchestrator/reward.py
MultiObjectiveReward
¶
MultiObjectiveReward(weights: RewardWeights, normalizers: Normalizers)
Multi-objective reward combining accuracy, cost, energy, latency, power.
Formula::
reward = alpha * accuracy
- beta_cost * (cost / cost_scale)
- beta_energy * (energy / energy_scale)
- gamma_latency * (latency / latency_scale)
- gamma_power * (power / power_scale)
Source code in src/openjarvis/learning/orchestrator/reward.py
Functions¶
compute
¶
compute(episode: Episode) -> float
Compute scalar reward for an episode.
Source code in src/openjarvis/learning/orchestrator/reward.py
compute_with_breakdown
¶
compute_with_breakdown(episode: Episode) -> Dict[str, float]
Compute reward with detailed per-component breakdown.
Source code in src/openjarvis/learning/orchestrator/reward.py
Normalizers
dataclass
¶
Normalizers(energy_scale: float = 100.0, cost_scale: float = 0.1, latency_scale: float = 30.0, power_scale: float = 200.0)
Normalization constants for reward scaling.
These are typical values used to scale metrics to similar ranges. Tune based on your specific tools and tasks.
RewardWeights
dataclass
¶
RewardWeights(alpha: float = 0.4, beta_cost: float = 0.15, beta_energy: float = 0.15, gamma_latency: float = 0.15, gamma_power: float = 0.15)
Weights for multi-objective reward function.
Each metric has its own coefficient: - alpha: Accuracy (correctness of answer) - beta_cost: API/cloud cost in USD - beta_energy: Energy consumption in joules - gamma_latency: Response time in seconds - gamma_power: Peak power usage in watts
OrchestratorSFTConfig
dataclass
¶
OrchestratorSFTConfig(model_name: str = 'Qwen/Qwen3-1.7B', max_seq_length: int = 4096, num_epochs: int = 3, batch_size: int = 8, learning_rate: float = 2e-05, weight_decay: float = 0.01, warmup_ratio: float = 0.1, max_grad_norm: float = 1.0, teacher_engine_key: str = '', teacher_model: str = '', traces_per_query: int = 2, max_attempts_per_trace: int = 3, generation_temperature: float = 0.7, trace_cache_path: str = 'data/orchestrator_sft_traces.jsonl', regenerate_traces: bool = False, checkpoint_dir: str = 'checkpoints/orchestrator_sft', save_every_n_epochs: int = 1, log_dir: str = 'logs/orchestrator_sft', log_every_n_steps: int = 10, use_wandb: bool = False, gradient_checkpointing: bool = True, available_tools: List[str] = (lambda: ['calculator', 'think', 'code_interpreter', 'web_search'])())
Configuration for orchestrator SFT training.
OrchestratorSFTDataset
¶
Dataset for SFT training from generated trace JSONL files.
Source code in src/openjarvis/learning/orchestrator/sft_trainer.py
OrchestratorSFTTrainer
¶
OrchestratorSFTTrainer(config: OrchestratorSFTConfig)
SFT trainer for orchestrator policy.
Performs standard next-token cross-entropy loss on successful
trajectories. torch must be installed to call :meth:train.
Source code in src/openjarvis/learning/orchestrator/sft_trainer.py
Functions¶
train
¶
Run the SFT training loop.
Source code in src/openjarvis/learning/orchestrator/sft_trainer.py
Episode
dataclass
¶
Episode(task_id: str, initial_prompt: str, steps: List[EpisodeStep] = list(), final_answer: str = '', ground_truth: str = '', correct: bool = False, total_energy_joules: float = 0.0, total_cost_usd: float = 0.0, total_latency_seconds: float = 0.0, total_tokens: int = 0, max_power_watts: float = 0.0, metadata: Dict[str, Any] = dict())
Complete RL episode with aggregate metrics.
Attributes¶
steps
class-attribute
instance-attribute
¶
steps: List[EpisodeStep] = field(default_factory=list)
Sequence of (action, observation) pairs.
final_answer
class-attribute
instance-attribute
¶
Final answer produced by orchestrator.
correct
class-attribute
instance-attribute
¶
Whether final answer matches ground truth.
Functions¶
add_step
¶
add_step(action: OrchestratorAction, observation: OrchestratorObservation) -> None
Add a step to the episode and update aggregate metrics.
Source code in src/openjarvis/learning/orchestrator/types.py
num_turns
¶
compute_ipj
¶
Compute Intelligence Per Joule (IPJ).
Returns: IPJ score (higher is better). 0.0 if energy is zero or the answer is incorrect.
Source code in src/openjarvis/learning/orchestrator/types.py
to_dict
¶
Convert episode to dictionary for serialization.
Source code in src/openjarvis/learning/orchestrator/types.py
EpisodeState
dataclass
¶
EpisodeState(initial_prompt: str, history: List[Tuple[OrchestratorAction, OrchestratorObservation]] = list(), final_answer: Optional[str] = None)
Mutable state during episode execution.
Attributes¶
history
class-attribute
instance-attribute
¶
history: List[Tuple[OrchestratorAction, OrchestratorObservation]] = field(default_factory=list)
History of (action, observation) pairs.
final_answer
class-attribute
instance-attribute
¶
Final answer (set when is_final_answer action is taken).
Functions¶
add_turn
¶
add_turn(action: OrchestratorAction, observation: OrchestratorObservation) -> None
Add a turn to the episode history.
Source code in src/openjarvis/learning/orchestrator/types.py
num_turns
¶
to_episode
¶
to_episode(task_id: str, ground_truth: str, correct: bool) -> Episode
Convert state to Episode for reward computation.
Source code in src/openjarvis/learning/orchestrator/types.py
EpisodeStep
dataclass
¶
EpisodeStep(turn: int, action: OrchestratorAction, observation: OrchestratorObservation)
Single step in an episode.
Attributes¶
OrchestratorAction
dataclass
¶
OrchestratorObservation
dataclass
¶
OrchestratorObservation(content: str, latency_seconds: float = 0.0, cost_usd: float = 0.0, energy_joules: float = 0.0, power_watts: float = 0.0, tokens: int = 0)
PolicyOutput
dataclass
¶
PolicyOutput(thought: str, tool_name: str, tool_input: str, is_final_answer: bool = False, raw_text: str = '', confidence: float = 1.0)
Functions¶
build_system_prompt
¶
build_system_prompt(tool_names: Optional[List[str]] = None, *, tools: Optional[List['BaseTool']] = None) -> str
Build the complete system prompt for the given tools.
Args:
tool_names: Tool names to include. If None, uses all
tools from :data:TOOL_DESCRIPTIONS. This path is kept for
backward compatibility with training pipelines.
tools: Optional list of BaseTool instances. When provided,
rich descriptions are auto-generated from ToolSpec,
replacing the hardcoded :data:TOOL_DESCRIPTIONS lookup.
Unknown / MCP tools get full descriptions instead of
"Tool: {name}".
Returns: Complete system prompt string.
Source code in src/openjarvis/learning/orchestrator/prompt_registry.py
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 | |
extract_answer
¶
Extract the core answer from a potentially verbose response.
Handles patterns like: - "The answer is 4" - "Result: 4.0" - "4" (unchanged) - "Therefore, the answer is approximately 4"
Source code in src/openjarvis/learning/orchestrator/types.py
grade_answer
¶
Grade an answer against expected, with smart matching.
Handles: - Exact string match (case-insensitive) - Numeric comparison with tolerance - Answer extraction from verbose responses
Args: predicted: The model's answer. expected: Ground truth answer. tolerance: Tolerance for numeric comparisons.
Returns: True if answer is correct.
Source code in src/openjarvis/learning/orchestrator/types.py
normalize_number
¶
Try to parse a string as a number.
Returns None if not a valid number.