1.1: Delete dead code and fix all lint errors

Remove unused modules that were never called at startup:
- cheddahbot/skills/__init__.py (dead @skill decorator system)
- cheddahbot/providers/__init__.py (empty placeholder)
- cheddahbot/tools/build_skill.py (depends on dead skills system)
- cheddahbot/tools/build_tool.py (security risk: generates arbitrary Python)

Also fix all pre-existing ruff lint errors across the codebase:
- Fix import sorting, unused imports, line length violations
- Fix type comparisons (use `is` instead of `==`)
- Fix implicit Optional types (dict -> dict | None)
- Fix unused variables, ambiguous variable names
- Apply ruff format for consistent style

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
cora-start
PeninsulaInd 2026-02-17 09:56:36 -06:00
parent a7171673fc
commit 0bef1e71b3
32 changed files with 577 additions and 486 deletions

View File

@ -1,12 +1,11 @@
"""Entry point: python -m cheddahbot"""
import logging
import sys
from .agent import Agent
from .config import load_config
from .db import Database
from .llm import LLMAdapter
from .agent import Agent
from .ui import create_ui
logging.basicConfig(
@ -35,7 +34,9 @@ def main():
if llm.is_execution_brain_available():
log.info("Execution brain: Claude Code CLI found in PATH")
else:
log.warning("Execution brain: Claude Code CLI NOT found — heartbeat/scheduler tasks will fail")
log.warning(
"Execution brain: Claude Code CLI NOT found — heartbeat/scheduler tasks will fail"
)
log.info("Creating agent...")
agent = Agent(config, db, llm)
@ -43,6 +44,7 @@ def main():
# Phase 2+: Memory system
try:
from .memory import MemorySystem
log.info("Initializing memory system...")
memory = MemorySystem(config, db)
agent.set_memory(memory)
@ -52,6 +54,7 @@ def main():
# Phase 3+: Tool system
try:
from .tools import ToolRegistry
log.info("Initializing tool system...")
tools = ToolRegistry(config, db, agent)
agent.set_tools(tools)
@ -62,6 +65,7 @@ def main():
notification_bus = None
try:
from .notifications import NotificationBus
log.info("Initializing notification bus...")
notification_bus = NotificationBus(db)
except Exception as e:
@ -70,6 +74,7 @@ def main():
# Phase 3+: Scheduler
try:
from .scheduler import Scheduler
log.info("Starting scheduler...")
scheduler = Scheduler(config, db, agent, notification_bus=notification_bus)
scheduler.start()
@ -77,13 +82,12 @@ def main():
log.warning("Scheduler not available: %s", e)
log.info("Launching Gradio UI on %s:%s...", config.host, config.port)
app, css = create_ui(agent, config, llm, notification_bus=notification_bus)
app = create_ui(agent, config, llm, notification_bus=notification_bus)
app.launch(
server_name=config.host,
server_port=config.port,
pwa=True,
show_error=True,
css=css,
)

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import json
import logging
import uuid
from typing import Generator
from collections.abc import Generator
from .config import Config
from .db import Database
@ -69,12 +69,14 @@ class Agent:
# Load conversation history
history = self.db.get_messages(conv_id, limit=self.config.memory.max_context_messages)
messages = format_messages_for_llm(system_prompt, history, self.config.memory.max_context_messages)
messages = format_messages_for_llm(
system_prompt, history, self.config.memory.max_context_messages
)
# Agent loop: LLM call → tool execution → repeat
seen_tool_calls: set[str] = set() # track (name, args_json) to prevent duplicates
for iteration in range(MAX_TOOL_ITERATIONS):
for _iteration in range(MAX_TOOL_ITERATIONS):
full_response = ""
tool_calls = []
@ -88,7 +90,9 @@ class Agent:
# If no tool calls, we're done
if not tool_calls:
if full_response:
self.db.add_message(conv_id, "assistant", full_response, model=self.llm.current_model)
self.db.add_message(
conv_id, "assistant", full_response, model=self.llm.current_model
)
break
# Filter out duplicate tool calls
@ -104,21 +108,30 @@ class Agent:
if not unique_tool_calls:
# All tool calls were duplicates — force the model to respond
if full_response:
self.db.add_message(conv_id, "assistant", full_response, model=self.llm.current_model)
self.db.add_message(
conv_id, "assistant", full_response, model=self.llm.current_model
)
else:
yield "(I already have the information needed to answer.)"
break
# Store assistant message with tool calls
self.db.add_message(
conv_id, "assistant", full_response,
conv_id,
"assistant",
full_response,
tool_calls=[{"name": tc["name"], "input": tc["input"]} for tc in unique_tool_calls],
model=self.llm.current_model,
)
# Execute tools
if self._tools:
messages.append({"role": "assistant", "content": full_response or "I'll use some tools to help with that."})
messages.append(
{
"role": "assistant",
"content": full_response or "I'll use some tools to help with that.",
}
)
for tc in unique_tool_calls:
yield f"\n\n**Using tool: {tc['name']}**\n"
@ -129,11 +142,15 @@ class Agent:
yield f"```\n{result[:2000]}\n```\n\n"
self.db.add_message(conv_id, "tool", result, tool_result=tc["name"])
messages.append({"role": "user", "content": f'[Tool "{tc["name"]}" result]\n{result}'})
messages.append(
{"role": "user", "content": f'[Tool "{tc["name"]}" result]\n{result}'}
)
else:
# No tool system configured - just mention tool was requested
if full_response:
self.db.add_message(conv_id, "assistant", full_response, model=self.llm.current_model)
self.db.add_message(
conv_id, "assistant", full_response, model=self.llm.current_model
)
for tc in unique_tool_calls:
yield f"\n(Tool requested: {tc['name']} - tool system not yet initialized)\n"
break

View File

@ -43,10 +43,9 @@ class ClickUpTask:
options = cf.get("type_config", {}).get("options", [])
order_index = cf_value if isinstance(cf_value, int) else None
for opt in options:
if order_index is not None and opt.get("orderindex") == order_index:
cf_value = opt.get("name", cf_value)
break
elif opt.get("id") == cf_value:
if (
order_index is not None and opt.get("orderindex") == order_index
) or opt.get("id") == cf_value:
cf_value = opt.get("name", cf_value)
break
@ -72,7 +71,9 @@ class ClickUpTask:
class ClickUpClient:
"""Thin wrapper around the ClickUp REST API v2."""
def __init__(self, api_token: str, workspace_id: str = "", task_type_field_name: str = "Task Type"):
def __init__(
self, api_token: str, workspace_id: str = "", task_type_field_name: str = "Task Type"
):
self._token = api_token
self.workspace_id = workspace_id
self._task_type_field_name = task_type_field_name
@ -110,7 +111,9 @@ class ClickUpClient:
tasks_data = resp.json().get("tasks", [])
return [ClickUpTask.from_api(t, self._task_type_field_name) for t in tasks_data]
def get_tasks_from_space(self, space_id: str, statuses: list[str] | None = None) -> list[ClickUpTask]:
def get_tasks_from_space(
self, space_id: str, statuses: list[str] | None = None
) -> list[ClickUpTask]:
"""Traverse all folders and lists in a space to collect tasks."""
all_tasks: list[ClickUpTask] = []
list_ids = set()
@ -142,7 +145,9 @@ class ClickUpClient:
except httpx.HTTPStatusError as e:
log.warning("Failed to fetch tasks from list %s: %s", list_id, e)
log.info("Found %d tasks across %d lists in space %s", len(all_tasks), len(list_ids), space_id)
log.info(
"Found %d tasks across %d lists in space %s", len(all_tasks), len(list_ids), space_id
)
return all_tasks
# ── Write (with retry) ──
@ -164,7 +169,7 @@ class ClickUpClient:
raise
last_exc = e
if attempt < max_attempts:
wait = backoff ** attempt
wait = backoff**attempt
log.warning("Retry %d/%d after %.1fs: %s", attempt, max_attempts, wait, e)
time.sleep(wait)
raise last_exc
@ -172,10 +177,12 @@ class ClickUpClient:
def update_task_status(self, task_id: str, status: str) -> bool:
"""Update a task's status."""
try:
def _call():
resp = self._client.put(f"/task/{task_id}", json={"status": status})
resp.raise_for_status()
return resp
self._retry(_call)
log.info("Updated task %s status to '%s'", task_id, status)
return True
@ -186,6 +193,7 @@ class ClickUpClient:
def add_comment(self, task_id: str, text: str) -> bool:
"""Add a comment to a task."""
try:
def _call():
resp = self._client.post(
f"/task/{task_id}/comment",
@ -193,6 +201,7 @@ class ClickUpClient:
)
resp.raise_for_status()
return resp
self._retry(_call)
log.info("Added comment to task %s", task_id)
return True
@ -212,6 +221,7 @@ class ClickUpClient:
log.warning("Attachment file not found: %s", fp)
return False
try:
def _call():
with open(fp, "rb") as f:
resp = httpx.post(
@ -222,6 +232,7 @@ class ClickUpClient:
)
resp.raise_for_status()
return resp
self._retry(_call)
log.info("Uploaded attachment %s to task %s", fp.name, task_id)
return True

View File

@ -29,7 +29,9 @@ class SchedulerConfig:
@dataclass
class ShellConfig:
blocked_commands: list[str] = field(default_factory=lambda: ["rm -rf /", "format", ":(){:|:&};:"])
blocked_commands: list[str] = field(
default_factory=lambda: ["rm -rf /", "format", ":(){:|:&};:"]
)
require_approval: bool = False

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import json
import sqlite3
import threading
from datetime import datetime, timezone
from datetime import UTC, datetime
from pathlib import Path
@ -105,7 +105,8 @@ class Database:
) -> int:
now = _now()
cur = self._conn.execute(
"""INSERT INTO messages (conv_id, role, content, tool_calls, tool_result, model, created_at)
"""INSERT INTO messages
(conv_id, role, content, tool_calls, tool_result, model, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(
conv_id,
@ -117,9 +118,7 @@ class Database:
now,
),
)
self._conn.execute(
"UPDATE conversations SET updated_at = ? WHERE id = ?", (now, conv_id)
)
self._conn.execute("UPDATE conversations SET updated_at = ? WHERE id = ?", (now, conv_id))
self._conn.commit()
return cur.lastrowid
@ -148,9 +147,7 @@ class Database:
if not message_ids:
return
placeholders = ",".join("?" for _ in message_ids)
self._conn.execute(
f"DELETE FROM messages WHERE id IN ({placeholders})", message_ids
)
self._conn.execute(f"DELETE FROM messages WHERE id IN ({placeholders})", message_ids)
self._conn.commit()
# -- Scheduled Tasks --
@ -167,7 +164,8 @@ class Database:
def get_due_tasks(self) -> list[dict]:
now = _now()
rows = self._conn.execute(
"SELECT * FROM scheduled_tasks WHERE enabled = 1 AND (next_run IS NULL OR next_run <= ?)",
"SELECT * FROM scheduled_tasks"
" WHERE enabled = 1 AND (next_run IS NULL OR next_run <= ?)",
(now,),
).fetchall()
return [dict(r) for r in rows]
@ -180,15 +178,15 @@ class Database:
def disable_task(self, task_id: int):
"""Disable a scheduled task (e.g. after a one-time task has run)."""
self._conn.execute(
"UPDATE scheduled_tasks SET enabled = 0 WHERE id = ?", (task_id,)
)
self._conn.execute("UPDATE scheduled_tasks SET enabled = 0 WHERE id = ?", (task_id,))
self._conn.commit()
def log_task_run(self, task_id: int, result: str | None = None, error: str | None = None):
now = _now()
self._conn.execute(
"INSERT INTO task_run_logs (task_id, started_at, finished_at, result, error) VALUES (?, ?, ?, ?, ?)",
"INSERT INTO task_run_logs"
" (task_id, started_at, finished_at, result, error)"
" VALUES (?, ?, ?, ?, ?)",
(task_id, now, now, result, error),
)
self._conn.commit()
@ -231,11 +229,12 @@ class Database:
def get_notifications_after(self, after_id: int = 0, limit: int = 50) -> list[dict]:
"""Get notifications with id > after_id."""
rows = self._conn.execute(
"SELECT id, message, category, created_at FROM notifications WHERE id > ? ORDER BY id ASC LIMIT ?",
"SELECT id, message, category, created_at FROM notifications"
" WHERE id > ? ORDER BY id ASC LIMIT ?",
(after_id, limit),
).fetchall()
return [dict(r) for r in rows]
def _now() -> str:
return datetime.now(timezone.utc).isoformat()
return datetime.now(UTC).isoformat()

View File

@ -19,8 +19,8 @@ import os
import shutil
import subprocess
import sys
from collections.abc import Generator
from dataclasses import dataclass
from typing import Generator
import httpx
@ -96,22 +96,28 @@ class LLMAdapter:
model_id = CLAUDE_OPENROUTER_MAP[model_id]
provider = "openrouter"
else:
yield {"type": "text", "content": (
yield {
"type": "text",
"content": (
"To chat with Claude models, you need an OpenRouter API key "
"(set OPENROUTER_API_KEY in .env). Alternatively, select a local "
"model from Ollama or LM Studio."
)}
),
}
return
# Check if provider is available
if provider == "openrouter" and not self.openrouter_key:
yield {"type": "text", "content": (
yield {
"type": "text",
"content": (
"No API key configured. To use cloud models:\n"
"1. Get an OpenRouter API key at https://openrouter.ai/keys\n"
"2. Set OPENROUTER_API_KEY in your .env file\n\n"
"Or install Ollama (free, local) and pull a model:\n"
" ollama pull llama3.2"
)}
),
}
return
base_url, api_key = self._resolve_endpoint(provider)
@ -138,14 +144,21 @@ class LLMAdapter:
"""
claude_bin = shutil.which("claude")
if not claude_bin:
return "Error: `claude` CLI not found in PATH. Install Claude Code: npm install -g @anthropic-ai/claude-code"
return (
"Error: `claude` CLI not found in PATH. "
"Install Claude Code: npm install -g @anthropic-ai/claude-code"
)
# Pipe prompt through stdin to avoid Windows 8191-char command-line limit.
cmd = [
claude_bin, "-p",
"--output-format", "json",
"--tools", tools,
"--allowedTools", tools,
claude_bin,
"-p",
"--output-format",
"json",
"--tools",
tools,
"--allowedTools",
tools,
]
if model:
cmd.extend(["--model", model])
@ -170,7 +183,10 @@ class LLMAdapter:
env=env,
)
except FileNotFoundError:
return "Error: `claude` CLI not found. Install Claude Code: npm install -g @anthropic-ai/claude-code"
return (
"Error: `claude` CLI not found. "
"Install Claude Code: npm install -g @anthropic-ai/claude-code"
)
try:
stdout, stderr = proc.communicate(input=prompt, timeout=300)
@ -234,7 +250,9 @@ class LLMAdapter:
if idx not in tool_calls_accum:
tool_calls_accum[idx] = {
"id": tc.id or "",
"name": tc.function.name if tc.function and tc.function.name else "",
"name": tc.function.name
if tc.function and tc.function.name
else "",
"arguments": "",
}
if tc.function and tc.function.arguments:
@ -276,7 +294,7 @@ class LLMAdapter:
# ── Helpers ──
def _resolve_endpoint(self, provider: str) -> tuple[str, str]:
if provider == "openrouter":
if provider == "openrouter": # noqa: SIM116
return "https://openrouter.ai/api/v1", self.openrouter_key or "sk-placeholder"
elif provider == "ollama":
return f"{self.ollama_url}/v1", "ollama"
@ -295,6 +313,7 @@ class LLMAdapter:
def _get_openai(self):
if self._openai_mod is None:
import openai
self._openai_mod = openai
return self._openai_mod
@ -307,11 +326,13 @@ class LLMAdapter:
r = httpx.get(f"{self.ollama_url}/api/tags", timeout=3)
if r.status_code == 200:
for m in r.json().get("models", []):
models.append(ModelInfo(
models.append(
ModelInfo(
id=f"local/ollama/{m['name']}",
name=f"[Ollama] {m['name']}",
provider="ollama",
))
)
)
except Exception:
pass
# LM Studio
@ -319,11 +340,13 @@ class LLMAdapter:
r = httpx.get(f"{self.lmstudio_url}/v1/models", timeout=3)
if r.status_code == 200:
for m in r.json().get("data", []):
models.append(ModelInfo(
models.append(
ModelInfo(
id=f"local/lmstudio/{m['id']}",
name=f"[LM Studio] {m['id']}",
provider="lmstudio",
))
)
)
except Exception:
pass
return models
@ -333,14 +356,19 @@ class LLMAdapter:
models = []
if self.openrouter_key:
models.extend([
models.extend(
[
# Anthropic (via OpenRouter — system prompts work correctly)
ModelInfo("anthropic/claude-sonnet-4.5", "Claude Sonnet 4.5", "openrouter"),
ModelInfo("anthropic/claude-opus-4.6", "Claude Opus 4.6", "openrouter"),
# Google
ModelInfo("google/gemini-3-flash-preview", "Gemini 3 Flash Preview", "openrouter"),
ModelInfo(
"google/gemini-3-flash-preview", "Gemini 3 Flash Preview", "openrouter"
),
ModelInfo("google/gemini-2.5-flash", "Gemini 2.5 Flash", "openrouter"),
ModelInfo("google/gemini-2.5-flash-lite", "Gemini 2.5 Flash Lite", "openrouter"),
ModelInfo(
"google/gemini-2.5-flash-lite", "Gemini 2.5 Flash Lite", "openrouter"
),
# OpenAI
ModelInfo("openai/gpt-5-nano", "GPT-5 Nano", "openrouter"),
ModelInfo("openai/gpt-4o-mini", "GPT-4o Mini", "openrouter"),
@ -349,7 +377,8 @@ class LLMAdapter:
ModelInfo("x-ai/grok-4.1-fast", "Grok 4.1 Fast", "openrouter"),
ModelInfo("moonshotai/kimi-k2.5", "Kimi K2.5", "openrouter"),
ModelInfo("minimax/minimax-m2.5", "MiniMax M2.5", "openrouter"),
])
]
)
models.extend(self.discover_local_models())
return models

View File

@ -13,6 +13,7 @@ log = logging.getLogger(__name__)
# ── Speech-to-Text ──
def transcribe_audio(audio_path: str | Path) -> str:
"""Transcribe audio to text. Tries OpenAI Whisper API, falls back to local whisper."""
audio_path = Path(audio_path)
@ -38,14 +39,17 @@ def transcribe_audio(audio_path: str | Path) -> str:
def _transcribe_local(audio_path: Path) -> str:
import whisper
model = whisper.load_model("base")
result = model.transcribe(str(audio_path))
return result.get("text", "").strip()
def _transcribe_openai_api(audio_path: Path) -> str:
import openai
import os
import openai
key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENROUTER_API_KEY")
if not key:
raise ValueError("No API key for Whisper")
@ -57,18 +61,20 @@ def _transcribe_openai_api(audio_path: Path) -> str:
# ── Text-to-Speech ──
def text_to_speech(text: str, output_path: str | Path | None = None, voice: str = "en-US-AriaNeural") -> Path:
def text_to_speech(
text: str, output_path: str | Path | None = None, voice: str = "en-US-AriaNeural"
) -> Path:
"""Convert text to speech using edge-tts (free, no API key)."""
if output_path is None:
output_path = Path(tempfile.mktemp(suffix=".mp3"))
else:
output_path = Path(output_path)
output_path = Path(tempfile.mktemp(suffix=".mp3")) if output_path is None else Path(output_path)
try:
import edge_tts
async def _generate():
communicate = edge_tts.Communicate(text, voice)
await communicate.save(str(output_path))
asyncio.run(_generate())
return output_path
except ImportError:
@ -80,6 +86,7 @@ def text_to_speech(text: str, output_path: str | Path | None = None, voice: str
# ── Video Frame Extraction ──
def extract_video_frames(video_path: str | Path, max_frames: int = 5) -> list[Path]:
"""Extract key frames from a video using ffmpeg."""
video_path = Path(video_path)
@ -91,18 +98,37 @@ def extract_video_frames(video_path: str | Path, max_frames: int = 5) -> list[Pa
try:
# Get video duration
result = subprocess.run(
["ffprobe", "-v", "error", "-show_entries", "format=duration",
"-of", "default=noprint_wrappers=1:nokey=1", str(video_path)],
capture_output=True, text=True, timeout=10,
[
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
str(video_path),
],
capture_output=True,
text=True,
timeout=10,
)
duration = float(result.stdout.strip()) if result.stdout.strip() else 10.0
interval = max(duration / (max_frames + 1), 1.0)
# Extract frames
subprocess.run(
["ffmpeg", "-i", str(video_path), "-vf", f"fps=1/{interval}",
"-frames:v", str(max_frames), str(output_dir / "frame_%03d.jpg")],
capture_output=True, timeout=30,
[
"ffmpeg",
"-i",
str(video_path),
"-vf",
f"fps=1/{interval}",
"-frames:v",
str(max_frames),
str(output_dir / "frame_%03d.jpg"),
],
capture_output=True,
timeout=30,
)
frames = sorted(output_dir.glob("frame_*.jpg"))

View File

@ -12,8 +12,7 @@ from __future__ import annotations
import logging
import sqlite3
import threading
from datetime import datetime, timezone
from pathlib import Path
from datetime import UTC, datetime
import numpy as np
@ -61,7 +60,7 @@ class MemorySystem:
def remember(self, text: str):
"""Save a fact/instruction to long-term memory."""
memory_path = self.memory_dir / "MEMORY.md"
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M")
timestamp = datetime.now(UTC).strftime("%Y-%m-%d %H:%M")
entry = f"\n- [{timestamp}] {text}\n"
if memory_path.exists():
@ -76,9 +75,9 @@ class MemorySystem:
def log_daily(self, text: str):
"""Append an entry to today's daily log."""
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
today = datetime.now(UTC).strftime("%Y-%m-%d")
log_path = self.memory_dir / f"{today}.md"
timestamp = datetime.now(timezone.utc).strftime("%H:%M")
timestamp = datetime.now(UTC).strftime("%H:%M")
if log_path.exists():
content = log_path.read_text(encoding="utf-8")
@ -121,7 +120,9 @@ class MemorySystem:
if not summary_parts:
return
summary = f"Conversation summary ({len(to_summarize)} messages):\n" + "\n".join(summary_parts[:20])
summary = f"Conversation summary ({len(to_summarize)} messages):\n" + "\n".join(
summary_parts[:20]
)
self.log_daily(summary)
# Delete the flushed messages from DB so they don't get re-flushed
@ -153,7 +154,7 @@ class MemorySystem:
return ""
def _read_daily_log(self) -> str:
today = datetime.now(timezone.utc).strftime("%Y-%m-%d")
today = datetime.now(UTC).strftime("%Y-%m-%d")
path = self.memory_dir / f"{today}.md"
if path.exists():
content = path.read_text(encoding="utf-8")
@ -182,6 +183,7 @@ class MemorySystem:
return self._embedder
try:
from sentence_transformers import SentenceTransformer
model_name = self.config.memory.embedding_model
log.info("Loading embedding model: %s", model_name)
self._embedder = SentenceTransformer(model_name)
@ -217,7 +219,9 @@ class MemorySystem:
scored = []
for doc_id, text, vec_bytes in rows:
vec = np.frombuffer(vec_bytes, dtype=np.float32)
sim = float(np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec) + 1e-8))
sim = float(
np.dot(query_vec, vec) / (np.linalg.norm(query_vec) * np.linalg.norm(vec) + 1e-8)
)
scored.append({"id": doc_id, "text": text, "score": sim})
scored.sort(key=lambda x: x["score"], reverse=True)

