diff --git a/cheddahbot/ntfy.py b/cheddahbot/ntfy.py index 702954f..f90da3f 100644 --- a/cheddahbot/ntfy.py +++ b/cheddahbot/ntfy.py @@ -6,10 +6,13 @@ topics based on category and message-pattern matching. from __future__ import annotations +import hashlib import logging import re import threading +import time from dataclasses import dataclass, field +from datetime import date import httpx @@ -48,8 +51,24 @@ class NtfyChannel: class NtfyNotifier: """Posts notifications to ntfy.sh topics.""" - def __init__(self, channels: list[NtfyChannel]): + def __init__( + self, + channels: list[NtfyChannel], + *, + daily_cap: int = 200, + dedup_window_secs: int = 3600, + ): self._channels = [ch for ch in channels if ch.topic] + self._daily_cap = daily_cap + self._dedup_window_secs = dedup_window_secs + self._lock = threading.Lock() + # dedup: hash(channel.name + message) -> last-sent epoch + self._recent: dict[str, float] = {} + # daily cap tracking + self._daily_count = 0 + self._daily_date = "" + # 429 backoff: date string when rate-limited + self._rate_limited_until = "" if self._channels: log.info( "ntfy notifier initialized with %d channel(s): %s", @@ -61,6 +80,59 @@ class NtfyNotifier: def enabled(self) -> bool: return bool(self._channels) + def _today(self) -> str: + return date.today().isoformat() + + def _check_and_track(self, channel_name: str, message: str) -> bool: + """Return True if this message should be sent. Updates internal state.""" + now = time.monotonic() + today = self._today() + + with self._lock: + # 429 backoff: skip all sends for rest of day + if self._rate_limited_until == today: + return False + + # Reset daily counter on date rollover + if self._daily_date != today: + self._daily_date = today + self._daily_count = 0 + self._rate_limited_until = "" + self._recent.clear() + + # Daily cap check + if self._daily_count >= self._daily_cap: + return False + + # Dedup check + key = hashlib.md5( + (channel_name + "\0" + message).encode() + ).hexdigest() + last_sent = self._recent.get(key) + if last_sent is not None and (now - last_sent) < self._dedup_window_secs: + log.debug( + "ntfy dedup: suppressed duplicate to '%s'", channel_name, + ) + return False + + # All checks passed — record send + self._recent[key] = now + self._daily_count += 1 + + if self._daily_count == self._daily_cap: + log.warning( + "ntfy daily cap reached (%d). No more sends today.", + self._daily_cap, + ) + + return True + + def _mark_rate_limited(self) -> None: + """Flag that we got a 429 — suppress all sends for rest of day.""" + with self._lock: + self._rate_limited_until = self._today() + log.warning("ntfy 429 received. Suppressing all sends for rest of day.") + def notify(self, message: str, category: str) -> None: """Route a notification to matching ntfy channels. @@ -70,6 +142,8 @@ class NtfyNotifier: """ for channel in self._channels: if channel.accepts(message, category): + if not self._check_and_track(channel.name, message): + continue t = threading.Thread( target=self._post, args=(channel, message, category), @@ -93,7 +167,9 @@ class NtfyNotifier: headers=headers, timeout=10.0, ) - if resp.status_code >= 400: + if resp.status_code == 429: + self._mark_rate_limited() + elif resp.status_code >= 400: log.warning( "ntfy '%s' returned %d: %s", channel.name, resp.status_code, resp.text[:200], diff --git a/tests/test_ntfy.py b/tests/test_ntfy.py index cb9e764..c04871b 100644 --- a/tests/test_ntfy.py +++ b/tests/test_ntfy.py @@ -2,6 +2,7 @@ from __future__ import annotations +import time from unittest.mock import MagicMock, patch import httpx @@ -288,3 +289,124 @@ class TestPostFormat: t.join(timeout=2) assert mock_post.call_args[1]["headers"]["Priority"] == "urgent" + + +# --------------------------------------------------------------------------- +# Dedup window +# --------------------------------------------------------------------------- + + +def _make_channel(**overrides) -> NtfyChannel: + defaults = dict( + name="errors", + server="https://ntfy.sh", + topic="test-topic", + categories=["alert"], + ) + defaults.update(overrides) + return NtfyChannel(**defaults) + + +class TestDedup: + def test_first_message_goes_through(self): + notifier = NtfyNotifier([_make_channel()], dedup_window_secs=3600) + assert notifier._check_and_track("errors", "task X skipped") is True + + def test_duplicate_within_window_suppressed(self): + notifier = NtfyNotifier([_make_channel()], dedup_window_secs=3600) + assert notifier._check_and_track("errors", "task X skipped") is True + assert notifier._check_and_track("errors", "task X skipped") is False + + def test_duplicate_after_window_passes(self): + notifier = NtfyNotifier([_make_channel()], dedup_window_secs=60) + assert notifier._check_and_track("errors", "task X skipped") is True + # Simulate time passing beyond the window + key = list(notifier._recent.keys())[0] + notifier._recent[key] = time.monotonic() - 120 + assert notifier._check_and_track("errors", "task X skipped") is True + + def test_different_messages_not_deduped(self): + notifier = NtfyNotifier([_make_channel()], dedup_window_secs=3600) + assert notifier._check_and_track("errors", "task A skipped") is True + assert notifier._check_and_track("errors", "task B skipped") is True + + def test_same_message_different_channel_not_deduped(self): + notifier = NtfyNotifier([_make_channel()], dedup_window_secs=3600) + assert notifier._check_and_track("errors", "task X skipped") is True + assert notifier._check_and_track("alerts", "task X skipped") is True + + +# --------------------------------------------------------------------------- +# Daily cap +# --------------------------------------------------------------------------- + + +class TestDailyCap: + def test_sends_up_to_cap(self): + notifier = NtfyNotifier([_make_channel()], daily_cap=3) + for i in range(3): + assert notifier._check_and_track("errors", f"msg {i}") is True + assert notifier._check_and_track("errors", "msg 3") is False + + def test_cap_resets_on_new_day(self): + notifier = NtfyNotifier([_make_channel()], daily_cap=2) + assert notifier._check_and_track("errors", "msg 0") is True + assert notifier._check_and_track("errors", "msg 1") is True + assert notifier._check_and_track("errors", "msg 2") is False + + with patch.object(notifier, "_today", return_value="2099-01-01"): + assert notifier._check_and_track("errors", "msg 2") is True + + +# --------------------------------------------------------------------------- +# 429 backoff +# --------------------------------------------------------------------------- + + +class TestRateLimitBackoff: + def test_429_suppresses_rest_of_day(self): + notifier = NtfyNotifier([_make_channel()]) + notifier._mark_rate_limited() + assert notifier._check_and_track("errors", "new message") is False + + def test_429_resets_next_day(self): + notifier = NtfyNotifier([_make_channel()]) + notifier._mark_rate_limited() + assert notifier._check_and_track("errors", "blocked") is False + + with patch.object(notifier, "_today", return_value="2099-01-01"): + assert notifier._check_and_track("errors", "unblocked") is True + + def test_post_sets_rate_limit_on_429(self): + channel = _make_channel() + notifier = NtfyNotifier([channel]) + + mock_resp = MagicMock(status_code=429, text="Rate limited") + with patch("cheddahbot.ntfy.httpx.post", return_value=mock_resp): + notifier._post(channel, "test msg", "alert") + + assert notifier._rate_limited_until == notifier._today() + + +# --------------------------------------------------------------------------- +# Notify integration with dedup +# --------------------------------------------------------------------------- + + +class TestNotifyDedup: + @patch("cheddahbot.ntfy.httpx.post") + def test_notify_skips_deduped_messages(self, mock_post): + mock_post.return_value = MagicMock(status_code=200) + channel = _make_channel() + notifier = NtfyNotifier([channel]) + + notifier.notify("same msg", "alert") + notifier.notify("same msg", "alert") + + import threading + for t in threading.enumerate(): + if t.daemon and t.is_alive(): + t.join(timeout=2) + + # Only one post — second was deduped + mock_post.assert_called_once()