Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from ami.jobs.tasks import process_nats_pipeline_result
from ami.main.api.schemas import project_id_doc_param
from ami.main.api.views import DefaultViewSet
from ami.ml.auth import HasProcessingServiceAPIKey
from ami.ml.models.processing_service import ProcessingService
from ami.utils.fields import url_boolean_param

from .models import Job, JobDispatchMode, JobState
Expand Down Expand Up @@ -146,6 +148,13 @@ class JobViewSet(DefaultViewSet, ProjectMixin):

permission_classes = [ObjectPermission]

def _update_processing_service_heartbeat(self, request):
Comment thread
mihow marked this conversation as resolved.
"""Update heartbeat for the specific PS identified by API key auth."""
from ami.ml.schemas import get_client_info

if isinstance(request.auth, ProcessingService):
request.auth.mark_seen(client_info=get_client_info(request))

def get_serializer_class(self):
"""
Return different serializers for list and detail views.
Expand Down Expand Up @@ -247,7 +256,12 @@ def list(self, request, *args, **kwargs):
responses={200: MLJobTasksResponseSerializer},
parameters=[project_id_doc_param],
)
@action(detail=True, methods=["post"], name="tasks")
@action(
detail=True,
methods=["post"],
name="tasks",
permission_classes=[ObjectPermission | HasProcessingServiceAPIKey],
)
def tasks(self, request, pk=None):
"""
Fetch tasks from the job queue (POST).
Expand Down Expand Up @@ -275,8 +289,13 @@ def tasks(self, request, pk=None):
if not job.pipeline:
raise ValidationError("This job does not have a pipeline configured")

# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)
# Record heartbeat. When the request is API-key-authenticated we know the
# exact PS, so use the precise per-PS heartbeat. Fall back to the bulk
# pipeline-level heartbeat for token-authenticated requests (transition period).
if isinstance(request.auth, ProcessingService):
self._update_processing_service_heartbeat(request)
else:
_mark_pipeline_pull_services_seen(job)

# Get tasks from NATS JetStream
from ami.ml.orchestration.nats_queue import TaskQueueManager
Expand All @@ -298,7 +317,12 @@ async def get_tasks():
responses={200: MLJobResultsResponseSerializer},
parameters=[project_id_doc_param],
)
@action(detail=True, methods=["post"], name="result")
@action(
detail=True,
methods=["post"],
name="result",
permission_classes=[ObjectPermission | HasProcessingServiceAPIKey],
)
def result(self, request, pk=None):
"""
Submit pipeline results.
Expand All @@ -310,8 +334,11 @@ def result(self, request, pk=None):

job = self.get_object()

# Record heartbeat for async processing services on this pipeline
_mark_pipeline_pull_services_seen(job)
# Record heartbeat (see comment in tasks() for rationale)
if isinstance(request.auth, ProcessingService):
self._update_processing_service_heartbeat(request)
else:
_mark_pipeline_pull_services_seen(job)

serializer = MLJobResultsRequestSerializer(data=request.data)
serializer.is_valid(raise_exception=True)
Expand Down
87 changes: 43 additions & 44 deletions ami/main/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3648,9 +3648,16 @@ def test_nonexistent_taxa_list_returns_404(self):


class TestProjectPipelinesAPI(APITestCase):
"""Test the project pipelines API endpoint."""
"""Test the project pipelines API endpoint.

Pipeline registration requires API key authentication (since PR #1194).
The processing service is identified by its API key, not by name.
"""

def setUp(self):
from unittest.mock import patch

from ami.ml.models.processing_service import ProcessingServiceAPIKey
from ami.users.roles import ProjectManager, create_roles_for_project

self.user = User.objects.create_user(email="test@example.com") # type: ignore
Expand All @@ -3665,76 +3672,68 @@ def setUp(self):
create_roles_for_project(self.other_project)
ProjectManager.assign_user(self.user, self.project)

# Create a processing service with API key for registration tests
with patch.object(ProcessingService, "get_status"):
self.service = ProcessingService.objects.create(name="TestService", endpoint_url=None)
self.service.projects.add(self.project)
_, self.api_key = ProcessingServiceAPIKey.objects.create_key(name="test-key", processing_service=self.service)

def _get_pipelines_url(self, project_id):
"""Get the pipelines API URL for a project."""
return f"/api/v2/projects/{project_id}/pipelines/"

def _get_test_payload(self, service_name: str):
"""Get a minimal test payload for pipeline registration."""
return {
"processing_service_name": service_name,
"pipelines": [],
}

def test_create_new_service_success(self):
"""Test creating a new processing service if it doesn't exist."""
def test_registration_with_api_key_succeeds(self):
"""Test that API-key-authenticated registration succeeds."""
url = self._get_pipelines_url(self.project.pk)
payload = self._get_test_payload("NewService")
payload = {"pipelines": []}

self.client.force_authenticate(user=self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}")
response = self.client.post(url, payload, format="json")

self.assertEqual(response.status_code, status.HTTP_201_CREATED)

# Verify service was created and associated
service = ProcessingService.objects.get(name="NewService")
self.assertIn(self.project, service.projects.all())

def test_reregistration_is_idempotent(self):
"""Test that re-registering a service already associated with the project succeeds."""
# Create and associate service
service = ProcessingService.objects.create(name="ExistingService")
service.projects.add(self.project)

"""Test that re-registering the same service succeeds."""
url = self._get_pipelines_url(self.project.pk)
payload = self._get_test_payload("ExistingService")
payload = {"pipelines": []}

self.client.force_authenticate(user=self.user)
response = self.client.post(url, payload, format="json")
self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}")

