Skip to content

Commit eb46788

Browse files
authored
Merge pull request #69 from nettoclaudio/fix/issue-68
Open a new connection for each probe call
2 parents 4dc55d3 + a0d6293 commit eb46788

File tree

2 files changed

+131
-25
lines changed

2 files changed

+131
-25
lines changed

cmd/livenessprobe/livenessprobe_test.go

Lines changed: 88 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,17 @@ limitations under the License.
1717
package main
1818

1919
import (
20+
"flag"
21+
"fmt"
22+
"io/ioutil"
2023
"net/http"
2124
"net/http/httptest"
25+
"os"
2226
"testing"
2327

2428
csi "github.com/container-storage-interface/spec/lib/go/csi"
2529
"github.com/golang/mock/gomock"
26-
connlib "github.com/kubernetes-csi/csi-lib-utils/connection"
2730
"github.com/kubernetes-csi/csi-test/driver"
28-
"google.golang.org/grpc"
2931
)
3032

3133
const (
@@ -38,8 +40,7 @@ func createMockServer(t *testing.T) (
3840
*driver.MockIdentityServer,
3941
*driver.MockControllerServer,
4042
*driver.MockNodeServer,
41-
*grpc.ClientConn,
42-
error) {
43+
func()) {
4344
// Start the mock server
4445
mockController := gomock.NewController(t)
4546
identityServer := driver.NewMockIdentityServer(mockController)
@@ -50,37 +51,77 @@ func createMockServer(t *testing.T) (
5051
Controller: controllerServer,
5152
Node: nodeServer,
5253
})
53-
drv.Start()
5454

55-
// Create a client connection to it
56-
addr := drv.Address()
57-
csiConn, err := connlib.Connect(addr)
55+
tmpDir, err := ioutil.TempDir("", "livenessprobe_test.*")
5856
if err != nil {
59-
return nil, nil, nil, nil, nil, nil, err
57+
t.Errorf("failed to create a temporary socket file name: %v", err)
6058
}
6159

62-
return mockController, drv, identityServer, controllerServer, nodeServer, csiConn, nil
60+
csiEndpoint := fmt.Sprintf("%s/csi.sock", tmpDir)
61+
err = drv.StartOnAddress("unix", csiEndpoint)
62+
if err != nil {
63+
t.Errorf("failed to start the csi driver at %s: %v", csiEndpoint, err)
64+
}
65+
66+
return mockController, drv, identityServer, controllerServer, nodeServer, func() {
67+
mockController.Finish()
68+
drv.Stop()
69+
os.RemoveAll(csiEndpoint)
70+
}
6371
}
6472

6573
func TestProbe(t *testing.T) {
66-
mockController, driver, idServer, _, _, csiConn, err := createMockServer(t)
74+
_, driver, idServer, _, _, cleanUpFunc := createMockServer(t)
75+
defer cleanUpFunc()
76+
77+
flag.Set("csi-address", driver.Address())
78+
flag.Parse()
79+
80+
var injectedErr error
81+
82+
inProbe := &csi.ProbeRequest{}
83+
outProbe := &csi.ProbeResponse{}
84+
idServer.EXPECT().Probe(gomock.Any(), inProbe).Return(outProbe, injectedErr).Times(1)
85+
86+
hp := &healthProbe{driverName: driverName}
87+
88+
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
89+
if req.URL.String() == "/healthz" {
90+
hp.checkProbe(rw, req)
91+
}
92+
}))
93+
defer server.Close()
94+
95+
httpreq, err := http.NewRequest("GET", fmt.Sprintf("%s/healthz", server.URL), nil)
6796
if err != nil {
68-
t.Fatal(err)
97+
t.Fatalf("failed to build test request for health check: %v", err)
6998
}
70-
defer mockController.Finish()
71-
defer driver.Stop()
72-
defer csiConn.Close()
99+
100+
httpresp, err := http.DefaultClient.Do(httpreq)
101+
if err != nil {
102+
t.Errorf("failed to check probe: %v", err)
103+
}
104+
105+
expectedStatusCode := http.StatusOK
106+
if httpresp.StatusCode != expectedStatusCode {
107+
t.Errorf("expected status code %d but got %d", expectedStatusCode, httpresp.StatusCode)
108+
}
109+
}
110+
111+
func TestProbe_issue68(t *testing.T) {
112+
_, driver, idServer, _, _, cleanUpFunc := createMockServer(t)
113+
defer cleanUpFunc()
114+
115+
flag.Set("csi-address", driver.Address())
116+
flag.Parse()
73117

74118
var injectedErr error
75119

76120
inProbe := &csi.ProbeRequest{}
77121
outProbe := &csi.ProbeResponse{}
78122
idServer.EXPECT().Probe(gomock.Any(), inProbe).Return(outProbe, injectedErr).Times(1)
79123

80-
hp := &healthProbe{
81-
conn: csiConn,
82-
driverName: driverName,
83-
}
124+
hp := &healthProbe{driverName: driverName}
84125

85126
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
86127
if req.URL.String() == "/healthz" {
@@ -89,12 +130,38 @@ func TestProbe(t *testing.T) {
89130
}))
90131
defer server.Close()
91132

