"""SQLite persistence layer.""" from __future__ import annotations import contextlib import json import sqlite3 import threading from datetime import UTC, datetime from pathlib import Path class Database: def __init__(self, db_path: Path): self._path = db_path self._local = threading.local() self._init_schema() @property def _conn(self) -> sqlite3.Connection: if not hasattr(self._local, "conn"): self._local.conn = sqlite3.connect(str(self._path)) self._local.conn.row_factory = sqlite3.Row self._local.conn.execute("PRAGMA journal_mode=WAL") self._local.conn.execute("PRAGMA foreign_keys=ON") return self._local.conn def _init_schema(self): self._conn.executescript(""" CREATE TABLE IF NOT EXISTS conversations ( id TEXT PRIMARY KEY, title TEXT, created_at TEXT NOT NULL, updated_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, conv_id TEXT NOT NULL REFERENCES conversations(id), role TEXT NOT NULL, content TEXT NOT NULL, tool_calls TEXT, tool_result TEXT, model TEXT, created_at TEXT NOT NULL ); CREATE INDEX IF NOT EXISTS idx_messages_conv ON messages(conv_id, created_at); CREATE TABLE IF NOT EXISTS scheduled_tasks ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, prompt TEXT NOT NULL, schedule TEXT NOT NULL, enabled INTEGER NOT NULL DEFAULT 1, next_run TEXT, created_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS task_run_logs ( id INTEGER PRIMARY KEY AUTOINCREMENT, task_id INTEGER NOT NULL REFERENCES scheduled_tasks(id), started_at TEXT NOT NULL, finished_at TEXT, result TEXT, error TEXT ); CREATE TABLE IF NOT EXISTS kv_store ( key TEXT PRIMARY KEY, value TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS notifications ( id INTEGER PRIMARY KEY AUTOINCREMENT, message TEXT NOT NULL, category TEXT NOT NULL DEFAULT 'clickup', created_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS api_usage ( id INTEGER PRIMARY KEY AUTOINCREMENT, timestamp TEXT NOT NULL, model TEXT NOT NULL, provider TEXT NOT NULL, prompt_tokens INTEGER NOT NULL DEFAULT 0, completion_tokens INTEGER NOT NULL DEFAULT 0, total_tokens INTEGER NOT NULL DEFAULT 0, estimated_cost REAL NOT NULL DEFAULT 0.0, conv_id TEXT, agent_name TEXT ); """) # Migration: add agent_name column to conversations (idempotent) with contextlib.suppress(sqlite3.OperationalError): self._conn.execute( "ALTER TABLE conversations ADD COLUMN agent_name TEXT DEFAULT 'default'" ) self._conn.commit() # -- Conversations -- def create_conversation( self, conv_id: str, title: str = "New Chat", agent_name: str = "default" ) -> str: now = _now() self._conn.execute( "INSERT INTO conversations (id, title, created_at, updated_at, agent_name)" " VALUES (?, ?, ?, ?, ?)", (conv_id, title, now, now, agent_name), ) self._conn.commit() return conv_id def list_conversations(self, limit: int = 50, agent_name: str | None = None) -> list[dict]: # Only return conversations that have at least one message if agent_name: rows = self._conn.execute( "SELECT c.id, c.title, c.updated_at, c.agent_name" " FROM conversations c" " WHERE c.agent_name = ?" " AND EXISTS (SELECT 1 FROM messages m WHERE m.conv_id = c.id)" " ORDER BY c.updated_at DESC LIMIT ?", (agent_name, limit), ).fetchall() else: rows = self._conn.execute( "SELECT c.id, c.title, c.updated_at, c.agent_name" " FROM conversations c" " WHERE EXISTS (SELECT 1 FROM messages m WHERE m.conv_id = c.id)" " ORDER BY c.updated_at DESC LIMIT ?", (limit,), ).fetchall() return [dict(r) for r in rows] def get_conversation_title(self, conv_id: str) -> str | None: row = self._conn.execute( "SELECT title FROM conversations WHERE id = ?", (conv_id,) ).fetchone() return row["title"] if row else None def update_conversation_title(self, conv_id: str, title: str): self._conn.execute( "UPDATE conversations SET title = ? WHERE id = ?", (title, conv_id) ) self._conn.commit() # -- Messages -- def add_message( self, conv_id: str, role: str, content: str, tool_calls: list | None = None, tool_result: str | None = None, model: str | None = None, ) -> int: now = _now() cur = self._conn.execute( """INSERT INTO messages (conv_id, role, content, tool_calls, tool_result, model, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)""", ( conv_id, role, content, json.dumps(tool_calls) if tool_calls else None, tool_result, model, now, ), ) self._conn.execute("UPDATE conversations SET updated_at = ? WHERE id = ?", (now, conv_id)) self._conn.commit() return cur.lastrowid def get_messages(self, conv_id: str, limit: int = 100) -> list[dict]: rows = self._conn.execute( """SELECT id, role, content, tool_calls, tool_result, model, created_at FROM messages WHERE conv_id = ? ORDER BY created_at ASC LIMIT ?""", (conv_id, limit), ).fetchall() result = [] for r in rows: msg = dict(r) if msg["tool_calls"]: msg["tool_calls"] = json.loads(msg["tool_calls"]) result.append(msg) return result def count_messages(self, conv_id: str) -> int: row = self._conn.execute( "SELECT COUNT(*) as cnt FROM messages WHERE conv_id = ?", (conv_id,) ).fetchone() return row["cnt"] def delete_messages(self, message_ids: list[int]): """Delete messages by their IDs (used by auto_flush).""" if not message_ids: return placeholders = ",".join("?" for _ in message_ids) self._conn.execute(f"DELETE FROM messages WHERE id IN ({placeholders})", message_ids) self._conn.commit() # -- Scheduled Tasks -- def add_scheduled_task(self, name: str, prompt: str, schedule: str) -> int: now = _now() cur = self._conn.execute( "INSERT INTO scheduled_tasks (name, prompt, schedule, created_at) VALUES (?, ?, ?, ?)", (name, prompt, schedule, now), ) self._conn.commit() return cur.lastrowid def get_due_tasks(self) -> list[dict]: now = _now() rows = self._conn.execute( "SELECT * FROM scheduled_tasks" " WHERE enabled = 1 AND (next_run IS NULL OR next_run <= ?)", (now,), ).fetchall() return [dict(r) for r in rows] def update_task_next_run(self, task_id: int, next_run: str): self._conn.execute( "UPDATE scheduled_tasks SET next_run = ? WHERE id = ?", (next_run, task_id) ) self._conn.commit() def disable_task(self, task_id: int): """Disable a scheduled task (e.g. after a one-time task has run).""" self._conn.execute("UPDATE scheduled_tasks SET enabled = 0 WHERE id = ?", (task_id,)) self._conn.commit() def log_task_run(self, task_id: int, result: str | None = None, error: str | None = None): now = _now() self._conn.execute( "INSERT INTO task_run_logs" " (task_id, started_at, finished_at, result, error)" " VALUES (?, ?, ?, ?, ?)", (task_id, now, now, result, error), ) self._conn.commit() # -- Key-Value Store -- def kv_set(self, key: str, value: str): self._conn.execute( "INSERT OR REPLACE INTO kv_store (key, value) VALUES (?, ?)", (key, value) ) self._conn.commit() def kv_get(self, key: str) -> str | None: row = self._conn.execute("SELECT value FROM kv_store WHERE key = ?", (key,)).fetchone() return row["value"] if row else None def kv_scan(self, prefix: str) -> list[tuple[str, str]]: """Return all key-value pairs where key starts with prefix.""" rows = self._conn.execute( "SELECT key, value FROM kv_store WHERE key LIKE ?", (prefix + "%",) ).fetchall() return [(r["key"], r["value"]) for r in rows] def kv_delete(self, key: str): """Delete a key from the kv_store.""" self._conn.execute("DELETE FROM kv_store WHERE key = ?", (key,)) self._conn.commit() # -- Notifications -- def add_notification(self, message: str, category: str = "clickup") -> int: now = _now() cur = self._conn.execute( "INSERT INTO notifications (message, category, created_at) VALUES (?, ?, ?)", (message, category, now), ) self._conn.commit() return cur.lastrowid def get_max_notification_id(self) -> int: """Return the highest notification id, or 0 if the table is empty.""" row = self._conn.execute("SELECT MAX(id) FROM notifications").fetchone() return row[0] or 0 def get_notifications_after(self, after_id: int = 0, limit: int = 50) -> list[dict]: """Get notifications with id > after_id.""" rows = self._conn.execute( "SELECT id, message, category, created_at FROM notifications" " WHERE id > ? ORDER BY id ASC LIMIT ?", (after_id, limit), ).fetchall() return [dict(r) for r in rows] # -- API Usage -- def log_api_usage( self, model: str, provider: str, prompt_tokens: int, completion_tokens: int, total_tokens: int, estimated_cost: float, conv_id: str | None = None, agent_name: str | None = None, ): now = _now() self._conn.execute( """INSERT INTO api_usage (timestamp, model, provider, prompt_tokens, completion_tokens, total_tokens, estimated_cost, conv_id, agent_name) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)""", (now, model, provider, prompt_tokens, completion_tokens, total_tokens, estimated_cost, conv_id, agent_name), ) self._conn.commit() def get_api_usage_summary(self, days: int = 30) -> dict: """Return total tokens, total cost, and per-model breakdown for the period.""" cutoff = datetime.now(UTC).isoformat()[:10] # today # Compute cutoff date from datetime import timedelta cutoff_dt = datetime.now(UTC) - timedelta(days=days) cutoff = cutoff_dt.isoformat() row = self._conn.execute( "SELECT COALESCE(SUM(prompt_tokens), 0) as prompt_tokens," " COALESCE(SUM(completion_tokens), 0) as completion_tokens," " COALESCE(SUM(total_tokens), 0) as total_tokens," " COALESCE(SUM(estimated_cost), 0.0) as total_cost" " FROM api_usage WHERE timestamp >= ?", (cutoff,), ).fetchone() model_rows = self._conn.execute( "SELECT model," " COALESCE(SUM(prompt_tokens), 0) as prompt_tokens," " COALESCE(SUM(completion_tokens), 0) as completion_tokens," " COALESCE(SUM(total_tokens), 0) as total_tokens," " COALESCE(SUM(estimated_cost), 0.0) as total_cost," " COUNT(*) as call_count" " FROM api_usage WHERE timestamp >= ?" " GROUP BY model ORDER BY total_cost DESC", (cutoff,), ).fetchall() return { "prompt_tokens": row["prompt_tokens"], "completion_tokens": row["completion_tokens"], "total_tokens": row["total_tokens"], "total_cost": row["total_cost"], "by_model": [dict(r) for r in model_rows], } def get_api_usage_daily(self, days: int = 7) -> list[dict]: """Return daily totals for trending.""" from datetime import timedelta cutoff_dt = datetime.now(UTC) - timedelta(days=days) cutoff = cutoff_dt.isoformat() rows = self._conn.execute( "SELECT DATE(timestamp) as day," " COALESCE(SUM(total_tokens), 0) as total_tokens," " COALESCE(SUM(estimated_cost), 0.0) as total_cost," " COUNT(*) as call_count" " FROM api_usage WHERE timestamp >= ?" " GROUP BY DATE(timestamp) ORDER BY day ASC", (cutoff,), ).fetchall() return [dict(r) for r in rows] def _now() -> str: return datetime.now(UTC).isoformat()