diff --git a/task_processor/decorators.py b/task_processor/decorators.py index 27958c7..66f44d8 100644 --- a/task_processor/decorators.py +++ b/task_processor/decorators.py @@ -9,9 +9,9 @@ from django.db.transaction import on_commit from django.utils import timezone +from task_processor import task_registry from task_processor.exceptions import InvalidArgumentsError, TaskQueueFullError from task_processor.models import RecurringTask, Task, TaskPriority -from task_processor.task_registry import register_task from task_processor.task_run_method import TaskRunMethod P = typing.ParamSpec("P") @@ -50,7 +50,7 @@ def __init__( task_name = task_name or f.__name__ task_module = getmodule(f).__name__.rsplit(".")[-1] self.task_identifier = task_identifier = f"{task_module}.{task_name}" - register_task(task_identifier, f) + task_registry.register_task(task_identifier, f) def __call__(self, *args: P.args, **kwargs: P.kwargs) -> None: _validate_inputs(*args, **kwargs) @@ -168,31 +168,27 @@ def register_recurring_task( kwargs: dict[str, typing.Any] | None = None, first_run_time: time | None = None, timeout: timedelta | None = timedelta(minutes=30), -) -> typing.Callable[[typing.Callable[..., None]], RecurringTask]: +) -> typing.Callable[[typing.Callable[..., None]], None]: if not os.environ.get("RUN_BY_PROCESSOR"): # Do not register recurring tasks if not invoked by task processor return lambda f: f - def decorator(f: typing.Callable[..., None]) -> RecurringTask: + def decorator(f: typing.Callable[..., None]) -> None: nonlocal task_name task_name = task_name or f.__name__ task_module = getmodule(f).__name__.rsplit(".")[-1] task_identifier = f"{task_module}.{task_name}" - register_task(task_identifier, f) - - task, _ = RecurringTask.objects.update_or_create( - task_identifier=task_identifier, - defaults={ - "serialized_args": RecurringTask.serialize_data(args or ()), - "serialized_kwargs": RecurringTask.serialize_data(kwargs or {}), - "run_every": run_every, - "first_run_time": first_run_time, - "timeout": timeout, - }, - ) - return task + task_kwargs = { + "serialized_args": RecurringTask.serialize_data(args or ()), + "serialized_kwargs": RecurringTask.serialize_data(kwargs or {}), + "run_every": run_every, + "first_run_time": first_run_time, + "timeout": timeout, + } + + task_registry.register_recurring_task(task_identifier, f, **task_kwargs) return decorator diff --git a/task_processor/management/commands/runprocessor.py b/task_processor/management/commands/runprocessor.py index e82758e..15dc11c 100644 --- a/task_processor/management/commands/runprocessor.py +++ b/task_processor/management/commands/runprocessor.py @@ -7,7 +7,7 @@ from django.core.management import BaseCommand from django.utils import timezone -from task_processor.task_registry import registered_tasks +from task_processor.task_registry import initialise from task_processor.thread_monitoring import ( clear_unhealthy_threads, write_unhealthy_threads, @@ -74,10 +74,9 @@ def handle(self, *args, **options): ] ) - logger.info( - "Processor starting. Registered tasks are: %s", - list(registered_tasks.keys()), - ) + logger.info("Processor starting") + + initialise() for thread in self._threads: thread.start() diff --git a/task_processor/models.py b/task_processor/models.py index fc234a5..d7dd220 100644 --- a/task_processor/models.py +++ b/task_processor/models.py @@ -68,7 +68,8 @@ def run(self): @property def callable(self) -> typing.Callable: try: - return registered_tasks[self.task_identifier] + task = registered_tasks[self.task_identifier] + return task.task_function except KeyError as e: raise TaskProcessingError( "No task registered with identifier '%s'. Ensure your task is " diff --git a/task_processor/task_registry.py b/task_processor/task_registry.py index 50ae62e..c3f8b7d 100644 --- a/task_processor/task_registry.py +++ b/task_processor/task_registry.py @@ -1,23 +1,75 @@ +import enum import logging import typing +from dataclasses import dataclass logger = logging.getLogger(__name__) -registered_tasks: typing.Dict[str, typing.Callable] = {} +class TaskType(enum.Enum): + STANDARD = "STANDARD" + RECURRING = "RECURRING" -def register_task(task_identifier: str, callable_: typing.Callable): + +@dataclass +class RegisteredTask: + task_identifier: str + task_function: typing.Callable + task_type: TaskType = TaskType.STANDARD + task_kwargs: typing.Dict[str, typing.Any] = None + + +registered_tasks: typing.Dict[str, RegisteredTask] = {} + + +def initialise() -> None: global registered_tasks - logger.debug("Registering task '%s'", task_identifier) + from task_processor.models import RecurringTask + + for task_identifier, registered_task in registered_tasks.items(): + logger.debug("Initialising task '%s'", task_identifier) + + if registered_task.task_type == TaskType.RECURRING: + logger.debug("Persisting recurring task '%s'", task_identifier) + RecurringTask.objects.update_or_create( + task_identifier=task_identifier, + defaults=registered_task.task_kwargs, + ) - registered_tasks[task_identifier] = callable_ + +def get_task(task_identifier: str) -> RegisteredTask: + global registered_tasks + + return registered_tasks[task_identifier] + + +def register_task(task_identifier: str, callable_: typing.Callable) -> None: + global registered_tasks + + registered_task = RegisteredTask( + task_identifier=task_identifier, + task_function=callable_, + ) + registered_tasks[task_identifier] = registered_task + + +def register_recurring_task( + task_identifier: str, callable_: typing.Callable, **task_kwargs +) -> None: + global registered_tasks + + logger.debug("Registering recurring task '%s'", task_identifier) + + registered_task = RegisteredTask( + task_identifier=task_identifier, + task_function=callable_, + task_type=TaskType.RECURRING, + task_kwargs=task_kwargs, + ) + registered_tasks[task_identifier] = registered_task logger.debug( "Registered tasks now has the following tasks registered: %s", list(registered_tasks.keys()), ) - - -def get_task(task_identifier: str) -> typing.Callable: - return registered_tasks[task_identifier] diff --git a/tests/unit/task_processor/conftest.py b/tests/unit/task_processor/conftest.py index 6f38020..59c2823 100644 --- a/tests/unit/task_processor/conftest.py +++ b/tests/unit/task_processor/conftest.py @@ -3,6 +3,8 @@ import pytest +from task_processor.task_registry import RegisteredTask + @pytest.fixture def run_by_processor(monkeypatch): @@ -33,3 +35,11 @@ def _inner(log_level: str | int = logging.INFO) -> pytest.LogCaptureFixture: return caplog return _inner + + +@pytest.fixture(autouse=True) +def task_registry() -> typing.Generator[dict[str, RegisteredTask], None, None]: + from task_processor.task_registry import registered_tasks + + registered_tasks.clear() + yield registered_tasks diff --git a/tests/unit/task_processor/test_unit_task_processor_decorators.py b/tests/unit/task_processor/test_unit_task_processor_decorators.py index c655322..7d0ef8b 100644 --- a/tests/unit/task_processor/test_unit_task_processor_decorators.py +++ b/tests/unit/task_processor/test_unit_task_processor_decorators.py @@ -14,7 +14,7 @@ ) from task_processor.exceptions import InvalidArgumentsError from task_processor.models import RecurringTask, Task, TaskPriority -from task_processor.task_registry import get_task +from task_processor.task_registry import get_task, initialise from task_processor.task_run_method import TaskRunMethod if typing.TYPE_CHECKING: @@ -113,6 +113,8 @@ def a_function(first_arg, second_arg): return first_arg + second_arg # Then + initialise() + task = RecurringTask.objects.get(task_identifier=task_identifier) assert task.serialized_kwargs == json.dumps(task_kwargs) assert task.run_every == run_every diff --git a/tests/unit/task_processor/test_unit_task_processor_models.py b/tests/unit/task_processor/test_unit_task_processor_models.py index 4b4b020..268924a 100644 --- a/tests/unit/task_processor/test_unit_task_processor_models.py +++ b/tests/unit/task_processor/test_unit_task_processor_models.py @@ -6,23 +6,25 @@ from task_processor.decorators import register_task_handler from task_processor.models import RecurringTask, Task +from task_processor.task_registry import initialise now = timezone.now() one_hour_ago = now - timedelta(hours=1) one_hour_from_now = now + timedelta(hours=1) -@register_task_handler() -def my_callable(arg_one: str, arg_two: str = None): - """Example callable to use for tasks (needs to be global for registering to work)""" - return arg_one, arg_two - - def test_task_run(): # Given + @register_task_handler() + def my_callable(arg_one: str, arg_two: str = None): + """Example callable to use for tasks (needs to be global for registering to work)""" + return arg_one, arg_two + args = ["foo"] kwargs = {"arg_two": "bar"} + initialise() + task = Task.create( my_callable.task_identifier, scheduled_for=timezone.now(), diff --git a/tests/unit/task_processor/test_unit_task_processor_processor.py b/tests/unit/task_processor/test_unit_task_processor_processor.py index 3fbcb0e..c89a0b6 100644 --- a/tests/unit/task_processor/test_unit_task_processor_processor.py +++ b/tests/unit/task_processor/test_unit_task_processor_processor.py @@ -10,8 +10,10 @@ from django.utils import timezone from freezegun import freeze_time from pytest import MonkeyPatch +from pytest_mock import MockerFixture from task_processor.decorators import ( + TaskHandler, register_recurring_task, register_task_handler, ) @@ -28,7 +30,7 @@ run_recurring_tasks, run_tasks, ) -from task_processor.task_registry import registered_tasks +from task_processor.task_registry import initialise, registered_tasks if typing.TYPE_CHECKING: # This import breaks private-package-test workflow in core @@ -45,10 +47,40 @@ def reset_cache(): cache.clear() -def test_run_task_runs_task_and_creates_task_run_object_when_success(db): +@pytest.fixture +def dummy_task(db: None) -> TaskHandler: + @register_task_handler() + def _dummy_task(key: str = DEFAULT_CACHE_KEY, value: str = DEFAULT_CACHE_VALUE): + """function used to test that task is being run successfully""" + cache.set(key, value) + + return _dummy_task + + +@pytest.fixture +def raise_exception_task(db: None) -> TaskHandler: + @register_task_handler() + def _raise_exception_task(msg: str): + raise Exception(msg) + + return _raise_exception_task + + +@pytest.fixture +def sleep_task(db: None) -> TaskHandler: + @register_task_handler() + def _sleep_task(seconds: int): + time.sleep(seconds) + + return _sleep_task + + +def test_run_task_runs_task_and_creates_task_run_object_when_success( + dummy_task: TaskHandler, +): # Given task = Task.create( - _dummy_task.task_identifier, + dummy_task.task_identifier, scheduled_for=timezone.now(), ) task.save() @@ -71,13 +103,13 @@ def test_run_task_runs_task_and_creates_task_run_object_when_success(db): def test_run_task_kills_task_after_timeout( - db: None, + sleep_task: TaskHandler, get_task_processor_caplog: "GetTaskProcessorCaplog", ) -> None: # Given caplog = get_task_processor_caplog(logging.ERROR) task = Task.create( - _sleep.task_identifier, + sleep_task.task_identifier, scheduled_for=timezone.now(), args=(1,), timeout=timedelta(microseconds=1), @@ -122,8 +154,10 @@ def test_run_recurring_task_kills_task_after_timeout( def _dummy_recurring_task(): time.sleep(1) + initialise() + task = RecurringTask.objects.get( - task_identifier=_dummy_recurring_task.task_identifier + task_identifier="test_unit_task_processor_processor._dummy_recurring_task", ) # When task_runs = run_recurring_tasks() @@ -158,8 +192,10 @@ def test_run_recurring_tasks_runs_task_and_creates_recurring_task_run_object_whe def _dummy_recurring_task(): cache.set(DEFAULT_CACHE_KEY, DEFAULT_CACHE_VALUE) + initialise() + task = RecurringTask.objects.get( - task_identifier=_dummy_recurring_task.task_identifier + task_identifier="test_unit_task_processor_processor._dummy_recurring_task", ) # When task_runs = run_recurring_tasks() @@ -186,8 +222,10 @@ def test_run_recurring_tasks_runs_locked_task_after_tiemout( def _dummy_recurring_task(): cache.set(DEFAULT_CACHE_KEY, DEFAULT_CACHE_VALUE) + initialise() + task = RecurringTask.objects.get( - task_identifier=_dummy_recurring_task.task_identifier + task_identifier="test_unit_task_processor_processor._dummy_recurring_task", ) task.is_locked = True task.locked_at = timezone.now() - timedelta(hours=1) @@ -221,8 +259,10 @@ def _dummy_recurring_task(): val = cache.get(DEFAULT_CACHE_KEY, 0) + 1 cache.set(DEFAULT_CACHE_KEY, val) + initialise() + task = RecurringTask.objects.get( - task_identifier=_dummy_recurring_task.task_identifier + task_identifier="test_unit_task_processor_processor._dummy_recurring_task", ) # When @@ -265,8 +305,10 @@ def _dummy_recurring_task(): val = cache.get(DEFAULT_CACHE_KEY, 0) + 1 cache.set(DEFAULT_CACHE_KEY, val) + initialise() + task = RecurringTask.objects.get( - task_identifier=_dummy_recurring_task.task_identifier + task_identifier="test_unit_task_processor_processor._dummy_recurring_task", ) # When - we call run_recurring_tasks twice @@ -293,7 +335,11 @@ def test_run_recurring_tasks_does_nothing_if_unregistered_task_is_new( def _a_task(): pass + initialise() + # now - remove the task from the registry + from task_processor.task_registry import registered_tasks + registered_tasks.pop(task_identifier) # When @@ -305,7 +351,9 @@ def _a_task(): def test_run_recurring_tasks_deletes_the_task_if_unregistered_task_is_old( - db: None, run_by_processor: None, caplog: pytest.LogCaptureFixture + db: None, + run_by_processor: None, + mocker: MockerFixture, ) -> None: # Given task_processor_logger = logging.getLogger("task_processor") @@ -319,6 +367,8 @@ def test_run_recurring_tasks_deletes_the_task_if_unregistered_task_is_old( def _a_task(): pass + initialise() + # now - remove the task from the registry registered_tasks.pop(task_identifier) @@ -333,7 +383,7 @@ def _a_task(): def test_run_task_runs_task_and_creates_task_run_object_when_failure( - db: None, + raise_exception_task: TaskHandler, get_task_processor_caplog: "GetTaskProcessorCaplog", ) -> None: # Given @@ -341,7 +391,7 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure( msg = "Error!" task = Task.create( - _raise_exception.task_identifier, args=(msg,), scheduled_for=timezone.now() + raise_exception_task.task_identifier, args=(msg,), scheduled_for=timezone.now() ) task.save() @@ -377,9 +427,11 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure( ] -def test_run_task_runs_failed_task_again(db): +def test_run_task_runs_failed_task_again(raise_exception_task: TaskHandler): # Given - task = Task.create(_raise_exception.task_identifier, scheduled_for=timezone.now()) + task = Task.create( + raise_exception_task.task_identifier, scheduled_for=timezone.now() + ) task.save() # When @@ -415,6 +467,8 @@ def test_run_recurring_task_runs_task_and_creates_recurring_task_run_object_when def _raise_exception(organisation_name): raise RuntimeError("test exception") + initialise() + task = RecurringTask.objects.get(task_identifier=task_identifier) # When @@ -439,11 +493,11 @@ def test_run_task_does_nothing_if_no_tasks(db): @pytest.mark.django_db(transaction=True) -def test_run_task_runs_tasks_in_correct_priority(): +def test_run_task_runs_tasks_in_correct_priority(dummy_task: TaskHandler): # Given # 2 tasks task_1 = Task.create( - _dummy_task.task_identifier, + dummy_task.task_identifier, scheduled_for=timezone.now(), args=("task 1 organisation",), priority=TaskPriority.HIGH, @@ -451,7 +505,7 @@ def test_run_task_runs_tasks_in_correct_priority(): task_1.save() task_2 = Task.create( - _dummy_task.task_identifier, + dummy_task.task_identifier, scheduled_for=timezone.now(), args=("task 2 organisation",), priority=TaskPriority.HIGH, @@ -459,7 +513,7 @@ def test_run_task_runs_tasks_in_correct_priority(): task_2.save() task_3 = Task.create( - _dummy_task.task_identifier, + dummy_task.task_identifier, scheduled_for=timezone.now(), args=("task 3 organisation",), priority=TaskPriority.HIGHEST, @@ -478,7 +532,10 @@ def test_run_task_runs_tasks_in_correct_priority(): @pytest.mark.django_db(transaction=True) -def test_run_tasks_skips_locked_tasks(): +def test_run_tasks_skips_locked_tasks( + dummy_task: TaskHandler, + sleep_task: TaskHandler, +): """ This test verifies that tasks are locked while being executed, and hence new task runners are not able to pick up 'in progress' tasks. @@ -488,13 +545,13 @@ def test_run_tasks_skips_locked_tasks(): # One which is configured to just sleep for 3 seconds, to simulate a task # being held for a short period of time task_1 = Task.create( - _sleep.task_identifier, scheduled_for=timezone.now(), args=(3,) + sleep_task.task_identifier, scheduled_for=timezone.now(), args=(3,) ) task_1.save() # and another which should create an organisation task_2 = Task.create( - _dummy_task.task_identifier, + dummy_task.task_identifier, scheduled_for=timezone.now(), args=("task 2 organisation",), ) @@ -516,7 +573,7 @@ def test_run_tasks_skips_locked_tasks(): task_runner_thread.join() -def test_run_more_than_one_task(db): +def test_run_more_than_one_task(dummy_task: TaskHandler): # Given num_tasks = 5 @@ -525,7 +582,7 @@ def test_run_more_than_one_task(db): organisation_name = f"test-org-{uuid.uuid4()}" tasks.append( Task.create( - _dummy_task.task_identifier, + dummy_task.task_identifier, scheduled_for=timezone.now(), args=(organisation_name,), ) @@ -557,6 +614,8 @@ def test_recurring_tasks_are_unlocked_if_picked_up_but_not_executed( def my_task(): pass + initialise() + recurring_task = RecurringTask.objects.get( task_identifier="test_unit_task_processor_processor.my_task" ) @@ -578,19 +637,3 @@ def my_task(): # Then recurring_task.refresh_from_db() assert recurring_task.is_locked is False - - -@register_task_handler() -def _dummy_task(key: str = DEFAULT_CACHE_KEY, value: str = DEFAULT_CACHE_VALUE): - """function used to test that task is being run successfully""" - cache.set(key, value) - - -@register_task_handler() -def _raise_exception(msg: str): - raise Exception(msg) - - -@register_task_handler() -def _sleep(seconds: int): - time.sleep(seconds)