489 lines
18 KiB
Python
489 lines
18 KiB
Python
"""Press-release pipeline tool.
|
||
|
||
Autonomous workflow:
|
||
1. Generate 7 compliant headlines (chat brain)
|
||
2. AI judge picks the 2 best (chat brain)
|
||
3. Write 2 full press releases (execution brain × 2)
|
||
4. Generate 2 JSON-LD schemas (execution brain × 2, Sonnet + WebSearch)
|
||
5. Save 4 files, return cost summary
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
from . import tool
|
||
|
||
log = logging.getLogger(__name__)
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Paths
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_ROOT_DIR = Path(__file__).resolve().parent.parent.parent
|
||
_SKILLS_DIR = _ROOT_DIR / "skills"
|
||
_DATA_DIR = _ROOT_DIR / "data"
|
||
_OUTPUT_DIR = _DATA_DIR / "generated" / "press_releases"
|
||
_COMPANIES_FILE = _SKILLS_DIR / "companies.md"
|
||
_HEADLINES_FILE = _SKILLS_DIR / "headlines.md"
|
||
|
||
SONNET_CLI_MODEL = "sonnet"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _load_skill(filename: str) -> str:
|
||
"""Read a markdown skill file from the skills/ directory."""
|
||
path = _SKILLS_DIR / filename
|
||
if not path.exists():
|
||
raise FileNotFoundError(f"Skill file not found: {path}")
|
||
return path.read_text(encoding="utf-8")
|
||
|
||
|
||
def _load_file_if_exists(path: Path) -> str:
|
||
"""Read a file if it exists, return empty string otherwise."""
|
||
if path.exists():
|
||
return path.read_text(encoding="utf-8")
|
||
return ""
|
||
|
||
|
||
def _slugify(text: str) -> str:
|
||
"""Turn a headline into a filesystem-safe slug."""
|
||
text = text.lower().strip()
|
||
text = re.sub(r"[^\w\s-]", "", text)
|
||
text = re.sub(r"[\s_]+", "-", text)
|
||
return text[:60].strip("-")
|
||
|
||
|
||
def _word_count(text: str) -> int:
|
||
return len(text.split())
|
||
|
||
|
||
def _chat_call(agent, messages: list[dict]) -> str:
|
||
"""Make a non-streaming chat-brain call and return the full text."""
|
||
parts: list[str] = []
|
||
for chunk in agent.llm.chat(messages, tools=None, stream=False):
|
||
if chunk["type"] == "text":
|
||
parts.append(chunk["content"])
|
||
return "".join(parts)
|
||
|
||
|
||
def _clean_pr_output(raw: str, headline: str) -> str:
|
||
"""Clean execution brain output to just the press release text.
|
||
|
||
Strategy: find the headline we asked for in the output, take everything
|
||
from that point forward. Strip any markdown formatting artifacts.
|
||
"""
|
||
# Normalize the headline for matching
|
||
headline_lower = headline.strip().lower()
|
||
|
||
lines = raw.strip().splitlines()
|
||
|
||
# Try to find the exact headline in the output
|
||
pr_start = None
|
||
for i, line in enumerate(lines):
|
||
clean_line = re.sub(r"\*\*", "", line).strip().lower()
|
||
if clean_line == headline_lower:
|
||
pr_start = i
|
||
break
|
||
|
||
# Fallback: find a line that contains most of the headline words
|
||
if pr_start is None:
|
||
headline_words = set(headline_lower.split())
|
||
for i, line in enumerate(lines):
|
||
clean_line = re.sub(r"\*\*", "", line).strip().lower()
|
||
line_words = set(clean_line.split())
|
||
# If >70% of headline words are in this line, it's probably the headline
|
||
if len(headline_words & line_words) >= len(headline_words) * 0.7:
|
||
pr_start = i
|
||
break
|
||
|
||
# If we still can't find it, just take the whole output
|
||
if pr_start is None:
|
||
pr_start = 0
|
||
|
||
# Rebuild from the headline forward
|
||
result_lines = []
|
||
for line in lines[pr_start:]:
|
||
# Strip markdown formatting
|
||
line = re.sub(r"\*\*", "", line)
|
||
line = re.sub(r"^#{1,6}\s+", "", line)
|
||
result_lines.append(line)
|
||
|
||
result = "\n".join(result_lines).strip()
|
||
|
||
# Remove trailing horizontal rules
|
||
result = re.sub(r"\n---\s*$", "", result).strip()
|
||
|
||
return result
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Prompt builders
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def _build_headline_prompt(topic: str, company_name: str, url: str,
|
||
lsi_terms: str, headlines_ref: str) -> str:
|
||
"""Build the prompt for Step 1: generate 7 headlines."""
|
||
prompt = (
|
||
f"Generate exactly 7 unique press release headline options for the following.\n\n"
|
||
f"Topic: {topic}\n"
|
||
f"Company: {company_name}\n"
|
||
)
|
||
if url:
|
||
prompt += f"Reference URL: {url}\n"
|
||
if lsi_terms:
|
||
prompt += f"LSI terms to consider: {lsi_terms}\n"
|
||
|
||
prompt += (
|
||
"\nRules for EVERY headline:\n"
|
||
"- Maximum 70 characters (including spaces)\n"
|
||
"- Title case\n"
|
||
"- News-focused, not promotional\n"
|
||
"- NO location/geographic keywords\n"
|
||
"- NO superlatives (best, top, leading, #1)\n"
|
||
"- NO questions\n"
|
||
"- NO colons — colons are considered lower quality\n"
|
||
"- Must contain an actual news announcement\n"
|
||
)
|
||
|
||
if headlines_ref:
|
||
prompt += (
|
||
"\nHere are examples of high-quality headlines to use as reference "
|
||
"for tone, structure, and length:\n\n"
|
||
f"{headlines_ref}\n"
|
||
)
|
||
|
||
prompt += (
|
||
"\nReturn ONLY a numbered list (1-7), one headline per line. "
|
||
"No commentary, no character counts, just the headlines."
|
||
)
|
||
return prompt
|
||
|
||
|
||
def _build_judge_prompt(headlines: str, headlines_ref: str) -> str:
|
||
"""Build the prompt for Step 2: pick the 2 best headlines."""
|
||
prompt = (
|
||
"You are judging press release headlines for Press Advantage distribution. "
|
||
"Pick the 2 best headlines from the candidates below.\n\n"
|
||
"DISQUALIFY any headline that:\n"
|
||
"- Contains a colon\n"
|
||
"- Contains location/geographic keywords\n"
|
||
"- Contains superlatives (best, top, leading, #1)\n"
|
||
"- Is a question\n"
|
||
"- Exceeds 70 characters\n"
|
||
"- Implies a NEW product launch when none exists (avoid 'launches', "
|
||
"'introduces', 'unveils', 'announces new' unless the topic is genuinely new)\n\n"
|
||
"PREFER headlines that:\n"
|
||
"- Match the tone and structure of the reference examples below\n"
|
||
"- Use action verbs like 'Highlights', 'Expands', 'Strengthens', "
|
||
"'Reinforces', 'Delivers', 'Adds'\n"
|
||
"- Describe what the company DOES or OFFERS, not what it just invented\n"
|
||
"- Read like a real news wire headline, not a product announcement\n\n"
|
||
f"Candidates:\n{headlines}\n\n"
|
||
)
|
||
|
||
if headlines_ref:
|
||
prompt += (
|
||
"Reference headlines (these scored 77+ on quality — match their style):\n"
|
||
f"{headlines_ref}\n\n"
|
||
)
|
||
|
||
prompt += (
|
||
"Return ONLY the 2 best headlines, one per line, exactly as written in the candidates. "
|
||
"No numbering, no commentary."
|
||
)
|
||
return prompt
|
||
|
||
|
||
def _build_pr_prompt(headline: str, topic: str, company_name: str,
|
||
url: str, lsi_terms: str, required_phrase: str,
|
||
skill_text: str, companies_file: str) -> str:
|
||
"""Build the prompt for Step 3: write one full press release."""
|
||
prompt = (
|
||
f"{skill_text}\n\n"
|
||
"---\n\n"
|
||
f"Write a press release using the headline below. "
|
||
f"Follow every rule in the skill instructions above.\n\n"
|
||
f"Headline: {headline}\n"
|
||
f"Topic: {topic}\n"
|
||
f"Company: {company_name}\n"
|
||
)
|
||
if url:
|
||
prompt += f"Reference URL (fetch for context): {url}\n"
|
||
if lsi_terms:
|
||
prompt += f"LSI terms to integrate: {lsi_terms}\n"
|
||
if required_phrase:
|
||
prompt += f'Required phrase (use exactly once): "{required_phrase}"\n'
|
||
|
||
if companies_file:
|
||
prompt += (
|
||
f"\nCompany directory — look up the executive name and title for {company_name}. "
|
||
f"If the company is NOT listed below, use 'a company spokesperson' for quotes "
|
||
f"instead of making up a name:\n"
|
||
f"{companies_file}\n"
|
||
)
|
||
|
||
prompt += (
|
||
"\nTarget 600-750 words. Minimum 575, maximum 800.\n\n"
|
||
"CRITICAL OUTPUT RULES:\n"
|
||
"- Output ONLY the press release text\n"
|
||
"- Start with the headline on the first line, then the body\n"
|
||
"- Do NOT include any commentary, reasoning, notes, or explanations\n"
|
||
"- Do NOT use markdown formatting (no **, no ##, no ---)\n"
|
||
"- Do NOT prefix with 'Here is the press release' or similar\n"
|
||
"- The very first line of your output must be the headline"
|
||
)
|
||
return prompt
|
||
|
||
|
||
def _build_schema_prompt(pr_text: str, company_name: str, url: str,
|
||
skill_text: str) -> str:
|
||
"""Build the prompt for Step 4: generate JSON-LD schema for one PR."""
|
||
prompt = (
|
||
f"{skill_text}\n\n"
|
||
"---\n\n"
|
||
"Generate a NewsArticle JSON-LD schema for the press release below. "
|
||
"Follow every rule in the skill instructions above. "
|
||
"Use WebSearch to find Wikipedia URLs for each entity.\n\n"
|
||
"CRITICAL OUTPUT RULES:\n"
|
||
"- Output ONLY valid JSON\n"
|
||
"- No markdown fences, no commentary, no explanations\n"
|
||
"- The very first character of your output must be {\n"
|
||
)
|
||
prompt += (
|
||
f"\nCompany name: {company_name}\n\n"
|
||
f"Press release text:\n{pr_text}"
|
||
)
|
||
return prompt
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Main tool
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@tool(
|
||
"write_press_releases",
|
||
description=(
|
||
"Full autonomous press-release pipeline. Generates 7 headlines, "
|
||
"AI-picks the best 2, writes 2 complete press releases (600-750 words each), "
|
||
"generates JSON-LD schema for each, and saves all files. "
|
||
"Returns both press releases, both schemas, file paths, and a cost summary. "
|
||
"Use when the user asks to write, create, or draft a press release."
|
||
),
|
||
category="content",
|
||
)
|
||
def write_press_releases(
|
||
topic: str,
|
||
company_name: str,
|
||
url: str = "",
|
||
lsi_terms: str = "",
|
||
required_phrase: str = "",
|
||
ctx: dict = None,
|
||
) -> str:
|
||
"""Run the full press-release pipeline and return results + cost summary."""
|
||
if not ctx or "agent" not in ctx:
|
||
return "Error: press release tool requires agent context."
|
||
|
||
agent = ctx["agent"]
|
||
|
||
# Load skill prompts
|
||
try:
|
||
pr_skill = _load_skill("press_release_prompt.md")
|
||
schema_skill = _load_skill("press-release-schema.md")
|
||
except FileNotFoundError as e:
|
||
return f"Error: {e}"
|
||
|
||
# Load reference files
|
||
companies_file = _load_file_if_exists(_COMPANIES_FILE)
|
||
headlines_ref = _load_file_if_exists(_HEADLINES_FILE)
|
||
|
||
# Ensure output directory
|
||
_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||
today = datetime.now().strftime("%Y-%m-%d")
|
||
|
||
cost_log: list[dict] = []
|
||
|
||
# ── Step 1: Generate 7 headlines (chat brain) ─────────────────────────
|
||
step_start = time.time()
|
||
headline_prompt = _build_headline_prompt(topic, company_name, url, lsi_terms, headlines_ref)
|
||
messages = [
|
||
{"role": "system", "content": "You are a senior press-release headline writer."},
|
||
{"role": "user", "content": headline_prompt},
|
||
]
|
||
headlines_raw = _chat_call(agent, messages)
|
||
cost_log.append({
|
||
"step": "1. Generate 7 headlines",
|
||
"model": agent.llm.current_model,
|
||
"elapsed_s": round(time.time() - step_start, 1),
|
||
})
|
||
|
||
if not headlines_raw.strip():
|
||
return "Error: headline generation returned empty result."
|
||
|
||
# Save all 7 headline candidates to file
|
||
slug_base = _slugify(f"{company_name}-{topic}")
|
||
headlines_file = _OUTPUT_DIR / f"{slug_base}_{today}_headlines.txt"
|
||
headlines_file.write_text(headlines_raw.strip(), encoding="utf-8")
|
||
|
||
# ── Step 2: AI judge picks best 2 (chat brain) ───────────────────────
|
||
step_start = time.time()
|
||
judge_prompt = _build_judge_prompt(headlines_raw, headlines_ref)
|
||
messages = [
|
||
{"role": "system", "content": "You are a senior PR editor."},
|
||
{"role": "user", "content": judge_prompt},
|
||
]
|
||
judge_result = _chat_call(agent, messages)
|
||
cost_log.append({
|
||
"step": "2. Judge picks best 2",
|
||
"model": agent.llm.current_model,
|
||
"elapsed_s": round(time.time() - step_start, 1),
|
||
})
|
||
|
||
# Parse the two winning headlines
|
||
winners = [line.strip().lstrip("0123456789.-) ") for line in judge_result.strip().splitlines() if line.strip()]
|
||
if len(winners) < 2:
|
||
all_headlines = [line.strip().lstrip("0123456789.-) ") for line in headlines_raw.strip().splitlines() if line.strip()]
|
||
winners = all_headlines[:2] if len(all_headlines) >= 2 else [all_headlines[0], all_headlines[0]] if all_headlines else ["Headline A", "Headline B"]
|
||
winners = winners[:2]
|
||
|
||
# ── Step 3: Write 2 press releases (execution brain × 2) ─────────────
|
||
pr_texts: list[str] = []
|
||
pr_files: list[str] = []
|
||
for i, headline in enumerate(winners):
|
||
step_start = time.time()
|
||
pr_prompt = _build_pr_prompt(
|
||
headline, topic, company_name, url, lsi_terms,
|
||
required_phrase, pr_skill, companies_file,
|
||
)
|
||
exec_tools = "Bash,Read,Edit,Write,Glob,Grep,WebFetch"
|
||
raw_result = agent.execute_task(pr_prompt, tools=exec_tools)
|
||
elapsed = round(time.time() - step_start, 1)
|
||
cost_log.append({
|
||
"step": f"3{chr(97+i)}. Write PR '{headline[:40]}...'",
|
||
"model": "execution-brain (default)",
|
||
"elapsed_s": elapsed,
|
||
})
|
||
|
||
# Clean output: find the headline, strip preamble and markdown
|
||
clean_result = _clean_pr_output(raw_result, headline)
|
||
pr_texts.append(clean_result)
|
||
|
||
# Validate word count
|
||
wc = _word_count(clean_result)
|
||
if wc < 575 or wc > 800:
|
||
log.warning("PR %d word count %d outside 575-800 range", i + 1, wc)
|
||
|
||
# Save PR to file
|
||
slug = _slugify(headline)
|
||
filename = f"{slug}_{today}.txt"
|
||
filepath = _OUTPUT_DIR / filename
|
||
filepath.write_text(clean_result, encoding="utf-8")
|
||
pr_files.append(str(filepath))
|
||
|
||
# ── Step 4: Generate 2 JSON-LD schemas (Sonnet + WebSearch) ───────────
|
||
schema_texts: list[str] = []
|
||
schema_files: list[str] = []
|
||
for i, pr_text in enumerate(pr_texts):
|
||
step_start = time.time()
|
||
schema_prompt = _build_schema_prompt(pr_text, company_name, url, schema_skill)
|
||
exec_tools = "WebSearch,WebFetch"
|
||
result = agent.execute_task(
|
||
schema_prompt,
|
||
tools=exec_tools,
|
||
model=SONNET_CLI_MODEL,
|
||
)
|
||
elapsed = round(time.time() - step_start, 1)
|
||
cost_log.append({
|
||
"step": f"4{chr(97+i)}. Schema for PR {i+1}",
|
||
"model": SONNET_CLI_MODEL,
|
||
"elapsed_s": elapsed,
|
||
})
|
||
|
||
# Extract clean JSON and force correct mainEntityOfPage
|
||
schema_json = _extract_json(result)
|
||
if schema_json:
|
||
try:
|
||
schema_obj = json.loads(schema_json)
|
||
if url:
|
||
schema_obj["mainEntityOfPage"] = url
|
||
schema_json = json.dumps(schema_obj, indent=2)
|
||
except json.JSONDecodeError:
|
||
log.warning("Schema %d is not valid JSON", i + 1)
|
||
schema_texts.append(schema_json or result)
|
||
|
||
# Save schema to file
|
||
slug = _slugify(winners[i])
|
||
filename = f"{slug}_{today}_schema.json"
|
||
filepath = _OUTPUT_DIR / filename
|
||
filepath.write_text(schema_json or result, encoding="utf-8")
|
||
schema_files.append(str(filepath))
|
||
|
||
# ── Build final output ────────────────────────────────────────────────
|
||
total_elapsed = sum(c["elapsed_s"] for c in cost_log)
|
||
output_parts = []
|
||
|
||
for i in range(2):
|
||
label = chr(65 + i) # A, B
|
||
wc = _word_count(pr_texts[i])
|
||
output_parts.append(f"## Press Release {label}: {winners[i]}")
|
||
output_parts.append(f"**Word count:** {wc}")
|
||
output_parts.append(f"**File:** `{pr_files[i]}`\n")
|
||
output_parts.append(pr_texts[i])
|
||
output_parts.append("\n---\n")
|
||
output_parts.append(f"### Schema {label}")
|
||
output_parts.append(f"**File:** `{schema_files[i]}`\n")
|
||
output_parts.append(f"```json\n{schema_texts[i]}\n```")
|
||
output_parts.append("\n---\n")
|
||
|
||
# Cost summary table
|
||
output_parts.append("## Cost Summary\n")
|
||
output_parts.append("| Step | Model | Time (s) |")
|
||
output_parts.append("|------|-------|----------|")
|
||
for c in cost_log:
|
||
output_parts.append(f"| {c['step']} | {c['model']} | {c['elapsed_s']} |")
|
||
output_parts.append(f"| **Total** | | **{round(total_elapsed, 1)}** |")
|
||
|
||
return "\n".join(output_parts)
|
||
|
||
|
||
def _extract_json(text: str) -> str | None:
|
||
"""Try to pull a JSON object out of LLM output (strip fences, prose, etc)."""
|
||
stripped = text.strip()
|
||
if stripped.startswith("{"):
|
||
try:
|
||
json.loads(stripped)
|
||
return stripped
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Strip markdown fences
|
||
fence_match = re.search(r"```(?:json)?\s*\n?([\s\S]*?)\n?```", text)
|
||
if fence_match:
|
||
candidate = fence_match.group(1).strip()
|
||
try:
|
||
json.loads(candidate)
|
||
return candidate
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
# Last resort: find first { to last }
|
||
start = text.find("{")
|
||
end = text.rfind("}")
|
||
if start != -1 and end != -1 and end > start:
|
||
candidate = text[start:end + 1]
|
||
try:
|
||
json.loads(candidate)
|
||
return candidate
|
||
except json.JSONDecodeError:
|
||
pass
|
||
|
||
return None
|