Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions ami/jobs/migrations/0020_schedule_job_monitoring_beat_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from django.db import migrations


def create_periodic_tasks(apps, schema_editor):
CrontabSchedule = apps.get_model("django_celery_beat", "CrontabSchedule")
PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask")

Comment thread
coderabbitai[bot] marked this conversation as resolved.
schedule, _ = CrontabSchedule.objects.get_or_create(
minute="*/15",
hour="*",
day_of_week="*",
day_of_month="*",
month_of_year="*",
)
PeriodicTask.objects.get_or_create(
name="jobs.health_check",
defaults={
"task": "ami.jobs.tasks.jobs_health_check",
"crontab": schedule,
"description": (
"Umbrella job-health checks: stale-job reconciler plus a NATS "
"consumer snapshot for each running async_api job."
),
},
)


def delete_periodic_tasks(apps, schema_editor):
PeriodicTask = apps.get_model("django_celery_beat", "PeriodicTask")
PeriodicTask.objects.filter(name="jobs.health_check").delete()


class Migration(migrations.Migration):
dependencies = [
("jobs", "0019_job_dispatch_mode"),
("django_celery_beat", "0018_improve_crontab_helptext"),
]

operations = [
migrations.RunPython(create_periodic_tasks, delete_periodic_tasks),
]
133 changes: 133 additions & 0 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import datetime
import functools
import logging
Expand All @@ -9,6 +10,7 @@
from celery.signals import task_failure, task_postrun, task_prerun
from django.db import transaction

