361 lines
15 KiB
Python
361 lines
15 KiB
Python
"""
|
|
Content generation service - orchestrates the three-stage AI generation pipeline
|
|
"""
|
|
|
|
import time
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Dict, Any, Optional, Tuple
|
|
from src.database.models import Project, GeneratedContent
|
|
from src.database.repositories import GeneratedContentRepository
|
|
from src.generation.ai_client import AIClient, AIClientError
|
|
from src.generation.validator import StageValidator
|
|
from src.generation.augmenter import ContentAugmenter
|
|
from src.generation.rule_engine import ContentRuleEngine
|
|
from src.core.config import Config, get_config
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
class GenerationError(Exception):
|
|
"""Content generation error"""
|
|
pass
|
|
|
|
|
|
class ContentGenerationService:
|
|
"""Service for AI-powered content generation with validation"""
|
|
|
|
def __init__(
|
|
self,
|
|
session: Session,
|
|
config: Optional[Config] = None,
|
|
ai_client: Optional[AIClient] = None
|
|
):
|
|
"""
|
|
Initialize service
|
|
|
|
Args:
|
|
session: Database session
|
|
config: Application configuration
|
|
ai_client: AI client (creates new if None)
|
|
"""
|
|
self.session = session
|
|
self.config = config or get_config()
|
|
self.ai_client = ai_client or AIClient(self.config)
|
|
self.content_repo = GeneratedContentRepository(session)
|
|
self.rule_engine = ContentRuleEngine(self.config)
|
|
self.validator = StageValidator(self.config, self.rule_engine)
|
|
self.augmenter = ContentAugmenter()
|
|
|
|
self.prompts_dir = Path(__file__).parent / "prompts"
|
|
|
|
def generate_article(
|
|
self,
|
|
project: Project,
|
|
tier: int,
|
|
title_model: str,
|
|
outline_model: str,
|
|
content_model: str,
|
|
max_retries: int = 3
|
|
) -> GeneratedContent:
|
|
"""
|
|
Generate complete article through three-stage pipeline
|
|
|
|
Args:
|
|
project: Project with CORA data
|
|
tier: Tier level
|
|
title_model: Model for title generation
|
|
outline_model: Model for outline generation
|
|
content_model: Model for content generation
|
|
max_retries: Max retry attempts per stage
|
|
|
|
Returns:
|
|
GeneratedContent record with completed article
|
|
|
|
Raises:
|
|
GenerationError: If generation fails after all retries
|
|
"""
|
|
start_time = time.time()
|
|
|
|
content_record = self.content_repo.create(project.id, tier)
|
|
content_record.title_model = title_model
|
|
content_record.outline_model = outline_model
|
|
content_record.content_model = content_model
|
|
self.content_repo.update(content_record)
|
|
|
|
try:
|
|
title = self._generate_title(project, content_record, title_model, max_retries)
|
|
|
|
content_record.generation_stage = "outline"
|
|
self.content_repo.update(content_record)
|
|
|
|
outline = self._generate_outline(project, title, content_record, outline_model, max_retries)
|
|
|
|
content_record.generation_stage = "content"
|
|
self.content_repo.update(content_record)
|
|
|
|
html_content = self._generate_content(
|
|
project, title, outline, content_record, content_model, max_retries
|
|
)
|
|
|
|
content_record.status = "completed"
|
|
content_record.generation_duration = time.time() - start_time
|
|
self.content_repo.update(content_record)
|
|
|
|
return content_record
|
|
|
|
except Exception as e:
|
|
content_record.status = "failed"
|
|
content_record.error_message = str(e)
|
|
content_record.generation_duration = time.time() - start_time
|
|
self.content_repo.update(content_record)
|
|
raise GenerationError(f"Article generation failed: {e}")
|
|
|
|
def _generate_title(
|
|
self,
|
|
project: Project,
|
|
content_record: GeneratedContent,
|
|
model: str,
|
|
max_retries: int
|
|
) -> str:
|
|
"""Generate and validate title"""
|
|
prompt_template = self._load_prompt("title_generation.json")
|
|
|
|
entities_str = ", ".join(project.entities[:10]) if project.entities else "N/A"
|
|
searches_str = ", ".join(project.related_searches[:10]) if project.related_searches else "N/A"
|
|
|
|
prompt = prompt_template["user_template"].format(
|
|
main_keyword=project.main_keyword,
|
|
word_count=project.word_count,
|
|
entities=entities_str,
|
|
related_searches=searches_str
|
|
)
|
|
|
|
for attempt in range(1, max_retries + 1):
|
|
content_record.title_attempts = attempt
|
|
self.content_repo.update(content_record)
|
|
|
|
try:
|
|
title = self.ai_client.generate(
|
|
prompt=prompt,
|
|
model=model,
|
|
temperature=0.7
|
|
)
|
|
|
|
is_valid, errors = self.validator.validate_title(title, project)
|
|
|
|
if is_valid:
|
|
content_record.title = title
|
|
self.content_repo.update(content_record)
|
|
return title
|
|
|
|
if attempt < max_retries:
|
|
prompt += f"\n\nPrevious attempt failed: {', '.join(errors)}. Please fix these issues."
|
|
|
|
except AIClientError as e:
|
|
if attempt == max_retries:
|
|
raise GenerationError(f"Title generation failed after {max_retries} attempts: {e}")
|
|
|
|
raise GenerationError(f"Title validation failed after {max_retries} attempts")
|
|
|
|
def _generate_outline(
|
|
self,
|
|
project: Project,
|
|
title: str,
|
|
content_record: GeneratedContent,
|
|
model: str,
|
|
max_retries: int
|
|
) -> Dict[str, Any]:
|
|
"""Generate and validate outline"""
|
|
prompt_template = self._load_prompt("outline_generation.json")
|
|
|
|
entities_str = ", ".join(project.entities[:20]) if project.entities else "N/A"
|
|
searches_str = ", ".join(project.related_searches[:20]) if project.related_searches else "N/A"
|
|
|
|
h2_total = int(project.h2_total) if project.h2_total else 5
|
|
h2_exact = int(project.h2_exact) if project.h2_exact else 1
|
|
h2_related = int(project.h2_related_search) if project.h2_related_search else 1
|
|
h2_entities = int(project.h2_entities) if project.h2_entities else 2
|
|
|
|
h3_total = int(project.h3_total) if project.h3_total else 10
|
|
h3_exact = int(project.h3_exact) if project.h3_exact else 1
|
|
h3_related = int(project.h3_related_search) if project.h3_related_search else 2
|
|
h3_entities = int(project.h3_entities) if project.h3_entities else 3
|
|
|
|
if self.config.content_rules.cora_validation.round_averages_down:
|
|
h2_total = int(h2_total)
|
|
h3_total = int(h3_total)
|
|
|
|
prompt = prompt_template["user_template"].format(
|
|
title=title,
|
|
main_keyword=project.main_keyword,
|
|
word_count=project.word_count,
|
|
h2_total=h2_total,
|
|
h2_exact=h2_exact,
|
|
h2_related_search=h2_related,
|
|
h2_entities=h2_entities,
|
|
h3_total=h3_total,
|
|
h3_exact=h3_exact,
|
|
h3_related_search=h3_related,
|
|
h3_entities=h3_entities,
|
|
entities=entities_str,
|
|
related_searches=searches_str
|
|
)
|
|
|
|
for attempt in range(1, max_retries + 1):
|
|
content_record.outline_attempts = attempt
|
|
self.content_repo.update(content_record)
|
|
|
|
try:
|
|
outline_json_str = self.ai_client.generate_json(
|
|
prompt=prompt,
|
|
model=model,
|
|
temperature=0.7,
|
|
max_tokens=2000
|
|
)
|
|
|
|
if isinstance(outline_json_str, str):
|
|
outline = json.loads(outline_json_str)
|
|
else:
|
|
outline = outline_json_str
|
|
|
|
is_valid, errors, missing = self.validator.validate_outline(outline, project)
|
|
|
|
if is_valid:
|
|
content_record.outline = json.dumps(outline)
|
|
self.content_repo.update(content_record)
|
|
return outline
|
|
|
|
if attempt < max_retries:
|
|
if missing:
|
|
augmented_outline, aug_log = self.augmenter.augment_outline(
|
|
outline, missing, project.main_keyword,
|
|
project.entities or [], project.related_searches or []
|
|
)
|
|
|
|
is_valid_aug, errors_aug, _ = self.validator.validate_outline(
|
|
augmented_outline, project
|
|
)
|
|
|
|
if is_valid_aug:
|
|
content_record.outline = json.dumps(augmented_outline)
|
|
content_record.augmented = True
|
|
content_record.augmentation_log = aug_log
|
|
self.content_repo.update(content_record)
|
|
return augmented_outline
|
|
|
|
prompt += f"\n\nPrevious attempt failed: {', '.join(errors)}. Please meet ALL CORA targets exactly."
|
|
|
|
except (AIClientError, json.JSONDecodeError) as e:
|
|
if attempt == max_retries:
|
|
raise GenerationError(f"Outline generation failed after {max_retries} attempts: {e}")
|
|
|
|
raise GenerationError(f"Outline validation failed after {max_retries} attempts")
|
|
|
|
def _generate_content(
|
|
self,
|
|
project: Project,
|
|
title: str,
|
|
outline: Dict[str, Any],
|
|
content_record: GeneratedContent,
|
|
model: str,
|
|
max_retries: int
|
|
) -> str:
|
|
"""Generate and validate full HTML content"""
|
|
prompt_template = self._load_prompt("content_generation.json")
|
|
|
|
outline_str = self._format_outline_for_prompt(outline)
|
|
entities_str = ", ".join(project.entities[:30]) if project.entities else "N/A"
|
|
searches_str = ", ".join(project.related_searches[:30]) if project.related_searches else "N/A"
|
|
|
|
prompt = prompt_template["user_template"].format(
|
|
outline=outline_str,
|
|
title=title,
|
|
main_keyword=project.main_keyword,
|
|
word_count=project.word_count,
|
|
term_frequency=project.term_frequency or 3,
|
|
entities=entities_str,
|
|
related_searches=searches_str
|
|
)
|
|
|
|
for attempt in range(1, max_retries + 1):
|
|
content_record.content_attempts = attempt
|
|
self.content_repo.update(content_record)
|
|
|
|
try:
|
|
html_content = self.ai_client.generate(
|
|
prompt=prompt,
|
|
model=model,
|
|
temperature=0.7,
|
|
max_tokens=self.config.ai_service.max_tokens
|
|
)
|
|
|
|
is_valid, validation_result = self.validator.validate_content(html_content, project)
|
|
|
|
content_record.validation_errors = len(validation_result.errors)
|
|
content_record.validation_warnings = len(validation_result.warnings)
|
|
content_record.validation_report = validation_result.to_dict()
|
|
self.content_repo.update(content_record)
|
|
|
|
if is_valid:
|
|
content_record.content = html_content
|
|
word_count = len(html_content.split())
|
|
content_record.word_count = word_count
|
|
self.content_repo.update(content_record)
|
|
return html_content
|
|
|
|
if attempt < max_retries:
|
|
missing = self.validator.extract_missing_elements(validation_result, project)
|
|
|
|
if missing and any(missing.values()):
|
|
augmented_html, aug_log = self.augmenter.augment_content(
|
|
html_content, missing, project.main_keyword,
|
|
project.entities or [], project.related_searches or []
|
|
)
|
|
|
|
is_valid_aug, validation_result_aug = self.validator.validate_content(
|
|
augmented_html, project
|
|
)
|
|
|
|
if is_valid_aug:
|
|
content_record.content = augmented_html
|
|
content_record.augmented = True
|
|
existing_log = content_record.augmentation_log or {}
|
|
existing_log["content_augmentation"] = aug_log
|
|
content_record.augmentation_log = existing_log
|
|
content_record.validation_errors = len(validation_result_aug.errors)
|
|
content_record.validation_warnings = len(validation_result_aug.warnings)
|
|
content_record.validation_report = validation_result_aug.to_dict()
|
|
word_count = len(augmented_html.split())
|
|
content_record.word_count = word_count
|
|
self.content_repo.update(content_record)
|
|
return augmented_html
|
|
|
|
error_summary = ", ".join([e.message for e in validation_result.errors[:5]])
|
|
prompt += f"\n\nPrevious content failed validation: {error_summary}. Please fix these issues."
|
|
|
|
except AIClientError as e:
|
|
if attempt == max_retries:
|
|
raise GenerationError(f"Content generation failed after {max_retries} attempts: {e}")
|
|
|
|
raise GenerationError(f"Content validation failed after {max_retries} attempts")
|
|
|
|
def _load_prompt(self, filename: str) -> Dict[str, Any]:
|
|
"""Load prompt template from JSON file"""
|
|
prompt_path = self.prompts_dir / filename
|
|
if not prompt_path.exists():
|
|
raise GenerationError(f"Prompt template not found: {filename}")
|
|
|
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
|
|
def _format_outline_for_prompt(self, outline: Dict[str, Any]) -> str:
|
|
"""Format outline JSON into readable string for content prompt"""
|
|
lines = [f"H1: {outline.get('h1', '')}"]
|
|
|
|
for section in outline.get("sections", []):
|
|
lines.append(f"\nH2: {section['h2']}")
|
|
for h3 in section.get("h3s", []):
|
|
lines.append(f" H3: {h3}")
|
|
|
|
return "\n".join(lines)
|