Skip to content

♻️Mypy: webserver2 #6200

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/ci-testing-deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ jobs:
run: ./ci/github/unit-testing/webserver.bash install
- name: typecheck
run: ./ci/github/unit-testing/webserver.bash typecheck
continue-on-error: true
- name: test isolated
if: always()
run: ./ci/github/unit-testing/webserver.bash test_isolated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class WalletGet(OutputSchema):
wallet_id: WalletID
name: str
name: IDStr
description: str | None
owner: GroupID
thumbnail: str | None
Expand Down
23 changes: 22 additions & 1 deletion packages/models-library/src/models_library/basic_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from enum import Enum
from typing import TypeAlias
from typing import Final, TypeAlias

from pydantic import (
ConstrainedDecimal,
Expand Down Expand Up @@ -83,11 +83,32 @@ class UUIDStr(ConstrainedStr):

# non-empty bounded string used as identifier
# e.g. "123" or "name_123" or "fa327c73-52d8-462a-9267-84eeaf0f90e3" but NOT ""
_ELLIPSIS_CHAR: Final[str] = "..."


class IDStr(ConstrainedStr):
strip_whitespace = True
min_length = 1
max_length = 100

@staticmethod
def concatenate(*args: "IDStr", link_char: str = " ") -> "IDStr":
result = link_char.join(args).strip()
assert IDStr.min_length # nosec
assert IDStr.max_length # nosec
if len(result) > IDStr.max_length:
if IDStr.max_length > len(_ELLIPSIS_CHAR):
result = (
result[: IDStr.max_length - len(_ELLIPSIS_CHAR)].rstrip()
+ _ELLIPSIS_CHAR
)
else:
result = _ELLIPSIS_CHAR[0] * IDStr.max_length
if len(result) < IDStr.min_length:
msg = f"IDStr.concatenate: result is too short: {result}"
raise ValueError(msg)
return IDStr(result)


class ShortTruncatedStr(ConstrainedStr):
# NOTE: Use to input e.g. titles or display names
Expand Down
8 changes: 8 additions & 0 deletions packages/service-library/src/servicelib/redis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from collections.abc import Awaitable, Callable
from datetime import timedelta
from typing import Any

import arrow

Expand Down Expand Up @@ -100,3 +101,10 @@ def start_exclusive_periodic_task(
usr_tsk_task_name=task_name,
**kwargs,
)


async def handle_redis_returns_union_types(result: Any | Awaitable[Any]) -> Any:
"""Used to handle mypy issues with redis 5.x return types"""
if isinstance(result, Awaitable):
return await result
return result
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections.abc import Iterator
from typing import Any
from typing import Any, cast

from aiohttp import web
from aiohttp.web import Request
Expand Down Expand Up @@ -319,8 +319,11 @@ async def get_service_output(
service = await client.get_service(
ctx.app, ctx.user_id, service_key, service_version, ctx.product_name
)
return await ServiceOutputGetFactory.from_catalog_service_api_model(
service=service, output_key=output_key
return cast( # mypy -> aiocache is not typed.
ServiceOutputGet,
await ServiceOutputGetFactory.from_catalog_service_api_model(
service=service, output_key=output_key
),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import Any, Final

from aiocache import cached
from aiocache import cached # type: ignore[import-untyped]
from models_library.api_schemas_webserver.catalog import (
ServiceInputGet,
ServiceInputKey,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
from collections.abc import AsyncIterator
from typing import Any
from typing import Any, cast

from aiohttp import web
from aiopg.sa import Engine, create_engine
Expand All @@ -31,7 +31,7 @@
async def _ensure_pg_ready(settings: PostgresSettings) -> Engine:

_logger.info("Connecting to postgres with %s", f"{settings=}")
engine = await create_engine(
engine: Engine = await create_engine(
settings.dsn,
application_name=settings.POSTGRES_CLIENT_NAME,
minsize=settings.POSTGRES_MINSIZE,
Expand Down Expand Up @@ -91,7 +91,7 @@ def get_engine_state(app: web.Application) -> dict[str, Any]:


def get_database_engine(app: web.Application) -> Engine:
return app[APP_DB_ENGINE_KEY]
return cast(Engine, app[APP_DB_ENGINE_KEY])


@app_module_setup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NodeGetIdle,
NodeGetUnknown,
)
from models_library.basic_types import IDStr
from models_library.progress_bar import ProgressReport
from models_library.projects import ProjectID
from models_library.projects_nodes_io import NodeID
Expand Down Expand Up @@ -108,7 +109,7 @@ async def stop_dynamic_services_in_project(
user_id,
project_id,
),
description="stopping services",
description=IDStr("stopping services"),
)
)

Expand All @@ -123,7 +124,7 @@ async def stop_dynamic_services_in_project(
save_state=save_state,
),
progress=progress_bar.sub_progress(
1, description=f"{service.node_uuid}"
1, description=IDStr(f"{service.node_uuid}")
),
)
for service in running_dynamic_services
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
MIME: Multipurpose Internet Mail Extensions

"""

import logging

import aiohttp_jinja2
import jinja_app_loader
import jinja_app_loader # type: ignore[import-untyped]
from aiohttp import web
from servicelib.aiohttp.application_setup import ModuleCategory, app_module_setup

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ async def export_project(request: web.Request):
)

headers = {"Content-Disposition": f'attachment; filename="{file_to_download.name}"'}

assert delete_tmp_dir # nosec
return CleanupFileResponse(
remove_tmp_dir_cb=delete_tmp_dir, path=file_to_download, headers=headers
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
FolderGetPage,
PutFolderBodyParams,
)
from models_library.basic_types import IDStr
from models_library.folders import FolderID
from models_library.rest_ordering import OrderBy, OrderDirection
from models_library.rest_pagination import Page, PageQueryParameters
Expand Down Expand Up @@ -70,8 +71,9 @@ class FoldersPathParams(StrictRequestParams):


class FolderListWithJsonStrQueryParams(PageQueryParameters):
order_by: Json[OrderBy] = Field( # pylint: disable=unsubscriptable-object
default=OrderBy(field="modified", direction=OrderDirection.DESC),
# pylint: disable=unsubscriptable-object
order_by: Json[OrderBy] = Field( # type: ignore[type-arg]
default=OrderBy(field=IDStr("modified"), direction=OrderDirection.DESC),
description="Order by field (modified_at|name|description) and direction (asc|desc). The default sorting order is ascending.",
example='{"field": "name", "direction": "desc"}',
alias="order_by",
Expand All @@ -89,7 +91,8 @@ def validate_order_by_field(cls, v):
"name",
"description",
}:
raise ValueError(f"We do not support ordering by provided field {v.field}")
msg = f"We do not support ordering by provided field {v.field}"
raise ValueError(msg)
if v.field == "modified_at":
v.field = "modified"
return v
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ async def _get_user_group(
group = await result.fetchone()
if not group:
raise GroupNotFoundError(gid)
assert isinstance(group, RowProxy) # nosec
return group


async def get_user_from_email(conn: SAConnection, email: str) -> RowProxy:
result = await conn.execute(sa.select(users).where(users.c.email == email))
user: RowProxy = await result.fetchone()
user = await result.fetchone()
if not user:
raise UserNotFoundError(email=email)
assert isinstance(user, RowProxy) # nosec
return user


Expand Down
Loading
Loading