Skip to content

Commit fc8f33e

Browse files
CLI: add replace package name (#372)
* add replace package name --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 652d939 commit fc8f33e

File tree

4 files changed

+59
-8
lines changed

4 files changed

+59
-8
lines changed

CHANGELOG.md

+5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased] - 2025-MM-DD
99

10+
### Added
11+
12+
- CLI: replace package name in requirements ([#372](https://github.com/Lightning-AI/utilities/pull/372))
13+
14+
1015
### Changed
1116

1217
-

src/lightning_utilities/cli/__main__.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
#
55

66
import lightning_utilities
7-
from lightning_utilities.cli.dependencies import prune_pkgs_in_requirements, replace_oldest_ver
7+
from lightning_utilities.cli.dependencies import (
8+
prune_packages_in_requirements,
9+
replace_oldest_version,
10+
replace_package_in_requirements,
11+
)
812

913

1014
def main() -> None:
@@ -13,8 +17,9 @@ def main() -> None:
1317

1418
Fire({
1519
"requirements": {
16-
"prune-pkgs": prune_pkgs_in_requirements,
17-
"set-oldest": replace_oldest_ver,
20+
"prune-pkgs": prune_packages_in_requirements,
21+
"set-oldest": replace_oldest_version,
22+
"replace-pkg": replace_package_in_requirements,
1823
},
1924
"version": lambda: print(lightning_utilities.__version__),
2025
})

src/lightning_utilities/cli/dependencies.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55
import glob
66
import os.path
7+
import re
78
from collections.abc import Sequence
89
from pprint import pprint
910
from typing import Union
@@ -15,7 +16,7 @@
1516
REQUIREMENT_FILES_ALL += [REQUIREMENT_ROOT]
1617

1718

18-
def prune_pkgs_in_requirements(
19+
def prune_packages_in_requirements(
1920
packages: Union[str, Sequence[str]], req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
2021
) -> None:
2122
"""Remove some packages from given requirement files."""
@@ -49,9 +50,35 @@ def _replace_min(fname: str) -> None:
4950
fw.write(req)
5051

5152

52-
def replace_oldest_ver(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None:
53+
def replace_oldest_version(req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL) -> None:
5354
"""Replace the min package version by fixed one."""
5455
if isinstance(req_files, str):
5556
req_files = [req_files]
5657
for fname in req_files:
5758
_replace_min(fname)
59+
60+
61+
def _replace_package_name(requirements: list[str], old_package: str, new_package: str) -> list[str]:
62+
"""Replace one package by another with same version in given requirement file.
63+
64+
>>> _replace_package_name(["torch>=1.0 # comment", "torchvision>=0.2", "torchtext <0.3"], "torch", "pytorch")
65+
['pytorch>=1.0 # comment', 'torchvision>=0.2', 'torchtext <0.3']
66+
67+
"""
68+
for i, req in enumerate(requirements):
69+
requirements[i] = re.sub(r"^" + re.escape(old_package) + r"(?=[ <=>#]|$)", new_package, req)
70+
return requirements
71+
72+
73+
def replace_package_in_requirements(
74+
old_package: str, new_package: str, req_files: Union[str, Sequence[str]] = REQUIREMENT_FILES_ALL
75+
) -> None:
76+
"""Replace one package by another with same version in given requirement files."""
77+
if isinstance(req_files, str):
78+
req_files = [req_files]
79+
for fname in req_files:
80+
with open(fname) as fopen:
81+
reqs = fopen.readlines()
82+
reqs = _replace_package_name(reqs, old_package, new_package)
83+
with open(fname, "w") as fw:
84+
fw.writelines(reqs)
+17-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from pathlib import Path
22

3-
from lightning_utilities.cli.dependencies import prune_pkgs_in_requirements, replace_oldest_ver
3+
from lightning_utilities.cli.dependencies import (
4+
prune_packages_in_requirements,
5+
replace_oldest_version,
6+
replace_package_in_requirements,
7+
)
48

59
_PATH_ROOT = Path(__file__).parent.parent.parent
610

@@ -9,7 +13,7 @@ def test_prune_packages(tmpdir):
913
req_file = tmpdir / "requirements.txt"
1014
with open(req_file, "w") as fp:
1115
fp.writelines(["fire\n", "abc>=0.1\n"])
12-
prune_pkgs_in_requirements("abc", req_files=[str(req_file)])
16+
prune_packages_in_requirements("abc", req_files=[str(req_file)])
1317
with open(req_file) as fp:
1418
lines = fp.readlines()
1519
assert lines == ["fire\n"]
@@ -19,7 +23,17 @@ def test_oldest_packages(tmpdir):
1923
req_file = tmpdir / "requirements.txt"
2024
with open(req_file, "w") as fp:
2125
fp.writelines(["fire>0.2\n", "abc>=0.1\n"])
22-
replace_oldest_ver(req_files=[str(req_file)])
26+
replace_oldest_version(req_files=[str(req_file)])
2327
with open(req_file) as fp:
2428
lines = fp.readlines()
2529
assert lines == ["fire>0.2\n", "abc==0.1\n"]
30+
31+
32+
def test_replace_packages(tmpdir):
33+
req_file = tmpdir / "requirements.txt"
34+
with open(req_file, "w") as fp:
35+
fp.writelines(["torchvision>=0.2\n", "torch>=1.0 # comment\n", "torchtext <0.3\n"])
36+
replace_package_in_requirements(old_package="torch", new_package="pytorch", req_files=[str(req_file)])
37+
with open(req_file) as fp:
38+
lines = fp.readlines()
39+
assert lines == ["torchvision>=0.2\n", "pytorch>=1.0 # comment\n", "torchtext <0.3\n"]

0 commit comments

Comments
 (0)