CheddahBot/cheddahbot/memory.py

232 lines
8.2 KiB
Python

"""4-layer memory system.
Layers:
1. Identity - SOUL.md + USER.md (handled by router.py)
2. Long-term - memory/MEMORY.md (learned facts, decisions)
3. Daily logs - memory/YYYY-MM-DD.md (timestamped entries)
4. Semantic - memory/embeddings.db (vector search over all memory)
"""
from __future__ import annotations
import logging
import sqlite3
import threading
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
from .config import Config
from .db import Database
log = logging.getLogger(__name__)
class MemorySystem:
def __init__(self, config: Config, db: Database):
self.config = config
self.db = db
self.memory_dir = config.memory_dir
self._embedder = None
self._embed_lock = threading.Lock()
self._embed_db_path = self.memory_dir / "embeddings.db"
self._init_embed_db()
# ── Public API ──
def get_context(self, query: str) -> str:
"""Build memory context string for the system prompt."""
parts = []
# Long-term memory
lt = self._read_long_term()
if lt:
parts.append(f"## Long-Term Memory\n{lt}")
# Today's log
today_log = self._read_daily_log()
if today_log:
parts.append(f"## Today's Log\n{today_log}")
# Semantic search results
if query:
results = self.search(query, top_k=self.config.memory.search_top_k)
if results:
formatted = "\n".join(f"- {r['text']}" for r in results)
parts.append(f"## Related Memories\n{formatted}")
return "\n\n".join(parts) if parts else ""
def remember(self, text: str):
"""Save a fact/instruction to long-term memory."""
memory_path = self.memory_dir / "MEMORY.md"
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M")
entry = f"\n- [{timestamp}] {text}\n"
if memory_path.exists():
content = memory_path.read_text(encoding="utf-8")
else:
content = "# Long-Term Memory\n"
content += entry
memory_path.write_text(content, encoding="utf-8")
self._index_text(text, f"memory:long_term:{timestamp}")
log.info("Saved to long-term memory: %s", text[:80])
def log_daily(self, text: str):
"""Append an entry to today's daily log."""
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
log_path = self.memory_dir / f"{today}.md"
timestamp = datetime.now(timezone.utc).strftime("%H:%M")
if log_path.exists():
content = log_path.read_text(encoding="utf-8")
else:
content = f"# Daily Log - {today}\n"
content += f"\n- [{timestamp}] {text}\n"
log_path.write_text(content, encoding="utf-8")
self._index_text(text, f"daily:{today}:{timestamp}")
def search(self, query: str, top_k: int = 5) -> list[dict]:
"""Semantic search over all indexed memory."""
embedder = self._get_embedder()
if embedder is None:
return self._fallback_search(query, top_k)
query_vec = embedder.encode([query])[0]
return self._vector_search(query_vec, top_k)
def auto_flush(self, conv_id: str):
"""Summarize old messages and move to daily log."""
messages = self.db.get_messages(conv_id, limit=200)
if len(messages) < self.config.memory.flush_threshold:
return
# Take older messages for summarization
to_summarize = messages[:-10] # keep last 10 in context
text_block = "\n".join(
f"{m['role']}: {m['content'][:200]}" for m in to_summarize
if m.get("content")
)
summary = f"Conversation summary ({len(to_summarize)} messages): {text_block[:1000]}"
self.log_daily(summary)
log.info("Auto-flushed %d messages to daily log", len(to_summarize))
def reindex_all(self):
"""Rebuild the embedding index from all memory files."""
self._clear_embeddings()
for path in self.memory_dir.glob("*.md"):
content = path.read_text(encoding="utf-8")
for i, line in enumerate(content.split("\n")):
line = line.strip().lstrip("- ")
if len(line) > 10:
self._index_text(line, f"file:{path.name}:L{i}")
log.info("Reindexed all memory files")
# ── Private: Long-term memory ──
def _read_long_term(self) -> str:
path = self.memory_dir / "MEMORY.md"
if path.exists():
content = path.read_text(encoding="utf-8")
# Return last 2000 chars to keep prompt manageable
return content[-2000:] if len(content) > 2000 else content
return ""
def _read_daily_log(self) -> str:
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
path = self.memory_dir / f"{today}.md"
if path.exists():
content = path.read_text(encoding="utf-8")
return content[-1500:] if len(content) > 1500 else content
return ""
# ── Private: Embedding system ──
def _init_embed_db(self):
conn = sqlite3.connect(str(self._embed_db_path))
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id TEXT PRIMARY KEY,
text TEXT NOT NULL,
vector BLOB NOT NULL
)
""")
conn.commit()
conn.close()
def _get_embedder(self):
if self._embedder is not None:
return self._embedder
with self._embed_lock:
if self._embedder is not None:
return self._embedder
try:
from sentence_transformers import SentenceTransformer
model_name = self.config.memory.embedding_model
log.info("Loading embedding model: %s", model_name)
self._embedder = SentenceTransformer(model_name)
return self._embedder
except ImportError:
log.warning("sentence-transformers not installed; semantic search disabled")
return None
except Exception as e:
log.warning("Failed to load embedding model: %s", e)
return None
def _index_text(self, text: str, doc_id: str):
embedder = self._get_embedder()
if embedder is None:
return
vec = embedder.encode([text])[0]
conn = sqlite3.connect(str(self._embed_db_path))
conn.execute(
"INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)",
(doc_id, text, vec.tobytes()),
)
conn.commit()
conn.close()
def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]:
conn = sqlite3.connect(str(self._embed_db_path))
rows = conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
conn.close()
if not rows:
return []
scored = []
for doc_id, text, vec_bytes in rows:
vec = np.frombuffer(vec_bytes, dtype=np.float32)
sim = float(np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec) + 1e-8))
scored.append({"id": doc_id, "text": text, "score": sim})
scored.sort(key=lambda x: x["score"], reverse=True)
return scored[:top_k]
def _clear_embeddings(self):
conn = sqlite3.connect(str(self._embed_db_path))
conn.execute("DELETE FROM embeddings")
conn.commit()
conn.close()
def _fallback_search(self, query: str, top_k: int) -> list[dict]:
"""Simple keyword search when embeddings are unavailable."""
results = []
query_lower = query.lower()
for path in self.memory_dir.glob("*.md"):
try:
content = path.read_text(encoding="utf-8")
except Exception:
continue
for line in content.split("\n"):
stripped = line.strip().lstrip("- ")
if len(stripped) > 10 and query_lower in stripped.lower():
results.append({"id": path.name, "text": stripped, "score": 1.0})
if len(results) >= top_k:
return results
return results