Skip to content
This repository was archived by the owner on Mar 28, 2025. It is now read-only.

fix(recurring-task): reduce update load on task_processor_recurringtask table #23

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions task_processor/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from django.db.transaction import on_commit
from django.utils import timezone

from task_processor import task_registry
from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError
from task_processor.models import RecurringTask, Task, TaskPriority
from task_processor.task_registry import register_task
from task_processor.task_run_method import TaskRunMethod

P = typing.ParamSpec("P")
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
task_name = task_name or f.__name__
task_module = getmodule(f).__name__.rsplit(".")[-1]
self.task_identifier = task_identifier = f"{task_module}.{task_name}"
register_task(task_identifier, f)
task_registry.register_task(task_identifier, f)

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None:
_validate_inputs(*args, **kwargs)
Expand Down Expand Up @@ -168,31 +168,27 @@ def register_recurring_task(
kwargs: dict[str, typing.Any] | None = None,
first_run_time: time | None = None,
timeout: timedelta | None = timedelta(minutes=30),
) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]:
) -> typing.Callable[[typing.Callable[..., None]], None]:
if not os.environ.get("RUN_BY_PROCESSOR"):
# Do not register recurring tasks if not invoked by task processor
return lambda f: f

def decorator(f: typing.Callable[..., None]) -> RecurringTask:
def decorator(f: typing.Callable[..., None]) -> None:
nonlocal task_name

task_name = task_name or f.__name__
task_module = getmodule(f).__name__.rsplit(".")[-1]
task_identifier = f"{task_module}.{task_name}"

register_task(task_identifier, f)

task, _ = RecurringTask.objects.update_or_create(
task_identifier=task_identifier,
defaults={
"serialized_args": RecurringTask.serialize_data(args or ()),
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
"run_every": run_every,
"first_run_time": first_run_time,
"timeout": timeout,
},
)
return task
task_kwargs = {
"serialized_args": RecurringTask.serialize_data(args or ()),
"serialized_kwargs": RecurringTask.serialize_data(kwargs or {}),
"run_every": run_every,
"first_run_time": first_run_time,
"timeout": timeout,
}

task_registry.register_recurring_task(task_identifier, f, **task_kwargs)

return decorator

Expand Down
9 changes: 4 additions & 5 deletions task_processor/management/commands/runprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from django.core.management import BaseCommand
from django.utils import timezone

from task_processor.task_registry import registered_tasks
from task_processor.task_registry import initialise
from task_processor.thread_monitoring import (
clear_unhealthy_threads,
write_unhealthy_threads,
Expand Down Expand Up @@ -74,10 +74,9 @@ def handle(self, *args, **options):
]
)

logger.info(
"Processor starting. Registered tasks are: %s",
list(registered_tasks.keys()),
)
logger.info("Processor starting")

initialise()

for thread in self._threads:
thread.start()
Expand Down
3 changes: 2 additions & 1 deletion task_processor/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def run(self):
@property
def callable(self) -> typing.Callable:
try:
return registered_tasks[self.task_identifier]
task = registered_tasks[self.task_identifier]
return task.task_function
except KeyError as e:
raise TaskProcessingError(
"No task registered with identifier '%s'. Ensure your task is "
Expand Down
68 changes: 60 additions & 8 deletions task_processor/task_registry.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,75 @@
import enum
import logging
import typing
from dataclasses import dataclass

logger = logging.getLogger(__name__)

registered_tasks: typing.Dict[str, typing.Callable] = {}

class TaskType(enum.Enum):
STANDARD = "STANDARD"
RECURRING = "RECURRING"

def register_task(task_identifier: str, callable_: typing.Callable):

@dataclass
class RegisteredTask:
task_identifier: str
task_function: typing.Callable
task_type: TaskType = TaskType.STANDARD
task_kwargs: typing.Dict[str, typing.Any] = None


registered_tasks: typing.Dict[str, RegisteredTask] = {}


def initialise() -> None:
global registered_tasks

logger.debug("Registering task '%s'", task_identifier)
from task_processor.models import RecurringTask

for task_identifier, registered_task in registered_tasks.items():
logger.debug("Initialising task '%s'", task_identifier)

if registered_task.task_type == TaskType.RECURRING:
logger.debug("Persisting recurring task '%s'", task_identifier)
RecurringTask.objects.update_or_create(
task_identifier=task_identifier,
defaults=registered_task.task_kwargs,
)

registered_tasks[task_identifier] = callable_

def get_task(task_identifier: str) -> RegisteredTask:
global registered_tasks

return registered_tasks[task_identifier]


def register_task(task_identifier: str, callable_: typing.Callable) -> None:
global registered_tasks

registered_task = RegisteredTask(
task_identifier=task_identifier,
task_function=callable_,
)
registered_tasks[task_identifier] = registered_task


def register_recurring_task(
task_identifier: str, callable_: typing.Callable, **task_kwargs
) -> None:
global registered_tasks

logger.debug("Registering recurring task '%s'", task_identifier)

registered_task = RegisteredTask(
task_identifier=task_identifier,
task_function=callable_,
task_type=TaskType.RECURRING,
task_kwargs=task_kwargs,
)
registered_tasks[task_identifier] = registered_task

logger.debug(
"Registered tasks now has the following tasks registered: %s",
list(registered_tasks.keys()),
)


def get_task(task_identifier: str) -> typing.Callable:
return registered_tasks[task_identifier]
10 changes: 10 additions & 0 deletions tests/unit/task_processor/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import pytest

from task_processor.task_registry import RegisteredTask


@pytest.fixture
def run_by_processor(monkeypatch):
Expand Down Expand Up @@ -33,3 +35,11 @@ def _inner(log_level: str | int = logging.INFO) -> pytest.LogCaptureFixture:
return caplog

return _inner


@pytest.fixture(autouse=True)
def task_registry() -> typing.Generator[dict[str, RegisteredTask], None, None]:
from task_processor.task_registry import registered_tasks

registered_tasks.clear()
yield registered_tasks
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from task_processor.exceptions import InvalidArgumentsError
from task_processor.models import RecurringTask, Task, TaskPriority
from task_processor.task_registry import get_task
from task_processor.task_registry import get_task, initialise
from task_processor.task_run_method import TaskRunMethod

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -113,6 +113,8 @@ def a_function(first_arg, second_arg):
return first_arg + second_arg

# Then
initialise()

task = RecurringTask.objects.get(task_identifier=task_identifier)
assert task.serialized_kwargs == json.dumps(task_kwargs)
assert task.run_every == run_every
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/task_processor/test_unit_task_processor_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,25 @@

from task_processor.decorators import register_task_handler
from task_processor.models import RecurringTask, Task
from task_processor.task_registry import initialise

now = timezone.now()
one_hour_ago = now - timedelta(hours=1)
one_hour_from_now = now + timedelta(hours=1)


@register_task_handler()
def my_callable(arg_one: str, arg_two: str = None):
"""Example callable to use for tasks (needs to be global for registering to work)"""
return arg_one, arg_two


def test_task_run():
# Given
@register_task_handler()
def my_callable(arg_one: str, arg_two: str = None):
"""Example callable to use for tasks (needs to be global for registering to work)"""
return arg_one, arg_two

args = ["foo"]
kwargs = {"arg_two": "bar"}

initialise()

task = Task.create(
my_callable.task_identifier,
scheduled_for=timezone.now(),
Expand Down
Loading