From 6e07161fb0e512dcd4703a4330e31f486c684e60 Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Mon, 6 Jan 2025 10:24:52 +0530 Subject: [PATCH 1/3] fix(17/recurring-task-lock): Add timeout to auto unlock task --- task_processor/decorators.py | 8 ++ .../0012_add_locked_at_and_timeout.py | 39 ++++++ .../0012_get_recurringtasks_to_process.sql | 32 +++++ task_processor/models.py | 10 +- task_processor/processor.py | 20 ++- .../test_unit_task_processor_processor.py | 130 +++++++++++++++++- 6 files changed, 228 insertions(+), 11 deletions(-) create mode 100644 task_processor/migrations/0012_add_locked_at_and_timeout.py create mode 100644 task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql diff --git a/task_processor/decorators.py b/task_processor/decorators.py index f5c306b..27958c7 100644 --- a/task_processor/decorators.py +++ b/task_processor/decorators.py @@ -26,6 +26,7 @@ class TaskHandler(typing.Generic[P]): "priority", "transaction_on_commit", "task_identifier", + "timeout", ) unwrapped: typing.Callable[P, None] @@ -38,11 +39,13 @@ def __init__( queue_size: int | None = None, priority: TaskPriority = TaskPriority.NORMAL, transaction_on_commit: bool = True, + timeout: timedelta | None = None, ) -> None: self.unwrapped = f self.queue_size = queue_size self.priority = priority self.transaction_on_commit = transaction_on_commit + self.timeout = timeout task_name = task_name or f.__name__ task_module = getmodule(f).__name__.rsplit(".")[-1] @@ -87,6 +90,7 @@ def delay( scheduled_for=delay_until or timezone.now(), priority=self.priority, queue_size=self.queue_size, + timeout=self.timeout, args=args, kwargs=kwargs, ) @@ -124,6 +128,7 @@ def register_task_handler( # noqa: C901 queue_size: int | None = None, priority: TaskPriority = TaskPriority.NORMAL, transaction_on_commit: bool = True, + timeout: timedelta | None = timedelta(seconds=60), ) -> typing.Callable[[typing.Callable[P, None]], TaskHandler[P]]: """ Turn a function into an asynchronous task. @@ -150,6 +155,7 @@ def wrapper(f: typing.Callable[P, None]) -> TaskHandler[P]: queue_size=queue_size, priority=priority, transaction_on_commit=transaction_on_commit, + timeout=timeout, ) return wrapper @@ -161,6 +167,7 @@ def register_recurring_task( args: tuple[typing.Any] = (), kwargs: dict[str, typing.Any] | None = None, first_run_time: time | None = None, + timeout: timedelta | None = timedelta(minutes=30), ) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]: if not os.environ.get("RUN_BY_PROCESSOR"): # Do not register recurring tasks if not invoked by task processor @@ -182,6 +189,7 @@ def decorator(f: typing.Callable[..., None]) -> RecurringTask: "serialized_kwargs": RecurringTask.serialize_data(kwargs or {}), "run_every": run_every, "first_run_time": first_run_time, + "timeout": timeout, }, ) return task diff --git a/task_processor/migrations/0012_add_locked_at_and_timeout.py b/task_processor/migrations/0012_add_locked_at_and_timeout.py new file mode 100644 index 0000000..a3f65a3 --- /dev/null +++ b/task_processor/migrations/0012_add_locked_at_and_timeout.py @@ -0,0 +1,39 @@ +# Generated by Django 3.2.23 on 2025-01-06 04:51 + +from task_processor.migrations.helpers import PostgresOnlyRunSQL +import datetime +from django.db import migrations, models +import os + + +class Migration(migrations.Migration): + + dependencies = [ + ("task_processor", "0011_add_priority_to_get_tasks_to_process"), + ] + + operations = [ + migrations.AddField( + model_name="recurringtask", + name="locked_at", + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name="recurringtask", + name="timeout", + field=models.DurationField(default=datetime.timedelta(minutes=30)), + ), + migrations.AddField( + model_name="task", + name="timeout", + field=models.DurationField(blank=True, null=True), + ), + PostgresOnlyRunSQL.from_sql_file( + os.path.join( + os.path.dirname(__file__), + "sql", + "0012_get_recurringtasks_to_process.sql", + ), + reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process", + ), + ] diff --git a/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql b/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql new file mode 100644 index 0000000..d8483f2 --- /dev/null +++ b/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql @@ -0,0 +1,32 @@ +CREATE OR REPLACE FUNCTION get_recurringtasks_to_process(num_tasks integer) +RETURNS SETOF task_processor_recurringtask AS $$ +DECLARE + row_to_return task_processor_recurringtask; +BEGIN + -- Select the tasks that needs to be processed + FOR row_to_return IN + SELECT * + FROM task_processor_recurringtask + WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout) + ORDER BY id + LIMIT num_tasks + -- Select for update to ensure that no other workers can select these tasks while in this transaction block + FOR UPDATE SKIP LOCKED + LOOP + -- Lock every selected task(by updating `is_locked` to true) + UPDATE task_processor_recurringtask + -- Lock this row by setting is_locked True, so that no other workers can select these tasks after this + -- transaction is complete (but the tasks are still being executed by the current worker) + SET is_locked = TRUE, locked_at = NOW() + WHERE id = row_to_return.id; + -- If we don't explicitly update the columns here, the client will receive a row + -- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`. + row_to_return.is_locked := TRUE; + row_to_return.locked_at := NOW(); + RETURN NEXT row_to_return; + END LOOP; + + RETURN; +END; +$$ LANGUAGE plpgsql + diff --git a/task_processor/models.py b/task_processor/models.py index 9093b22..c6e6248 100644 --- a/task_processor/models.py +++ b/task_processor/models.py @@ -1,6 +1,6 @@ import typing import uuid -from datetime import datetime +from datetime import datetime, timedelta import simplejson as json from django.core.serializers.json import DjangoJSONEncoder @@ -61,6 +61,7 @@ def mark_success(self): def unlock(self): self.is_locked = False + self.locked_at = None def run(self): return self.callable(*self.args, **self.kwargs) @@ -80,6 +81,8 @@ def callable(self) -> typing.Callable: class Task(AbstractBaseTask): scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now) + timeout = models.DurationField(null=True, blank=True) + # denormalise failures and completion so that we can use select_for_update num_failures = models.IntegerField(default=0) completed = models.BooleanField(default=False) @@ -109,6 +112,7 @@ def create( *, args: typing.Tuple[typing.Any] = None, kwargs: typing.Dict[str, typing.Any] = None, + timeout: timedelta | None = None, ) -> "Task": if queue_size and cls._is_queue_full(task_identifier, queue_size): raise TaskQueueFullError( @@ -121,6 +125,7 @@ def create( priority=priority, serialized_args=cls.serialize_data(args or tuple()), serialized_kwargs=cls.serialize_data(kwargs or dict()), + timeout=timeout, ) @classmethod @@ -146,6 +151,9 @@ def mark_success(self): class RecurringTask(AbstractBaseTask): run_every = models.DurationField() first_run_time = models.TimeField(blank=True, null=True) + locked_at = models.DateTimeField(blank=True, null=True) + + timeout = models.DurationField(default=timedelta(minutes=30)) objects = RecurringTaskManager() diff --git a/task_processor/processor.py b/task_processor/processor.py index 93d5436..e1281ae 100644 --- a/task_processor/processor.py +++ b/task_processor/processor.py @@ -1,6 +1,7 @@ import logging import traceback import typing +from concurrent.futures import ThreadPoolExecutor from datetime import timedelta from django.utils import timezone @@ -78,7 +79,7 @@ def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]: # update all tasks that were not deleted to_update = [task for task in tasks if task.id] - RecurringTask.objects.bulk_update(to_update, fields=["is_locked"]) + RecurringTask.objects.bulk_update(to_update, fields=["is_locked", "locked_at"]) if task_runs: RecurringTaskRun.objects.bulk_create(task_runs) @@ -93,16 +94,25 @@ def _run_task(task: typing.Union[Task, RecurringTask]) -> typing.Tuple[Task, Tas task_run = task.task_runs.model(started_at=timezone.now(), task=task) try: - task.run() - task_run.result = TaskResult.SUCCESS + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(task.run) + timeout = task.timeout.total_seconds() if task.timeout else None + future.result(timeout=timeout) # Wait for completion or timeout + task_run.result = TaskResult.SUCCESS task_run.finished_at = timezone.now() task.mark_success() + except Exception as e: + # For errors that don't include a default message (e.g., TimeoutError), + # fall back to using repr. + err_msg = str(e) or repr(e) + logger.error( - "Failed to execute task '%s'. Exception was: %s", + "Failed to execute task '%s', with id %d. Exception: %s", task.task_identifier, - str(e), + task.id, + err_msg, exc_info=True, ) logger.debug("args: %s", str(task.args)) diff --git a/tests/unit/task_processor/test_unit_task_processor_processor.py b/tests/unit/task_processor/test_unit_task_processor_processor.py index 04826a3..aab4ae6 100644 --- a/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -1,5 +1,6 @@ import logging import time +import typing import uuid from datetime import timedelta from threading import Thread @@ -8,6 +9,7 @@ from django.core.cache import cache from django.utils import timezone from freezegun import freeze_time +from pytest import MonkeyPatch from task_processor.decorators import ( register_recurring_task, @@ -28,6 +30,11 @@ ) from task_processor.task_registry import registered_tasks +if typing.TYPE_CHECKING: + # This import breaks private-package-test workflow in core + from tests.unit.task_processor.conftest import GetTaskProcessorCaplog + + DEFAULT_CACHE_KEY = "foo" DEFAULT_CACHE_VALUE = "bar" @@ -63,6 +70,83 @@ def test_run_task_runs_task_and_creates_task_run_object_when_success(db): assert task.completed +def test_run_task_kills_task_after_timeout( + db: None, + get_task_processor_caplog: "GetTaskProcessorCaplog", +) -> None: + # Given + caplog = get_task_processor_caplog(logging.ERROR) + task = Task.create( + _sleep.task_identifier, + scheduled_for=timezone.now(), + args=(1,), + timeout=timedelta(microseconds=1), + ) + task.save() + + # When + task_runs = run_tasks() + + # Then + assert len(task_runs) == TaskRun.objects.filter(task=task).count() == 1 + task_run = task_runs[0] + assert task_run.result == TaskResult.FAILURE + assert task_run.started_at + assert task_run.finished_at is None + assert "TimeoutError" in task_run.error_details + + task.refresh_from_db() + + assert task.completed is False + assert task.num_failures == 1 + assert task.is_locked is False + + assert len(caplog.records) == 1 + assert caplog.records[0].message == ( + f"Failed to execute task '{task.task_identifier}', with id {task.id}. Exception: TimeoutError()" + ) + + +def test_run_recurring_task_kills_task_after_timeout( + db: None, + monkeypatch: MonkeyPatch, + get_task_processor_caplog: "GetTaskProcessorCaplog", +) -> None: + # Given + caplog = get_task_processor_caplog(logging.ERROR) + monkeypatch.setenv("RUN_BY_PROCESSOR", "True") + + @register_recurring_task( + run_every=timedelta(seconds=1), timeout=timedelta(microseconds=1) + ) + def _dummy_recurring_task(): + time.sleep(1) + + task = RecurringTask.objects.get( + task_identifier=_dummy_recurring_task.task_identifier + ) + # When + task_runs = run_recurring_tasks() + + # Then + assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 1 + task_run = task_runs[0] + assert task_run.result == TaskResult.FAILURE + assert task_run.started_at + assert task_run.finished_at is None + assert "TimeoutError" in task_run.error_details + + task.refresh_from_db() + + assert task.locked_at is None + assert task.is_locked is False + + assert len(caplog.records) == 1 + assert caplog.records[0].message == ( + f"Failed to execute task '{task.task_identifier}', with id {task.id}. Exception: TimeoutError()" + ) + + def test_run_recurring_tasks_runs_task_and_creates_recurring_task_run_object_when_success( db, monkeypatch, @@ -91,6 +175,43 @@ def _dummy_recurring_task(): assert task_run.error_details is None +def test_run_recurring_tasks_runs_locked_task_after_tiemout( + db: None, + monkeypatch: MonkeyPatch, +) -> None: + # Given + monkeypatch.setenv("RUN_BY_PROCESSOR", "True") + + @register_recurring_task(run_every=timedelta(hours=1)) + def _dummy_recurring_task(): + cache.set(DEFAULT_CACHE_KEY, DEFAULT_CACHE_VALUE) + + task = RecurringTask.objects.get( + task_identifier=_dummy_recurring_task.task_identifier + ) + task.is_locked = True + task.locked_at = timezone.now() - timedelta(hours=1) + task.save() + + # When + task_runs = run_recurring_tasks() + + # Then + assert cache.get(DEFAULT_CACHE_KEY) + + assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 1 + task_run = task_runs[0] + assert task_run.result == TaskResult.SUCCESS + assert task_run.started_at + assert task_run.finished_at + assert task_run.error_details is None + + # And the task is no longer locked + task.refresh_from_db() + assert task.is_locked is False + assert task.locked_at is None + + @pytest.mark.django_db(transaction=True) def test_run_recurring_tasks_multiple_runs(db, run_by_processor): # Given @@ -211,12 +332,11 @@ def _a_task(): def test_run_task_runs_task_and_creates_task_run_object_when_failure( - db: None, caplog: pytest.LogCaptureFixture + db: None, + get_task_processor_caplog: "GetTaskProcessorCaplog", ) -> None: # Given - task_processor_logger = logging.getLogger("task_processor") - task_processor_logger.propagate = True - task_processor_logger.level = logging.DEBUG + caplog = get_task_processor_caplog(logging.DEBUG) msg = "Error!" task = Task.create( @@ -243,7 +363,7 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure( log_record = caplog.records[0] assert log_record.levelname == "ERROR" assert log_record.message == ( - f"Failed to execute task '{task.task_identifier}'. Exception was: {msg}" + f"Failed to execute task '{task.task_identifier}', with id {task.id}. Exception: {msg}" ) debug_log_args, debug_log_kwargs = caplog.records[1:] From d8a4c3740446d8ea12f96e4c6f4305bbcbfc05f0 Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Wed, 8 Jan 2025 09:48:26 +0530 Subject: [PATCH 2/3] fix locks for non recurring tasks --- .../0012_add_locked_at_and_timeout.py | 25 +++++++++++-- .../0012_get_recurringtasks_to_process.sql | 2 +- .../sql/0012_get_tasks_to_process.sql | 31 ++++++++++++++++ task_processor/models.py | 7 ++-- task_processor/processor.py | 3 +- .../test_unit_task_processor_processor.py | 36 ++++++++++++++++++- 6 files changed, 96 insertions(+), 8 deletions(-) create mode 100644 task_processor/migrations/sql/0012_get_tasks_to_process.sql diff --git a/task_processor/migrations/0012_add_locked_at_and_timeout.py b/task_processor/migrations/0012_add_locked_at_and_timeout.py index a3f65a3..c779b63 100644 --- a/task_processor/migrations/0012_add_locked_at_and_timeout.py +++ b/task_processor/migrations/0012_add_locked_at_and_timeout.py @@ -18,6 +18,11 @@ class Migration(migrations.Migration): name="locked_at", field=models.DateTimeField(blank=True, null=True), ), + migrations.AddField( + model_name="task", + name="locked_at", + field=models.DateTimeField(blank=True, null=True), + ), migrations.AddField( model_name="recurringtask", name="timeout", @@ -26,7 +31,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="task", name="timeout", - field=models.DurationField(blank=True, null=True), + field=models.DurationField(default=datetime.timedelta(minutes=1)), ), PostgresOnlyRunSQL.from_sql_file( os.path.join( @@ -34,6 +39,22 @@ class Migration(migrations.Migration): "sql", "0012_get_recurringtasks_to_process.sql", ), - reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process", + reverse_sql=os.path.join( + os.path.dirname(__file__), + "sql", + "0008_get_recurringtasks_to_process.sql", + ), + ), + PostgresOnlyRunSQL.from_sql_file( + os.path.join( + os.path.dirname(__file__), + "sql", + "0012_get_tasks_to_process.sql", + ), + reverse_sql=os.path.join( + os.path.dirname(__file__), + "sql", + "0011_get_tasks_to_process.sql", + ), ), ] diff --git a/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql b/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql index d8483f2..9778640 100644 --- a/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql +++ b/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql @@ -7,7 +7,7 @@ BEGIN FOR row_to_return IN SELECT * FROM task_processor_recurringtask - WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout) + WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout) ORDER BY id LIMIT num_tasks -- Select for update to ensure that no other workers can select these tasks while in this transaction block diff --git a/task_processor/migrations/sql/0012_get_tasks_to_process.sql b/task_processor/migrations/sql/0012_get_tasks_to_process.sql new file mode 100644 index 0000000..1c06ef3 --- /dev/null +++ b/task_processor/migrations/sql/0012_get_tasks_to_process.sql @@ -0,0 +1,31 @@ +CREATE OR REPLACE FUNCTION get_tasks_to_process(num_tasks integer) +RETURNS SETOF task_processor_task AS $$ +DECLARE + row_to_return task_processor_task; +BEGIN + -- Select the tasks that needs to be processed + FOR row_to_return IN + SELECT * + FROM task_processor_task + WHERE num_failures < 3 AND scheduled_for < NOW() AND completed = FALSE AND (is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout)) + ORDER BY priority ASC, scheduled_for ASC, created_at ASC + LIMIT num_tasks + -- Select for update to ensure that no other workers can select these tasks while in this transaction block + FOR UPDATE SKIP LOCKED + LOOP + -- Lock every selected task(by updating `is_locked` to true) + UPDATE task_processor_task + -- Lock this row by setting is_locked True, so that no other workers can select these tasks after this + -- transaction is complete (but the tasks are still being executed by the current worker) + SET is_locked = TRUE, locked_at = NOW() + WHERE id = row_to_return.id; + -- If we don't explicitly update the columns here, the client will receive a row + -- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`. + row_to_return.is_locked := TRUE; + RETURN NEXT row_to_return; + END LOOP; + + RETURN; +END; +$$ LANGUAGE plpgsql + diff --git a/task_processor/models.py b/task_processor/models.py index c6e6248..c871c04 100644 --- a/task_processor/models.py +++ b/task_processor/models.py @@ -30,6 +30,8 @@ class AbstractBaseTask(models.Model): serialized_kwargs = models.TextField(blank=True, null=True) is_locked = models.BooleanField(default=False) + locked_at = models.DateTimeField(blank=True, null=True) + class Meta: abstract = True @@ -81,7 +83,7 @@ def callable(self) -> typing.Callable: class Task(AbstractBaseTask): scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now) - timeout = models.DurationField(null=True, blank=True) + timeout = models.DurationField(default=timedelta(minutes=1)) # denormalise failures and completion so that we can use select_for_update num_failures = models.IntegerField(default=0) @@ -112,7 +114,7 @@ def create( *, args: typing.Tuple[typing.Any] = None, kwargs: typing.Dict[str, typing.Any] = None, - timeout: timedelta | None = None, + timeout: timedelta | None = timedelta(seconds=60), ) -> "Task": if queue_size and cls._is_queue_full(task_identifier, queue_size): raise TaskQueueFullError( @@ -151,7 +153,6 @@ def mark_success(self): class RecurringTask(AbstractBaseTask): run_every = models.DurationField() first_run_time = models.TimeField(blank=True, null=True) - locked_at = models.DateTimeField(blank=True, null=True) timeout = models.DurationField(default=timedelta(minutes=30)) diff --git a/task_processor/processor.py b/task_processor/processor.py index e1281ae..456ff51 100644 --- a/task_processor/processor.py +++ b/task_processor/processor.py @@ -37,7 +37,8 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]: if executed_tasks: Task.objects.bulk_update( - executed_tasks, fields=["completed", "num_failures", "is_locked"] + executed_tasks, + fields=["completed", "num_failures", "is_locked", "locked_at"], ) if task_runs: diff --git a/tests/unit/task_processor/test_unit_task_processor_processor.py b/tests/unit/task_processor/test_unit_task_processor_processor.py index aab4ae6..bec719d 100644 --- a/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -70,6 +70,39 @@ def test_run_task_runs_task_and_creates_task_run_object_when_success(db): assert task.completed +def test_run_tasks_runs_locked_task_after_tiemout( + db: None, +) -> None: + # Given + task = Task.create( + _dummy_task.task_identifier, + timeout=timedelta(seconds=10), + scheduled_for=timezone.now(), + ) + task.is_locked = True + task.locked_at = timezone.now() - timedelta(minutes=1) + task.save() + + # When + assert cache.get(DEFAULT_CACHE_KEY) is None + task_runs = run_tasks() + + # Then + assert cache.get(DEFAULT_CACHE_KEY) == DEFAULT_CACHE_VALUE + + assert len(task_runs) == TaskRun.objects.filter(task=task).count() == 1 + task_run = task_runs[0] + assert task_run.result == TaskResult.SUCCESS + assert task_run.started_at + assert task_run.finished_at + assert task_run.error_details is None + + # And the task is no longer locked + task.refresh_from_db() + assert task.is_locked is False + assert task.locked_at is None + + def test_run_task_kills_task_after_timeout( db: None, get_task_processor_caplog: "GetTaskProcessorCaplog", @@ -194,10 +227,11 @@ def _dummy_recurring_task(): task.save() # When + assert cache.get(DEFAULT_CACHE_KEY) is None task_runs = run_recurring_tasks() # Then - assert cache.get(DEFAULT_CACHE_KEY) + assert cache.get(DEFAULT_CACHE_KEY) == DEFAULT_CACHE_VALUE assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 1 task_run = task_runs[0] From 64fa4e784fa3c2941b3c676b2593fbd074596412 Mon Sep 17 00:00:00 2001 From: Gagan Trivedi Date: Fri, 10 Jan 2025 16:15:04 +0530 Subject: [PATCH 3/3] remove auto unlock from tasks --- task_processor/managers.py | 4 +-- .../0012_add_locked_at_and_timeout.py | 25 ++------------ .../0012_get_recurringtasks_to_process.sql | 7 ++-- .../sql/0012_get_tasks_to_process.sql | 31 ----------------- task_processor/models.py | 10 +++--- task_processor/processor.py | 9 ++--- task_processor/threads.py | 2 +- .../test_unit_task_processor_processor.py | 33 ------------------- 8 files changed, 18 insertions(+), 103 deletions(-) delete mode 100644 task_processor/migrations/sql/0012_get_tasks_to_process.sql diff --git a/task_processor/managers.py b/task_processor/managers.py index a04c3e4..ca3db32 100644 --- a/task_processor/managers.py +++ b/task_processor/managers.py @@ -12,5 +12,5 @@ def get_tasks_to_process(self, num_tasks: int) -> QuerySet["Task"]: class RecurringTaskManager(Manager): - def get_tasks_to_process(self, num_tasks: int) -> QuerySet["RecurringTask"]: - return self.raw("SELECT * FROM get_recurringtasks_to_process(%s)", [num_tasks]) + def get_tasks_to_process(self) -> QuerySet["RecurringTask"]: + return self.raw("SELECT * FROM get_recurringtasks_to_process()") diff --git a/task_processor/migrations/0012_add_locked_at_and_timeout.py b/task_processor/migrations/0012_add_locked_at_and_timeout.py index c779b63..866d40e 100644 --- a/task_processor/migrations/0012_add_locked_at_and_timeout.py +++ b/task_processor/migrations/0012_add_locked_at_and_timeout.py @@ -18,11 +18,6 @@ class Migration(migrations.Migration): name="locked_at", field=models.DateTimeField(blank=True, null=True), ), - migrations.AddField( - model_name="task", - name="locked_at", - field=models.DateTimeField(blank=True, null=True), - ), migrations.AddField( model_name="recurringtask", name="timeout", @@ -31,7 +26,7 @@ class Migration(migrations.Migration): migrations.AddField( model_name="task", name="timeout", - field=models.DurationField(default=datetime.timedelta(minutes=1)), + field=models.DurationField(blank=True, null=True), ), PostgresOnlyRunSQL.from_sql_file( os.path.join( @@ -39,22 +34,6 @@ class Migration(migrations.Migration): "sql", "0012_get_recurringtasks_to_process.sql", ), - reverse_sql=os.path.join( - os.path.dirname(__file__), - "sql", - "0008_get_recurringtasks_to_process.sql", - ), - ), - PostgresOnlyRunSQL.from_sql_file( - os.path.join( - os.path.dirname(__file__), - "sql", - "0012_get_tasks_to_process.sql", - ), - reverse_sql=os.path.join( - os.path.dirname(__file__), - "sql", - "0011_get_tasks_to_process.sql", - ), + reverse_sql="DROP FUNCTION IF EXISTS get_recurringtasks_to_process()", ), ] diff --git a/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql b/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql index 9778640..52bec14 100644 --- a/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql +++ b/task_processor/migrations/sql/0012_get_recurringtasks_to_process.sql @@ -1,4 +1,4 @@ -CREATE OR REPLACE FUNCTION get_recurringtasks_to_process(num_tasks integer) +CREATE OR REPLACE FUNCTION get_recurringtasks_to_process() RETURNS SETOF task_processor_recurringtask AS $$ DECLARE row_to_return task_processor_recurringtask; @@ -7,9 +7,10 @@ BEGIN FOR row_to_return IN SELECT * FROM task_processor_recurringtask - WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout) + -- Add one minute to the timeout as a grace period for overhead + WHERE is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout + INTERVAL '1 minute') ORDER BY id - LIMIT num_tasks + LIMIT 1 -- Select for update to ensure that no other workers can select these tasks while in this transaction block FOR UPDATE SKIP LOCKED LOOP diff --git a/task_processor/migrations/sql/0012_get_tasks_to_process.sql b/task_processor/migrations/sql/0012_get_tasks_to_process.sql deleted file mode 100644 index 1c06ef3..0000000 --- a/task_processor/migrations/sql/0012_get_tasks_to_process.sql +++ /dev/null @@ -1,31 +0,0 @@ -CREATE OR REPLACE FUNCTION get_tasks_to_process(num_tasks integer) -RETURNS SETOF task_processor_task AS $$ -DECLARE - row_to_return task_processor_task; -BEGIN - -- Select the tasks that needs to be processed - FOR row_to_return IN - SELECT * - FROM task_processor_task - WHERE num_failures < 3 AND scheduled_for < NOW() AND completed = FALSE AND (is_locked = FALSE OR (locked_at IS NOT NULL AND locked_at < NOW() - timeout)) - ORDER BY priority ASC, scheduled_for ASC, created_at ASC - LIMIT num_tasks - -- Select for update to ensure that no other workers can select these tasks while in this transaction block - FOR UPDATE SKIP LOCKED - LOOP - -- Lock every selected task(by updating `is_locked` to true) - UPDATE task_processor_task - -- Lock this row by setting is_locked True, so that no other workers can select these tasks after this - -- transaction is complete (but the tasks are still being executed by the current worker) - SET is_locked = TRUE, locked_at = NOW() - WHERE id = row_to_return.id; - -- If we don't explicitly update the columns here, the client will receive a row - -- that is locked but still shows `is_locked` as `False` and `locked_at` as `None`. - row_to_return.is_locked := TRUE; - RETURN NEXT row_to_return; - END LOOP; - - RETURN; -END; -$$ LANGUAGE plpgsql - diff --git a/task_processor/models.py b/task_processor/models.py index c871c04..fc234a5 100644 --- a/task_processor/models.py +++ b/task_processor/models.py @@ -30,8 +30,6 @@ class AbstractBaseTask(models.Model): serialized_kwargs = models.TextField(blank=True, null=True) is_locked = models.BooleanField(default=False) - locked_at = models.DateTimeField(blank=True, null=True) - class Meta: abstract = True @@ -63,7 +61,6 @@ def mark_success(self): def unlock(self): self.is_locked = False - self.locked_at = None def run(self): return self.callable(*self.args, **self.kwargs) @@ -83,7 +80,7 @@ def callable(self) -> typing.Callable: class Task(AbstractBaseTask): scheduled_for = models.DateTimeField(blank=True, null=True, default=timezone.now) - timeout = models.DurationField(default=timedelta(minutes=1)) + timeout = models.DurationField(blank=True, null=True) # denormalise failures and completion so that we can use select_for_update num_failures = models.IntegerField(default=0) @@ -154,6 +151,7 @@ class RecurringTask(AbstractBaseTask): run_every = models.DurationField() first_run_time = models.TimeField(blank=True, null=True) + locked_at = models.DateTimeField(blank=True, null=True) timeout = models.DurationField(default=timedelta(minutes=30)) objects = RecurringTaskManager() @@ -166,6 +164,10 @@ class Meta: ), ] + def unlock(self): + self.is_locked = False + self.locked_at = None + @property def should_execute(self) -> bool: now = timezone.now() diff --git a/task_processor/processor.py b/task_processor/processor.py index 456ff51..da7ac83 100644 --- a/task_processor/processor.py +++ b/task_processor/processor.py @@ -38,7 +38,7 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]: if executed_tasks: Task.objects.bulk_update( executed_tasks, - fields=["completed", "num_failures", "is_locked", "locked_at"], + fields=["completed", "num_failures", "is_locked"], ) if task_runs: @@ -50,14 +50,11 @@ def run_tasks(num_tasks: int = 1) -> typing.List[TaskRun]: return [] -def run_recurring_tasks(num_tasks: int = 1) -> typing.List[RecurringTaskRun]: - if num_tasks < 1: - raise ValueError("Number of tasks to process must be at least one") - +def run_recurring_tasks() -> typing.List[RecurringTaskRun]: # NOTE: We will probably see a lot of delay in the execution of recurring tasks # if the tasks take longer then `run_every` to execute. This is not # a problem for now, but we should be mindful of this limitation - tasks = RecurringTask.objects.get_tasks_to_process(num_tasks) + tasks = RecurringTask.objects.get_tasks_to_process() if tasks: task_runs = [] diff --git a/task_processor/threads.py b/task_processor/threads.py index b9a3a2d..42973fb 100644 --- a/task_processor/threads.py +++ b/task_processor/threads.py @@ -34,7 +34,7 @@ def run(self) -> None: def run_iteration(self) -> None: try: run_tasks(self.queue_pop_size) - run_recurring_tasks(self.queue_pop_size) + run_recurring_tasks() except Exception as e: # To prevent task threads from dying if they get an error retrieving the tasks from the # database this will allow the thread to continue trying to retrieve tasks if it can diff --git a/tests/unit/task_processor/test_unit_task_processor_processor.py b/tests/unit/task_processor/test_unit_task_processor_processor.py index bec719d..2794167 100644 --- a/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -70,39 +70,6 @@ def test_run_task_runs_task_and_creates_task_run_object_when_success(db): assert task.completed -def test_run_tasks_runs_locked_task_after_tiemout( - db: None, -) -> None: - # Given - task = Task.create( - _dummy_task.task_identifier, - timeout=timedelta(seconds=10), - scheduled_for=timezone.now(), - ) - task.is_locked = True - task.locked_at = timezone.now() - timedelta(minutes=1) - task.save() - - # When - assert cache.get(DEFAULT_CACHE_KEY) is None - task_runs = run_tasks() - - # Then - assert cache.get(DEFAULT_CACHE_KEY) == DEFAULT_CACHE_VALUE - - assert len(task_runs) == TaskRun.objects.filter(task=task).count() == 1 - task_run = task_runs[0] - assert task_run.result == TaskResult.SUCCESS - assert task_run.started_at - assert task_run.finished_at - assert task_run.error_details is None - - # And the task is no longer locked - task.refresh_from_db() - assert task.is_locked is False - assert task.locked_at is None - - def test_run_task_kills_task_after_timeout( db: None, get_task_processor_caplog: "GetTaskProcessorCaplog",