from ami.main.checks.schemas import IntegrityCheckResult
from ami.ml.orchestration.async_job_state import AsyncJobStateManager
from ami.ml.orchestration.nats_queue import TaskQueueManager
from ami.ml.schemas import PipelineResultsError, PipelineResultsResponse
Expand Down Expand Up @@ -407,6 +409,137 @@ def check_stale_jobs(hours: int | None = None, dry_run: bool = False) -> list[di
return results


# Expire queued copies that accumulate while a worker is unavailable so we
# don't process a backlog when a worker reconnects. Kept below the 15-minute
# schedule interval so a backlog is dropped but a single delayed copy still
# runs. Going well below the interval would risk every copy expiring before
# a worker picks it up under moderate broker pressure — change this in lock-
# step with the crontab in migration 0020.
_JOBS_HEALTH_BEAT_EXPIRES = 60 * 14


@dataclasses.dataclass
class JobsHealthCheckResult:
"""Nested result of one :func:`jobs_health_check` tick.

Each field is the summary for one sub-check and uses the shared
:class:`IntegrityCheckResult` shape so operators see a uniform
``checked / fixed / unfixable`` triple regardless of which check ran.
Add a new field here when adding a sub-check to the umbrella.
"""

stale_jobs: IntegrityCheckResult
running_job_snapshots: IntegrityCheckResult


def _run_stale_jobs_check() -> IntegrityCheckResult:
"""Reconcile jobs stuck in running states past FAILED_CUTOFF_HOURS."""
results = check_stale_jobs()
updated = sum(1 for r in results if r["action"] == "updated")
revoked = sum(1 for r in results if r["action"] == "revoked")
logger.info(
"stale_jobs check: %d stale job(s), %d updated from Celery, %d revoked",
len(results),
updated,
revoked,
)
return IntegrityCheckResult(checked=len(results), fixed=updated + revoked, unfixable=0)


def _run_running_job_snapshot_check() -> IntegrityCheckResult:
"""Log a NATS consumer snapshot for each running async_api job.

Observation-only: ``fixed`` stays 0 because no state is altered. Jobs
that error during snapshot are counted in ``unfixable`` — a persistently
stuck job will be picked up on the next tick by ``_run_stale_jobs_check``.
"""
from ami.jobs.models import Job, JobDispatchMode, JobState

running_jobs = list(
Job.objects.filter(
status__in=JobState.running_states(),
dispatch_mode=JobDispatchMode.ASYNC_API,
)
)
if not running_jobs:
return IntegrityCheckResult()

# Resolve each job's per-job logger synchronously before entering the
# event loop — ``Job.logger`` attaches a ``JobLogHandler`` on first access
# which touches the Django ORM, so it is only safe to call from a sync
# context.
job_loggers = [(job, job.logger) for job in running_jobs]
errors = 0

async def _snapshot_all() -> None:
nonlocal errors
# One NATS connection per tick — on a 15-min cadence a per-job fallback
# is not worth the code. If the shared connection fails to set up, we
# skip this tick's snapshots and try fresh on the next one.
async with TaskQueueManager(job_logger=job_loggers[0][1]) as manager:
for job, job_logger in job_loggers:
try:
# ``log_async`` reads ``job_logger`` fresh each call, so
# swapping per iteration routes lifecycle lines to the
# right job's UI log.
manager.job_logger = job_logger
await manager.log_consumer_stats_snapshot(job.pk)
except Exception:
errors += 1
logger.exception("Failed to snapshot NATS consumer stats for job %s", job.pk)

try:
async_to_sync(_snapshot_all)()
except Exception:
# Covers both ``__aenter__`` setup failures (no iteration ran) and the
# rare ``__aexit__`` teardown failure after a clean loop. In the
# teardown case this overwrites the per-iteration count with the total
# — accepted: a persistent failure will show up again next tick.
logger.exception("Shared-connection snapshot failed; marking tick unfixable")
errors = len(running_jobs)

log_fn = logger.warning if errors else logger.info
log_fn(
"running_job_snapshots check: %d running async job(s), %d error(s)",
len(running_jobs),
errors,
)
return IntegrityCheckResult(checked=len(running_jobs), fixed=0, unfixable=errors)
Comment on lines +449 to +507
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

running_job_snapshots.unfixable can undercount real per-job failures.

_run_running_job_snapshot_check increments errors only when log_consumer_stats_snapshot raises, but that method is implemented to swallow common snapshot-fetch failures internally and return early. This can report unfixable=0 even when snapshots failed for individual jobs. Consider returning a success/failure signal from log_consumer_stats_snapshot (or a companion API) and count from that result instead of exception flow.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@ami/jobs/tasks.py` around lines 449 - 507, The check
_run_running_job_snapshot_check undercounts failures because
TaskQueueManager.log_consumer_stats_snapshot swallows errors and doesn't signal
failure; update that API to return a boolean (or similar success flag) and have
_run_running_job_snapshot_check call and inspect the return value instead of
relying only on exceptions: change TaskQueueManager.log_consumer_stats_snapshot
(and any internal helpers it uses) to return True on success / False on failure,
then in _run_running_job_snapshot_check's loop (where
manager.log_consumer_stats_snapshot(job.pk) is awaited) check the result and
increment errors when it returns False (preserving the existing exception
handling for shared-connection setup/teardown via the outer try/except and
logger.exception).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If snapshots failed, we won't see logs. But I do like this idea. Perhaps unfixable could actually increment when jobs have alarming stats.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
rg -n "_log_consumer_stats|log_consumer_stats_snapshot|ConsumerInfo|num_redelivered|num_pending|num_ack_pending|ack_floor|formatted_stats" ami/ml/orchestration/nats_queue.py -A 3 -B 1

Length of output: 2063


🏁 Script executed:

#!/bin/bash
rg -n "_log_consumer_stats" ami/ml/orchestration/nats_queue.py -A 30 | grep -A 30 "async def _log_consumer_stats"

Length of output: 1346


@mihow — both observations are valid, and they compound each other:

  1. Silent on failure: _log_consumer_stats currently catches all exceptions and does only logger.debug(...), so a fetch failure produces no INFO/WARNING log and no exception to the _run_running_job_snapshot_check loop. An operator watching job logs sees nothing, which looks identical to a healthy quiet job.

  2. Alarming-stats idea: This would make unfixable semantically richer and more actionable. Instead of log_consumer_stats_snapshot → None, it could return ConsumerInfo | None, and _run_running_job_snapshot_check could then inspect the result across two axes:

info = await manager.log_consumer_stats_snapshot(job.pk)
if info is None:
    errors += 1  # fetch failed → can't observe
elif _is_alarming(info):
    errors += 1  # observed, but stats are concerning

For _is_alarming, the most actionable thresholds from the available ConsumerInfo fields would be:

  • num_redelivered above some threshold — repeated redelivery usually signals a processing loop or a stuck acker
  • num_ack_pending significantly exceeding num_pending + 1 — messages being consumed but not acknowledged

A simple start could be a single configurable threshold on num_redelivered (e.g., > 10) since that's the clearest "something is wrong" signal without being noisy for normal in-flight jobs.

Would you like me to draft the changes to log_consumer_stats_snapshot / _log_consumer_stats (return ConsumerInfo | None) and the updated _run_running_job_snapshot_check loop (with the _is_alarming helper)? Happy to also add a WARNING-level log in _log_consumer_stats's except branch so fetch failures are visible at the INFO tier too.


🧠 Learnings used
Learnt from: mihow
Repo: RolnickLab/antenna PR: 1130
File: ami/ml/orchestration/nats_queue.py:217-217
Timestamp: 2026-02-13T23:11:30.276Z
Learning: In the RolnickLab/antenna project (Pydantic v1), use .dict() for serialization of Pydantic models instead of .model_dump(). Update code paths in Python files where model_dump() is used to ensure compatibility with Pydantic v1 behavior. If any files are later migrated to Pydantic v2, consider switching to .model_dump() or using model_dump(..., mode='json') as appropriate.

Learnt from: mihow
Repo: RolnickLab/antenna PR: 1222
File: ami/ml/orchestration/nats_queue.py:126-132
Timestamp: 2026-04-13T22:08:32.779Z
Learning: In code that uses the `nats-py` JetStream client, do not divide `ConsumerInfo.config.ack_wait` (or the `ack_wait` returned by `consumer_info()` / `add_consumer()`) by `1e9`. `nats-py` already deserializes the server’s raw nanosecond value by dividing by `_NANOSECOND` (1e9), so application-level code should treat `ack_wait` as already human-readable seconds (e.g., `30.0` means 30 seconds). Only apply nanosecond-to-second conversion if you are working with the raw server nanosecond field directly (not the `ack_wait` property exposed by `ConsumerInfo`/`ConsumerConfig`).



def _safe_run_sub_check(name: str, fn: Callable[[], IntegrityCheckResult]) -> IntegrityCheckResult:
"""Run one umbrella sub-check, returning an ``unfixable=1`` sentinel on failure.

The umbrella composes independent sub-checks; one failing must not block
the others. A raised exception is logged and surfaced as a single
``unfixable`` entry so operators watching the task result in Flower see
the check failed rather than reading zero and assuming all-clear.
"""
try:
return fn()
except Exception:
logger.exception("%s sub-check failed; continuing umbrella", name)
return IntegrityCheckResult(checked=0, fixed=0, unfixable=1)


@celery_app.task(soft_time_limit=300, time_limit=360, expires=_JOBS_HEALTH_BEAT_EXPIRES)
def jobs_health_check() -> dict:
"""Umbrella beat task for periodic job-health checks.

Composes reconciliation (stale jobs) with observation (NATS consumer
snapshots for running async jobs) so both land in the same 15-minute
tick — a quietly hung async job gets a snapshot entry right before the
reconciler decides whether to revoke it. Returns the serialized form of
:class:`JobsHealthCheckResult` so celery's default JSON backend can store
it; add new sub-checks by extending that dataclass and calling them here.
"""
result = JobsHealthCheckResult(
stale_jobs=_safe_run_sub_check("stale_jobs", _run_stale_jobs_check),
running_job_snapshots=_safe_run_sub_check("running_job_snapshots", _run_running_job_snapshot_check),
)
return dataclasses.asdict(result)


def cleanup_async_job_if_needed(job) -> None:
"""
Clean up async resources (NATS/Redis) if this job uses them.
Expand Down
175 changes: 175 additions & 0 deletions ami/jobs/tests/test_periodic_beat_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from datetime import timedelta
from unittest.mock import AsyncMock, patch

from django.test import TestCase
from django.utils import timezone

from ami.jobs.models import Job, JobDispatchMode, JobState
from ami.jobs.tasks import jobs_health_check
from ami.main.models import Project


def _empty_check_dict() -> dict:
return {"checked": 0, "fixed": 0, "unfixable": 0}


@patch("ami.jobs.tasks.cleanup_async_job_if_needed")
@patch("ami.jobs.tasks.TaskQueueManager")
class JobsHealthCheckTest(TestCase):
def setUp(self):
self.project = Project.objects.create(name="Beat schedule test project")

def _create_stale_job(self, status=JobState.STARTED, hours_ago=100):
job = Job.objects.create(project=self.project, name="stale", status=status)
Job.objects.filter(pk=job.pk).update(updated_at=timezone.now() - timedelta(hours=hours_ago))
job.refresh_from_db()
return job

def _create_async_job(self, status=JobState.STARTED):
job = Job.objects.create(project=self.project, name=f"async {status}", status=status)
Job.objects.filter(pk=job.pk).update(dispatch_mode=JobDispatchMode.ASYNC_API)
job.refresh_from_db()
return job

def _stub_manager(self, mock_manager_cls) -> AsyncMock:
instance = mock_manager_cls.return_value
instance.__aenter__ = AsyncMock(return_value=instance)
instance.__aexit__ = AsyncMock(return_value=False)
instance.log_consumer_stats_snapshot = AsyncMock()
return instance

def test_reports_both_sub_check_results(self, mock_manager_cls, _mock_cleanup):
self._create_stale_job()
self._create_stale_job()
self._stub_manager(mock_manager_cls)

result = jobs_health_check()

self.assertEqual(
result,
{
"stale_jobs": {"checked": 2, "fixed": 2, "unfixable": 0},
"running_job_snapshots": _empty_check_dict(),
},
)

def test_idle_deployment_returns_all_zeros(self, mock_manager_cls, _mock_cleanup):
# No stale jobs, no running async jobs.
self._create_stale_job(hours_ago=1) # recent — not stale
self._stub_manager(mock_manager_cls)

self.assertEqual(
jobs_health_check(),
{
"stale_jobs": _empty_check_dict(),
"running_job_snapshots": _empty_check_dict(),
},
)

def test_snapshots_each_running_async_job(self, mock_manager_cls, _mock_cleanup):
job_a = self._create_async_job()
job_b = self._create_async_job()
instance = self._stub_manager(mock_manager_cls)

result = jobs_health_check()

self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 0})
# Observation-only contract: the snapshot sub-check must never report
# ``fixed > 0`` since it does not mutate state. Lock this in explicitly
# so a future refactor that accidentally increments ``fixed`` breaks
# this assertion rather than silently shipping.
self.assertEqual(result["running_job_snapshots"]["fixed"], 0)
snapshots = [call.args[0] for call in instance.log_consumer_stats_snapshot.await_args_list]
self.assertCountEqual(snapshots, [job_a.pk, job_b.pk])

def test_one_job_snapshot_failure_counts_as_unfixable(self, mock_manager_cls, _mock_cleanup):
job_ok = self._create_async_job()
job_broken = self._create_async_job()
instance = self._stub_manager(mock_manager_cls)

calls = []

async def _snapshot(job_id):
calls.append(job_id)
if job_id == job_broken.pk:
raise RuntimeError("nats down for this one")

instance.log_consumer_stats_snapshot = AsyncMock(side_effect=_snapshot)

result = jobs_health_check()

# Both jobs were attempted; only the broken one failed.
self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 1})
self.assertIn(job_ok.pk, calls)
self.assertIn(job_broken.pk, calls)

def test_shared_connection_setup_failure_marks_all_unfixable(self, mock_manager_cls, _mock_cleanup):
self._create_async_job()
self._create_async_job()

instance = mock_manager_cls.return_value
instance.__aenter__ = AsyncMock(side_effect=RuntimeError("nats down"))
instance.__aexit__ = AsyncMock(return_value=False)
instance.log_consumer_stats_snapshot = AsyncMock()

result = jobs_health_check()

# All running jobs are counted as unfixable for this tick; no
# snapshots ran and the shared-connection error was swallowed.
self.assertEqual(result["running_job_snapshots"], {"checked": 2, "fixed": 0, "unfixable": 2})
instance.log_consumer_stats_snapshot.assert_not_awaited()

def test_non_async_running_jobs_are_ignored_by_snapshot_check(self, mock_manager_cls, _mock_cleanup):
job = Job.objects.create(project=self.project, name="sync job", status=JobState.STARTED)
self.assertNotEqual(job.dispatch_mode, JobDispatchMode.ASYNC_API)
instance = self._stub_manager(mock_manager_cls)

result = jobs_health_check()

self.assertEqual(result["running_job_snapshots"], _empty_check_dict())
instance.log_consumer_stats_snapshot.assert_not_awaited()

def test_sub_check_exception_does_not_block_the_other(self, mock_manager_cls, _mock_cleanup):
# One stale job to prove the reconciler would have had work; the
# snapshot sub-check raises and must not prevent the stale-jobs
# sub-check from running and reporting its own result.
self._create_stale_job()
self._stub_manager(mock_manager_cls)

with patch(
"ami.jobs.tasks._run_running_job_snapshot_check",
side_effect=RuntimeError("pretend the observation check blew up"),
):
result = jobs_health_check()

# Stale-jobs sub-check completes normally and reports the reconciliation.
self.assertEqual(result["stale_jobs"], {"checked": 1, "fixed": 1, "unfixable": 0})
# Snapshot sub-check returns the `unfixable=1` sentinel on failure so
# operators reading the task result see the check failed, not "nothing
# to do."
self.assertEqual(result["running_job_snapshots"], {"checked": 0, "fixed": 0, "unfixable": 1})

def test_stale_jobs_fixed_counts_celery_updated_and_revoked_paths(self, mock_manager_cls, _mock_cleanup):
# Two stale jobs in different reconciliation states: one has a Celery
# task_id that returns a terminal state (counts as "updated from Celery"),
# the other has no task_id and is forced to REVOKED. Both contribute to
# `fixed` — this test guards against a refactor dropping one branch.
from celery import states

job_with_task = self._create_stale_job()
job_with_task.task_id = "terminal-task"
job_with_task.save(update_fields=["task_id"])
self._create_stale_job() # no task_id → revoked path
self._stub_manager(mock_manager_cls)

class _FakeAsyncResult:
def __init__(self, task_id):
self.state = states.SUCCESS if task_id == "terminal-task" else states.PENDING

# `check_stale_jobs` imports AsyncResult locally from celery.result,
# so patch at source rather than at the call site.
with patch("celery.result.AsyncResult", _FakeAsyncResult):
result = jobs_health_check()

# checked == 2 (both stale), fixed == 2 (one per branch), unfixable == 0
self.assertEqual(result["stale_jobs"], {"checked": 2, "fixed": 2, "unfixable": 0})
12 changes: 12 additions & 0 deletions ami/main/checks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Integrity and health check primitives shared across apps.

Sub-modules in this package (added by per-domain check PRs such as
``ami.main.checks.occurrences`` in #1188) define ``get_*`` and
``reconcile_*`` function pairs. The shared result schema lives in
:mod:`ami.main.checks.schemas` so reconciliation and observation checks
across apps return the same shape.
"""

from ami.main.checks.schemas import IntegrityCheckResult

__all__ = ["IntegrityCheckResult"]
Loading
Loading