514 lines
18 KiB
Python
514 lines
18 KiB
Python
"""
|
|
Content generation service with three-stage pipeline
|
|
"""
|
|
|
|
import re
|
|
import json
|
|
from html import unescape
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Optional, Tuple, List
|
|
from src.generation.ai_client import AIClient, PromptManager
|
|
from src.database.repositories import ProjectRepository, GeneratedContentRepository, SiteDeploymentRepository
|
|
from src.templating.service import TemplateService
|
|
|
|
|
|
class ContentGenerator:
|
|
"""Main service for generating content through AI pipeline"""
|
|
|
|
def __init__(
|
|
self,
|
|
ai_client: AIClient,
|
|
prompt_manager: PromptManager,
|
|
project_repo: ProjectRepository,
|
|
content_repo: GeneratedContentRepository,
|
|
template_service: Optional[TemplateService] = None,
|
|
site_deployment_repo: Optional[SiteDeploymentRepository] = None
|
|
):
|
|
self.ai_client = ai_client
|
|
self.prompt_manager = prompt_manager
|
|
self.project_repo = project_repo
|
|
self.content_repo = content_repo
|
|
self.template_service = template_service or TemplateService(content_repo)
|
|
self.site_deployment_repo = site_deployment_repo
|
|
|
|
def generate_title(self, project_id: int, debug: bool = False, model: Optional[str] = None) -> str:
|
|
"""
|
|
Generate SEO-optimized title
|
|
|
|
Args:
|
|
project_id: Project ID to generate title for
|
|
debug: If True, save response to debug_output/
|
|
model: Optional model override for this generation stage
|
|
|
|
Returns:
|
|
Generated title string
|
|
"""
|
|
project = self.project_repo.get_by_id(project_id)
|
|
if not project:
|
|
raise ValueError(f"Project {project_id} not found")
|
|
|
|
entities_str = ", ".join(project.entities or [])
|
|
related_str = ", ".join(project.related_searches or [])
|
|
|
|
system_msg, user_prompt = self.prompt_manager.format_prompt(
|
|
"title_generation",
|
|
keyword=project.main_keyword,
|
|
entities=entities_str,
|
|
related_searches=related_str
|
|
)
|
|
|
|
title, _ = self.ai_client.generate_completion(
|
|
prompt=user_prompt,
|
|
system_message=system_msg,
|
|
max_tokens=100,
|
|
temperature=0.7,
|
|
override_model=model
|
|
)
|
|
|
|
title = title.strip().strip('"').strip("'")
|
|
|
|
if debug:
|
|
self._save_debug_output(
|
|
project_id, "title", title, "txt"
|
|
)
|
|
|
|
return title
|
|
|
|
def generate_titles_batch(
|
|
self,
|
|
project_id: int,
|
|
count: int,
|
|
batch_size: int = 25,
|
|
debug: bool = False,
|
|
model: Optional[str] = None
|
|
) -> List[str]:
|
|
"""
|
|
Generate multiple titles in batches
|
|
|
|
Args:
|
|
project_id: Project ID to generate titles for
|
|
count: Total number of titles needed
|
|
batch_size: Number of titles per AI call (default: 25)
|
|
debug: If True, save responses to debug_output/
|
|
model: Optional model override for this generation stage
|
|
|
|
Returns:
|
|
List of generated title strings
|
|
"""
|
|
project = self.project_repo.get_by_id(project_id)
|
|
if not project:
|
|
raise ValueError(f"Project {project_id} not found")
|
|
|
|
entities_str = ", ".join(project.entities or [])
|
|
related_str = ", ".join(project.related_searches or [])
|
|
|
|
all_titles = []
|
|
titles_remaining = count
|
|
|
|
while titles_remaining > 0:
|
|
current_batch_size = min(batch_size, titles_remaining)
|
|
|
|
system_msg, user_prompt = self.prompt_manager.format_prompt(
|
|
"batch_title_generation",
|
|
keyword=project.main_keyword,
|
|
entities=entities_str,
|
|
related_searches=related_str,
|
|
count=current_batch_size
|
|
)
|
|
|
|
batch_titles = None
|
|
for attempt in range(3):
|
|
try:
|
|
response, _ = self.ai_client.generate_completion(
|
|
prompt=user_prompt,
|
|
system_message=system_msg,
|
|
max_tokens=100 * current_batch_size,
|
|
temperature=0.7,
|
|
override_model=model
|
|
)
|
|
|
|
lines = response.strip().split('\n')
|
|
batch_titles = []
|
|
|
|
for line in lines:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
line = re.sub(r'^\d+[\.\)]\s*', '', line)
|
|
line = line.strip('"').strip("'")
|
|
if line:
|
|
batch_titles.append(line)
|
|
|
|
if len(batch_titles) < current_batch_size:
|
|
print(f"Warning: Requested {current_batch_size} titles but received {len(batch_titles)}. Continuing with partial batch.")
|
|
|
|
if len(batch_titles) > current_batch_size:
|
|
batch_titles = batch_titles[:current_batch_size]
|
|
|
|
break
|
|
|
|
except Exception as e:
|
|
if attempt == 2:
|
|
raise ValueError(f"Failed to generate batch after 3 attempts: {e}")
|
|
print(f"Batch generation attempt {attempt + 1} failed: {e}, retrying...")
|
|
|
|
if batch_titles:
|
|
all_titles.extend(batch_titles)
|
|
titles_remaining -= len(batch_titles)
|
|
else:
|
|
raise ValueError("Failed to generate any titles in batch")
|
|
|
|
if debug:
|
|
for i, title in enumerate(all_titles, 1):
|
|
self._save_debug_output(
|
|
project_id, f"batch_title_{i}", title, "txt"
|
|
)
|
|
|
|
return all_titles
|
|
|
|
def generate_outline(
|
|
self,
|
|
project_id: int,
|
|
title: str,
|
|
min_h2: int,
|
|
max_h2: int,
|
|
min_h3: int,
|
|
max_h3: int,
|
|
debug: bool = False,
|
|
model: Optional[str] = None
|
|
) -> dict:
|
|
"""
|
|
Generate article outline in JSON format
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
title: Article title
|
|
min_h2: Minimum H2 headings
|
|
max_h2: Maximum H2 headings
|
|
min_h3: Minimum H3 subheadings total
|
|
max_h3: Maximum H3 subheadings total
|
|
debug: If True, save response to debug_output/
|
|
model: Optional model override for this generation stage
|
|
|
|
Returns:
|
|
Outline dictionary: {"outline": [{"h2": "...", "h3": ["...", "..."]}]}
|
|
|
|
Raises:
|
|
ValueError: If outline doesn't meet minimum requirements
|
|
"""
|
|
project = self.project_repo.get_by_id(project_id)
|
|
if not project:
|
|
raise ValueError(f"Project {project_id} not found")
|
|
|
|
entities_str = ", ".join(project.entities or [])
|
|
related_str = ", ".join(project.related_searches or [])
|
|
|
|
system_msg, user_prompt = self.prompt_manager.format_prompt(
|
|
"outline_generation",
|
|
title=title,
|
|
keyword=project.main_keyword,
|
|
min_h2=min_h2,
|
|
max_h2=max_h2,
|
|
min_h3=min_h3,
|
|
max_h3=max_h3,
|
|
entities=entities_str,
|
|
related_searches=related_str
|
|
)
|
|
|
|
outline_json, _ = self.ai_client.generate_completion(
|
|
prompt=user_prompt,
|
|
system_message=system_msg,
|
|
max_tokens=2000,
|
|
temperature=0.7,
|
|
json_mode=True,
|
|
override_model=model
|
|
)
|
|
print(f"[DEBUG] Raw outline response: {outline_json}")
|
|
# Save raw response immediately
|
|
if debug:
|
|
self._save_debug_output(project_id, "outline_raw", outline_json, "txt")
|
|
print(f"[DEBUG] Raw outline response: {outline_json}")
|
|
|
|
try:
|
|
outline = json.loads(outline_json)
|
|
except json.JSONDecodeError as e:
|
|
if debug:
|
|
self._save_debug_output(project_id, "outline_error", outline_json, "txt")
|
|
raise ValueError(f"Failed to parse outline JSON: {e}\nResponse: {outline_json[:500]}")
|
|
|
|
if "outline" not in outline:
|
|
if debug:
|
|
self._save_debug_output(project_id, "outline_invalid", json.dumps(outline, indent=2), "json")
|
|
raise ValueError(f"Outline missing 'outline' key. Got keys: {list(outline.keys())}\nContent: {outline}")
|
|
|
|
h2_count = len(outline["outline"])
|
|
h3_count = sum(len(section.get("h3", [])) for section in outline["outline"])
|
|
|
|
if h2_count < min_h2:
|
|
raise ValueError(f"Outline has {h2_count} H2s, minimum is {min_h2}")
|
|
|
|
if h3_count < min_h3:
|
|
raise ValueError(f"Outline has {h3_count} H3s, minimum is {min_h3}")
|
|
|
|
if debug:
|
|
self._save_debug_output(
|
|
project_id, "outline", json.dumps(outline, indent=2), "json"
|
|
)
|
|
|
|
return outline
|
|
|
|
def generate_content(
|
|
self,
|
|
project_id: int,
|
|
title: str,
|
|
outline: dict,
|
|
min_word_count: int,
|
|
max_word_count: int,
|
|
debug: bool = False,
|
|
model: Optional[str] = None
|
|
) -> str:
|
|
"""
|
|
Generate full article HTML fragment
|
|
|
|
Compensates for AI undershoot by adding 200 words to targets.
|
|
|
|
Args:
|
|
project_id: Project ID
|
|
title: Article title
|
|
outline: Article outline dict
|
|
min_word_count: Minimum word count for guidance
|
|
max_word_count: Maximum word count for guidance
|
|
debug: If True, save response to debug_output/
|
|
model: Optional model override for this generation stage
|
|
|
|
Returns:
|
|
Tuple of (HTML string with <h2>, <h3>, <p> tags, finish_reason)
|
|
"""
|
|
project = self.project_repo.get_by_id(project_id)
|
|
if not project:
|
|
raise ValueError(f"Project {project_id} not found")
|
|
|
|
entities_str = ", ".join(project.entities or [])
|
|
related_str = ", ".join(project.related_searches or [])
|
|
outline_str = json.dumps(outline, indent=2)
|
|
|
|
compensated_min = min_word_count + 100
|
|
compensated_max = max_word_count + 100
|
|
|
|
h3_count = sum(len(section.get("h3", [])) for section in outline.get("outline", []))
|
|
words_per_section = int(compensated_min / h3_count) if h3_count > 0 else 100
|
|
|
|
system_msg, user_prompt = self.prompt_manager.format_prompt(
|
|
"content_generation",
|
|
title=title,
|
|
outline=outline_str,
|
|
keyword=project.main_keyword,
|
|
entities=entities_str,
|
|
related_searches=related_str,
|
|
min_word_count=compensated_min,
|
|
max_word_count=compensated_max,
|
|
words_per_section=words_per_section
|
|
)
|
|
|
|
content, finish_reason = self.ai_client.generate_completion(
|
|
prompt=user_prompt,
|
|
system_message=system_msg,
|
|
max_tokens=12000,
|
|
temperature=0.7,
|
|
override_model=model,
|
|
title=title
|
|
)
|
|
|
|
content = content.strip()
|
|
content = self._clean_markdown_fences(content)
|
|
|
|
if debug:
|
|
self._save_debug_output(
|
|
project_id, "content", content, "html"
|
|
)
|
|
|
|
return content, finish_reason
|
|
|
|
def validate_word_count(self, content: str, min_words: int, max_words: int) -> Tuple[bool, int]:
|
|
"""
|
|
Validate content word count
|
|
|
|
Args:
|
|
content: HTML content string
|
|
min_words: Minimum word count
|
|
max_words: Maximum word count
|
|
|
|
Returns:
|
|
Tuple of (is_valid, actual_count)
|
|
"""
|
|
word_count = self.count_words(content)
|
|
is_valid = min_words <= word_count <= max_words
|
|
return is_valid, word_count
|
|
|
|
def count_words(self, html_content: str) -> int:
|
|
"""
|
|
Count words in HTML content
|
|
|
|
Args:
|
|
html_content: HTML string
|
|
|
|
Returns:
|
|
Number of words
|
|
"""
|
|
text = re.sub(r'<[^>]+>', '', html_content)
|
|
text = unescape(text)
|
|
words = text.split()
|
|
return len(words)
|
|
|
|
def augment_content(
|
|
self,
|
|
content: str,
|
|
target_word_count: int,
|
|
debug: bool = False,
|
|
project_id: Optional[int] = None,
|
|
model: Optional[str] = None
|
|
) -> str:
|
|
"""
|
|
Expand article content to meet minimum word count
|
|
|
|
Args:
|
|
content: Current HTML content
|
|
target_word_count: Target word count
|
|
debug: If True, save response to debug_output/
|
|
project_id: Optional project ID for debug output
|
|
model: Optional model override for this generation stage
|
|
|
|
Returns:
|
|
Expanded HTML content
|
|
"""
|
|
system_msg, user_prompt = self.prompt_manager.format_prompt(
|
|
"content_augmentation",
|
|
content=content,
|
|
target_word_count=target_word_count
|
|
)
|
|
|
|
augmented, _ = self.ai_client.generate_completion(
|
|
prompt=user_prompt,
|
|
system_message=system_msg,
|
|
max_tokens=8000,
|
|
temperature=0.7,
|
|
override_model=model
|
|
)
|
|
|
|
augmented = augmented.strip()
|
|
augmented = self._clean_markdown_fences(augmented)
|
|
|
|
if debug and project_id:
|
|
self._save_debug_output(
|
|
project_id, "augmented", augmented, "html"
|
|
)
|
|
|
|
return augmented
|
|
|
|
def apply_template(
|
|
self,
|
|
content_id: int,
|
|
meta_description: Optional[str] = None,
|
|
canonical_url: Optional[str] = None
|
|
) -> bool:
|
|
"""
|
|
Apply HTML template to generated content and save to database
|
|
|
|
Args:
|
|
content_id: GeneratedContent ID to format
|
|
meta_description: Optional meta description (defaults to truncated content)
|
|
canonical_url: Optional canonical URL for SEO canonical link tag
|
|
|
|
Returns:
|
|
True if successful, False otherwise
|
|
"""
|
|
try:
|
|
# Refresh to ensure we have latest content (especially after image reinsertion)
|
|
content_record = self.content_repo.get_by_id(content_id)
|
|
if not content_record:
|
|
print(f"Warning: Content {content_id} not found")
|
|
return False
|
|
|
|
# Force refresh from database to get latest content
|
|
self.content_repo.session.refresh(content_record)
|
|
|
|
if not meta_description:
|
|
text = re.sub(r'<[^>]+>', '', content_record.content)
|
|
text = unescape(text)
|
|
words = text.split()[:25]
|
|
meta_description = ' '.join(words) + '...'
|
|
|
|
template_name = self.template_service.select_template_for_content(
|
|
site_deployment_id=content_record.site_deployment_id,
|
|
site_deployment_repo=self.site_deployment_repo
|
|
)
|
|
|
|
formatted_html = self.template_service.format_content(
|
|
content=content_record.content,
|
|
title=content_record.title,
|
|
meta_description=meta_description,
|
|
template_name=template_name,
|
|
canonical_url=canonical_url
|
|
)
|
|
|
|
content_record.formatted_html = formatted_html
|
|
content_record.template_used = template_name
|
|
self.content_repo.update(content_record)
|
|
|
|
# Verify it was saved
|
|
self.content_repo.session.refresh(content_record)
|
|
if content_record.template_used != template_name:
|
|
print(f"ERROR: template_used not saved! Expected '{template_name}', got '{content_record.template_used}'")
|
|
return False
|
|
|
|
print(f"Applied template '{template_name}' to content {content_id}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Error applying template to content {content_id}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def _clean_markdown_fences(self, content: str) -> str:
|
|
"""
|
|
Remove markdown code fences from AI-generated content
|
|
|
|
Args:
|
|
content: Raw content that may contain ```html or ``` markers
|
|
|
|
Returns:
|
|
Cleaned content without markdown code fences
|
|
"""
|
|
import re
|
|
|
|
content = re.sub(r'^```html\s*\n?', '', content, flags=re.MULTILINE)
|
|
content = re.sub(r'^```\s*$', '', content, flags=re.MULTILINE)
|
|
content = re.sub(r'\n```\s*$', '', content)
|
|
|
|
return content.strip()
|
|
|
|
def _save_debug_output(
|
|
self,
|
|
project_id: int,
|
|
stage: str,
|
|
content: str,
|
|
extension: str,
|
|
tier: Optional[str] = None,
|
|
article_num: Optional[int] = None
|
|
):
|
|
"""Save debug output to file"""
|
|
debug_dir = Path("debug_output")
|
|
debug_dir.mkdir(exist_ok=True)
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
tier_part = f"_tier{tier}" if tier else ""
|
|
article_part = f"_article{article_num}" if article_num else ""
|
|
|
|
filename = f"{stage}_project{project_id}{tier_part}{article_part}_{timestamp}.{extension}"
|
|
filepath = debug_dir / filename
|
|
|
|
with open(filepath, 'w', encoding='utf-8') as f:
|
|
f.write(content) |