View File

@ -9,7 +9,8 @@ from __future__ import annotations
import logging
import threading
from typing import Callable, TYPE_CHECKING
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .db import Database

View File

@ -1 +0,0 @@
# Reserved for future custom providers

View File

@ -6,7 +6,7 @@ import json
import logging
import re
import threading
from datetime import datetime, timezone
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from croniter import croniter
@ -31,8 +31,13 @@ def _extract_docx_paths(result: str) -> list[str]:
class Scheduler:
def __init__(self, config: Config, db: Database, agent: Agent,
notification_bus: NotificationBus | None = None):
def __init__(
self,
config: Config,
db: Database,
agent: Agent,
notification_bus: NotificationBus | None = None,
):
self.config = config
self.db = db
self.agent = agent
@ -48,20 +53,28 @@ class Scheduler:
self._thread = threading.Thread(target=self._poll_loop, daemon=True, name="scheduler")
self._thread.start()
self._heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True, name="heartbeat")
self._heartbeat_thread = threading.Thread(
target=self._heartbeat_loop, daemon=True, name="heartbeat"
)
self._heartbeat_thread.start()
# Start ClickUp polling if configured
if self.config.clickup.enabled:
self._clickup_thread = threading.Thread(target=self._clickup_loop, daemon=True, name="clickup")
self._clickup_thread = threading.Thread(
target=self._clickup_loop, daemon=True, name="clickup"
)
self._clickup_thread.start()
log.info("ClickUp polling started (interval=%dm)", self.config.clickup.poll_interval_minutes)
log.info(
"ClickUp polling started (interval=%dm)", self.config.clickup.poll_interval_minutes
)
else:
log.info("ClickUp integration disabled (no API token)")
log.info("Scheduler started (poll=%ds, heartbeat=%dm)",
log.info(
"Scheduler started (poll=%ds, heartbeat=%dm)",
self.config.scheduler.poll_interval_seconds,
self.config.scheduler.heartbeat_interval_minutes)
self.config.scheduler.heartbeat_interval_minutes,
)
def stop(self):
self._stop_event.set()
@ -100,7 +113,7 @@ class Scheduler:
self.db.disable_task(task["id"])
else:
# Cron schedule - calculate next run
now = datetime.now(timezone.utc)
now = datetime.now(UTC)
cron = croniter(schedule, now)
next_run = cron.get_next(datetime)
self.db.update_task_next_run(task["id"], next_run.isoformat())
@ -147,6 +160,7 @@ class Scheduler:
"""Lazy-init the ClickUp API client."""
if self._clickup_client is None:
from .clickup import ClickUpClient
self._clickup_client = ClickUpClient(
api_token=self.config.clickup.api_token,
workspace_id=self.config.clickup.workspace_id,
@ -216,9 +230,8 @@ class Scheduler:
def _process_clickup_task(self, task, active_ids: set[str]):
"""Discover a new ClickUp task, map to skill, decide action."""
from .clickup import ClickUpTask
now = datetime.now(timezone.utc).isoformat()
now = datetime.now(UTC).isoformat()
skill_map = self.config.clickup.skill_map
# Build state object
@ -270,8 +283,8 @@ class Scheduler:
self._notify(
f"New ClickUp task needs your approval.\n"
f"Task: **{task.name}** → Skill: `{tool_name}`\n"
f"Use `clickup_approve_task(\"{task.id}\")` to approve or "
f"`clickup_decline_task(\"{task.id}\")` to decline."
f'Use `clickup_approve_task("{task.id}")` to approve or '
f'`clickup_decline_task("{task.id}")` to decline.'
)
log.info("ClickUp task awaiting approval: %s%s", task.name, tool_name)
@ -296,7 +309,7 @@ class Scheduler:
task_id = state["clickup_task_id"]
task_name = state["clickup_task_name"]
skill_name = state["skill_name"]
now = datetime.now(timezone.utc).isoformat()
now = datetime.now(UTC).isoformat()
log.info("Executing ClickUp task: %s%s", task_name, skill_name)
@ -314,7 +327,7 @@ class Scheduler:
args = self._build_tool_args(state)
# Execute the skill via the tool registry
if hasattr(self.agent, '_tools') and self.agent._tools:
if hasattr(self.agent, "_tools") and self.agent._tools:
result = self.agent._tools.execute(skill_name, args)
else:
result = self.agent.execute_task(
@ -334,7 +347,7 @@ class Scheduler:
# Success
state["state"] = "completed"
state["completed_at"] = datetime.now(timezone.utc).isoformat()
state["completed_at"] = datetime.now(UTC).isoformat()
self.db.kv_set(kv_key, json.dumps(state))
# Update ClickUp
@ -357,13 +370,12 @@ class Scheduler:
# Failure
state["state"] = "failed"
state["error"] = str(e)
state["completed_at"] = datetime.now(timezone.utc).isoformat()
state["completed_at"] = datetime.now(UTC).isoformat()
self.db.kv_set(kv_key, json.dumps(state))
# Comment the error on ClickUp
client.add_comment(
task_id,
f"❌ CheddahBot failed to complete this task.\n\nError: {str(e)[:2000]}"
task_id, f"❌ CheddahBot failed to complete this task.\n\nError: {str(e)[:2000]}"
)
self._notify(

View File

@ -1,63 +0,0 @@
"""Skill registry with @skill decorator and loader."""
from __future__ import annotations
import importlib.util
import logging
from pathlib import Path
from typing import Callable
log = logging.getLogger(__name__)
_SKILLS: dict[str, "SkillDef"] = {}
class SkillDef:
def __init__(self, name: str, description: str, func: Callable):
self.name = name
self.description = description
self.func = func
def skill(name: str, description: str):
"""Decorator to register a skill."""
def decorator(func: Callable) -> Callable:
_SKILLS[name] = SkillDef(name, description, func)
return func
return decorator
def load_skill(path: Path):
"""Dynamically load a skill from a .py file."""
spec = importlib.util.spec_from_file_location(path.stem, path)
if spec and spec.loader:
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
log.info("Loaded skill from %s", path)
def discover_skills(skills_dir: Path):
"""Load all .py files from the skills directory."""
if not skills_dir.exists():
return
for path in skills_dir.glob("*.py"):
if path.name.startswith("_"):
continue
try:
load_skill(path)
except Exception as e:
log.warning("Failed to load skill %s: %s", path.name, e)
def list_skills() -> list[SkillDef]:
return list(_SKILLS.values())
def run_skill(name: str, **kwargs) -> str:
if name not in _SKILLS:
return f"Unknown skill: {name}"
try:
result = _SKILLS[name].func(**kwargs)
return str(result) if result is not None else "Done."
except Exception as e:
return f"Skill error: {e}"

View File

@ -4,11 +4,11 @@ from __future__ import annotations
import importlib
import inspect
import json
import logging
import pkgutil
from collections.abc import Callable
from pathlib import Path
from typing import Any, Callable, TYPE_CHECKING
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ..agent import Agent
@ -72,15 +72,15 @@ def _extract_params(func: Callable) -> dict:
prop: dict[str, Any] = {}
annotation = param.annotation
if annotation == str or annotation == inspect.Parameter.empty:
if annotation is str or annotation is inspect.Parameter.empty:
prop["type"] = "string"
elif annotation == int:
elif annotation is int:
prop["type"] = "integer"
elif annotation == float:
elif annotation is float:
prop["type"] = "number"
elif annotation == bool:
elif annotation is bool:
prop["type"] = "boolean"
elif annotation == list:
elif annotation is list:
prop["type"] = "array"
prop["items"] = {"type": "string"}
else:
@ -100,7 +100,7 @@ def _extract_params(func: Callable) -> dict:
class ToolRegistry:
"""Runtime tool registry with execution and schema generation."""
def __init__(self, config: "Config", db: "Database", agent: "Agent"):
def __init__(self, config: Config, db: Database, agent: Agent):
self.config = config
self.db = db
self.agent = agent

View File

@ -1,49 +0,0 @@
"""Meta-skill: create multi-step skills at runtime."""
from __future__ import annotations
import textwrap
from pathlib import Path
from . import tool
@tool("build_skill", "Create a new multi-step skill from a description", category="meta")
def build_skill(name: str, description: str, steps: str, ctx: dict = None) -> str:
"""Generate a new skill and save it to the skills directory.
Args:
name: Skill name (snake_case)
description: What the skill does
steps: Python code implementing the skill steps (must use @skill decorator)
"""
if not name.isidentifier():
return f"Invalid skill name: {name}. Must be a valid Python identifier."
if not ctx or not ctx.get("config"):
return "Config context not available."
skills_dir = ctx["config"].skills_dir
skills_dir.mkdir(parents=True, exist_ok=True)
module_code = textwrap.dedent(f'''\
"""Auto-generated skill: {description}"""
from __future__ import annotations
from cheddahbot.skills import skill
{steps}
''')
file_path = skills_dir / f"{name}.py"
if file_path.exists():
return f"Skill '{name}' already exists. Choose a different name."
file_path.write_text(module_code, encoding="utf-8")
# Try to load it
try:
from cheddahbot.skills import load_skill
load_skill(file_path)
return f"Skill '{name}' created at {file_path}"
except Exception as e:
return f"Skill created at {file_path} but failed to load: {e}"

View File

@ -1,48 +0,0 @@
"""Meta-tool: dynamically create new tools at runtime."""
from __future__ import annotations
import importlib
import textwrap
from pathlib import Path
from . import tool
@tool("build_tool", "Create a new tool from a description. The agent writes Python code with @tool decorator.", category="meta")
def build_tool(name: str, description: str, code: str, ctx: dict = None) -> str:
"""Generate a new tool module and hot-load it.
Args:
name: Tool name (snake_case)
description: What the tool does
code: Full Python code for the tool function (must use @tool decorator)
"""
if not name.isidentifier():
return f"Invalid tool name: {name}. Must be a valid Python identifier."
# Wrap code in a module with the import
module_code = textwrap.dedent(f'''\
"""Auto-generated tool: {description}"""
from __future__ import annotations
from . import tool
{code}
''')
# Write to tools directory
tools_dir = Path(__file__).parent
file_path = tools_dir / f"{name}.py"
if file_path.exists():
return f"Tool module '{name}' already exists. Choose a different name."
file_path.write_text(module_code, encoding="utf-8")
# Hot-import the new module
try:
importlib.import_module(f".{name}", package=__package__)
return f"Tool '{name}' created and loaded successfully at {file_path}"
except Exception as e:
# Clean up on failure
file_path.unlink(missing_ok=True)
return f"Failed to load tool '{name}': {e}"

View File

@ -2,13 +2,13 @@
from __future__ import annotations
from datetime import datetime, timezone
from . import tool
@tool("remember_this", "Save an important fact or instruction to long-term memory", category="memory")
def remember_this(text: str, ctx: dict = None) -> str:
@tool(
"remember_this", "Save an important fact or instruction to long-term memory", category="memory"
)
def remember_this(text: str, ctx: dict | None = None) -> str:
if ctx and ctx.get("memory"):
ctx["memory"].remember(text)
return f"Saved to memory: {text}"
@ -16,7 +16,7 @@ def remember_this(text: str, ctx: dict = None) -> str:
@tool("search_memory", "Search through saved memories", category="memory")
def search_memory(query: str, ctx: dict = None) -> str:
def search_memory(query: str, ctx: dict | None = None) -> str:
if ctx and ctx.get("memory"):
results = ctx["memory"].search(query)
if results:
@ -26,7 +26,7 @@ def search_memory(query: str, ctx: dict = None) -> str:
@tool("log_note", "Add a timestamped note to today's daily log", category="memory")
def log_note(text: str, ctx: dict = None) -> str:
def log_note(text: str, ctx: dict | None = None) -> str:
if ctx and ctx.get("memory"):
ctx["memory"].log_daily(text)
return f"Logged: {text}"
@ -34,7 +34,7 @@ def log_note(text: str, ctx: dict = None) -> str:
@tool("schedule_task", "Schedule a recurring or one-time task", category="scheduling")
def schedule_task(name: str, prompt: str, schedule: str, ctx: dict = None) -> str:
def schedule_task(name: str, prompt: str, schedule: str, ctx: dict | None = None) -> str:
"""Schedule a task. Schedule format: cron expression or 'once:YYYY-MM-DDTHH:MM'."""
if ctx and ctx.get("db"):
task_id = ctx["db"].add_scheduled_task(name, prompt, schedule)
@ -43,11 +43,15 @@ def schedule_task(name: str, prompt: str, schedule: str, ctx: dict = None) -> st
@tool("list_tasks", "List all scheduled tasks", category="scheduling")
def list_tasks(ctx: dict = None) -> str:
def list_tasks(ctx: dict | None = None) -> str:
if ctx and ctx.get("db"):
tasks = ctx["db"]._conn.execute(
tasks = (
ctx["db"]
._conn.execute(
"SELECT id, name, schedule, enabled, next_run FROM scheduled_tasks ORDER BY id"
).fetchall()
)
.fetchall()
)
if not tasks:
return "No scheduled tasks."
lines = []

View File

@ -33,7 +33,7 @@ def _get_clickup_states(db) -> dict[str, dict]:
parts = key.split(":")
if len(parts) == 4 and parts[3] == "state":
task_id = parts[2]
try:
try: # noqa: SIM105
states[task_id] = json.loads(value)
except json.JSONDecodeError:
pass
@ -47,7 +47,7 @@ def _get_clickup_states(db) -> dict[str, dict]:
"and custom fields directly from the ClickUp API.",
category="clickup",
)
def clickup_query_tasks(status: str = "", task_type: str = "", ctx: dict = None) -> str:
def clickup_query_tasks(status: str = "", task_type: str = "", ctx: dict | None = None) -> str:
"""Query ClickUp API for tasks, optionally filtered by status and task type."""
client = _get_clickup_client(ctx)
if not client:
@ -98,7 +98,7 @@ def clickup_query_tasks(status: str = "", task_type: str = "", ctx: dict = None)
"(discovered, awaiting_approval, executing, completed, failed, declined, unmapped).",
category="clickup",
)
def clickup_list_tasks(status: str = "", ctx: dict = None) -> str:
def clickup_list_tasks(status: str = "", ctx: dict | None = None) -> str:
"""List tracked ClickUp tasks, optionally filtered by state."""
db = ctx["db"]
states = _get_clickup_states(db)
@ -130,7 +130,7 @@ def clickup_list_tasks(status: str = "", ctx: dict = None) -> str:
"Check the detailed internal processing state of a ClickUp task by its ID.",
category="clickup",
)
def clickup_task_status(task_id: str, ctx: dict = None) -> str:
def clickup_task_status(task_id: str, ctx: dict | None = None) -> str:
"""Get detailed state for a specific tracked task."""
db = ctx["db"]
raw = db.kv_get(f"clickup:task:{task_id}:state")
@ -168,7 +168,7 @@ def clickup_task_status(task_id: str, ctx: dict = None) -> str:
"Approve a ClickUp task that is waiting for permission to execute.",
category="clickup",
)
def clickup_approve_task(task_id: str, ctx: dict = None) -> str:
def clickup_approve_task(task_id: str, ctx: dict | None = None) -> str:
"""Approve a task in awaiting_approval state."""
db = ctx["db"]
key = f"clickup:task:{task_id}:state"
@ -182,11 +182,13 @@ def clickup_approve_task(task_id: str, ctx: dict = None) -> str:
return f"Corrupted state data for task '{task_id}'."
if state.get("state") != "awaiting_approval":
return f"Task '{task_id}' is in state '{state.get('state')}', not 'awaiting_approval'. Cannot approve."
current = state.get("state")
return f"Task '{task_id}' is in state '{current}', not 'awaiting_approval'. Cannot approve."
state["state"] = "approved"
db.kv_set(key, json.dumps(state))
return f"Task '{state.get('clickup_task_name', task_id)}' approved for execution. It will run on the next scheduler cycle."
name = state.get("clickup_task_name", task_id)
return f"Task '{name}' approved for execution. It will run on the next scheduler cycle."
@tool(
@ -194,7 +196,7 @@ def clickup_approve_task(task_id: str, ctx: dict = None) -> str:
"Decline a ClickUp task that is waiting for permission to execute.",
category="clickup",
)
def clickup_decline_task(task_id: str, ctx: dict = None) -> str:
def clickup_decline_task(task_id: str, ctx: dict | None = None) -> str:
"""Decline a task in awaiting_approval state."""
db = ctx["db"]
key = f"clickup:task:{task_id}:state"
@ -208,7 +210,8 @@ def clickup_decline_task(task_id: str, ctx: dict = None) -> str:
return f"Corrupted state data for task '{task_id}'."
if state.get("state") != "awaiting_approval":
return f"Task '{task_id}' is in state '{state.get('state')}', not 'awaiting_approval'. Cannot decline."
current = state.get("state")
return f"Task '{task_id}' is in state '{current}', not 'awaiting_approval'. Cannot decline."
state["state"] = "declined"
db.kv_set(key, json.dumps(state))

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import csv
import io
import json
from pathlib import Path
@ -34,7 +33,8 @@ def read_csv(path: str, max_rows: int = 20) -> str:
lines.append(" | ".join(str(c)[:50] for c in row))
result = "\n".join(lines)
total_line_count = sum(1 for _ in open(p, encoding="utf-8-sig"))
with open(p, encoding="utf-8-sig") as fcount:
total_line_count = sum(1 for _ in fcount)
if total_line_count > max_rows + 1:
result += f"\n\n... ({total_line_count - 1} total rows, showing first {max_rows})"
return result
@ -66,7 +66,11 @@ def query_json(path: str, json_path: str) -> str:
try:
data = json.loads(p.read_text(encoding="utf-8"))
result = _navigate(data, json_path.split("."))
return json.dumps(result, indent=2, ensure_ascii=False) if not isinstance(result, str) else result
return (
json.dumps(result, indent=2, ensure_ascii=False)
if not isinstance(result, str)
else result
)
except Exception as e:
return f"Error: {e}"