92-
httpreq, err := http.NewRequest("GET", server.URL+"/healthz", nil)
133+
httpreq, err := http.NewRequest("GET", fmt.Sprintf("%s/healthz", server.URL), nil)
93134
if err != nil {
94135
t.Fatalf("failed to build test request for health check: %v", err)
95136
}
96-
_, err = http.DefaultClient.Do(httpreq)
137+
138+
httpresp, err := http.DefaultClient.Do(httpreq)
97139
if err != nil {
98140
t.Errorf("failed to check probe: %v", err)
99141
}
142+
143+
expectedStatusCode := http.StatusOK
144+
if httpresp.StatusCode != expectedStatusCode {
145+
t.Errorf("expected status code %d but got %d", expectedStatusCode, httpresp.StatusCode)
146+
}
147+
148+
err = os.Remove(driver.Address())
149+
if err != nil {
150+
t.Errorf("failed to remove the csi driver socket file: %v", err)
151+
}
152+
153+
httpreq, err = http.NewRequest("GET", fmt.Sprintf("%s/healthz", server.URL), nil)
154+
if err != nil {
155+
t.Fatalf("failed to build test request for health check: %v", err)
156+
}
157+
158+
httpresp, err = http.DefaultClient.Do(httpreq)
159+
if err != nil {
160+
t.Errorf("failed to check probe: %v", err)
161+
}
162+
163+
expectedStatusCode = http.StatusInternalServerError
164+
if httpresp.StatusCode != expectedStatusCode {
165+
t.Errorf("expected status code %d but got %d", expectedStatusCode, httpresp.StatusCode)
166+
}
100167
}

cmd/livenessprobe/main.go

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"flag"
2222
"net"
2323
"net/http"
24+
"sync"
2425
"time"
2526

2627
"k8s.io/klog"
@@ -39,16 +40,24 @@ var (
3940
)
4041

4142
type healthProbe struct {
42-
conn *grpc.ClientConn
4343
driverName string
4444
}
4545

4646
func (h *healthProbe) checkProbe(w http.ResponseWriter, req *http.Request) {
4747
ctx, cancel := context.WithTimeout(req.Context(), *probeTimeout)
4848
defer cancel()
4949

50+
conn, err := acquireConnection(ctx)
51+
if err != nil {
52+
w.WriteHeader(http.StatusInternalServerError)
53+
w.Write([]byte(err.Error()))
54+
klog.Errorf("failed to establish connection to CSI driver: %v", err)
55+
return
56+
}
57+
defer conn.Close()
58+
5059
klog.V(5).Infof("Sending probe request to CSI driver %q", h.driverName)
51-
ready, err := rpc.Probe(ctx, h.conn)
60+
ready, err := rpc.Probe(ctx, conn)
5261
if err != nil {
5362
w.WriteHeader(http.StatusInternalServerError)
5463
w.Write([]byte(err.Error()))
@@ -68,12 +77,42 @@ func (h *healthProbe) checkProbe(w http.ResponseWriter, req *http.Request) {
6877
klog.V(5).Infof("Health check succeeded")
6978
}
7079

80+
// acquireConnection wraps the connlib.Connect but adding support to context
81+
// cancelation.
82+
func acquireConnection(ctx context.Context) (conn *grpc.ClientConn, err error) {
83+
var m sync.Mutex
84+
var canceled bool
85+
ready := make(chan bool)
86+
go func() {
87+
conn, err = connlib.Connect(*csiAddress)
88+
89+
m.Lock()
90+
defer m.Unlock()
91+
if err != nil && canceled {
92+
conn.Close()
93+
}
94+
95+
close(ready)
96+
}()
97+
98+
select {
99+
case <-ctx.Done():
100+
m.Lock()
101+
defer m.Unlock()
102+
canceled = true
103+
return nil, ctx.Err()
104+
105+
case <-ready:
106+
return conn, err
107+
}
108+
}
109+
71110
func main() {
72111
klog.InitFlags(nil)
73112
flag.Set("logtostderr", "true")
74113
flag.Parse()
75114

76-
csiConn, err := connlib.Connect(*csiAddress)
115+
csiConn, err := acquireConnection(context.Background())
77116
if err != nil {
78117
// connlib should retry forever so a returned error should mean
79118
// the grpc client is misconfigured rather than an error on the network
@@ -82,13 +121,13 @@ func main() {
82121

83122
klog.Infof("calling CSI driver to discover driver name")
84123
csiDriverName, err := rpc.GetDriverName(context.Background(), csiConn)
124+
csiConn.Close()
85125
if err != nil {
86126
klog.Fatalf("failed to get CSI driver name: %v", err)
87127
}
88128
klog.Infof("CSI driver name: %q", csiDriverName)
89129

90130
hp := &healthProbe{
91-
conn: csiConn,
92131
driverName: csiDriverName,
93132
}
94133

0 commit comments

Comments
 (0)