Skip to content

Add task_expires_seconds for async endpoint task expiration#754

Open
lukasewecker wants to merge 8 commits intomainfrom
lukas/task-expires
Open

Add task_expires_seconds for async endpoint task expiration#754
lukasewecker wants to merge 8 commits intomainfrom
lukas/task-expires

Conversation

@lukasewecker
Copy link
Collaborator

@lukasewecker lukasewecker commented Feb 11, 2026

Pull Request Summary

Add configurable task expiration for async endpoints, allowing users to specify how long a task can wait in the queue before being discarded by Celery workers. Defaults to 86400 seconds (24 hours) if not specified.

Keep in mind that this change does not automatically discard a task from the queue once it is stale. It only marks a task as stale, which then waits in the queue to be picked up by the worker. Once picked up by a worker, the worker then notices that the task is stale and will not process it. This means that when a stale task is never picked up by the worker (e.g. no worker is available), it stays in the queue.

Changes:

  • Add task_expires_seconds field to DTOs, entities, and ORM model
  • Pass expires parameter to Celery task queue gateway
  • Add K8s annotations for celery autoscaler configuration
  • Include database migration for new column
  • Add comprehensive unit tests for the feature

Note

  • Before deploying, the db migration must be run (alembic update head).
  • There was a paremeter task_timeout_seconds before, which was not used in the downstream tasks though. It was all renamed to task_expires_seconds, and is now passed properly to the celery tasks.

Example API usage

When creating a new model endpoint via the api, you can now pass the parameter:

curl -X POST "https://api.example.com/v1/model-endpoints" \                                                                                                                                                                                                                                             
  -H "Content-Type: application/json" \                                                                                                                                                                                                                                                                 
  -d '{                                                                                                                                                                                                                                                                                                 
    "name": "my-endpoint",                                                                                                                                                                                                                                                                              
    "model_bundle_id": "bundle-123",                                                                                                                                                                                                                                                                    
    "endpoint_type": "async",                                                                                                                                                                                                                                                                           
    "min_workers": 1,                                                                                                                                                                                                                                                                                   
    "max_workers": 3,                                                                                                                                                                                                                                                                                   
    "per_worker": 1,                                                                                                                                                                                                                                                                                    
    "cpus": 1,                                                                                                                                                                                                                                                                                          
    "memory": "4Gi",                                                                                                                                                                                                                                                                                    
    "storage": "10Gi",                                                                                                                                                                                                                                                                                  
    "gpus": 0,                                                                                                                                                                                                                                                                                          
    "labels": {},                                                                                                                                                                                                                                                                                       
    "metadata": {},                                                                                                                                                                                                                                                                                     
    "task_expires_seconds": 3600                                                                                                                                                                                                                                                                        
  }'                                                                                                                                                                                                                                                                                                    

Test Plan and Usage Guide

Ran unit tests and added new tests for the new feature.

Also, tested locally with this script (spun up Redis locally):

 """Test task_expires_seconds is passed to Celery"""                                                                                                                                                                                                                                                    
 import os                                                                                                                                                                                                                                                                                              
 os.environ["USE_REDIS_LOCALHOST"] = "1"                                                                                                                                                                                                                                                                
 os.environ["GIT_TAG"] = "test"                                                                                                                                                                                                                                                                         
                                                                                                                                                                                                                                                                                                        
 from model_engine_server.common.dtos.model_endpoints import BrokerType                                                                                                                                                                                                                                 
 from model_engine_server.core.tracing.live_tracing_gateway import LiveTracingGateway                                                                                                                                                                                                                   
 from model_engine_server.infra.gateways.celery_task_queue_gateway import CeleryTaskQueueGateway                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                        
 # Create gateway                                                                                                                                                                                                                                                                                       
 gateway = CeleryTaskQueueGateway(                                                                                                                                                                                                                                                                      
     broker_type=BrokerType.REDIS_24H,                                                                                                                                                                                                                                                                  
     tracing_gateway=LiveTracingGateway()                                                                                                                                                                                                                                                               
 )                                                                                                                                                                                                                                                                                                      
                                                                                                                                                                                                                                                                                                        
 print("Testing task submission with expires parameter...\n")                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                                        
 from unittest.mock import patch, MagicMock                                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                                                        
 with patch.object(gateway, '_get_celery_dest') as mock_get_dest:                                                                                                                                                                                                                                       
     mock_celery = MagicMock()                                                                                                                                                                                                                                                                          
     mock_celery.send_task.return_value = MagicMock(id="test-task-id")                                                                                                                                                                                                                                  
     mock_get_dest.return_value = mock_celery                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                                        
     # Test 1: Custom expiration (1 hour)                                                                                                                                                                                                                                                               
     gateway.send_task(                                                                                                                                                                                                                                                                                 
         task_name="test.task",                                                                                                                                                                                                                                                                         
         queue_name="test-queue",                                                                                                                                                                                                                                                                       
         args=[{"test": "data"}],                                                                                                                                                                                                                                                                       
         expires=3600,                                                                                                                                                                                                                                                                                  
     )                                                                                                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                                        
     call_kwargs = mock_celery.send_task.call_args.kwargs                                                                                                                                                                                                                                               
     print(f"Test 1 - Custom expiration:")                                                                                                                                                                                                                                                              
     print(f"  expires={call_kwargs.get('expires')} seconds")                                                                                                                                                                                                                                           
     assert call_kwargs['expires'] == 3600, "expires should be 3600"                                                                                                                                                                                                                                    
     print("  ✅ PASSED\n")                                                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                                                        
     # Test 2: Default expiration (86400 = 24 hours)                                                                                                                                                                                                                                                    
     mock_celery.reset_mock()                                                                                                                                                                                                                                                                           
     gateway.send_task(                                                                                                                                                                                                                                                                                 
         task_name="test.task",                                                                                                                                                                                                                                                                         
         queue_name="test-queue",                                                                                                                                                                                                                                                                       
         args=[{"test": "data"}],                                                                                                                                                                                                                                                                       
         expires=86400,                                                                                                                                                                                                                                                                                 
     )                                                                                                                                                                                                                                                                                                  
                                                                                                                                                                                                                                                                                                        
     call_kwargs = mock_celery.send_task.call_args.kwargs                                                                                                                                                                                                                                               
     print(f"Test 2 - Default expiration (24 hours):")                                                                                                                                                                                                                                                  
     print(f"  expires={call_kwargs.get('expires')} seconds")                                                                                                                                                                                                                                           
     assert call_kwargs['expires'] == 86400, "expires should be 86400"                                                                                                                                                                                                                                  
     print("  ✅ PASSED\n")                                                                                                                                                                                                                                                                             
                                                                                                                                                                                                                                                                                                        

