Big-Link-Man/src/generation/batch_processor.py

184 lines
6.8 KiB
Python

"""
Batch job processor for generating multiple articles across tiers
"""
import time
from typing import Optional
from sqlalchemy.orm import Session
from src.database.models import Project
from src.database.repositories import ProjectRepository
from src.generation.service import ContentGenerationService, GenerationError
from src.generation.job_config import JobConfig, JobResult
from src.core.config import Config, get_config
class BatchProcessor:
"""Processes batch content generation jobs"""
def __init__(
self,
session: Session,
config: Optional[Config] = None
):
"""
Initialize batch processor
Args:
session: Database session
config: Application configuration
"""
self.session = session
self.config = config or get_config()
self.project_repo = ProjectRepository(session)
self.generation_service = ContentGenerationService(session, config)
def process_job(
self,
job_config: JobConfig,
progress_callback: Optional[callable] = None,
debug: bool = False
) -> JobResult:
"""
Process a batch job according to configuration
Args:
job_config: Job configuration
progress_callback: Optional callback function(tier, article_num, total, status)
Returns:
JobResult with statistics
"""
start_time = time.time()
project = self.project_repo.get_by_id(job_config.project_id)
if not project:
raise ValueError(f"Project {job_config.project_id} not found")
result = JobResult(
job_name=job_config.job_name,
project_id=job_config.project_id,
total_articles=job_config.get_total_articles(),
successful=0,
failed=0,
skipped=0
)
consecutive_failures = 0
for tier_config in job_config.tiers:
tier = tier_config.tier
for article_num in range(1, tier_config.article_count + 1):
if progress_callback:
progress_callback(
tier=tier,
article_num=article_num,
total=tier_config.article_count,
status="starting"
)
try:
content = self.generation_service.generate_article(
project=project,
tier=tier,
title_model=tier_config.models.title,
outline_model=tier_config.models.outline,
content_model=tier_config.models.content,
max_retries=tier_config.validation_attempts,
progress_callback=progress_callback,
debug=debug
)
result.successful += 1
result.add_tier_result(tier, "successful")
consecutive_failures = 0
if progress_callback:
progress_callback(
tier=tier,
article_num=article_num,
total=tier_config.article_count,
status="completed",
content_id=content.id
)
except GenerationError as e:
error_msg = f"Tier {tier}, Article {article_num}: {str(e)}"
result.add_error(error_msg)
consecutive_failures += 1
if job_config.failure_config.skip_on_failure:
result.skipped += 1
result.add_tier_result(tier, "skipped")
if progress_callback:
progress_callback(
tier=tier,
article_num=article_num,
total=tier_config.article_count,
status="skipped",
error=str(e)
)
if consecutive_failures >= job_config.failure_config.max_consecutive_failures:
result.add_error(
f"Stopping job: {consecutive_failures} consecutive failures exceeded threshold"
)
result.duration = time.time() - start_time
return result
else:
result.failed += 1
result.add_tier_result(tier, "failed")
result.duration = time.time() - start_time
if progress_callback:
progress_callback(
tier=tier,
article_num=article_num,
total=tier_config.article_count,
status="failed",
error=str(e)
)
return result
except Exception as e:
error_msg = f"Tier {tier}, Article {article_num}: Unexpected error: {str(e)}"
result.add_error(error_msg)
result.failed += 1
result.add_tier_result(tier, "failed")
result.duration = time.time() - start_time
if progress_callback:
progress_callback(
tier=tier,
article_num=article_num,
total=tier_config.article_count,
status="failed",
error=str(e)
)
return result
result.duration = time.time() - start_time
return result
def process_job_from_file(
self,
job_file_path: str,
progress_callback: Optional[callable] = None
) -> JobResult:
"""
Load and process a job from a JSON file
Args:
job_file_path: Path to job configuration JSON file
progress_callback: Optional progress callback
Returns:
JobResult with statistics
"""
job_config = JobConfig.from_file(job_file_path)
return self.process_job(job_config, progress_callback)