335 lines
12 KiB
Python
335 lines
12 KiB
Python
"""Model-agnostic LLM adapter.
|
|
|
|
Routing:
|
|
- Claude models → Claude Code SDK (subprocess, uses Max subscription)
|
|
- Cloud models → OpenRouter (single API key, OpenAI-compatible)
|
|
- Local models → direct HTTP (Ollama / LM Studio, OpenAI-compatible)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import Generator
|
|
|
|
import httpx
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ModelInfo:
|
|
id: str
|
|
name: str
|
|
provider: str # "claude" | "openrouter" | "ollama" | "lmstudio"
|
|
context_length: int | None = None
|
|
|
|
|
|
# Well-known Claude models that route through the SDK
|
|
CLAUDE_MODELS = {
|
|
"claude-sonnet-4-20250514",
|
|
"claude-opus-4-20250514",
|
|
"claude-haiku-4-20250514",
|
|
}
|
|
|
|
|
|
def _is_claude_model(model_id: str) -> bool:
|
|
return model_id in CLAUDE_MODELS or model_id.startswith("claude-")
|
|
|
|
|
|
def _provider_for(model_id: str, openrouter_key: str, ollama_url: str, lmstudio_url: str) -> str:
|
|
if _is_claude_model(model_id):
|
|
return "claude"
|
|
if model_id.startswith("local/ollama/"):
|
|
return "ollama"
|
|
if model_id.startswith("local/lmstudio/"):
|
|
return "lmstudio"
|
|
if openrouter_key:
|
|
return "openrouter"
|
|
return "openrouter"
|
|
|
|
|
|
class LLMAdapter:
|
|
def __init__(
|
|
self,
|
|
default_model: str = "claude-sonnet-4-20250514",
|
|
openrouter_key: str = "",
|
|
ollama_url: str = "http://localhost:11434",
|
|
lmstudio_url: str = "http://localhost:1234",
|
|
):
|
|
self.current_model = default_model
|
|
self.openrouter_key = openrouter_key
|
|
self.ollama_url = ollama_url.rstrip("/")
|
|
self.lmstudio_url = lmstudio_url.rstrip("/")
|
|
self._openai_mod = None # lazy import
|
|
|
|
@property
|
|
def provider(self) -> str:
|
|
return _provider_for(self.current_model, self.openrouter_key, self.ollama_url, self.lmstudio_url)
|
|
|
|
def switch_model(self, model_id: str):
|
|
self.current_model = model_id
|
|
log.info("Switched to model: %s (provider: %s)", model_id, self.provider)
|
|
|
|
# ── Main entry point ──
|
|
|
|
def chat(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None = None,
|
|
stream: bool = True,
|
|
) -> Generator[dict, None, None]:
|
|
"""Yield chunks: {"type": "text", "content": "..."} or {"type": "tool_use", ...}."""
|
|
provider = self.provider
|
|
if provider == "claude":
|
|
yield from self._chat_claude_sdk(messages, tools, stream)
|
|
else:
|
|
base_url, api_key = self._resolve_endpoint(provider)
|
|
model_id = self._resolve_model_id(provider)
|
|
yield from self._chat_openai_sdk(messages, tools, stream, base_url, api_key, model_id)
|
|
|
|
# ── Claude Code SDK (subprocess) ──
|
|
|
|
def _chat_claude_sdk(
|
|
self, messages: list[dict], tools: list[dict] | None, stream: bool
|
|
) -> Generator[dict, None, None]:
|
|
# Separate system prompt from user messages
|
|
system_prompt = ""
|
|
user_prompt_parts = []
|
|
for m in messages:
|
|
role = m.get("role", "user")
|
|
content = m.get("content", "")
|
|
if isinstance(content, list):
|
|
content = " ".join(c.get("text", "") for c in content if c.get("type") == "text")
|
|
if role == "system":
|
|
system_prompt += content + "\n"
|
|
elif role == "assistant":
|
|
user_prompt_parts.append(f"[Assistant]\n{content}")
|
|
else:
|
|
user_prompt_parts.append(content)
|
|
user_prompt = "\n\n".join(user_prompt_parts)
|
|
|
|
# Find claude CLI - on Windows needs .cmd extension for npm-installed binaries
|
|
claude_bin = shutil.which("claude")
|
|
if not claude_bin:
|
|
yield {"type": "text", "content": "Error: `claude` CLI not found in PATH. Install Claude Code: npm install -g @anthropic-ai/claude-code"}
|
|
return
|
|
|
|
cmd = [claude_bin, "-p", user_prompt, "--model", self.current_model,
|
|
"--output-format", "json", "--tools", ""]
|
|
if system_prompt.strip():
|
|
cmd.extend(["--system-prompt", system_prompt.strip()])
|
|
log.debug("Claude SDK using: %s", claude_bin)
|
|
|
|
# Strip CLAUDECODE env var so the subprocess doesn't think it's nested
|
|
env = {k: v for k, v in os.environ.items() if k != "CLAUDECODE"}
|
|
|
|
try:
|
|
proc = subprocess.Popen(
|
|
cmd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True,
|
|
encoding="utf-8",
|
|
shell=(sys.platform == "win32"),
|
|
env=env,
|
|
)
|
|
except FileNotFoundError:
|
|
yield {"type": "text", "content": "Error: `claude` CLI not found. Install Claude Code: npm install -g @anthropic-ai/claude-code"}
|
|
return
|
|
|
|
stdout, stderr = proc.communicate(timeout=120)
|
|
|
|
if proc.returncode != 0:
|
|
yield {"type": "text", "content": f"Claude SDK error: {stderr or 'unknown error'}"}
|
|
return
|
|
|
|
# --output-format json returns a single JSON object
|
|
try:
|
|
result = json.loads(stdout)
|
|
text = result.get("result", "")
|
|
if text:
|
|
yield {"type": "text", "content": text}
|
|
elif result.get("is_error"):
|
|
yield {"type": "text", "content": f"Claude error: {result.get('result', 'unknown')}"}
|
|
return
|
|
except json.JSONDecodeError:
|
|
# Fallback: treat as plain text
|
|
if stdout.strip():
|
|
yield {"type": "text", "content": stdout.strip()}
|
|
|
|
# ── OpenAI-compatible SDK (OpenRouter / Ollama / LM Studio) ──
|
|
|
|
def _chat_openai_sdk(
|
|
self,
|
|
messages: list[dict],
|
|
tools: list[dict] | None,
|
|
stream: bool,
|
|
base_url: str,
|
|
api_key: str,
|
|
model_id: str,
|
|
) -> Generator[dict, None, None]:
|
|
openai = self._get_openai()
|
|
client = openai.OpenAI(base_url=base_url, api_key=api_key)
|
|
|
|
kwargs: dict = {
|
|
"model": model_id,
|
|
"messages": messages,
|
|
"stream": stream,
|
|
}
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
|
|
try:
|
|
if stream:
|
|
response = client.chat.completions.create(**kwargs)
|
|
tool_calls_accum: dict[int, dict] = {}
|
|
for chunk in response:
|
|
delta = chunk.choices[0].delta if chunk.choices else None
|
|
if not delta:
|
|
continue
|
|
if delta.content:
|
|
yield {"type": "text", "content": delta.content}
|
|
if delta.tool_calls:
|
|
for tc in delta.tool_calls:
|
|
idx = tc.index
|
|
if idx not in tool_calls_accum:
|
|
tool_calls_accum[idx] = {
|
|
"id": tc.id or "",
|
|
"name": tc.function.name if tc.function and tc.function.name else "",
|
|
"arguments": "",
|
|
}
|
|
if tc.function and tc.function.arguments:
|
|
tool_calls_accum[idx]["arguments"] += tc.function.arguments
|
|
if tc.id:
|
|
tool_calls_accum[idx]["id"] = tc.id
|
|
|
|
for _, tc in sorted(tool_calls_accum.items()):
|
|
try:
|
|
args = json.loads(tc["arguments"])
|
|
except json.JSONDecodeError:
|
|
args = {}
|
|
yield {
|
|
"type": "tool_use",
|
|
"id": tc["id"],
|
|
"name": tc["name"],
|
|
"input": args,
|
|
}
|
|
else:
|
|
response = client.chat.completions.create(**kwargs)
|
|
msg = response.choices[0].message
|
|
if msg.content:
|
|
yield {"type": "text", "content": msg.content}
|
|
if msg.tool_calls:
|
|
for tc in msg.tool_calls:
|
|
try:
|
|
args = json.loads(tc.function.arguments)
|
|
except json.JSONDecodeError:
|
|
args = {}
|
|
yield {
|
|
"type": "tool_use",
|
|
"id": tc.id,
|
|
"name": tc.function.name,
|
|
"input": args,
|
|
}
|
|
except Exception as e:
|
|
yield {"type": "text", "content": f"LLM error ({self.provider}): {e}"}
|
|
|
|
# ── Helpers ──
|
|
|
|
def _resolve_endpoint(self, provider: str) -> tuple[str, str]:
|
|
if provider == "openrouter":
|
|
return "https://openrouter.ai/api/v1", self.openrouter_key or "sk-placeholder"
|
|
elif provider == "ollama":
|
|
return f"{self.ollama_url}/v1", "ollama"
|
|
elif provider == "lmstudio":
|
|
return f"{self.lmstudio_url}/v1", "lm-studio"
|
|
return "https://openrouter.ai/api/v1", self.openrouter_key or "sk-placeholder"
|
|
|
|
def _resolve_model_id(self, provider: str) -> str:
|
|
model = self.current_model
|
|
if provider == "ollama" and model.startswith("local/ollama/"):
|
|
return model.removeprefix("local/ollama/")
|
|
if provider == "lmstudio" and model.startswith("local/lmstudio/"):
|
|
return model.removeprefix("local/lmstudio/")
|
|
return model
|
|
|
|
def _messages_to_prompt(self, messages: list[dict]) -> str:
|
|
"""Flatten messages into a single prompt string for Claude SDK -p flag."""
|
|
parts = []
|
|
for m in messages:
|
|
role = m.get("role", "user")
|
|
content = m.get("content", "")
|
|
if isinstance(content, list):
|
|
# multimodal - extract text parts
|
|
content = " ".join(
|
|
c.get("text", "") for c in content if c.get("type") == "text"
|
|
)
|
|
if role == "system":
|
|
parts.append(f"[System]\n{content}")
|
|
elif role == "assistant":
|
|
parts.append(f"[Assistant]\n{content}")
|
|
else:
|
|
parts.append(content)
|
|
return "\n\n".join(parts)
|
|
|
|
def _get_openai(self):
|
|
if self._openai_mod is None:
|
|
import openai
|
|
self._openai_mod = openai
|
|
return self._openai_mod
|
|
|
|
# ── Model Discovery ──
|
|
|
|
def discover_local_models(self) -> list[ModelInfo]:
|
|
models = []
|
|
# Ollama
|
|
try:
|
|
r = httpx.get(f"{self.ollama_url}/api/tags", timeout=3)
|
|
if r.status_code == 200:
|
|
for m in r.json().get("models", []):
|
|
models.append(ModelInfo(
|
|
id=f"local/ollama/{m['name']}",
|
|
name=f"[Ollama] {m['name']}",
|
|
provider="ollama",
|
|
))
|
|
except Exception:
|
|
pass
|
|
# LM Studio
|
|
try:
|
|
r = httpx.get(f"{self.lmstudio_url}/v1/models", timeout=3)
|
|
if r.status_code == 200:
|
|
for m in r.json().get("data", []):
|
|
models.append(ModelInfo(
|
|
id=f"local/lmstudio/{m['id']}",
|
|
name=f"[LM Studio] {m['id']}",
|
|
provider="lmstudio",
|
|
))
|
|
except Exception:
|
|
pass
|
|
return models
|
|
|
|
def list_available_models(self) -> list[ModelInfo]:
|
|
"""Return all available models across all providers."""
|
|
models = [
|
|
ModelInfo("claude-sonnet-4-20250514", "Claude Sonnet 4", "claude"),
|
|
ModelInfo("claude-opus-4-20250514", "Claude Opus 4", "claude"),
|
|
ModelInfo("claude-haiku-4-20250514", "Claude Haiku 4", "claude"),
|
|
]
|
|
if self.openrouter_key:
|
|
models.extend([
|
|
ModelInfo("openai/gpt-4o", "GPT-4o", "openrouter"),
|
|
ModelInfo("openai/gpt-4o-mini", "GPT-4o Mini", "openrouter"),
|
|
ModelInfo("google/gemini-2.0-flash-001", "Gemini 2.0 Flash", "openrouter"),
|
|
ModelInfo("google/gemini-2.5-pro-preview", "Gemini 2.5 Pro", "openrouter"),
|
|
ModelInfo("mistralai/mistral-large", "Mistral Large", "openrouter"),
|
|
ModelInfo("meta-llama/llama-3.3-70b-instruct", "Llama 3.3 70B", "openrouter"),
|
|
])
|
|
models.extend(self.discover_local_models())
|
|
return models
|