self.assertEqual(response.status_code, status.HTTP_201_CREATED)
response1 = self.client.post(url, payload, format="json")
self.assertEqual(response1.status_code, status.HTTP_201_CREATED)

def test_associate_existing_service_success(self):
"""Test associating existing service with project when not yet associated."""
# Create service but don't associate with project
service = ProcessingService.objects.create(name="UnassociatedService")
response2 = self.client.post(url, payload, format="json")
self.assertEqual(response2.status_code, status.HTTP_201_CREATED)

def test_registration_updates_heartbeat(self):
"""Test that registration marks the service as seen."""
url = self._get_pipelines_url(self.project.pk)
payload = self._get_test_payload("UnassociatedService")
payload = {"pipelines": []}

self.client.force_authenticate(user=self.user)
response = self.client.post(url, payload, format="json")
self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}")
self.client.post(url, payload, format="json")

self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertIn(self.project, service.projects.all())
self.service.refresh_from_db()
self.assertIsNotNone(self.service.last_seen)
self.assertTrue(self.service.last_seen_live)

def test_unauthorized_project_access_returns_403(self):
"""Test 403 when user doesn't have write access to project."""
def test_wrong_project_denied(self):
"""Test that API key for a PS not linked to the target project is denied."""
url = self._get_pipelines_url(self.other_project.pk)
payload = self._get_test_payload("UnauthorizedService")
payload = {"pipelines": []}

self.client.force_authenticate(user=self.user)
self.client.credentials(HTTP_AUTHORIZATION=f"Api-Key {self.api_key}")
response = self.client.post(url, payload, format="json")

self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertIn(response.status_code, [status.HTTP_403_FORBIDDEN, status.HTTP_404_NOT_FOUND])

def test_invalid_payload_returns_400(self):
"""Test 400 when payload is invalid."""
def test_user_token_auth_rejected_for_registration(self):
"""Test that user-token auth is rejected for pipeline registration."""
url = self._get_pipelines_url(self.project.pk)
invalid_payload = {"invalid": "data"}
payload = {"pipelines": []}

self.client.force_authenticate(user=self.user)
response = self.client.post(url, invalid_payload, format="json")
response = self.client.post(url, payload, format="json")

self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

Expand Down Expand Up @@ -3771,7 +3770,7 @@ def test_list_pipelines_draft_project_non_member(self):
def test_unauthenticated_write_returns_401(self):
"""Unauthenticated users cannot register pipelines."""
url = self._get_pipelines_url(self.project.pk)
payload = self._get_test_payload("AnonService")
payload = {"pipelines": []}
response = self.client.post(url, payload, format="json")
self.assertIn(response.status_code, [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN])

Expand Down
27 changes: 26 additions & 1 deletion ami/ml/admin.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from django.contrib import admin
from rest_framework_api_key.admin import APIKeyModelAdmin

from ami.main.admin import AdminBase, ProjectPipelineConfigInline

from .models.algorithm import Algorithm, AlgorithmCategoryMap
from .models.pipeline import Pipeline
from .models.processing_service import ProcessingService
from .models.processing_service import ProcessingService, ProcessingServiceAPIKey


@admin.register(Algorithm)
Expand Down Expand Up @@ -70,8 +71,32 @@ class ProcessingServiceAdmin(AdminBase):
"id",
"name",
"endpoint_url",
"last_seen_live",
"created_at",
]
readonly_fields = ["last_seen_client_info"]

@admin.action(description="Generate API key for selected processing services (revokes existing)")
def generate_api_key(self, request, queryset):
for ps in queryset:
ps.api_keys.filter(revoked=False).update(revoked=True)
_, plaintext_key = ProcessingServiceAPIKey.objects.create_key(
name=f"{ps.name} key",
processing_service=ps,
)
self.message_user(
request,
f"{ps.name}: {plaintext_key} (copy now — it won't be shown again)",
)

actions = [generate_api_key]


