Skip to content

Commit 83d6d37

Browse files
Merge pull request #118182 from seans3/wsstream-refactor
Refactor wsstream library from apiserver to apimachinery Kubernetes-commit: 056f3a56b821a063210c2c4a67cc7a4d0a361afe
2 parents 07fcd5d + b714ff2 commit 83d6d37

File tree

5 files changed

+1116
-0
lines changed

5 files changed

+1116
-0
lines changed

pkg/util/httpstream/wsstream/conn.go

+350
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
/*
2+
Copyright 2015 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package wsstream
18+
19+
import (
20+
"encoding/base64"
21+
"fmt"
22+
"io"
23+
"net/http"
24+
"regexp"
25+
"strings"
26+
"time"
27+
28+
"golang.org/x/net/websocket"
29+
"k8s.io/klog/v2"
30+
31+
"k8s.io/apimachinery/pkg/util/runtime"
32+
)
33+
34+
// The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating
35+
// the channel number (zero indexed) the message was sent on. Messages in both directions should
36+
// prefix their messages with this channel byte. When used for remote execution, the channel numbers
37+
// are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, and STDERR
38+
// (0, 1, and 2). No other conversion is performed on the raw subprotocol - writes are sent as they
39+
// are received by the server.
40+
//
41+
// Example client session:
42+
//
43+
// CONNECT http://server.com with subprotocol "channel.k8s.io"
44+
// WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN)
45+
// READ []byte{1, 10} # receive "\n" on channel 1 (STDOUT)
46+
// CLOSE
47+
const ChannelWebSocketProtocol = "channel.k8s.io"
48+
49+
// The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
50+
// indicating the channel number (zero indexed) the message was sent on. Messages in both directions
51+
// should prefix their messages with this channel char. When used for remote execution, the channel
52+
// numbers are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT,
53+
// and STDERR ('0', '1', and '2'). The data received on the server is base64 decoded (and must be
54+
// be valid) and data written by the server to the client is base64 encoded.
55+
//
56+
// Example client session:
57+
//
58+
// CONNECT http://server.com with subprotocol "base64.channel.k8s.io"
59+
// WRITE []byte{48, 90, 109, 57, 118, 67, 103, 111, 61} # send "foo\n" (base64: "Zm9vCgo=") on channel '0' (STDIN)
60+
// READ []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
61+
// CLOSE
62+
const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
63+
64+
type codecType int
65+
66+
const (
67+
rawCodec codecType = iota
68+
base64Codec
69+
)
70+
71+
type ChannelType int
72+
73+
const (
74+
IgnoreChannel ChannelType = iota
75+
ReadChannel
76+
WriteChannel
77+
ReadWriteChannel
78+
)
79+
80+
var (
81+
// connectionUpgradeRegex matches any Connection header value that includes upgrade
82+
connectionUpgradeRegex = regexp.MustCompile("(^|.*,\\s*)upgrade($|\\s*,)")
83+
)
84+
85+
// IsWebSocketRequest returns true if the incoming request contains connection upgrade headers
86+
// for WebSockets.
87+
func IsWebSocketRequest(req *http.Request) bool {
88+
if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
89+
return false
90+
}
91+
return connectionUpgradeRegex.MatchString(strings.ToLower(req.Header.Get("Connection")))
92+
}
93+
94+
// IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
95+
// read and write deadlines are pushed every time a new message is received.
96+
func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
97+
defer runtime.HandleCrash()
98+
var data []byte
99+
for {
100+
resetTimeout(ws, timeout)
101+
if err := websocket.Message.Receive(ws, &data); err != nil {
102+
return
103+
}
104+
}
105+
}
106+
107+
// handshake ensures the provided user protocol matches one of the allowed protocols. It returns
108+
// no error if no protocol is specified.
109+
func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
110+
protocols := config.Protocol
111+
if len(protocols) == 0 {
112+
protocols = []string{""}
113+
}
114+
115+
for _, protocol := range protocols {
116+
for _, allow := range allowed {
117+
if allow == protocol {
118+
config.Protocol = []string{protocol}
119+
return nil
120+
}
121+
}
122+
}
123+
124+
return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
125+
}
126+
127+
// ChannelProtocolConfig describes a websocket subprotocol with channels.
128+
type ChannelProtocolConfig struct {
129+
Binary bool
130+
Channels []ChannelType
131+
}
132+
133+
// NewDefaultChannelProtocols returns a channel protocol map with the
134+
// subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
135+
// channels.
136+
func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
137+
return map[string]ChannelProtocolConfig{
138+
"": {Binary: true, Channels: channels},
139+
ChannelWebSocketProtocol: {Binary: true, Channels: channels},
140+
Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
141+
}
142+
}
143+
144+
// Conn supports sending multiple binary channels over a websocket connection.
145+
type Conn struct {
146+
protocols map[string]ChannelProtocolConfig
147+
selectedProtocol string
148+
channels []*websocketChannel
149+
codec codecType
150+
ready chan struct{}
151+
ws *websocket.Conn
152+
timeout time.Duration
153+
}
154+
155+
// NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
156+
// web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
157+
// future use. The channel types for each channel are passed as an array, supporting the different
158+
// duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
159+
//
160+
// The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
161+
// name is used if websocket.Config.Protocol is empty.
162+
func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
163+
return &Conn{
164+
ready: make(chan struct{}),
165+
protocols: protocols,
166+
}
167+
}
168+
169+
// SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
170+
// there is no timeout on the connection.
171+
func (conn *Conn) SetIdleTimeout(duration time.Duration) {
172+
conn.timeout = duration
173+
}
174+
175+
// Open the connection and create channels for reading and writing. It returns
176+
// the selected subprotocol, a slice of channels and an error.
177+
func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
178+
go func() {
179+
defer runtime.HandleCrash()
180+
defer conn.Close()
181+
websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req)
182+
}()
183+
<-conn.ready
184+
rwc := make([]io.ReadWriteCloser, len(conn.channels))
185+
for i := range conn.channels {
186+
rwc[i] = conn.channels[i]
187+
}
188+
return conn.selectedProtocol, rwc, nil
189+
}
190+
191+
func (conn *Conn) initialize(ws *websocket.Conn) {
192+
negotiated := ws.Config().Protocol
193+
conn.selectedProtocol = negotiated[0]
194+
p := conn.protocols[conn.selectedProtocol]
195+
if p.Binary {
196+
conn.codec = rawCodec
197+
} else {
198+
conn.codec = base64Codec
199+
}
200+
conn.ws = ws
201+
conn.channels = make([]*websocketChannel, len(p.Channels))
202+
for i, t := range p.Channels {
203+
switch t {
204+
case ReadChannel:
205+
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
206+
case WriteChannel:
207+
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
208+
case ReadWriteChannel:
209+
conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
210+
case IgnoreChannel:
211+
conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
212+
}
213+
}
214+
215+
close(conn.ready)
216+
}
217+
218+
func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
219+
supportedProtocols := make([]string, 0, len(conn.protocols))
220+
for p := range conn.protocols {
221+
supportedProtocols = append(supportedProtocols, p)
222+
}
223+
return handshake(config, req, supportedProtocols)
224+
}
225+
226+
func (conn *Conn) resetTimeout() {
227+
if conn.timeout > 0 {
228+
conn.ws.SetDeadline(time.Now().Add(conn.timeout))
229+
}
230+
}
231+
232+
// Close is only valid after Open has been called
233+
func (conn *Conn) Close() error {
234+
<-conn.ready
235+
for _, s := range conn.channels {
236+
s.Close()
237+
}
238+
conn.ws.Close()
239+
return nil
240+
}
241+
242+
// handle implements a websocket handler.
243+
func (conn *Conn) handle(ws *websocket.Conn) {
244+
defer conn.Close()
245+
conn.initialize(ws)
246+
247+
for {
248+
conn.resetTimeout()
249+
var data []byte
250+
if err := websocket.Message.Receive(ws, &data); err != nil {
251+
if err != io.EOF {
252+
klog.Errorf("Error on socket receive: %v", err)
253+
}
254+
break
255+
}
256+
if len(data) == 0 {
257+
continue
258+
}
259+
channel := data[0]
260+
if conn.codec == base64Codec {
261+
channel = channel - '0'
262+
}
263+
data = data[1:]
264+
if int(channel) >= len(conn.channels) {
265+
klog.V(6).Infof("Frame is targeted for a reader %d that is not valid, possible protocol error", channel)
266+
continue
267+
}
268+
if _, err := conn.channels[channel].DataFromSocket(data); err != nil {
269+
klog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data))
270+
continue
271+
}
272+
}
273+
}
274+
275+
// write multiplexes the specified channel onto the websocket
276+
func (conn *Conn) write(num byte, data []byte) (int, error) {
277+
conn.resetTimeout()
278+
switch conn.codec {
279+
case rawCodec:
280+
frame := make([]byte, len(data)+1)
281+
frame[0] = num
282+
copy(frame[1:], data)
283+
if err := websocket.Message.Send(conn.ws, frame); err != nil {
284+
return 0, err
285+
}
286+
case base64Codec:
287+
frame := string('0'+num) + base64.StdEncoding.EncodeToString(data)
288+
if err := websocket.Message.Send(conn.ws, frame); err != nil {
289+
return 0, err
290+
}
291+
}
292+
return len(data), nil
293+
}
294+
295+
// websocketChannel represents a channel in a connection
296+
type websocketChannel struct {
297+
conn *Conn
298+
num byte
299+
r io.Reader
300+
w io.WriteCloser
301+
302+
read, write bool
303+
}
304+
305+
// newWebsocketChannel creates a pipe for writing to a websocket. Do not write to this pipe
306+
// prior to the connection being opened. It may be no, half, or full duplex depending on
307+
// read and write.
308+
func newWebsocketChannel(conn *Conn, num byte, read, write bool) *websocketChannel {
309+
r, w := io.Pipe()
310+
return &websocketChannel{conn, num, r, w, read, write}
311+
}
312+
313+
func (p *websocketChannel) Write(data []byte) (int, error) {
314+
if !p.write {
315+
return len(data), nil
316+
}
317+
return p.conn.write(p.num, data)
318+
}
319+
320+
// DataFromSocket is invoked by the connection receiver to move data from the connection
321+
// into a specific channel.
322+
func (p *websocketChannel) DataFromSocket(data []byte) (int, error) {
323+
if !p.read {
324+
return len(data), nil
325+
}
326+
327+
switch p.conn.codec {
328+
case rawCodec:
329+
return p.w.Write(data)
330+
case base64Codec:
331+
dst := make([]byte, len(data))
332+
n, err := base64.StdEncoding.Decode(dst, data)
333+
if err != nil {
334+
return 0, err
335+
}
336+
return p.w.Write(dst[:n])
337+
}
338+
return 0, nil
339+
}
340+
341+
func (p *websocketChannel) Read(data []byte) (int, error) {
342+
if !p.read {
343+
return 0, io.EOF
344+
}
345+
return p.r.Read(data)
346+
}
347+
348+
func (p *websocketChannel) Close() error {
349+
return p.w.Close()
350+
}

0 commit comments

Comments
 (0)