Skip to content
This repository was archived by the owner on Mar 28, 2025. It is now read-only.

Commit c833e22

Browse files
fix(recurring-task): reduce update load on task_processor_recurringtask table (#23)
Co-authored-by: Kim Gustyr <[email protected]>
1 parent c719934 commit c833e22

File tree

8 files changed

+183
-78
lines changed

8 files changed

+183
-78
lines changed

Diff for: task_processor/decorators.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from django.db.transaction import on_commit
1010
from django.utils import timezone
1111

12+
from task_processor import task_registry
1213
from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError
1314
from task_processor.models import RecurringTask, Task, TaskPriority
14-
from task_processor.task_registry import register_task
1515
from task_processor.task_run_method import TaskRunMethod
1616

1717
P = typing.ParamSpec("P")
@@ -50,7 +50,7 @@ def __init__(
5050
task_name = task_name or f.__name__
5151
task_module = getmodule(f).__name__.rsplit(".")[-1]
5252
self.task_identifier = task_identifier = f"{task_module}.{task_name}"
53-
register_task(task_identifier, f)
53+
task_registry.register_task(task_identifier, f)
5454

5555
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None:
5656
_validate_inputs(*args, **kwargs)
@@ -168,31 +168,27 @@ def register_recurring_task(
168168
kwargs: dict[str, typing.Any] | None = None,
169169
first_run_time: time | None = None,
170170
timeout: timedelta | None = timedelta(minutes=30),
171-
) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
171+
) -> typing.Callable[[typing.Callable[..., None]], None]:
172172
if not os.environ.get("RUN_BY_PROCESSOR"):
173173
# Do not register recurring tasks if not invoked by task processor
174174
return lambda f: f
175175

176-
def decorator(f: typing.Callable[..., None]) -> RecurringTask:
176+
def decorator(f: typing.Callable[..., None]) -> None:
177177
nonlocal task_name
178178

179179
task_name = task_name or f.__name__
180180
task_module = getmodule(f).__name__.rsplit(".")[-1]
181181
task_identifier = f"{task_module}.{task_name}"
182182

183-
register_task(task_identifier, f)
184-
185-
task, _ = RecurringTask.objects.update_or_create(
186-
task_identifier=task_identifier,
187-
defaults={
188-
"serialized_args": RecurringTask.serialize_data(args or ()),
189-
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
190-
"run_every": run_every,
191-
"first_run_time": first_run_time,
192-
"timeout": timeout,
193-
},
194-
)
195-
return task
183+
task_kwargs = {
184+
"serialized_args": RecurringTask.serialize_data(args or ()),
185+
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
186+
"run_every": run_every,
187+
"first_run_time": first_run_time,
188+
"timeout": timeout,
189+
}
190+
191+
task_registry.register_recurring_task(task_identifier, f, **task_kwargs)
196192

197193
return decorator
198194

Diff for: task_processor/management/commands/runprocessor.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.core.management import BaseCommand
88
from django.utils import timezone
99

10-
from task_processor.task_registry import registered_tasks
10+
from task_processor.task_registry import initialise
1111
from task_processor.thread_monitoring import (
1212
clear_unhealthy_threads,
1313
write_unhealthy_threads,
@@ -74,10 +74,9 @@ def handle(self, *args, **options):
7474
]
7575
)
7676

77-
logger.info(
78-
"Processor starting. Registered tasks are: %s",
79-
list(registered_tasks.keys()),
80-
)
77+
logger.info("Processor starting")
78+
79+
initialise()
8180

8281
for thread in self._threads:
8382
thread.start()

