@@ -40,24 +40,11 @@ func main() {
40
40
flag .Parse ()
41
41
42
42
endpoint := os .Getenv ("CSI_ENDPOINT" )
43
- if len (endpoint ) == 0 {
44
- fmt .Println ("CSI_ENDPOINT must be defined and must be a path" )
45
- os .Exit (1 )
46
- }
47
- if strings .Contains (endpoint , ":" ) {
48
- fmt .Println ("CSI_ENDPOINT must be a unix path" )
49
- os .Exit (1 )
50
- }
51
-
52
43
controllerEndpoint := os .Getenv ("CSI_CONTROLLER_ENDPOINT" )
53
44
if len (controllerEndpoint ) == 0 {
54
45
// If empty, set to the common endpoint.
55
46
controllerEndpoint = endpoint
56
47
}
57
- if strings .Contains (controllerEndpoint , ":" ) {
58
- fmt .Println ("CSI_CONTROLLER_ENDPOINT must be a unix path" )
59
- os .Exit (1 )
60
- }
61
48
62
49
// Create mock driver
63
50
s := service .New (config )
@@ -77,16 +64,14 @@ func main() {
77
64
}
78
65
79
66
// Listen
80
- os .Remove (endpoint )
81
- os .Remove (controllerEndpoint )
82
- l , err := net .Listen ("unix" , endpoint )
67
+ l , cleanup , err := listen (endpoint )
83
68
if err != nil {
84
69
fmt .Printf ("Error: Unable to listen on %s socket: %v\n " ,
85
70
endpoint ,
86
71
err )
87
72
os .Exit (1 )
88
73
}
89
- defer os . Remove ( endpoint )
74
+ defer cleanup ( )
90
75
91
76
// Start server
92
77
if err := d .Start (l ); err != nil {
@@ -129,15 +114,14 @@ func main() {
129
114
}
130
115
131
116
// Listen controller.
132
- os .Remove (controllerEndpoint )
133
- l , err := net .Listen ("unix" , controllerEndpoint )
117
+ l , cleanupController , err := listen (controllerEndpoint )
134
118
if err != nil {
135
119
fmt .Printf ("Error: Unable to listen on %s socket: %v\n " ,
136
120
controllerEndpoint ,
137
121
err )
138
122
os .Exit (1 )
139
123
}
140
- defer os . Remove ( controllerEndpoint )
124
+ defer cleanupController ( )
141
125
142
126
// Start controller server.
143
127
if err = dc .Start (l ); err != nil {
@@ -148,15 +132,14 @@ func main() {
148
132
fmt .Println ("mock controller driver started" )
149
133
150
134
// Listen node.
151
- os .Remove (endpoint )
152
- l , err = net .Listen ("unix" , endpoint )
135
+ l , cleanupNode , err := listen (endpoint )
153
136
if err != nil {
154
137
fmt .Printf ("Error: Unable to listen on %s socket: %v\n " ,
155
138
endpoint ,
156
139
err )
157
140
os .Exit (1 )
158
141
}
159
- defer os . Remove ( endpoint )
142
+ defer cleanupNode ( )
160
143
161
144
// Start node server.
162
145
if err = dn .Start (l ); err != nil {
@@ -182,3 +165,36 @@ func main() {
182
165
fmt .Println ("mock drivers stopped" )
183
166
}
184
167
}
168
+
169
+ func parseEndpoint (ep string ) (string , string , error ) {
170
+ if strings .HasPrefix (strings .ToLower (ep ), "unix://" ) || strings .HasPrefix (strings .ToLower (ep ), "tcp://" ) {
171
+ s := strings .SplitN (ep , "://" , 2 )
172
+ if s [1 ] != "" {
173
+ return s [0 ], s [1 ], nil
174
+ }
175
+ return "" , "" , fmt .Errorf ("Invalid endpoint: %v" , ep )
176
+ }
177
+ // Assume everything else is a file path for a Unix Domain Socket.
178
+ return "unix" , ep , nil
179
+ }
180
+
181
+ func listen (endpoint string ) (net.Listener , func (), error ) {
182
+ proto , addr , err := parseEndpoint (endpoint )
183
+ if err != nil {
184
+ return nil , nil , err
185
+ }
186
+
187
+ cleanup := func () {}
188
+ if proto == "unix" {
189
+ addr = "/" + addr
190
+ if err := os .Remove (addr ); err != nil && ! os .IsNotExist (err ) { //nolint: vetshadow
191
+ return nil , nil , fmt .Errorf ("%s: %q" , addr , err )
192
+ }
193
+ cleanup = func () {
194
+ os .Remove (addr )
195
+ }
196
+ }
197
+
198
+ l , err := net .Listen (proto , addr )
199
+ return l , cleanup , err
200
+ }
0 commit comments