View File

@ -21,7 +21,7 @@ from . import tool
),
category="system",
)
def delegate_task(task_description: str, ctx: dict = None) -> str:
def delegate_task(task_description: str, ctx: dict | None = None) -> str:
"""Delegate a task to the execution brain."""
if not ctx or "agent" not in ctx:
return "Error: delegate tool requires agent context."

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import os
from pathlib import Path
from . import tool

View File

@ -9,14 +9,22 @@ from . import tool
@tool("analyze_image", "Describe or analyze an image file", category="media")
def analyze_image(path: str, question: str = "Describe this image in detail.", ctx: dict = None) -> str:
def analyze_image(
path: str, question: str = "Describe this image in detail.", ctx: dict | None = None
) -> str:
p = Path(path).resolve()
if not p.exists():
return f"Image not found: {path}"
suffix = p.suffix.lower()
mime_map = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg",
".gif": "image/gif", ".webp": "image/webp", ".bmp": "image/bmp"}
mime_map = {
".png": "image/png",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".gif": "image/gif",
".webp": "image/webp",
".bmp": "image/bmp",
}
mime = mime_map.get(suffix, "image/png")
try:
@ -27,10 +35,13 @@ def analyze_image(path: str, question: str = "Describe this image in detail.", c
if ctx and ctx.get("agent"):
agent = ctx["agent"]
messages = [
{"role": "user", "content": [
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{data}"}},
]},
],
},
]
result_parts = []
for chunk in agent.llm.chat(messages, stream=False):

