213 lines
7.5 KiB
Python
213 lines
7.5 KiB
Python
"""Gradio interface for CheddahBot."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
import gradio as gr
|
|
|
|
if TYPE_CHECKING:
|
|
from .agent import Agent
|
|
from .config import Config
|
|
from .llm import LLMAdapter
|
|
from .notifications import NotificationBus
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_CSS = """
|
|
.contain { max-width: 900px; margin: auto; }
|
|
footer { display: none !important; }
|
|
.notification-banner {
|
|
background: #1a1a2e;
|
|
border: 1px solid #16213e;
|
|
border-radius: 8px;
|
|
padding: 10px 16px;
|
|
margin-bottom: 8px;
|
|
font-size: 0.9em;
|
|
}
|
|
"""
|
|
|
|
|
|
def create_ui(agent: Agent, config: Config, llm: LLMAdapter,
|
|
notification_bus: NotificationBus | None = None) -> gr.Blocks:
|
|
"""Build and return the Gradio app."""
|
|
|
|
available_models = llm.list_chat_models()
|
|
model_choices = [(m.name, m.id) for m in available_models]
|
|
current_model = llm.current_model
|
|
|
|
exec_status = "available" if llm.is_execution_brain_available() else "unavailable"
|
|
clickup_status = "enabled" if config.clickup.enabled else "disabled"
|
|
|
|
with gr.Blocks(title="CheddahBot") as app:
|
|
gr.Markdown("# CheddahBot", elem_classes=["contain"])
|
|
gr.Markdown(
|
|
f"*Chat Brain:* `{current_model}` | "
|
|
f"*Execution Brain (Claude Code CLI):* `{exec_status}` | "
|
|
f"*ClickUp:* `{clickup_status}`",
|
|
elem_classes=["contain"],
|
|
)
|
|
|
|
# -- Notification banner --
|
|
notification_display = gr.Markdown(
|
|
value="",
|
|
visible=False,
|
|
elem_classes=["contain", "notification-banner"],
|
|
)
|
|
|
|
with gr.Row(elem_classes=["contain"]):
|
|
model_dropdown = gr.Dropdown(
|
|
choices=model_choices,
|
|
value=current_model,
|
|
label="Model",
|
|
interactive=True,
|
|
allow_custom_value=True,
|
|
scale=3,
|
|
)
|
|
refresh_btn = gr.Button("Refresh", scale=0, min_width=70)
|
|
new_chat_btn = gr.Button("New Chat", scale=1, variant="secondary")
|
|
|
|
chatbot = gr.Chatbot(
|
|
label="Chat",
|
|
height=500,
|
|
buttons=["copy"],
|
|
elem_classes=["contain"],
|
|
)
|
|
|
|
pipeline_status = gr.Markdown(
|
|
value="",
|
|
visible=False,
|
|
elem_classes=["contain"],
|
|
)
|
|
|
|
with gr.Row(elem_classes=["contain"]):
|
|
msg_input = gr.MultimodalTextbox(
|
|
placeholder="Type a message... (attach files, use mic, or camera)",
|
|
show_label=False,
|
|
scale=4,
|
|
sources=["upload", "microphone"],
|
|
)
|
|
|
|
|
|
# -- Event handlers --
|
|
|
|
def on_model_change(model_id):
|
|
llm.switch_model(model_id)
|
|
return f"Switched to {model_id}"
|
|
|
|
def on_refresh_models():
|
|
models = llm.list_chat_models()
|
|
choices = [(m.name, m.id) for m in models]
|
|
return gr.update(choices=choices, value=llm.current_model)
|
|
|
|
def on_new_chat():
|
|
agent.new_conversation()
|
|
return []
|
|
|
|
def on_user_message(message, chat_history):
|
|
chat_history = chat_history or []
|
|
|
|
# Extract text and files from MultimodalTextbox
|
|
if isinstance(message, dict):
|
|
text = message.get("text", "")
|
|
files = message.get("files", [])
|
|
else:
|
|
text = str(message)
|
|
files = []
|
|
|
|
if not text and not files:
|
|
yield chat_history, gr.update(value=None)
|
|
return
|
|
|
|
# Handle audio files - transcribe them
|
|
processed_files = []
|
|
for f in files:
|
|
fpath = f if isinstance(f, str) else f.get("path", f.get("name", ""))
|
|
if fpath and Path(fpath).suffix.lower() in (".wav", ".mp3", ".ogg", ".webm", ".m4a"):
|
|
try:
|
|
from .media import transcribe_audio
|
|
transcript = transcribe_audio(fpath)
|
|
if transcript:
|
|
text = f"{text}\n[Voice message]: {transcript}" if text else f"[Voice message]: {transcript}"
|
|
continue
|
|
except Exception as e:
|
|
log.warning("Audio transcription failed: %s", e)
|
|
processed_files.append(fpath)
|
|
|
|
# Add user message
|
|
user_display = text
|
|
if processed_files:
|
|
file_names = [Path(f).name for f in processed_files]
|
|
user_display += f"\n[Attached: {', '.join(file_names)}]"
|
|
|
|
chat_history = chat_history + [{"role": "user", "content": user_display}]
|
|
yield chat_history, gr.update(value=None)
|
|
|
|
# Stream assistant response
|
|
try:
|
|
response_text = ""
|
|
chat_history = chat_history + [{"role": "assistant", "content": ""}]
|
|
|
|
for chunk in agent.respond(text, files=processed_files):
|
|
response_text += chunk
|
|
chat_history[-1] = {"role": "assistant", "content": response_text}
|
|
yield chat_history, gr.update(value=None)
|
|
|
|
# If no response came through, show a fallback
|
|
if not response_text:
|
|
chat_history[-1] = {"role": "assistant", "content": "(No response received from model)"}
|
|
yield chat_history, gr.update(value=None)
|
|
except Exception as e:
|
|
log.error("Error in agent.respond: %s", e, exc_info=True)
|
|
chat_history = chat_history + [{"role": "assistant", "content": f"Error: {e}"}]
|
|
yield chat_history, gr.update(value=None)
|
|
|
|
def poll_pipeline_status():
|
|
"""Poll the DB for pipeline progress updates."""
|
|
status = agent.db.kv_get("pipeline:status")
|
|
if status:
|
|
return gr.update(value=f"⏳ {status}", visible=True)
|
|
return gr.update(value="", visible=False)
|
|
|
|
def poll_notifications():
|
|
"""Poll the notification bus for pending messages."""
|
|
if not notification_bus:
|
|
return gr.update(value="", visible=False)
|
|
|
|
messages = notification_bus.get_pending("gradio")
|
|
if not messages:
|
|
return gr.update() # No change
|
|
|
|
# Format notifications as markdown
|
|
lines = []
|
|
for msg in messages[-5:]: # Show last 5 notifications max
|
|
lines.append(f"**Notification:** {msg}")
|
|
banner = "\n\n".join(lines)
|
|
return gr.update(value=banner, visible=True)
|
|
|
|
# -- Wire events --
|
|
|
|
model_dropdown.change(on_model_change, [model_dropdown], None)
|
|
refresh_btn.click(on_refresh_models, None, [model_dropdown])
|
|
new_chat_btn.click(on_new_chat, None, [chatbot])
|
|
|
|
msg_input.submit(
|
|
on_user_message,
|
|
[msg_input, chatbot],
|
|
[chatbot, msg_input],
|
|
)
|
|
|
|
# Pipeline status polling timer (every 3 seconds)
|
|
status_timer = gr.Timer(3)
|
|
status_timer.tick(poll_pipeline_status, None, [pipeline_status])
|
|
|
|
# Notification polling timer (every 10 seconds)
|
|
if notification_bus:
|
|
notification_bus.subscribe("gradio", lambda msg, cat: None) # Register listener
|
|
timer = gr.Timer(10)
|
|
timer.tick(poll_notifications, None, [notification_display])
|
|
|
|
return app, _CSS
|