@admin.register(ProcessingServiceAPIKey)
class ProcessingServiceAPIKeyAdmin(APIKeyModelAdmin):
Comment thread
mihow marked this conversation as resolved.
list_display = [*APIKeyModelAdmin.list_display, "processing_service"]
list_filter = ["processing_service"]
search_fields = [*APIKeyModelAdmin.search_fields, "processing_service__name"]


@admin.register(AlgorithmCategoryMap)
Expand Down
106 changes: 106 additions & 0 deletions ami/ml/auth.py
Comment thread
mihow marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
"""
API key authentication for processing services.

Uses djangorestframework-api-key to provide key-based auth. Each ProcessingService
can have one or more API keys. When a request arrives with `Authorization: Api-Key <key>`,
the authentication class identifies the ProcessingService and sets request.auth to it.

Contains:
- ProcessingServiceAPIKeyAuthentication: DRF auth backend
- HasProcessingServiceAPIKey: DRF permission class

The ProcessingServiceAPIKey model lives in ami.ml.models.processing_service.
"""

import logging

from rest_framework import authentication, exceptions, permissions
from rest_framework_api_key.permissions import KeyParser

from ami.ml.models.processing_service import ProcessingServiceAPIKey

logger = logging.getLogger(__name__)


class ProcessingServiceAPIKeyAuthentication(authentication.BaseAuthentication):
"""
DRF authentication class that identifies a ProcessingService from an API key.

Sets:
request.user = AnonymousUser (required by django-guardian/ObjectPermission)
request.auth = ProcessingService instance

This allows views to check `request.auth` to get the calling service,
and permission classes to verify project access.
"""

key_parser = KeyParser()

def authenticate(self, request):
key = self.key_parser.get(request)
if not key:
return None # No Api-Key header; fall through to next auth class

try:
api_key = ProcessingServiceAPIKey.objects.get_from_key(key)
except ProcessingServiceAPIKey.DoesNotExist:
raise exceptions.AuthenticationFailed("Invalid API key.")

if not api_key.is_valid:
raise exceptions.AuthenticationFailed("API key has been revoked or expired.")

from django.contrib.auth.models import AnonymousUser

return (AnonymousUser(), api_key.processing_service)

def authenticate_header(self, request):
return "Api-Key"
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this class and any methods already part of the DRF api_key package?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude says: BaseHasAPIKey from the library handles key validation (is_valid() check), but our HasProcessingServiceAPIKey serves a different purpose — it checks authorization (project membership) after the auth backend has already validated the key. The auth backend (ProcessingServiceAPIKeyAuthentication) handles authentication and puts the PS on request.auth. The permission class then checks that the PS belongs to the right project. So they complement each other rather than overlap — we can't inherit from BaseHasAPIKey because we need the isinstance(request.auth, ProcessingService) check rather than re-validating the key.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude says: The library provides BaseHasAPIKey (permission class) and KeyParser (header extraction), but not a DRF authentication backend. Our ProcessingServiceAPIKeyAuthentication serves a different role — it's an auth backend that sets request.auth = ProcessingService, which DRF's auth pipeline requires. The library's BaseHasAPIKey is a permission class that checks key validity but doesn't identify the caller.

We use KeyParser from the library (for header parsing) and its get_from_key() manager method (for hashed lookup). The auth backend and permission class are ours because they need to: (1) place the PS on request.auth, and (2) check project membership — both are app-specific concerns.



class HasProcessingServiceAPIKey(permissions.BasePermission):
"""
Allow access for requests authenticated with a ProcessingService API key.

The auth backend places the ProcessingService on request.auth.
This permission verifies project membership.

Compose with ObjectPermission for endpoints used by both users and services:
permission_classes = [ObjectPermission | HasProcessingServiceAPIKey]
"""

def has_permission(self, request, view):
from ami.ml.models.processing_service import ProcessingService

if not isinstance(request.auth, ProcessingService):
return False

# For detail views (e.g. /jobs/{pk}/tasks/), defer project scoping
# to has_object_permission where we can derive it from the object.
# CONTRACT: all detail-level actions using this permission MUST call
# self.get_object() so that DRF invokes has_object_permission().
# Actions that fetch objects manually without get_object() will bypass
# project-scoping checks.
if view.kwargs.get("pk"):
return True

get_active_project = getattr(view, "get_active_project", None)
if not callable(get_active_project):
return False

project = get_active_project()
if not project:
return False

return request.auth.projects.filter(pk=project.pk).exists()

def has_object_permission(self, request, view, obj):
from ami.ml.models.processing_service import ProcessingService

if not isinstance(request.auth, ProcessingService):
return False

ps = request.auth
project = obj.get_project() if hasattr(obj, "get_project") else None
if not project:
return False
return ps.projects.filter(pk=project.pk).exists()
Loading