Diff for: task_processor/models.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def run(self):
6868
@property
6969
def callable(self) -> typing.Callable:
7070
try:
71-
return registered_tasks[self.task_identifier]
71+
task = registered_tasks[self.task_identifier]
72+
return task.task_function
7273
except KeyError as e:
7374
raise TaskProcessingError(
7475
"No task registered with identifier '%s'. Ensure your task is "

Diff for: task_processor/task_registry.py

+60-8
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,75 @@
1+
import enum
12
import logging
23
import typing
4+
from dataclasses import dataclass
35

46
logger = logging.getLogger(__name__)
57

6-
registered_tasks: typing.Dict[str, typing.Callable] = {}
78

9+
class TaskType(enum.Enum):
10+
STANDARD = "STANDARD"
11+
RECURRING = "RECURRING"
812

9-
def register_task(task_identifier: str, callable_: typing.Callable):
13+
14+
@dataclass
15+
class RegisteredTask:
16+
task_identifier: str
17+
task_function: typing.Callable
18+
task_type: TaskType = TaskType.STANDARD
19+
task_kwargs: typing.Dict[str, typing.Any] = None
20+
21+
22+
registered_tasks: typing.Dict[str, RegisteredTask] = {}
23+
24+
25+
def initialise() -> None:
1026
global registered_tasks
1127

12-
logger.debug("Registering task '%s'", task_identifier)
28+
from task_processor.models import RecurringTask
29+
30+
for task_identifier, registered_task in registered_tasks.items():
31+
logger.debug("Initialising task '%s'", task_identifier)
32+
33+
if registered_task.task_type == TaskType.RECURRING:
34+
logger.debug("Persisting recurring task '%s'", task_identifier)
35+
RecurringTask.objects.update_or_create(
36+
task_identifier=task_identifier,
37+
defaults=registered_task.task_kwargs,
38+
)
1339

14-
registered_tasks[task_identifier] = callable_
40+
41+
def get_task(task_identifier: str) -> RegisteredTask:
42+
global registered_tasks
43+
44+
return registered_tasks[task_identifier]
45+
46+
47+
def register_task(task_identifier: str, callable_: typing.Callable) -> None:
48+
global registered_tasks
49+
50+
registered_task = RegisteredTask(
51+
task_identifier=task_identifier,
52+
task_function=callable_,
53+
)
54+
registered_tasks[task_identifier] = registered_task
55+
56+
57+
def register_recurring_task(
58+
task_identifier: str, callable_: typing.Callable, **task_kwargs
59+
) -> None:
60+
global registered_tasks
61+
62+
logger.debug("Registering recurring task '%s'", task_identifier)
63+
64+
registered_task = RegisteredTask(
65+
task_identifier=task_identifier,
66+
task_function=callable_,
67+
task_type=TaskType.RECURRING,
68+
task_kwargs=task_kwargs,
69+
)
70+
registered_tasks[task_identifier] = registered_task
1571

1672
logger.debug(
1773
"Registered tasks now has the following tasks registered: %s",
1874
list(registered_tasks.keys()),
1975
)
20-
21-
22-
def get_task(task_identifier: str) -> typing.Callable:
23-
return registered_tasks[task_identifier]

Diff for: tests/unit/task_processor/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
import pytest
55

6+
from task_processor.task_registry import RegisteredTask
7+
68

79
@pytest.fixture
810
def run_by_processor(monkeypatch):
@@ -33,3 +35,11 @@ def _inner(log_level: str | int = logging.INFO) -> pytest.LogCaptureFixture:
3335
return caplog
3436

3537
return _inner
38+
39+
40+
@pytest.fixture(autouse=True)
41+
def task_registry() -> typing.Generator[dict[str, RegisteredTask], None, None]:
42+
from task_processor.task_registry import registered_tasks
43+
44+
registered_tasks.clear()
45+
yield registered_tasks

Diff for: tests/unit/task_processor/test_unit_task_processor_decorators.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from task_processor.exceptions import InvalidArgumentsError
1616
from task_processor.models import RecurringTask, Task, TaskPriority
17-
from task_processor.task_registry import get_task
17+
from task_processor.task_registry import get_task, initialise
1818
from task_processor.task_run_method import TaskRunMethod
1919

2020
if typing.TYPE_CHECKING:
@@ -113,6 +113,8 @@ def a_function(first_arg, second_arg):
113113
return first_arg + second_arg
114114

115115
# Then
116+
initialise()
117+
116118
task = RecurringTask.objects.get(task_identifier=task_identifier)
117119
assert task.serialized_kwargs == json.dumps(task_kwargs)
118120
assert task.run_every == run_every

Diff for: tests/unit/task_processor/test_unit_task_processor_models.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,25 @@
66

77
from task_processor.decorators import register_task_handler
88
from task_processor.models import RecurringTask, Task
9+
from task_processor.task_registry import initialise
910

1011
now = timezone.now()
1112
one_hour_ago = now - timedelta(hours=1)
1213
one_hour_from_now = now + timedelta(hours=1)
1314

1415

15-
@register_task_handler()
16-
def my_callable(arg_one: str, arg_two: str = None):
17-
"""Example callable to use for tasks (needs to be global for registering to work)"""
18-
return arg_one, arg_two
19-
20-
2116
def test_task_run():
2217
# Given
18+
@register_task_handler()
19+
def my_callable(arg_one: str, arg_two: str = None):
20+
"""Example callable to use for tasks (needs to be global for registering to work)"""
21+
return arg_one, arg_two
22+
2323
args = ["foo"]
2424
kwargs = {"arg_two": "bar"}
2525

26+
initialise()
27+
2628
task = Task.create(
2729
my_callable.task_identifier,
2830
scheduled_for=timezone.now(),

0 commit comments

Comments
 (0)