253 lines
9.0 KiB
Python
253 lines
9.0 KiB
Python
"""4-layer memory system.
|
|
|
|
Layers:
|
|
1. Identity - SOUL.md + USER.md (handled by router.py)
|
|
2. Long-term - memory/MEMORY.md (learned facts, decisions)
|
|
3. Daily logs - memory/YYYY-MM-DD.md (timestamped entries)
|
|
4. Semantic - memory/embeddings.db (vector search over all memory)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sqlite3
|
|
import threading
|
|
from datetime import UTC, datetime
|
|
|
|
import numpy as np
|
|
|
|
from .config import Config
|
|
from .db import Database
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class MemorySystem:
|
|
def __init__(self, config: Config, db: Database):
|
|
self.config = config
|
|
self.db = db
|
|
self.memory_dir = config.memory_dir
|
|
self._embedder = None
|
|
self._embed_lock = threading.Lock()
|
|
self._embed_db_path = self.memory_dir / "embeddings.db"
|
|
self._embed_local = threading.local()
|
|
self._init_embed_db()
|
|
|
|
# ── Public API ──
|
|
|
|
def get_context(self, query: str) -> str:
|
|
"""Build memory context string for the system prompt."""
|
|
parts = []
|
|
|
|
# Long-term memory
|
|
lt = self._read_long_term()
|
|
if lt:
|
|
parts.append(f"## Long-Term Memory\n{lt}")
|
|
|
|
# Today's log
|
|
today_log = self._read_daily_log()
|
|
if today_log:
|
|
parts.append(f"## Today's Log\n{today_log}")
|
|
|
|
# Semantic search results
|
|
if query:
|
|
results = self.search(query, top_k=self.config.memory.search_top_k)
|
|
if results:
|
|
formatted = "\n".join(f"- {r['text']}" for r in results)
|
|
parts.append(f"## Related Memories\n{formatted}")
|
|
|
|
return "\n\n".join(parts) if parts else ""
|
|
|
|
def remember(self, text: str):
|
|
"""Save a fact/instruction to long-term memory."""
|
|
memory_path = self.memory_dir / "MEMORY.md"
|
|
timestamp = datetime.now(UTC).strftime("%Y-%m-%d %H:%M")
|
|
entry = f"\n- [{timestamp}] {text}\n"
|
|
|
|
if memory_path.exists():
|
|
content = memory_path.read_text(encoding="utf-8")
|
|
else:
|
|
content = "# Long-Term Memory\n"
|
|
|
|
content += entry
|
|
memory_path.write_text(content, encoding="utf-8")
|
|
self._index_text(text, f"memory:long_term:{timestamp}")
|
|
log.info("Saved to long-term memory: %s", text[:80])
|
|
|
|
def log_daily(self, text: str):
|
|
"""Append an entry to today's daily log."""
|
|
today = datetime.now(UTC).strftime("%Y-%m-%d")
|
|
log_path = self.memory_dir / f"{today}.md"
|
|
timestamp = datetime.now(UTC).strftime("%H:%M")
|
|
|
|
if log_path.exists():
|
|
content = log_path.read_text(encoding="utf-8")
|
|
else:
|
|
content = f"# Daily Log - {today}\n"
|
|
|
|
content += f"\n- [{timestamp}] {text}\n"
|
|
log_path.write_text(content, encoding="utf-8")
|
|
self._index_text(text, f"daily:{today}:{timestamp}")
|
|
|
|
def search(self, query: str, top_k: int = 5) -> list[dict]:
|
|
"""Semantic search over all indexed memory."""
|
|
embedder = self._get_embedder()
|
|
if embedder is None:
|
|
return self._fallback_search(query, top_k)
|
|
|
|
query_vec = embedder.encode([query])[0]
|
|
return self._vector_search(query_vec, top_k)
|
|
|
|
def auto_flush(self, conv_id: str):
|
|
"""Summarize old messages and move to daily log, then delete flushed messages."""
|
|
messages = self.db.get_messages(conv_id, limit=200)
|
|
if len(messages) < self.config.memory.flush_threshold:
|
|
return
|
|
|
|
# Take older messages for summarization, keep last 10 in context
|
|
to_summarize = messages[:-10]
|
|
if not to_summarize:
|
|
return
|
|
|
|
# Build a concise summary (skip tool results, keep user/assistant text)
|
|
summary_parts = []
|
|
for m in to_summarize:
|
|
role = m.get("role", "")
|
|
content = (m.get("content") or "").strip()
|
|
if not content or role == "tool":
|
|
continue
|
|
summary_parts.append(f"{role}: {content[:150]}")
|
|
|
|
if not summary_parts:
|
|
return
|
|
|
|
summary = f"Conversation summary ({len(to_summarize)} messages):\n" + "\n".join(
|
|
summary_parts[:20]
|
|
)
|
|
self.log_daily(summary)
|
|
|
|
# Delete the flushed messages from DB so they don't get re-flushed
|
|
flushed_ids = [m["id"] for m in to_summarize if "id" in m]
|
|
if flushed_ids:
|
|
self.db.delete_messages(flushed_ids)
|
|
|
|
log.info("Auto-flushed %d messages from conv %s", len(to_summarize), conv_id)
|
|
|
|
def reindex_all(self):
|
|
"""Rebuild the embedding index from all memory files."""
|
|
self._clear_embeddings()
|
|
for path in self.memory_dir.glob("*.md"):
|
|
content = path.read_text(encoding="utf-8")
|
|
for i, line in enumerate(content.split("\n")):
|
|
line = line.strip().lstrip("- ")
|
|
if len(line) > 10:
|
|
self._index_text(line, f"file:{path.name}:L{i}")
|
|
log.info("Reindexed all memory files")
|
|
|
|
# ── Private: Long-term memory ──
|
|
|
|
def _read_long_term(self) -> str:
|
|
path = self.memory_dir / "MEMORY.md"
|
|
if path.exists():
|
|
content = path.read_text(encoding="utf-8")
|
|
# Return last 2000 chars to keep prompt manageable
|
|
return content[-2000:] if len(content) > 2000 else content
|
|
return ""
|
|
|
|
def _read_daily_log(self) -> str:
|
|
today = datetime.now(UTC).strftime("%Y-%m-%d")
|
|
path = self.memory_dir / f"{today}.md"
|
|
if path.exists():
|
|
content = path.read_text(encoding="utf-8")
|
|
return content[-1500:] if len(content) > 1500 else content
|
|
return ""
|
|
|
|
# ── Private: Embedding system ──
|
|
|
|
@property
|
|
def _embed_conn(self) -> sqlite3.Connection:
|
|
"""Thread-local SQLite connection for embeddings DB (matches db.py pattern)."""
|
|
if not hasattr(self._embed_local, "conn"):
|
|
self._embed_local.conn = sqlite3.connect(str(self._embed_db_path))
|
|
self._embed_local.conn.execute("PRAGMA journal_mode=WAL")
|
|
return self._embed_local.conn
|
|
|
|
def _init_embed_db(self):
|
|
self._embed_conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS embeddings (
|
|
id TEXT PRIMARY KEY,
|
|
text TEXT NOT NULL,
|
|
vector BLOB NOT NULL
|
|
)
|
|
""")
|
|
self._embed_conn.commit()
|
|
|
|
def _get_embedder(self):
|
|
if self._embedder is not None:
|
|
return self._embedder
|
|
with self._embed_lock:
|
|
if self._embedder is not None:
|
|
return self._embedder
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
model_name = self.config.memory.embedding_model
|
|
log.info("Loading embedding model: %s", model_name)
|
|
self._embedder = SentenceTransformer(model_name)
|
|
return self._embedder
|
|
except ImportError:
|
|
log.warning("sentence-transformers not installed; semantic search disabled")
|
|
return None
|
|
except Exception as e:
|
|
log.warning("Failed to load embedding model: %s", e)
|
|
return None
|
|
|
|
def _index_text(self, text: str, doc_id: str):
|
|
embedder = self._get_embedder()
|
|
if embedder is None:
|
|
return
|
|
vec = embedder.encode([text])[0]
|
|
self._embed_conn.execute(
|
|
"INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)",
|
|
(doc_id, text, vec.tobytes()),
|
|
)
|
|
self._embed_conn.commit()
|
|
|
|
def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]:
|
|
rows = self._embed_conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
|
|
|
|
if not rows:
|
|
return []
|
|
|
|
scored = []
|
|
for doc_id, text, vec_bytes in rows:
|
|
vec = np.frombuffer(vec_bytes, dtype=np.float32)
|
|
sim = float(
|
|
np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec) + 1e-8)
|
|
)
|
|
scored.append({"id": doc_id, "text": text, "score": sim})
|
|
|
|
scored.sort(key=lambda x: x["score"], reverse=True)
|
|
return scored[:top_k]
|
|
|
|
def _clear_embeddings(self):
|
|
self._embed_conn.execute("DELETE FROM embeddings")
|
|
self._embed_conn.commit()
|
|
|
|
def _fallback_search(self, query: str, top_k: int) -> list[dict]:
|
|
"""Simple keyword search when embeddings are unavailable."""
|
|
results = []
|
|
query_lower = query.lower()
|
|
for path in self.memory_dir.glob("*.md"):
|
|
try:
|
|
content = path.read_text(encoding="utf-8")
|
|
except Exception:
|
|
continue
|
|
for line in content.split("\n"):
|
|
stripped = line.strip().lstrip("- ")
|
|
if len(stripped) > 10 and query_lower in stripped.lower():
|
|
results.append({"id": path.name, "text": stripped, "score": 1.0})
|
|
if len(results) >= top_k:
|
|
return results
|
|
return results
|