271 lines
8.0 KiB
Python
271 lines
8.0 KiB
Python
"""Chat routes: send messages, stream responses, manage conversations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import tempfile
|
|
import time
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
from fastapi import APIRouter, Form, Request, UploadFile
|
|
from fastapi.responses import HTMLResponse
|
|
from fastapi.templating import Jinja2Templates
|
|
from sse_starlette.sse import EventSourceResponse
|
|
|
|
if TYPE_CHECKING:
|
|
from ..agent_registry import AgentRegistry
|
|
from ..config import Config
|
|
from ..db import Database
|
|
from ..llm import LLMAdapter
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/chat")
|
|
|
|
_registry: AgentRegistry | None = None
|
|
_config: Config | None = None
|
|
_llm: LLMAdapter | None = None
|
|
_db: Database | None = None
|
|
_templates: Jinja2Templates | None = None
|
|
|
|
# Pending responses: conv_id -> {text, files, timestamp}
|
|
_pending: dict[str, dict] = {}
|
|
|
|
|
|
def setup(registry, config, llm, db, templates):
|
|
global _registry, _config, _llm, _db, _templates
|
|
_registry = registry
|
|
_config = config
|
|
_llm = llm
|
|
_db = db
|
|
_templates = templates
|
|
|
|
|
|
def _get_agent(name: str):
|
|
if _registry:
|
|
return _registry.get(name) or _registry.default
|
|
return None
|
|
|
|
|
|
def _cleanup_pending():
|
|
"""Remove pending entries older than 60s."""
|
|
now = time.time()
|
|
expired = [k for k, v in _pending.items() if now - v["timestamp"] > 60]
|
|
for k in expired:
|
|
del _pending[k]
|
|
|
|
|
|
@router.post("/send")
|
|
async def send_message(
|
|
request: Request,
|
|
text: str = Form(""),
|
|
agent_name: str = Form("default"),
|
|
conv_id: str = Form(""),
|
|
files: list[UploadFile] | None = None,
|
|
):
|
|
"""Accept user message, return user bubble HTML + trigger SSE stream."""
|
|
_cleanup_pending()
|
|
|
|
agent = _get_agent(agent_name)
|
|
if not agent:
|
|
return HTMLResponse("<div class='error'>Agent not found</div>", status_code=400)
|
|
|
|
# Handle file uploads
|
|
saved_files = []
|
|
for f in (files or []):
|
|
if f.filename and f.size and f.size > 0:
|
|
tmp = Path(tempfile.mkdtemp()) / f.filename
|
|
content = await f.read()
|
|
tmp.write_bytes(content)
|
|
saved_files.append(str(tmp))
|
|
|
|
if not text.strip() and not saved_files:
|
|
return HTMLResponse("")
|
|
|
|
# Ensure conversation exists
|
|
if not conv_id:
|
|
agent.new_conversation()
|
|
conv_id = agent.ensure_conversation()
|
|
else:
|
|
agent.conv_id = conv_id
|
|
|
|
# Build display text
|
|
display_text = text
|
|
if saved_files:
|
|
file_names = [Path(f).name for f in saved_files]
|
|
display_text += f"\n[Attached: {', '.join(file_names)}]"
|
|
|
|
# Stash for SSE stream
|
|
_pending[conv_id] = {
|
|
"text": text,
|
|
"files": saved_files,
|
|
"timestamp": time.time(),
|
|
"agent_name": agent_name,
|
|
}
|
|
|
|
# Render user bubble + SSE trigger div
|
|
user_html = _templates.get_template("partials/chat_message.html").render(
|
|
role="user", content=display_text
|
|
)
|
|
# The SSE trigger div connects to the stream endpoint
|
|
sse_div = (
|
|
f'<div id="sse-trigger" '
|
|
f'hx-ext="sse" '
|
|
f'sse-connect="/chat/stream/{conv_id}" '
|
|
f'sse-swap="chunk" '
|
|
f'hx-target="#assistant-response" '
|
|
f'hx-swap="beforeend">'
|
|
f'</div>'
|
|
f'<div id="assistant-bubble" class="message assistant">'
|
|
f'<div class="message-avatar">CB</div>'
|
|
f'<div class="message-body">'
|
|
f'<div id="assistant-response" class="message-content"></div>'
|
|
f'</div></div>'
|
|
)
|
|
|
|
headers = {
|
|
"HX-Trigger-After-Swap": "scrollChat",
|
|
"HX-Push-Url": f"/?conv={conv_id}",
|
|
}
|
|
|
|
return HTMLResponse(user_html + sse_div, headers=headers)
|
|
|
|
|
|
@router.get("/stream/{conv_id}")
|
|
async def stream_response(conv_id: str):
|
|
"""SSE endpoint: stream assistant response chunks."""
|
|
pending = _pending.pop(conv_id, None)
|
|
if not pending:
|
|
async def empty():
|
|
yield {"event": "done", "data": ""}
|
|
return EventSourceResponse(empty())
|
|
|
|
agent = _get_agent(pending["agent_name"])
|
|
if not agent:
|
|
async def error():
|
|
yield {"event": "chunk", "data": "Agent not found"}
|
|
yield {"event": "done", "data": ""}
|
|
return EventSourceResponse(error())
|
|
|
|
agent.conv_id = conv_id
|
|
|
|
async def generate():
|
|
loop = asyncio.get_event_loop()
|
|
queue: asyncio.Queue = asyncio.Queue()
|
|
|
|
def run_agent():
|
|
try:
|
|
for chunk in agent.respond(pending["text"], files=pending.get("files")):
|
|
loop.call_soon_threadsafe(queue.put_nowait, ("chunk", chunk))
|
|
except Exception as e:
|
|
log.error("Stream error: %s", e, exc_info=True)
|
|
loop.call_soon_threadsafe(
|
|
queue.put_nowait, ("chunk", f"\n\nError: {e}")
|
|
)
|
|
finally:
|
|
loop.call_soon_threadsafe(queue.put_nowait, ("done", ""))
|
|
|
|
# Run agent.respond() in a thread
|
|
import threading
|
|
t = threading.Thread(target=run_agent, daemon=True)
|
|
t.start()
|
|
|
|
while True:
|
|
event, data = await queue.get()
|
|
if event == "done":
|
|
yield {"event": "done", "data": conv_id}
|
|
break
|
|
yield {"event": "chunk", "data": data}
|
|
|
|
return EventSourceResponse(generate())
|
|
|
|
|
|
@router.get("/conversations")
|
|
async def list_conversations(agent_name: str = "default"):
|
|
"""Return sidebar conversation list as HTML partial."""
|
|
agent = _get_agent(agent_name)
|
|
if not agent:
|
|
return HTMLResponse("")
|
|
|
|
convs = agent.db.list_conversations(limit=50, agent_name=agent_name)
|
|
html = _templates.get_template("partials/chat_sidebar.html").render(
|
|
conversations=convs
|
|
)
|
|
return HTMLResponse(html)
|
|
|
|
|
|
@router.post("/new")
|
|
async def new_conversation(agent_name: str = Form("default")):
|
|
"""Create a new conversation, return empty chat + updated sidebar."""
|
|
agent = _get_agent(agent_name)
|
|
if not agent:
|
|
return HTMLResponse("")
|
|
|
|
agent.new_conversation()
|
|
conv_id = agent.ensure_conversation()
|
|
|
|
convs = agent.db.list_conversations(limit=50, agent_name=agent_name)
|
|
sidebar_html = _templates.get_template("partials/chat_sidebar.html").render(
|
|
conversations=convs
|
|
)
|
|
|
|
# Return empty chat area + sidebar update via OOB swap
|
|
html = (
|
|
f'<div id="chat-messages"></div>'
|
|
f'<div id="sidebar-conversations" hx-swap-oob="innerHTML">'
|
|
f'{sidebar_html}</div>'
|
|
)
|
|
|
|
headers = {"HX-Push-Url": f"/?conv={conv_id}"}
|
|
return HTMLResponse(html, headers=headers)
|
|
|
|
|
|
@router.get("/load/{conv_id}")
|
|
async def load_conversation(conv_id: str, agent_name: str = "default"):
|
|
"""Load conversation history as HTML."""
|
|
agent = _get_agent(agent_name)
|
|
if not agent:
|
|
return HTMLResponse("")
|
|
|
|
messages = agent.load_conversation(conv_id)
|
|
parts = []
|
|
for msg in messages:
|
|
role = msg.get("role", "")
|
|
content = msg.get("content", "")
|
|
if role in ("user", "assistant") and content:
|
|
parts.append(
|
|
_templates.get_template("partials/chat_message.html").render(
|
|
role=role, content=content
|
|
)
|
|
)
|
|
|
|
headers = {"HX-Push-Url": f"/?conv={conv_id}"}
|
|
return HTMLResponse("\n".join(parts), headers=headers)
|
|
|
|
|
|
@router.post("/agent/{name}")
|
|
async def switch_agent(name: str):
|
|
"""Switch active agent. Returns updated sidebar via OOB."""
|
|
agent = _get_agent(name)
|
|
if not agent:
|
|
return HTMLResponse("<div class='error'>Agent not found</div>", status_code=400)
|
|
|
|
agent.new_conversation()
|
|
conv_id = agent.ensure_conversation()
|
|
|
|
convs = agent.db.list_conversations(limit=50, agent_name=name)
|
|
sidebar_html = _templates.get_template("partials/chat_sidebar.html").render(
|
|
conversations=convs
|
|
)
|
|
|
|
html = (
|
|
f'<div id="chat-messages"></div>'
|
|
f'<div id="sidebar-conversations" hx-swap-oob="innerHTML">'
|
|
f'{sidebar_html}</div>'
|
|
)
|
|
|
|
headers = {"HX-Push-Url": f"/?conv={conv_id}"}
|
|
return HTMLResponse(html, headers=headers)
|