Skip to content
This repository was archived by the owner on Mar 13, 2022. It is now read-only.

Commit 1dd2756

Browse files
committed
Implement port forwarding.
1 parent 54d188f commit 1dd2756

File tree

3 files changed

+252
-66
lines changed

3 files changed

+252
-66
lines changed

Diff for: stream/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .stream import stream
15+
from .stream import stream, portforward

Diff for: stream/stream.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,31 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import types
16+
1517
from . import ws_client
1618

1719

1820
def stream(func, *args, **kwargs):
1921
"""Stream given API call using websocket.
2022
Extra kwarg: capture-all=True - captures all stdout+stderr for use with WSClient.read_all()"""
2123

22-
def _intercept_request_call(*args, **kwargs):
23-
# old generated code's api client has config. new ones has
24-
# configuration
25-
try:
26-
config = func.__self__.api_client.configuration
27-
except AttributeError:
28-
config = func.__self__.api_client.config
24+
api_client = func.__self__.api_client
25+
prev_request = api_client.request
26+
try:
27+
api_client.request = types.MethodType(ws_client.websocket_call, api_client)
28+
return func(*args, **kwargs)
29+
finally:
30+
api_client.request = prev_request
2931

30-
return ws_client.websocket_call(config, *args, **kwargs)
3132

32-
prev_request = func.__self__.api_client.request
33+
def portforward(func, *args, **kwargs):
34+
kwargs['_preload_content'] = False
35+
api_client = func.__self__.api_client
36+
prev_request = api_client.request
3337
try:
34-
func.__self__.api_client.request = _intercept_request_call
38+
api_client.request = types.MethodType(ws_client.portforward_call, api_client)
3539
return func(*args, **kwargs)
3640
finally:
37-
func.__self__.api_client.request = prev_request
41+
api_client.request = prev_request
42+

Diff for: stream/ws_client.py

+235-54
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from kubernetes.client.rest import ApiException
15+
from kubernetes.client.rest import ApiException, ApiValueError
1616

1717
import certifi
1818
import collections
1919
import select
20+
import socket
2021
import ssl
22+
import threading
2123
import time
2224

2325
import six
2426
import yaml
2527

26-
from six.moves.urllib.parse import urlencode, quote_plus, urlparse, urlunparse
28+
from six.moves.urllib.parse import urlencode, urlparse, urlunparse
2729
from six import StringIO
2830