Greptile Summary

This PR adds configurable task expiration for async endpoints via a new task_expires_seconds field, allowing users to control how long a Celery task can wait in the queue before being discarded by workers (defaults to 86400 seconds / 24 hours). It also introduces queue_message_timeout_seconds to configure SQS VisibilityTimeout and ASB lock_duration for queue message processing. The changes span the full stack: DTOs, domain entities, ORM model, database migration, use cases, gateways, K8s templates, and tests.

  • Adds task_expires_seconds (persisted to DB) and queue_message_timeout_seconds (passed to queue infrastructure) to create/update endpoint flows for both base and LLM endpoints
  • Renames task_timeout_seconds to task_expires_seconds throughout the codebase for clarity and now actually passes it to Celery's expires parameter
  • Adds K8s annotation celery.scaleml.autoscaler/taskExpiresSeconds to async deployment templates for autoscaler awareness
  • Includes Alembic migration for the new task_expires_seconds column and comprehensive unit tests
  • Note: The Helm chart template (charts/model-engine/templates/_helpers.tpl) was not updated with the new taskExpiresSeconds annotation, creating an inconsistency with the CircleCI template

Confidence Score: 4/5

  • This PR is generally safe to merge but has one deployment consistency gap in the Helm chart template that should be addressed.
  • The implementation is thorough and well-structured across the full stack with good test coverage. The only notable issue is the missing taskExpiresSeconds annotation in the Helm chart template, which could cause inconsistent behavior for Helm-based deployments. All other changes follow existing patterns correctly.
  • The Helm chart template at charts/model-engine/templates/_helpers.tpl needs the taskExpiresSeconds annotation added to modelEngine.serviceTemplateAsyncAnnotations for parity with the CircleCI template.

Important Files Changed

Filename Overview
model-engine/model_engine_server/common/dtos/model_endpoints.py Adds queue_message_timeout_seconds and task_expires_seconds fields to Create/Update request DTOs and Get response DTO with appropriate validation (gt=0 for task_expires, ge=1/le=43200 for queue_message_timeout).
model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py Renames DEFAULT_TASK_TIMEOUT_SECONDS to DEFAULT_TASK_EXPIRES_SECONDS, reads task_expires_seconds from the endpoint record with fallback to default. Correctly passes the value to the inference gateway.
model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py Passes the expires parameter through to celery_dest.send_task(). Simple and correct change.
model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py Adds TASK_EXPIRES_SECONDS to _AsyncDeploymentArguments TypedDict and populates it in all 4 async deployment argument builders using model_endpoint_record.task_expires_seconds or DEFAULT_TASK_EXPIRES_SECONDS.
model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml Adds celery.scaleml.autoscaler/taskExpiresSeconds annotation to all 4 async deployment templates. However, the equivalent Helm chart template (charts/model-engine/templates/_helpers.tpl) was not updated.
model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py Adds queue_message_timeout_seconds support for SQS VisibilityTimeout. Updates existing queues when parameter is provided, and uses it (with 43200 fallback) for new queues.
model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py Adds queue_message_timeout_seconds support for ASB lock_duration with proper guard (only updates when not None) and 300-second cap. Handles errors gracefully with logging.
model-engine/model_engine_server/db/migrations/alembic/versions/2026_02_10_1920-62da4f8b3403_add_task_expires_seconds_column.py Clean migration adding nullable task_expires_seconds integer column to the endpoints table. Proper upgrade/downgrade implemented.
model-engine/model_engine_server/infra/services/live_model_endpoint_service.py Correctly threads task_expires_seconds through create/update flows to record repository and queue_message_timeout_seconds to infra gateway.
model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py Properly adds task_expires_seconds to both create and update paths, using dict_not_none for the update to correctly handle None values.

