1.2: Fix thread safety in memory.py embedding DB
Replace 4 standalone sqlite3.connect()/conn.close() pairs with a thread-local _embed_conn property, matching the pattern in db.py. Adds WAL mode for better concurrent read/write performance. This prevents potential collisions between scheduler threads and Gradio request threads accessing the embedding database. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>cora-start
parent
0bef1e71b3
commit
9002fc08d2
|
|
@ -30,6 +30,7 @@ class MemorySystem:
|
|||
self._embedder = None
|
||||
self._embed_lock = threading.Lock()
|
||||
self._embed_db_path = self.memory_dir / "embeddings.db"
|
||||
self._embed_local = threading.local()
|
||||
self._init_embed_db()
|
||||
|
||||
# ── Public API ──
|
||||
|
|
@ -163,17 +164,23 @@ class MemorySystem:
|
|||
|
||||
# ── Private: Embedding system ──
|
||||
|
||||
@property
|
||||
def _embed_conn(self) -> sqlite3.Connection:
|
||||
"""Thread-local SQLite connection for embeddings DB (matches db.py pattern)."""
|
||||
if not hasattr(self._embed_local, "conn"):
|
||||
self._embed_local.conn = sqlite3.connect(str(self._embed_db_path))
|
||||
self._embed_local.conn.execute("PRAGMA journal_mode=WAL")
|
||||
return self._embed_local.conn
|
||||
|
||||
def _init_embed_db(self):
|
||||
conn = sqlite3.connect(str(self._embed_db_path))
|
||||
conn.execute("""
|
||||
self._embed_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS embeddings (
|
||||
id TEXT PRIMARY KEY,
|
||||
text TEXT NOT NULL,
|
||||
vector BLOB NOT NULL
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._embed_conn.commit()
|
||||
|
||||
def _get_embedder(self):
|
||||
if self._embedder is not None:
|
||||
|
|
@ -200,18 +207,14 @@ class MemorySystem:
|
|||
if embedder is None:
|
||||
return
|
||||
vec = embedder.encode([text])[0]
|
||||
conn = sqlite3.connect(str(self._embed_db_path))
|
||||
conn.execute(
|
||||
self._embed_conn.execute(
|
||||
"INSERT OR REPLACE INTO embeddings (id, text, vector) VALUES (?, ?, ?)",
|
||||
(doc_id, text, vec.tobytes()),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._embed_conn.commit()
|
||||
|
||||
def _vector_search(self, query_vec: np.ndarray, top_k: int) -> list[dict]:
|
||||
conn = sqlite3.connect(str(self._embed_db_path))
|
||||
rows = conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
|
||||
conn.close()
|
||||
rows = self._embed_conn.execute("SELECT id, text, vector FROM embeddings").fetchall()
|
||||
|
||||
if not rows:
|
||||
return []
|
||||
|
|
@ -228,10 +231,8 @@ class MemorySystem:
|
|||
return scored[:top_k]
|
||||
|
||||
def _clear_embeddings(self):
|
||||
conn = sqlite3.connect(str(self._embed_db_path))
|
||||
conn.execute("DELETE FROM embeddings")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
self._embed_conn.execute("DELETE FROM embeddings")
|
||||
self._embed_conn.commit()
|
||||
|
||||
def _fallback_search(self, query: str, top_k: int) -> list[dict]:
|
||||
"""Simple keyword search when embeddings are unavailable."""
|
||||
|
|
|
|||
Loading…
Reference in New Issue