Skip to content

Commit 149b147

Browse files
authored
Merge pull request #1777 from andyzhangx/main-test
test: add unit test for main function
2 parents c7e463e + 715d92c commit 149b147

File tree

2 files changed

+91
-5
lines changed

2 files changed

+91
-5
lines changed

Diff for: pkg/blobplugin/main.go

+9-5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ func init() {
5151
driverOptions.AddFlags()
5252
}
5353

54+
// exit is a separate function to handle program termination
55+
var exit = func(code int) {
56+
os.Exit(code)
57+
}
58+
5459
func main() {
5560
klog.InitFlags(nil)
5661
_ = flag.Set("logtostderr", "true")
@@ -61,12 +66,11 @@ func main() {
6166
klog.Fatalln(err)
6267
}
6368
fmt.Println(info) // nolint
64-
os.Exit(0)
69+
} else {
70+
exportMetrics()
71+
handle()
6572
}
66-
67-
exportMetrics()
68-
handle()
69-
os.Exit(0)
73+
exit(0)
7074
}
7175

7276
func handle() {

Diff for: pkg/blobplugin/main_test.go

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
Copyright 2024 The Kubernetes Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package main
18+
19+
import (
20+
"fmt"
21+
"net"
22+
"os"
23+
"reflect"
24+
"testing"
25+
)
26+
27+
func TestMain(t *testing.T) {
28+
// Set the version flag to true
29+
os.Args = []string{"cmd", "-version"}
30+
31+
// Capture stdout
32+
old := os.Stdout
33+
_, w, _ := os.Pipe()
34+
os.Stdout = w
35+
36+
// Replace exit function with mock function
37+
var exitCode int
38+
exit = func(code int) {
39+
exitCode = code
40+
}
41+
42+
// Call main function
43+
main()
44+
45+
// Restore stdout
46+
w.Close()
47+
os.Stdout = old
48+
exit = func(code int) {
49+
os.Exit(code)
50+
}
51+
52+
if exitCode != 0 {
53+
t.Errorf("Expected exit code 0, but got %d", exitCode)
54+
}
55+
}
56+
57+
func TestTrapClosedConnErr(t *testing.T) {
58+
tests := []struct {
59+
err error
60+
expectedErr error
61+
}{
62+
{
63+
err: net.ErrClosed,
64+
expectedErr: nil,
65+
},
66+
{
67+
err: nil,
68+
expectedErr: nil,
69+
},
70+
{
71+
err: fmt.Errorf("some error"),
72+
expectedErr: fmt.Errorf("some error"),
73+
},
74+
}
75+
76+
for _, test := range tests {
77+
err := trapClosedConnErr(test.err)
78+
if !reflect.DeepEqual(err, test.expectedErr) {
79+
t.Errorf("Expected error %v, but got %v", test.expectedErr, err)
80+
}
81+
}
82+
}

0 commit comments

Comments
 (0)