Sequence Diagram

sequenceDiagram
    participant Client
    participant API
    participant UseCase as AsyncInferenceUseCase
    participant Service as ModelEndpointService
    participant Record as EndpointRecord (DB)
    participant InfraGW as InfraGateway
    participant Queue as Queue (SQS/ASB/Redis)
    participant CeleryGW as CeleryTaskQueueGateway
    participant Worker as Celery Worker

    Note over Client,Worker: Endpoint Creation Flow
    Client->>API: POST /model-endpoints (task_expires_seconds=3600, queue_message_timeout_seconds=600)
    API->>Service: create_model_endpoint(task_expires_seconds, queue_message_timeout_seconds)
    Service->>Record: create_record(task_expires_seconds=3600)
    Service->>InfraGW: create_infra(queue_message_timeout_seconds=600)
    InfraGW->>Queue: create_queue(VisibilityTimeout=600)
    InfraGW-->>Service: K8s deployment with taskExpiresSeconds annotation

    Note over Client,Worker: Async Task Submission Flow
    Client->>API: POST /async-tasks
    API->>UseCase: execute(endpoint_id, request)
    UseCase->>Record: read task_expires_seconds (3600)
    UseCase->>CeleryGW: send_task(expires=3600)
    CeleryGW->>Queue: send_task(expires=3600)
    Queue->>Worker: pick up task
    Worker->>Worker: check if task expired (>3600s old → discard)
Loading

Last reviewed commit: a60fb88

Add configurable task expiration for async endpoints, allowing users to
specify how long a task can wait in the queue before being discarded by
Celery workers. Defaults to 86400 seconds (24 hours) if not specified.

Changes:
- Add task_expires_seconds field to DTOs, entities, and ORM model
- Pass expires parameter to Celery task queue gateway
- Add K8s annotations for celery autoscaler configuration
- Include database migration for new column
- Add comprehensive unit tests for the feature

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
self,
topic: str,
predict_request: EndpointPredictV1Request,
task_timeout_seconds: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is being used in async_inference_use_cases.py

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just renamed the parameter to task_expires_seconds here. Can you explain more what you mean?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue was that it wasnt properly passed down to when celery creates the task:

res = celery_dest.send_task(
name=task_name,
args=args,
kwargs=kwargs,
queue=queue_name,
)

Also it wasnt possible to configure this parameter. So basically what I did was

  • rename it (you could argue if thats necessary)
  • pass it down properly to downstream code
  • make it configurable from outside

@olliestanley
Copy link
Member

olliestanley commented Feb 18, 2026

incorporated changes from this other PR #760 as requested by @dmchoiboi

reasoning for these timeout changes:

  • On Azure, the default time for ASBs (Azure Service Bus queues) is 1 minute. Therefore if inference time is longer than 1 minute for any of our models, that will make inference time out repeatedly and eventually the message gets discarded.
  • This PR adds a parameter that can be used to set the timeout.

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

34 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

34 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

34 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

task_expires_seconds: Optional[int] = Field(
default=None,
description="For async endpoints, how long a task can wait in queue before expiring (in seconds).",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add queue_message_timeout_seconds to this

@@ -35,6 +35,7 @@ class BuildEndpointRequest(BaseModel):
high_priority: Optional[bool] = None
default_callback_url: Optional[str] = None
default_callback_auth: Optional[CallbackAuth] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to pipe task_expires_seconds in here as well?

resource_state: Optional[ModelEndpointResourceState] = Field(default=None)
num_queued_items: Optional[int] = Field(default=None)
public_inference: Optional[bool] = Field(default=None)
task_expires_seconds: Optional[int] = Field(default=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add queue_message_timeout_seconds

resource_state=(None if infra_state is None else infra_state.resource_state),
num_queued_items=(None if infra_state is None else infra_state.num_queued_items),
public_inference=model_endpoint.record.public_inference,
task_expires_seconds=model_endpoint.record.task_expires_seconds,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add queue_message_timeout_seconds

pass

if queue_message_timeout_seconds is not None:
lock_duration = timedelta(seconds=min(queue_message_timeout_seconds, 300))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract 300 to a const, and can we bump to a higher default? some requests can take > 5min

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments