1.2: Fix thread safety in memory.py embedding DB

Replace 4 standalone sqlite3.connect()/conn.close() pairs with a
thread-local _embed_conn property, matching the pattern in db.py.
Adds WAL mode for better concurrent read/write performance.

This prevents potential collisions between scheduler threads and
Gradio request threads accessing the embedding database.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
cora-start
PeninsulaInd 2026-02-17 09:57:41 -06:00
parent 0bef1e71b3
commit 9002fc08d2
1 changed files with 16 additions and 15 deletions

View File

@ -30,6 +30,7 @@ class MemorySystem:
self._embedder = None
self._embed_lock = threading.Lock()
self._embed_db_path = self.memory_dir / "embeddings.db"
self._embed_local = threading.local()
self._init_embed_db()
# ── Public API ──
@ -163,17 +164,23 @@ class MemorySystem:
# ── Private: Embedding system ──
@property
def _embed_conn(self) -> sqlite3.Connection:
"""Thread-local SQLite connection for embeddings DB (matches db.py pattern)."""
if not hasattr(self._embed_local, "conn"):
self._embed_local.conn = sqlite3.connect(str(self._embed_db_path))
self._embed_local.conn.execute("PRAGMA journal_mode=WAL")
return self._embed_local.conn
def _init_embed_db(self):
conn = sqlite3.connect(str(self._embed_db_path))
conn.execute("""
self._embed_conn.execute("""
CREATE TABLE IF NOT EXISTS embeddings (
id TEXT PRIMARY KEY,
text TEXT NOT NULL,
vector BLOB NOT NULL
)
""")
conn.commit()
conn.close()
self._embed_conn.commit()
def _get_embedder(self):
if self._embedder is not None:
@ -200,18 +207,14 @@ class MemorySystem:
if embedder is None:
return
vec = embedder.encode([text])[0]
conn = sqlite3.connect(str(self._embed_db_path))
conn.execute(
self._embed_conn.execute(
"INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)",
(doc_id, text, vec.tobytes()),
)
conn.commit()
conn.close()
self._embed_conn.commit()
def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]:
conn = sqlite3.connect(str(self._embed_db_path))
rows = conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
conn.close()
rows = self._embed_conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
if not rows:
return []
@ -228,10 +231,8 @@ class MemorySystem:
return scored[:top_k]
def _clear_embeddings(self):
conn = sqlite3.connect(str(self._embed_db_path))
conn.execute("DELETE FROM embeddings")
conn.commit()
conn.close()
self._embed_conn.execute("DELETE FROM embeddings")
self._embed_conn.commit()
def _fallback_search(self, query: str, top_k: int) -> list[dict]:
"""Simple keyword search when embeddings are unavailable."""