"""4-layer memory system. Layers: 1. Identity - SOUL.md + USER.md (handled by router.py) 2. Long-term - memory/MEMORY.md (learned facts, decisions) 3. Daily logs - memory/YYYY-MM-DD.md (timestamped entries) 4. Semantic - memory/embeddings.db (vector search over all memory) """ from __future__ import annotations import logging import sqlite3 import threading from datetime import UTC, datetime import numpy as np from .config import Config from .db import Database log = logging.getLogger(__name__) class MemorySystem: def __init__(self, config: Config, db: Database): self.config = config self.db = db self.memory_dir = config.memory_dir self._embedder = None self._embed_lock = threading.Lock() self._embed_db_path = self.memory_dir / "embeddings.db" self._embed_local = threading.local() self._init_embed_db() # ── Public API ── def get_context(self, query: str) -> str: """Build memory context string for the system prompt.""" parts = [] # Long-term memory lt = self._read_long_term() if lt: parts.append(f"## Long-Term Memory\n{lt}") # Today's log today_log = self._read_daily_log() if today_log: parts.append(f"## Today's Log\n{today_log}") # Semantic search results if query: results = self.search(query, top_k=self.config.memory.search_top_k) if results: formatted = "\n".join(f"- {r['text']}" for r in results) parts.append(f"## Related Memories\n{formatted}") return "\n\n".join(parts) if parts else "" def remember(self, text: str): """Save a fact/instruction to long-term memory.""" memory_path = self.memory_dir / "MEMORY.md" timestamp = datetime.now(UTC).strftime("%Y-%m-%d %H:%M") entry = f"\n- [{timestamp}] {text}\n" if memory_path.exists(): content = memory_path.read_text(encoding="utf-8") else: content = "# Long-Term Memory\n" content += entry memory_path.write_text(content, encoding="utf-8") self._index_text(text, f"memory:long_term:{timestamp}") log.info("Saved to long-term memory: %s", text[:80]) def log_daily(self, text: str): """Append an entry to today's daily log.""" today = datetime.now(UTC).strftime("%Y-%m-%d") log_path = self.memory_dir / f"{today}.md" timestamp = datetime.now(UTC).strftime("%H:%M") if log_path.exists(): content = log_path.read_text(encoding="utf-8") else: content = f"# Daily Log - {today}\n" content += f"\n- [{timestamp}] {text}\n" log_path.write_text(content, encoding="utf-8") self._index_text(text, f"daily:{today}:{timestamp}") def search(self, query: str, top_k: int = 5) -> list[dict]: """Semantic search over all indexed memory.""" embedder = self._get_embedder() if embedder is None: return self._fallback_search(query, top_k) query_vec = embedder.encode([query])[0] return self._vector_search(query_vec, top_k) def auto_flush(self, conv_id: str): """Summarize old messages and move to daily log, then delete flushed messages.""" messages = self.db.get_messages(conv_id, limit=200) if len(messages) < self.config.memory.flush_threshold: return # Take older messages for summarization, keep last 10 in context to_summarize = messages[:-10] if not to_summarize: return # Build a concise summary (skip tool results, keep user/assistant text) summary_parts = [] for m in to_summarize: role = m.get("role", "") content = (m.get("content") or "").strip() if not content or role == "tool": continue summary_parts.append(f"{role}: {content[:150]}") if not summary_parts: return summary = f"Conversation summary ({len(to_summarize)} messages):\n" + "\n".join( summary_parts[:20] ) self.log_daily(summary) # Delete the flushed messages from DB so they don't get re-flushed flushed_ids = [m["id"] for m in to_summarize if "id" in m] if flushed_ids: self.db.delete_messages(flushed_ids) log.info("Auto-flushed %d messages from conv %s", len(to_summarize), conv_id) def reindex_all(self): """Rebuild the embedding index from all memory files.""" self._clear_embeddings() for path in self.memory_dir.glob("*.md"): content = path.read_text(encoding="utf-8") for i, line in enumerate(content.split("\n")): line = line.strip().lstrip("- ") if len(line) > 10: self._index_text(line, f"file:{path.name}:L{i}") log.info("Reindexed all memory files") # ── Private: Long-term memory ── def _read_long_term(self) -> str: path = self.memory_dir / "MEMORY.md" if path.exists(): content = path.read_text(encoding="utf-8") # Return last 2000 chars to keep prompt manageable return content[-2000:] if len(content) > 2000 else content return "" def _read_daily_log(self) -> str: today = datetime.now(UTC).strftime("%Y-%m-%d") path = self.memory_dir / f"{today}.md" if path.exists(): content = path.read_text(encoding="utf-8") return content[-1500:] if len(content) > 1500 else content return "" # ── Private: Embedding system ── @property def _embed_conn(self) -> sqlite3.Connection: """Thread-local SQLite connection for embeddings DB (matches db.py pattern).""" if not hasattr(self._embed_local, "conn"): self._embed_local.conn = sqlite3.connect(str(self._embed_db_path)) self._embed_local.conn.execute("PRAGMA journal_mode=WAL") return self._embed_local.conn def _init_embed_db(self): self._embed_conn.execute(""" CREATE TABLE IF NOT EXISTS embeddings ( id TEXT PRIMARY KEY, text TEXT NOT NULL, vector BLOB NOT NULL ) """) self._embed_conn.commit() def _get_embedder(self): if self._embedder is not None: return self._embedder with self._embed_lock: if self._embedder is not None: return self._embedder try: from sentence_transformers import SentenceTransformer model_name = self.config.memory.embedding_model log.info("Loading embedding model: %s", model_name) self._embedder = SentenceTransformer(model_name) return self._embedder except ImportError: log.warning("sentence-transformers not installed; semantic search disabled") return None except Exception as e: log.warning("Failed to load embedding model: %s", e) return None def _index_text(self, text: str, doc_id: str): embedder = self._get_embedder() if embedder is None: return vec = embedder.encode([text])[0] self._embed_conn.execute( "INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)", (doc_id, text, vec.tobytes()), ) self._embed_conn.commit() def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]: rows = self._embed_conn.execute("SELECT id, text, vector FROM embeddings").fetchall() if not rows: return [] scored = [] for doc_id, text, vec_bytes in rows: vec = np.frombuffer(vec_bytes, dtype=np.float32) sim = float( np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec) + 1e-8) ) scored.append({"id": doc_id, "text": text, "score": sim}) scored.sort(key=lambda x: x["score"], reverse=True) return scored[:top_k] def _clear_embeddings(self): self._embed_conn.execute("DELETE FROM embeddings") self._embed_conn.commit() def _fallback_search(self, query: str, top_k: int) -> list[dict]: """Simple keyword search when embeddings are unavailable.""" results = [] query_lower = query.lower() for path in self.memory_dir.glob("*.md"): try: content = path.read_text(encoding="utf-8") except Exception: continue for line in content.split("\n"): stripped = line.strip().lstrip("- ") if len(stripped) > 10 and query_lower in stripped.lower(): results.append({"id": path.name, "text": stripped, "score": 1.0}) if len(results) >= top_k: return results return results