Big-Link-Man/src/generation/ai_client.py

186 lines
6.5 KiB
Python

"""
OpenRouter AI client and prompt management
"""
import time
import json
from pathlib import Path
from typing import Optional, Dict, Any
from openai import OpenAI, RateLimitError, APIError, APIConnectionError, APITimeoutError
from src.core.config import get_config
AVAILABLE_MODELS = {
"gpt-4o-mini": "openai/gpt-4o-mini",
"claude-sonnet-3.5": "anthropic/claude-3.5-sonnet",
"grok-4-fast": "x-ai/grok-4-fast"
}
class AIClient:
"""OpenRouter API client using OpenAI SDK"""
def __init__(
self,
api_key: str,
model: str,
base_url: str = "https://openrouter.ai/api/v1"
):
self.client = OpenAI(
api_key=api_key,
base_url=base_url,
default_headers={
"HTTP-Referer": "https://github.com/yourusername/Big-Link-Man",
"X-Title": "Big-Link-Man"
}
)
if model in AVAILABLE_MODELS:
self.model = AVAILABLE_MODELS[model]
else:
self.model = model
def generate_completion(
self,
prompt: str,
system_message: Optional[str] = None,
max_tokens: int = 4000,
temperature: float = 0.7,
json_mode: bool = False,
override_model: Optional[str] = None,
title: Optional[str] = None
) -> str:
"""
Generate completion from OpenRouter API
Args:
prompt: User prompt text
system_message: Optional system message
max_tokens: Maximum tokens to generate
temperature: Sampling temperature (0-1)
json_mode: If True, requests JSON response format
override_model: If provided, use this model instead of self.model
Returns:
Tuple of (generated text completion, finish_reason)
"""
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
messages.append({"role": "user", "content": prompt})
model_to_use = override_model if override_model else self.model
kwargs: Dict[str, Any] = {
"model": model_to_use,
"messages": messages,
"max_tokens": max_tokens,
"temperature": temperature
}
if json_mode:
kwargs["response_format"] = {"type": "json_object"}
retries = 3
wait_times = [10, 20]
for attempt in range(retries):
try:
response = self.client.chat.completions.create(**kwargs)
content = response.choices[0].message.content or ""
finish_reason = response.choices[0].finish_reason
if finish_reason != "stop":
title_str = title if title else "N/A"
print(f"{title_str} - {finish_reason} - {model_to_use}")
if json_mode:
print(f"[DEBUG] AI Response (first 200 chars): {content[:200]}")
return content, finish_reason
except RateLimitError as e:
if attempt < retries - 1:
wait_time = wait_times[attempt]
print(f"[API] Rate limit hit. Retrying in {wait_time}s... (attempt {attempt + 1}/{retries})")
time.sleep(wait_time)
else:
print(f"[API] Rate limit exceeded after {retries} attempts")
raise
except (APIConnectionError, APITimeoutError) as e:
if attempt < retries - 1:
wait_time = wait_times[attempt]
print(f"[API] Connection/timeout error. Retrying in {wait_time}s... (attempt {attempt + 1}/{retries})")
time.sleep(wait_time)
else:
print(f"[API] Connection failed after {retries} attempts")
raise
except json.JSONDecodeError as e:
if attempt < retries - 1:
wait_time = wait_times[attempt]
print(f"[API] Invalid JSON response (likely API error page). Retrying in {wait_time}s... (attempt {attempt + 1}/{retries})")
time.sleep(wait_time)
else:
print(f"[API] Failed to get valid response after {retries} attempts")
raise
except APIError as e:
if attempt < retries - 1:
wait_time = wait_times[attempt]
print(f"[API] API error: {str(e)[:100]}. Retrying in {wait_time}s... (attempt {attempt + 1}/{retries})")
time.sleep(wait_time)
else:
print(f"[API] API error after {retries} attempts: {str(e)[:200]}")
raise
except Exception as e:
print(f"[API] Unexpected error: {type(e).__name__}: {str(e)[:200]}")
raise
return ""
class PromptManager:
"""Manages loading and formatting of prompt templates"""
def __init__(self, prompts_dir: str = "src/generation/prompts"):
self.prompts_dir = Path(prompts_dir)
self.prompts: Dict[str, dict] = {}
def load_prompt(self, prompt_name: str) -> dict:
"""Load prompt from JSON file"""
if prompt_name in self.prompts:
return self.prompts[prompt_name]
prompt_file = self.prompts_dir / f"{prompt_name}.json"
if not prompt_file.exists():
raise FileNotFoundError(f"Prompt file not found: {prompt_file}")
with open(prompt_file, 'r', encoding='utf-8') as f:
prompt_data = json.load(f)
self.prompts[prompt_name] = prompt_data
return prompt_data
def format_prompt(self, prompt_name: str, **kwargs) -> tuple[str, str]:
"""
Format prompt with variables
Args:
prompt_name: Name of the prompt template
**kwargs: Variables to inject into the template
Returns:
Tuple of (system_message, user_prompt)
"""
prompt_data = self.load_prompt(prompt_name)
system_message = prompt_data.get("system_message", "")
user_prompt = prompt_data.get("user_prompt", "")
if system_message:
system_message = system_message.format(**kwargs)
user_prompt = user_prompt.format(**kwargs)
return system_message, user_prompt