2931
from websocket import WebSocket, ABNF, enableTrace
@@ -51,47 +53,13 @@ def __init__(self, configuration, url, headers, capture_all):
5153
like port forwarding can forward different pods' streams to different
5254
channels.
5355
"""
54-
enableTrace(False)
55-
header = []
5656
self._connected = False
5757
self._channels = {}
5858
if capture_all:
5959
self._all = StringIO()
6060
else:
6161
self._all = _IgnoredIO()
62-
63-
# We just need to pass the Authorization, ignore all the other
64-
# http headers we get from the generated code
65-
if headers and 'authorization' in headers:
66-
header.append("authorization: %s" % headers['authorization'])
67-
68-
if headers and 'sec-websocket-protocol' in headers:
69-
header.append("sec-websocket-protocol: %s" %
70-
headers['sec-websocket-protocol'])
71-
else:
72-
header.append("sec-websocket-protocol: v4.channel.k8s.io")
73-
74-
if url.startswith('wss://') and configuration.verify_ssl:
75-
ssl_opts = {
76-
'cert_reqs': ssl.CERT_REQUIRED,
77-
'ca_certs': configuration.ssl_ca_cert or certifi.where(),
78-
}
79-
if configuration.assert_hostname is not None:
80-
ssl_opts['check_hostname'] = configuration.assert_hostname
81-
else:
82-
ssl_opts = {'cert_reqs': ssl.CERT_NONE}
83-
84-
if configuration.cert_file:
85-
ssl_opts['certfile'] = configuration.cert_file
86-
if configuration.key_file:
87-
ssl_opts['keyfile'] = configuration.key_file
88-
89-
self.sock = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False)
90-
if configuration.proxy:
91-
proxy_url = urlparse(configuration.proxy)
92-
self.sock.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port)
93-
else:
94-
self.sock.connect(url, header=header)
62+
self.sock = create_websocket(configuration, url, headers=headers)
9563
self._connected = True
9664

9765
def peek_channel(self, channel, timeout=0):
@@ -259,44 +227,257 @@ def close(self, **kwargs):
259227
WSResponse = collections.namedtuple('WSResponse', ['data'])
260228

261229

262-
def get_websocket_url(url):
230+
class PortForwardClient:
231+
def __init__(self, websocket, ports):
232+
"""A websocket client with support for port forwarding.
233+
234+
Port Forward command sends on 2 channels per port, a read/write
235+
data channel and a read only error channel. Both channels are sent an
236+
initial frame contaning the port number that channel is associated with.
237+
"""
238+
239+
self.websocket = websocket
240+
self.ports = {}
241+
for ix, port_number in enumerate(ports):
242+
self.ports[port_number] = self.Port(ix, port_number)
243+
threading.Thread(
244+
name="Kubernetes port forward proxy", target=self._proxy, daemon=True
245+
).start()
246+
247+
def socket(self, port_number):
248+
if port_number not in self.ports:
249+
raise ValueError("Invalid port number")
250+
return self.ports[port_number].socket
251+
252+
def error(self, port_number):
253+
if port_number not in self.ports:
254+
raise ValueError("Invalid port number")
255+
return self.ports[port_number].error
256+
257+
def close(self):
258+
for port in self.ports.values():
259+
port.socket.close()
260+
261+
class Port:
262+
def __init__(self, ix, number):
263+
self.number = number
264+
self.channel = bytes([ix * 2])
265+
s, self.python = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
266+
self.socket = self.Socket(s)
267+
self.data = b''
268+
self.error = None
269+
270+
class Socket:
271+
def __init__(self, socket):
272+
self._socket = socket
273+
274+
def __getattr__(self, name):
275+
return getattr(self._socket, name)
276+
277+
def setsockopt(self, level, optname, value):
278+
# The following socket option is not valid with a socket created from socketpair,
279+
# and is set when creating an SSLSocket from this socket.
280+
if level == socket.IPPROTO_TCP and optname == socket.TCP_NODELAY:
281+
return
282+
self._socket.setsockopt(level, optname, value)
283+
284+
# Proxy all socket data between the python code and the kubernetes websocket.
285+
def _proxy(self):
286+
channel_ports = []
287+
channel_initialized = []
288+
python_ports = {}
289+
rlist = []
290+
for port in self.ports.values():
291+
channel_ports.append(port)
292+
channel_initialized.append(False)
293+
channel_ports.append(port)
294+
channel_initialized.append(False)
295+
python_ports[port.python] = port
296+
rlist.append(port.python)
297+
rlist.append(self.websocket.sock)
298+
kubernetes_data = b''
299+
while True:
300+
wlist = []
301+
for port in self.ports.values():
302+
if port.data:
303+
wlist.append(port.python)
304+
if kubernetes_data:
305+
wlist.append(self.websocket.sock)
306+
r, w, _ = select.select(rlist, wlist, [])
307+
for s in w:
308+
if s == self.websocket.sock:
309+
sent = self.websocket.sock.send(kubernetes_data)
310+
kubernetes_data = kubernetes_data[sent:]
311+
else:
312+
port = python_ports[s]
313+
sent = port.python.send(port.data)
314+
port.data = port.data[sent:]
315+
for s in r:
316+
if s == self.websocket.sock:
317+
opcode, frame = self.websocket.recv_data_frame(True)
318+
if opcode == ABNF.OPCODE_CLOSE:
319+
for port in self.ports.values():
320+
port.python.close()
321+
return
322+
if opcode == ABNF.OPCODE_BINARY:
323+
if not frame.data:
324+
raise RuntimeError("Unexpected frame data size")
325+
channel = frame.data[0]
326+
if channel >= len(channel_ports):
327+
raise RuntimeError("Unexpected channel number: " + str(channel))
328+
port = channel_ports[channel]
329+
if channel_initialized[channel]:
330+
if channel % 2:
331+
port.error = frame.data[1:].decode()
332+
if port.python in rlist:
333+
port.python.close()
334+
rlist.remove(port.python)
335+
port.data = b''
336+
else:
337+
port.data += frame.data[1:]
338+
else:
339+
if len(frame.data) != 3:
340+
raise RuntimeError(
341+
"Unexpected initial channel frame data size"
342+
)
343+
port_number = frame.data[1] + (frame.data[2] * 256)
344+
if port_number != port.number:
345+
raise RuntimeError(
346+
"Unexpected port number in initial channel frame: " + str(port_number)
347+
)
348+
channel_initialized[channel] = True
349+
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG):
350+
raise RuntimeError("Unexpected websocket opcode: " + str(opcode))
351+
else:
352+
port = python_ports[s]
353+
data = port.python.recv(1024 * 1024)
354+
if data:
355+
kubernetes_data += ABNF.create_frame(
356+
port.channel + data,
357+
ABNF.OPCODE_BINARY,
358+
).format()
359+
else:
360+
port.python.close()
361+
rlist.remove(s)
362+
if len(rlist) == 1:
363+
self.websocket.close()
364+
return
365+
366+
367+
def get_websocket_url(url, query_params=None):
263368
parsed_url = urlparse(url)
264369
parts = list(parsed_url)
265370
if parsed_url.scheme == 'http':
266371
parts[0] = 'ws'
267372
elif parsed_url.scheme == 'https':
268373
parts[0] = 'wss'
374+
if query_params:
375+
query = []
376+
for key, value in query_params:
377+
if key == 'command' and isinstance(value, list):
378+
for command in value:
379+
query.append((key, command))
380+
else:
381+
query.append((key, value))
382+
if query:
383+
parts[4] = urlencode(query)
269384
return urlunparse(parts)
270385

271386

272-
def websocket_call(configuration, *args, **kwargs):
387+
def create_websocket(configuration, url, headers=None):
388+
enableTrace(False)
389+
390+
# We just need to pass the Authorization, ignore all the other
391+
# http headers we get from the generated code
392+
header = []
393+
if headers and 'authorization' in headers:
394+
header.append("authorization: %s" % headers['authorization'])
395+
if headers and 'sec-websocket-protocol' in headers:
396+
header.append("sec-websocket-protocol: %s" %
397+
headers['sec-websocket-protocol'])
398+
else:
399+
header.append("sec-websocket-protocol: v4.channel.k8s.io")
400+
401+
if url.startswith('wss://') and configuration.verify_ssl:
402+
ssl_opts = {
403+
'cert_reqs': ssl.CERT_REQUIRED,
404+
'ca_certs': configuration.ssl_ca_cert or certifi.where(),
405+
}
406+
if configuration.assert_hostname is not None:
407+
ssl_opts['check_hostname'] = configuration.assert_hostname
408+
else:
409+
ssl_opts = {'cert_reqs': ssl.CERT_NONE}
410+
411+
if configuration.cert_file:
412+
ssl_opts['certfile'] = configuration.cert_file
413+
if configuration.key_file:
414+
ssl_opts['keyfile'] = configuration.key_file
415+
416+
websocket = WebSocket(sslopt=ssl_opts, skip_utf8_validation=False)
417+
if configuration.proxy:
418+
proxy_url = urlparse(configuration.proxy)
419+
websocket.connect(url, header=header, http_proxy_host=proxy_url.hostname, http_proxy_port=proxy_url.port)
420+
else:
421+
websocket.connect(url, header=header)
422+
return websocket
423+
424+
425+
def _configuration(api_client):
426+
# old generated code's api client has config. new ones has
427+
# configuration
428+
try:
429+
return api_client.configuration
430+
except AttributeError:
431+
return api_client.config
432+
433+
434+
def websocket_call(api_client, _method, url, **kwargs):
273435
"""An internal function to be called in api-client when a websocket
274436
connection is required. args and kwargs are the parameters of
275437
apiClient.request method."""
276438

277-
url = args[1]
439+
url = get_websocket_url(url, kwargs.get("query_params"))
440+
headers = kwargs.get("headers")
278441
_request_timeout = kwargs.get("_request_timeout", 60)
279442
_preload_content = kwargs.get("_preload_content", True)
280443
capture_all = kwargs.get("capture_all", True)
281-
headers = kwargs.get("headers")
282-
283-
# Expand command parameter list to indivitual command params
284-
query_params = []
285-
for key, value in kwargs.get("query_params", {}):
286-
if key == 'command' and isinstance(value, list):
287-
for command in value:
288-
query_params.append((key, command))
289-
else:
290-
query_params.append((key, value))
291-
292-
if query_params:
293-
url += '?' + urlencode(query_params)
294444

295445
try:
296-
client = WSClient(configuration, get_websocket_url(url), headers, capture_all)
446+
client = WSClient(_configuration(api_client), url, headers, capture_all)
297447
if not _preload_content:
298448
return client
299449
client.run_forever(timeout=_request_timeout)
300450
return WSResponse('%s' % ''.join(client.read_all()))
301451
except (Exception, KeyboardInterrupt, SystemExit) as e:
302452
raise ApiException(status=0, reason=str(e))
453+
454+
455+
def portforward_call(api_client, _method, url, **kwargs):
456+
"""An internal function to be called in api-client when a websocket
457+
connection is required for port forwarding. args and kwargs are the
458+
parameters of apiClient.request method."""
459+
460+
query_params = kwargs.get("query_params")
461+
462+
ports = []
463+
for key, value in query_params:
464+
if key == 'ports':
465+
for port in value.split(','):
466+
try:
467+
port = int(port)
468+
if not (0 < port < 65536):
469+
raise ValueError
470+
ports.append(port)
471+
except ValueError:
472+
raise ApiValueError("Invalid port number `" + str(port) + "`")
473+
if not ports:
474+
raise ApiValueError("Missing required parameter `ports`")
475+
476+
url = get_websocket_url(url, query_params)
477+
headers = kwargs.get("headers")
478+
479+
try:
480+
websocket = create_websocket(_configuration(api_client), url, headers)
481+
return PortForwardClient(websocket, ports)
482+
except (Exception, KeyboardInterrupt, SystemExit) as e:
483+
raise ApiException(status=0, reason=str(e))

0 commit comments

Comments
 (0)