Skip to content

Commit 1a78a2c

Browse files
author
Crucifixion-Fxl
committed
[Bugfix] Migrate to REGEX Library to prevent catastrophic backtracking
Signed-off-by: Crucifixion-Fxl <[email protected]>
1 parent 47fda6d commit 1a78a2c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+153
-150
lines changed

.github/scripts/cleanup_pr_body.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ sed -i '/\*\*BEFORE SUBMITTING, PLEASE READ.*\*\*/,$d' "${NEW}"
2626

2727
# Remove HTML <details> section that includes <summary> text of "PR Checklist (Click to Expand)"
2828
python3 - <<EOF
29-
import re
29+
import regex
3030
3131
with open("${NEW}", "r") as file:
3232
content = file.read()
3333
34-
pattern = re.compile(r'(---\n\n)?<details>.*?<summary>.*?PR Checklist \(Click to Expand\).*?</summary>.*?</details>', re.DOTALL)
35-
content = re.sub(pattern, '', content)
34+
pattern = regex.compile(r'(---\n\n)?<details>.*?<summary>.*?PR Checklist \(Click to Expand\).*?</summary>.*?</details>', regex.DOTALL)
35+
content = regex.sub(pattern, '', content)
3636
3737
with open("${NEW}", "w") as file:
3838
file.write(content)

benchmarks/benchmark_serving_structured_output.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -672,11 +672,11 @@ def process_one_metric(
672672
def evaluate(ret, args):
673673
def _eval_correctness_json(expected, actual):
674674
# extract json string from string using regex
675-
import re
675+
import regex
676676

677677
actual = actual.replace("\n", "").replace(" ", "").strip()
678678
try:
679-
actual = re.search(r"\{.*\}", actual).group()
679+
actual = regex.search(r"\{.*\}", actual).group()
680680
actual = json.loads(actual)
681681
except Exception:
682682
return False
@@ -687,9 +687,9 @@ def _eval_correctness_choice(expected, actual):
687687
return actual in args.choice
688688

689689
def _eval_correctness_regex(expected, actual):
690-
import re
690+
import regex
691691

692-
return re.match(args.regex, actual) is not None
692+
return regex.match(args.regex, actual) is not None
693693

694694
def _eval_correctness(expected, actual):
695695
if args.structure_type == "guided_json":

benchmarks/kernels/graph_machete_bench.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import math
44
import pickle
5-
import re
65
from collections import defaultdict
76

87
import matplotlib.pyplot as plt
98
import pandas as pd
9+
import regex
1010
import seaborn as sns
1111
from torch.utils.benchmark import Measurement as TMeasurement
1212

@@ -27,12 +27,12 @@
2727

2828
results = defaultdict(lambda: list())
2929
for v in raw_results:
30-
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
30+
result = regex.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
3131
if result is not None:
3232
KN = result.group(1)
3333
else:
3434
raise Exception("MKN not found")
35-
result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
35+
result = regex.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
3636
if result is not None:
3737
M = result.group(1)
3838
else:

docs/source/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
import datetime
1616
import logging
1717
import os
18-
import re
1918
import sys
2019
from pathlib import Path
2120

21+
import regex
2222
import requests
2323

2424
logger = logging.getLogger(__name__)
@@ -198,7 +198,7 @@ def linkcode_resolve(domain, info):
198198
for lineno, line in enumerate(lines, 1):
199199
if not line or line.startswith("#"):
200200
continue
201-
if re.match(pattern, line):
201+
if regex.match(pattern, line):
202202
break
203203

204204
# If the line number is not found, return None

docs/source/generate_examples.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import itertools
4-
import re
54
from dataclasses import dataclass, field
65
from pathlib import Path
76

7+
import regex
8+
89
ROOT_DIR = Path(__file__).parent.parent.parent.resolve()
910
ROOT_DIR_RELATIVE = '../../../..'
1011
EXAMPLE_DIR = ROOT_DIR / "examples"
@@ -32,7 +33,7 @@ def fix_case(text: str) -> str:
3233
r"int\d+": lambda x: x.group(0).upper(), # e.g. int8, int16
3334
}
3435
for pattern, repl in subs.items():
35-
text = re.sub(rf'\b{pattern}\b', repl, text, flags=re.IGNORECASE)
36+
text = regex.sub(rf'\b{pattern}\b', repl, text, flags=regex.IGNORECASE)
3637
return text
3738

3839

examples/offline_inference/prithvi_geospatial_mae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
import argparse
2121
import datetime
2222
import os
23-
import re
2423
from typing import Union
2524

