CheddahBot/cheddahbot/memory.py

252 lines
8.8 KiB
Python

"""4-layer memory system.
Layers:
1. Identity - SOUL.md + USER.md (handled by router.py)
2. Long-term - memory/MEMORY.md (learned facts, decisions)
3. Daily logs - memory/YYYY-MM-DD.md (timestamped entries)
4. Semantic - memory/embeddings.db (vector search over all memory)
"""
from __future__ import annotations
import logging
import sqlite3
import threading
from datetime import UTC, datetime
import numpy as np
from .config import Config
from .db import Database
log = logging.getLogger(__name__)
class MemorySystem:
def __init__(self, config: Config, db: Database):
self.config = config
self.db = db
self.memory_dir = config.memory_dir
self._embedder = None
self._embed_lock = threading.Lock()
self._embed_db_path = self.memory_dir / "embeddings.db"
self._init_embed_db()
# ── Public API ──
def get_context(self, query: str) -> str:
"""Build memory context string for the system prompt."""
parts = []
# Long-term memory
lt = self._read_long_term()
if lt:
parts.append(f"## Long-Term Memory\n{lt}")
# Today's log
today_log = self._read_daily_log()
if today_log:
parts.append(f"## Today's Log\n{today_log}")
# Semantic search results
if query:
results = self.search(query, top_k=self.config.memory.search_top_k)
if results:
formatted = "\n".join(f"- {r['text']}" for r in results)
parts.append(f"## Related Memories\n{formatted}")
return "\n\n".join(parts) if parts else ""
def remember(self, text: str):
"""Save a fact/instruction to long-term memory."""
memory_path = self.memory_dir / "MEMORY.md"
timestamp = datetime.now(UTC).strftime("%Y-%m-%d %H:%M")
entry = f"\n- [{timestamp}] {text}\n"
if memory_path.exists():
content = memory_path.read_text(encoding="utf-8")
else:
content = "# Long-Term Memory\n"
content += entry
memory_path.write_text(content, encoding="utf-8")
self._index_text(text, f"memory:long_term:{timestamp}")
log.info("Saved to long-term memory: %s", text[:80])
def log_daily(self, text: str):
"""Append an entry to today's daily log."""
today = datetime.now(UTC).strftime("%Y-%m-%d")
log_path = self.memory_dir / f"{today}.md"
timestamp = datetime.now(UTC).strftime("%H:%M")
if log_path.exists():
content = log_path.read_text(encoding="utf-8")
else:
content = f"# Daily Log - {today}\n"
content += f"\n- [{timestamp}] {text}\n"
log_path.write_text(content, encoding="utf-8")
self._index_text(text, f"daily:{today}:{timestamp}")
def search(self, query: str, top_k: int = 5) -> list[dict]:
"""Semantic search over all indexed memory."""
embedder = self._get_embedder()
if embedder is None:
return self._fallback_search(query, top_k)
query_vec = embedder.encode([query])[0]
return self._vector_search(query_vec, top_k)
def auto_flush(self, conv_id: str):
"""Summarize old messages and move to daily log, then delete flushed messages."""
messages = self.db.get_messages(conv_id, limit=200)
if len(messages) < self.config.memory.flush_threshold:
return
# Take older messages for summarization, keep last 10 in context
to_summarize = messages[:-10]
if not to_summarize:
return
# Build a concise summary (skip tool results, keep user/assistant text)
summary_parts = []
for m in to_summarize:
role = m.get("role", "")
content = (m.get("content") or "").strip()
if not content or role == "tool":
continue
summary_parts.append(f"{role}: {content[:150]}")
if not summary_parts:
return
summary = f"Conversation summary ({len(to_summarize)} messages):\n" + "\n".join(
summary_parts[:20]
)
self.log_daily(summary)
# Delete the flushed messages from DB so they don't get re-flushed
flushed_ids = [m["id"] for m in to_summarize if "id" in m]
if flushed_ids:
self.db.delete_messages(flushed_ids)
log.info("Auto-flushed %d messages from conv %s", len(to_summarize), conv_id)
def reindex_all(self):
"""Rebuild the embedding index from all memory files."""
self._clear_embeddings()
for path in self.memory_dir.glob("*.md"):
content = path.read_text(encoding="utf-8")
for i, line in enumerate(content.split("\n")):
line = line.strip().lstrip("- ")
if len(line) > 10:
self._index_text(line, f"file:{path.name}:L{i}")
log.info("Reindexed all memory files")
# ── Private: Long-term memory ──
def _read_long_term(self) -> str:
path = self.memory_dir / "MEMORY.md"
if path.exists():
content = path.read_text(encoding="utf-8")
# Return last 2000 chars to keep prompt manageable
return content[-2000:] if len(content) > 2000 else content
return ""
def _read_daily_log(self) -> str:
today = datetime.now(UTC).strftime("%Y-%m-%d")
path = self.memory_dir / f"{today}.md"
if path.exists():
content = path.read_text(encoding="utf-8")
return content[-1500:] if len(content) > 1500 else content
return ""
# ── Private: Embedding system ──
def _init_embed_db(self):
conn = sqlite3.connect(str(self._embed_db_path))
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id TEXT PRIMARY KEY,
text TEXT NOT NULL,
vector BLOB NOT NULL
)
""")
conn.commit()
conn.close()
def _get_embedder(self):
if self._embedder is not None:
return self._embedder
with self._embed_lock:
if self._embedder is not None:
return self._embedder
try:
from sentence_transformers import SentenceTransformer
model_name = self.config.memory.embedding_model
log.info("Loading embedding model: %s", model_name)
self._embedder = SentenceTransformer(model_name)
return self._embedder
except ImportError:
log.warning("sentence-transformers not installed; semantic search disabled")
return None
except Exception as e:
log.warning("Failed to load embedding model: %s", e)
return None
def _index_text(self, text: str, doc_id: str):
embedder = self._get_embedder()
if embedder is None:
return
vec = embedder.encode([text])[0]
conn = sqlite3.connect(str(self._embed_db_path))
conn.execute(
"INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)",
(doc_id, text, vec.tobytes()),
)
conn.commit()
conn.close()
def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]:
conn = sqlite3.connect(str(self._embed_db_path))
rows = conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
conn.close()
if not rows:
return []
scored = []
for doc_id, text, vec_bytes in rows:
vec = np.frombuffer(vec_bytes, dtype=np.float32)
sim = float(
np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec) + 1e-8)
)
scored.append({"id": doc_id, "text": text, "score": sim})
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]
def _clear_embeddings(self):
conn = sqlite3.connect(str(self._embed_db_path))
conn.execute("DELETE FROM embeddings")
conn.commit()
conn.close()
def _fallback_search(self, query: str, top_k: int) -> list[dict]:
"""Simple keyword search when embeddings are unavailable."""
results = []
query_lower = query.lower()
for path in self.memory_dir.glob("*.md"):
try:
content = path.read_text(encoding="utf-8")
except Exception:
continue
for line in content.split("\n"):
stripped = line.strip().lstrip("- ")
if len(stripped) > 10 and query_lower in stripped.lower():
results.append({"id": path.name, "text": stripped, "score": 1.0})
if len(results) >= top_k:
return results
return results