diff --git a/env.example b/env.example index bae667b..f44c613 100644 --- a/env.example +++ b/env.example @@ -49,6 +49,9 @@ CLOUDFLARE_ACCOUNT_ID=your_cloudflare_account_id_here LINK_BUILDER_API_URL=http://localhost:8001/api LINK_BUILDER_API_KEY=your_link_builder_api_key_here +# fal.ai Image Generation API +FAL_API_KEY=your_fal_api_key_here + # Application Configuration LOG_LEVEL=INFO ENVIRONMENT=development diff --git a/requirements.txt b/requirements.txt index 4162c00..4798c60 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,7 +32,7 @@ beautifulsoup4==4.14.2 # AI/ML openai==2.5.0 - +fal-client==0.9.1 # Testing pytest==8.4.2 pytest-asyncio==0.21.1 diff --git a/scripts/migrate_add_image_fields.py b/scripts/migrate_add_image_fields.py new file mode 100644 index 0000000..39b9839 --- /dev/null +++ b/scripts/migrate_add_image_fields.py @@ -0,0 +1,101 @@ +""" +Migration script to add image fields to projects and generated_content tables +Story 7.1: Generate and Insert Images into Articles +""" + +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.database.session import db_manager +from sqlalchemy import text + + +def migrate(): + """Add image fields to projects and generated_content tables""" + + session = db_manager.get_session() + + try: + print("Starting migration: Add image fields...") + + print(" Adding image_theme_prompt to projects table...") + session.execute(text(""" + ALTER TABLE projects + ADD COLUMN image_theme_prompt TEXT NULL + """)) + + print(" Adding hero_image_url to generated_content table...") + session.execute(text(""" + ALTER TABLE generated_content + ADD COLUMN hero_image_url TEXT NULL + """)) + + print(" Adding content_images to generated_content table...") + session.execute(text(""" + ALTER TABLE generated_content + ADD COLUMN content_images JSON NULL + """)) + + session.commit() + + print("Migration completed successfully!") + print("\nNew fields added:") + print(" - projects.image_theme_prompt (TEXT, nullable)") + print(" - generated_content.hero_image_url (TEXT, nullable)") + print(" - generated_content.content_images (JSON, nullable)") + + except Exception as e: + session.rollback() + print(f"Migration failed: {e}") + raise + + finally: + session.close() + + +def rollback(): + """Rollback migration (remove image fields)""" + + session = db_manager.get_session() + + try: + print("Rolling back migration: Remove image fields...") + + print(" Removing content_images column...") + session.execute(text(""" + ALTER TABLE generated_content + DROP COLUMN content_images + """)) + + print(" Removing hero_image_url column...") + session.execute(text(""" + ALTER TABLE generated_content + DROP COLUMN hero_image_url + """)) + + print(" Removing image_theme_prompt column...") + session.execute(text(""" + ALTER TABLE projects + DROP COLUMN image_theme_prompt + """)) + + session.commit() + + print("Rollback completed successfully!") + + except Exception as e: + session.rollback() + print(f"Rollback failed: {e}") + raise + + finally: + session.close() + + +if __name__ == "__main__": + if len(sys.argv) > 1 and sys.argv[1] == "rollback": + rollback() + else: + migrate() + diff --git a/scripts/test_image_generation.py b/scripts/test_image_generation.py new file mode 100644 index 0000000..d427524 --- /dev/null +++ b/scripts/test_image_generation.py @@ -0,0 +1,288 @@ +""" +Test script to generate images for existing articles +Tests image generation on project 23: first 2 T1 articles and first 3 T2 articles +""" + +import sys +from pathlib import Path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.database.session import db_manager +from src.database.repositories import ( + ProjectRepository, + GeneratedContentRepository, + SiteDeploymentRepository +) +from src.generation.service import ContentGenerator +from src.generation.ai_client import AIClient, PromptManager +from src.generation.image_generator import ImageGenerator, truncate_title, slugify +from src.generation.image_injection import insert_hero_after_h1, insert_content_images_after_h2s, generate_alt_text +from src.generation.image_upload import upload_image_to_storage +from src.deployment.bunny_storage import BunnyStorageClient +from src.core.config import get_config +import click +import random +from pathlib import Path + + +def test_image_generation(project_id: int): + """Test image generation on existing articles""" + + # Create output directory for test images + output_dir = Path("test_images") + output_dir.mkdir(exist_ok=True) + click.echo(f"Test images will be saved to: {output_dir.absolute()}\n") + + session = db_manager.get_session() + + try: + # Get repositories + project_repo = ProjectRepository(session) + content_repo = GeneratedContentRepository(session) + site_repo = SiteDeploymentRepository(session) + + # Get project + project = project_repo.get_by_id(project_id) + if not project: + click.echo(f"Project {project_id} not found") + return + + click.echo(f"\n{'='*60}") + click.echo(f"Testing Image Generation for Project {project_id}") + click.echo(f"Project: {project.name}") + click.echo(f"Main Keyword: {project.main_keyword}") + click.echo(f"{'='*60}\n") + + # Get articles + t1_articles = content_repo.get_by_project_and_tier(project_id, "tier1", require_site=False) + t2_articles = content_repo.get_by_project_and_tier(project_id, "tier2", require_site=False) + + click.echo(f"Found {len(t1_articles)} T1 articles, using first 2") + click.echo(f"Found {len(t2_articles)} T2 articles, using first 3\n") + + # Initialize AI client and image generator + import os + from dotenv import load_dotenv + load_dotenv() + + api_key = os.getenv("OPENROUTER_API_KEY") + if not api_key: + click.echo("Error: OPENROUTER_API_KEY not set in environment", err=True) + return + + fal_api_key = os.getenv("FAL_API_KEY") + if not fal_api_key: + click.echo("\n[WARN] FAL_API_KEY not set - image generation will fail") + click.echo(" Set FAL_API_KEY in your .env file to test image generation\n") + + ai_client = AIClient( + api_key=api_key, + model=os.getenv("AI_MODEL", "gpt-4o-mini") + ) + prompt_manager = PromptManager() + + image_generator = ImageGenerator( + ai_client=ai_client, + prompt_manager=prompt_manager, + project_repo=project_repo + ) + + storage_client = BunnyStorageClient() + + # Test T1 articles (first 2) + click.echo(f"\n{'='*60}") + click.echo("T1 ARTICLES") + click.echo(f"{'='*60}\n") + + for i, article in enumerate(t1_articles[:2], 1): + click.echo(f"\n--- T1 Article {i}: {article.title[:60]}... ---") + + if not article.site_deployment_id: + click.echo(" [WARN] No site assigned, skipping image upload") + site = None + else: + site = site_repo.get_by_id(article.site_deployment_id) + if not site: + click.echo(" [WARN] Site not found, skipping image upload") + site = None + + # Generate theme prompt (if not exists) + click.echo("\n1. Theme Prompt:") + if project.image_theme_prompt: + click.echo(f" (Using cached): {project.image_theme_prompt}") + else: + click.echo(" Generating theme prompt...") + theme = image_generator.get_theme_prompt(project_id) + click.echo(f" Generated: {theme}") + + # Generate hero image + click.echo("\n2. Hero Image:") + try: + # Show the prompt that will be used + theme = image_generator.get_theme_prompt(project_id) + title_short = truncate_title(article.title, 4) + hero_prompt = f"{theme} Text: '{title_short}' in clean simple uppercase letters, positioned in middle of image." + click.echo(f" Prompt: {hero_prompt}") + + hero_image = image_generator.generate_hero_image( + project_id=project_id, + title=article.title, + width=1280, + height=720 + ) + + if hero_image: + click.echo(f" [OK] Generated ({len(hero_image):,} bytes)") + + # Save to local file + main_keyword_slug = slugify(project.main_keyword) + local_file = output_dir / f"hero-t1-{main_keyword_slug}-{i}.jpg" + local_file.write_bytes(hero_image) + click.echo(f" [OK] Saved to: {local_file}") + + if site: + file_path = f"images/{main_keyword_slug}.jpg" + hero_url = upload_image_to_storage(storage_client, site, hero_image, file_path) + if hero_url: + click.echo(f" [OK] Uploaded: {hero_url}") + else: + click.echo(" [FAIL] Upload failed") + else: + click.echo(" (Skipped upload - no site)") + else: + click.echo(" [FAIL] Generation failed") + except Exception as e: + click.echo(f" [ERROR] {str(e)[:200]}") + + # Generate content images (1-3 for T1) + click.echo("\n3. Content Images:") + num_content_images = random.randint(1, 3) + click.echo(f" Generating {num_content_images} content image(s)...") + + entities = project.entities or [] + related_searches = project.related_searches or [] + + if not entities or not related_searches: + click.echo(" [WARN] No entities/related_searches, skipping") + else: + for j in range(num_content_images): + entity = random.choice(entities) + related_search = random.choice(related_searches) + + click.echo(f"\n Image {j+1}/{num_content_images}:") + click.echo(f" Entity: {entity}") + click.echo(f" Related Search: {related_search}") + + try: + # Show the prompt that will be used + theme = image_generator.get_theme_prompt(project_id) + content_prompt = f"{theme} Focus on {entity} and {related_search}, professional illustration style." + click.echo(f" Prompt: {content_prompt}") + + content_image = image_generator.generate_content_image( + project_id=project_id, + entity=entity, + related_search=related_search, + width=512, + height=512 + ) + + if content_image: + click.echo(f" [OK] Generated ({len(content_image):,} bytes)") + + # Save to local file + main_keyword_slug = slugify(project.main_keyword) + entity_slug = slugify(entity) + related_slug = slugify(related_search) + local_file = output_dir / f"content-{main_keyword_slug}-{i}-{j+1}-{entity_slug}-{related_slug}.jpg" + local_file.write_bytes(content_image) + click.echo(f" [OK] Saved to: {local_file}") + + if site: + file_path = f"images/{main_keyword_slug}-{entity_slug}-{related_slug}.jpg" + img_url = upload_image_to_storage(storage_client, site, content_image, file_path) + if img_url: + click.echo(f" [OK] Uploaded: {img_url}") + else: + click.echo(" [FAIL] Upload failed") + else: + click.echo(" (Skipped upload - no site)") + else: + click.echo(" [FAIL] Generation failed") + except Exception as e: + click.echo(f" [ERROR] {str(e)[:200]}") + + # Test T2 articles (first 3) + click.echo(f"\n\n{'='*60}") + click.echo("T2 ARTICLES") + click.echo(f"{'='*60}\n") + + for i, article in enumerate(t2_articles[:3], 1): + click.echo(f"\n--- T2 Article {i}: {article.title[:60]}... ---") + + if not article.site_deployment_id: + click.echo(" [WARN] No site assigned, skipping image upload") + site = None + else: + site = site_repo.get_by_id(article.site_deployment_id) + if not site: + click.echo(" [WARN] Site not found, skipping image upload") + site = None + + # Generate hero image only (T2 doesn't get content images by default) + click.echo("\n1. Hero Image:") + try: + # Show the prompt that will be used + theme = image_generator.get_theme_prompt(project_id) + title_short = truncate_title(article.title, 4) + hero_prompt = f"{theme} Text: '{title_short}' in clean simple uppercase letters, positioned in middle of image." + click.echo(f" Prompt: {hero_prompt}") + + hero_image = image_generator.generate_hero_image( + project_id=project_id, + title=article.title, + width=1280, + height=720 + ) + + if hero_image: + click.echo(f" [OK] Generated ({len(hero_image):,} bytes)") + + # Save to local file + main_keyword_slug = slugify(project.main_keyword) + local_file = output_dir / f"hero-t2-{main_keyword_slug}-{i}.jpg" + local_file.write_bytes(hero_image) + click.echo(f" [OK] Saved to: {local_file}") + + if site: + file_path = f"images/{main_keyword_slug}.jpg" + hero_url = upload_image_to_storage(storage_client, site, hero_image, file_path) + if hero_url: + click.echo(f" [OK] Uploaded: {hero_url}") + else: + click.echo(" [FAIL] Upload failed") + else: + click.echo(" (Skipped upload - no site)") + else: + click.echo(" [FAIL] Generation failed") + except Exception as e: + click.echo(f" [ERROR] {str(e)[:200]}") + + click.echo("\n2. Content Images:") + click.echo(" (Skipped - T2 articles don't get content images by default)") + + click.echo(f"\n\n{'='*60}") + click.echo("TEST COMPLETE") + click.echo(f"{'='*60}\n") + + except Exception as e: + click.echo(f"Error: {e}", err=True) + import traceback + traceback.print_exc() + finally: + session.close() + + +if __name__ == "__main__": + test_image_generation(23) + diff --git a/src/database/models.py b/src/database/models.py index 0f6f317..69cefd2 100644 --- a/src/database/models.py +++ b/src/database/models.py @@ -109,6 +109,7 @@ class Project(Base): custom_anchor_text: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) spintax_related_search_terms: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + image_theme_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column( @@ -140,6 +141,8 @@ class GeneratedContent(Base): site_deployment_id: Mapped[Optional[int]] = mapped_column(Integer, ForeignKey('site_deployments.id'), nullable=True, index=True) deployed_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True) deployed_at: Mapped[Optional[datetime]] = mapped_column(DateTime, nullable=True, index=True) + hero_image_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + content_images: Mapped[Optional[list]] = mapped_column(JSON, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, nullable=False) updated_at: Mapped[datetime] = mapped_column( DateTime, diff --git a/src/database/repositories.py b/src/database/repositories.py index 13e3de2..f4fb9cc 100644 --- a/src/database/repositories.py +++ b/src/database/repositories.py @@ -411,7 +411,9 @@ class GeneratedContentRepository: content: str, word_count: int, status: str, - site_deployment_id: Optional[int] = None + site_deployment_id: Optional[int] = None, + hero_image_url: Optional[str] = None, + content_images: Optional[list] = None ) -> GeneratedContent: """ Create a new generated content record @@ -439,7 +441,9 @@ class GeneratedContentRepository: content=content, word_count=word_count, status=status, - site_deployment_id=site_deployment_id + site_deployment_id=site_deployment_id, + hero_image_url=hero_image_url, + content_images=content_images ) self.session.add(content_record) diff --git a/src/generation/batch_processor.py b/src/generation/batch_processor.py index 7690eac..52bf97e 100644 --- a/src/generation/batch_processor.py +++ b/src/generation/batch_processor.py @@ -21,6 +21,11 @@ from src.generation.site_assignment import assign_sites_to_batch from src.deployment.bunny_storage import BunnyStorageClient from src.deployment.deployment_service import DeploymentService from src.deployment.url_logger import URLLogger +from src.generation.image_generator import ImageGenerator +from src.generation.image_injection import insert_hero_after_h1, insert_content_images_after_h2s, generate_alt_text +from src.generation.image_upload import upload_image_to_storage +from src.generation.image_generator import slugify +import random class BatchProcessor: @@ -352,6 +357,17 @@ class BatchProcessor: status = "augmented" self.stats["augmented_articles"] += 1 + # Generate and insert images + content, hero_url, content_image_urls = self._generate_and_insert_images( + project_id=project_id, + tier_name=tier_name, + tier_config=tier_config, + title=title, + content=content, + site_deployment_id=site_deployment_id, + prefix=prefix + ) + saved_content = self.content_repo.create( project_id=project_id, tier=tier_name, @@ -361,11 +377,128 @@ class BatchProcessor: content=content, word_count=word_count, status=status, - site_deployment_id=site_deployment_id + site_deployment_id=site_deployment_id, + hero_image_url=hero_url, + content_images=content_image_urls if content_image_urls else None ) click.echo(f"{prefix} Saved (ID: {saved_content.id}, Status: {status})") + def _generate_and_insert_images( + self, + project_id: int, + tier_name: str, + tier_config: TierConfig, + title: str, + content: str, + site_deployment_id: Optional[int], + prefix: str + ) -> tuple[str, Optional[str], List[str]]: + """ + Generate images and insert into HTML content + + Note: image_config is always created by job config parser (with defaults if not in JSON). + Defaults: hero images for all tiers (1280x720), content images for T1 only (1-3 images). + """ + if not tier_config.image_config: + return content, None, [] + + project = self.project_repo.get_by_id(project_id) + if not project: + return content, None, [] + + # Initialize image generator + image_generator = ImageGenerator( + ai_client=self.generator.ai_client, + prompt_manager=self.generator.prompt_manager, + project_repo=self.project_repo + ) + + storage_client = BunnyStorageClient() + hero_url = None + content_image_urls = [] + + # Generate hero image (all tiers if enabled) + if tier_config.image_config.hero: + try: + click.echo(f"{prefix} Generating hero image...") + hero_image = image_generator.generate_hero_image( + project_id=project_id, + title=title, + width=tier_config.image_config.hero.width, + height=tier_config.image_config.hero.height + ) + + if hero_image and site_deployment_id: + site = self.site_deployment_repo.get_by_id(site_deployment_id) if self.site_deployment_repo else None + if site: + main_keyword_slug = slugify(project.main_keyword) + file_path = f"images/{main_keyword_slug}.jpg" + hero_url = upload_image_to_storage(storage_client, site, hero_image, file_path) + if hero_url: + click.echo(f"{prefix} Hero image uploaded: {hero_url}") + else: + click.echo(f"{prefix} Hero image upload failed") + except Exception as e: + click.echo(f"{prefix} Hero image generation failed: {e}") + + # Generate content images (T1 only, if enabled) + if tier_config.image_config.content and tier_config.image_config.content.max_num_images > 0: + try: + num_images = random.randint( + tier_config.image_config.content.min_num_images, + tier_config.image_config.content.max_num_images + ) + + if num_images > 0: + click.echo(f"{prefix} Generating {num_images} content image(s)...") + + entities = project.entities or [] + related_searches = project.related_searches or [] + + if not entities or not related_searches: + click.echo(f"{prefix} Skipping content images (no entities/related_searches)") + else: + for i in range(num_images): + try: + entity = random.choice(entities) + related_search = random.choice(related_searches) + + content_image = image_generator.generate_content_image( + project_id=project_id, + entity=entity, + related_search=related_search, + width=tier_config.image_config.content.width, + height=tier_config.image_config.content.height + ) + + if content_image and site_deployment_id: + site = self.site_deployment_repo.get_by_id(site_deployment_id) if self.site_deployment_repo else None + if site: + main_keyword_slug = slugify(project.main_keyword) + entity_slug = slugify(entity) + related_slug = slugify(related_search) + file_path = f"images/{main_keyword_slug}-{entity_slug}-{related_slug}.jpg" + img_url = upload_image_to_storage(storage_client, site, content_image, file_path) + if img_url: + content_image_urls.append(img_url) + click.echo(f"{prefix} Content image {i+1}/{num_images} uploaded") + except Exception as e: + click.echo(f"{prefix} Content image {i+1} generation failed: {e}") + except Exception as e: + click.echo(f"{prefix} Content image generation failed: {e}") + + # Insert images into HTML + if hero_url: + alt_text = generate_alt_text(project) + content = insert_hero_after_h1(content, hero_url, alt_text) + + if content_image_urls: + alt_texts = [generate_alt_text(project) for _ in content_image_urls] + content = insert_content_images_after_h2s(content, content_image_urls, alt_texts) + + return content, hero_url, content_image_urls + def _process_articles_concurrent( self, article_tasks: List[Dict[str, Any]], @@ -547,6 +680,17 @@ class BatchProcessor: with self.stats_lock: self.stats["augmented_articles"] += 1 + # Generate and insert images + content, hero_url, content_image_urls = self._generate_and_insert_images( + project_id=project_id, + tier_name=tier_name, + tier_config=tier_config, + title=title, + content=content, + site_deployment_id=site_deployment_id, + prefix=prefix + ) + saved_content = thread_content_repo.create( project_id=project_id, tier=tier_name, @@ -556,7 +700,9 @@ class BatchProcessor: content=content, word_count=word_count, status=status, - site_deployment_id=site_deployment_id + site_deployment_id=site_deployment_id, + hero_image_url=hero_url, + content_images=content_image_urls if content_image_urls else None ) thread_session.commit() diff --git a/src/generation/image_generator.py b/src/generation/image_generator.py new file mode 100644 index 0000000..6fb913f --- /dev/null +++ b/src/generation/image_generator.py @@ -0,0 +1,222 @@ +""" +Image generation using fal.ai FLUX.1 schnell API +""" + +import os +import re +import random +import logging +import requests +from typing import Optional, Tuple +from concurrent.futures import ThreadPoolExecutor, as_completed +import fal_client +from src.generation.ai_client import AIClient, PromptManager +from src.database.repositories import ProjectRepository + +logger = logging.getLogger(__name__) + + +def truncate_title(title: str, max_words: int = 4) -> str: + """Truncate title to max_words and convert to UPPERCASE""" + words = title.split()[:max_words] + return " ".join(words).upper() + + +def slugify(text: str) -> str: + """Convert text to URL-friendly slug""" + text = text.lower() + text = re.sub(r'[^a-z0-9]+', '-', text) + text = text.strip('-') + return text + + +class ImageGenerator: + """Generate images using fal.ai API""" + + def __init__( + self, + ai_client: AIClient, + prompt_manager: PromptManager, + project_repo: ProjectRepository + ): + self.ai_client = ai_client + self.prompt_manager = prompt_manager + self.project_repo = project_repo + # fal_client library expects FAL_KEY, but we use FAL_API_KEY in our env + # Set both for compatibility + self.fal_key = os.getenv("FAL_API_KEY") or os.getenv("FAL_KEY") + if self.fal_key and not os.getenv("FAL_KEY"): + os.environ["FAL_KEY"] = self.fal_key + if not self.fal_key: + logger.warning("FAL_API_KEY not set, image generation will fail") + self.max_concurrent = 5 + self.executor = ThreadPoolExecutor(max_workers=self.max_concurrent) + + def get_theme_prompt(self, project_id: int) -> str: + """Get or generate theme prompt for project""" + project = self.project_repo.get_by_id(project_id) + if not project: + raise ValueError(f"Project {project_id} not found") + + if project.image_theme_prompt: + return project.image_theme_prompt + + # Generate theme prompt using AI + entities_str = ", ".join(project.entities or []) + related_str = ", ".join(project.related_searches or []) + + system_msg, user_prompt = self.prompt_manager.format_prompt( + "image_theme_generation", + main_keyword=project.main_keyword, + entities=entities_str, + related_searches=related_str + ) + + theme_prompt, _ = self.ai_client.generate_completion( + prompt=user_prompt, + system_message=system_msg, + max_tokens=200, + temperature=0.7 + ) + + # Save to project + project.image_theme_prompt = theme_prompt.strip() + self.project_repo.session.commit() + + logger.info(f"Generated theme prompt for project {project_id}") + return project.image_theme_prompt + + def generate_hero_image( + self, + project_id: int, + title: str, + width: int = 1280, + height: int = 720 + ) -> Optional[bytes]: + """Generate hero image with title text""" + if not self.fal_key: + logger.error("FAL_API_KEY not set") + return None + + try: + theme = self.get_theme_prompt(project_id) + title_short = truncate_title(title, 4) + prompt = f"{theme} Text: '{title_short}' in clean simple uppercase letters, positioned in middle of image." + + logger.info(f"Generating hero image with prompt: {prompt}") + + result = fal_client.subscribe( + "fal-ai/flux-1/schnell", + arguments={ + "prompt": prompt, + "image_size": {"width": width, "height": height}, + "num_inference_steps": 4, + "guidance_scale": 3.5, + "output_format": "jpeg" + }, + with_logs=True + ) + + logger.debug(f"API response keys: {result.keys() if result else 'None'}") + logger.debug(f"API response type: {type(result)}") + + # Check different possible response structures + images = None + if result: + if "images" in result: + images = result["images"] + elif "data" in result and "images" in result["data"]: + images = result["data"]["images"] + elif isinstance(result, dict) and len(result) == 1 and "images" in list(result.values())[0]: + images = list(result.values())[0]["images"] + + if images and len(images) > 0: + image_data = images[0] + image_url = image_data.get("url") + + if not image_url: + logger.error(f"No URL in image response. Image data keys: {image_data.keys() if isinstance(image_data, dict) else 'not a dict'}") + return None + + logger.info(f"Downloading hero image from URL: {image_url}") + response = requests.get(image_url, timeout=30) + response.raise_for_status() + return response.content + + logger.error(f"No image returned from fal.ai. Response: {result}") + return None + + except Exception as e: + logger.error(f"Failed to generate hero image: {e}", exc_info=True) + logger.error(f"Exception type: {type(e).__name__}") + if hasattr(e, 'response'): + logger.error(f"Response: {e.response}") + return None + + def generate_content_image( + self, + project_id: int, + entity: str, + related_search: str, + width: int = 512, + height: int = 512 + ) -> Optional[bytes]: + """Generate content image with entity and related search""" + if not self.fal_key: + logger.error("FAL_API_KEY not set") + return None + + try: + theme = self.get_theme_prompt(project_id) + prompt = f"{theme} Focus on {entity} and {related_search}, professional illustration style." + + logger.info(f"Generating content image with prompt: {prompt}") + + result = fal_client.subscribe( + "fal-ai/flux-1/schnell", + arguments={ + "prompt": prompt, + "image_size": {"width": width, "height": height}, + "num_inference_steps": 4, + "guidance_scale": 3.5, + "output_format": "jpeg" + }, + with_logs=True + ) + + logger.debug(f"API response keys: {result.keys() if result else 'None'}") + logger.debug(f"API response type: {type(result)}") + + # Check different possible response structures + images = None + if result: + if "images" in result: + images = result["images"] + elif "data" in result and "images" in result["data"]: + images = result["data"]["images"] + elif isinstance(result, dict) and len(result) == 1 and "images" in list(result.values())[0]: + images = list(result.values())[0]["images"] + + if images and len(images) > 0: + image_data = images[0] + image_url = image_data.get("url") + + if not image_url: + logger.error(f"No URL in image response. Image data keys: {image_data.keys() if isinstance(image_data, dict) else 'not a dict'}") + return None + + logger.info(f"Downloading content image from URL: {image_url}") + response = requests.get(image_url, timeout=30) + response.raise_for_status() + return response.content + + logger.error(f"No image returned from fal.ai. Response: {result}") + return None + + except Exception as e: + logger.error(f"Failed to generate content image: {e}", exc_info=True) + logger.error(f"Exception type: {type(e).__name__}") + if hasattr(e, 'response'): + logger.error(f"Response: {e.response}") + return None + diff --git a/src/generation/image_injection.py b/src/generation/image_injection.py new file mode 100644 index 0000000..a225b3b --- /dev/null +++ b/src/generation/image_injection.py @@ -0,0 +1,94 @@ +""" +HTML image insertion logic +""" + +import re +import random +from typing import List, Optional +from src.database.models import Project + + +def generate_alt_text(project: Project) -> str: + """Generate alt text with 3 entities and 2 related searches""" + entities = project.entities or [] + related_searches = project.related_searches or [] + + # Pick 3 random entities (or all if less than 3) + selected_entities = random.sample(entities, min(3, len(entities))) if entities else [] + # Pick 2 random related searches (or all if less than 2) + selected_related = random.sample(related_searches, min(2, len(related_searches))) if related_searches else [] + + # Combine: entity1 related_search1 entity2 related_search2 entity3 + parts = [] + # Add entities and related searches in order: entity1, related1, entity2, related2, entity3 + for i in range(max(len(selected_entities), len(selected_related))): + if i < len(selected_entities): + parts.append(selected_entities[i]) + if i < len(selected_related): + parts.append(selected_related[i]) + if len(parts) >= 5: + break + + return " ".join(parts[:5]) if parts else project.main_keyword + + +def insert_hero_after_h1(html: str, hero_url: str, alt_text: str) -> str: + """Insert hero image immediately after first H1 tag""" + # Find first