diff --git a/backend/btrixcloud/basecrawls.py b/backend/btrixcloud/basecrawls.py index cd1c344f67..73e1cc347e 100644 --- a/backend/btrixcloud/basecrawls.py +++ b/backend/btrixcloud/basecrawls.py @@ -6,10 +6,11 @@ from datetime import timedelta from typing import Optional, List, Union import urllib.parse +import contextlib from pydantic import UUID4 from fastapi import HTTPException, Depends -from redis import asyncio as aioredis, exceptions +from redis import exceptions from .models import ( CrawlFile, @@ -216,8 +217,8 @@ async def _resolve_crawl_refs( # more responsive, saves db update in operator if crawl.state in RUNNING_STATES: try: - redis = await self.get_redis(crawl.id) - crawl.stats = await get_redis_crawl_stats(redis, crawl.id) + async with self.get_redis(crawl.id) as redis: + crawl.stats = await get_redis_crawl_stats(redis, crawl.id) # redis not available, ignore except exceptions.ConnectionError: pass @@ -281,13 +282,17 @@ async def _update_presigned(self, updates): for update in updates: await self.crawls.find_one_and_update(*update) + @contextlib.asynccontextmanager async def get_redis(self, crawl_id): """get redis url for crawl id""" redis_url = self.crawl_manager.get_redis_url(crawl_id) - return await aioredis.from_url( - redis_url, encoding="utf-8", decode_responses=True - ) + redis = await self.crawl_manager.get_redis_client(redis_url) + + try: + yield redis + finally: + await redis.close() async def add_to_collection( self, crawl_ids: List[uuid.UUID], collection_id: uuid.UUID, org: Organization diff --git a/backend/btrixcloud/crawls.py b/backend/btrixcloud/crawls.py index 2ed4c5a778..80db4f3fc4 100644 --- a/backend/btrixcloud/crawls.py +++ b/backend/btrixcloud/crawls.py @@ -363,23 +363,26 @@ async def get_crawl_queue(self, crawl_id, offset, count, regex): total = 0 results = [] - redis = None try: - redis = await self.get_redis(crawl_id) + async with self.get_redis(crawl_id) as redis: + total = await self._crawl_queue_len(redis, f"{crawl_id}:q") + results = await self._crawl_queue_range( + redis, f"{crawl_id}:q", offset, count + ) + results = [json.loads(result)["url"] for result in results] - total = await self._crawl_queue_len(redis, f"{crawl_id}:q") - results = await self._crawl_queue_range( - redis, f"{crawl_id}:q", offset, count - ) - results = [json.loads(result)["url"] for result in results] except exceptions.ConnectionError: # can't connect to redis, likely not initialized yet pass matched = [] if regex: - regex = re.compile(regex) + try: + regex = re.compile(regex) + except re.error as exc: + raise HTTPException(status_code=400, detail="invalid_regex") from exc + matched = [result for result in results if regex.search(result)] return {"total": total, "results": results, "matched": matched} @@ -387,25 +390,29 @@ async def get_crawl_queue(self, crawl_id, offset, count, regex): async def match_crawl_queue(self, crawl_id, regex): """get list of urls that match regex""" total = 0 - redis = None - - try: - redis = await self.get_redis(crawl_id) - total = await self._crawl_queue_len(redis, f"{crawl_id}:q") - except exceptions.ConnectionError: - # can't connect to redis, likely not initialized yet - pass - - regex = re.compile(regex) matched = [] step = 50 - for count in range(0, total, step): - results = await self._crawl_queue_range(redis, f"{crawl_id}:q", count, step) - for result in results: - url = json.loads(result)["url"] - if regex.search(url): - matched.append(url) + async with self.get_redis(crawl_id) as redis: + try: + total = await self._crawl_queue_len(redis, f"{crawl_id}:q") + except exceptions.ConnectionError: + # can't connect to redis, likely not initialized yet + pass + + try: + regex = re.compile(regex) + except re.error as exc: + raise HTTPException(status_code=400, detail="invalid_regex") from exc + + for count in range(0, total, step): + results = await self._crawl_queue_range( + redis, f"{crawl_id}:q", count, step + ) + for result in results: + url = json.loads(result)["url"] + if regex.search(url): + matched.append(url) return {"total": total, "matched": matched} @@ -413,56 +420,58 @@ async def filter_crawl_queue(self, crawl_id, regex): """filter out urls that match regex""" # pylint: disable=too-many-locals total = 0 - redis = None - q_key = f"{crawl_id}:q" s_key = f"{crawl_id}:s" - - try: - redis = await self.get_redis(crawl_id) - total = await self._crawl_queue_len(redis, f"{crawl_id}:q") - except exceptions.ConnectionError: - # can't connect to redis, likely not initialized yet - pass - - dircount = -1 - regex = re.compile(regex) step = 50 - - count = 0 num_removed = 0 - # pylint: disable=fixme - # todo: do this in a more efficient way? - # currently quite inefficient as redis does not have a way - # to atomically check and remove value from list - # so removing each jsob block by value - while count < total: - if dircount == -1 and count > total / 2: - dircount = 1 - results = await self._crawl_queue_range(redis, q_key, count, step) - count += step - - qrems = [] - srems = [] - - for result in results: - url = json.loads(result)["url"] - if regex.search(url): - srems.append(url) - # await redis.srem(s_key, url) - # res = await self._crawl_queue_rem(redis, q_key, result, dircount) - qrems.append(result) - - if not srems: - continue - - await redis.srem(s_key, *srems) - res = await self._crawl_queue_rem(redis, q_key, qrems, dircount) - if res: - count -= res - num_removed += res - print(f"Removed {res} from queue", flush=True) + async with self.get_redis(crawl_id) as redis: + try: + total = await self._crawl_queue_len(redis, f"{crawl_id}:q") + except exceptions.ConnectionError: + # can't connect to redis, likely not initialized yet + pass + + dircount = -1 + + try: + regex = re.compile(regex) + except re.error as exc: + raise HTTPException(status_code=400, detail="invalid_regex") from exc + + count = 0 + + # pylint: disable=fixme + # todo: do this in a more efficient way? + # currently quite inefficient as redis does not have a way + # to atomically check and remove value from list + # so removing each jsob block by value + while count < total: + if dircount == -1 and count > total / 2: + dircount = 1 + results = await self._crawl_queue_range(redis, q_key, count, step) + count += step + + qrems = [] + srems = [] + + for result in results: + url = json.loads(result)["url"] + if regex.search(url): + srems.append(url) + # await redis.srem(s_key, url) + # res = await self._crawl_queue_rem(redis, q_key, result, dircount) + qrems.append(result) + + if not srems: + continue + + await redis.srem(s_key, *srems) + res = await self._crawl_queue_rem(redis, q_key, qrems, dircount) + if res: + count -= res + num_removed += res + print(f"Removed {res} from queue", flush=True) return num_removed @@ -475,13 +484,13 @@ async def get_errors_from_redis( skip = page * page_size upper_bound = skip + page_size - 1 - try: - redis = await self.get_redis(crawl_id) - errors = await redis.lrange(f"{crawl_id}:e", skip, upper_bound) - total = await redis.llen(f"{crawl_id}:e") - except exceptions.ConnectionError: - # pylint: disable=raise-missing-from - raise HTTPException(status_code=503, detail="redis_connection_error") + async with self.get_redis(crawl_id) as redis: + try: + errors = await redis.lrange(f"{crawl_id}:e", skip, upper_bound) + total = await redis.llen(f"{crawl_id}:e") + except exceptions.ConnectionError: + # pylint: disable=raise-missing-from + raise HTTPException(status_code=503, detail="redis_connection_error") parsed_errors = parse_jsonl_error_messages(errors) return parsed_errors, total diff --git a/backend/btrixcloud/k8sapi.py b/backend/btrixcloud/k8sapi.py index 414ee62a1c..705c9b11e1 100644 --- a/backend/btrixcloud/k8sapi.py +++ b/backend/btrixcloud/k8sapi.py @@ -13,6 +13,9 @@ from kubernetes_asyncio.utils import create_from_dict from kubernetes_asyncio.client.exceptions import ApiException +from redis.asyncio import Redis +from redis.asyncio.connection import ConnectionPool + from fastapi.templating import Jinja2Templates from .utils import get_templates_dir, dt_now, to_k8s_date @@ -62,6 +65,17 @@ def get_redis_url(self, crawl_id): ) return redis_url + async def get_redis_client(self, redis_url): + """return redis client with correct params for one-time use""" + # manual settings until redis 5.0.0 is released + pool = ConnectionPool.from_url(redis_url, decode_responses=True) + redis = Redis( + connection_pool=pool, + decode_responses=True, + ) + redis.auto_close_connection_pool = True + return redis + # pylint: disable=too-many-arguments async def new_crawl_job( self, cid, userid, oid, scale=1, crawl_timeout=0, manual=True diff --git a/backend/btrixcloud/operator.py b/backend/btrixcloud/operator.py index 71b0da13ad..03dccd3285 100644 --- a/backend/btrixcloud/operator.py +++ b/backend/btrixcloud/operator.py @@ -13,7 +13,6 @@ import humanize from pydantic import BaseModel -from redis import asyncio as aioredis from .utils import ( from_k8s_date, @@ -430,6 +429,7 @@ async def delete_pvc(self, crawl_id): async def cancel_crawl(self, redis_url, crawl_id, cid, status, state): """immediately cancel crawl with specified state return true if db mark_finished update succeeds""" + redis = None try: redis = await self._get_redis(redis_url) await self.mark_finished(redis, crawl_id, uuid.UUID(cid), status, state) @@ -438,6 +438,10 @@ async def cancel_crawl(self, redis_url, crawl_id, cid, status, state): except: return False + finally: + if redis: + await redis.close() + def _done_response(self, status, finalized=False): """done response for removing crawl""" return { @@ -462,15 +466,16 @@ async def _get_redis(self, redis_url): """init redis, ensure connectivity""" redis = None try: - redis = await aioredis.from_url( - redis_url, encoding="utf-8", decode_responses=True - ) + redis = await self.get_redis_client(redis_url) # test connection await redis.ping() return redis # pylint: disable=bare-except except: + if redis: + await redis.close() + return None async def check_if_finished(self, crawl, status): @@ -512,16 +517,16 @@ async def sync_crawl_state(self, redis_url, crawl, status, pods): status.resync_after = self.fast_retry_secs return status - # set state to running (if not already) - if status.state not in RUNNING_STATES: - await self.set_state( - "running", - status, - crawl.id, - allowed_from=["starting", "waiting_capacity"], - ) - try: + # set state to running (if not already) + if status.state not in RUNNING_STATES: + await self.set_state( + "running", + status, + crawl.id, + allowed_from=["starting", "waiting_capacity"], + ) + file_done = await redis.lpop(self.done_key) while file_done: @@ -547,6 +552,9 @@ async def sync_crawl_state(self, redis_url, crawl, status, pods): print(f"Crawl get failed: {exc}, will try again") return status + finally: + await redis.close() + def check_if_pods_running(self, pods): """check if at least one crawler pod has started""" try: diff --git a/backend/requirements.txt b/backend/requirements.txt index 38a1ef60dd..46f718efa2 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,7 +5,7 @@ loguru aiofiles kubernetes-asyncio==22.6.5 aiobotocore -redis>=4.2.0rc1 +redis>=5.0.0rc2 pyyaml jinja2 humanize