Skip to content
This repository was archived by the owner on Mar 28, 2025. It is now read-only.

fix(17): Add timeout to auto unlock recurring tasks #16

Merged
merged 3 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TaskHandler(typing.Generic[P]):
"priority",
"transaction_on_commit",
"task_identifier",
"timeout",
)

unwrapped: typing.Callable[P, None]
Expand All @@ -38,11 +39,13 @@ def __init__(
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = None,
) -> None:
self.unwrapped = f
self.queue_size = queue_size
self.priority = priority
self.transaction_on_commit = transaction_on_commit
self.timeout = timeout

task_name = task_name or f.__name__
task_module = getmodule(f).__name__.rsplit(".")[-1]
Expand Down Expand Up @@ -87,6 +90,7 @@ def delay(
scheduled_for=delay_until or timezone.now(),
priority=self.priority,
queue_size=self.queue_size,
timeout=self.timeout,
args=args,
kwargs=kwargs,
)
Expand Down Expand Up @@ -124,6 +128,7 @@ def register_task_handler( # noqa: C901
queue_size: int | None = None,
priority: TaskPriority = TaskPriority.NORMAL,
transaction_on_commit: bool = True,
timeout: timedelta | None = timedelta(seconds=60),
) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]:
"""
Turn a function into an asynchronous task.
Expand All @@ -150,6 +155,7 @@ def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]:
queue_size=queue_size,
priority=priority,
transaction_on_commit=transaction_on_commit,
timeout=timeout,
)

return wrapper
Expand All @@ -161,6 +167,7 @@ def register_recurring_task(
args: tuple[typing.Any] = (),
kwargs: dict[str, typing.Any] | None = None,
first_run_time: time | None = None,
timeout: timedelta | None = timedelta(minutes=30),
) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
if not os.environ.get("RUN_BY_PROCESSOR"):
# Do not register recurring tasks if not invoked by task processor
Expand All @@ -182,6 +189,7 @@ def decorator(f: typing.Callable[..., None]) -> RecurringTask:
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
"run_every": run_every,
"first_run_time": first_run_time,
"timeout": timeout,
},
)
return task
Expand Down
4 changes: 2 additions & 2 deletions task_processor/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ def get_tasks_to_process(self, num_tasks: int) -> QuerySet["Task"]:


class RecurringTaskManager(Manager):
def get_tasks_to_process(self, num_tasks: int) -> QuerySet["RecurringTask"]:
return self.raw("SELECT * FROM get_recurringtasks_to_process(%s)", [num_tasks])
def get_tasks_to_process(self) -> QuerySet["RecurringTask"]:
return self.raw("SELECT * FROM get_recurringtasks_to_process()")
39 changes: 39 additions & 0 deletions task_processor/migrations/0012_add_locked_at_and_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Generated by Django 3.2.23 on 2025-01-06 04:51

from task_processor.migrations.helpers import PostgresOnlyRunSQL
import datetime
from django.db import migrations, models
import os


class Migration(migrations.Migration):

dependencies = [
("task_processor", "0011_add_priority_to_get_tasks_to_process"),
]

operations = [
migrations.AddField(
model_name="recurringtask",
name="locked_at",
field=models.DateTimeField(blank=True, null=True),
),
migrations.AddField(
model_name="recurringtask",
name="timeout",
field=models.DurationField(default=datetime.timedelta(minutes=30)),
),
migrations.AddField(
model_name="task",
name="timeout",
field=models.DurationField(blank=True, null=True),
),
PostgresOnlyRunSQL.from_sql_file(
os.path.join(
os.path.dirname(__file__),
"sql",
"0012_get_recurringtasks_to_process.sql",
),
reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process()",
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
CREATE OR REPLACE FUNCTION get_recurringtasks_to_process()
RETURNS SETOF task_processor_recurringtask AS $$
DECLARE
row_to_return task_processor_recurringtask;
BEGIN
-- Select the tasks that needs to be processed
FOR row_to_return IN
SELECT *
FROM task_processor_recurringtask
-- Add one minute to the timeout as a grace period for overhead
WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout + INTERVAL '1 minute')
ORDER BY id
LIMIT 1
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
FOR UPDATE SKIP LOCKED
LOOP
-- Lock every selected task(by updating `is_locked` to true)
UPDATE task_processor_recurringtask
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
-- transaction is complete (but the tasks are still being executed by the current worker)
SET is_locked = TRUE, locked_at = NOW()
WHERE id = row_to_return.id;
-- If we don't explicitly update the columns here, the client will receive a row
-- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`.
row_to_return.is_locked := TRUE;
row_to_return.locked_at := NOW();
RETURN NEXT row_to_return;
END LOOP;

RETURN;
END;
$$ LANGUAGE plpgsql

13 changes: 12 additions & 1 deletion task_processor/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing
import uuid
from datetime import datetime
from datetime import datetime, timedelta

import simplejson as json
from django.core.serializers.json import DjangoJSONEncoder
Expand Down Expand Up @@ -80,6 +80,8 @@ def callable(self) -> typing.Callable:
class Task(AbstractBaseTask):
scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now)

timeout = models.DurationField(blank=True, null=True)

# denormalise failures and completion so that we can use select_for_update
num_failures = models.IntegerField(default=0)
completed = models.BooleanField(default=False)
Expand Down Expand Up @@ -109,6 +111,7 @@ def create(
*,
args: typing.Tuple[typing.Any] = None,
kwargs: typing.Dict[str, typing.Any] = None,
timeout: timedelta | None = timedelta(seconds=60),
) -> "Task":
if queue_size and cls._is_queue_full(task_identifier, queue_size):
raise TaskQueueFullError(
Expand All @@ -121,6 +124,7 @@ def create(
priority=priority,
serialized_args=cls.serialize_data(args or tuple()),
serialized_kwargs=cls.serialize_data(kwargs or dict()),
timeout=timeout,
)

@classmethod
Expand All @@ -147,6 +151,9 @@ class RecurringTask(AbstractBaseTask):
run_every = models.DurationField()
first_run_time = models.TimeField(blank=True, null=True)

locked_at = models.DateTimeField(blank=True, null=True)
timeout = models.DurationField(default=timedelta(minutes=30))

objects = RecurringTaskManager()

class Meta:
Expand All @@ -157,6 +164,10 @@ class Meta:
),
]

def unlock(self):
self.is_locked = False
self.locked_at = None

@property
def should_execute(self) -> bool:
now = timezone.now()
Expand Down
30 changes: 19 additions & 11 deletions task_processor/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import traceback
import typing
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta

from django.utils import timezone
Expand Down Expand Up @@ -36,7 +37,8 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]:

if executed_tasks:
Task.objects.bulk_update(
executed_tasks, fields=["completed", "num_failures", "is_locked"]
executed_tasks,
fields=["completed", "num_failures", "is_locked"],
)

if task_runs:
Expand All @@ -48,14 +50,11 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]:
return []


def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]:
if num_tasks < 1:
raise ValueError("Number of tasks to process must be at least one")

def run_recurring_tasks() -> typing.List[RecurringTaskRun]:
# NOTE: We will probably see a lot of delay in the execution of recurring tasks
# if the tasks take longer then `run_every` to execute. This is not
# a problem for now, but we should be mindful of this limitation
tasks = RecurringTask.objects.get_tasks_to_process(num_tasks)
tasks = RecurringTask.objects.get_tasks_to_process()
if tasks:
task_runs = []

Expand All @@ -78,7 +77,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]:

# update all tasks that were not deleted
to_update = [task for task in tasks if task.id]
RecurringTask.objects.bulk_update(to_update, fields=["is_locked"])
RecurringTask.objects.bulk_update(to_update, fields=["is_locked", "locked_at"])

if task_runs:
RecurringTaskRun.objects.bulk_create(task_runs)
Expand All @@ -93,16 +92,25 @@ def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, Tas
task_run = task.task_runs.model(started_at=timezone.now(), task=task)

try:
task.run()
task_run.result = TaskResult.SUCCESS
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(task.run)
timeout = task.timeout.total_seconds() if task.timeout else None
future.result(timeout=timeout) # Wait for completion or timeout

task_run.result = TaskResult.SUCCESS
task_run.finished_at = timezone.now()
task.mark_success()

except Exception as e:
# For errors that don't include a default message (e.g., TimeoutError),
# fall back to using repr.
err_msg = str(e) or repr(e)

logger.error(
"Failed to execute task '%s'. Exception was: %s",
"Failed to execute task '%s', with id %d. Exception: %s",
task.task_identifier,
str(e),
task.id,
err_msg,
exc_info=True,
)
logger.debug("args: %s", str(task.args))
Expand Down
2 changes: 1 addition & 1 deletion task_processor/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def run(self) -> None:
def run_iteration(self) -> None:
try:
run_tasks(self.queue_pop_size)
run_recurring_tasks(self.queue_pop_size)
run_recurring_tasks()
except Exception as e:
# To prevent task threads from dying if they get an error retrieving the tasks from the
# database this will allow the thread to continue trying to retrieve tasks if it can
Expand Down
Loading
Loading