Skip to content

Commit 2807271

Browse files
authored
[CI] enforce import regex instead of re (#18665)
Signed-off-by: Aaron Pham <[email protected]>
1 parent b9018a3 commit 2807271

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ repos:
128128
name: Update Dockerfile dependency graph
129129
entry: tools/update-dockerfile-graph.sh
130130
language: script
131+
- id: enforce-import-regex-instead-of-re
132+
name: Enforce import regex as re
133+
entry: python tools/enforce_regex_import.py
134+
language: python
135+
types: [python]
136+
pass_filenames: false
137+
additional_dependencies: [regex]
131138
# forbid directly import triton
132139
- id: forbid-direct-triton-import
133140
name: "Forbid direct 'import triton'"

tools/enforce_regex_import.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from __future__ import annotations
3+
4+
import subprocess
5+
from pathlib import Path
6+
7+
import regex as re
8+
9+
FORBIDDEN_PATTERNS = re.compile(
10+
r'^\s*(?:import\s+re(?:$|\s|,)|from\s+re\s+import)')
11+
ALLOWED_PATTERNS = [
12+
re.compile(r'^\s*import\s+regex\s+as\s+re\s*$'),
13+
re.compile(r'^\s*import\s+regex\s*$'),
14+
]
15+
16+
17+
def get_staged_python_files() -> list[str]:
18+
try:
19+
result = subprocess.run(
20+
['git', 'diff', '--cached', '--name-only', '--diff-filter=AM'],
21+
capture_output=True,
22+
text=True,
23+
check=True)
24+
files = result.stdout.strip().split(
25+
'\n') if result.stdout.strip() else []
26+
return [f for f in files if f.endswith('.py')]
27+
except subprocess.CalledProcessError:
28+
return []
29+
30+
31+
def is_forbidden_import(line: str) -> bool:
32+
line = line.strip()
33+
return bool(
34+
FORBIDDEN_PATTERNS.match(line)
35+
and not any(pattern.match(line) for pattern in ALLOWED_PATTERNS))
36+
37+
38+
def check_file(filepath: str) -> list[tuple[int, str]]:
39+
violations = []
40+
try:
41+
with open(filepath, encoding='utf-8') as f:
42+
for line_num, line in enumerate(f, 1):
43+
if is_forbidden_import(line):
44+
violations.append((line_num, line.strip()))
45+
except (OSError, UnicodeDecodeError):
46+
pass
47+
return violations
48+
49+
50+
def main() -> int:
51+
files = get_staged_python_files()
52+
if not files:
53+
return 0
54+
55+
total_violations = 0
56+
57+
for filepath in files:
58+
if not Path(filepath).exists():
59+
continue
60+
61+
violations = check_file(filepath)
62+
if violations:
63+
print(f"\n{filepath}:")
64+
for line_num, line in violations:
65+
print(f" Line {line_num}: {line}")
66+
total_violations += 1
67+
68+
if total_violations > 0:
69+
print(f"\n💡 Found {total_violations} violation(s).")
70+
print("❌ Please replace 'import re' with 'import regex as re'")
71+
print(
72+
" Also replace 'from re import ...' with 'from regex import ...'"
73+
) # noqa: E501
74+
print("✅ Allowed imports:")
75+
print(" - import regex as re")
76+
print(" - import regex") # noqa: E501
77+
return 1
78+
79+
return 0
80+
81+
82+
if __name__ == "__main__":
83+
raise SystemExit(main())

0 commit comments

Comments
 (0)