Fix concurrent generation threading issues

- Fix SessionLocal import error by using db_manager.get_session()
- Create thread-local ContentGenerator instances for each worker
- Ensure each thread uses its own database session
- Prevents database session conflicts in concurrent article generation
main
PeninsulaInd 2025-10-27 13:58:18 -05:00
parent 3ebef04ef0
commit d919ea25e1
7 changed files with 301 additions and 50 deletions

View File

@ -63,7 +63,7 @@ Implement concurrent API calls for article generation using ThreadPoolExecutor t
- Console output shows concurrency setting at job start - Console output shows concurrency setting at job start
### 5. Word Count Optimization ### 5. Word Count Optimization
**Status:** PENDING **Status:** COMPLETE
- Content generation prompts compensate by adding 200 words to target - Content generation prompts compensate by adding 200 words to target
- If `min_word_count` = 1000, prompt receives `target_word_count` = 1200 - If `min_word_count` = 1000, prompt receives `target_word_count` = 1200

View File

@ -52,3 +52,8 @@ LINK_BUILDER_API_KEY=your_link_builder_api_key_here
# Application Configuration # Application Configuration
LOG_LEVEL=INFO LOG_LEVEL=INFO
ENVIRONMENT=development ENVIRONMENT=development
# Concurrent Processing
# Number of concurrent article generation workers (default: 5)
# Set to 1 for sequential processing (useful for debugging)
CONCURRENT_WORKERS=5

View File

@ -0,0 +1,42 @@
"""
List all available FQDN domains (custom hostnames)
Usage:
uv run python scripts/list_domains.py
"""
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.database.session import db_manager
from src.database.models import SiteDeployment
def list_domains():
"""List all site deployments with custom hostnames"""
session = db_manager.get_session()
try:
sites = session.query(SiteDeployment).filter(
SiteDeployment.custom_hostname.isnot(None)
).order_by(SiteDeployment.site_name).all()
print(f"\nFound {len(sites)} domains with custom hostnames:\n")
print(f"{'ID':<5} {'Site Name':<35} {'Custom Hostname'}")
print("-" * 80)
for s in sites:
print(f"{s.id:<5} {s.site_name:<35} {s.custom_hostname}")
print()
finally:
session.close()
if __name__ == "__main__":
list_domains()

View File

@ -4,7 +4,7 @@ CLI command definitions using Click
import click import click
from typing import Optional from typing import Optional
from src.core.config import get_config, get_bunny_account_api_key from src.core.config import get_config, get_bunny_account_api_key, get_concurrent_workers
from src.auth.service import AuthService from src.auth.service import AuthService
from src.database.session import db_manager from src.database.session import db_manager
from src.database.repositories import UserRepository, SiteDeploymentRepository, ProjectRepository from src.database.repositories import UserRepository, SiteDeploymentRepository, ProjectRepository
@ -977,14 +977,20 @@ def generate_batch(
content_repo=content_repo content_repo=content_repo
) )
max_workers = get_concurrent_workers()
job_max_workers = jobs[0].max_workers if jobs and jobs[0].max_workers else None
final_max_workers = job_max_workers or max_workers
batch_processor = BatchProcessor( batch_processor = BatchProcessor(
content_generator=content_generator, content_generator=content_generator,
content_repo=content_repo, content_repo=content_repo,
project_repo=project_repo, project_repo=project_repo,
site_deployment_repo=site_deployment_repo site_deployment_repo=site_deployment_repo,
max_workers=final_max_workers
) )
click.echo(f"\nProcessing job file: {job_file}") click.echo(f"\nProcessing job file: {job_file}")
click.echo(f"Concurrent workers: {final_max_workers}")
if debug: if debug:
click.echo("Debug mode: AI responses will be saved to debug_output/\n") click.echo("Debug mode: AI responses will be saved to debug_output/\n")

View File

