|
1 | 1 | """
|
2 | 2 | Try to acquire a lock on the MPI resource.
|
3 | 3 |
|
4 |
| -Due to pour non async implementation aioredlock will be used |
| 4 | +Due to pour non async implementation aioredlock will be used. |
| 5 | +All configuration is specified upfront |
5 | 6 |
|
6 |
| -How it works: |
7 | 7 |
|
8 |
| -- Try to acquire a lock the lock in a tight loop for about X seconds. |
9 |
| -- If it works start a task which updates the expiration every X second is spawned. |
10 |
| -- Ensures sleeper can be started as MPI sleeper again. |
11 | 8 | """
|
12 | 9 | import asyncio
|
13 |
| -import datetime |
14 | 10 | import logging
|
15 |
| -from threading import Thread |
16 |
| -from typing import Any, Callable, Optional, Tuple |
| 11 | +import multiprocessing |
| 12 | +import os |
17 | 13 |
|
18 |
| -from aioredlock import Aioredlock, Lock, LockError |
| 14 | +import aioredis |
| 15 | +import tenacity |
| 16 | +from aioredlock import Aioredlock, LockError |
| 17 | +from pydantic.networks import RedisDsn |
19 | 18 |
|
20 | 19 | from . import config
|
21 | 20 |
|
22 |
| -logger = logging.getLogger(__name__) |
23 |
| - |
| 21 | +# ptsvd cause issues with multiprocessing |
| 22 | +# SEE: https://github.com/microsoft/ptvsd/issues/1443 |
| 23 | +if os.environ.get("SC_BOOT_MODE") == "debug-ptvsd": # pragma: no cover |
| 24 | + multiprocessing.set_start_method("spawn", True) |
24 | 25 |
|
25 |
| -async def retry_for_result( |
26 |
| - result_validator: Callable[[Any], Any], coroutine_factory: Callable |
27 |
| -) -> Tuple[bool, Any]: |
28 |
| - """ |
29 |
| - Will execute the given callback until the expected result is reached. |
30 |
| - Between each retry it will wait 1/5 of REDLOCK_REFRESH_INTERVAL_SECONDS |
31 |
| - """ |
32 |
| - sleep_interval = config.REDLOCK_REFRESH_INTERVAL_SECONDS / 5.0 |
33 |
| - elapsed = 0.0 |
34 |
| - start = datetime.datetime.utcnow() |
35 |
| - |
36 |
| - while elapsed < config.REDLOCK_REFRESH_INTERVAL_SECONDS: |
37 |
| - result = await coroutine_factory() |
38 |
| - if result_validator(result): |
39 |
| - return True, result |
40 |
| - await asyncio.sleep(sleep_interval) |
41 |
| - elapsed = (datetime.datetime.utcnow() - start).total_seconds() |
42 |
| - |
43 |
| - return False, None |
| 26 | +logger = logging.getLogger(__name__) |
44 | 27 |
|
45 | 28 |
|
46 |
| -def start_background_lock_extender( |
47 |
| - lock_manager: Aioredlock, lock: Lock, loop: asyncio.BaseEventLoop |
| 29 | +async def _wrapped_acquire_and_extend_lock_worker( |
| 30 | + reply_queue: multiprocessing.Queue, cpu_count: int |
48 | 31 | ) -> None:
|
49 |
| - """Will periodically extend the duration of the lock""" |
50 |
| - |
51 |
| - async def extender_worker(lock_manager: Aioredlock): |
52 |
| - sleep_interval = 0.9 * config.REDLOCK_REFRESH_INTERVAL_SECONDS |
53 |
| - while True: |
54 |
| - await lock_manager.extend(lock, config.REDLOCK_REFRESH_INTERVAL_SECONDS) |
55 |
| - |
56 |
| - await asyncio.sleep(sleep_interval) |
57 |
| - |
58 |
| - loop.run_until_complete(extender_worker(lock_manager)) |
59 |
| - |
60 |
| - |
61 |
| -def thread_worker( |
62 |
| - lock_manager: Aioredlock, lock: Lock, loop: asyncio.BaseEventLoop |
| 32 | + try: |
| 33 | + # if the lock is acquired the above function will block here |
| 34 | + await _acquire_and_extend_lock_forever(reply_queue, cpu_count) |
| 35 | + finally: |
| 36 | + # if the _acquire_and_extend_lock_forever function returns |
| 37 | + # the lock was not acquired, need to make sure the acquire_mpi_lock |
| 38 | + # always has a result to avoid issues |
| 39 | + reply_queue.put(False) |
| 40 | + |
| 41 | + |
| 42 | +@tenacity.retry( |
| 43 | + wait=tenacity.wait_fixed(5), |
| 44 | + stop=tenacity.stop_after_attempt(60), |
| 45 | + before_sleep=tenacity.before_sleep_log(logger, logging.INFO), |
| 46 | + reraise=True, |
| 47 | +) |
| 48 | +async def wait_till_redis_responsive(dsn: RedisDsn) -> None: |
| 49 | + logger.info("Trying to connect to %s", dsn) |
| 50 | + client = await aioredis.create_redis_pool(dsn, encoding="utf-8") |
| 51 | + client.close() |
| 52 | + await client.wait_closed() |
| 53 | + |
| 54 | + |
| 55 | +# trap lock_error |
| 56 | +async def _acquire_and_extend_lock_forever( |
| 57 | + reply_queue: multiprocessing.Queue, cpu_count: int |
63 | 58 | ) -> None:
|
64 |
| - start_background_lock_extender(lock_manager, lock, loop) |
| 59 | + await wait_till_redis_responsive(config.CELERY_CONFIG.redis.dsn) |
65 | 60 |
|
| 61 | + resource_name = f"aioredlock:mpi_lock:{cpu_count}" |
| 62 | + endpoint = [ |
| 63 | + { |
| 64 | + "host": config.CELERY_CONFIG.redis.host, |
| 65 | + "port": config.CELERY_CONFIG.redis.port, |
| 66 | + "db": int(config.CELERY_CONFIG.redis.db), |
| 67 | + } |
| 68 | + ] |
| 69 | + |
| 70 | + logger.info("Will try to acquire an mpi_lock on %s", resource_name) |
| 71 | + logger.info("Connecting to %s", endpoint) |
| 72 | + lock_manager = Aioredlock( |
| 73 | + redis_connections=endpoint, |
| 74 | + retry_count=10, |
| 75 | + internal_lock_timeout=config.REDLOCK_REFRESH_INTERVAL_SECONDS, |
| 76 | + ) |
66 | 77 |
|
67 |
| -async def try_to_acquire_lock( |
68 |
| - lock_manager: Aioredlock, resource_name: str |
69 |
| -) -> Optional[Tuple[bool, Lock]]: |
70 |
| - # Try to acquire the lock: |
| 78 | + # Try to acquire the lock, it will retry it 5 times with |
| 79 | + # a wait between 0.1 and 0.3 seconds between each try |
| 80 | + # if the lock is not acquire a LockError is raised |
71 | 81 | try:
|
72 |
| - return await lock_manager.lock( |
73 |
| - resource_name, lock_timeout=config.REDLOCK_REFRESH_INTERVAL_SECONDS |
74 |
| - ) |
| 82 | + lock = await lock_manager.lock(resource_name) |
75 | 83 | except LockError:
|
76 |
| - pass |
77 |
| - |
78 |
| - return None |
| 84 | + logger.warning("Could not acquire lock on resource %s", resource_name) |
| 85 | + await lock_manager.destroy() |
| 86 | + return |
79 | 87 |
|
| 88 | + # NOTE: in high concurrency situation you can have |
| 89 | + # multiple instances acquire the same lock |
| 90 | + # wait a tiny amount and read back the result of the lock acquisition |
| 91 | + await asyncio.sleep(0.1) |
| 92 | + # reed back result to make sure it was locked |
| 93 | + is_locked = await lock_manager.is_locked(resource_name) |
80 | 94 |
|
81 |
| -async def acquire_lock(cpu_count: int) -> bool: |
82 |
| - resource_name = f"aioredlock:mpi_lock:{cpu_count}" |
83 |
| - lock_manager = Aioredlock([config.CELERY_CONFIG.redis.dsn]) |
84 |
| - logger.info("Will try to acquire an mpi_lock") |
| 95 | + # the lock was successfully acquired, put the result in the queue |
| 96 | + reply_queue.put(is_locked) |
85 | 97 |
|
86 |
| - def is_locked_factory(): |
87 |
| - return lock_manager.is_locked(resource_name) |
| 98 | + # continue renewing the lock at regular intervals |
| 99 | + sleep_interval = 0.5 * config.REDLOCK_REFRESH_INTERVAL_SECONDS |
| 100 | + logger.info("Starting lock extention at %s seconds interval", sleep_interval) |
88 | 101 |
|
89 |
| - is_lock_free, _ = await retry_for_result( |
90 |
| - result_validator=lambda x: x is False, |
91 |
| - coroutine_factory=is_locked_factory, |
92 |
| - ) |
| 102 | + try: |
| 103 | + while True: |
| 104 | + try: |
| 105 | + await lock_manager.extend(lock) |
| 106 | + except LockError: |
| 107 | + logger.warning( |
| 108 | + "There was an error trying to extend the lock %s", resource_name |
| 109 | + ) |
93 | 110 |
|
94 |
| - if not is_lock_free: |
95 |
| - # it was not possible to acquire the lock |
96 |
| - return False |
| 111 | + await asyncio.sleep(sleep_interval) |
| 112 | + finally: |
| 113 | + # in case some other error occurs recycle all connections to redis |
| 114 | + await lock_manager.destroy() |
97 | 115 |
|
98 |
| - def try_to_acquire_lock_factory(): |
99 |
| - return try_to_acquire_lock(lock_manager, resource_name) |
100 | 116 |
|
101 |
| - # lock is free try to acquire and start background extention |
102 |
| - managed_to_acquire_lock, lock = await retry_for_result( |
103 |
| - result_validator=lambda x: type(x) == Lock, |
104 |
| - coroutine_factory=try_to_acquire_lock_factory, |
| 117 | +def _process_worker(queue: multiprocessing.Queue, cpu_count: int) -> None: |
| 118 | + logger.info("Starting background process for mpi lock result") |
| 119 | + asyncio.get_event_loop().run_until_complete( |
| 120 | + _wrapped_acquire_and_extend_lock_worker(queue, cpu_count) |
105 | 121 | )
|
106 |
| - |
107 |
| - if managed_to_acquire_lock: |
108 |
| - Thread( |
109 |
| - target=thread_worker, |
110 |
| - args=( |
111 |
| - lock_manager, |
112 |
| - lock, |
113 |
| - asyncio.get_event_loop(), |
114 |
| - ), |
115 |
| - daemon=True, |
116 |
| - ).start() |
117 |
| - |
118 |
| - logger.info("mpi_lock acquisition result %s", managed_to_acquire_lock) |
119 |
| - return managed_to_acquire_lock |
| 122 | + logger.info("Background asyncio task finished. Background process will despawn.") |
120 | 123 |
|
121 | 124 |
|
122 | 125 | def acquire_mpi_lock(cpu_count: int) -> bool:
|
123 | 126 | """
|
124 | 127 | returns True if successfull
|
125 | 128 | Will try to acquire a distributed shared lock.
|
126 |
| - This operation will last up to 2 x config.REDLOCK_REFRESH_INTERVAL_SECONDS |
127 | 129 | """
|
128 |
| - from .utils import wrap_async_call |
| 130 | + reply_queue = multiprocessing.Queue() |
| 131 | + multiprocessing.Process( |
| 132 | + target=_process_worker, args=(reply_queue, cpu_count), daemon=True |
| 133 | + ).start() |
129 | 134 |
|
130 |
| - was_acquired = wrap_async_call(acquire_lock(cpu_count)) |
131 |
| - return was_acquired |
| 135 | + lock_acquired = reply_queue.get() |
| 136 | + return lock_acquired |
0 commit comments