From 9002fc08d2f170cb66f251d50a3ab7a552db5de4 Mon Sep 17 00:00:00 2001 From: PeninsulaInd Date: Tue, 17 Feb 2026 09:57:41 -0600 Subject: [PATCH] 1.2: Fix thread safety in memory.py embedding DB Replace 4 standalone sqlite3.connect()/conn.close() pairs with a thread-local _embed_conn property, matching the pattern in db.py. Adds WAL mode for better concurrent read/write performance. This prevents potential collisions between scheduler threads and Gradio request threads accessing the embedding database. Co-Authored-By: Claude Opus 4.6 --- cheddahbot/memory.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/cheddahbot/memory.py b/cheddahbot/memory.py index 1d350fb..98e7d80 100644 --- a/cheddahbot/memory.py +++ b/cheddahbot/memory.py @@ -30,6 +30,7 @@ class MemorySystem: self._embedder = None self._embed_lock = threading.Lock() self._embed_db_path = self.memory_dir / "embeddings.db" + self._embed_local = threading.local() self._init_embed_db() # ── Public API ── @@ -163,17 +164,23 @@ class MemorySystem: # ── Private: Embedding system ── + @property + def _embed_conn(self) -> sqlite3.Connection: + """Thread-local SQLite connection for embeddings DB (matches db.py pattern).""" + if not hasattr(self._embed_local, "conn"): + self._embed_local.conn = sqlite3.connect(str(self._embed_db_path)) + self._embed_local.conn.execute("PRAGMA journal_mode=WAL") + return self._embed_local.conn + def _init_embed_db(self): - conn = sqlite3.connect(str(self._embed_db_path)) - conn.execute(""" + self._embed_conn.execute(""" CREATE TABLE IF NOT EXISTS embeddings ( id TEXT PRIMARY KEY, text TEXT NOT NULL, vector BLOB NOT NULL ) """) - conn.commit() - conn.close() + self._embed_conn.commit() def _get_embedder(self): if self._embedder is not None: @@ -200,18 +207,14 @@ class MemorySystem: if embedder is None: return vec = embedder.encode([text])[0] - conn = sqlite3.connect(str(self._embed_db_path)) - conn.execute( + self._embed_conn.execute( "INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)", (doc_id, text, vec.tobytes()), ) - conn.commit() - conn.close() + self._embed_conn.commit() def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]: - conn = sqlite3.connect(str(self._embed_db_path)) - rows = conn.execute("SELECT id, text, vector FROM embeddings").fetchall() - conn.close() + rows = self._embed_conn.execute("SELECT id, text, vector FROM embeddings").fetchall() if not rows: return [] @@ -228,10 +231,8 @@ class MemorySystem: return scored[:top_k] def _clear_embeddings(self): - conn = sqlite3.connect(str(self._embed_db_path)) - conn.execute("DELETE FROM embeddings") - conn.commit() - conn.close() + self._embed_conn.execute("DELETE FROM embeddings") + self._embed_conn.commit() def _fallback_search(self, query: str, top_k: int) -> list[dict]: """Simple keyword search when embeddings are unavailable."""