CheddahBot/cheddahbot/db.py

188 lines
6.3 KiB
Python

"""SQLite persistence layer."""
from __future__ import annotations
import json
import sqlite3
import threading
from datetime import datetime, timezone
from pathlib import Path
class Database:
def __init__(self, db_path: Path):
self._path = db_path
self._local = threading.local()
self._init_schema()
@property
def _conn(self) -> sqlite3.Connection:
if not hasattr(self._local, "conn"):
self._local.conn = sqlite3.connect(str(self._path))
self._local.conn.row_factory = sqlite3.Row
self._local.conn.execute("PRAGMA journal_mode=WAL")
self._local.conn.execute("PRAGMA foreign_keys=ON")
return self._local.conn
def _init_schema(self):
self._conn.executescript("""
CREATE TABLE IF NOT EXISTS conversations (
id TEXT PRIMARY KEY,
title TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conv_id TEXT NOT NULL REFERENCES conversations(id),
role TEXT NOT NULL,
content TEXT NOT NULL,
tool_calls TEXT,
tool_result TEXT,
model TEXT,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_messages_conv ON messages(conv_id, created_at);
CREATE TABLE IF NOT EXISTS scheduled_tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
prompt TEXT NOT NULL,
schedule TEXT NOT NULL,
enabled INTEGER NOT NULL DEFAULT 1,
next_run TEXT,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS task_run_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_id INTEGER NOT NULL REFERENCES scheduled_tasks(id),
started_at TEXT NOT NULL,
finished_at TEXT,
result TEXT,
error TEXT
);
CREATE TABLE IF NOT EXISTS kv_store (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
);
""")
self._conn.commit()
# -- Conversations --
def create_conversation(self, conv_id: str, title: str = "New Chat") -> str:
now = _now()
self._conn.execute(
"INSERT INTO conversations (id, title, created_at, updated_at) VALUES (?, ?, ?, ?)",
(conv_id, title, now, now),
)
self._conn.commit()
return conv_id
def list_conversations(self, limit: int = 50) -> list[dict]:
rows = self._conn.execute(
"SELECT id, title, updated_at FROM conversations ORDER BY updated_at DESC LIMIT ?",
(limit,),
).fetchall()
return [dict(r) for r in rows]
# -- Messages --
def add_message(
self,
conv_id: str,
role: str,
content: str,
tool_calls: list | None = None,
tool_result: str | None = None,
model: str | None = None,
) -> int:
now = _now()
cur = self._conn.execute(
"""INSERT INTO messages (conv_id, role, content, tool_calls, tool_result, model, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(
conv_id,
role,
content,
json.dumps(tool_calls) if tool_calls else None,
tool_result,
model,
now,
),
)
self._conn.execute(
"UPDATE conversations SET updated_at = ? WHERE id = ?", (now, conv_id)
)
self._conn.commit()
return cur.lastrowid
def get_messages(self, conv_id: str, limit: int = 100) -> list[dict]:
rows = self._conn.execute(
"""SELECT role, content, tool_calls, tool_result, model, created_at
FROM messages WHERE conv_id = ? ORDER BY created_at ASC LIMIT ?""",
(conv_id, limit),
).fetchall()
result = []
for r in rows:
msg = dict(r)
if msg["tool_calls"]:
msg["tool_calls"] = json.loads(msg["tool_calls"])
result.append(msg)
return result
def count_messages(self, conv_id: str) -> int:
row = self._conn.execute(
"SELECT COUNT(*) as cnt FROM messages WHERE conv_id = ?", (conv_id,)
).fetchone()
return row["cnt"]
# -- Scheduled Tasks --
def add_scheduled_task(self, name: str, prompt: str, schedule: str) -> int:
now = _now()
cur = self._conn.execute(
"INSERT INTO scheduled_tasks (name, prompt, schedule, created_at) VALUES (?, ?, ?, ?)",
(name, prompt, schedule, now),
)
self._conn.commit()
return cur.lastrowid
def get_due_tasks(self) -> list[dict]:
now = _now()
rows = self._conn.execute(
"SELECT * FROM scheduled_tasks WHERE enabled = 1 AND (next_run IS NULL OR next_run <= ?)",
(now,),
).fetchall()
return [dict(r) for r in rows]
def update_task_next_run(self, task_id: int, next_run: str):
self._conn.execute(
"UPDATE scheduled_tasks SET next_run = ? WHERE id = ?", (next_run, task_id)
)
self._conn.commit()
def log_task_run(self, task_id: int, result: str | None = None, error: str | None = None):
now = _now()
self._conn.execute(
"INSERT INTO task_run_logs (task_id, started_at, finished_at, result, error) VALUES (?, ?, ?, ?, ?)",
(task_id, now, now, result, error),
)
self._conn.commit()
# -- Key-Value Store --
def kv_set(self, key: str, value: str):
self._conn.execute(
"INSERT OR REPLACE INTO kv_store (key, value) VALUES (?, ?)", (key, value)
)
self._conn.commit()
def kv_get(self, key: str) -> str | None:
row = self._conn.execute("SELECT value FROM kv_store WHERE key = ?", (key,)).fetchone()
return row["value"] if row else None
def _now() -> str:
return datetime.now(timezone.utc).isoformat()