Skip to content

Commit 12c3d15

Browse files
Recognize stub pyi Python files. (#2182)
Recognize stub ``pyi`` Python files. Refs pylint-dev/pylint#4987 Co-authored-by: Jacob Walls <[email protected]>
1 parent 1177acc commit 12c3d15

File tree

12 files changed

+33
-8
lines changed

12 files changed

+33
-8
lines changed

ChangeLog

+4
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ Release date: TBA
156156

157157
Refs pylint-dev/pylint#8554
158158

159+
* Recognize stub ``pyi`` Python files.
160+
161+
Refs pylint-dev/pylint#4987
162+
159163

160164
What's New in astroid 2.15.4?
161165
=============================

astroid/interpreter/_import/spec.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def find_module(
163163

164164
for entry in submodule_path:
165165
package_directory = os.path.join(entry, modname)
166-
for suffix in (".py", importlib.machinery.BYTECODE_SUFFIXES[0]):
166+
for suffix in (".py", ".pyi", importlib.machinery.BYTECODE_SUFFIXES[0]):
167167
package_file_name = "__init__" + suffix
168168
file_path = os.path.join(package_directory, package_file_name)
169169
if os.path.isfile(file_path):

astroid/modutils.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@
4444

4545

4646
if sys.platform.startswith("win"):
47-
PY_SOURCE_EXTS = ("py", "pyw")
47+
PY_SOURCE_EXTS = ("py", "pyw", "pyi")
4848
PY_COMPILED_EXTS = ("dll", "pyd")
4949
else:
50-
PY_SOURCE_EXTS = ("py",)
50+
PY_SOURCE_EXTS = ("py", "pyi")
5151
PY_COMPILED_EXTS = ("so",)
5252

5353

@@ -274,9 +274,6 @@ def _get_relative_base_path(filename: str, path_to_check: str) -> list[str] | No
274274
if os.path.normcase(real_filename).startswith(path_to_check):
275275
importable_path = real_filename
276276

277-
# if "var" in path_to_check:
278-
# breakpoint()
279-
280277
if importable_path:
281278
base_path = os.path.splitext(importable_path)[0]
282279
relative_base_path = base_path[len(path_to_check) :]
@@ -476,7 +473,7 @@ def get_module_files(
476473
continue
477474
_handle_blacklist(blacklist, dirnames, filenames)
478475
# check for __init__.py
479-
if not list_all and "__init__.py" not in filenames:
476+
if not list_all and {"__init__.py", "__init__.pyi"}.isdisjoint(filenames):
480477
dirnames[:] = ()
481478
continue
482479
for filename in filenames:
@@ -499,6 +496,8 @@ def get_source_file(filename: str, include_no_ext: bool = False) -> str:
499496
"""
500497
filename = os.path.abspath(_path_from_filename(filename))
501498
base, orig_ext = os.path.splitext(filename)
499+
if orig_ext == ".pyi" and os.path.exists(f"{base}{orig_ext}"):
500+
return f"{base}{orig_ext}"
502501
for ext in PY_SOURCE_EXTS:
503502
source_path = f"{base}.{ext}"
504503
if os.path.exists(source_path):
@@ -663,7 +662,7 @@ def _is_python_file(filename: str) -> bool:
663662
664663
.pyc and .pyo are ignored
665664
"""
666-
return filename.endswith((".py", ".so", ".pyd", ".pyw"))
665+
return filename.endswith((".py", ".pyi", ".so", ".pyd", ".pyw"))
667666

668667

669668
def _has_init(directory: str) -> str | None:

tests/test_modutils.py

+22
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ def test(self) -> None:
287287
def test_raise(self) -> None:
288288
self.assertRaises(modutils.NoSourceFile, modutils.get_source_file, "whatever")
289289

290+
def test_(self) -> None:
291+
package = resources.find("pyi_data")
292+
module = os.path.join(package, "__init__.pyi")
293+
self.assertEqual(modutils.get_source_file(module), os.path.normpath(module))
294+
290295

291296
class IsStandardModuleTest(resources.SysPathSetup, unittest.TestCase):
292297
"""
@@ -417,8 +422,12 @@ def test_success(self) -> None:
417422
assert modutils.module_in_path("data.module", datadir)
418423
assert modutils.module_in_path("data.module", (datadir,))
419424
assert modutils.module_in_path("data.module", os.path.abspath(datadir))
425+
assert modutils.module_in_path("pyi_data.module", datadir)
426+
assert modutils.module_in_path("pyi_data.module", (datadir,))
427+
assert modutils.module_in_path("pyi_data.module", os.path.abspath(datadir))
420428
# "" will evaluate to cwd
421429
assert modutils.module_in_path("data.module", "")
430+
assert modutils.module_in_path("pyi_data.module", "")
422431

423432
def test_bad_import(self) -> None:
424433
datadir = resources.find("")
@@ -496,6 +505,19 @@ def test_get_module_files_1(self) -> None:
496505
]
497506
self.assertEqual(modules, {os.path.join(package, x) for x in expected})
498507

508+
def test_get_module_files_2(self) -> None:
509+
package = resources.find("pyi_data/find_test")
510+
modules = set(modutils.get_module_files(package, []))
511+
expected = [
512+
"__init__.py",
513+
"__init__.pyi",
514+
"module.py",
515+
"module2.py",
516+
"noendingnewline.py",
517+
"nonregr.py",
518+
]
519+
self.assertEqual(modules, {os.path.join(package, x) for x in expected})
520+
499521
def test_get_all_files(self) -> None:
500522
"""Test that list_all returns all Python files from given location."""
501523
non_package = resources.find("data/notamodule")

tests/testdata/python3/pyi_data/__init__.pyi

Whitespace-only changes.

tests/testdata/python3/pyi_data/find_test/__init__.py

Whitespace-only changes.

tests/testdata/python3/pyi_data/find_test/__init__.pyi

Whitespace-only changes.

tests/testdata/python3/pyi_data/find_test/module.py

Whitespace-only changes.

tests/testdata/python3/pyi_data/find_test/module2.py

Whitespace-only changes.

tests/testdata/python3/pyi_data/find_test/noendingnewline.py

Whitespace-only changes.

tests/testdata/python3/pyi_data/find_test/nonregr.py

Whitespace-only changes.

tests/testdata/python3/pyi_data/module.py

Whitespace-only changes.

0 commit comments

Comments
 (0)