260 lines
9.2 KiB
Python
260 lines
9.2 KiB
Python
"""SQLite persistence layer."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import json
|
|
import sqlite3
|
|
import threading
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
|
|
|
|
class Database:
|
|
def __init__(self, db_path: Path):
|
|
self._path = db_path
|
|
self._local = threading.local()
|
|
self._init_schema()
|
|
|
|
@property
|
|
def _conn(self) -> sqlite3.Connection:
|
|
if not hasattr(self._local, "conn"):
|
|
self._local.conn = sqlite3.connect(str(self._path))
|
|
self._local.conn.row_factory = sqlite3.Row
|
|
self._local.conn.execute("PRAGMA journal_mode=WAL")
|
|
self._local.conn.execute("PRAGMA foreign_keys=ON")
|
|
return self._local.conn
|
|
|
|
def _init_schema(self):
|
|
self._conn.executescript("""
|
|
CREATE TABLE IF NOT EXISTS conversations (
|
|
id TEXT PRIMARY KEY,
|
|
title TEXT,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
conv_id TEXT NOT NULL REFERENCES conversations(id),
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
tool_calls TEXT,
|
|
tool_result TEXT,
|
|
model TEXT,
|
|
created_at TEXT NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_messages_conv ON messages(conv_id, created_at);
|
|
|
|
CREATE TABLE IF NOT EXISTS scheduled_tasks (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT NOT NULL,
|
|
prompt TEXT NOT NULL,
|
|
schedule TEXT NOT NULL,
|
|
enabled INTEGER NOT NULL DEFAULT 1,
|
|
next_run TEXT,
|
|
created_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS task_run_logs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
task_id INTEGER NOT NULL REFERENCES scheduled_tasks(id),
|
|
started_at TEXT NOT NULL,
|
|
finished_at TEXT,
|
|
result TEXT,
|
|
error TEXT
|
|
);
|
|
CREATE TABLE IF NOT EXISTS kv_store (
|
|
key TEXT PRIMARY KEY,
|
|
value TEXT NOT NULL
|
|
);
|
|
CREATE TABLE IF NOT EXISTS notifications (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
message TEXT NOT NULL,
|
|
category TEXT NOT NULL DEFAULT 'clickup',
|
|
created_at TEXT NOT NULL
|
|
);
|
|
""")
|
|
# Migration: add agent_name column to conversations (idempotent)
|
|
with contextlib.suppress(sqlite3.OperationalError):
|
|
self._conn.execute(
|
|
"ALTER TABLE conversations ADD COLUMN agent_name TEXT DEFAULT 'default'"
|
|
)
|
|
self._conn.commit()
|
|
|
|
# -- Conversations --
|
|
|
|
def create_conversation(
|
|
self, conv_id: str, title: str = "New Chat", agent_name: str = "default"
|
|
) -> str:
|
|
now = _now()
|
|
self._conn.execute(
|
|
"INSERT INTO conversations (id, title, created_at, updated_at, agent_name)"
|
|
" VALUES (?, ?, ?, ?, ?)",
|
|
(conv_id, title, now, now, agent_name),
|
|
)
|
|
self._conn.commit()
|
|
return conv_id
|
|
|
|
def list_conversations(
|
|
self, limit: int = 50, agent_name: str | None = None
|
|
) -> list[dict]:
|
|
if agent_name:
|
|
rows = self._conn.execute(
|
|
"SELECT id, title, updated_at, agent_name FROM conversations"
|
|
" WHERE agent_name = ? ORDER BY updated_at DESC LIMIT ?",
|
|
(agent_name, limit),
|
|
).fetchall()
|
|
else:
|
|
rows = self._conn.execute(
|
|
"SELECT id, title, updated_at, agent_name FROM conversations"
|
|
" ORDER BY updated_at DESC LIMIT ?",
|
|
(limit,),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
# -- Messages --
|
|
|
|
def add_message(
|
|
self,
|
|
conv_id: str,
|
|
role: str,
|
|
content: str,
|
|
tool_calls: list | None = None,
|
|
tool_result: str | None = None,
|
|
model: str | None = None,
|
|
) -> int:
|
|
now = _now()
|
|
cur = self._conn.execute(
|
|
"""INSERT INTO messages
|
|
(conv_id, role, content, tool_calls, tool_result, model, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
|
(
|
|
conv_id,
|
|
role,
|
|
content,
|
|
json.dumps(tool_calls) if tool_calls else None,
|
|
tool_result,
|
|
model,
|
|
now,
|
|
),
|
|
)
|
|
self._conn.execute("UPDATE conversations SET updated_at = ? WHERE id = ?", (now, conv_id))
|
|
self._conn.commit()
|
|
return cur.lastrowid
|
|
|
|
def get_messages(self, conv_id: str, limit: int = 100) -> list[dict]:
|
|
rows = self._conn.execute(
|
|
"""SELECT id, role, content, tool_calls, tool_result, model, created_at
|
|
FROM messages WHERE conv_id = ? ORDER BY created_at ASC LIMIT ?""",
|
|
(conv_id, limit),
|
|
).fetchall()
|
|
result = []
|
|
for r in rows:
|
|
msg = dict(r)
|
|
if msg["tool_calls"]:
|
|
msg["tool_calls"] = json.loads(msg["tool_calls"])
|
|
result.append(msg)
|
|
return result
|
|
|
|
def count_messages(self, conv_id: str) -> int:
|
|
row = self._conn.execute(
|
|
"SELECT COUNT(*) as cnt FROM messages WHERE conv_id = ?", (conv_id,)
|
|
).fetchone()
|
|
return row["cnt"]
|
|
|
|
def delete_messages(self, message_ids: list[int]):
|
|
"""Delete messages by their IDs (used by auto_flush)."""
|
|
if not message_ids:
|
|
return
|
|
placeholders = ",".join("?" for _ in message_ids)
|
|
self._conn.execute(f"DELETE FROM messages WHERE id IN ({placeholders})", message_ids)
|
|
self._conn.commit()
|
|
|
|
# -- Scheduled Tasks --
|
|
|
|
def add_scheduled_task(self, name: str, prompt: str, schedule: str) -> int:
|
|
now = _now()
|
|
cur = self._conn.execute(
|
|
"INSERT INTO scheduled_tasks (name, prompt, schedule, created_at) VALUES (?, ?, ?, ?)",
|
|
(name, prompt, schedule, now),
|
|
)
|
|
self._conn.commit()
|
|
return cur.lastrowid
|
|
|
|
def get_due_tasks(self) -> list[dict]:
|
|
now = _now()
|
|
rows = self._conn.execute(
|
|
"SELECT * FROM scheduled_tasks"
|
|
" WHERE enabled = 1 AND (next_run IS NULL OR next_run <= ?)",
|
|
(now,),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
def update_task_next_run(self, task_id: int, next_run: str):
|
|
self._conn.execute(
|
|
"UPDATE scheduled_tasks SET next_run = ? WHERE id = ?", (next_run, task_id)
|
|
)
|
|
self._conn.commit()
|
|
|
|
def disable_task(self, task_id: int):
|
|
"""Disable a scheduled task (e.g. after a one-time task has run)."""
|
|
self._conn.execute("UPDATE scheduled_tasks SET enabled = 0 WHERE id = ?", (task_id,))
|
|
self._conn.commit()
|
|
|
|
def log_task_run(self, task_id: int, result: str | None = None, error: str | None = None):
|
|
now = _now()
|
|
self._conn.execute(
|
|
"INSERT INTO task_run_logs"
|
|
" (task_id, started_at, finished_at, result, error)"
|
|
" VALUES (?, ?, ?, ?, ?)",
|
|
(task_id, now, now, result, error),
|
|
)
|
|
self._conn.commit()
|
|
|
|
# -- Key-Value Store --
|
|
|
|
def kv_set(self, key: str, value: str):
|
|
self._conn.execute(
|
|
"INSERT OR REPLACE INTO kv_store (key, value) VALUES (?, ?)", (key, value)
|
|
)
|
|
self._conn.commit()
|
|
|
|
def kv_get(self, key: str) -> str | None:
|
|
row = self._conn.execute("SELECT value FROM kv_store WHERE key = ?", (key,)).fetchone()
|
|
return row["value"] if row else None
|
|
|
|
def kv_scan(self, prefix: str) -> list[tuple[str, str]]:
|
|
"""Return all key-value pairs where key starts with prefix."""
|
|
rows = self._conn.execute(
|
|
"SELECT key, value FROM kv_store WHERE key LIKE ?", (prefix + "%",)
|
|
).fetchall()
|
|
return [(r["key"], r["value"]) for r in rows]
|
|
|
|
# -- Notifications --
|
|
|
|
def add_notification(self, message: str, category: str = "clickup") -> int:
|
|
now = _now()
|
|
cur = self._conn.execute(
|
|
"INSERT INTO notifications (message, category, created_at) VALUES (?, ?, ?)",
|
|
(message, category, now),
|
|
)
|
|
self._conn.commit()
|
|
return cur.lastrowid
|
|
|
|
def get_max_notification_id(self) -> int:
|
|
"""Return the highest notification id, or 0 if the table is empty."""
|
|
row = self._conn.execute("SELECT MAX(id) FROM notifications").fetchone()
|
|
return row[0] or 0
|
|
|
|
def get_notifications_after(self, after_id: int = 0, limit: int = 50) -> list[dict]:
|
|
"""Get notifications with id > after_id."""
|
|
rows = self._conn.execute(
|
|
"SELECT id, message, category, created_at FROM notifications"
|
|
" WHERE id > ? ORDER BY id ASC LIMIT ?",
|
|
(after_id, limit),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|
|
|
|
|
|
def _now() -> str:
|
|
return datetime.now(UTC).isoformat()
|