Skip to content

policy_model

policy_model

Policy model wrapper for orchestrator training.

Adapted from IPW's policy.py. Wraps a HuggingFace causal LM (e.g. Qwen3-1.7B) to predict structured actions in the orchestrator environment. All torch/transformers imports are guarded so the module can be imported without GPU dependencies.

Classes

OrchestratorPolicyModel

OrchestratorPolicyModel(model: Any = None, tokenizer: Any = None, max_tokens: int = 256, temperature: float = 0.7)

Wrapper around a causal LM for orchestrator policy prediction.

Input format (prompt)::

Task: {initial_prompt}

Available tools: calculator, think, ...

History:
Turn 1:
  Thought: ...
  Tool: ...
  Observation: ...

What should you do next?
Format your response as:
THOUGHT: [your reasoning]
TOOL: [tool_name]
INPUT: [input for tool]

Output format (from model)::

THOUGHT: [reasoning]
TOOL: [tool_name]
INPUT: [input]
--- or ---
FINAL_ANSWER: [answer]
Source code in src/openjarvis/learning/orchestrator/policy_model.py
def __init__(
    self,
    model: Any = None,
    tokenizer: Any = None,
    max_tokens: int = 256,
    temperature: float = 0.7,
) -> None:
    self.model = model
    self.tokenizer = tokenizer
    self.max_tokens = max_tokens
    self.temperature = temperature
Functions
from_pretrained classmethod
from_pretrained(model_name: str = 'Qwen/Qwen3-1.7B', gradient_checkpointing: bool = False, load_in_8bit: bool = False, device: Optional[str] = None, **kwargs: Any) -> 'OrchestratorPolicyModel'

Load model from a HuggingFace checkpoint.

Raises ImportError if transformers is not installed.

Source code in src/openjarvis/learning/orchestrator/policy_model.py
@classmethod
def from_pretrained(
    cls,
    model_name: str = "Qwen/Qwen3-1.7B",
    gradient_checkpointing: bool = False,
    load_in_8bit: bool = False,
    device: Optional[str] = None,
    **kwargs: Any,
) -> "OrchestratorPolicyModel":
    """Load model from a HuggingFace checkpoint.

    Raises ``ImportError`` if ``transformers`` is not installed.
    """
    import torch as _torch
    from transformers import AutoModelForCausalLM, AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained(model_name)

    model_kwargs: dict[str, Any] = {"torch_dtype": _torch.bfloat16}

    if load_in_8bit:
        try:
            from transformers import BitsAndBytesConfig

            model_kwargs["quantization_config"] = BitsAndBytesConfig(
                load_in_8bit=True
            )
        except ImportError as exc:
            logger.debug("FP8 not available, falling back to BF16: %s", exc)

    if device is not None:
        if device == "auto":
            model_kwargs["device_map"] = "auto"
        else:
            model_kwargs["device_map"] = {"": device}

    model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)

    if gradient_checkpointing and hasattr(
        model, "gradient_checkpointing_enable"
    ):
        model.gradient_checkpointing_enable(
            gradient_checkpointing_kwargs={"use_reentrant": False}
        )

    return cls(model=model, tokenizer=tokenizer, **kwargs)
from_checkpoint classmethod
from_checkpoint(checkpoint_path: str, **kwargs: Any) -> 'OrchestratorPolicyModel'

Load from a previously saved checkpoint directory.

Source code in src/openjarvis/learning/orchestrator/policy_model.py
@classmethod
def from_checkpoint(
    cls, checkpoint_path: str, **kwargs: Any
) -> "OrchestratorPolicyModel":
    """Load from a previously saved checkpoint directory."""
    return cls.from_pretrained(checkpoint_path, **kwargs)
predict_action
predict_action(state: EpisodeState, available_tools: List[str]) -> OrchestratorAction

Predict the next action given current state.

Source code in src/openjarvis/learning/orchestrator/policy_model.py
def predict_action(
    self,
    state: EpisodeState,
    available_tools: List[str],
) -> OrchestratorAction:
    """Predict the next action given current state."""
    prompt = self._build_prompt(state, available_tools)

    if self.model is None:
        raise RuntimeError(
            "Cannot generate actions without a loaded model. "
            "Load with OrchestratorPolicyModel.from_pretrained() first."
        )

    output_text = self._generate(prompt)
    policy_output = self._parse_output(output_text, available_tools)
    return OrchestratorAction(
        thought=policy_output.thought,
        tool_name=policy_output.tool_name,
        tool_input=policy_output.tool_input,
        is_final_answer=policy_output.is_final_answer,
    )
save
save(path: str) -> None

Save model and tokenizer to path.

Source code in src/openjarvis/learning/orchestrator/policy_model.py
def save(self, path: str) -> None:
    """Save model and tokenizer to *path*."""
    if self.model is not None:
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)