async def stream(self) -> AsyncGenerator[str, None]:
"""Async generator that yields SSE-formatted strings."""
self._subscribe_all()
# Kick off agent.run() in a background thread
loop = asyncio.get_event_loop()
agent_task = asyncio.ensure_future(asyncio.to_thread(self._run_agent))
def _on_done(fut):
loop.call_soon_threadsafe(self._queue.put_nowait, _DONE)
agent_task.add_done_callback(_on_done)
try:
# Send initial role chunk (OpenAI-compatible)
first_chunk = ChatCompletionChunk(
id=self._chunk_id,
model=self._model,
choices=[
StreamChoice(
delta=DeltaMessage(role="assistant"),
)
],
)
yield f"data: {first_chunk.model_dump_json()}\n\n"
# Drain queue until the agent finishes
while True:
item = await self._queue.get()
if item is _DONE:
break
if isinstance(item, Event):
sse_name = _EVENT_MAP.get(item.event_type)
if sse_name:
yield self._format_named_event(sse_name, item.data)
# Agent is done -- retrieve result
try:
agent_result = agent_task.result()
except Exception as exc:
import logging
logger = logging.getLogger("openjarvis.server")
logger.error("Agent stream error: %s", exc, exc_info=True)
error_str = str(exc)
if "context length" in error_str.lower() or (
"400" in error_str and "too long" in error_str.lower()
):
error_content = (
"The input is too long for the model's context window. "
"Please try a shorter message."
)
elif "400" in error_str:
error_content = f"The model returned an error: {error_str}"
else:
error_content = f"Sorry, an error occurred: {error_str}"
error_chunk = ChatCompletionChunk(
id=self._chunk_id,
model=self._model,
choices=[
StreamChoice(
delta=DeltaMessage(content=error_content),
finish_reason="stop",
)
],
)
yield f"data: {error_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
return
# Emit tool results metadata if any
tool_results_data = []
for tr in agent_result.tool_results:
tool_results_data.append(
{
"tool_name": tr.tool_name,
"success": tr.success,
"output": tr.content,
"latency_ms": tr.latency_seconds * 1000,
}
)
if tool_results_data:
yield self._format_named_event(
"tool_results",
{"results": tool_results_data},
)
# Stream content using real LLM token streaming via
# engine.stream_full() when the engine is available.
content = agent_result.content or ""
engine = getattr(self._agent, "_engine", None)
used_real_streaming = False
if engine is not None and hasattr(engine, "stream_full") and content:
# Re-stream using the engine for real token delivery.
# Build the same messages the agent used for its final turn.
try:
from openjarvis.core.types import Message as MsgType
from openjarvis.core.types import Role as RoleType
replay_messages = []
for m in self._request.messages:
role = (
RoleType(m.role)
if m.role in {r.value for r in RoleType}
else RoleType.USER
)
replay_messages.append(
MsgType(
role=role,
content=m.content or "",
name=m.name,
tool_call_id=m.tool_call_id,
)
)
async for sc in engine.stream_full(
replay_messages,
model=self._model,
):
if sc.content:
chunk = ChatCompletionChunk(
id=self._chunk_id,
model=self._model,
choices=[
StreamChoice(
delta=DeltaMessage(content=sc.content),
)
],
)
yield f"data: {chunk.model_dump_json()}\n\n"
used_real_streaming = True
except Exception as stream_exc:
import logging as _logging
_logger = _logging.getLogger("openjarvis.server")
_logger.warning(
"Real streaming failed, falling back to word replay: %s",
stream_exc,
)
# Fallback: word-by-word replay if real streaming was not used
if not used_real_streaming and content:
words = content.split(" ")
for i, word in enumerate(words):
token = word if i == 0 else " " + word
chunk = ChatCompletionChunk(
id=self._chunk_id,
model=self._model,
choices=[
StreamChoice(
delta=DeltaMessage(content=token),
)
],
)
yield f"data: {chunk.model_dump_json()}\n\n"
await asyncio.sleep(0.012)
# Final chunk: finish_reason + usage
prompt_tokens = agent_result.metadata.get("prompt_tokens", 0)
completion_tokens = agent_result.metadata.get(
"completion_tokens",
0,
)
total_tokens = agent_result.metadata.get("total_tokens", 0)
if total_tokens == 0:
# Fallback: estimate from request messages (incl. system) + content
completion_tokens = max(len(content) // 4, 1)
prompt_tokens = _estimate_prompt_tokens(self._request.messages)
total_tokens = prompt_tokens + completion_tokens
final_chunk = ChatCompletionChunk(
id=self._chunk_id,
model=self._model,
choices=[
StreamChoice(
delta=DeltaMessage(),
finish_reason="stop",
)
],
)
final_data = json.loads(final_chunk.model_dump_json())
final_data["usage"] = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
).model_dump()
yield f"data: {json.dumps(final_data)}\n\n"
yield "data: [DONE]\n\n"
except Exception:
# On error, cancel the agent task if still running
if not agent_task.done():
agent_task.cancel()
raise
finally:
self._unsubscribe_all()