23
23
)
24
24
25
25
import functools
26
+ from typing import Any , Optional
26
27
28
+ import paramiko .pkey
27
29
import paramiko .ssh_exception
28
30
29
31
from testinfra .backend import base
32
34
class IgnorePolicy (paramiko .MissingHostKeyPolicy ):
33
35
"""Policy for ignoring missing host key."""
34
36
35
- def missing_host_key (self , client , hostname , key ):
37
+ def missing_host_key (
38
+ self , client : paramiko .SSHClient , hostname : str , key : paramiko .pkey .PKey
39
+ ) -> None :
36
40
pass
37
41
38
42
@@ -41,12 +45,12 @@ class ParamikoBackend(base.BaseBackend):
41
45
42
46
def __init__ (
43
47
self ,
44
- hostspec ,
45
- ssh_config = None ,
46
- ssh_identity_file = None ,
47
- timeout = 10 ,
48
- * args ,
49
- ** kwargs ,
48
+ hostspec : str ,
49
+ ssh_config : Optional [ str ] = None ,
50
+ ssh_identity_file : Optional [ str ] = None ,
51
+ timeout : int = 10 ,
52
+ * args : Any ,
53
+ ** kwargs : Any ,
50
54
):
51
55
self .host = self .parse_hostspec (hostspec )
52
56
self .ssh_config = ssh_config
@@ -55,7 +59,13 @@ def __init__(
55
59
self .timeout = int (timeout )
56
60
super ().__init__ (self .host .name , * args , ** kwargs )
57
61
58
- def _load_ssh_config (self , client , cfg , ssh_config , ssh_config_dir = "~/.ssh" ):
62
+ def _load_ssh_config (
63
+ self ,
64
+ client : paramiko .SSHClient ,
65
+ cfg : dict [str , Any ],
66
+ ssh_config : paramiko .SSHConfig ,
67
+ ssh_config_dir : str = "~/.ssh" ,
68
+ ) -> None :
59
69
for key , value in ssh_config .lookup (self .host .name ).items ():
60
70
if key == "hostname" :
61
71
cfg [key ] = value
@@ -85,7 +95,7 @@ def _load_ssh_config(self, client, cfg, ssh_config, ssh_config_dir="~/.ssh"):
85
95
self ._load_ssh_config (client , cfg , new_ssh_config , ssh_config_dir )
86
96
87
97
@functools .cached_property
88
- def client (self ):
98
+ def client (self ) -> paramiko . SSHClient :
89
99
client = paramiko .SSHClient ()
90
100
client .set_missing_host_key_policy (paramiko .WarningPolicy ())
91
101
cfg = {
@@ -118,11 +128,13 @@ def client(self):
118
128
119
129
if self .ssh_identity_file :
120
130
cfg ["key_filename" ] = self .ssh_identity_file
121
- client .connect (** cfg )
131
+ client .connect (** cfg ) # type: ignore[arg-type]
122
132
return client
123
133
124
- def _exec_command (self , command ):
125
- chan = self .client .get_transport ().open_session ()
134
+ def _exec_command (self , command : bytes ) -> tuple [int , bytes , bytes ]:
135
+ transport = self .client .get_transport ()
136
+ assert transport is not None
137
+ chan = transport .open_session ()
126
138
if self .get_pty :
127
139
chan .get_pty ()
128
140
chan .exec_command (command )
@@ -131,17 +143,19 @@ def _exec_command(self, command):
131
143
stderr = b"" .join (chan .makefile_stderr ("rb" ))
132
144
return rc , stdout , stderr
133
145
134
- def run (self , command , * args , ** kwargs ) :
146
+ def run (self , command : str , * args : str , ** kwargs : Any ) -> base . CommandResult :
135
147
command = self .get_command (command , * args )
136
- command = self .encode (command )
148
+ cmd = self .encode (command )
137
149
try :
138
- rc , stdout , stderr = self ._exec_command (command )
150
+ rc , stdout , stderr = self ._exec_command (cmd )
139
151
except paramiko .ssh_exception .SSHException :
140
- if not self .client .get_transport ().is_active ():
152
+ transport = self .client .get_transport ()
153
+ assert transport is not None
154
+ if not transport .is_active ():
141
155
# try to reinit connection (once)
142
156
del self .client
143
- rc , stdout , stderr = self ._exec_command (command )
157
+ rc , stdout , stderr = self ._exec_command (cmd )
144
158
else :
145
159
raise
146
160
147
- return self .result (rc , command , stdout , stderr )
161
+ return self .result (rc , cmd , stdout , stderr )
0 commit comments