Skip to content

Fix types, run typechecker in CI #1393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Nov 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
bf6b23c
python: add trivial type casts to make typechecker happy
yorickvP Nov 21, 2023
38313b9
worker.cancel: don't crash on missing child
yorickvP Nov 21, 2023
109a900
webhook_caller_filtered: expects as filter set()
yorickvP Nov 21, 2023
bd0bffd
cog.server.http: don't crash when cpu_count couldn't be determined
yorickvP Nov 21, 2023
b96e171
ast_openapi_schema: deal with bytes() defaults
yorickvP Nov 21, 2023
e7c273f
cog.server.http: change PredictionRequest to make pyright happy
yorickvP Nov 21, 2023
2bf8de9
cog.predictor: more possible types
yorickvP Nov 21, 2023
9946ffc
Fix: UnionType is new in 3.10
yorickvP Nov 21, 2023
0b23405
Fix: | syntax is only available in 3.10
yorickvP Nov 21, 2023
0ba9855
cog.types: fix untyped definition
yorickvP Nov 22, 2023
ab66b2a
Add tool.pyright config to pyproject.toml
yorickvP Nov 22, 2023
6264eaf
python: error on missing type arguments, add them everywhere
yorickvP Nov 22, 2023
8dfe4e6
ast_openapi_schema.extract_info: make type checker actually happy
yorickvP Nov 22, 2023
ddc9333
Add typecheck python github action
yorickvP Nov 22, 2023
5dfd4f3
cog.server.http: quote typenames
yorickvP Nov 22, 2023
38f0222
python: process some review comments
yorickvP Nov 23, 2023
2c8a65c
async runner/http: fix types
yorickvP Nov 23, 2023
4c9b8d9
python: mypy -> pyright everywhere
yorickvP Nov 24, 2023
d18c755
pyproject: unneccesary -> unnecessary
yorickvP Nov 24, 2023
c8f3636
resolve threading|asyncio.Event difference the other way
yorickvP Nov 24, 2023
9aca73b
python: make tests pass for PredictionResponse
yorickvP Nov 24, 2023
068f6a7
Fix test in python 3.7
yorickvP Nov 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install '.[dev]'
yes | python -m mypy --install-types replicate || true
- name: Build
run: make cog
- name: Test
Expand All @@ -61,12 +60,29 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install '.[dev]'
yes | python -m mypy --install-types replicate || true
- name: Test
run: make test-python
env:
HYPOTHESIS_PROFILE: ci

typecheck-python:
name: "Typecheck Python"
runs-on: ubuntu-latest
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install Python dependencies
run: |
python -m pip install '.[dev]'
- name: Run typechecking
run: |
python -m pyright

# cannot run this on mac due to licensing issues: https://github.com/actions/virtual-environments/issues/2150
test-integration:
name: "Test integration"
Expand All @@ -82,7 +98,6 @@ jobs:
- name: Install Python dependencies
run: |
python -m pip install '.[dev]'
yes | python -m mypy --install-types replicate || true
- name: Test
run: make test-integration

Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ GOARCH := $(shell $(GO) env GOARCH)

PYTHON := python
PYTEST := $(PYTHON) -m pytest
MYPY := $(PYTHON) -m mypy
PYRIGHT := $(PYTHON) -m pyright
RUFF := $(PYTHON) -m ruff

default: all
Expand Down Expand Up @@ -94,7 +94,7 @@ lint-go:
.PHONY: lint-python
lint-python:
$(RUFF) python/cog
$(MYPY) python/cog
$(PYRIGHT)

