diff --git a/cmd/mock-driver/main.go b/cmd/mock-driver/main.go index 0b8a477e..12b64717 100644 --- a/cmd/mock-driver/main.go +++ b/cmd/mock-driver/main.go @@ -40,24 +40,11 @@ func main() { flag.Parse() endpoint := os.Getenv("CSI_ENDPOINT") - if len(endpoint) == 0 { - fmt.Println("CSI_ENDPOINT must be defined and must be a path") - os.Exit(1) - } - if strings.Contains(endpoint, ":") { - fmt.Println("CSI_ENDPOINT must be a unix path") - os.Exit(1) - } - controllerEndpoint := os.Getenv("CSI_CONTROLLER_ENDPOINT") if len(controllerEndpoint) == 0 { // If empty, set to the common endpoint. controllerEndpoint = endpoint } - if strings.Contains(controllerEndpoint, ":") { - fmt.Println("CSI_CONTROLLER_ENDPOINT must be a unix path") - os.Exit(1) - } // Create mock driver s := service.New(config) @@ -77,16 +64,14 @@ func main() { } // Listen - os.Remove(endpoint) - os.Remove(controllerEndpoint) - l, err := net.Listen("unix", endpoint) + l, cleanup, err := listen(endpoint) if err != nil { fmt.Printf("Error: Unable to listen on %s socket: %v\n", endpoint, err) os.Exit(1) } - defer os.Remove(endpoint) + defer cleanup() // Start server if err := d.Start(l); err != nil { @@ -129,15 +114,14 @@ func main() { } // Listen controller. - os.Remove(controllerEndpoint) - l, err := net.Listen("unix", controllerEndpoint) + l, cleanupController, err := listen(controllerEndpoint) if err != nil { fmt.Printf("Error: Unable to listen on %s socket: %v\n", controllerEndpoint, err) os.Exit(1) } - defer os.Remove(controllerEndpoint) + defer cleanupController() // Start controller server. if err = dc.Start(l); err != nil { @@ -148,15 +132,14 @@ func main() { fmt.Println("mock controller driver started") // Listen node. - os.Remove(endpoint) - l, err = net.Listen("unix", endpoint) + l, cleanupNode, err := listen(endpoint) if err != nil { fmt.Printf("Error: Unable to listen on %s socket: %v\n", endpoint, err) os.Exit(1) } - defer os.Remove(endpoint) + defer cleanupNode() // Start node server. if err = dn.Start(l); err != nil { @@ -182,3 +165,36 @@ func main() { fmt.Println("mock drivers stopped") } } + +func parseEndpoint(ep string) (string, string, error) { + if strings.HasPrefix(strings.ToLower(ep), "unix://") || strings.HasPrefix(strings.ToLower(ep), "tcp://") { + s := strings.SplitN(ep, "://", 2) + if s[1] != "" { + return s[0], s[1], nil + } + return "", "", fmt.Errorf("Invalid endpoint: %v", ep) + } + // Assume everything else is a file path for a Unix Domain Socket. + return "unix", ep, nil +} + +func listen(endpoint string) (net.Listener, func(), error) { + proto, addr, err := parseEndpoint(endpoint) + if err != nil { + return nil, nil, err + } + + cleanup := func() {} + if proto == "unix" { + addr = "/" + addr + if err := os.Remove(addr); err != nil && !os.IsNotExist(err) { //nolint: vetshadow + return nil, nil, fmt.Errorf("%s: %q", addr, err) + } + cleanup = func() { + os.Remove(addr) + } + } + + l, err := net.Listen(proto, addr) + return l, cleanup, err +} diff --git a/driver/mock.go b/driver/mock.go index 399d0170..53414d7e 100644 --- a/driver/mock.go +++ b/driver/mock.go @@ -75,7 +75,7 @@ func (m *MockCSIDriver) Nexus() (*grpc.ClientConn, error) { } // Create a client connection - m.conn, err = utils.Connect(m.Address()) + m.conn, err = utils.Connect(m.Address(), grpc.WithInsecure()) if err != nil { return nil, err } diff --git a/hack/e2e.sh b/hack/e2e.sh index e3a8c6cf..8cb59e9c 100755 --- a/hack/e2e.sh +++ b/hack/e2e.sh @@ -4,6 +4,10 @@ TESTARGS=$@ UDS="/tmp/e2e-csi-sanity.sock" UDS_NODE="/tmp/e2e-csi-sanity-node.sock" UDS_CONTROLLER="/tmp/e2e-csi-sanity-ctrl.sock" +# Protocol specified as for net.Listen... +TCP_SERVER="tcp://localhost:7654" +# ... and slightly differently for gRPC. +TCP_CLIENT="dns:///localhost:7654" CSI_ENDPOINTS="$CSI_ENDPOINTS ${UDS}" CSI_MOCK_VERSION="master" @@ -108,6 +112,7 @@ cd cmd/csi-sanity make clean install || exit 1 cd ../.. +runTest "${TCP_SERVER}" "${TCP_CLIENT}" && runTest "${UDS}" "${UDS}" && runTestWithCreds "${UDS}" "${UDS}" && runTestAPI "${UDS}" && diff --git a/pkg/sanity/sanity.go b/pkg/sanity/sanity.go index 86be7caa..432b9af2 100644 --- a/pkg/sanity/sanity.go +++ b/pkg/sanity/sanity.go @@ -71,10 +71,19 @@ type TestConfig struct { // is empty, it must provide both the controller and node service. Address string + // DialOptions specifies the options that are to be used + // when connecting to Address. The default is grpc.WithInsecure(). + // A dialer will be added for Unix Domain Sockets. + DialOptions []grpc.DialOption + // ControllerAddress optionally provides the gRPC endpoint of // the controller service. ControllerAddress string + // ControllerDialOptions specifies the options that are to be used + // for ControllerAddress. + ControllerDialOptions []grpc.DialOption + // SecretsFile is the filename of a .yaml file which is used // to populate CSISecrets which are then used for calls to the // CSI driver. @@ -175,6 +184,9 @@ func NewTestConfig() TestConfig { RemovePathCmdTimeout: 10 * time.Second, TestVolumeSize: 10 * 1024 * 1024 * 1024, // 10 GiB IDGen: &DefaultIDGenerator{}, + + DialOptions: []grpc.DialOption{grpc.WithInsecure()}, + ControllerDialOptions: []grpc.DialOption{grpc.WithInsecure()}, } } @@ -240,7 +252,7 @@ func (sc *TestContext) Setup() { sc.Conn.Close() } By("connecting to CSI driver") - sc.Conn, err = utils.Connect(sc.Config.Address) + sc.Conn, err = utils.Connect(sc.Config.Address, sc.Config.DialOptions...) Expect(err).NotTo(HaveOccurred()) sc.connAddress = sc.Config.Address } else { @@ -253,7 +265,7 @@ func (sc *TestContext) Setup() { sc.ControllerConn = sc.Conn sc.controllerConnAddress = sc.Config.Address } else { - sc.ControllerConn, err = utils.Connect(sc.Config.ControllerAddress) + sc.ControllerConn, err = utils.Connect(sc.Config.ControllerAddress, sc.Config.ControllerDialOptions...) Expect(err).NotTo(HaveOccurred()) sc.controllerConnAddress = sc.Config.ControllerAddress } diff --git a/test/driver_test.go b/test/driver_test.go index fe82aa6a..2a0d23c1 100644 --- a/test/driver_test.go +++ b/test/driver_test.go @@ -106,7 +106,7 @@ func TestSimpleDriver(t *testing.T) { defer s.Stop() // Setup a connection to the driver - conn, err := utils.Connect(s.Address()) + conn, err := utils.Connect(s.Address(), grpc.WithInsecure()) if err != nil { t.Errorf("Error: %s", err.Error()) } diff --git a/utils/grpcutil.go b/utils/grpcutil.go index ff0587f7..3ef7050b 100644 --- a/utils/grpcutil.go +++ b/utils/grpcutil.go @@ -25,14 +25,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/keepalive" ) // Connect address by grpc -func Connect(address string) (*grpc.ClientConn, error) { - dialOptions := []grpc.DialOption{ - grpc.WithInsecure(), - } +func Connect(address string, dialOptions ...grpc.DialOption) (*grpc.ClientConn, error) { u, err := url.Parse(address) if err == nil && (!u.IsAbs() || u.Scheme == "unix") { dialOptions = append(dialOptions, @@ -41,11 +37,6 @@ func Connect(address string) (*grpc.ClientConn, error) { return net.DialTimeout("unix", u.Path, timeout) })) } - // This is necessary when connecting via TCP and does not hurt - // when using Unix domain sockets. It ensures that gRPC detects a dead connection - // in a timely manner. - dialOptions = append(dialOptions, - grpc.WithKeepaliveParams(keepalive.ClientParameters{PermitWithoutStream: true})) conn, err := grpc.Dial(address, dialOptions...) if err != nil { diff --git a/vendor/modules.txt b/vendor/modules.txt index 8a931279..760e93c0 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -97,7 +97,6 @@ google.golang.org/grpc/codes google.golang.org/grpc/reflection google.golang.org/grpc/status google.golang.org/grpc/connectivity -google.golang.org/grpc/keepalive google.golang.org/grpc/balancer google.golang.org/grpc/balancer/roundrobin google.golang.org/grpc/credentials @@ -112,6 +111,7 @@ google.golang.org/grpc/internal/envconfig google.golang.org/grpc/internal/grpcrand google.golang.org/grpc/internal/grpcsync google.golang.org/grpc/internal/transport +google.golang.org/grpc/keepalive google.golang.org/grpc/metadata google.golang.org/grpc/naming google.golang.org/grpc/peer