View File

@ -3,8 +3,8 @@
Autonomous workflow:
1. Generate 7 compliant headlines (chat brain)
2. AI judge picks the 2 best (chat brain)
3. Write 2 full press releases (execution brain × 2)
4. Generate 2 JSON-LD schemas (execution brain × 2, Sonnet + WebSearch)
3. Write 2 full press releases (execution brain x 2)
4. Generate 2 JSON-LD schemas (execution brain x 2, Sonnet + WebSearch)
5. Save 4 files, return cost summary
"""
@ -14,7 +14,7 @@ import json
import logging
import re
import time
from datetime import datetime
from datetime import UTC, datetime
from pathlib import Path
from ..docx_export import text_to_docx
@ -47,6 +47,7 @@ def _set_status(ctx: dict | None, message: str) -> None:
# Helpers
# ---------------------------------------------------------------------------
def _load_skill(filename: str) -> str:
"""Read a markdown skill file from the skills/ directory."""
path = _SKILLS_DIR / filename
@ -137,8 +138,10 @@ def _clean_pr_output(raw: str, headline: str) -> str:
# Prompt builders
# ---------------------------------------------------------------------------
def _build_headline_prompt(topic: str, company_name: str, url: str,
lsi_terms: str, headlines_ref: str) -> str:
def _build_headline_prompt(
topic: str, company_name: str, url: str, lsi_terms: str, headlines_ref: str
) -> str:
"""Build the prompt for Step 1: generate 7 headlines."""
prompt = (
f"Generate exactly 7 unique press release headline options for the following.\n\n"
@ -266,7 +269,7 @@ def _fuzzy_find_anchor(text: str, company_name: str, topic: str) -> str | None:
candidate = context[:phrase_end].strip()
# Clean: stop at sentence boundaries
for sep in (".", ",", ";", "\n"):
if sep in candidate[len(company_name):]:
if sep in candidate[len(company_name) :]:
break
else:
return candidate
@ -276,10 +279,17 @@ def _fuzzy_find_anchor(text: str, company_name: str, topic: str) -> str | None:
return None
def _build_pr_prompt(headline: str, topic: str, company_name: str,
url: str, lsi_terms: str, required_phrase: str,
skill_text: str, companies_file: str,
anchor_phrase: str = "") -> str:
def _build_pr_prompt(
headline: str,
topic: str,
company_name: str,
url: str,
lsi_terms: str,
required_phrase: str,
skill_text: str,
companies_file: str,
anchor_phrase: str = "",
) -> str:
"""Build the prompt for Step 3: write one full press release."""
prompt = (
f"{skill_text}\n\n"
@ -299,10 +309,10 @@ def _build_pr_prompt(headline: str, topic: str, company_name: str,
if anchor_phrase:
prompt += (
f'\nANCHOR TEXT REQUIREMENT: You MUST include the exact phrase '
f"\nANCHOR TEXT REQUIREMENT: You MUST include the exact phrase "
f'"{anchor_phrase}" somewhere naturally in the body of the press '
f'release. This phrase will be used as anchor text for an SEO link. '
f'Work it into a sentence where it reads naturally — for example: '
f"release. This phrase will be used as anchor text for an SEO link. "
f"Work it into a sentence where it reads naturally — for example: "
f'"As a {anchor_phrase.split(company_name, 1)[-1].strip()} provider, '
f'{company_name}..." or "{anchor_phrase} continues to...".\n'
)
@ -328,8 +338,7 @@ def _build_pr_prompt(headline: str, topic: str, company_name: str,
return prompt
def _build_schema_prompt(pr_text: str, company_name: str, url: str,
skill_text: str) -> str:
def _build_schema_prompt(pr_text: str, company_name: str, url: str, skill_text: str) -> str:
"""Build the prompt for Step 4: generate JSON-LD schema for one PR."""
prompt = (
f"{skill_text}\n\n"
@ -342,10 +351,7 @@ def _build_schema_prompt(pr_text: str, company_name: str, url: str,
"- No markdown fences, no commentary, no explanations\n"
"- The very first character of your output must be {\n"
)
prompt += (
f"\nCompany name: {company_name}\n\n"
f"Press release text:\n{pr_text}"
)
prompt += f"\nCompany name: {company_name}\n\nPress release text:\n{pr_text}"
return prompt
@ -353,6 +359,7 @@ def _build_schema_prompt(pr_text: str, company_name: str, url: str,
# Main tool
# ---------------------------------------------------------------------------
@tool(
"write_press_releases",
description=(
@ -371,7 +378,7 @@ def write_press_releases(
lsi_terms: str = "",
required_phrase: str = "",
clickup_task_id: str = "",
ctx: dict = None,
ctx: dict | None = None,
) -> str:
"""Run the full press-release pipeline and return results + cost summary."""
if not ctx or "agent" not in ctx:
@ -408,11 +415,13 @@ def write_press_releases(
{"role": "user", "content": headline_prompt},
]
headlines_raw = _chat_call(agent, messages)
cost_log.append({
cost_log.append(
{
"step": "1. Generate 7 headlines",
"model": agent.llm.current_model,
"elapsed_s": round(time.time() - step_start, 1),
})
}
)
if not headlines_raw.strip():
return "Error: headline generation returned empty result."
@ -432,20 +441,36 @@ def write_press_releases(
{"role": "user", "content": judge_prompt},
]
judge_result = _chat_call(agent, messages)
cost_log.append({
cost_log.append(
{
"step": "2. Judge picks best 2",
"model": agent.llm.current_model,
"elapsed_s": round(time.time() - step_start, 1),
})
}
)
# Parse the two winning headlines
winners = [line.strip().lstrip("0123456789.-) ") for line in judge_result.strip().splitlines() if line.strip()]
winners = [
line.strip().lstrip("0123456789.-) ")
for line in judge_result.strip().splitlines()
if line.strip()
]
if len(winners) < 2:
all_headlines = [line.strip().lstrip("0123456789.-) ") for line in headlines_raw.strip().splitlines() if line.strip()]
winners = all_headlines[:2] if len(all_headlines) >= 2 else [all_headlines[0], all_headlines[0]] if all_headlines else ["Headline A", "Headline B"]
all_headlines = [
line.strip().lstrip("0123456789.-) ")
for line in headlines_raw.strip().splitlines()
if line.strip()
]
winners = (
all_headlines[:2]
if len(all_headlines) >= 2
else [all_headlines[0], all_headlines[0]]
if all_headlines
else ["Headline A", "Headline B"]
)
winners = winners[:2]
# ── Step 3: Write 2 press releases (execution brain × 2) ─────────────
# ── Step 3: Write 2 press releases (execution brain x 2) ─────────────
log.info("[PR Pipeline] Step 3/4: Writing 2 press releases...")
anchor_phrase = _derive_anchor_phrase(company_name, topic)
pr_texts: list[str] = []
@ -454,21 +479,29 @@ def write_press_releases(
anchor_warnings: list[str] = []
for i, headline in enumerate(winners):
log.info("[PR Pipeline] Writing PR %d/2: %s", i + 1, headline[:60])
_set_status(ctx, f"Step 3/4: Writing press release {i+1}/2 — {headline[:60]}...")
_set_status(ctx, f"Step 3/4: Writing press release {i + 1}/2 — {headline[:60]}...")
step_start = time.time()
pr_prompt = _build_pr_prompt(
headline, topic, company_name, url, lsi_terms,
required_phrase, pr_skill, companies_file,
headline,
topic,
company_name,
url,
lsi_terms,
required_phrase,
pr_skill,
companies_file,
anchor_phrase=anchor_phrase,
)
exec_tools = "Bash,Read,Edit,Write,Glob,Grep,WebFetch"
raw_result = agent.execute_task(pr_prompt, tools=exec_tools)
elapsed = round(time.time() - step_start, 1)
cost_log.append({
"step": f"3{chr(97+i)}. Write PR '{headline[:40]}...'",
cost_log.append(
{
"step": f"3{chr(97 + i)}. Write PR '{headline[:40]}...'",
"model": "execution-brain (default)",
"elapsed_s": elapsed,
})
}
)
# Clean output: find the headline, strip preamble and markdown
clean_result = _clean_pr_output(raw_result, headline)
@ -487,13 +520,13 @@ def write_press_releases(
if fuzzy:
log.info("PR %d: exact anchor not found, fuzzy match: '%s'", i + 1, fuzzy)
anchor_warnings.append(
f"PR {chr(65+i)}: Exact anchor phrase \"{anchor_phrase}\" not found. "
f"Closest match: \"{fuzzy}\" — you may want to adjust before submitting."
f'PR {chr(65 + i)}: Exact anchor phrase "{anchor_phrase}" not found. '
f'Closest match: "{fuzzy}" — you may want to adjust before submitting.'
)
else:
log.warning("PR %d: anchor phrase '%s' NOT found", i + 1, anchor_phrase)
anchor_warnings.append(
f"PR {chr(65+i)}: Anchor phrase \"{anchor_phrase}\" NOT found in the text. "
f'PR {chr(65 + i)}: Anchor phrase "{anchor_phrase}" NOT found in the text. '
f"You'll need to manually add it before submitting to PA."
)
@ -515,7 +548,7 @@ def write_press_releases(
schema_files: list[str] = []
for i, pr_text in enumerate(pr_texts):
log.info("[PR Pipeline] Schema %d/2 for: %s", i + 1, winners[i][:60])
_set_status(ctx, f"Step 4/4: Generating schema {i+1}/2...")
_set_status(ctx, f"Step 4/4: Generating schema {i + 1}/2...")
step_start = time.time()
schema_prompt = _build_schema_prompt(pr_text, company_name, url, schema_skill)
exec_tools = "WebSearch,WebFetch"
@ -525,11 +558,13 @@ def write_press_releases(
model=SONNET_CLI_MODEL,
)
elapsed = round(time.time() - step_start, 1)
cost_log.append({
"step": f"4{chr(97+i)}. Schema for PR {i+1}",
cost_log.append(
{
"step": f"4{chr(97 + i)}. Schema for PR {i + 1}",
"model": SONNET_CLI_MODEL,
"elapsed_s": elapsed,
})
}
)
# Extract clean JSON and force correct mainEntityOfPage
schema_json = _extract_json(result)
@ -573,7 +608,7 @@ def write_press_releases(
# Anchor text warnings
if anchor_warnings:
output_parts.append("## Anchor Text Warnings\n")
output_parts.append(f"Required anchor phrase: **\"{anchor_phrase}\"**\n")
output_parts.append(f'Required anchor phrase: **"{anchor_phrase}"**\n')
for warning in anchor_warnings:
output_parts.append(f"- {warning}")
output_parts.append("")
@ -608,10 +643,11 @@ def write_press_releases(
# Post a result comment
attach_note = f"\n📎 {uploaded_count} file(s) attached." if uploaded_count else ""
result_text = "\n".join(output_parts)[:3000]
comment = (
f"✅ CheddahBot completed this task (via chat).\n\n"
f"Skill: write_press_releases\n"
f"Result:\n{'\n'.join(output_parts)[:3000]}{attach_note}"
f"Result:\n{result_text}{attach_note}"
)
client.add_comment(clickup_task_id, comment)
@ -622,19 +658,19 @@ def write_press_releases(
db = ctx.get("db")
if db:
import json as _json
kv_key = f"clickup:task:{clickup_task_id}:state"
existing = db.kv_get(kv_key)
if existing:
from datetime import timezone
state = _json.loads(existing)
state["state"] = "completed"
state["completed_at"] = datetime.now(timezone.utc).isoformat()
state["completed_at"] = datetime.now(UTC).isoformat()
state["deliverable_paths"] = docx_files
db.kv_set(kv_key, _json.dumps(state))
client.close()
output_parts.append(f"\n## ClickUp Sync\n")
output_parts.append("\n## ClickUp Sync\n")
output_parts.append(f"- Task `{clickup_task_id}` updated")
output_parts.append(f"- {uploaded_count} file(s) uploaded")
output_parts.append(f"- Status set to '{config.clickup.review_status}'")
@ -642,7 +678,7 @@ def write_press_releases(
log.info("ClickUp sync complete for task %s", clickup_task_id)
except Exception as e:
log.error("ClickUp sync failed for task %s: %s", clickup_task_id, e)
output_parts.append(f"\n## ClickUp Sync\n")
output_parts.append("\n## ClickUp Sync\n")
output_parts.append(f"- **Sync failed:** {e}")
output_parts.append("- Press release results are still valid above")
@ -683,7 +719,7 @@ def _parse_company_data(companies_text: str) -> dict[str, dict]:
current_data = {"name": current_company}
elif current_company:
if line.startswith("- **PA Org ID:**"):
try:
try: # noqa: SIM105
current_data["org_id"] = int(line.split(":**")[1].strip())
except (ValueError, IndexError):
pass
@ -804,20 +840,21 @@ def _extract_json(text: str) -> str | None:
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
candidate = text[start:end + 1]
candidate = text[start : end + 1]
try:
json.loads(candidate)
return candidate
except json.JSONDecodeError:
pass
return None # noqa: RET501
return None
# ---------------------------------------------------------------------------
# Submit tool
# ---------------------------------------------------------------------------
def _resolve_branded_url(branded_url: str, company_data: dict | None) -> str:
"""Resolve the branded link URL.
@ -867,12 +904,12 @@ def _build_links(
if fuzzy:
links.append({"url": target_url, "anchor": fuzzy})
warnings.append(
f"Brand+keyword link: exact phrase \"{anchor_phrase}\" not found. "
f"Used fuzzy match: \"{fuzzy}\""
f'Brand+keyword link: exact phrase "{anchor_phrase}" not found. '
f'Used fuzzy match: "{fuzzy}"'
)
else:
warnings.append(
f"Brand+keyword link: anchor phrase \"{anchor_phrase}\" NOT found in PR text. "
f'Brand+keyword link: anchor phrase "{anchor_phrase}" NOT found in PR text. '
f"Link to {target_url} could not be injected — add it manually in PA."
)
@ -883,7 +920,7 @@ def _build_links(
links.append({"url": branded_url_resolved, "anchor": company_name})
else:
warnings.append(
f"Branded link: company name \"{company_name}\" not found in PR text. "
f'Branded link: company name "{company_name}" not found in PR text. '
f"Link to {branded_url_resolved} could not be injected."
)
@ -911,7 +948,7 @@ def submit_press_release(
pr_text: str = "",
file_path: str = "",
description: str = "",
ctx: dict = None,
ctx: dict | None = None,
) -> str:
"""Submit a finished press release to Press Advantage as a draft."""
# --- Get config ---
@ -991,7 +1028,11 @@ def submit_press_release(
# --- Build links ---
branded_url_resolved = _resolve_branded_url(branded_url, company_data)
link_list, link_warnings = _build_links(
pr_text, company_name, topic, target_url, branded_url_resolved,
pr_text,
company_name,
topic,
target_url,
branded_url_resolved,
)
# --- Convert to HTML ---
@ -1039,7 +1080,7 @@ def submit_press_release(
if link_list:
output_parts.append("\n**Links:**")
for link in link_list:
output_parts.append(f" - \"{link['anchor']}\"{link['url']}")
output_parts.append(f' - "{link["anchor"]}"{link["url"]}')
if link_warnings:
output_parts.append("\n**Link warnings:**")

View File

@ -3,7 +3,6 @@
from __future__ import annotations
import subprocess
import sys
from . import tool

View File

@ -51,7 +51,7 @@ def fetch_url(url: str) -> str:
tag.decompose()
text = soup.get_text(separator="\n", strip=True)
# Collapse whitespace
lines = [l.strip() for l in text.split("\n") if l.strip()]
lines = [line.strip() for line in text.split("\n") if line.strip()]
text = "\n".join(lines)
if len(text) > 15000:
text = text[:15000] + "\n... (truncated)"

View File

@ -16,8 +16,9 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
_HEAD = '<meta name="viewport" content="width=device-width, initial-scale=1">'
_CSS = """
.contain { max-width: 900px; margin: auto; }
footer { display: none !important; }
.notification-banner {
background: #1a1a2e;
@ -27,11 +28,54 @@ footer { display: none !important; }
margin-bottom: 8px;
font-size: 0.9em;
}
/* Mobile optimizations */
@media (max-width: 768px) {
.gradio-container { padding: 4px !important; }
/* 16px base font on chat messages to prevent iOS zoom on focus */
.chatbot .message-row .message { font-size: 16px !important; }
/* Chat container: scrollable, no zoom-stuck overflow */
.chatbot {
overflow-y: auto !important;
-webkit-overflow-scrolling: touch;
height: calc(100dvh - 220px) !important;
max-height: none !important;
}
/* Tighten up header/status bar spacing */
.gradio-container > .main > .wrap { gap: 8px !important; }
/* Keep input area pinned at the bottom, never overlapping chat */
.gradio-container > .main {
display: flex;
flex-direction: column;
height: 100dvh;
}
.gradio-container > .main > .wrap:last-child {
position: sticky;
bottom: 0;
background: var(--background-fill-primary);
padding-bottom: env(safe-area-inset-bottom, 8px);
z-index: 10;
}
/* Input box: prevent tiny text that triggers zoom */
.multimodal-textbox textarea,
.multimodal-textbox input {
font-size: 16px !important;
}
/* Reduce model dropdown row padding */
.contain .gr-row { gap: 4px !important; }
}
"""
def create_ui(agent: Agent, config: Config, llm: LLMAdapter,
notification_bus: NotificationBus | None = None) -> gr.Blocks:
def create_ui(
agent: Agent, config: Config, llm: LLMAdapter, notification_bus: NotificationBus | None = None
) -> gr.Blocks:
"""Build and return the Gradio app."""
available_models = llm.list_chat_models()
@ -41,7 +85,7 @@ def create_ui(agent: Agent, config: Config, llm: LLMAdapter,
exec_status = "available" if llm.is_execution_brain_available() else "unavailable"
clickup_status = "enabled" if config.clickup.enabled else "disabled"
with gr.Blocks(title="CheddahBot") as app:
with gr.Blocks(title="CheddahBot", fill_width=True, css=_CSS, head=_HEAD) as app:
gr.Markdown("# CheddahBot", elem_classes=["contain"])
gr.Markdown(
f"*Chat Brain:* `{current_model}` &nbsp;|&nbsp; "
@ -90,7 +134,6 @@ def create_ui(agent: Agent, config: Config, llm: LLMAdapter,
sources=["upload", "microphone"],
)
# -- Event handlers --
def on_model_change(model_id):
@ -125,12 +168,23 @@ def create_ui(agent: Agent, config: Config, llm: LLMAdapter,
processed_files = []
for f in files:
fpath = f if isinstance(f, str) else f.get("path", f.get("name", ""))
if fpath and Path(fpath).suffix.lower() in (".wav", ".mp3", ".ogg", ".webm", ".m4a"):
if fpath and Path(fpath).suffix.lower() in (
".wav",
".mp3",
".ogg",
".webm",
".m4a",
):
try:
from .media import transcribe_audio
transcript = transcribe_audio(fpath)
if transcript:
text = f"{text}\n[Voice message]: {transcript}" if text else f"[Voice message]: {transcript}"
text = (
f"{text}\n[Voice message]: {transcript}"
if text
else f"[Voice message]: {transcript}"
)
continue
except Exception as e:
log.warning("Audio transcription failed: %s", e)
@ -142,13 +196,13 @@ def create_ui(agent: Agent, config: Config, llm: LLMAdapter,
file_names = [Path(f).name for f in processed_files]
user_display += f"\n[Attached: {', '.join(file_names)}]"
chat_history = chat_history + [{"role": "user", "content": user_display}]
chat_history = [*chat_history, {"role": "user", "content": user_display}]
yield chat_history, gr.update(value=None)
# Stream assistant response
try:
response_text = ""
chat_history = chat_history + [{"role": "assistant", "content": ""}]
chat_history = [*chat_history, {"role": "assistant", "content": ""}]
for chunk in agent.respond(text, files=processed_files):
response_text += chunk
@ -157,11 +211,14 @@ def create_ui(agent: Agent, config: Config, llm: LLMAdapter,
# If no response came through, show a fallback
if not response_text:
chat_history[-1] = {"role": "assistant", "content": "(No response received from model)"}
chat_history[-1] = {
"role": "assistant",
"content": "(No response received from model)",
}
yield chat_history, gr.update(value=None)
except Exception as e:
log.error("Error in agent.respond: %s", e, exc_info=True)
chat_history = chat_history + [{"role": "assistant", "content": f"Error: {e}"}]
chat_history = [*chat_history, {"role": "assistant", "content": f"Error: {e}"}]
yield chat_history, gr.update(value=None)
def poll_pipeline_status():
@ -209,4 +266,4 @@ def create_ui(agent: Agent, config: Config, llm: LLMAdapter,
timer = gr.Timer(10)
timer.tick(poll_notifications, None, [notification_display])
return app, _CSS
return app

View File

@ -114,7 +114,7 @@
- **Website:**
- **GBP:**
## FZE Industrial
## FZE Manufacturing
- **Executive:** Doug Pribyl, CEO
- **PA Org ID:** 22377
- **Website:**

View File

@ -2,9 +2,6 @@
from __future__ import annotations
import tempfile
from pathlib import Path
import pytest
from cheddahbot.db import Database

View File

@ -7,7 +7,6 @@ import respx
from cheddahbot.clickup import BASE_URL, ClickUpClient, ClickUpTask
# ── ClickUpTask.from_api ──
@ -183,9 +182,7 @@ class TestClickUpClient:
@respx.mock
def test_update_task_status(self):
respx.put(f"{BASE_URL}/task/t1").mock(
return_value=httpx.Response(200, json={})
)
respx.put(f"{BASE_URL}/task/t1").mock(return_value=httpx.Response(200, json={}))
client = ClickUpClient(api_token="pk_test_123")
result = client.update_task_status("t1", "in progress")
@ -210,9 +207,7 @@ class TestClickUpClient:
@respx.mock
def test_add_comment(self):
respx.post(f"{BASE_URL}/task/t1/comment").mock(
return_value=httpx.Response(200, json={})
)
respx.post(f"{BASE_URL}/task/t1/comment").mock(return_value=httpx.Response(200, json={}))
client = ClickUpClient(api_token="pk_test_123")
result = client.add_comment("t1", "CheddahBot completed this task.")
@ -260,9 +255,7 @@ class TestClickUpClient:
docx_file = tmp_path / "report.docx"
docx_file.write_bytes(b"fake docx content")
respx.post(f"{BASE_URL}/task/t1/attachment").mock(
return_value=httpx.Response(200, json={})
)
respx.post(f"{BASE_URL}/task/t1/attachment").mock(return_value=httpx.Response(200, json={}))
client = ClickUpClient(api_token="pk_test_123")
result = client.upload_attachment("t1", docx_file)

View File

@ -4,8 +4,6 @@ from __future__ import annotations
import json
import pytest
from cheddahbot.tools.clickup_tool import (
clickup_approve_task,
clickup_decline_task,

View File

@ -64,8 +64,8 @@ class TestNotifications:
def test_after_id_filters_correctly(self, tmp_db):
id1 = tmp_db.add_notification("First", "clickup")
id2 = tmp_db.add_notification("Second", "clickup")
id3 = tmp_db.add_notification("Third", "clickup")
_id2 = tmp_db.add_notification("Second", "clickup")
_id3 = tmp_db.add_notification("Third", "clickup")
# Should only get notifications after id1
notifs = tmp_db.get_notifications_after(id1)

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock
import httpx
@ -24,7 +23,6 @@ from cheddahbot.tools.press_release import (
submit_press_release,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@ -81,19 +79,21 @@ def submit_ctx(pa_config):
# PressAdvantageClient tests
# ---------------------------------------------------------------------------
class TestPressAdvantageClient:
class TestPressAdvantageClient:
@respx.mock
def test_get_organizations(self):
respx.get(
"https://app.pressadvantage.com/api/customers/organizations.json",
).mock(return_value=httpx.Response(
).mock(
return_value=httpx.Response(
200,
json=[
{"id": 19634, "name": "Advanced Industrial"},
{"id": 19800, "name": "Metal Craft"},
],
))
)
)
client = PressAdvantageClient("test-key")
try:
@ -108,10 +108,12 @@ class TestPressAdvantageClient:
def test_create_release_success(self):
respx.post(
"https://app.pressadvantage.com/api/customers/releases/with_content.json",
).mock(return_value=httpx.Response(
).mock(
return_value=httpx.Response(
200,
json={"id": 99999, "state": "draft", "title": "Test Headline"},
))
)
)
client = PressAdvantageClient("test-key")
try:
@ -154,10 +156,12 @@ class TestPressAdvantageClient:
def test_get_release(self):
respx.get(
"https://app.pressadvantage.com/api/customers/releases/81505.json",
).mock(return_value=httpx.Response(
).mock(
return_value=httpx.Response(
200,
json={"id": 81505, "state": "draft", "title": "Test"},
))
)
)
client = PressAdvantageClient("test-key")
try:
@ -171,10 +175,12 @@ class TestPressAdvantageClient:
def test_get_built_urls(self):
respx.get(
"https://app.pressadvantage.com/api/customers/releases/81505/built_urls.json",
).mock(return_value=httpx.Response(
).mock(
return_value=httpx.Response(
200,
json=[{"url": "https://example.com/press-release"}],
))
)
)
client = PressAdvantageClient("test-key")
try:
@ -204,6 +210,7 @@ class TestPressAdvantageClient:
# Company data parsing tests
# ---------------------------------------------------------------------------
class TestParseCompanyOrgIds:
def test_parses_all_companies(self):
mapping = _parse_company_org_ids(SAMPLE_COMPANIES_MD)
@ -280,12 +287,19 @@ class TestFuzzyMatchCompanyData:
# Anchor phrase helpers
# ---------------------------------------------------------------------------
class TestDeriveAnchorPhrase:
def test_basic(self):
assert _derive_anchor_phrase("Advanced Industrial", "PEEK machining") == "Advanced Industrial PEEK machining"
assert (
_derive_anchor_phrase("Advanced Industrial", "PEEK machining")
== "Advanced Industrial PEEK machining"
)
def test_strips_whitespace(self):
assert _derive_anchor_phrase("Metal Craft", " custom fabrication ") == "Metal Craft custom fabrication"
assert (
_derive_anchor_phrase("Metal Craft", " custom fabrication ")
== "Metal Craft custom fabrication"
)
class TestFindAnchorInText:
@ -325,10 +339,14 @@ class TestFuzzyFindAnchor:
# Branded URL resolution
# ---------------------------------------------------------------------------
class TestResolveBrandedUrl:
def test_literal_url(self):
data = {"website": "https://example.com", "gbp": "https://maps.google.com/123"}
assert _resolve_branded_url("https://linkedin.com/company/acme", data) == "https://linkedin.com/company/acme"
assert (
_resolve_branded_url("https://linkedin.com/company/acme", data)
== "https://linkedin.com/company/acme"
)
def test_gbp_shortcut(self):
data = {"website": "https://example.com", "gbp": "https://maps.google.com/maps?cid=123"}
@ -358,12 +376,16 @@ class TestResolveBrandedUrl:
# Link building
# ---------------------------------------------------------------------------
class TestBuildLinks:
def test_both_links_found(self):
text = "Advanced Industrial PEEK machining is excellent. Advanced Industrial leads the way."
links, warnings = _build_links(
text, "Advanced Industrial", "PEEK machining",
"https://example.com/peek", "https://linkedin.com/company/ai",
text,
"Advanced Industrial",
"PEEK machining",
"https://example.com/peek",
"https://linkedin.com/company/ai",
)
assert len(links) == 2
assert links[0]["url"] == "https://example.com/peek"
@ -380,9 +402,12 @@ class TestBuildLinks:
def test_brand_keyword_not_found_warns(self):
text = "This text has no relevant anchor phrases at all. " * 30
links, warnings = _build_links(
text, "Advanced Industrial", "PEEK machining",
"https://example.com/peek", "",
_links, warnings = _build_links(
text,
"Advanced Industrial",
"PEEK machining",
"https://example.com/peek",
"",
)
assert len(warnings) == 1
assert "NOT found" in warnings[0]
@ -390,8 +415,11 @@ class TestBuildLinks:
def test_fuzzy_match_used(self):
text = "Advanced Industrial provides excellent PEEK solutions to many clients worldwide."
links, warnings = _build_links(
text, "Advanced Industrial", "PEEK machining",
"https://example.com/peek", "",
text,
"Advanced Industrial",
"PEEK machining",
"https://example.com/peek",
"",
)
# Fuzzy should find "Advanced Industrial provides excellent PEEK" or similar
assert len(links) == 1
@ -404,6 +432,7 @@ class TestBuildLinks:
# Text to HTML
# ---------------------------------------------------------------------------
class TestTextToHtml:
def test_basic_paragraphs(self):
text = "First paragraph.\n\nSecond paragraph."
@ -451,12 +480,15 @@ class TestTextToHtml:
# submit_press_release tool tests
# ---------------------------------------------------------------------------
class TestSubmitPressRelease:
def test_missing_api_key(self):
config = MagicMock()
config.press_advantage.api_key = ""
result = submit_press_release(
headline="Test", company_name="Acme", pr_text=LONG_PR_TEXT,
headline="Test",
company_name="Acme",
pr_text=LONG_PR_TEXT,
ctx={"config": config},
)
assert "PRESS_ADVANTAGE_API" in result
@ -464,13 +496,16 @@ class TestSubmitPressRelease:
def test_missing_context(self):
result = submit_press_release(
headline="Test", company_name="Acme", pr_text=LONG_PR_TEXT,
headline="Test",
company_name="Acme",
pr_text=LONG_PR_TEXT,
)
assert "Error" in result
def test_no_pr_text_or_file(self, submit_ctx):
result = submit_press_release(
headline="Test", company_name="Advanced Industrial",
headline="Test",
company_name="Advanced Industrial",
ctx=submit_ctx,
)
assert "Error" in result
@ -479,16 +514,20 @@ class TestSubmitPressRelease:
def test_word_count_too_low(self, submit_ctx):
short_text = " ".join(["word"] * 100)
result = submit_press_release(
headline="Test", company_name="Advanced Industrial",
pr_text=short_text, ctx=submit_ctx,
headline="Test",
company_name="Advanced Industrial",
pr_text=short_text,
ctx=submit_ctx,
)
assert "Error" in result
assert "550 words" in result
def test_file_not_found(self, submit_ctx):
result = submit_press_release(
headline="Test", company_name="Advanced Industrial",
file_path="/nonexistent/file.txt", ctx=submit_ctx,
headline="Test",
company_name="Advanced Industrial",
file_path="/nonexistent/file.txt",
ctx=submit_ctx,
)
assert "Error" in result
assert "file not found" in result
@ -502,10 +541,12 @@ class TestSubmitPressRelease:
respx.post(
"https://app.pressadvantage.com/api/customers/releases/with_content.json",
).mock(return_value=httpx.Response(
).mock(
return_value=httpx.Response(
200,
json={"id": 88888, "state": "draft"},
))
)
)
result = submit_press_release(
headline="Advanced Industrial Expands PEEK Machining",
@ -526,7 +567,7 @@ class TestSubmitPressRelease:
lambda p: SAMPLE_COMPANIES_MD,
)
route = respx.post(
respx.post(
"https://app.pressadvantage.com/api/customers/releases/with_content.json",
).mock(return_value=httpx.Response(200, json={"id": 1, "state": "draft"}))
@ -549,7 +590,7 @@ class TestSubmitPressRelease:
lambda p: SAMPLE_COMPANIES_MD,
)
route = respx.post(
respx.post(
"https://app.pressadvantage.com/api/customers/releases/with_content.json",
).mock(return_value=httpx.Response(200, json={"id": 1, "state": "draft"}))
@ -599,8 +640,10 @@ class TestSubmitPressRelease:
).mock(return_value=httpx.Response(200, json=[]))
result = submit_press_release(
headline="Test", company_name="Totally Unknown Corp",
pr_text=LONG_PR_TEXT, ctx=submit_ctx,
headline="Test",
company_name="Totally Unknown Corp",
pr_text=LONG_PR_TEXT,
ctx=submit_ctx,
)
assert "Error" in result
@ -615,10 +658,12 @@ class TestSubmitPressRelease:
respx.get(
"https://app.pressadvantage.com/api/customers/organizations.json",
).mock(return_value=httpx.Response(
).mock(
return_value=httpx.Response(
200,
json=[{"id": 12345, "name": "New Client Co"}],
))
)
)
respx.post(
"https://app.pressadvantage.com/api/customers/releases/with_content.json",

View File

@ -26,11 +26,7 @@ class TestExtractDocxPaths:
assert paths == []
def test_only_matches_docx_extension(self):
result = (
"**Docx:** `report.docx`\n"
"**PDF:** `report.pdf`\n"
"**Docx:** `summary.txt`\n"
)
result = "**Docx:** `report.docx`\n**PDF:** `report.pdf`\n**Docx:** `summary.txt`\n"
paths = _extract_docx_paths(result)
assert paths == ["report.docx"]