Skip to content

Commit c7e93a9

Browse files
authored
Fix py.typed in namespace packages (#122)
Previously, the py.typed file was incorrectly put in the namespace package, instead of the non-namespace sub-packages.
1 parent 4185d1a commit c7e93a9

File tree

2 files changed

+93
-14
lines changed

2 files changed

+93
-14
lines changed

stub_uploader/build_wheel.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,21 +135,71 @@ def __init__(self, base_path: Path, package_data: dict[str, list[str]]) -> None:
135135
self.base_path = base_path
136136
self.package_data = package_data
137137

138+
def package_path(self, package_name: str) -> Path:
139+
"""Return the path of a given package name.
140+
141+
The package name can use dotted notation to address sub-packages.
142+
The top-level package name can optionally include the "-stubs" suffix.
143+
"""
144+
top_level, *sub_packages = package_name.split(".")
145+
if top_level.endswith(SUFFIX):
146+
top_level = top_level[: -len(SUFFIX)]
147+
return self.base_path.joinpath(top_level, *sub_packages)
148+
149+
def is_single_file_package(self, package_name: str) -> bool:
150+
filename = package_name.split("-")[0] + ".pyi"
151+
return (self.base_path / filename).exists()
152+
138153
@property
139154
def top_level_packages(self) -> list[str]:
140155
"""Top level package names.
141156
142-
These are the packages that are not subpackages of any other package
143-
and includes namespace packages.
157+
These are the packages that are not sub-packages of any other package
158+
and includes namespace packages. Their name includes the "-stubs"
159+
suffix.
144160
"""
145161
return list(self.package_data.keys())
146162

163+
@property
164+
def top_level_non_namespace_packages(self) -> list[str]:
165+
"""Top level non-namespace package names.
166+
167+
This will return all packages that are not subpackages of any other
168+
package, other than namespace packages in dotted notation, e.g. if
169+
"flying" is a top level namespace package, and "circus" is a
170+
non-namespace sub-package, this will return ["flying.circus"].
171+
"""
172+
packages: list[str] = []
173+
for top_level in self.top_level_packages:
174+
if self.is_single_file_package(top_level):
175+
packages.append(top_level)
176+
else:
177+
packages.extend(self._find_non_namespace_sub_packages(top_level))
178+
return packages
179+
180+
def _find_non_namespace_sub_packages(self, package: str) -> list[str]:
181+
path = self.package_path(package)
182+
if is_namespace_package(path):
183+
sub_packages: list[str] = []
184+
for entry in path.iterdir():
185+
if entry.is_dir():
186+
sub_name = package + "." + entry.name
187+
sub_packages.extend(self._find_non_namespace_sub_packages(sub_name))
188+
return sub_packages
189+
else:
190+
return [package]
191+
147192
def add_file(self, package: str, filename: str, file_contents: str) -> None:
148193
"""Add a file to a package."""
149-
entry_path = self.base_path / package
194+
top_level = package.split(".")[0]
195+
entry_path = self.package_path(package)
150196
entry_path.mkdir(exist_ok=True)
151197
(entry_path / filename).write_text(file_contents)
152-
self.package_data[package].append(filename)
198+
self.package_data[top_level].append(filename)
199+
200+
201+
def is_namespace_package(path: Path) -> bool:
202+
return not (path / "__init__.pyi").exists()
153203

154204

155205
def find_stub_files(top: str) -> list[str]:
@@ -166,6 +216,8 @@ def find_stub_files(top: str) -> list[str]:
166216
name.isidentifier()
167217
), "All file names must be valid Python modules"
168218
result.append(os.path.relpath(os.path.join(root, file), top))
219+
elif file == "py.typed":
220+
result.append(os.path.relpath(os.path.join(root, file), top))
169221
elif not file.endswith((".md", ".rst")):
170222
# Allow having README docs, as some stubs have these (e.g. click).
171223
if (
@@ -257,7 +309,7 @@ def collect_package_data(base_path: Path) -> PackageData:
257309

258310

259311
def add_partial_markers(pkg_data: PackageData) -> None:
260-
for package in pkg_data.top_level_packages:
312+
for package in pkg_data.top_level_non_namespace_packages:
261313
pkg_data.add_file(package, "py.typed", "partial\n")
262314

263315

tests/test_integration.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from stub_uploader.ts_data import read_typeshed_data
2929

3030
TYPESHED = "../typeshed"
31+
THIRD_PARTY_PATH = Path(TYPESHED) / THIRD_PARTY_NAMESPACE
3132

3233

3334
def test_fetch_pypi_versions() -> None:
@@ -37,19 +38,15 @@ def test_fetch_pypi_versions() -> None:
3738
assert not get_version.fetch_pypi_versions("types-nonexistent-distribution")
3839

3940

40-
@pytest.mark.parametrize(
41-
"distribution", os.listdir(os.path.join(TYPESHED, THIRD_PARTY_NAMESPACE))
42-
)
41+
@pytest.mark.parametrize("distribution", os.listdir(THIRD_PARTY_PATH))
4342
def test_build_wheel(distribution: str) -> None:
4443
"""Check that we can build wheels for all distributions."""
4544
tmp_dir = build_wheel.main(TYPESHED, distribution, version="1.1.1")
4645
assert tmp_dir.endswith("/dist")
4746
assert list(os.listdir(tmp_dir)) # check it is not empty
4847

4948

50-
@pytest.mark.parametrize(
51-
"distribution", os.listdir(os.path.join(TYPESHED, THIRD_PARTY_NAMESPACE))
52-
)
49+
@pytest.mark.parametrize("distribution", os.listdir(THIRD_PARTY_PATH))
5350
def test_version_increment(distribution: str) -> None:
5451
get_version.determine_stub_version(read_metadata(TYPESHED, distribution))
5552

@@ -145,9 +142,7 @@ def test_dependency_order_single() -> None:
145142
]
146143

