1.2: Fix thread safety in memory.py embedding DB

Replace 4 standalone sqlite3.connect()/conn.close() pairs with a
thread-local _embed_conn property, matching the pattern in db.py.
Adds WAL mode for better concurrent read/write performance.

This prevents potential collisions between scheduler threads and
Gradio request threads accessing the embedding database.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
cora-start
PeninsulaInd 2026-02-17 09:57:41 -06:00
parent 0bef1e71b3
commit 9002fc08d2
1 changed files with 16 additions and 15 deletions

View File

@ -30,6 +30,7 @@ class MemorySystem:
self._embedder = None self._embedder = None
self._embed_lock = threading.Lock() self._embed_lock = threading.Lock()
self._embed_db_path = self.memory_dir / "embeddings.db" self._embed_db_path = self.memory_dir / "embeddings.db"
self._embed_local = threading.local()
self._init_embed_db() self._init_embed_db()
# ── Public API ── # ── Public API ──
@ -163,17 +164,23 @@ class MemorySystem:
# ── Private: Embedding system ── # ── Private: Embedding system ──
@property
def _embed_conn(self) -> sqlite3.Connection:
"""Thread-local SQLite connection for embeddings DB (matches db.py pattern)."""
if not hasattr(self._embed_local, "conn"):
self._embed_local.conn = sqlite3.connect(str(self._embed_db_path))
self._embed_local.conn.execute("PRAGMA journal_mode=WAL")
return self._embed_local.conn
def _init_embed_db(self): def _init_embed_db(self):
conn = sqlite3.connect(str(self._embed_db_path)) self._embed_conn.execute("""
conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings ( CREATE TABLE IF NOT EXISTS embeddings (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
text TEXT NOT NULL, text TEXT NOT NULL,
vector BLOB NOT NULL vector BLOB NOT NULL
) )
""") """)
conn.commit() self._embed_conn.commit()
conn.close()
def _get_embedder(self): def _get_embedder(self):
if self._embedder is not None: if self._embedder is not None:
@ -200,18 +207,14 @@ class MemorySystem:
if embedder is None: if embedder is None:
return return
vec = embedder.encode([text])[0] vec = embedder.encode([text])[0]
conn = sqlite3.connect(str(self._embed_db_path)) self._embed_conn.execute(
conn.execute(
"INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)", "INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)",
(doc_id, text, vec.tobytes()), (doc_id, text, vec.tobytes()),
) )
conn.commit() self._embed_conn.commit()
conn.close()
def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]: def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]:
conn = sqlite3.connect(str(self._embed_db_path)) rows = self._embed_conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
rows = conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
conn.close()
if not rows: if not rows:
return [] return []
@ -228,10 +231,8 @@ class MemorySystem:
return scored[:top_k] return scored[:top_k]
def _clear_embeddings(self): def _clear_embeddings(self):
conn = sqlite3.connect(str(self._embed_db_path)) self._embed_conn.execute("DELETE FROM embeddings")
conn.execute("DELETE FROM embeddings") self._embed_conn.commit()
conn.commit()
conn.close()
def _fallback_search(self, query: str, top_k: int) -> list[dict]: def _fallback_search(self, query: str, top_k: int) -> list[dict]:
"""Simple keyword search when embeddings are unavailable.""" """Simple keyword search when embeddings are unavailable."""