CheddahBot/cheddahbot/web/routes_chat.py

271 lines
8.0 KiB
Python

"""Chat routes: send messages, stream responses, manage conversations."""
from __future__ import annotations
import asyncio
import logging
import tempfile
import time
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, Form, Request, UploadFile
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from sse_starlette.sse import EventSourceResponse
if TYPE_CHECKING:
from ..agent_registry import AgentRegistry
from ..config import Config
from ..db import Database
from ..llm import LLMAdapter
log = logging.getLogger(__name__)
router = APIRouter(prefix="/chat")
_registry: AgentRegistry | None = None
_config: Config | None = None
_llm: LLMAdapter | None = None
_db: Database | None = None
_templates: Jinja2Templates | None = None
# Pending responses: conv_id -> {text, files, timestamp}
_pending: dict[str, dict] = {}
def setup(registry, config, llm, db, templates):
global _registry, _config, _llm, _db, _templates
_registry = registry
_config = config
_llm = llm
_db = db
_templates = templates
def _get_agent(name: str):
if _registry:
return _registry.get(name) or _registry.default
return None
def _cleanup_pending():
"""Remove pending entries older than 60s."""
now = time.time()
expired = [k for k, v in _pending.items() if now - v["timestamp"] > 60]
for k in expired:
del _pending[k]
@router.post("/send")
async def send_message(
request: Request,
text: str = Form(""),
agent_name: str = Form("default"),
conv_id: str = Form(""),
files: list[UploadFile] | None = None,
):
"""Accept user message, return user bubble HTML + trigger SSE stream."""
_cleanup_pending()
agent = _get_agent(agent_name)
if not agent:
return HTMLResponse("<div class='error'>Agent not found</div>", status_code=400)
# Handle file uploads
saved_files = []
for f in (files or []):
if f.filename and f.size and f.size > 0:
tmp = Path(tempfile.mkdtemp()) / f.filename
content = await f.read()
tmp.write_bytes(content)
saved_files.append(str(tmp))
if not text.strip() and not saved_files:
return HTMLResponse("")
# Ensure conversation exists
if not conv_id:
agent.new_conversation()
conv_id = agent.ensure_conversation()
else:
agent.conv_id = conv_id
# Build display text
display_text = text
if saved_files:
file_names = [Path(f).name for f in saved_files]
display_text += f"\n[Attached: {', '.join(file_names)}]"
# Stash for SSE stream
_pending[conv_id] = {
"text": text,
"files": saved_files,
"timestamp": time.time(),
"agent_name": agent_name,
}
# Render user bubble + SSE trigger div
user_html = _templates.get_template("partials/chat_message.html").render(
role="user", content=display_text
)
# The SSE trigger div connects to the stream endpoint
sse_div = (
f'<div id="sse-trigger" '
f'hx-ext="sse" '
f'sse-connect="/chat/stream/{conv_id}" '
f'sse-swap="chunk" '
f'hx-target="#assistant-response" '
f'hx-swap="beforeend">'
f'</div>'
f'<div id="assistant-bubble" class="message assistant">'
f'<div class="message-avatar">CB</div>'
f'<div class="message-body">'
f'<div id="assistant-response" class="message-content"></div>'
f'</div></div>'
)
headers = {
"HX-Trigger-After-Swap": "scrollChat",
"HX-Push-Url": f"/?conv={conv_id}",
}
return HTMLResponse(user_html + sse_div, headers=headers)
@router.get("/stream/{conv_id}")
async def stream_response(conv_id: str):
"""SSE endpoint: stream assistant response chunks."""
pending = _pending.pop(conv_id, None)
if not pending:
async def empty():
yield {"event": "done", "data": ""}
return EventSourceResponse(empty())
agent = _get_agent(pending["agent_name"])
if not agent:
async def error():
yield {"event": "chunk", "data": "Agent not found"}
yield {"event": "done", "data": ""}
return EventSourceResponse(error())
agent.conv_id = conv_id
async def generate():
loop = asyncio.get_event_loop()
queue: asyncio.Queue = asyncio.Queue()
def run_agent():
try:
for chunk in agent.respond(pending["text"], files=pending.get("files")):
loop.call_soon_threadsafe(queue.put_nowait, ("chunk", chunk))
except Exception as e:
log.error("Stream error: %s", e, exc_info=True)
loop.call_soon_threadsafe(
queue.put_nowait, ("chunk", f"\n\nError: {e}")
)
finally:
loop.call_soon_threadsafe(queue.put_nowait, ("done", ""))
# Run agent.respond() in a thread
import threading
t = threading.Thread(target=run_agent, daemon=True)
t.start()
while True:
event, data = await queue.get()
if event == "done":
yield {"event": "done", "data": conv_id}
break
yield {"event": "chunk", "data": data}
return EventSourceResponse(generate())
@router.get("/conversations")
async def list_conversations(agent_name: str = "default"):
"""Return sidebar conversation list as HTML partial."""
agent = _get_agent(agent_name)
if not agent:
return HTMLResponse("")
convs = agent.db.list_conversations(limit=50, agent_name=agent_name)
html = _templates.get_template("partials/chat_sidebar.html").render(
conversations=convs
)
return HTMLResponse(html)
@router.post("/new")
async def new_conversation(agent_name: str = Form("default")):
"""Create a new conversation, return empty chat + updated sidebar."""
agent = _get_agent(agent_name)
if not agent:
return HTMLResponse("")
agent.new_conversation()
conv_id = agent.ensure_conversation()
convs = agent.db.list_conversations(limit=50, agent_name=agent_name)
sidebar_html = _templates.get_template("partials/chat_sidebar.html").render(
conversations=convs
)
# Return empty chat area + sidebar update via OOB swap
html = (
f'<div id="chat-messages"></div>'
f'<div id="sidebar-conversations" hx-swap-oob="innerHTML">'
f'{sidebar_html}</div>'
)
headers = {"HX-Push-Url": f"/?conv={conv_id}"}
return HTMLResponse(html, headers=headers)
@router.get("/load/{conv_id}")
async def load_conversation(conv_id: str, agent_name: str = "default"):
"""Load conversation history as HTML."""
agent = _get_agent(agent_name)
if not agent:
return HTMLResponse("")
messages = agent.load_conversation(conv_id)
parts = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role in ("user", "assistant") and content:
parts.append(
_templates.get_template("partials/chat_message.html").render(
role=role, content=content
)
)
headers = {"HX-Push-Url": f"/?conv={conv_id}"}
return HTMLResponse("\n".join(parts), headers=headers)
@router.post("/agent/{name}")
async def switch_agent(name: str):
"""Switch active agent. Returns updated sidebar via OOB."""
agent = _get_agent(name)
if not agent:
return HTMLResponse("<div class='error'>Agent not found</div>", status_code=400)
agent.new_conversation()
conv_id = agent.ensure_conversation()
convs = agent.db.list_conversations(limit=50, agent_name=name)
sidebar_html = _templates.get_template("partials/chat_sidebar.html").render(
conversations=convs
)
html = (
f'<div id="chat-messages"></div>'
f'<div id="sidebar-conversations" hx-swap-oob="innerHTML">'
f'{sidebar_html}</div>'
)
headers = {"HX-Push-Url": f"/?conv={conv_id}"}
return HTMLResponse(html, headers=headers)