179 lines
6.2 KiB
Python
179 lines
6.2 KiB
Python
"""
|
|
OpenRouter AI client and prompt management
|
|
"""
|
|
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any
|
|
from openai import OpenAI, RateLimitError, APIError, APIConnectionError, APITimeoutError
|
|
from src.core.config import get_config
|
|
|
|
AVAILABLE_MODELS = {
|
|
"gpt-4o-mini": "openai/gpt-4o-mini",
|
|
"claude-sonnet-3.5": "anthropic/claude-3.5-sonnet",
|
|
"grok-4-fast": "x-ai/grok-4-fast"
|
|
}
|
|
|
|
|
|
class AIClient:
|
|
"""OpenRouter API client using OpenAI SDK"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
model: str,
|
|
base_url: str = "https://openrouter.ai/api/v1"
|
|
):
|
|
self.client = OpenAI(
|
|
api_key=api_key,
|
|
base_url=base_url,
|
|
default_headers={
|
|
"HTTP-Referer": "https://github.com/yourusername/Big-Link-Man",
|
|
"X-Title": "Big-Link-Man"
|
|
}
|
|
)
|
|
|
|
if model in AVAILABLE_MODELS:
|
|
self.model = AVAILABLE_MODELS[model]
|
|
else:
|
|
self.model = model
|
|
|
|
def generate_completion(
|
|
self,
|
|
prompt: str,
|
|
system_message: Optional[str] = None,
|
|
max_tokens: int = 4000,
|
|
temperature: float = 0.7,
|
|
json_mode: bool = False,
|
|
override_model: Optional[str] = None
|
|
) -> str:
|
|
"""
|
|
Generate completion from OpenRouter API
|
|
|
|
Args:
|
|
prompt: User prompt text
|
|
system_message: Optional system message
|
|
max_tokens: Maximum tokens to generate
|
|
temperature: Sampling temperature (0-1)
|
|
json_mode: If True, requests JSON response format
|
|
override_model: If provided, use this model instead of self.model
|
|
|
|
Returns:
|
|
Generated text completion
|
|
"""
|
|
messages = []
|
|
if system_message:
|
|
messages.append({"role": "system", "content": system_message})
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
model_to_use = override_model if override_model else self.model
|
|
|
|
kwargs: Dict[str, Any] = {
|
|
"model": model_to_use,
|
|
"messages": messages,
|
|
"max_tokens": max_tokens,
|
|
"temperature": temperature
|
|
}
|
|
|
|
if json_mode:
|
|
kwargs["response_format"] = {"type": "json_object"}
|
|
|
|
retries = 3
|
|
wait_times = [10, 20]
|
|
|
|
for attempt in range(retries):
|
|
try:
|
|
response = self.client.chat.completions.create(**kwargs)
|
|
content = response.choices[0].message.content or ""
|
|
if json_mode:
|
|
print(f"[DEBUG] AI Response (first 200 chars): {content[:200]}")
|
|
return content
|
|
|
|
except RateLimitError as e:
|
|
if attempt < retries - 1:
|
|
wait_time = wait_times[attempt]
|
|
print(f"[API] Rate limit hit. Retrying in {wait_time}s... (attempt {attempt + 1}/{retries})")
|
|
time.sleep(wait_time)
|
|
else:
|
|
print(f"[API] Rate limit exceeded after {retries} attempts")
|
|
raise
|
|
|
|
except (APIConnectionError, APITimeoutError) as e:
|
|
if attempt < retries - 1:
|
|
wait_time = wait_times[attempt]
|
|
print(f"[API] Connection/timeout error. Retrying in {wait_time}s... (attempt {attempt + 1}/{retries})")
|
|
time.sleep(wait_time)
|
|
else:
|
|
print(f"[API] Connection failed after {retries} attempts")
|
|
raise
|
|
|
|
except json.JSONDecodeError as e:
|
|
if attempt < retries - 1:
|
|
wait_time = wait_times[attempt]
|
|
print(f"[API] Invalid JSON response (likely API error page). Retrying in {wait_time}s... (attempt {attempt + 1}/{retries})")
|
|
time.sleep(wait_time)
|
|
else:
|
|
print(f"[API] Failed to get valid response after {retries} attempts")
|
|
raise
|
|
|
|
except APIError as e:
|
|
if attempt < retries - 1:
|
|
wait_time = wait_times[attempt]
|
|
print(f"[API] API error: {str(e)[:100]}. Retrying in {wait_time}s... (attempt {attempt + 1}/{retries})")
|
|
time.sleep(wait_time)
|
|
else:
|
|
print(f"[API] API error after {retries} attempts: {str(e)[:200]}")
|
|
raise
|
|
|
|
except Exception as e:
|
|
print(f"[API] Unexpected error: {type(e).__name__}: {str(e)[:200]}")
|
|
raise
|
|
|
|
return ""
|
|
|
|
|
|
class PromptManager:
|
|
"""Manages loading and formatting of prompt templates"""
|
|
|
|
def __init__(self, prompts_dir: str = "src/generation/prompts"):
|
|
self.prompts_dir = Path(prompts_dir)
|
|
self.prompts: Dict[str, dict] = {}
|
|
|
|
def load_prompt(self, prompt_name: str) -> dict:
|
|
"""Load prompt from JSON file"""
|
|
if prompt_name in self.prompts:
|
|
return self.prompts[prompt_name]
|
|
|
|
prompt_file = self.prompts_dir / f"{prompt_name}.json"
|
|
if not prompt_file.exists():
|
|
raise FileNotFoundError(f"Prompt file not found: {prompt_file}")
|
|
|
|
with open(prompt_file, 'r', encoding='utf-8') as f:
|
|
prompt_data = json.load(f)
|
|
|
|
self.prompts[prompt_name] = prompt_data
|
|
return prompt_data
|
|
|
|
def format_prompt(self, prompt_name: str, **kwargs) -> tuple[str, str]:
|
|
"""
|
|
Format prompt with variables
|
|
|
|
Args:
|
|
prompt_name: Name of the prompt template
|
|
**kwargs: Variables to inject into the template
|
|
|
|
Returns:
|
|
Tuple of (system_message, user_prompt)
|
|
"""
|
|
prompt_data = self.load_prompt(prompt_name)
|
|
|
|
system_message = prompt_data.get("system_message", "")
|
|
user_prompt = prompt_data.get("user_prompt", "")
|
|
|
|
if system_message:
|
|
system_message = system_message.format(**kwargs)
|
|
|
|
user_prompt = user_prompt.format(**kwargs)
|
|
|
|
return system_message, user_prompt |