diff --git a/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py new file mode 100644 index 000000000..6f1ee3ef3 --- /dev/null +++ b/ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py @@ -0,0 +1,41 @@ +from django.db import migrations + + +def create_periodic_tasks(apps, schema_editor): + CrontabSchedule = apps.get_model("django_celery_beat", "CrontabSchedule") + PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") + + schedule, _ = CrontabSchedule.objects.get_or_create( + minute="*/15", + hour="*", + day_of_week="*", + day_of_month="*", + month_of_year="*", + ) + PeriodicTask.objects.get_or_create( + name="jobs.health_check", + defaults={ + "task": "ami.jobs.tasks.jobs_health_check", + "crontab": schedule, + "description": ( + "Umbrella job-health checks: stale-job reconciler plus a NATS " + "consumer snapshot for each running async_api job." + ), + }, + ) + + +def delete_periodic_tasks(apps, schema_editor): + PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask") + PeriodicTask.objects.filter(name="jobs.health_check").delete() + + +class Migration(migrations.Migration): + dependencies = [ + ("jobs", "0019_job_dispatch_mode"), + ("django_celery_beat", "0018_improve_crontab_helptext"), + ] + + operations = [ + migrations.RunPython(create_periodic_tasks, delete_periodic_tasks), + ] diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index ad3e18ca8..d3c8e9070 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import functools import logging @@ -9,6 +10,7 @@ from celery.signals import task_failure, task_postrun, task_prerun from django.db import transaction +from ami.main.checks.schemas import IntegrityCheckResult from ami.ml.orchestration.async_job_state import AsyncJobStateManager from ami.ml.orchestration.nats_queue import TaskQueueManager from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse @@ -407,6 +409,137 @@ def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[di return results +# Expire queued copies that accumulate while a worker is unavailable so we +# don't process a backlog when a worker reconnects. Kept below the 15-minute +# schedule interval so a backlog is dropped but a single delayed copy still +# runs. Going well below the interval would risk every copy expiring before +# a worker picks it up under moderate broker pressure — change this in lock- +# step with the crontab in migration 0020. +_JOBS_HEALTH_BEAT_EXPIRES = 60 * 14 + + +@dataclasses.dataclass +class JobsHealthCheckResult: + """Nested result of one :func:`jobs_health_check` tick. + + Each field is the summary for one sub-check and uses the shared + :class:`IntegrityCheckResult` shape so operators see a uniform + ``checked / fixed / unfixable`` triple regardless of which check ran. + Add a new field here when adding a sub-check to the umbrella. + """ + + stale_jobs: IntegrityCheckResult + running_job_snapshots: IntegrityCheckResult + + +def _run_stale_jobs_check() -> IntegrityCheckResult: + """Reconcile jobs stuck in running states past FAILED_CUTOFF_HOURS.""" + results = check_stale_jobs() + updated = sum(1 for r in results if r["action"] == "updated") + revoked = sum(1 for r in results if r["action"] == "revoked") + logger.info( + "stale_jobs check: %d stale job(s), %d updated from Celery, %d revoked", + len(results), + updated, + revoked, + ) + return IntegrityCheckResult(checked=len(results), fixed=updated + revoked, unfixable=0) + + +def _run_running_job_snapshot_check() -> IntegrityCheckResult: + """Log a NATS consumer snapshot for each running async_api job. + + Observation-only: ``fixed`` stays 0 because no state is altered. Jobs + that error during snapshot are counted in ``unfixable`` — a persistently + stuck job will be picked up on the next tick by ``_run_stale_jobs_check``. + """ + from ami.jobs.models import Job, JobDispatchMode, JobState + + running_jobs = list( + Job.objects.filter( + status__in=JobState.running_states(), + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + ) + if not running_jobs: + return IntegrityCheckResult() + + # Resolve each job's per-job logger synchronously before entering the + # event loop — ``Job.logger`` attaches a ``JobLogHandler`` on first access + # which touches the Django ORM, so it is only safe to call from a sync + # context. + job_loggers = [(job, job.logger) for job in running_jobs] + errors = 0 + + async def _snapshot_all() -> None: + nonlocal errors + # One NATS connection per tick — on a 15-min cadence a per-job fallback + # is not worth the code. If the shared connection fails to set up, we + # skip this tick's snapshots and try fresh on the next one. + async with TaskQueueManager(job_logger=job_loggers[0][1]) as manager: + for job, job_logger in job_loggers: + try: + # ``log_async`` reads ``job_logger`` fresh each call, so + # swapping per iteration routes lifecycle lines to the + # right job's UI log. + manager.job_logger = job_logger + await manager.log_consumer_stats_snapshot(job.pk) + except Exception: + errors += 1 + logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk) + + try: + async_to_sync(_snapshot_all)() + except Exception: + # Covers both ``__aenter__`` setup failures (no iteration ran) and the + # rare ``__aexit__`` teardown failure after a clean loop. In the + # teardown case this overwrites the per-iteration count with the total + # — accepted: a persistent failure will show up again next tick. + logger.exception("Shared-connection snapshot failed; marking tick unfixable") + errors = len(running_jobs) + + log_fn = logger.warning if errors else logger.info + log_fn( + "running_job_snapshots check: %d running async job(s), %d error(s)", + len(running_jobs), + errors, + ) + return IntegrityCheckResult(checked=len(running_jobs), fixed=0, unfixable=errors) + + +def _safe_run_sub_check(name: str, fn: Callable[[], IntegrityCheckResult]) -> IntegrityCheckResult: + """Run one umbrella sub-check, returning an ``unfixable=1`` sentinel on failure. + + The umbrella composes independent sub-checks; one failing must not block + the others. A raised exception is logged and surfaced as a single + ``unfixable`` entry so operators watching the task result in Flower see + the check failed rather than reading zero and assuming all-clear. + """ + try: + return fn() + except Exception: + logger.exception("%s sub-check failed; continuing umbrella", name) + return IntegrityCheckResult(checked=0, fixed=0, unfixable=1) + + +@celery_app.task(soft_time_limit=300, time_limit=360, expires=_JOBS_HEALTH_BEAT_EXPIRES) +def jobs_health_check() -> dict: + """Umbrella beat task for periodic job-health checks. + + Composes reconciliation (stale jobs) with observation (NATS consumer + snapshots for running async jobs) so both land in the same 15-minute + tick — a quietly hung async job gets a snapshot entry right before the + reconciler decides whether to revoke it. Returns the serialized form of + :class:`JobsHealthCheckResult` so celery's default JSON backend can store + it; add new sub-checks by extending that dataclass and calling them here. + """ + result = JobsHealthCheckResult( + stale_jobs=_safe_run_sub_check("stale_jobs", _run_stale_jobs_check), + running_job_snapshots=_safe_run_sub_check("running_job_snapshots", _run_running_job_snapshot_check), + ) + return dataclasses.asdict(result) + + def cleanup_async_job_if_needed(job) -> None: """ Clean up async resources (NATS/Redis) if this job uses them. diff --git a/ami/jobs/tests/test_periodic_beat_tasks.py b/ami/jobs/tests/test_periodic_beat_tasks.py new file mode 100644 index 000000000..eaf2f3368 --- /dev/null +++ b/ami/jobs/tests/test_periodic_beat_tasks.py @@ -0,0 +1,175 @@ +from datetime import timedelta +from unittest.mock import AsyncMock, patch + +from django.test import TestCase +from django.utils import timezone + +from ami.jobs.models import Job, JobDispatchMode, JobState +from ami.jobs.tasks import jobs_health_check +from ami.main.models import Project + + +def _empty_check_dict() -> dict: + return {"checked": 0, "fixed": 0, "unfixable": 0} + + +@patch("ami.jobs.tasks.cleanup_async_job_if_needed") +@patch("ami.jobs.tasks.TaskQueueManager") +class JobsHealthCheckTest(TestCase): + def setUp(self): + self.project = Project.objects.create(name="Beat schedule test project") + + def _create_stale_job(self, status=JobState.STARTED, hours_ago=100): + job = Job.objects.create(project=self.project, name="stale", status=status) + Job.objects.filter(pk=job.pk).update(updated_at=timezone.now() - timedelta(hours=hours_ago)) + job.refresh_from_db() + return job + + def _create_async_job(self, status=JobState.STARTED): + job = Job.objects.create(project=self.project, name=f"async {status}", status=status) + Job.objects.filter(pk=job.pk).update(dispatch_mode=JobDispatchMode.ASYNC_API) + job.refresh_from_db() + return job + + def _stub_manager(self, mock_manager_cls) -> AsyncMock: + instance = mock_manager_cls.return_value + instance.__aenter__ = AsyncMock(return_value=instance) + instance.__aexit__ = AsyncMock(return_value=False) + instance.log_consumer_stats_snapshot = AsyncMock() + return instance + + def test_reports_both_sub_check_results(self, mock_manager_cls, _mock_cleanup): + self._create_stale_job() + self._create_stale_job() + self._stub_manager(mock_manager_cls) + + result = jobs_health_check() + + self.assertEqual( + result, + { + "stale_jobs": {"checked": 2, "fixed": 2, "unfixable": 0}, + "running_job_snapshots": _empty_check_dict(), + }, + ) + + def test_idle_deployment_returns_all_zeros(self, mock_manager_cls, _mock_cleanup): + # No stale jobs, no running async jobs. + self._create_stale_job(hours_ago=1) # recent — not stale + self._stub_manager(mock_manager_cls) + + self.assertEqual( + jobs_health_check(), + { + "stale_jobs": _empty_check_dict(), + "running_job_snapshots": _empty_check_dict(), + }, + ) + + def test_snapshots_each_running_async_job(self, mock_manager_cls, _mock_cleanup): + job_a = self._create_async_job() + job_b = self._create_async_job() + instance = self._stub_manager(mock_manager_cls) + + result = jobs_health_check() + + self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 0}) + # Observation-only contract: the snapshot sub-check must never report + # ``fixed > 0`` since it does not mutate state. Lock this in explicitly + # so a future refactor that accidentally increments ``fixed`` breaks + # this assertion rather than silently shipping. + self.assertEqual(result["running_job_snapshots"]["fixed"], 0) + snapshots = [call.args[0] for call in instance.log_consumer_stats_snapshot.await_args_list] + self.assertCountEqual(snapshots, [job_a.pk, job_b.pk]) + + def test_one_job_snapshot_failure_counts_as_unfixable(self, mock_manager_cls, _mock_cleanup): + job_ok = self._create_async_job() + job_broken = self._create_async_job() + instance = self._stub_manager(mock_manager_cls) + + calls = [] + + async def _snapshot(job_id): + calls.append(job_id) + if job_id == job_broken.pk: + raise RuntimeError("nats down for this one") + + instance.log_consumer_stats_snapshot = AsyncMock(side_effect=_snapshot) + + result = jobs_health_check() + + # Both jobs were attempted; only the broken one failed. + self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 1}) + self.assertIn(job_ok.pk, calls) + self.assertIn(job_broken.pk, calls) + + def test_shared_connection_setup_failure_marks_all_unfixable(self, mock_manager_cls, _mock_cleanup): + self._create_async_job() + self._create_async_job() + + instance = mock_manager_cls.return_value + instance.__aenter__ = AsyncMock(side_effect=RuntimeError("nats down")) + instance.__aexit__ = AsyncMock(return_value=False) + instance.log_consumer_stats_snapshot = AsyncMock() + + result = jobs_health_check() + + # All running jobs are counted as unfixable for this tick; no + # snapshots ran and the shared-connection error was swallowed. + self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 2}) + instance.log_consumer_stats_snapshot.assert_not_awaited() + + def test_non_async_running_jobs_are_ignored_by_snapshot_check(self, mock_manager_cls, _mock_cleanup): + job = Job.objects.create(project=self.project, name="sync job", status=JobState.STARTED) + self.assertNotEqual(job.dispatch_mode, JobDispatchMode.ASYNC_API) + instance = self._stub_manager(mock_manager_cls) + + result = jobs_health_check() + + self.assertEqual(result["running_job_snapshots"], _empty_check_dict()) + instance.log_consumer_stats_snapshot.assert_not_awaited() + + def test_sub_check_exception_does_not_block_the_other(self, mock_manager_cls, _mock_cleanup): + # One stale job to prove the reconciler would have had work; the + # snapshot sub-check raises and must not prevent the stale-jobs + # sub-check from running and reporting its own result. + self._create_stale_job() + self._stub_manager(mock_manager_cls) + + with patch( + "ami.jobs.tasks._run_running_job_snapshot_check", + side_effect=RuntimeError("pretend the observation check blew up"), + ): + result = jobs_health_check() + + # Stale-jobs sub-check completes normally and reports the reconciliation. + self.assertEqual(result["stale_jobs"], {"checked": 1, "fixed": 1, "unfixable": 0}) + # Snapshot sub-check returns the `unfixable=1` sentinel on failure so + # operators reading the task result see the check failed, not "nothing + # to do." + self.assertEqual(result["running_job_snapshots"], {"checked": 0, "fixed": 0, "unfixable": 1}) + + def test_stale_jobs_fixed_counts_celery_updated_and_revoked_paths(self, mock_manager_cls, _mock_cleanup): + # Two stale jobs in different reconciliation states: one has a Celery + # task_id that returns a terminal state (counts as "updated from Celery"), + # the other has no task_id and is forced to REVOKED. Both contribute to + # `fixed` — this test guards against a refactor dropping one branch. + from celery import states + + job_with_task = self._create_stale_job() + job_with_task.task_id = "terminal-task" + job_with_task.save(update_fields=["task_id"]) + self._create_stale_job() # no task_id → revoked path + self._stub_manager(mock_manager_cls) + + class _FakeAsyncResult: + def __init__(self, task_id): + self.state = states.SUCCESS if task_id == "terminal-task" else states.PENDING + + # `check_stale_jobs` imports AsyncResult locally from celery.result, + # so patch at source rather than at the call site. + with patch("celery.result.AsyncResult", _FakeAsyncResult): + result = jobs_health_check() + + # checked == 2 (both stale), fixed == 2 (one per branch), unfixable == 0 + self.assertEqual(result["stale_jobs"], {"checked": 2, "fixed": 2, "unfixable": 0}) diff --git a/ami/main/checks/__init__.py b/ami/main/checks/__init__.py new file mode 100644 index 000000000..506d49ffe --- /dev/null +++ b/ami/main/checks/__init__.py @@ -0,0 +1,12 @@ +"""Integrity and health check primitives shared across apps. + +Sub-modules in this package (added by per-domain check PRs such as +``ami.main.checks.occurrences`` in #1188) define ``get_*`` and +``reconcile_*`` function pairs. The shared result schema lives in +:mod:`ami.main.checks.schemas` so reconciliation and observation checks +across apps return the same shape. +""" + +from ami.main.checks.schemas import IntegrityCheckResult + +__all__ = ["IntegrityCheckResult"] diff --git a/ami/main/checks/schemas.py b/ami/main/checks/schemas.py new file mode 100644 index 000000000..75347b2aa --- /dev/null +++ b/ami/main/checks/schemas.py @@ -0,0 +1,27 @@ +"""Shared result schemas for integrity and health checks. + +A check is any function that inspects some slice of state and returns an +:class:`IntegrityCheckResult`. Reconciliation checks populate ``fixed`` +with the number of rows actually mutated; observation checks (e.g. +logging a snapshot) keep ``fixed`` at 0 and use ``unfixable`` to count +items the check could not complete for. +""" + +import dataclasses + + +@dataclasses.dataclass +class IntegrityCheckResult: + """Summary of a single integrity or health check pass. + + Attributes: + checked: Rows / items the check inspected this pass. + fixed: Rows the check mutated to a correct state. Observation-only + checks must leave this at 0 — ``fixed`` means state was altered. + unfixable: Rows the check inspected but could not repair or observe + (for observation checks this counts errors per item). + """ + + checked: int = 0 + fixed: int = 0 + unfixable: int = 0 diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 43d9d65e5..767298af4 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -477,6 +477,19 @@ async def _log_final_consumer_stats(self, job_id: int) -> None: redelivered before the consumer vanished. Failures here must NOT block cleanup — if the consumer or stream is already gone, just skip it. """ + await self._log_consumer_stats(job_id, prefix="Finalizing NATS consumer", suffix="before deletion") + + async def log_consumer_stats_snapshot(self, job_id: int) -> None: + """Log a mid-flight snapshot of the consumer state for a running job. + + Called by the ``running_job_snapshots`` sub-check of the periodic + ``jobs_health_check`` beat task so operators can see deliver/ack/pending + counts without waiting for the job to finish. Tolerant of missing + stream/consumer like the cleanup-time variant. + """ + await self._log_consumer_stats(job_id, prefix="NATS consumer status") + + async def _log_consumer_stats(self, job_id: int, *, prefix: str, suffix: str = "") -> None: if self.js is None: return stream_name = self._get_stream_name(job_id) @@ -487,15 +500,15 @@ async def _log_final_consumer_stats(self, job_id: int) -> None: timeout=NATS_JETSTREAM_TIMEOUT, ) except Exception as e: - # Broad catch is intentional here (unlike _ensure_consumer): at - # cleanup time we tolerate any failure — stream gone, consumer - # already deleted, auth, timeout — so the delete calls below - # still get a chance to run. - logger.debug(f"Could not fetch consumer info for {consumer_name} before deletion: {e}") + # Broad catch is intentional: if the consumer or stream is gone we + # just skip — callers (cleanup, periodic snapshot) should never fail + # because we couldn't read stats. + logger.debug(f"Could not fetch consumer info for {consumer_name}: {e}") return + tail = f" {suffix}" if suffix else "" await self.log_async( logging.INFO, - f"Finalizing NATS consumer {consumer_name} before deletion ({self._format_consumer_stats(info)})", + f"{prefix} {consumer_name}{tail} ({self._format_consumer_stats(info)})", ) async def delete_consumer(self, job_id: int) -> bool: diff --git a/ami/ml/orchestration/tests/test_nats_queue.py b/ami/ml/orchestration/tests/test_nats_queue.py index d1d651450..9c35a4dae 100644 --- a/ami/ml/orchestration/tests/test_nats_queue.py +++ b/ami/ml/orchestration/tests/test_nats_queue.py @@ -493,6 +493,44 @@ async def test_publish_failure_surfaces_on_job_logger(self): f"expected publish failure on job_logger, got {messages}", ) + async def test_log_consumer_stats_snapshot_writes_current_stats(self): + """The periodic snapshot helper logs delivered/ack/pending WITHOUT + deleting the consumer — it's a mid-flight observability hook.""" + nc, js = self._create_mock_nats_connection() + js.consumer_info.return_value = self._make_consumer_info( + delivered=50, ack_floor=40, num_pending=10, num_ack_pending=10, num_redelivered=2 + ) + + job_logger = self._make_captured_logger() + captured = job_logger._captured # type: ignore[attr-defined] + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + await manager.log_consumer_stats_snapshot(9) + + messages = [m for _, m in captured] + self.assertTrue( + any("NATS consumer status job-9-consumer" in m for m in messages), + f"expected snapshot line on job_logger, got {messages}", + ) + snapshot_line = next(m for m in messages if "NATS consumer status" in m) + for expected in ("delivered=50", "ack_floor=40", "num_redelivered=2"): + self.assertIn(expected, snapshot_line) + # Must NOT have triggered a delete — this is read-only observability. + js.delete_consumer.assert_not_called() + js.delete_stream.assert_not_called() + + async def test_log_consumer_stats_snapshot_tolerates_missing_consumer(self): + """If the consumer is already gone, the snapshot helper just no-ops.""" + nc, js = self._create_mock_nats_connection() + js.consumer_info.side_effect = nats.js.errors.NotFoundError() + + job_logger = self._make_captured_logger() + + with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))): + async with TaskQueueManager(job_logger=job_logger) as manager: + await manager.log_consumer_stats_snapshot(99) # must not raise + async def test_no_job_logger_falls_back_to_module_logger_only(self): """When job_logger is None (e.g., module-level uses like advisory listener), lifecycle logs must still be emitted to the module logger