.PHONY: lint
lint: lint-go lint-python
Expand Down
22 changes: 15 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ optional-dependencies = { "dev" = [
"httpx",
'hypothesis<6.80.0; python_version < "3.8"',
'hypothesis; python_version >= "3.8"',
"mypy",
'numpy<1.22.0; python_version < "3.8"',
'numpy; python_version >= "3.8"',
"pillow",
"pyright",
"pytest",
"pytest-asyncio",
"pytest-httpserver",
Expand All @@ -48,12 +48,20 @@ dynamic = ["version"]
[tool.setuptools_scm]
write_to = "python/cog/_version.py"

[tool.mypy]
plugins = "pydantic.mypy"
disallow_untyped_defs = true
# TODO: remove this and bring the codebase inline with the current mypy default
no_implicit_optional = false
exclude = ["python/tests/"]
[tool.pyright]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we pick one of pyright or mypy and drop the other?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would agree, would kinda lean towards sticking to mypy rather than making changes. pylsp-mypy exists though I assume it's slower

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a long-time mypy user, I like pyright more:

Notably pyright actually passes the pydantic stuff :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...I've had check-untyped-defs on the entire time I've used mypy. but I have to agree that this is fairly compelling even if the javascript stuff is weird.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yorickvP Thanks for the nudge. I just gave Pyright a look, and found it to be an improvement over mypy. I'd be very happy for us to adopt it across the board. (To that end, I just opened this PR in our Python client)

# TODO: remove this and bring the codebase inline with the current default
strictParameterNoneValue = false
# legacy behavior, fixed in PEP688
disableBytesTypePromotions = true
include = ["python"]
exclude = ["python/tests"]
reportMissingParameterType = "error"
reportUnknownLambdaType = "error"
reportUnnecessaryIsInstance = "warning"
reportUnnecessaryComparison = "warning"
reportUnneesssaryContains = "warning"
reportMissingTypeArgument = "error"
reportUnusedExpression = "warning"

[tool.setuptools]
package-dir = { "" = "python" }
Expand Down
83 changes: 57 additions & 26 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
import json
import sys
import types
import typing
from pathlib import Path

try:
Expand Down Expand Up @@ -307,8 +309,25 @@ def find(obj: ast.AST, name: str) -> ast.AST:
"""Find a particular named node in a tree"""
return next(node for node in ast.walk(obj) if getattr(node, "name", "") == name)

if typing.TYPE_CHECKING:
AstVal: "typing.TypeAlias" = "int | float | complex | str | list[AstVal] | bytes | None"
AstValNoBytes: "typing.TypeAlias" = "int | float | str | list[AstValNoBytes]"
JSONObject: "typing.TypeAlias" = "int | float | str | list[JSONObject] | JSONDict | None"
JSONDict: "typing.TypeAlias" = "dict[str, JSONObject]"


def to_serializable(val: "AstVal") -> "JSONObject":
if isinstance(val, bytes):
return val.decode("utf-8")
elif isinstance(val, list):
return [to_serializable(x) for x in val]
elif isinstance(val, complex):
msg = "complex inputs are not supported"
raise ValueError(msg)
else:
return val

def get_value(node: ast.AST) -> "int | float | complex | str | list":
def get_value(node: ast.AST) -> "AstVal":
"""Return the value of constant or list of constants"""
if isinstance(node, ast.Constant):
return node.value
Expand All @@ -320,7 +339,7 @@ def get_value(node: ast.AST) -> "int | float | complex | str | list":
if isinstance(node, (ast.List, ast.Tuple)):
return [get_value(e) for e in node.elts]
if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
return -get_value(node.operand)
return -typing.cast(typing.Union[int, float, complex], get_value(node.operand))
raise ValueError("Unexpected node type", type(node))


Expand All @@ -344,7 +363,7 @@ def get_call_name(call: ast.Call) -> str:
raise ValueError("Unexpected node type", type(call), ast.unparse(call))


def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | Ellipsis]]":
def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | types.EllipsisType]]":
"""Parse argument, default pairs from a file with a predict function"""
predict = find(tree, "predict")
assert isinstance(predict, ast.FunctionDef)
Expand All @@ -353,34 +372,38 @@ def parse_args(tree: ast.AST) -> "list[tuple[ast.arg, ast.expr | Ellipsis]]":
defaults = [...] * (len(args) - len(predict.args.defaults)) + predict.args.defaults
return list(zip(args, defaults))


def parse_assignment(assignment: ast.AST) -> "tuple[str | None, dict | None]":
def parse_assignment(assignment: ast.AST) -> "None | tuple[str, JSONObject]":
"""Parse an assignment into an OpenAPI object property"""
if isinstance(assignment, ast.AnnAssign):
assert isinstance(assignment.target, ast.Name) # shouldn't be an Attribute
default = {"default": get_value(assignment.value)} if assignment.value else {}
default = {}
if assignment.value:
try:
default = {"default": to_serializable(get_value(assignment.value))}
except UnicodeDecodeError:
pass
return assignment.target.id, {
"title": assignment.target.id.replace("_", " ").title(),
"type": OPENAPI_TYPES[get_annotation(assignment.annotation)],
**default,
}
if isinstance(assignment, ast.Assign):
if len(assignment.targets) == 1 and isinstance(assignment.targets[0], ast.Name):
value = get_value(assignment.value)
value = to_serializable(get_value(assignment.value))
return assignment.targets[0].id, {
"title": assignment.targets[0].id.replace("_", " ").title(),
"type": OPENAPI_TYPES[type(value).__name__],
"default": value,
}
raise ValueError("Unexpected assignment", assignment)
return None, None
return None


