Skip to content

Commit bde7db8

Browse files
committed
Overload Variables class for better typing experience
1 parent 8045b8f commit bde7db8

File tree

1 file changed

+12
-5
lines changed
  • src/prompt_toolkit/contrib/regular_languages

1 file changed

+12
-5
lines changed

Diff for: src/prompt_toolkit/contrib/regular_languages/compiler.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from __future__ import annotations
4343

4444
import re
45-
from typing import Callable, Dict, Iterable, Iterator, Pattern
45+
from typing import Callable, Dict, Iterable, Iterator, Pattern, TypeVar, overload
4646
from typing import Match as RegexMatch
4747

4848
from .regex_parser import (
@@ -57,9 +57,7 @@
5757
tokenize_regex,
5858
)
5959

60-
__all__ = [
61-
"compile",
62-
]
60+
__all__ = ["compile", "Match", "Variables"]
6361

6462

6563
# Name of the named group in the regex, matching trailing input.
@@ -491,6 +489,9 @@ def end_nodes(self) -> Iterable[MatchVariable]:
491489
yield MatchVariable(varname, value, (reg[0], reg[1]))
492490

493491

492+
_T = TypeVar("_T")
493+
494+
494495
class Variables:
495496
def __init__(self, tuples: list[tuple[str, str, tuple[int, int]]]) -> None:
496497
#: List of (varname, value, slice) tuples.
@@ -502,7 +503,13 @@ def __repr__(self) -> str:
502503
", ".join(f"{k}={v!r}" for k, v, _ in self._tuples),
503504
)
504505

505-
def get(self, key: str, default: str | None = None) -> str | None:
506+
@overload
507+
def get(self, key: str) -> str | None: ...
508+
509+
@overload
510+
def get(self, key: str, default: _T = None) -> str | _T: ...
511+
512+
def get(self, key: str, default: _T = None) -> str | _T:
506513
items = self.getall(key)
507514
return items[0] if items else default
508515

0 commit comments

Comments
 (0)