Skip to content

Overload variables class for better typing experience #1919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@ jobs:
pip list
- name: Ruff
run: |
ruff .
ruff check .
ruff format --check .
typos .
- name: Tests
17 changes: 12 additions & 5 deletions src/prompt_toolkit/contrib/regular_languages/compiler.py
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@
from __future__ import annotations

import re
from typing import Callable, Dict, Iterable, Iterator, Pattern
from typing import Callable, Dict, Iterable, Iterator, Pattern, TypeVar, overload
from typing import Match as RegexMatch

from .regex_parser import (
@@ -57,9 +57,7 @@
tokenize_regex,
)

__all__ = [
"compile",
]
__all__ = ["compile", "Match", "Variables"]


# Name of the named group in the regex, matching trailing input.
@@ -491,6 +489,9 @@ def end_nodes(self) -> Iterable[MatchVariable]:
yield MatchVariable(varname, value, (reg[0], reg[1]))


_T = TypeVar("_T")


class Variables:
def __init__(self, tuples: list[tuple[str, str, tuple[int, int]]]) -> None:
#: List of (varname, value, slice) tuples.
@@ -502,7 +503,13 @@ def __repr__(self) -> str:
", ".join(f"{k}={v!r}" for k, v, _ in self._tuples),
)

def get(self, key: str, default: str | None = None) -> str | None:
@overload
def get(self, key: str) -> str | None: ...

@overload
def get(self, key: str, default: str | _T) -> str | _T: ...

def get(self, key: str, default: str | _T | None = None) -> str | _T | None:
items = self.getall(key)
return items[0] if items else default

22 changes: 13 additions & 9 deletions src/prompt_toolkit/output/defaults.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import TextIO, cast
from typing import TYPE_CHECKING, TextIO, cast

from prompt_toolkit.utils import (
get_bell_environment_variable,
@@ -13,13 +13,17 @@
from .color_depth import ColorDepth
from .plain_text import PlainTextOutput

if TYPE_CHECKING:
from prompt_toolkit.patch_stdout import StdoutProxy


__all__ = [
"create_output",
]


def create_output(
stdout: TextIO | None = None, always_prefer_tty: bool = False
stdout: TextIO | StdoutProxy | None = None, always_prefer_tty: bool = False
) -> Output:
"""
Return an :class:`~prompt_toolkit.output.Output` instance for the command
@@ -54,13 +58,6 @@ def create_output(
stdout = io
break

# If the output is still `None`, use a DummyOutput.
# This happens for instance on Windows, when running the application under
# `pythonw.exe`. In that case, there won't be a terminal Window, and
# stdin/stdout/stderr are `None`.
if stdout is None:
return DummyOutput()

# If the patch_stdout context manager has been used, then sys.stdout is
# replaced by this proxy. For prompt_toolkit applications, we want to use
# the real stdout.
@@ -69,6 +66,13 @@ def create_output(
while isinstance(stdout, StdoutProxy):
stdout = stdout.original_stdout

# If the output is still `None`, use a DummyOutput.
# This happens for instance on Windows, when running the application under
# `pythonw.exe`. In that case, there won't be a terminal Window, and
# stdin/stdout/stderr are `None`.
if stdout is None:
return DummyOutput()

if sys.platform == "win32":
from .conemu import ConEmuOutput
from .win32 import Win32Output
2 changes: 1 addition & 1 deletion src/prompt_toolkit/patch_stdout.py
Original file line number Diff line number Diff line change
@@ -273,7 +273,7 @@ def flush(self) -> None:
self._flush()

@property
def original_stdout(self) -> TextIO:
def original_stdout(self) -> TextIO | None:
return self._output.stdout or sys.__stdout__

# Attributes for compatibility with sys.__stdout__: