Skip to content

Commit a68ba27

Browse files
authored
Merge pull request #239 from pohly/keepalive-uds
keepalive API
2 parents fb5a4a4 + cebc604 commit a68ba27

File tree

7 files changed

+62
-38
lines changed

7 files changed

+62
-38
lines changed

cmd/mock-driver/main.go

+39-23
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,11 @@ func main() {
4040
flag.Parse()
4141

4242
endpoint := os.Getenv("CSI_ENDPOINT")
43-
if len(endpoint) == 0 {
44-
fmt.Println("CSI_ENDPOINT must be defined and must be a path")
45-
os.Exit(1)
46-
}
47-
if strings.Contains(endpoint, ":") {
48-
fmt.Println("CSI_ENDPOINT must be a unix path")
49-
os.Exit(1)
50-
}
51-
5243
controllerEndpoint := os.Getenv("CSI_CONTROLLER_ENDPOINT")
5344
if len(controllerEndpoint) == 0 {
5445
// If empty, set to the common endpoint.
5546
controllerEndpoint = endpoint
5647
}
57-
if strings.Contains(controllerEndpoint, ":") {
58-
fmt.Println("CSI_CONTROLLER_ENDPOINT must be a unix path")
59-
os.Exit(1)
60-
}
6148

6249
// Create mock driver
6350
s := service.New(config)
@@ -77,16 +64,14 @@ func main() {
7764
}
7865

7966
// Listen
80-
os.Remove(endpoint)
81-
os.Remove(controllerEndpoint)
82-
l, err := net.Listen("unix", endpoint)
67+
l, cleanup, err := listen(endpoint)
8368
if err != nil {
8469
fmt.Printf("Error: Unable to listen on %s socket: %v\n",
8570
endpoint,
8671
err)
8772
os.Exit(1)
8873
}
89-
defer os.Remove(endpoint)
74+
defer cleanup()
9075

9176
// Start server
9277
if err := d.Start(l); err != nil {
@@ -129,15 +114,14 @@ func main() {
129114
}
130115

131116
// Listen controller.
132-
os.Remove(controllerEndpoint)
133-
l, err := net.Listen("unix", controllerEndpoint)
117+
l, cleanupController, err := listen(controllerEndpoint)
134118
if err != nil {
135119
fmt.Printf("Error: Unable to listen on %s socket: %v\n",
136120
controllerEndpoint,
137121
err)
138122
os.Exit(1)
139123
}
140-
defer os.Remove(controllerEndpoint)
124+
defer cleanupController()
141125

142126
// Start controller server.
143127
if err = dc.Start(l); err != nil {
@@ -148,15 +132,14 @@ func main() {
148132
fmt.Println("mock controller driver started")
149133

150134
// Listen node.
151-
os.Remove(endpoint)
152-
l, err = net.Listen("unix", endpoint)
135+
l, cleanupNode, err := listen(endpoint)
153136
if err != nil {
154137
fmt.Printf("Error: Unable to listen on %s socket: %v\n",
155138
endpoint,
156139
err)
157140
os.Exit(1)
158141
}
159-
defer os.Remove(endpoint)
142+
defer cleanupNode()
160143

161144
// Start node server.
162145
if err = dn.Start(l); err != nil {
@@ -182,3 +165,36 @@ func main() {
182165
fmt.Println("mock drivers stopped")
183166
}
184167
}
168+
169+
func parseEndpoint(ep string) (string, string, error) {
170+
if strings.HasPrefix(strings.ToLower(ep), "unix://") || strings.HasPrefix(strings.ToLower(ep), "tcp://") {
171+
s := strings.SplitN(ep, "://", 2)
172+
if s[1] != "" {
173+
return s[0], s[1], nil
174+
}
175+
return "", "", fmt.Errorf("Invalid endpoint: %v", ep)
176+
}
177+
// Assume everything else is a file path for a Unix Domain Socket.
178+
return "unix", ep, nil
179+
}
180+
181+
func listen(endpoint string) (net.Listener, func(), error) {
182+
proto, addr, err := parseEndpoint(endpoint)
183+
if err != nil {
184+
return nil, nil, err
185+
}
186+
187+
cleanup := func() {}
188+
if proto == "unix" {
189+
addr = "/" + addr
190+
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { //nolint: vetshadow
191+
return nil, nil, fmt.Errorf("%s: %q", addr, err)
192+
}
193+
cleanup = func() {
194+
os.Remove(addr)
195+
}
196+
}
197+
198+
l, err := net.Listen(proto, addr)
199+
return l, cleanup, err
200+
}

driver/mock.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (m *MockCSIDriver) Nexus() (*grpc.ClientConn, error) {
7575
}
7676

7777
// Create a client connection
78-
m.conn, err = utils.Connect(m.Address())
78+
m.conn, err = utils.Connect(m.Address(), grpc.WithInsecure())
7979
if err != nil {
8080
return nil, err
8181
}

hack/e2e.sh

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ TESTARGS=$@
44
UDS="/tmp/e2e-csi-sanity.sock"
55
UDS_NODE="/tmp/e2e-csi-sanity-node.sock"
66
UDS_CONTROLLER="/tmp/e2e-csi-sanity-ctrl.sock"
7+
# Protocol specified as for net.Listen...
8+
TCP_SERVER="tcp://localhost:7654"
9+
# ... and slightly differently for gRPC.
10+
TCP_CLIENT="dns:///localhost:7654"
711
CSI_ENDPOINTS="$CSI_ENDPOINTS ${UDS}"
812
CSI_MOCK_VERSION="master"
913

@@ -108,6 +112,7 @@ cd cmd/csi-sanity
108112
make clean install || exit 1
109113
cd ../..
110114

115+
runTest "${TCP_SERVER}" "${TCP_CLIENT}" &&
111116
runTest "${UDS}" "${UDS}" &&
112117
runTestWithCreds "${UDS}" "${UDS}" &&
113118
runTestAPI "${UDS}" &&

pkg/sanity/sanity.go

+14-2
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,19 @@ type TestConfig struct {
7070
// is empty, it must provide both the controller and node service.
7171
Address string
7272

73+
// DialOptions specifies the options that are to be used
74+
// when connecting to Address. The default is grpc.WithInsecure().
75+
// A dialer will be added for Unix Domain Sockets.
76+
DialOptions []grpc.DialOption
77+
7378
// ControllerAddress optionally provides the gRPC endpoint of
7479
// the controller service.
7580
ControllerAddress string
7681

82+
// ControllerDialOptions specifies the options that are to be used
83+
// for ControllerAddress.
84+
ControllerDialOptions []grpc.DialOption
85+
7786
// SecretsFile is the filename of a .yaml file which is used
7887
// to populate CSISecrets which are then used for calls to the
7988
// CSI driver.
@@ -174,6 +183,9 @@ func NewTestConfig() TestConfig {
174183
RemovePathCmdTimeout: 10 * time.Second,
175184
TestVolumeSize: 10 * 1024 * 1024 * 1024, // 10 GiB
176185
IDGen: &DefaultIDGenerator{},
186+
187+
DialOptions: []grpc.DialOption{grpc.WithInsecure()},
188+
ControllerDialOptions: []grpc.DialOption{grpc.WithInsecure()},
177189
}
178190
}
179191

@@ -239,7 +251,7 @@ func (sc *TestContext) Setup() {
239251
sc.Conn.Close()
240252
}
241253
By("connecting to CSI driver")
242-
sc.Conn, err = utils.Connect(sc.Config.Address)
254+
sc.Conn, err = utils.Connect(sc.Config.Address, sc.Config.DialOptions...)
243255
Expect(err).NotTo(HaveOccurred())
244256
sc.connAddress = sc.Config.Address
245257
} else {
@@ -252,7 +264,7 @@ func (sc *TestContext) Setup() {
252264
sc.ControllerConn = sc.Conn
253265
sc.controllerConnAddress = sc.Config.Address
254266
} else {
255-
sc.ControllerConn, err = utils.Connect(sc.Config.ControllerAddress)
267+
sc.ControllerConn, err = utils.Connect(sc.Config.ControllerAddress, sc.Config.ControllerDialOptions...)
256268
Expect(err).NotTo(HaveOccurred())
257269
sc.controllerConnAddress = sc.Config.ControllerAddress
258270
}

test/driver_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func TestSimpleDriver(t *testing.T) {
106106
defer s.Stop()
107107

108108
// Setup a connection to the driver
109-
conn, err := utils.Connect(s.Address())
109+
conn, err := utils.Connect(s.Address(), grpc.WithInsecure())
110110
if err != nil {
111111
t.Errorf("Error: %s", err.Error())
112112
}

utils/grpcutil.go

+1-10
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,10 @@ import (
2525

2626
"google.golang.org/grpc"
2727
"google.golang.org/grpc/connectivity"
28-
"google.golang.org/grpc/keepalive"
2928
)
3029

3130
// Connect address by grpc
32-
func Connect(address string) (*grpc.ClientConn, error) {
33-
dialOptions := []grpc.DialOption{
34-
grpc.WithInsecure(),
35-
}
31+
func Connect(address string, dialOptions ...grpc.DialOption) (*grpc.ClientConn, error) {
3632
u, err := url.Parse(address)
3733
if err == nil && (!u.IsAbs() || u.Scheme == "unix") {
3834
dialOptions = append(dialOptions,
@@ -41,11 +37,6 @@ func Connect(address string) (*grpc.ClientConn, error) {
4137
return net.DialTimeout("unix", u.Path, timeout)
4238
}))
4339
}
44-
// This is necessary when connecting via TCP and does not hurt
45-
// when using Unix domain sockets. It ensures that gRPC detects a dead connection
46-
// in a timely manner.
47-
dialOptions = append(dialOptions,
48-
grpc.WithKeepaliveParams(keepalive.ClientParameters{PermitWithoutStream: true}))
4940

5041
conn, err := grpc.Dial(address, dialOptions...)
5142
if err != nil {

vendor/modules.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ google.golang.org/grpc/codes
9797
google.golang.org/grpc/reflection
9898
google.golang.org/grpc/status
9999
google.golang.org/grpc/connectivity
100-
google.golang.org/grpc/keepalive
101100
google.golang.org/grpc/balancer
102101
google.golang.org/grpc/balancer/roundrobin
103102
google.golang.org/grpc/credentials
@@ -112,6 +111,7 @@ google.golang.org/grpc/internal/envconfig
112111
google.golang.org/grpc/internal/grpcrand
113112
google.golang.org/grpc/internal/grpcsync
114113
google.golang.org/grpc/internal/transport
114+
google.golang.org/grpc/keepalive
115115
google.golang.org/grpc/metadata
116116
google.golang.org/grpc/naming
117117
google.golang.org/grpc/peer

0 commit comments

Comments
 (0)