def parse_class(classdef: ast.AST) -> dict:
def parse_class(classdef: ast.AST) -> "JSONDict":
"""Parse a class definition into an OpenAPI object"""
assert isinstance(classdef, ast.ClassDef)
properties = {
key: property for key, property in map(parse_assignment, classdef.body) if key
assignment[0]: assignment[1] for assignment in map(parse_assignment, classdef.body) if assignment
}
return {
"title": classdef.name,
Expand All @@ -404,16 +427,16 @@ def resolve_name(node: ast.expr) -> str:
if isinstance(node, ast.Name):
return node.id
if isinstance(node, ast.Index):
# depricated, but needed for py3.8
return resolve_name(node.value)
# deprecated, but needed for py3.8
return resolve_name(node.value) # type: ignore
if isinstance(node, ast.Attribute):
return node.attr
if isinstance(node, ast.Subscript):
return resolve_name(node.value)
raise ValueError("Unexpected node type", type(node), ast.unparse(node))


def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[dict, dict]":
def parse_return_annotation(tree: ast.AST, fn: str = "predict") -> "tuple[JSONDict, JSONDict]":
predict = find(tree, fn)
if not isinstance(predict, ast.FunctionDef):
raise ValueError("Could not find predict function")
Expand Down Expand Up @@ -459,7 +482,7 @@ def predict(
format = {"format": "uri"} if name in ("Path", "File") else {}
return {}, {"title": "Output", "type": OPENAPI_TYPES.get(name, name), **format}
# it must be a custom object
schema = {name: parse_class(find(tree, name))}
schema: "JSONDict" = {name: parse_class(find(tree, name))}
return schema, {
"title": "Output",
"$ref": f"#/components/schemas/{name}",
Expand All @@ -469,24 +492,30 @@ def predict(
KEPT_ATTRS = ("description", "default", "ge", "le", "max_length", "min_length", "regex")


def extract_info(code: str) -> dict:
def extract_info(code: str) -> "JSONDict":
"""Parse the schemas from a file with a predict function"""
tree = ast.parse(code)
inputs = {"title": "Input", "type": "object", "properties": {}}
properties: "JSONDict" = {}
inputs: "JSONDict" = {"title": "Input", "type": "object", "properties": properties}
required: "list[str]" = []
schemas: "dict[str, dict]" = {}
schemas: "JSONDict" = {}
for arg, default in parse_args(tree):
if arg.arg == "self":
continue
if isinstance(default, ast.Call) and get_call_name(default) == "Input":
kws = {kw.arg: get_value(kw.value) for kw in default.keywords}
kws = {}
for kw in default.keywords:
if kw.arg is None:
msg = "unknown argument for Input"
raise ValueError(msg)
kws[kw.arg] = to_serializable(get_value(kw.value))
elif isinstance(default, (ast.Constant, ast.List, ast.Tuple, ast.Str, ast.Num)):
kws = {"default": get_value(default)} # could be None
kws = {"default": to_serializable(get_value(default))} # could be None
elif default == ...: # no default
kws = {}
else:
raise ValueError("Unexpected default value", default)
input: dict = {"x-order": len(inputs["properties"])}
input: "JSONDict" = {"x-order": len(properties)}
# need to handle other types?
arg_type = OPENAPI_TYPES.get(get_annotation(arg.annotation), "string")
if get_annotation(arg.annotation) in ("Path", "File"):
Expand All @@ -508,23 +537,25 @@ def extract_info(code: str) -> dict:
else:
input["title"] = arg.arg.replace("_", " ").title()
input["type"] = arg_type
inputs["properties"][arg.arg] = input # type: ignore
properties[arg.arg] = input
if required:
inputs["required"] = required
inputs["required"] = list(required)
# List[Path], list[Path], str, Iterator[str], MyOutput, Output
return_schema, output = parse_return_annotation(tree, "predict")
schema = json.loads(BASE_SCHEMA)
components = {
schema: "JSONDict" = json.loads(BASE_SCHEMA)
components: "JSONDict" = {
"Input": inputs,
"Output": output,
**schemas,
**return_schema,
}
schema["components"]["schemas"].update(components)
# trust me, typechecker, I know BASE_SCHEMA
x: "JSONDict" = schema["components"]["schemas"] # type: ignore
x.update(components)
return schema


def extract_file(fname: "str | Path") -> dict:
def extract_file(fname: "str | Path") -> "JSONObject":
return extract_info(open(fname, encoding="utf-8").read())


Expand Down
3 changes: 2 additions & 1 deletion python/cog/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def upload_file(fh: io.IOBase, output_file_prefix: str = None) -> str:
b = b.encode("utf-8")
encoded_body = base64.b64encode(b)
if getattr(fh, "name", None):
# despite doing a getattr check here, mypy complains that io.IOBase has no attribute name
# despite doing a getattr check here, pyright complains that io.IOBase has no attribute name
# TODO: switch to typing.IO[]?
mime_type = mimetypes.guess_type(fh.name)[0] # type: ignore
else:
mime_type = "application/octet-stream"
Expand Down
6 changes: 2 additions & 4 deletions python/cog/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ def make_encodeable(obj: Any) -> Any:
return obj.isoformat()
try:
import numpy as np # type: ignore

has_numpy = True
except ImportError:
has_numpy = False
if has_numpy:
pass
else:
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.floating):
Expand Down
Loading