2625
import albumentations
2726
import numpy as np
2827
import rasterio
28+
import regex
2929
import torch
3030
from einops import rearrange
3131
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
@@ -300,7 +300,7 @@ def load_example(
300300
location_coords.append(coords)
301301

302302
try:
303-
match = re.search(r'(\d{7,8}T\d{6})', file)
303+
match = regex.search(r'(\d{7,8}T\d{6})', file)
304304
if match:
305305
year = int(match.group(1)[:4])
306306
julian_day = match.group(1).split('T')[0][4:]

requirements/common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
regex # Replace re for higher-performance regex matching
12
cachetools
23
psutil
34
sentencepiece # Required for LLaMA tokenizer.

requirements/docs.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ myst-parser==3.0.1 # `myst-parser==4.0.1` breaks inline code in titles
88
msgspec
99
snowballstemmer<3 # https://github.com/snowballstem/snowball/issues/229
1010
commonmark # Required by sphinx-argparse when using :markdownhelp:
11+
regex # Replace re for higher-performance regex matching
1112

1213
# Custom autodoc2 is necessary for faster docstring processing
1314
# see: https://github.com/sphinx-extensions2/sphinx-autodoc2/issues/33#issuecomment-2856386035

setup.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import json
66
import logging
77
import os
8-
import re
98
import subprocess
109
import sys
1110
from pathlib import Path
1211
from shutil import which
1312

13+
import regex
1414
import torch
1515
from packaging.version import Version, parse
1616
from setuptools import Extension, setup
@@ -389,8 +389,7 @@ def run(self) -> None:
389389
# vllm_flash_attn python code:
390390
# Regex from
391391
# `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)`
392-
import re
393-
compiled_regex = re.compile(
392+
compiled_regex = regex.compile(
394393
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
395394
file_members += list(
396395
filter(lambda x: compiled_regex.match(x.filename),
@@ -510,7 +509,7 @@ def get_neuronxcc_version():
510509
content = fp.read()
511510

512511
# Extract the version using a regular expression
513-
match = re.search(r"__version__ = '(\S+)'", content)
512+
match = regex.search(r"__version__ = '(\S+)'", content)
514513
if match:
515514
# Return the version string
516515
return match.group(1)

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import json
4-
import re
54
import weakref
65
from enum import Enum
76

87
import jsonschema
98
import pytest
9+
import regex
1010
from pydantic import BaseModel
1111

1212
from vllm.distributed import cleanup_dist_env_and_memory
@@ -62,7 +62,7 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str,
6262
generated_text = output.outputs[0].text
6363
print(generated_text)
6464
assert generated_text is not None
65-
assert re.fullmatch(sample_regex, generated_text) is not None
65+
assert regex.fullmatch(sample_regex, generated_text) is not None
6666
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
6767

6868

@@ -479,7 +479,7 @@ def test_guided_number_range_json_completion(llm, guided_decoding_backend: str,
479479
jsonschema.validate(instance=output_json, schema=sample_output_schema)
480480
assert 18 <= output_json["age"] <= 99
481481
assert 0.0 <= output_json["score"] <= 100.0
482-
assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"])
482+
assert (regex.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"])
483483
is not None)
484484

485485

tests/entrypoints/openai/test_chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22

33
# imports for guided decoding tests
44
import json
5-
import re
65
from typing import Optional
76

87
import jsonschema
98
import openai # use the official client for correctness check
109
import pytest
1110
import pytest_asyncio
11+
import regex
1212
import requests
1313
import torch
1414
from openai import BadRequestError, OpenAI
@@ -585,7 +585,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex):
585585
extra_body=dict(guided_regex=sample_regex))
586586
ip1 = chat_completion.choices[0].message.content
587587
assert ip1 is not None
588-
assert re.fullmatch(sample_regex, ip1) is not None
588+
assert regex.fullmatch(sample_regex, ip1) is not None
589589

590590
messages.append({"role": "assistant", "content": ip1})
591591
messages.append({"role": "user", "content": "Give me a different one"})
@@ -596,7 +596,7 @@ async def test_guided_regex_chat(client: openai.AsyncOpenAI, sample_regex):
596596
extra_body=dict(guided_regex=sample_regex))
597597
ip2 = chat_completion.choices[0].message.content
598598
assert ip2 is not None
599-
assert re.fullmatch(sample_regex, ip2) is not None
599+
assert regex.fullmatch(sample_regex, ip2) is not None
600600
assert ip1 != ip2
601601

602602

tests/entrypoints/openai/test_completion.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
32
# imports for guided decoding tests
43
import json
5-
import re
64
import shutil
75
from tempfile import TemporaryDirectory
86
from typing import Optional
@@ -11,6 +9,7 @@
119
import openai # use the official client for correctness check
1210
import pytest
1311
import pytest_asyncio
12+
import regex
1413
# downloading lora to test lora requests
1514
from huggingface_hub import snapshot_download
1615
from openai import BadRequestError
@@ -677,8 +676,8 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
677676
assert completion.id is not None
678677
assert len(completion.choices) == 3
679678
for i in range(3):
680-
assert re.fullmatch(sample_regex,
681-
completion.choices[i].text) is not None
679+
assert regex.fullmatch(sample_regex,
680+
completion.choices[i].text) is not None
682681

683682

684683
@pytest.mark.asyncio
@@ -747,7 +746,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
747746

748747
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
749748
list) else prompt
750-
assert re.search(r"^" + prompt_text, completion.choices[0].text)
749+
assert regex.search(r"^" + prompt_text, completion.choices[0].text)
751750
logprobs = completion.choices[0].logprobs
752751
assert logprobs is not None
753752
assert len(logprobs.text_offset) > 5

tests/entrypoints/openai/test_prompt_validation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
# imports for guided decoding tests
4-
import re
5-
64
import openai
75
import pytest
6+
import regex
87

98
from ...utils import RemoteOpenAIServer
109

@@ -32,7 +31,7 @@ async def test_out_of_vocab_token_ids():
3231
client = remote_server.get_async_client()
3332

3433
with pytest.raises(openai.BadRequestError,
35-
match=re.compile('.*out of vocabulary.*')):
34+
match=regex.compile('.*out of vocabulary.*')):
3635
await client.completions.create(model=model_name,
3736
prompt=[999999],
3837
max_tokens=5,
@@ -47,7 +46,7 @@ async def test_reject_multistep_with_guided_decoding():
4746
client = remote_server.get_async_client()
4847

4948
with pytest.raises(openai.BadRequestError,
50-
match=re.compile(
49+
match=regex.compile(
5150
'.*Guided decoding .* multi-step decoding.*')):
5251
await client.completions.create(
5352
model=model_name,

tests/models/multimodal/generation/test_phi4mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import os
4-
import re
54
from collections.abc import Sequence
65
from typing import Optional
76

87
import librosa
98
import pytest
9+
import regex
1010
from huggingface_hub import snapshot_download
1111
from transformers import AutoTokenizer
1212

@@ -44,7 +44,7 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str,
4444
"""Sanitize vllm output to be comparable with hf output."""
4545
_, output_str, out_logprobs = vllm_output
4646

47-
output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
47+
output_str_without_image = regex.sub(r"(<\|image_\d+\|>)+", "", output_str)
4848
assert output_str_without_image[0] == " "
4949
output_str_without_image = output_str_without_image[1:]
5050

tests/models/multimodal/generation/vlm_utils/model_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
for manipulating the input / output of HF & vLLM test runners, which are
44
typically specific to a small subset of models.
55
"""
6-
import re
76
import types
87
from pathlib import PosixPath
98
from typing import Optional, Union
109

10+
import regex
1111
import torch
1212
from PIL.Image import Image
1313
from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
@@ -160,7 +160,7 @@ def phi3v_vllm_to_hf_output(vllm_output: RunnerOutput,
160160
"""Sanitize vllm output [phi3v] to be comparable with hf output."""
161161
_, output_str, out_logprobs = vllm_output
162162

163-
output_str_without_image = re.sub(r"(<\|image_\d+\|>)+", "", output_str)
163+
output_str_without_image = regex.sub(r"(<\|image_\d+\|>)+", "", output_str)
164164
assert output_str_without_image[0] == " "
165165
output_str_without_image = output_str_without_image[1:]
166166

@@ -335,7 +335,7 @@ def processor(*args, text="", images=None, **kwargs):
335335

336336
images = [images] if isinstance(images, Image) else images
337337

338-
contents = re.findall(
338+
contents = regex.findall(
339339
r"<\|begin_of_image\|><\|endoftext\|><\|end_of_image\|>(.*?)<\|assistant\|>",
340340
text,
341341
)

tests/tool_use/test_tool_choice_required.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
22
import json
3-
import re
43
from copy import deepcopy
54
from unittest.mock import MagicMock
65

@@ -73,7 +72,7 @@ def _compile_and_check(tools: list[ChatCompletionToolsParam], sample_output,
7372
# use build_regex_from_schema used in JSONLogitsProcessor to create Guide
7473
from outlines_core.fsm.json_schema import build_regex_from_schema
7574
regex = build_regex_from_schema(json.dumps(schema))
76-
compiled = re.compile(regex)
75+
compiled = regex.compile(regex)
7776
matches = compiled.fullmatch(json.dumps(sample_output)) is not None
7877

7978
assert matches == should_match

0 commit comments

Comments
 (0)