Skip to content

Commit 69b3144

Browse files
authored
Fix stubgen regressions with pybind11 and mypy 1.7 (#16504)
This addresses several regressions identified in #16486 The primary regression from #15770 is that pybind11 properties with docstrings were erroneously assigned `typeshed. Incomplete`. The reason for the regression is that as of the introduction of the `--include-docstring` feature (#13284, not my PR, ftr), `./misc/test-stubgenc.sh` began always reporting success. That has been fixed. It was also pointed out that `--include-docstring` does not work for C-extensions. This was not actually a regression as it turns out this feature was never implemented for C-extensions (though the tests suggested it had been), but luckily my efforts to unify the pure-python and C-extension code-paths made fixing this super easy (barely an inconvenience)! So that is working now. I added back the extended list of `typing` objects that generate implicit imports for the inspection-based stub generator. I originally removed these because I encountered an issue generating stubs for `PySide2` (and another internal library) where there was an object with the same name as one of the `typing` objects and the auto-import created broken stubs. I felt somewhat justified in this decision as there was a straightforward solution -- e.g. use `list` or `typing.List` instead of `List`. That said, I recognize that the problem that I encountered is more niche than the general desire to add import statements for typing objects, so I've changed the behavior back for now, with the intention to eventually add a flag to control this behavior.
1 parent 379d59e commit 69b3144

File tree

9 files changed

+88
-35
lines changed

9 files changed

+88
-35
lines changed

misc/test-stubgenc.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function stubgenc_test() {
2424
# Compare generated stubs to expected ones
2525
if ! git diff --exit-code "$STUBGEN_OUTPUT_FOLDER";
2626
then
27-
EXIT=$?
27+
EXIT=1
2828
fi
2929
}
3030

mypy/stubdoc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,8 @@ def infer_ret_type_sig_from_docstring(docstr: str, name: str) -> str | None:
383383

384384
def infer_ret_type_sig_from_anon_docstring(docstr: str) -> str | None:
385385
"""Convert signature in form of "(self: TestClass, arg0) -> int" to their return type."""
386-
return infer_ret_type_sig_from_docstring("stub" + docstr.strip(), "stub")
386+
lines = ["stub" + line.strip() for line in docstr.splitlines() if line.strip().startswith("(")]
387+
return infer_ret_type_sig_from_docstring("".join(lines), "stub")
387388

388389

389390
def parse_signature(sig: str) -> tuple[str, list[str], list[str]] | None:

mypy/stubgen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,6 +1700,7 @@ def generate_stubs(options: Options) -> None:
17001700
doc_dir=options.doc_dir,
17011701
include_private=options.include_private,
17021702
export_less=options.export_less,
1703+
include_docstrings=options.include_docstrings,
17031704
)
17041705
num_modules = len(all_modules)
17051706
if not options.quiet and num_modules > 0:

