184 lines
6.8 KiB
Python
184 lines
6.8 KiB
Python
"""
|
|
Batch job processor for generating multiple articles across tiers
|
|
"""
|
|
|
|
import time
|
|
from typing import Optional
|
|
from sqlalchemy.orm import Session
|
|
from src.database.models import Project
|
|
from src.database.repositories import ProjectRepository
|
|
from src.generation.service import ContentGenerationService, GenerationError
|
|
from src.generation.job_config import JobConfig, JobResult
|
|
from src.core.config import Config, get_config
|
|
|
|
|
|
class BatchProcessor:
|
|
"""Processes batch content generation jobs"""
|
|
|
|
def __init__(
|
|
self,
|
|
session: Session,
|
|
config: Optional[Config] = None
|
|
):
|
|
"""
|
|
Initialize batch processor
|
|
|
|
Args:
|
|
session: Database session
|
|
config: Application configuration
|
|
"""
|
|
self.session = session
|
|
self.config = config or get_config()
|
|
self.project_repo = ProjectRepository(session)
|
|
self.generation_service = ContentGenerationService(session, config)
|
|
|
|
def process_job(
|
|
self,
|
|
job_config: JobConfig,
|
|
progress_callback: Optional[callable] = None,
|
|
debug: bool = False
|
|
) -> JobResult:
|
|
"""
|
|
Process a batch job according to configuration
|
|
|
|
Args:
|
|
job_config: Job configuration
|
|
progress_callback: Optional callback function(tier, article_num, total, status)
|
|
|
|
Returns:
|
|
JobResult with statistics
|
|
"""
|
|
start_time = time.time()
|
|
|
|
project = self.project_repo.get_by_id(job_config.project_id)
|
|
if not project:
|
|
raise ValueError(f"Project {job_config.project_id} not found")
|
|
|
|
result = JobResult(
|
|
job_name=job_config.job_name,
|
|
project_id=job_config.project_id,
|
|
total_articles=job_config.get_total_articles(),
|
|
successful=0,
|
|
failed=0,
|
|
skipped=0
|
|
)
|
|
|
|
consecutive_failures = 0
|
|
|
|
for tier_config in job_config.tiers:
|
|
tier = tier_config.tier
|
|
|
|
for article_num in range(1, tier_config.article_count + 1):
|
|
if progress_callback:
|
|
progress_callback(
|
|
tier=tier,
|
|
article_num=article_num,
|
|
total=tier_config.article_count,
|
|
status="starting"
|
|
)
|
|
|
|
try:
|
|
content = self.generation_service.generate_article(
|
|
project=project,
|
|
tier=tier,
|
|
title_model=tier_config.models.title,
|
|
outline_model=tier_config.models.outline,
|
|
content_model=tier_config.models.content,
|
|
max_retries=tier_config.validation_attempts,
|
|
progress_callback=progress_callback,
|
|
debug=debug
|
|
)
|
|
|
|
result.successful += 1
|
|
result.add_tier_result(tier, "successful")
|
|
consecutive_failures = 0
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
tier=tier,
|
|
article_num=article_num,
|
|
total=tier_config.article_count,
|
|
status="completed",
|
|
content_id=content.id
|
|
)
|
|
|
|
except GenerationError as e:
|
|
error_msg = f"Tier {tier}, Article {article_num}: {str(e)}"
|
|
result.add_error(error_msg)
|
|
consecutive_failures += 1
|
|
|
|
if job_config.failure_config.skip_on_failure:
|
|
result.skipped += 1
|
|
result.add_tier_result(tier, "skipped")
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
tier=tier,
|
|
article_num=article_num,
|
|
total=tier_config.article_count,
|
|
status="skipped",
|
|
error=str(e)
|
|
)
|
|
|
|
if consecutive_failures >= job_config.failure_config.max_consecutive_failures:
|
|
result.add_error(
|
|
f"Stopping job: {consecutive_failures} consecutive failures exceeded threshold"
|
|
)
|
|
result.duration = time.time() - start_time
|
|
return result
|
|
else:
|
|
result.failed += 1
|
|
result.add_tier_result(tier, "failed")
|
|
result.duration = time.time() - start_time
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
tier=tier,
|
|
article_num=article_num,
|
|
total=tier_config.article_count,
|
|
status="failed",
|
|
error=str(e)
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
error_msg = f"Tier {tier}, Article {article_num}: Unexpected error: {str(e)}"
|
|
result.add_error(error_msg)
|
|
result.failed += 1
|
|
result.add_tier_result(tier, "failed")
|
|
result.duration = time.time() - start_time
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
tier=tier,
|
|
article_num=article_num,
|
|
total=tier_config.article_count,
|
|
status="failed",
|
|
error=str(e)
|
|
)
|
|
|
|
return result
|
|
|
|
result.duration = time.time() - start_time
|
|
return result
|
|
|
|
def process_job_from_file(
|
|
self,
|
|
job_file_path: str,
|
|
progress_callback: Optional[callable] = None
|
|
) -> JobResult:
|
|
"""
|
|
Load and process a job from a JSON file
|
|
|
|
Args:
|
|
job_file_path: Path to job configuration JSON file
|
|
progress_callback: Optional progress callback
|
|
|
|
Returns:
|
|
JobResult with statistics
|
|
"""
|
|
job_config = JobConfig.from_file(job_file_path)
|
|
return self.process_job(job_config, progress_callback)
|
|
|