209 lines
5.0 KiB
Python
209 lines
5.0 KiB
Python
"""
|
|
Unit tests for job configuration
|
|
"""
|
|
|
|
import pytest
|
|
import json
|
|
import tempfile
|
|
from pathlib import Path
|
|
from src.generation.job_config import (
|
|
JobConfig, TierConfig, ModelConfig, AnchorTextConfig,
|
|
FailureConfig, InterlinkingConfig
|
|
)
|
|
|
|
|
|
def test_model_config_creation():
|
|
"""Test ModelConfig creation"""
|
|
config = ModelConfig(
|
|
title="model1",
|
|
outline="model2",
|
|
content="model3"
|
|
)
|
|
|
|
assert config.title == "model1"
|
|
assert config.outline == "model2"
|
|
assert config.content == "model3"
|
|
|
|
|
|
def test_anchor_text_config_modes():
|
|
"""Test different anchor text modes"""
|
|
default_config = AnchorTextConfig(mode="default")
|
|
assert default_config.mode == "default"
|
|
|
|
override_config = AnchorTextConfig(
|
|
mode="override",
|
|
custom_text=["anchor1", "anchor2"]
|
|
)
|
|
assert override_config.mode == "override"
|
|
assert len(override_config.custom_text) == 2
|
|
|
|
append_config = AnchorTextConfig(
|
|
mode="append",
|
|
additional_text=["extra"]
|
|
)
|
|
assert append_config.mode == "append"
|
|
|
|
|
|
def test_tier_config_creation():
|
|
"""Test TierConfig creation"""
|
|
models = ModelConfig(
|
|
title="model1",
|
|
outline="model2",
|
|
content="model3"
|
|
)
|
|
|
|
tier_config = TierConfig(
|
|
tier=1,
|
|
article_count=15,
|
|
models=models
|
|
)
|
|
|
|
assert tier_config.tier == 1
|
|
assert tier_config.article_count == 15
|
|
assert tier_config.validation_attempts == 3
|
|
|
|
|
|
def test_job_config_creation():
|
|
"""Test JobConfig creation"""
|
|
models = ModelConfig(
|
|
title="model1",
|
|
outline="model2",
|
|
content="model3"
|
|
)
|
|
|
|
tier = TierConfig(
|
|
tier=1,
|
|
article_count=10,
|
|
models=models
|
|
)
|
|
|
|
job = JobConfig(
|
|
job_name="Test Job",
|
|
project_id=1,
|
|
tiers=[tier]
|
|
)
|
|
|
|
assert job.job_name == "Test Job"
|
|
assert job.project_id == 1
|
|
assert len(job.tiers) == 1
|
|
assert job.get_total_articles() == 10
|
|
|
|
|
|
def test_job_config_multiple_tiers():
|
|
"""Test JobConfig with multiple tiers"""
|
|
models = ModelConfig(
|
|
title="model1",
|
|
outline="model2",
|
|
content="model3"
|
|
)
|
|
|
|
tier1 = TierConfig(tier=1, article_count=10, models=models)
|
|
tier2 = TierConfig(tier=2, article_count=20, models=models)
|
|
|
|
job = JobConfig(
|
|
job_name="Multi-Tier Job",
|
|
project_id=1,
|
|
tiers=[tier1, tier2]
|
|
)
|
|
|
|
assert job.get_total_articles() == 30
|
|
|
|
|
|
def test_job_config_unique_tiers_validation():
|
|
"""Test that tier numbers must be unique"""
|
|
models = ModelConfig(
|
|
title="model1",
|
|
outline="model2",
|
|
content="model3"
|
|
)
|
|
|
|
tier1 = TierConfig(tier=1, article_count=10, models=models)
|
|
tier2 = TierConfig(tier=1, article_count=20, models=models)
|
|
|
|
with pytest.raises(ValueError, match="unique"):
|
|
JobConfig(
|
|
job_name="Duplicate Tiers",
|
|
project_id=1,
|
|
tiers=[tier1, tier2]
|
|
)
|
|
|
|
|
|
def test_job_config_from_file():
|
|
"""Test loading JobConfig from JSON file"""
|
|
config_data = {
|
|
"job_name": "Test Job",
|
|
"project_id": 1,
|
|
"tiers": [
|
|
{
|
|
"tier": 1,
|
|
"article_count": 5,
|
|
"models": {
|
|
"title": "model1",
|
|
"outline": "model2",
|
|
"content": "model3"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
json.dump(config_data, f)
|
|
temp_path = f.name
|
|
|
|
try:
|
|
job = JobConfig.from_file(temp_path)
|
|
assert job.job_name == "Test Job"
|
|
assert job.project_id == 1
|
|
assert len(job.tiers) == 1
|
|
finally:
|
|
Path(temp_path).unlink()
|
|
|
|
|
|
def test_job_config_to_file():
|
|
"""Test saving JobConfig to JSON file"""
|
|
models = ModelConfig(
|
|
title="model1",
|
|
outline="model2",
|
|
content="model3"
|
|
)
|
|
|
|
tier = TierConfig(tier=1, article_count=5, models=models)
|
|
job = JobConfig(
|
|
job_name="Test Job",
|
|
project_id=1,
|
|
tiers=[tier]
|
|
)
|
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
|
|
temp_path = f.name
|
|
|
|
try:
|
|
job.to_file(temp_path)
|
|
assert Path(temp_path).exists()
|
|
|
|
loaded_job = JobConfig.from_file(temp_path)
|
|
assert loaded_job.job_name == job.job_name
|
|
assert loaded_job.project_id == job.project_id
|
|
finally:
|
|
Path(temp_path).unlink()
|
|
|
|
|
|
def test_interlinking_config_validation():
|
|
"""Test InterlinkingConfig validation"""
|
|
config = InterlinkingConfig(
|
|
links_per_article_min=2,
|
|
links_per_article_max=4
|
|
)
|
|
|
|
assert config.links_per_article_min == 2
|
|
assert config.links_per_article_max == 4
|
|
|
|
|
|
def test_failure_config_defaults():
|
|
"""Test FailureConfig default values"""
|
|
config = FailureConfig()
|
|
|
|
assert config.max_consecutive_failures == 5
|
|
assert config.skip_on_failure is True
|
|
|