Skip to content

Commit 2e73148

Browse files
authored
fix redis connection leaks + exclusions error: (fixes #1065) (#1066)
* fix redis connection leaks + exclusions error: (fixes #1065) - use contextmanager for accessing redis to ensure redis.close() is always called - add get_redis_client() to k8sapi to ensure unified place to get redis client - use connectionpool.from_url() until redis 5.0.0 is released to ensure auto close and single client settings are applied - also: catch invalid regex passed to re.compile() in queue regex check, return 400 instead of 500 for invalid regex - redis requirements: bump to 5.0.0rc2
1 parent 8998354 commit 2e73148

File tree

5 files changed

+132
-96
lines changed

5 files changed

+132
-96
lines changed

backend/btrixcloud/basecrawls.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from datetime import timedelta
77
from typing import Optional, List, Union
88
import urllib.parse
9+
import contextlib
910

1011
from pydantic import UUID4
1112
from fastapi import HTTPException, Depends
12-
from redis import asyncio as aioredis, exceptions
13+
from redis import exceptions
1314

1415
from .models import (
1516
CrawlFile,
@@ -216,8 +217,8 @@ async def _resolve_crawl_refs(
216217
# more responsive, saves db update in operator
217218
if crawl.state in RUNNING_STATES:
218219
try:
219-
redis = await self.get_redis(crawl.id)
220-
crawl.stats = await get_redis_crawl_stats(redis, crawl.id)
220+
async with self.get_redis(crawl.id) as redis:
221+
crawl.stats = await get_redis_crawl_stats(redis, crawl.id)
221222
# redis not available, ignore
222223
except exceptions.ConnectionError:
223224
pass
@@ -281,13 +282,17 @@ async def _update_presigned(self, updates):
281282
for update in updates:
282283
await self.crawls.find_one_and_update(*update)
283284

285+
@contextlib.asynccontextmanager
284286
async def get_redis(self, crawl_id):
285287
"""get redis url for crawl id"""
286288
redis_url = self.crawl_manager.get_redis_url(crawl_id)
287289

288-
return await aioredis.from_url(
289-
redis_url, encoding="utf-8", decode_responses=True
290-
)
290+
redis = await self.crawl_manager.get_redis_client(redis_url)
291+
292+
try:
293+
yield redis
294+
finally:
295+
await redis.close()
291296

292297
async def add_to_collection(
293298
self, crawl_ids: List[uuid.UUID], collection_id: uuid.UUID, org: Organization

backend/btrixcloud/crawls.py

+85-76
Original file line numberDiff line numberDiff line change
@@ -363,106 +363,115 @@ async def get_crawl_queue(self, crawl_id, offset, count, regex):
363363

364364
total = 0
365365
results = []
366-
redis = None
367366

368367
try:
369-
redis = await self.get_redis(crawl_id)
368+
async with self.get_redis(crawl_id) as redis:
369+
total = await self._crawl_queue_len(redis, f"{crawl_id}:q")
370+
results = await self._crawl_queue_range(
371+
redis, f"{crawl_id}:q", offset, count
372+
)
373+
results = [json.loads(result)["url"] for result in results]
370374

371-
total = await self._crawl_queue_len(redis, f"{crawl_id}:q")
372-
results = await self._crawl_queue_range(
373-
redis, f"{crawl_id}:q", offset, count
374-
)
375-
results = [json.loads(result)["url"] for result in results]
376375
except exceptions.ConnectionError:
377376
# can't connect to redis, likely not initialized yet
378377
pass
379378

380379
matched = []
381380
if regex:
382-
regex = re.compile(regex)
381+
try:
382+
regex = re.compile(regex)
383+
except re.error as exc:
384+
raise HTTPException(status_code=400, detail="invalid_regex") from exc
385+
383386
matched = [result for result in results if regex.search(result)]
384387

385388
return {"total": total, "results": results, "matched": matched}
386389

387390
async def match_crawl_queue(self, crawl_id, regex):
388391
"""get list of urls that match regex"""
389392
total = 0
390-
redis = None
391-
392-
try:
393-
redis = await self.get_redis(crawl_id)
394-
total = await self._crawl_queue_len(redis, f"{crawl_id}:q")
395-
except exceptions.ConnectionError:
396-
# can't connect to redis, likely not initialized yet
397-
pass
398-
399-
regex = re.compile(regex)
400393
matched = []
401394
step = 50
402395

403-
for count in range(0, total, step):
404-
results = await self._crawl_queue_range(redis, f"{crawl_id}:q", count, step)
405-
for result in results:
406-
url = json.loads(result)["url"]
407-
if regex.search(url):
408-
matched.append(url)
396+
async with self.get_redis(crawl_id) as redis:
397+
try:
398+
total = await self._crawl_queue_len(redis, f"{crawl_id}:q")
399+
except exceptions.ConnectionError:
400+
# can't connect to redis, likely not initialized yet
401+
pass
402+
403+
try:
404+
regex = re.compile(regex)
405+
except re.error as exc:
406+
raise HTTPException(status_code=400, detail="invalid_regex") from exc
407+
408+
for count in range(0, total, step):
409+
results = await self._crawl_queue_range(
410+
redis, f"{crawl_id}:q", count, step
411+
)
412+
for result in results:
413+
url = json.loads(result)["url"]
414+
if regex.search(url):
415+
matched.append(url)
409416

410417
return {"total": total, "matched": matched}
411418

412419
async def filter_crawl_queue(self, crawl_id, regex):
413420
"""filter out urls that match regex"""
414421
# pylint: disable=too-many-locals
415422
total = 0
416-
redis = None
417-
418423
q_key = f"{crawl_id}:q"
419424
s_key = f"{crawl_id}:s"
420-
421-
try:
422-
redis = await self.get_redis(crawl_id)
423-
total = await self._crawl_queue_len(redis, f"{crawl_id}:q")
424-
except exceptions.ConnectionError:
425-
# can't connect to redis, likely not initialized yet
426-
pass
427-
428-
dircount = -1
429-
regex = re.compile(regex)
430425
step = 50
431-
432-
count = 0
433426
num_removed = 0
434427

435-
# pylint: disable=fixme
436-
# todo: do this in a more efficient way?
437-
# currently quite inefficient as redis does not have a way
438-
# to atomically check and remove value from list
439-
# so removing each jsob block by value
440-
while count < total:
441-
if dircount == -1 and count > total / 2:
442-
dircount = 1
443-
results = await self._crawl_queue_range(redis, q_key, count, step)
444-
count += step
445-
446-
qrems = []
447-
srems = []
448-
449-
for result in results:
450-
url = json.loads(result)["url"]
451-
if regex.search(url):
452-
srems.append(url)
453-
# await redis.srem(s_key, url)
454-
# res = await self._crawl_queue_rem(redis, q_key, result, dircount)
455-
qrems.append(result)
456-
457-
if not srems:
458-
continue
459-
460-
await redis.srem(s_key, *srems)
461-
res = await self._crawl_queue_rem(redis, q_key, qrems, dircount)
462-
if res:
463-
count -= res
464-
num_removed += res
465-
print(f"Removed {res} from queue", flush=True)
428+
async with self.get_redis(crawl_id) as redis:
429+
try:
430+
total = await self._crawl_queue_len(redis, f"{crawl_id}:q")
431+
except exceptions.ConnectionError:
432+
# can't connect to redis, likely not initialized yet
433+
pass
434+
435+
dircount = -1
436+
437+
try:
438+
regex = re.compile(regex)
439+
except re.error as exc:
440+
raise HTTPException(status_code=400, detail="invalid_regex") from exc
441+
442+
count = 0
443+
444+
# pylint: disable=fixme
445+
# todo: do this in a more efficient way?
446+
# currently quite inefficient as redis does not have a way
447+
# to atomically check and remove value from list
448+
# so removing each jsob block by value
449+
while count < total:
450+
if dircount == -1 and count > total / 2:
451+
dircount = 1
452+
results = await self._crawl_queue_range(redis, q_key, count, step)
453+
count += step
454+
455+
qrems = []
456+
srems = []
457+
458+
for result in results:
459+
url = json.loads(result)["url"]
460+
if regex.search(url):
461+
srems.append(url)
462+
# await redis.srem(s_key, url)
463+
# res = await self._crawl_queue_rem(redis, q_key, result, dircount)
464+
qrems.append(result)
465+
466+
if not srems:
467+
continue
468+
469+
await redis.srem(s_key, *srems)
470+
res = await self._crawl_queue_rem(redis, q_key, qrems, dircount)
471+
if res:
472+
count -= res
473+
num_removed += res
474+
print(f"Removed {res} from queue", flush=True)
466475

467476
return num_removed
468477

@@ -475,13 +484,13 @@ async def get_errors_from_redis(
475484
skip = page * page_size
476485
upper_bound = skip + page_size - 1
477486

478-
try:
479-
redis = await self.get_redis(crawl_id)
480-
errors = await redis.lrange(f"{crawl_id}:e", skip, upper_bound)
481-
total = await redis.llen(f"{crawl_id}:e")
482-
except exceptions.ConnectionError:
483-
# pylint: disable=raise-missing-from
484-
raise HTTPException(status_code=503, detail="redis_connection_error")
487+
async with self.get_redis(crawl_id) as redis:
488+
try:
489+
errors = await redis.lrange(f"{crawl_id}:e", skip, upper_bound)
490+
total = await redis.llen(f"{crawl_id}:e")
491+
except exceptions.ConnectionError:
492+
# pylint: disable=raise-missing-from
493+
raise HTTPException(status_code=503, detail="redis_connection_error")
485494

486495
parsed_errors = parse_jsonl_error_messages(errors)
487496
return parsed_errors, total

backend/btrixcloud/k8sapi.py

+14
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from kubernetes_asyncio.utils import create_from_dict
1414
from kubernetes_asyncio.client.exceptions import ApiException
1515

16+
from redis.asyncio import Redis
17+
from redis.asyncio.connection import ConnectionPool
18+
1619
from fastapi.templating import Jinja2Templates
1720
from .utils import get_templates_dir, dt_now, to_k8s_date
1821

@@ -62,6 +65,17 @@ def get_redis_url(self, crawl_id):
6265
)
6366
return redis_url
6467

68+
async def get_redis_client(self, redis_url):
69+
"""return redis client with correct params for one-time use"""
70+
# manual settings until redis 5.0.0 is released
71+
pool = ConnectionPool.from_url(redis_url, decode_responses=True)
72+
redis = Redis(
73+
connection_pool=pool,
74+
decode_responses=True,
75+
)
76+
redis.auto_close_connection_pool = True
77+
return redis
78+
6579
# pylint: disable=too-many-arguments
6680
async def new_crawl_job(
6781
self, cid, userid, oid, scale=1, crawl_timeout=0, manual=True

backend/btrixcloud/operator.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import humanize
1414

1515
from pydantic import BaseModel
16-
from redis import asyncio as aioredis
1716

1817
from .utils import (
1918
from_k8s_date,
@@ -430,6 +429,7 @@ async def delete_pvc(self, crawl_id):
430429
async def cancel_crawl(self, redis_url, crawl_id, cid, status, state):
431430
"""immediately cancel crawl with specified state
432431
return true if db mark_finished update succeeds"""
432+
redis = None
433433
try:
434434
redis = await self._get_redis(redis_url)
435435
await self.mark_finished(redis, crawl_id, uuid.UUID(cid), status, state)
@@ -438,6 +438,10 @@ async def cancel_crawl(self, redis_url, crawl_id, cid, status, state):
438438
except:
439439
return False
440440

441+
finally:
442+
if redis:
443+
await redis.close()
444+
441445
def _done_response(self, status, finalized=False):
442446
"""done response for removing crawl"""
443447
return {
@@ -462,15 +466,16 @@ async def _get_redis(self, redis_url):
462466
"""init redis, ensure connectivity"""
463467
redis = None
464468
try:
465-
redis = await aioredis.from_url(
466-
redis_url, encoding="utf-8", decode_responses=True
467-
)
469+
redis = await self.get_redis_client(redis_url)
468470
# test connection
469471
await redis.ping()
470472
return redis
471473

472474
# pylint: disable=bare-except
473475
except:
476+
if redis:
477+
await redis.close()
478+
474479
return None
475480

476481
async def check_if_finished(self, crawl, status):
@@ -512,16 +517,16 @@ async def sync_crawl_state(self, redis_url, crawl, status, pods):
512517
status.resync_after = self.fast_retry_secs
513518
return status
514519

515-
# set state to running (if not already)
516-
if status.state not in RUNNING_STATES:
517-
await self.set_state(
518-
"running",
519-
status,
520-
crawl.id,
521-
allowed_from=["starting", "waiting_capacity"],
522-
)
523-
524520
try:
521+
# set state to running (if not already)
522+
if status.state not in RUNNING_STATES:
523+
await self.set_state(
524+
"running",
525+
status,
526+
crawl.id,
527+
allowed_from=["starting", "waiting_capacity"],
528+
)
529+
525530
file_done = await redis.lpop(self.done_key)
526531

527532
while file_done:
@@ -547,6 +552,9 @@ async def sync_crawl_state(self, redis_url, crawl, status, pods):
547552
print(f"Crawl get failed: {exc}, will try again")
548553
return status
549554

555+
finally:
556+
await redis.close()
557+
550558
def check_if_pods_running(self, pods):
551559
"""check if at least one crawler pod has started"""
552560
try:

backend/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ loguru
55
aiofiles
66
kubernetes-asyncio==22.6.5
77
aiobotocore
8-
redis>=4.2.0rc1
8+
redis>=5.0.0rc2
99
pyyaml
1010
jinja2
1111
humanize

0 commit comments

Comments
 (0)