Skip to content
This repository was archived by the owner on Mar 28, 2025. It is now read-only.

Commit f0ac380

Browse files
committed
fix(17/recurring-task-lock): Add timeout to auto unlock task
1 parent f92adfd commit f0ac380

File tree

6 files changed

+230
-12
lines changed

6 files changed

+230
-12
lines changed

task_processor/decorators.py

+8
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class TaskHandler(typing.Generic[P]):
2626
"priority",
2727
"transaction_on_commit",
2828
"task_identifier",
29+
"timeout",
2930
)
3031

3132
unwrapped: typing.Callable[P, None]
@@ -38,11 +39,13 @@ def __init__(
3839
queue_size: int | None = None,
3940
priority: TaskPriority = TaskPriority.NORMAL,
4041
transaction_on_commit: bool = True,
42+
timeout: timedelta | None = None,
4143
) -> None:
4244
self.unwrapped = f
4345
self.queue_size = queue_size
4446
self.priority = priority
4547
self.transaction_on_commit = transaction_on_commit
48+
self.timeout = timeout
4649

4750
task_name = task_name or f.__name__
4851
task_module = getmodule(f).__name__.rsplit(".")[-1]
@@ -87,6 +90,7 @@ def delay(
8790
scheduled_for=delay_until or timezone.now(),
8891
priority=self.priority,
8992
queue_size=self.queue_size,
93+
timeout=self.timeout,
9094
args=args,
9195
kwargs=kwargs,
9296
)
@@ -124,6 +128,7 @@ def register_task_handler( # noqa: C901
124128
queue_size: int | None = None,
125129
priority: TaskPriority = TaskPriority.NORMAL,
126130
transaction_on_commit: bool = True,
131+
timeout: timedelta | None = timedelta(seconds=60),
127132
) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]:
128133
"""
129134
Turn a function into an asynchronous task.
@@ -150,6 +155,7 @@ def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]:
150155
queue_size=queue_size,
151156
priority=priority,
152157
transaction_on_commit=transaction_on_commit,
158+
timeout=timeout,
153159
)
154160

155161
return wrapper
@@ -161,6 +167,7 @@ def register_recurring_task(
161167
args: tuple[typing.Any] = (),
162168
kwargs: dict[str, typing.Any] | None = None,
163169
first_run_time: time | None = None,
170+
timeout: timedelta | None = timedelta(minutes=30),
164171
) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
165172
if not os.environ.get("RUN_BY_PROCESSOR"):
166173
# Do not register recurring tasks if not invoked by task processor
@@ -182,6 +189,7 @@ def decorator(f: typing.Callable[..., None]) -> RecurringTask:
182189
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
183190
"run_every": run_every,
184191
"first_run_time": first_run_time,
192+
"timeout": timeout,
185193
},
186194
)
187195
return task
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Generated by Django 3.2.23 on 2025-01-06 04:51
2+
3+
from task_processor.migrations.helpers import PostgresOnlyRunSQL
4+
import datetime
5+
from django.db import migrations, models
6+
import os
7+
8+
9+
class Migration(migrations.Migration):
10+
11+
dependencies = [
12+
("task_processor", "0011_add_priority_to_get_tasks_to_process"),
13+
]
14+
15+
operations = [
16+
migrations.AddField(
17+
model_name="recurringtask",
18+
name="locked_at",
19+
field=models.DateTimeField(blank=True, null=True),
20+
),
21+
migrations.AddField(
22+
model_name="recurringtask",
23+
name="timeout",
24+
field=models.DurationField(default=datetime.timedelta(minutes=30)),
25+
),
26+
migrations.AddField(
27+
model_name="task",
28+
name="timeout",
29+
field=models.DurationField(blank=True, null=True),
30+
),
31+
PostgresOnlyRunSQL.from_sql_file(
32+
os.path.join(
33+
os.path.dirname(__file__),
34+
"sql",
35+
"0012_get_recurringtasks_to_process.sql",
36+
),
37+
reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process",
38+
),
39+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
CREATE OR REPLACE FUNCTION get_recurringtasks_to_process(num_tasks integer)
2+
RETURNS SETOF task_processor_recurringtask AS $$
3+
DECLARE
4+
row_to_return task_processor_recurringtask;
5+
BEGIN
6+
-- Select the tasks that needs to be processed
7+
FOR row_to_return IN
8+
SELECT *
9+
FROM task_processor_recurringtask
10+
WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout)
11+
ORDER BY id
12+
LIMIT num_tasks
13+
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
14+
FOR UPDATE SKIP LOCKED
15+
LOOP
16+
-- Lock every selected task(by updating `is_locked` to true)
17+
UPDATE task_processor_recurringtask
18+
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
19+
-- transaction is complete (but the tasks are still being executed by the current worker)
20+
SET is_locked = TRUE, locked_at = NOW()
21+
WHERE id = row_to_return.id;
22+
-- If we don't explicitly update the columns here, the client will receive a row
23+
-- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`.
24+
row_to_return.is_locked := TRUE;
25+
row_to_return.locked_at := NOW();
26+
RETURN NEXT row_to_return;
27+
END LOOP;
28+
29+
RETURN;
30+
END;
31+
$$ LANGUAGE plpgsql
32+

task_processor/models.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import typing
22
import uuid
3-
from datetime import datetime
3+
from datetime import datetime, timedelta
44

55
import simplejson as json
66
from django.core.serializers.json import DjangoJSONEncoder
@@ -61,6 +61,7 @@ def mark_success(self):
6161

6262
def unlock(self):
6363
self.is_locked = False
64+
self.locked_at = None
6465

6566
def run(self):
6667
return self.callable(*self.args, **self.kwargs)
@@ -80,6 +81,8 @@ def callable(self) -> typing.Callable:
8081
class Task(AbstractBaseTask):
8182
scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now)
8283

84+
timeout = models.DurationField(null=True, blank=True)
85+
8386
# denormalise failures and completion so that we can use select_for_update
8487
num_failures = models.IntegerField(default=0)
8588
completed = models.BooleanField(default=False)
@@ -109,6 +112,7 @@ def create(
109112
*,
110113
args: typing.Tuple[typing.Any] = None,
111114
kwargs: typing.Dict[str, typing.Any] = None,
115+
timeout: timedelta | None = None,
112116
) -> "Task":
113117
if queue_size and cls._is_queue_full(task_identifier, queue_size):
114118
raise TaskQueueFullError(
@@ -121,6 +125,7 @@ def create(
121125
priority=priority,
122126
serialized_args=cls.serialize_data(args or tuple()),
123127
serialized_kwargs=cls.serialize_data(kwargs or dict()),
128+
timeout=timeout,
124129
)
125130

126131
@classmethod
@@ -146,6 +151,9 @@ def mark_success(self):
146151
class RecurringTask(AbstractBaseTask):
147152
run_every = models.DurationField()
148153
first_run_time = models.TimeField(blank=True, null=True)
154+
locked_at = models.DateTimeField(blank=True, null=True)
155+
156+
timeout = models.DurationField(default=timedelta(minutes=30))
149157

150158
objects = RecurringTaskManager()
151159

task_processor/processor.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import logging
22
import traceback
33
import typing
4+
from concurrent.futures import ThreadPoolExecutor
45
from datetime import timedelta
56

67
from django.utils import timezone
@@ -78,7 +79,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]:
7879

7980
# update all tasks that were not deleted
8081
to_update = [task for task in tasks if task.id]
81-
RecurringTask.objects.bulk_update(to_update, fields=["is_locked"])
82+
RecurringTask.objects.bulk_update(to_update, fields=["is_locked", "locked_at"])
8283

8384
if task_runs:
8485
RecurringTaskRun.objects.bulk_create(task_runs)
@@ -93,16 +94,25 @@ def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, Tas
9394
task_run = task.task_runs.model(started_at=timezone.now(), task=task)
9495

9596
try:
96-
task.run()
97-
task_run.result = TaskResult.SUCCESS
97+
with ThreadPoolExecutor(max_workers=1) as executor:
98+
future = executor.submit(task.run)
99+
timeout = task.timeout.total_seconds() if task.timeout else None
100+
future.result(timeout=timeout) # Wait for completion or timeout
98101

102+
task_run.result = TaskResult.SUCCESS
99103
task_run.finished_at = timezone.now()
100104
task.mark_success()
105+
101106
except Exception as e:
107+
# For errors that don't include a default message (e.g., TimeoutError),
108+
# fall back to using repr.
109+
err_msg = str(e) or repr(e)
110+
102111
logger.error(
103-
"Failed to execute task '%s'. Exception was: %s",
112+
"Failed to execute task '%s', with id %d. Exception: %s",
104113
task.task_identifier,
105-
str(e),
114+
task.id,
115+
err_msg,
106116
exc_info=True,
107117
)
108118
logger.debug("args: %s", str(task.args))

0 commit comments

Comments
 (0)