mypy/stubgenc.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,12 @@ def get_property_type(self, default_type: str | None, ctx: FunctionContext) -> s
126126
"""Infer property type from docstring or docstring signature."""
127127
if ctx.docstring is not None:
128128
inferred = infer_ret_type_sig_from_anon_docstring(ctx.docstring)
129-
if not inferred:
130-
inferred = infer_ret_type_sig_from_docstring(ctx.docstring, ctx.name)
131-
if not inferred:
132-
inferred = infer_prop_type_from_docstring(ctx.docstring)
129+
if inferred:
130+
return inferred
131+
inferred = infer_ret_type_sig_from_docstring(ctx.docstring, ctx.name)
132+
if inferred:
133+
return inferred
134+
inferred = infer_prop_type_from_docstring(ctx.docstring)
133135
return inferred
134136
else:
135137
return None
@@ -237,6 +239,26 @@ def __init__(
237239
self.resort_members = self.is_c_module
238240
super().__init__(_all_, include_private, export_less, include_docstrings)
239241
self.module_name = module_name
242+
if self.is_c_module:
243+
# Add additional implicit imports.
244+
# C-extensions are given more lattitude since they do not import the typing module.
245+
self.known_imports.update(
246+
{
247+
"typing": [
248+
"Any",
249+
"Callable",
250+
"ClassVar",
251+
"Dict",
252+
"Iterable",
253+
"Iterator",
254+
"List",
255+
"NamedTuple",
256+
"Optional",
257+
"Tuple",
258+
"Union",
259+
]
260+
}
261+
)
240262

241263
def get_default_function_sig(self, func: object, ctx: FunctionContext) -> FunctionSig:
242264
argspec = None
@@ -590,9 +612,29 @@ def generate_function_stub(
590612
if inferred[0].args and inferred[0].args[0].name == "cls":
591613
decorators.append("@classmethod")
592614

615+
if docstring:
616+
docstring = self._indent_docstring(docstring)
593617
output.extend(self.format_func_def(inferred, decorators=decorators, docstring=docstring))
594618
self._fix_iter(ctx, inferred, output)
595619

620+
def _indent_docstring(self, docstring: str) -> str:
621+
"""Fix indentation of docstring extracted from pybind11 or other binding generators."""
622+
lines = docstring.splitlines(keepends=True)
623+
indent = self._indent + " "
624+
if len(lines) > 1:
625+
if not all(line.startswith(indent) or not line.strip() for line in lines):
626+
# if the docstring is not indented, then indent all but the first line
627+
for i, line in enumerate(lines[1:]):
628+
if line.strip():
629+
lines[i + 1] = indent + line
630+
# if there's a trailing newline, add a final line to visually indent the quoted docstring
631+
if lines[-1].endswith("\n"):
632+
if len(lines) > 1:
633+
lines.append(indent)
634+
else:
635+
lines[-1] = lines[-1][:-1]
636+
return "".join(lines)
637+
596638
def _fix_iter(
597639
self, ctx: FunctionContext, inferred: list[FunctionSig], output: list[str]
598640
) -> None:
@@ -640,7 +682,7 @@ def generate_property_stub(
640682
if fget:
641683
alt_docstr = getattr(fget, "__doc__", None)
642684
if alt_docstr and docstring:
643-
docstring += alt_docstr
685+
docstring += "\n" + alt_docstr
644686
elif alt_docstr:
645687
docstring = alt_docstr
646688

mypy/stubutil.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,14 @@ def __init__(
576576
self.sig_generators = self.get_sig_generators()
577577
# populated by visit_mypy_file
578578
self.module_name: str = ""
579+
# These are "soft" imports for objects which might appear in annotations but not have
580+
# a corresponding import statement.
581+
self.known_imports = {
582+
"_typeshed": ["Incomplete"],
583+
"typing": ["Any", "TypeVar", "NamedTuple"],
584+
"collections.abc": ["Generator"],
585+
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
586+
}
579587

580588
def get_sig_generators(self) -> list[SignatureGenerator]:
581589
return []
@@ -667,15 +675,7 @@ def set_defined_names(self, defined_names: set[str]) -> None:
667675
for name in self._all_ or ():
668676
self.import_tracker.reexport(name)
669677

670-
# These are "soft" imports for objects which might appear in annotations but not have
671-
# a corresponding import statement.
672-
known_imports = {
673-
"_typeshed": ["Incomplete"],
674-
"typing": ["Any", "TypeVar", "NamedTuple"],
675-
"collections.abc": ["Generator"],
676-
"typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"],
677-
}
678-
for pkg, imports in known_imports.items():
678+
for pkg, imports in self.known_imports.items():
679679
for t in imports:
680680
# require=False means that the import won't be added unless require_name() is called
681681
# for the object during generation.

test-data/pybind11_mypy_demo/src/main.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
#include <cmath>
4646
#include <pybind11/pybind11.h>
47+
#include <pybind11/stl.h>
4748

4849
namespace py = pybind11;
4950

@@ -102,6 +103,11 @@ struct Point {
102103
return distance_to(other.x, other.y);
103104
}
104105

106+
std::vector<double> as_vector()
107+
{
108+
return std::vector<double>{x, y};
109+
}
110+
105111
double x, y;
106112
};
107113

@@ -134,14 +140,15 @@ void bind_basics(py::module& basics) {
134140
.def(py::init<double, double>(), py::arg("x"), py::arg("y"))
135141
.def("distance_to", py::overload_cast<double, double>(&Point::distance_to, py::const_), py::arg("x"), py::arg("y"))
136142
.def("distance_to", py::overload_cast<const Point&>(&Point::distance_to, py::const_), py::arg("other"))
137-
.def_readwrite("x", &Point::x)
143+
.def("as_list", &Point::as_vector)
144+
.def_readwrite("x", &Point::x, "some docstring")
138145
.def_property("y",
139146
[](Point& self){ return self.y; },
140147
[](Point& self, double value){ self.y = value; }
141148
)
142149
.def_property_readonly("length", &Point::length)
143150
.def_property_readonly_static("x_axis", [](py::object cls){return Point::x_axis;})
144-
.def_property_readonly_static("y_axis", [](py::object cls){return Point::y_axis;})
151+
.def_property_readonly_static("y_axis", [](py::object cls){return Point::y_axis;}, "another docstring")
145152
.def_readwrite_static("length_unit", &Point::length_unit)
146153
.def_property_static("angle_unit",
147154
[](py::object& /*cls*/){ return Point::angle_unit; },
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import basics as basics

test-data/pybind11_mypy_demo/stubgen-include-docs/pybind11_mypy_demo/basics.pyi

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import ClassVar
1+
from typing import ClassVar, List, overload
22

3-
from typing import overload
43
PI: float
4+
__version__: str
55

66
class Point:
77
class AngleUnit:
@@ -13,8 +13,6 @@ class Point:
1313
"""__init__(self: pybind11_mypy_demo.basics.Point.AngleUnit, value: int) -> None"""
1414
def __eq__(self, other: object) -> bool:
1515
"""__eq__(self: object, other: object) -> bool"""
16-
def __getstate__(self) -> int:
17-
"""__getstate__(self: object) -> int"""
1816
def __hash__(self) -> int:
1917
"""__hash__(self: object) -> int"""
2018
def __index__(self) -> int:
@@ -23,8 +21,6 @@ class Point:
2321
"""__int__(self: pybind11_mypy_demo.basics.Point.AngleUnit) -> int"""
2422
def __ne__(self, other: object) -> bool:
2523
"""__ne__(self: object, other: object) -> bool"""
26-
def __setstate__(self, state: int) -> None:
27-
"""__setstate__(self: pybind11_mypy_demo.basics.Point.AngleUnit, state: int) -> None"""
2824
@property
2925
def name(self) -> str: ...
3026
@property
@@ -40,8 +36,6 @@ class Point:
4036
"""__init__(self: pybind11_mypy_demo.basics.Point.LengthUnit, value: int) -> None"""
4137
def __eq__(self, other: object) -> bool:
4238
"""__eq__(self: object, other: object) -> bool"""
43-
def __getstate__(self) -> int:
44-
"""__getstate__(self: object) -> int"""
4539
def __hash__(self) -> int:
4640
"""__hash__(self: object) -> int"""
4741
def __index__(self) -> int:
@@ -50,8 +44,6 @@ class Point:
5044
"""__int__(self: pybind11_mypy_demo.basics.Point.LengthUnit) -> int"""
5145
def __ne__(self, other: object) -> bool:
5246
"""__ne__(self: object, other: object) -> bool"""
53-
def __setstate__(self, state: int) -> None:
54-
"""__setstate__(self: pybind11_mypy_demo.basics.Point.LengthUnit, state: int) -> None"""
5547
@property
5648
def name(self) -> str: ...
5749
@property
@@ -70,43 +62,51 @@ class Point:
7062
7163
1. __init__(self: pybind11_mypy_demo.basics.Point) -> None
7264
73-
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None"""
65+
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
66+
"""
7467
@overload
7568
def __init__(self, x: float, y: float) -> None:
7669
"""__init__(*args, **kwargs)
7770
Overloaded function.
7871
7972
1. __init__(self: pybind11_mypy_demo.basics.Point) -> None
8073
81-
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None"""
74+
2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None
75+
"""
76+
def as_list(self) -> List[float]:
77+
"""as_list(self: pybind11_mypy_demo.basics.Point) -> List[float]"""
8278
@overload
8379
def distance_to(self, x: float, y: float) -> float:
8480
"""distance_to(*args, **kwargs)
8581
Overloaded function.
8682
8783
1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float
8884
89-
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float"""
85+
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
86+
"""
9087
@overload
9188
def distance_to(self, other: Point) -> float:
9289
"""distance_to(*args, **kwargs)
9390
Overloaded function.
9491
9592
1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float
9693
97-
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float"""
94+
2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float
95+
"""
9896
@property
9997
def length(self) -> float: ...
10098

10199
def answer() -> int:
102100
'''answer() -> int
103101
104-
answer docstring, with end quote"'''
102+
answer docstring, with end quote"
103+
'''
105104
def midpoint(left: float, right: float) -> float:
106105
"""midpoint(left: float, right: float) -> float"""
107106
def sum(arg0: int, arg1: int) -> int:
108107
'''sum(arg0: int, arg1: int) -> int
109108
110-
multiline docstring test, edge case quotes """\'\'\''''
109+
multiline docstring test, edge case quotes """\'\'\'
110+
'''
111111
def weighted_midpoint(left: float, right: float, alpha: float = ...) -> float:
112112
"""weighted_midpoint(left: float, right: float, alpha: float = 0.5) -> float"""

test-data/pybind11_mypy_demo/stubgen/pybind11_mypy_demo/basics.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import ClassVar, overload
1+
from typing import ClassVar, List, overload
22

33
PI: float
44
__version__: str
@@ -47,6 +47,7 @@ class Point:
4747
def __init__(self) -> None: ...
4848
@overload
4949
def __init__(self, x: float, y: float) -> None: ...
50+
def as_list(self) -> List[float]: ...
5051
@overload
5152
def distance_to(self, x: float, y: float) -> float: ...
5253
@overload

0 commit comments

Comments
 (0)