Skip to content

stream_bridge

stream_bridge

Bridge sync agent.run() + EventBus events to an async SSE generator.

Subscribes to EventBus callbacks that push events into an asyncio.Queue, runs agent.run() in a background thread, and yields SSE-formatted strings from the queue for consumption by FastAPI's StreamingResponse.

Classes

AgentStreamBridge

AgentStreamBridge(agent: BaseAgent, bus: EventBus, model: str, request: ChatCompletionRequest)

Bridge between a synchronous agent and an async SSE stream.

Pattern: 1. Subscribe EventBus callbacks that push events into an asyncio.Queue via loop.call_soon_threadsafe(). 2. Run agent.run() in a thread via asyncio.to_thread(). 3. Async generator reads from queue and yields SSE-formatted strings. 4. Unsubscribe from EventBus in finally block.

Source code in src/openjarvis/server/stream_bridge.py
def __init__(
    self,
    agent: BaseAgent,
    bus: EventBus,
    model: str,
    request: ChatCompletionRequest,
) -> None:
    self._agent = agent
    self._bus = bus
    self._model = model
    self._request = request
    self._chunk_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
    self._queue: asyncio.Queue = asyncio.Queue()
    self._callbacks: dict[EventType, object] = {}
Functions
stream async
stream() -> AsyncGenerator[str, None]

Async generator that yields SSE-formatted strings.