147144

148-
@pytest.mark.parametrize(
149-
"distribution", os.listdir(os.path.join(TYPESHED, THIRD_PARTY_NAMESPACE))
150-
)
145+
@pytest.mark.parametrize("distribution", os.listdir(THIRD_PARTY_PATH))
151146
def test_recursive_verify(distribution: str) -> None:
152147
recursive_verify(read_metadata(TYPESHED, distribution), TYPESHED)
153148

@@ -170,3 +165,35 @@ def test_verify_requires_python() -> None:
170165
InvalidRequires, match="Expected requires_python to be a '>=' specifier"
171166
):
172167
verify_requires_python("==3.10")
168+
169+
170+
@pytest.mark.parametrize(
171+
"distribution,expected_packages",
172+
[
173+
("pytz", ["pytz-stubs"]),
174+
("Pillow", ["PIL-stubs"]),
175+
("protobuf", ["google-stubs"]),
176+
("google-cloud-ndb", ["google-stubs"]),
177+
],
178+
)
179+
def test_pkg_data_top_level_packages(
180+
distribution: str, expected_packages: list[str]
181+
) -> None:
182+
pkg_data = build_wheel.collect_package_data(THIRD_PARTY_PATH / distribution)
183+
assert pkg_data.top_level_packages == expected_packages
184+
185+
186+
@pytest.mark.parametrize(
187+
"distribution,expected_packages",
188+
[
189+
("pytz", ["pytz-stubs"]),
190+
("Pillow", ["PIL-stubs"]),
191+
("protobuf", ["google-stubs.protobuf"]),
192+
("google-cloud-ndb", ["google-stubs.cloud.ndb"]),
193+
],
194+
)
195+
def test_pkg_data_non_namespace_packages(
196+
distribution: str, expected_packages: list[str]
197+
) -> None:
198+
pkg_data = build_wheel.collect_package_data(THIRD_PARTY_PATH / distribution)
199+
assert pkg_data.top_level_non_namespace_packages == expected_packages

0 commit comments

Comments
 (0)