@ -231,4 +231,14 @@ def get_bunny_account_api_key() -> str:
api_key = os.getenv("BUNNY_ACCOUNT_API_KEY") api_key = os.getenv("BUNNY_ACCOUNT_API_KEY")
if not api_key: if not api_key:
raise ValueError("BUNNY_ACCOUNT_API_KEY environment variable is required") raise ValueError("BUNNY_ACCOUNT_API_KEY environment variable is required")
return api_key return api_key
def get_concurrent_workers() -> int:
"""
Get the number of concurrent workers for article generation
Returns:
Number of concurrent workers (default: 5)
"""
return int(os.getenv("CONCURRENT_WORKERS", "5"))

View File

@ -2,11 +2,13 @@
Batch processor for content generation jobs Batch processor for content generation jobs
""" """
from typing import Dict, Any, Optional from typing import Dict, Any, Optional, List
import click import click
import os import os
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock
from src.generation.service import ContentGenerator from src.generation.service import ContentGenerator
from src.generation.job_config import JobConfig, Job, TierConfig from src.generation.job_config import JobConfig, Job, TierConfig
from src.generation.deployment_assignment import validate_and_resolve_targets, assign_site_for_article from src.generation.deployment_assignment import validate_and_resolve_targets, assign_site_for_article
@ -28,12 +30,15 @@ class BatchProcessor:
content_generator: ContentGenerator, content_generator: ContentGenerator,
content_repo: GeneratedContentRepository, content_repo: GeneratedContentRepository,
project_repo: ProjectRepository, project_repo: ProjectRepository,
site_deployment_repo: Optional[SiteDeploymentRepository] = None site_deployment_repo: Optional[SiteDeploymentRepository] = None,
max_workers: int = 5
): ):
self.generator = content_generator self.generator = content_generator
self.content_repo = content_repo self.content_repo = content_repo
self.project_repo = project_repo self.project_repo = project_repo
self.site_deployment_repo = site_deployment_repo self.site_deployment_repo = site_deployment_repo
self.max_workers = max_workers
self.stats_lock = Lock()
self.stats = { self.stats = {
"total_jobs": 0, "total_jobs": 0,
"processed_jobs": 0, "processed_jobs": 0,
@ -199,8 +204,8 @@ class BatchProcessor:
debug: bool, debug: bool,
continue_on_error: bool continue_on_error: bool
): ):
"""Process all articles for a tier with pre-generated titles""" """Process all articles for a tier with concurrent generation"""
click.echo(f" {tier_name}: Generating {tier_config.count} articles") click.echo(f" {tier_name}: Generating {tier_config.count} articles (concurrency: {self.max_workers})")
project = self.project_repo.get_by_id(project_id) project = self.project_repo.get_by_id(project_id)
keyword = project.main_keyword keyword = project.main_keyword
@ -220,58 +225,34 @@ class BatchProcessor:
titles = [line.strip() for line in f if line.strip()] titles = [line.strip() for line in f if line.strip()]
click.echo(f"[{tier_name}] Generated {len(titles)} titles") click.echo(f"[{tier_name}] Generated {len(titles)} titles")
click.echo(f"[{tier_name}] Titles saved to: {titles_file}")
targets_for_tier = resolved_targets if tier_name == "tier1" else {} targets_for_tier = resolved_targets if tier_name == "tier1" else {}
article_tasks = []
for article_num in range(1, tier_config.count + 1): for article_num in range(1, tier_config.count + 1):
self.stats["total_articles"] += 1
article_index = article_num - 1 article_index = article_num - 1
if article_index >= len(titles): if article_index >= len(titles):
click.echo(f" Warning: Not enough titles generated, skipping article {article_num}") click.echo(f" Warning: Not enough titles generated, skipping article {article_num}")
continue continue
title = titles[article_index] article_tasks.append({
'project_id': project_id,
try: 'tier_name': tier_name,
self._generate_single_article( 'tier_config': tier_config,
project_id, 'article_num': article_num,
tier_name, 'article_index': article_index,
tier_config, 'title': titles[article_index],
article_num, 'keyword': keyword,
article_index, 'resolved_targets': targets_for_tier,
title, 'debug': debug
keyword, })
targets_for_tier,
debug if self.max_workers > 1:
) self._process_articles_concurrent(article_tasks, continue_on_error)
self.stats["generated_articles"] += 1 else:
self._process_articles_sequential(article_tasks, continue_on_error)
except Exception as e:
self.stats["failed_articles"] += 1
import traceback
click.echo(f" [{article_num}/{tier_config.count}] FAILED: {e}")
click.echo(f" Traceback: {traceback.format_exc()}")
try:
self.content_repo.create(
project_id=project_id,
tier=tier_name,
keyword=keyword,
title="Failed Generation",
outline={"error": str(e)},
content="",
word_count=0,
status="failed"
)
except Exception as db_error:
click.echo(f" Failed to save error record: {db_error}")
if not continue_on_error:
raise
# Post-processing: URL generation and interlinking (Story 3.1-3.3)
try: try:
self._post_process_tier(project_id, tier_name, job, debug) self._post_process_tier(project_id, tier_name, job, debug)
except Exception as e: except Exception as e:
@ -367,6 +348,206 @@ class BatchProcessor:
click.echo(f"{prefix} Saved (ID: {saved_content.id}, Status: {status})") click.echo(f"{prefix} Saved (ID: {saved_content.id}, Status: {status})")
def _process_articles_concurrent(
self,
article_tasks: List[Dict[str, Any]],
continue_on_error: bool
):
"""
Process articles concurrently using ThreadPoolExecutor
"""
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_task = {
executor.submit(self._generate_single_article_thread_safe, **task): task
for task in article_tasks
}
for future in as_completed(future_to_task):
task = future_to_task[future]
article_num = task['article_num']
tier_name = task['tier_name']
tier_config = task['tier_config']
try:
future.result()
with self.stats_lock:
self.stats["generated_articles"] += 1
except Exception as e:
with self.stats_lock:
self.stats["failed_articles"] += 1
import traceback
click.echo(f" [{article_num}/{tier_config.count}] FAILED: {e}")
click.echo(f" Traceback: {traceback.format_exc()}")
try:
self.content_repo.create(
project_id=task['project_id'],
tier=tier_name,
keyword=task['keyword'],
title="Failed Generation",
outline={"error": str(e)},
content="",
word_count=0,
status="failed"
)
except Exception as db_error:
click.echo(f" Failed to save error record: {db_error}")
if not continue_on_error:
for f in future_to_task:
f.cancel()
raise
def _process_articles_sequential(
self,
article_tasks: List[Dict[str, Any]],
continue_on_error: bool
):
"""
Process articles sequentially (fallback for max_workers=1)
"""
for task in article_tasks:
with self.stats_lock:
self.stats["total_articles"] += 1
try:
self._generate_single_article(**task)
with self.stats_lock:
self.stats["generated_articles"] += 1
except Exception as e:
with self.stats_lock:
self.stats["failed_articles"] += 1
import traceback
click.echo(f" [{task['article_num']}/{task['tier_config'].count}] FAILED: {e}")
click.echo(f" Traceback: {traceback.format_exc()}")
if not continue_on_error:
raise
def _generate_single_article_thread_safe(
self,
project_id: int,
tier_name: str,
tier_config: TierConfig,
article_num: int,
article_index: int,
title: str,
keyword: str,
resolved_targets: Dict[str, int],
debug: bool
):
"""
Thread-safe wrapper for article generation
Creates a new database session for this thread
"""
with self.stats_lock:
self.stats["total_articles"] += 1
from src.database.session import db_manager
from src.generation.service import ContentGenerator
thread_session = db_manager.get_session()
try:
thread_content_repo = GeneratedContentRepository(thread_session)
thread_project_repo = ProjectRepository(thread_session)
thread_generator = ContentGenerator(
ai_client=self.generator.ai_client,
prompt_manager=self.generator.prompt_manager,
project_repo=thread_project_repo,
content_repo=thread_content_repo,
template_service=self.generator.template_service,
site_deployment_repo=self.generator.site_deployment_repo
)
prefix = f" [{article_num}/{tier_config.count}]"
models = self.current_job.models if hasattr(self, 'current_job') and self.current_job.models else None
site_deployment_id = assign_site_for_article(article_index, resolved_targets)
if site_deployment_id:
hostname = next((h for h, id in resolved_targets.items() if id == site_deployment_id), None)
click.echo(f"{prefix} Assigned to site: {hostname} (ID: {site_deployment_id})")
click.echo(f"{prefix} Using title: \"{title}\"")
click.echo(f"{prefix} Generating outline...")
outline = thread_generator.generate_outline(
project_id=project_id,
title=title,
min_h2=tier_config.min_h2_tags,
max_h2=tier_config.max_h2_tags,
min_h3=tier_config.min_h3_tags,
max_h3=tier_config.max_h3_tags,
debug=debug,
model=models.outline if models else None
)
h2_count = len(outline["outline"])
h3_count = sum(len(section.get("h3", [])) for section in outline["outline"])
click.echo(f"{prefix} Generated outline: {h2_count} H2s, {h3_count} H3s")
click.echo(f"{prefix} Generating content...")
content = thread_generator.generate_content(
project_id=project_id,
title=title,
outline=outline,
min_word_count=tier_config.min_word_count,
max_word_count=tier_config.max_word_count,
debug=debug,
model=models.content if models else None
)
word_count = thread_generator.count_words(content)
click.echo(f"{prefix} Generated content: {word_count:,} words")
status = "generated"
if word_count < tier_config.min_word_count:
click.echo(f"{prefix} Below minimum ({tier_config.min_word_count:,}), augmenting...")
content = thread_generator.augment_content(
content=content,
target_word_count=tier_config.min_word_count,
debug=debug,
project_id=project_id,
model=models.content if models else None
)
word_count = thread_generator.count_words(content)
click.echo(f"{prefix} Augmented content: {word_count:,} words")
status = "augmented"
with self.stats_lock:
self.stats["augmented_articles"] += 1
saved_content = thread_content_repo.create(
project_id=project_id,
tier=tier_name,
keyword=keyword,
title=title,
outline=outline,
content=content,
word_count=word_count,
status=status,
site_deployment_id=site_deployment_id
)
thread_session.commit()
click.echo(f"{prefix} Saved (ID: {saved_content.id}, Status: {status})")
except Exception as e:
thread_session.rollback()
raise
finally:
thread_session.close()
def _post_process_tier( def _post_process_tier(
self, self,
project_id: int, project_id: int,

View File

@ -94,6 +94,7 @@ class Job:
anchor_text_config: Optional[AnchorTextConfig] = None anchor_text_config: Optional[AnchorTextConfig] = None
failure_config: Optional[FailureConfig] = None failure_config: Optional[FailureConfig] = None
interlinking: Optional[InterlinkingConfig] = None interlinking: Optional[InterlinkingConfig] = None
max_workers: Optional[int] = None
class JobConfig: class JobConfig:
@ -288,6 +289,11 @@ class JobConfig:
see_also_max=see_also_max see_also_max=see_also_max
) )
max_workers = job_data.get("max_workers")
if max_workers is not None:
if not isinstance(max_workers, int) or max_workers < 1:
raise ValueError("'max_workers' must be a positive integer")
return Job( return Job(
project_id=project_id, project_id=project_id,
tiers=tiers, tiers=tiers,
@ -299,7 +305,8 @@ class JobConfig:
tiered_link_count_range=tiered_link_count_range, tiered_link_count_range=tiered_link_count_range,
anchor_text_config=anchor_text_config, anchor_text_config=anchor_text_config,
failure_config=failure_config, failure_config=failure_config,
interlinking=interlinking interlinking=interlinking,
max_workers=max_workers
) )
def _parse_tier(self, tier_name: str, tier_data: dict) -> TierConfig: def _parse_tier(self, tier_name: str, tier_data: dict) -> TierConfig: