232 lines
8.2 KiB
Python
232 lines
8.2 KiB
Python
"""4-layer memory system.
|
|
|
|
Layers:
|
|
1. Identity - SOUL.md + USER.md (handled by router.py)
|
|
2. Long-term - memory/MEMORY.md (learned facts, decisions)
|
|
3. Daily logs - memory/YYYY-MM-DD.md (timestamped entries)
|
|
4. Semantic - memory/embeddings.db (vector search over all memory)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sqlite3
|
|
import threading
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
from .config import Config
|
|
from .db import Database
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
class MemorySystem:
|
|
def __init__(self, config: Config, db: Database):
|
|
self.config = config
|
|
self.db = db
|
|
self.memory_dir = config.memory_dir
|
|
self._embedder = None
|
|
self._embed_lock = threading.Lock()
|
|
self._embed_db_path = self.memory_dir / "embeddings.db"
|
|
self._init_embed_db()
|
|
|
|
# ── Public API ──
|
|
|
|
def get_context(self, query: str) -> str:
|
|
"""Build memory context string for the system prompt."""
|
|
parts = []
|
|
|
|
# Long-term memory
|
|
lt = self._read_long_term()
|
|
if lt:
|
|
parts.append(f"## Long-Term Memory\n{lt}")
|
|
|
|
# Today's log
|
|
today_log = self._read_daily_log()
|
|
if today_log:
|
|
parts.append(f"## Today's Log\n{today_log}")
|
|
|
|
# Semantic search results
|
|
if query:
|
|
results = self.search(query, top_k=self.config.memory.search_top_k)
|
|
if results:
|
|
formatted = "\n".join(f"- {r['text']}" for r in results)
|
|
parts.append(f"## Related Memories\n{formatted}")
|
|
|
|
return "\n\n".join(parts) if parts else ""
|
|
|
|
def remember(self, text: str):
|
|
"""Save a fact/instruction to long-term memory."""
|
|
memory_path = self.memory_dir / "MEMORY.md"
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M")
|
|
entry = f"\n- [{timestamp}] {text}\n"
|
|
|
|
if memory_path.exists():
|
|
content = memory_path.read_text(encoding="utf-8")
|
|
else:
|
|
content = "# Long-Term Memory\n"
|
|
|
|
content += entry
|
|
memory_path.write_text(content, encoding="utf-8")
|
|
self._index_text(text, f"memory:long_term:{timestamp}")
|
|
log.info("Saved to long-term memory: %s", text[:80])
|
|
|
|
def log_daily(self, text: str):
|
|
"""Append an entry to today's daily log."""
|
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
log_path = self.memory_dir / f"{today}.md"
|
|
timestamp = datetime.now(timezone.utc).strftime("%H:%M")
|
|
|
|
if log_path.exists():
|
|
content = log_path.read_text(encoding="utf-8")
|
|
else:
|
|
content = f"# Daily Log - {today}\n"
|
|
|
|
content += f"\n- [{timestamp}] {text}\n"
|
|
log_path.write_text(content, encoding="utf-8")
|
|
self._index_text(text, f"daily:{today}:{timestamp}")
|
|
|
|
def search(self, query: str, top_k: int = 5) -> list[dict]:
|
|
"""Semantic search over all indexed memory."""
|
|
embedder = self._get_embedder()
|
|
if embedder is None:
|
|
return self._fallback_search(query, top_k)
|
|
|
|
query_vec = embedder.encode([query])[0]
|
|
return self._vector_search(query_vec, top_k)
|
|
|
|
def auto_flush(self, conv_id: str):
|
|
"""Summarize old messages and move to daily log."""
|
|
messages = self.db.get_messages(conv_id, limit=200)
|
|
if len(messages) < self.config.memory.flush_threshold:
|
|
return
|
|
|
|
# Take older messages for summarization
|
|
to_summarize = messages[:-10] # keep last 10 in context
|
|
text_block = "\n".join(
|
|
f"{m['role']}: {m['content'][:200]}" for m in to_summarize
|
|
if m.get("content")
|
|
)
|
|
|
|
summary = f"Conversation summary ({len(to_summarize)} messages): {text_block[:1000]}"
|
|
self.log_daily(summary)
|
|
log.info("Auto-flushed %d messages to daily log", len(to_summarize))
|
|
|
|
def reindex_all(self):
|
|
"""Rebuild the embedding index from all memory files."""
|
|
self._clear_embeddings()
|
|
for path in self.memory_dir.glob("*.md"):
|
|
content = path.read_text(encoding="utf-8")
|
|
for i, line in enumerate(content.split("\n")):
|
|
line = line.strip().lstrip("- ")
|
|
if len(line) > 10:
|
|
self._index_text(line, f"file:{path.name}:L{i}")
|
|
log.info("Reindexed all memory files")
|
|
|
|
# ── Private: Long-term memory ──
|
|
|
|
def _read_long_term(self) -> str:
|
|
path = self.memory_dir / "MEMORY.md"
|
|
if path.exists():
|
|
content = path.read_text(encoding="utf-8")
|
|
# Return last 2000 chars to keep prompt manageable
|
|
return content[-2000:] if len(content) > 2000 else content
|
|
return ""
|
|
|
|
def _read_daily_log(self) -> str:
|
|
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
|
path = self.memory_dir / f"{today}.md"
|
|
if path.exists():
|
|
content = path.read_text(encoding="utf-8")
|
|
return content[-1500:] if len(content) > 1500 else content
|
|
return ""
|
|
|
|
# ── Private: Embedding system ──
|
|
|
|
def _init_embed_db(self):
|
|
conn = sqlite3.connect(str(self._embed_db_path))
|
|
conn.execute("""
|
|
CREATE TABLE IF NOT EXISTS embeddings (
|
|
id TEXT PRIMARY KEY,
|
|
text TEXT NOT NULL,
|
|
vector BLOB NOT NULL
|
|
)
|
|
""")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def _get_embedder(self):
|
|
if self._embedder is not None:
|
|
return self._embedder
|
|
with self._embed_lock:
|
|
if self._embedder is not None:
|
|
return self._embedder
|
|
try:
|
|
from sentence_transformers import SentenceTransformer
|
|
model_name = self.config.memory.embedding_model
|
|
log.info("Loading embedding model: %s", model_name)
|
|
self._embedder = SentenceTransformer(model_name)
|
|
return self._embedder
|
|
except ImportError:
|
|
log.warning("sentence-transformers not installed; semantic search disabled")
|
|
return None
|
|
except Exception as e:
|
|
log.warning("Failed to load embedding model: %s", e)
|
|
return None
|
|
|
|
def _index_text(self, text: str, doc_id: str):
|
|
embedder = self._get_embedder()
|
|
if embedder is None:
|
|
return
|
|
vec = embedder.encode([text])[0]
|
|
conn = sqlite3.connect(str(self._embed_db_path))
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)",
|
|
(doc_id, text, vec.tobytes()),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]:
|
|
conn = sqlite3.connect(str(self._embed_db_path))
|
|
rows = conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
|
|
conn.close()
|
|
|
|
if not rows:
|
|
return []
|
|
|
|
scored = []
|
|
for doc_id, text, vec_bytes in rows:
|
|
vec = np.frombuffer(vec_bytes, dtype=np.float32)
|
|
sim = float(np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec) + 1e-8))
|
|
scored.append({"id": doc_id, "text": text, "score": sim})
|
|
|
|
scored.sort(key=lambda x: x["score"], reverse=True)
|
|
return scored[:top_k]
|
|
|
|
def _clear_embeddings(self):
|
|
conn = sqlite3.connect(str(self._embed_db_path))
|
|
conn.execute("DELETE FROM embeddings")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def _fallback_search(self, query: str, top_k: int) -> list[dict]:
|
|
"""Simple keyword search when embeddings are unavailable."""
|
|
results = []
|
|
query_lower = query.lower()
|
|
for path in self.memory_dir.glob("*.md"):
|
|
try:
|
|
content = path.read_text(encoding="utf-8")
|
|
except Exception:
|
|
continue
|
|
for line in content.split("\n"):
|
|
stripped = line.strip().lstrip("- ")
|
|
if len(stripped) > 10 and query_lower in stripped.lower():
|
|
results.append({"id": path.name, "text": stripped, "score": 1.0})
|
|
if len(results) >= top_k:
|
|
return results
|
|
return results
|