Source code in src/openjarvis/server/stream_bridge.py
async def stream(self) -> AsyncGenerator[str, None]:
    """Async generator that yields SSE-formatted strings."""
    self._subscribe_all()

    # Kick off agent.run() in a background thread
    loop = asyncio.get_event_loop()
    agent_task = asyncio.ensure_future(asyncio.to_thread(self._run_agent))

    def _on_done(fut):
        loop.call_soon_threadsafe(self._queue.put_nowait, _DONE)

    agent_task.add_done_callback(_on_done)

    try:
        # Send initial role chunk (OpenAI-compatible)
        first_chunk = ChatCompletionChunk(
            id=self._chunk_id,
            model=self._model,
            choices=[
                StreamChoice(
                    delta=DeltaMessage(role="assistant"),
                )
            ],
        )
        yield f"data: {first_chunk.model_dump_json()}\n\n"

        # Drain queue until the agent finishes
        while True:
            item = await self._queue.get()

            if item is _DONE:
                break

            if isinstance(item, Event):
                sse_name = _EVENT_MAP.get(item.event_type)
                if sse_name:
                    yield self._format_named_event(sse_name, item.data)

        # Agent is done -- retrieve result
        try:
            agent_result = agent_task.result()
        except Exception as exc:
            import logging

            logger = logging.getLogger("openjarvis.server")
            logger.error("Agent stream error: %s", exc, exc_info=True)

            error_str = str(exc)
            if "context length" in error_str.lower() or (
                "400" in error_str and "too long" in error_str.lower()
            ):
                error_content = (
                    "The input is too long for the model's context window. "
                    "Please try a shorter message."
                )
            elif "400" in error_str:
                error_content = f"The model returned an error: {error_str}"
            else:
                error_content = f"Sorry, an error occurred: {error_str}"
            error_chunk = ChatCompletionChunk(
                id=self._chunk_id,
                model=self._model,
                choices=[
                    StreamChoice(
                        delta=DeltaMessage(content=error_content),
                        finish_reason="stop",
                    )
                ],
            )
            yield f"data: {error_chunk.model_dump_json()}\n\n"
            yield "data: [DONE]\n\n"
            return

        # Emit tool results metadata if any
        tool_results_data = []
        for tr in agent_result.tool_results:
            tool_results_data.append(
                {
                    "tool_name": tr.tool_name,
                    "success": tr.success,
                    "output": tr.content,
                    "latency_ms": tr.latency_seconds * 1000,
                }
            )

        if tool_results_data:
            yield self._format_named_event(
                "tool_results",
                {"results": tool_results_data},
            )

        # Stream content using real LLM token streaming via
        # engine.stream_full() when the engine is available.
        content = agent_result.content or ""
        engine = getattr(self._agent, "_engine", None)
        used_real_streaming = False

        if engine is not None and hasattr(engine, "stream_full") and content:
            # Re-stream using the engine for real token delivery.
            # Build the same messages the agent used for its final turn.
            try:
                from openjarvis.core.types import Message as MsgType
                from openjarvis.core.types import Role as RoleType

                replay_messages = []
                for m in self._request.messages:
                    role = (
                        RoleType(m.role)
                        if m.role in {r.value for r in RoleType}
                        else RoleType.USER
                    )
                    replay_messages.append(
                        MsgType(
                            role=role,
                            content=m.content or "",
                            name=m.name,
                            tool_call_id=m.tool_call_id,
                        )
                    )

                async for sc in engine.stream_full(
                    replay_messages,
                    model=self._model,
                ):
                    if sc.content:
                        chunk = ChatCompletionChunk(
                            id=self._chunk_id,
                            model=self._model,
                            choices=[
                                StreamChoice(
                                    delta=DeltaMessage(content=sc.content),
                                )
                            ],
                        )
                        yield f"data: {chunk.model_dump_json()}\n\n"
                used_real_streaming = True
            except Exception as stream_exc:
                import logging as _logging

                _logger = _logging.getLogger("openjarvis.server")
                _logger.warning(
                    "Real streaming failed, falling back to word replay: %s",
                    stream_exc,
                )

        # Fallback: word-by-word replay if real streaming was not used
        if not used_real_streaming and content:
            words = content.split(" ")
            for i, word in enumerate(words):
                token = word if i == 0 else " " + word
                chunk = ChatCompletionChunk(
                    id=self._chunk_id,
                    model=self._model,
                    choices=[
                        StreamChoice(
                            delta=DeltaMessage(content=token),
                        )
                    ],
                )
                yield f"data: {chunk.model_dump_json()}\n\n"
                await asyncio.sleep(0.012)

        # Final chunk: finish_reason + usage
        prompt_tokens = agent_result.metadata.get("prompt_tokens", 0)
        completion_tokens = agent_result.metadata.get(
            "completion_tokens",
            0,
        )
        total_tokens = agent_result.metadata.get("total_tokens", 0)
        if total_tokens == 0:
            # Fallback: estimate from request messages (incl. system) + content
            completion_tokens = max(len(content) // 4, 1)
            prompt_tokens = _estimate_prompt_tokens(self._request.messages)
            total_tokens = prompt_tokens + completion_tokens

        final_chunk = ChatCompletionChunk(
            id=self._chunk_id,
            model=self._model,
            choices=[
                StreamChoice(
                    delta=DeltaMessage(),
                    finish_reason="stop",
                )
            ],
        )
        final_data = json.loads(final_chunk.model_dump_json())
        final_data["usage"] = UsageInfo(
            prompt_tokens=prompt_tokens,
            completion_tokens=completion_tokens,
            total_tokens=total_tokens,
        ).model_dump()
        yield f"data: {json.dumps(final_data)}\n\n"

        yield "data: [DONE]\n\n"

    except Exception:
        # On error, cancel the agent task if still running
        if not agent_task.done():
            agent_task.cancel()
        raise
    finally:
        self._unsubscribe_all()

Functions

create_agent_stream async

create_agent_stream(agent: BaseAgent, bus: EventBus, model: str, request: ChatCompletionRequest) -> StreamingResponse

Create an AgentStreamBridge and return a FastAPI StreamingResponse.

Source code in src/openjarvis/server/stream_bridge.py
async def create_agent_stream(
    agent: BaseAgent,
    bus: EventBus,
    model: str,
    request: ChatCompletionRequest,
) -> StreamingResponse:
    """Create an AgentStreamBridge and return a FastAPI StreamingResponse."""
    bridge = AgentStreamBridge(agent, bus, model, request)
    return StreamingResponse(
        bridge.stream(),
        media_type="text/event-stream",
        headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
    )