CheddahBot/cheddahbot/llm.py

335 lines
12 KiB
Python

"""Model-agnostic LLM adapter.
Routing:
- Claude models → Claude Code SDK (subprocess, uses Max subscription)
- Cloud models → OpenRouter (single API key, OpenAI-compatible)
- Local models → direct HTTP (Ollama / LM Studio, OpenAI-compatible)
"""
from __future__ import annotations
import json
import logging
import os
import shutil
import subprocess
import sys
from dataclasses import dataclass
from typing import Generator
import httpx
log = logging.getLogger(__name__)
@dataclass
class ModelInfo:
id: str
name: str
provider: str # "claude" | "openrouter" | "ollama" | "lmstudio"
context_length: int | None = None
# Well-known Claude models that route through the SDK
CLAUDE_MODELS = {
"claude-sonnet-4-20250514",
"claude-opus-4-20250514",
"claude-haiku-4-20250514",
}
def _is_claude_model(model_id: str) -> bool:
return model_id in CLAUDE_MODELS or model_id.startswith("claude-")
def _provider_for(model_id: str, openrouter_key: str, ollama_url: str, lmstudio_url: str) -> str:
if _is_claude_model(model_id):
return "claude"
if model_id.startswith("local/ollama/"):
return "ollama"
if model_id.startswith("local/lmstudio/"):
return "lmstudio"
if openrouter_key:
return "openrouter"
return "openrouter"
class LLMAdapter:
def __init__(
self,
default_model: str = "claude-sonnet-4-20250514",
openrouter_key: str = "",
ollama_url: str = "http://localhost:11434",
lmstudio_url: str = "http://localhost:1234",
):
self.current_model = default_model
self.openrouter_key = openrouter_key
self.ollama_url = ollama_url.rstrip("/")
self.lmstudio_url = lmstudio_url.rstrip("/")
self._openai_mod = None # lazy import
@property
def provider(self) -> str:
return _provider_for(self.current_model, self.openrouter_key, self.ollama_url, self.lmstudio_url)
def switch_model(self, model_id: str):
self.current_model = model_id
log.info("Switched to model: %s (provider: %s)", model_id, self.provider)
# ── Main entry point ──
def chat(
self,
messages: list[dict],
tools: list[dict] | None = None,
stream: bool = True,
) -> Generator[dict, None, None]:
"""Yield chunks: {"type": "text", "content": "..."} or {"type": "tool_use", ...}."""
provider = self.provider
if provider == "claude":
yield from self._chat_claude_sdk(messages, tools, stream)
else:
base_url, api_key = self._resolve_endpoint(provider)
model_id = self._resolve_model_id(provider)
yield from self._chat_openai_sdk(messages, tools, stream, base_url, api_key, model_id)
# ── Claude Code SDK (subprocess) ──
def _chat_claude_sdk(
self, messages: list[dict], tools: list[dict] | None, stream: bool
) -> Generator[dict, None, None]:
# Separate system prompt from user messages
system_prompt = ""
user_prompt_parts = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
if isinstance(content, list):
content = " ".join(c.get("text", "") for c in content if c.get("type") == "text")
if role == "system":
system_prompt += content + "\n"
elif role == "assistant":
user_prompt_parts.append(f"[Assistant]\n{content}")
else:
user_prompt_parts.append(content)
user_prompt = "\n\n".join(user_prompt_parts)
# Find claude CLI - on Windows needs .cmd extension for npm-installed binaries
claude_bin = shutil.which("claude")
if not claude_bin:
yield {"type": "text", "content": "Error: `claude` CLI not found in PATH. Install Claude Code: npm install -g @anthropic-ai/claude-code"}
return
cmd = [claude_bin, "-p", user_prompt, "--model", self.current_model,
"--output-format", "json", "--tools", ""]
if system_prompt.strip():
cmd.extend(["--system-prompt", system_prompt.strip()])
log.debug("Claude SDK using: %s", claude_bin)
# Strip CLAUDECODE env var so the subprocess doesn't think it's nested
env = {k: v for k, v in os.environ.items() if k != "CLAUDECODE"}
try:
proc = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
shell=(sys.platform == "win32"),
env=env,
)
except FileNotFoundError:
yield {"type": "text", "content": "Error: `claude` CLI not found. Install Claude Code: npm install -g @anthropic-ai/claude-code"}
return
stdout, stderr = proc.communicate(timeout=120)
if proc.returncode != 0:
yield {"type": "text", "content": f"Claude SDK error: {stderr or 'unknown error'}"}
return
# --output-format json returns a single JSON object
try:
result = json.loads(stdout)
text = result.get("result", "")
if text:
yield {"type": "text", "content": text}
elif result.get("is_error"):
yield {"type": "text", "content": f"Claude error: {result.get('result', 'unknown')}"}
return
except json.JSONDecodeError:
# Fallback: treat as plain text
if stdout.strip():
yield {"type": "text", "content": stdout.strip()}
# ── OpenAI-compatible SDK (OpenRouter / Ollama / LM Studio) ──
def _chat_openai_sdk(
self,
messages: list[dict],
tools: list[dict] | None,
stream: bool,
base_url: str,
api_key: str,
model_id: str,
) -> Generator[dict, None, None]:
openai = self._get_openai()
client = openai.OpenAI(base_url=base_url, api_key=api_key)
kwargs: dict = {
"model": model_id,
"messages": messages,
"stream": stream,
}
if tools:
kwargs["tools"] = tools
try:
if stream:
response = client.chat.completions.create(**kwargs)
tool_calls_accum: dict[int, dict] = {}
for chunk in response:
delta = chunk.choices[0].delta if chunk.choices else None
if not delta:
continue
if delta.content:
yield {"type": "text", "content": delta.content}
if delta.tool_calls:
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_accum:
tool_calls_accum[idx] = {
"id": tc.id or "",
"name": tc.function.name if tc.function and tc.function.name else "",
"arguments": "",
}
if tc.function and tc.function.arguments:
tool_calls_accum[idx]["arguments"] += tc.function.arguments
if tc.id:
tool_calls_accum[idx]["id"] = tc.id
for _, tc in sorted(tool_calls_accum.items()):
try:
args = json.loads(tc["arguments"])
except json.JSONDecodeError:
args = {}
yield {
"type": "tool_use",
"id": tc["id"],
"name": tc["name"],
"input": args,
}
else:
response = client.chat.completions.create(**kwargs)
msg = response.choices[0].message
if msg.content:
yield {"type": "text", "content": msg.content}
if msg.tool_calls:
for tc in msg.tool_calls:
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError:
args = {}
yield {
"type": "tool_use",
"id": tc.id,
"name": tc.function.name,
"input": args,
}
except Exception as e:
yield {"type": "text", "content": f"LLM error ({self.provider}): {e}"}
# ── Helpers ──
def _resolve_endpoint(self, provider: str) -> tuple[str, str]:
if provider == "openrouter":
return "https://openrouter.ai/api/v1", self.openrouter_key or "sk-placeholder"
elif provider == "ollama":
return f"{self.ollama_url}/v1", "ollama"
elif provider == "lmstudio":
return f"{self.lmstudio_url}/v1", "lm-studio"
return "https://openrouter.ai/api/v1", self.openrouter_key or "sk-placeholder"
def _resolve_model_id(self, provider: str) -> str:
model = self.current_model
if provider == "ollama" and model.startswith("local/ollama/"):
return model.removeprefix("local/ollama/")
if provider == "lmstudio" and model.startswith("local/lmstudio/"):
return model.removeprefix("local/lmstudio/")
return model
def _messages_to_prompt(self, messages: list[dict]) -> str:
"""Flatten messages into a single prompt string for Claude SDK -p flag."""
parts = []
for m in messages:
role = m.get("role", "user")
content = m.get("content", "")
if isinstance(content, list):
# multimodal - extract text parts
content = " ".join(
c.get("text", "") for c in content if c.get("type") == "text"
)
if role == "system":
parts.append(f"[System]\n{content}")
elif role == "assistant":
parts.append(f"[Assistant]\n{content}")
else:
parts.append(content)
return "\n\n".join(parts)
def _get_openai(self):
if self._openai_mod is None:
import openai
self._openai_mod = openai
return self._openai_mod
# ── Model Discovery ──
def discover_local_models(self) -> list[ModelInfo]:
models = []
# Ollama
try:
r = httpx.get(f"{self.ollama_url}/api/tags", timeout=3)
if r.status_code == 200:
for m in r.json().get("models", []):
models.append(ModelInfo(
id=f"local/ollama/{m['name']}",
name=f"[Ollama] {m['name']}",
provider="ollama",
))
except Exception:
pass
# LM Studio
try:
r = httpx.get(f"{self.lmstudio_url}/v1/models", timeout=3)
if r.status_code == 200:
for m in r.json().get("data", []):
models.append(ModelInfo(
id=f"local/lmstudio/{m['id']}",
name=f"[LM Studio] {m['id']}",
provider="lmstudio",
))
except Exception:
pass
return models
def list_available_models(self) -> list[ModelInfo]:
"""Return all available models across all providers."""
models = [
ModelInfo("claude-sonnet-4-20250514", "Claude Sonnet 4", "claude"),
ModelInfo("claude-opus-4-20250514", "Claude Opus 4", "claude"),
ModelInfo("claude-haiku-4-20250514", "Claude Haiku 4", "claude"),
]
if self.openrouter_key:
models.extend([
ModelInfo("openai/gpt-4o", "GPT-4o", "openrouter"),
ModelInfo("openai/gpt-4o-mini", "GPT-4o Mini", "openrouter"),
ModelInfo("google/gemini-2.0-flash-001", "Gemini 2.0 Flash", "openrouter"),
ModelInfo("google/gemini-2.5-pro-preview", "Gemini 2.5 Pro", "openrouter"),
ModelInfo("mistralai/mistral-large", "Mistral Large", "openrouter"),
ModelInfo("meta-llama/llama-3.3-70b-instruct", "Llama 3.3 70B", "openrouter"),
])
models.extend(self.discover_local_models())
return models