From 6e00071f25283f253272c651bf99548420107b55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 4 Feb 2025 19:50:05 +0545 Subject: [PATCH] ssh client: implement `password_auth_requested` --- .../git/backend/dulwich/asyncssh_vendor.py | 34 +++++++++---------- tests/vendor/test_paramiko_vendor.py | 2 +- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/scmrepo/git/backend/dulwich/asyncssh_vendor.py b/src/scmrepo/git/backend/dulwich/asyncssh_vendor.py index a60bf70d..762d4b87 100644 --- a/src/scmrepo/git/backend/dulwich/asyncssh_vendor.py +++ b/src/scmrepo/git/backend/dulwich/asyncssh_vendor.py @@ -40,6 +40,12 @@ async def _read_all(read: Callable[[int], Coroutine], n: Optional[int] = None) - return b"".join(result) +async def _getpass(*args, **kwargs) -> str: + from getpass import getpass + + return await asyncio.to_thread(getpass, *args, **kwargs) + + class _StderrWrapper: def __init__(self, stderr: "SSHReader", loop: asyncio.AbstractEventLoop) -> None: self.stderr = stderr @@ -202,8 +208,6 @@ async def public_key_auth_requested( # noqa: C901 return None async def _read_private_key_interactive(self, path: "FilePath") -> "SSHKey": - from getpass import getpass - from asyncssh.public_key import ( KeyEncryptionError, KeyImportError, @@ -215,11 +219,8 @@ async def _read_private_key_interactive(self, path: "FilePath") -> "SSHKey": if passphrase: return read_private_key(path, passphrase=passphrase) - loop = asyncio.get_running_loop() for _ in range(3): - passphrase = await loop.run_in_executor( - None, getpass, f"Enter passphrase for key '{path}': " - ) + passphrase = await _getpass(f"Enter passphrase for key {path!r}: ") if passphrase: try: key = read_private_key(path, passphrase=passphrase) @@ -239,23 +240,20 @@ async def kbdint_challenge_received( # pylint: disable=invalid-overridden-metho lang: str, prompts: "KbdIntPrompts", ) -> Optional["KbdIntResponse"]: - from getpass import getpass - if os.environ.get("GIT_TERMINAL_PROMPT") == "0": return None - def _getpass(prompt: str) -> str: - return getpass(prompt=prompt).rstrip() - if instructions: pass - loop = asyncio.get_running_loop() - return [ - await loop.run_in_executor( - None, _getpass, f"({name}) {prompt}" if name else prompt - ) - for prompt, _ in prompts - ] + + response: list[str] = [] + for prompt, _echo in prompts: + p = await _getpass(f"({name}) {prompt}" if name else prompt) + response.append(p.rstrip()) + return response + + async def password_auth_requested(self) -> str: + return await _getpass() class AsyncSSHVendor(BaseAsyncObject, SSHVendor): diff --git a/tests/vendor/test_paramiko_vendor.py b/tests/vendor/test_paramiko_vendor.py index 8310916d..756cbc6e 100644 --- a/tests/vendor/test_paramiko_vendor.py +++ b/tests/vendor/test_paramiko_vendor.py @@ -56,7 +56,7 @@ def check_channel_request(self, kind, chanid): return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED def get_allowed_auths(self, username): - return "password,publickey" + return "publickey" USER = "testuser"