diff --git a/docs/source/stubgen.rst b/docs/source/stubgen.rst index f06c9c066bb7..2de0743572e7 100644 --- a/docs/source/stubgen.rst +++ b/docs/source/stubgen.rst @@ -163,6 +163,11 @@ Additional flags Instead, only export imported names that are not referenced in the module that contains the import. +.. option:: --include-docstrings + + Include docstrings in stubs. This will add docstrings to Python function and + classes stubs and to C extension function stubs. + .. option:: --search-path PATH Specify module search directories, separated by colons (only used if diff --git a/misc/test-stubgenc.sh b/misc/test-stubgenc.sh index 7da135f0bf16..7713e1b04e43 100755 --- a/misc/test-stubgenc.sh +++ b/misc/test-stubgenc.sh @@ -3,17 +3,33 @@ set -e set -x -cd "$(dirname $0)/.." +cd "$(dirname "$0")/.." # Install dependencies, demo project and mypy python -m pip install -r test-requirements.txt python -m pip install ./test-data/pybind11_mypy_demo python -m pip install . -# Remove expected stubs and generate new inplace -STUBGEN_OUTPUT_FOLDER=./test-data/pybind11_mypy_demo/stubgen -rm -rf $STUBGEN_OUTPUT_FOLDER/* -stubgen -p pybind11_mypy_demo -o $STUBGEN_OUTPUT_FOLDER +EXIT=0 -# Compare generated stubs to expected ones -git diff --exit-code $STUBGEN_OUTPUT_FOLDER +# performs the stubgenc test +# first argument is the test result folder +# everything else is passed to stubgen as its arguments +function stubgenc_test() { + # Remove expected stubs and generate new inplace + STUBGEN_OUTPUT_FOLDER=./test-data/pybind11_mypy_demo/$1 + rm -rf "${STUBGEN_OUTPUT_FOLDER:?}/*" + stubgen -o "$STUBGEN_OUTPUT_FOLDER" "${@:2}" + + # Compare generated stubs to expected ones + if ! git diff --exit-code "$STUBGEN_OUTPUT_FOLDER"; + then + EXIT=$? + fi +} + +# create stubs without docstrings +stubgenc_test stubgen -p pybind11_mypy_demo +# create stubs with docstrings +stubgenc_test stubgen-include-docs -p pybind11_mypy_demo --include-docstrings +exit $EXIT diff --git a/mypy/fastparse.py b/mypy/fastparse.py index f7a98e9b2b8f..3a26cfe7d6ff 100644 --- a/mypy/fastparse.py +++ b/mypy/fastparse.py @@ -1008,6 +1008,8 @@ def do_func_def( # FuncDef overrides set_line -- can't use self.set_line func_def.set_line(lineno, n.col_offset, end_line, end_column) retval = func_def + if self.options.include_docstrings: + func_def.docstring = ast3.get_docstring(n, clean=False) self.class_and_function_stack.pop() return retval @@ -1121,6 +1123,8 @@ def visit_ClassDef(self, n: ast3.ClassDef) -> ClassDef: cdef.line = n.lineno cdef.deco_line = n.decorator_list[0].lineno if n.decorator_list else None + if self.options.include_docstrings: + cdef.docstring = ast3.get_docstring(n, clean=False) cdef.column = n.col_offset cdef.end_line = getattr(n, "end_lineno", None) cdef.end_column = getattr(n, "end_col_offset", None) diff --git a/mypy/nodes.py b/mypy/nodes.py index ebd222f4f253..452a4f643255 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -751,6 +751,7 @@ class FuncDef(FuncItem, SymbolNode, Statement): "is_mypy_only", # Present only when a function is decorated with @typing.datasclass_transform or similar "dataclass_transform_spec", + "docstring", ) __match_args__ = ("name", "arguments", "type", "body") @@ -779,6 +780,7 @@ def __init__( # Definitions that appear in if TYPE_CHECKING are marked with this flag. self.is_mypy_only = False self.dataclass_transform_spec: DataclassTransformSpec | None = None + self.docstring: str | None = None @property def name(self) -> str: @@ -1081,6 +1083,7 @@ class ClassDef(Statement): "analyzed", "has_incompatible_baseclass", "deco_line", + "docstring", "removed_statements", ) @@ -1127,6 +1130,7 @@ def __init__( self.has_incompatible_baseclass = False # Used for error reporting (to keep backwad compatibility with pre-3.8) self.deco_line: int | None = None + self.docstring: str | None = None self.removed_statements = [] @property diff --git a/mypy/options.py b/mypy/options.py index 75343acd38bb..46039b4d457f 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -279,6 +279,12 @@ def __init__(self) -> None: # mypy. (Like mypyc.) self.preserve_asts = False + # If True, function and class docstrings will be extracted and retained. + # This isn't exposed as a command line option + # because it is intended for software integrating with + # mypy. (Like stubgen.) + self.include_docstrings = False + # Paths of user plugins self.plugins: list[str] = [] diff --git a/mypy/stubgen.py b/mypy/stubgen.py index a77ee738d56f..b6fc3e8b7377 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -243,6 +243,7 @@ def __init__( verbose: bool, quiet: bool, export_less: bool, + include_docstrings: bool, ) -> None: # See parse_options for descriptions of the flags. self.pyversion = pyversion @@ -261,6 +262,7 @@ def __init__( self.verbose = verbose self.quiet = quiet self.export_less = export_less + self.include_docstrings = include_docstrings class StubSource: @@ -624,6 +626,7 @@ def __init__( include_private: bool = False, analyzed: bool = False, export_less: bool = False, + include_docstrings: bool = False, ) -> None: # Best known value of __all__. self._all_ = _all_ @@ -638,6 +641,7 @@ def __init__( self._state = EMPTY self._toplevel_names: list[str] = [] self._include_private = include_private + self._include_docstrings = include_docstrings self._current_class: ClassDef | None = None self.import_tracker = ImportTracker() # Was the tree semantically analysed before? @@ -809,7 +813,13 @@ def visit_func_def(self, o: FuncDef) -> None: retfield = " -> " + retname self.add(", ".join(args)) - self.add(f"){retfield}: ...\n") + self.add(f"){retfield}:") + if self._include_docstrings and o.docstring: + docstring = mypy.util.quote_docstring(o.docstring) + self.add(f"\n{self._indent} {docstring}\n") + else: + self.add(" ...\n") + self._state = FUNC def is_none_expr(self, expr: Expression) -> bool: @@ -910,8 +920,11 @@ def visit_class_def(self, o: ClassDef) -> None: if base_types: self.add(f"({', '.join(base_types)})") self.add(":\n") - n = len(self._output) self._indent += " " + if self._include_docstrings and o.docstring: + docstring = mypy.util.quote_docstring(o.docstring) + self.add(f"{self._indent}{docstring}\n") + n = len(self._output) self._vars.append([]) super().visit_class_def(o) self._indent = self._indent[:-4] @@ -920,7 +933,8 @@ def visit_class_def(self, o: ClassDef) -> None: if len(self._output) == n: if self._state == EMPTY_CLASS and sep is not None: self._output[sep] = "" - self._output[-1] = self._output[-1][:-1] + " ...\n" + if not (self._include_docstrings and o.docstring): + self._output[-1] = self._output[-1][:-1] + " ...\n" self._state = EMPTY_CLASS else: self._state = CLASS @@ -1710,6 +1724,7 @@ def mypy_options(stubgen_options: Options) -> MypyOptions: options.show_traceback = True options.transform_source = remove_misplaced_type_comments options.preserve_asts = True + options.include_docstrings = stubgen_options.include_docstrings # Override cache_dir if provided in the environment environ_cache_dir = os.getenv("MYPY_CACHE_DIR", "") @@ -1773,6 +1788,7 @@ def generate_stub_from_ast( parse_only: bool = False, include_private: bool = False, export_less: bool = False, + include_docstrings: bool = False, ) -> None: """Use analysed (or just parsed) AST to generate type stub for single file. @@ -1784,6 +1800,7 @@ def generate_stub_from_ast( include_private=include_private, analyzed=not parse_only, export_less=export_less, + include_docstrings=include_docstrings, ) assert mod.ast is not None, "This function must be used only with analyzed modules" mod.ast.accept(gen) @@ -1845,7 +1862,12 @@ def generate_stubs(options: Options) -> None: files.append(target) with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): generate_stub_from_ast( - mod, target, options.parse_only, options.include_private, options.export_less + mod, + target, + options.parse_only, + options.include_private, + options.export_less, + include_docstrings=options.include_docstrings, ) # Separately analyse C modules using different logic. @@ -1859,7 +1881,11 @@ def generate_stubs(options: Options) -> None: files.append(target) with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): generate_stub_for_c_module( - mod.module, target, known_modules=all_modules, sig_generators=sig_generators + mod.module, + target, + known_modules=all_modules, + sig_generators=sig_generators, + include_docstrings=options.include_docstrings, ) num_modules = len(py_modules) + len(c_modules) if not options.quiet and num_modules > 0: @@ -1913,6 +1939,11 @@ def parse_options(args: list[str]) -> Options: action="store_true", help="don't implicitly export all names imported from other modules in the same package", ) + parser.add_argument( + "--include-docstrings", + action="store_true", + help="include existing docstrings with the stubs", + ) parser.add_argument("-v", "--verbose", action="store_true", help="show more verbose messages") parser.add_argument("-q", "--quiet", action="store_true", help="show fewer messages") parser.add_argument( @@ -1993,6 +2024,7 @@ def parse_options(args: list[str]) -> Options: verbose=ns.verbose, quiet=ns.quiet, export_less=ns.export_less, + include_docstrings=ns.include_docstrings, ) diff --git a/mypy/stubgenc.py b/mypy/stubgenc.py index 8aa1fb3d2c0a..31487f9d0dcf 100755 --- a/mypy/stubgenc.py +++ b/mypy/stubgenc.py @@ -14,6 +14,7 @@ from types import ModuleType from typing import Any, Final, Iterable, Mapping +import mypy.util from mypy.moduleinspect import is_c_module from mypy.stubdoc import ( ArgSig, @@ -169,6 +170,7 @@ def generate_stub_for_c_module( target: str, known_modules: list[str], sig_generators: Iterable[SignatureGenerator], + include_docstrings: bool = False, ) -> None: """Generate stub for C module. @@ -201,6 +203,7 @@ def generate_stub_for_c_module( known_modules=known_modules, imports=imports, sig_generators=sig_generators, + include_docstrings=include_docstrings, ) done.add(name) types: list[str] = [] @@ -216,6 +219,7 @@ def generate_stub_for_c_module( known_modules=known_modules, imports=imports, sig_generators=sig_generators, + include_docstrings=include_docstrings, ) done.add(name) variables = [] @@ -319,15 +323,17 @@ def generate_c_function_stub( self_var: str | None = None, cls: type | None = None, class_name: str | None = None, + include_docstrings: bool = False, ) -> None: """Generate stub for a single function or method. - The result (always a single line) will be appended to 'output'. + The result will be appended to 'output'. If necessary, any required names will be added to 'imports'. The 'class_name' is used to find signature of __init__ or __new__ in 'class_sigs'. """ inferred: list[FunctionSig] | None = None + docstr: str | None = None if class_name: # method: assert cls is not None, "cls should be provided for methods" @@ -379,13 +385,19 @@ def generate_c_function_stub( # a sig generator indicates @classmethod by specifying the cls arg if class_name and signature.args and signature.args[0].name == "cls": output.append("@classmethod") - output.append( - "def {function}({args}) -> {ret}: ...".format( - function=name, - args=", ".join(args), - ret=strip_or_import(signature.ret_type, module, known_modules, imports), - ) + output_signature = "def {function}({args}) -> {ret}:".format( + function=name, + args=", ".join(args), + ret=strip_or_import(signature.ret_type, module, known_modules, imports), ) + if include_docstrings and docstr: + docstr_quoted = mypy.util.quote_docstring(docstr.strip()) + docstr_indented = "\n ".join(docstr_quoted.split("\n")) + output.append(output_signature) + output.extend(f" {docstr_indented}".split("\n")) + else: + output_signature += " ..." + output.append(output_signature) def strip_or_import( @@ -493,6 +505,7 @@ def generate_c_type_stub( known_modules: list[str], imports: list[str], sig_generators: Iterable[SignatureGenerator], + include_docstrings: bool = False, ) -> None: """Generate stub for a single class using runtime introspection. @@ -535,6 +548,7 @@ def generate_c_type_stub( cls=obj, class_name=class_name, sig_generators=sig_generators, + include_docstrings=include_docstrings, ) elif is_c_property(raw_value): generate_c_property_stub( @@ -557,6 +571,7 @@ def generate_c_type_stub( imports=imports, known_modules=known_modules, sig_generators=sig_generators, + include_docstrings=include_docstrings, ) else: attrs.append((attr, value)) diff --git a/mypy/util.py b/mypy/util.py index 8a079c5256bc..d0f2f8c6cc36 100644 --- a/mypy/util.py +++ b/mypy/util.py @@ -809,3 +809,20 @@ def plural_s(s: int | Sized) -> str: return "s" else: return "" + + +def quote_docstring(docstr: str) -> str: + """Returns docstring correctly encapsulated in a single or double quoted form.""" + # Uses repr to get hint on the correct quotes and escape everything properly. + # Creating multiline string for prettier output. + docstr_repr = "\n".join(re.split(r"(?<=[^\\])\\n", repr(docstr))) + + if docstr_repr.startswith("'"): + # Enforce double quotes when it's safe to do so. + # That is when double quotes are not in the string + # or when it doesn't end with a single quote. + if '"' not in docstr_repr[1:-1] and docstr_repr[-2] != "'": + return f'"""{docstr_repr[1:-1]}"""' + return f"''{docstr_repr}''" + else: + return f'""{docstr_repr}""' diff --git a/test-data/pybind11_mypy_demo/src/main.cpp b/test-data/pybind11_mypy_demo/src/main.cpp index ff0f93bf7017..00e5b2f4e871 100644 --- a/test-data/pybind11_mypy_demo/src/main.cpp +++ b/test-data/pybind11_mypy_demo/src/main.cpp @@ -119,8 +119,8 @@ void bind_basics(py::module& basics) { using namespace basics; // Functions - basics.def("answer", &answer); - basics.def("sum", &sum); + basics.def("answer", &answer, "answer docstring, with end quote\""); // tests explicit docstrings + basics.def("sum", &sum, "multiline docstring test, edge case quotes \"\"\"'''"); basics.def("midpoint", &midpoint, py::arg("left"), py::arg("right")); basics.def("weighted_midpoint", weighted_midpoint, py::arg("left"), py::arg("right"), py::arg("alpha")=0.5); diff --git a/test-data/pybind11_mypy_demo/stubgen-include-docs/pybind11_mypy_demo/__init__.pyi b/test-data/pybind11_mypy_demo/stubgen-include-docs/pybind11_mypy_demo/__init__.pyi new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/test-data/pybind11_mypy_demo/stubgen-include-docs/pybind11_mypy_demo/basics.pyi b/test-data/pybind11_mypy_demo/stubgen-include-docs/pybind11_mypy_demo/basics.pyi new file mode 100644 index 000000000000..676d7f6d3f15 --- /dev/null +++ b/test-data/pybind11_mypy_demo/stubgen-include-docs/pybind11_mypy_demo/basics.pyi @@ -0,0 +1,112 @@ +from typing import ClassVar + +from typing import overload +PI: float + +class Point: + class AngleUnit: + __members__: ClassVar[dict] = ... # read-only + __entries: ClassVar[dict] = ... + degree: ClassVar[Point.AngleUnit] = ... + radian: ClassVar[Point.AngleUnit] = ... + def __init__(self, value: int) -> None: + """__init__(self: pybind11_mypy_demo.basics.Point.AngleUnit, value: int) -> None""" + def __eq__(self, other: object) -> bool: + """__eq__(self: object, other: object) -> bool""" + def __getstate__(self) -> int: + """__getstate__(self: object) -> int""" + def __hash__(self) -> int: + """__hash__(self: object) -> int""" + def __index__(self) -> int: + """__index__(self: pybind11_mypy_demo.basics.Point.AngleUnit) -> int""" + def __int__(self) -> int: + """__int__(self: pybind11_mypy_demo.basics.Point.AngleUnit) -> int""" + def __ne__(self, other: object) -> bool: + """__ne__(self: object, other: object) -> bool""" + def __setstate__(self, state: int) -> None: + """__setstate__(self: pybind11_mypy_demo.basics.Point.AngleUnit, state: int) -> None""" + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + + class LengthUnit: + __members__: ClassVar[dict] = ... # read-only + __entries: ClassVar[dict] = ... + inch: ClassVar[Point.LengthUnit] = ... + mm: ClassVar[Point.LengthUnit] = ... + pixel: ClassVar[Point.LengthUnit] = ... + def __init__(self, value: int) -> None: + """__init__(self: pybind11_mypy_demo.basics.Point.LengthUnit, value: int) -> None""" + def __eq__(self, other: object) -> bool: + """__eq__(self: object, other: object) -> bool""" + def __getstate__(self) -> int: + """__getstate__(self: object) -> int""" + def __hash__(self) -> int: + """__hash__(self: object) -> int""" + def __index__(self) -> int: + """__index__(self: pybind11_mypy_demo.basics.Point.LengthUnit) -> int""" + def __int__(self) -> int: + """__int__(self: pybind11_mypy_demo.basics.Point.LengthUnit) -> int""" + def __ne__(self, other: object) -> bool: + """__ne__(self: object, other: object) -> bool""" + def __setstate__(self, state: int) -> None: + """__setstate__(self: pybind11_mypy_demo.basics.Point.LengthUnit, state: int) -> None""" + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + angle_unit: ClassVar[Point.AngleUnit] = ... + length_unit: ClassVar[Point.LengthUnit] = ... + x_axis: ClassVar[Point] = ... # read-only + y_axis: ClassVar[Point] = ... # read-only + origin: ClassVar[Point] = ... + x: float + y: float + @overload + def __init__(self) -> None: + """__init__(*args, **kwargs) + Overloaded function. + + 1. __init__(self: pybind11_mypy_demo.basics.Point) -> None + + 2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None""" + @overload + def __init__(self, x: float, y: float) -> None: + """__init__(*args, **kwargs) + Overloaded function. + + 1. __init__(self: pybind11_mypy_demo.basics.Point) -> None + + 2. __init__(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> None""" + @overload + def distance_to(self, x: float, y: float) -> float: + """distance_to(*args, **kwargs) + Overloaded function. + + 1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float + + 2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float""" + @overload + def distance_to(self, other: Point) -> float: + """distance_to(*args, **kwargs) + Overloaded function. + + 1. distance_to(self: pybind11_mypy_demo.basics.Point, x: float, y: float) -> float + + 2. distance_to(self: pybind11_mypy_demo.basics.Point, other: pybind11_mypy_demo.basics.Point) -> float""" + @property + def length(self) -> float: ... + +def answer() -> int: + '''answer() -> int + + answer docstring, with end quote"''' +def midpoint(left: float, right: float) -> float: + """midpoint(left: float, right: float) -> float""" +def sum(arg0: int, arg1: int) -> int: + '''sum(arg0: int, arg1: int) -> int + + multiline docstring test, edge case quotes """\'\'\'''' +def weighted_midpoint(left: float, right: float, alpha: float = ...) -> float: + """weighted_midpoint(left: float, right: float, alpha: float = 0.5) -> float""" diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index f6b71a994153..774a17b76161 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -3183,6 +3183,85 @@ def f2(): def f1(): ... def f2(): ... +[case testIncludeDocstrings] +# flags: --include-docstrings +class A: + """class docstring + + a multiline docstring""" + def func(): + """func docstring + don't forget to indent""" + ... + def nodoc(): + ... +class B: + def quoteA(): + '''func docstring with quotes"""\\n + and an end quote\'''' + ... + def quoteB(): + '''func docstring with quotes""" + \'\'\' + and an end quote\\"''' + ... + def quoteC(): + """func docstring with end quote\\\"""" + ... + def quoteD(): + r'''raw with quotes\"''' + ... +[out] +class A: + """class docstring + + a multiline docstring""" + def func() -> None: + """func docstring + don't forget to indent""" + def nodoc() -> None: ... + +class B: + def quoteA() -> None: + '''func docstring with quotes"""\\n + and an end quote\'''' + def quoteB() -> None: + '''func docstring with quotes""" + \'\'\' + and an end quote\\"''' + def quoteC() -> None: + '''func docstring with end quote\\"''' + def quoteD() -> None: + '''raw with quotes\\"''' + +[case testIgnoreDocstrings] +class A: + """class docstring + + a multiline docstring""" + def func(): + """func docstring + + don't forget to indent""" + def nodoc(): + ... + +class B: + def func(): + """func docstring""" + ... + def nodoc(): + ... + +[out] +class A: + def func() -> None: ... + def nodoc() -> None: ... + +class B: + def func() -> None: ... + def nodoc() -> None: ... + [case testKnownMagicMethodsReturnTypes] class Some: def __len__(self): ...