Skip to content

Commit 917d08c

Browse files
committed
Improve code
Signed-off-by: Ruben Vargas <[email protected]>
1 parent 206ccd6 commit 917d08c

File tree

2 files changed

+64
-42
lines changed

2 files changed

+64
-42
lines changed

tls/cert_watcher.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ func (w *CertWatcher) isModified() bool {
6868
hash, err := w.hashFile(w.CAPath)
6969
if err != nil {
7070
level.Error(w.logger).Log("unable to read the file", "error", err.Error())
71-
return true
71+
return false
7272
}
7373
changed := w.fileHashContent != hash
7474
w.fileHashContent = hash

tls/options.go

Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,58 +22,80 @@ type UpstreamOptions struct {
2222
func NewUpstreamOptions(upstreamCertFile, upstreamKeyFile, upstreamCAFile string,
2323
interval *time.Duration, logger log.Logger, ctx context.Context, g run.Group) (*UpstreamOptions, error) {
2424

25+
if interval != nil {
26+
return newWithWatchers(upstreamCertFile, upstreamKeyFile, upstreamCAFile, interval, logger, ctx, g)
27+
}
28+
29+
return newNoWatchers(upstreamCertFile, upstreamKeyFile, upstreamCAFile)
30+
}
31+
32+
func newWithWatchers(upstreamCertFile, upstreamKeyFile, upstreamCAFile string,
33+
interval *time.Duration, logger log.Logger, ctx context.Context, g run.Group) (*UpstreamOptions, error) {
2534
options := &UpstreamOptions{}
2635

27-
if interval != nil {
28-
if upstreamCertFile != "" && upstreamKeyFile != "" {
29-
certReloader, err := rbacproxytls.NewCertReloader(
30-
upstreamKeyFile,
31-
upstreamCertFile,
32-
*interval,
33-
)
34-
if err != nil {
35-
return nil, err
36-
}
37-
options.certReloader = certReloader
38-
ctx, cancel := context.WithCancel(ctx)
39-
g.Add(func() error {
40-
return certReloader.Watch(ctx)
41-
}, func(error) {
42-
cancel()
43-
})
44-
}
45-
if upstreamCAFile != "" {
46-
caReloader := NewCertWatcher(upstreamCAFile, logger)
47-
options.caReloader = caReloader
48-
ctx, cancel := context.WithCancel(ctx)
49-
g.Add(func() error {
50-
return caReloader.Watch(ctx)
51-
}, func(error) {
52-
cancel()
53-
})
36+
if upstreamCertFile != "" && upstreamKeyFile != "" {
37+
certReloader, err := startCertReloader(ctx, g, upstreamCertFile, upstreamKeyFile, *interval)
38+
if err != nil {
39+
return nil, err
5440
}
55-
} else {
56-
if upstreamCertFile != "" && upstreamKeyFile != "" {
57-
cert, err := stdtls.LoadX509KeyPair(upstreamCertFile, upstreamKeyFile)
58-
if err != nil {
59-
return nil, err
60-
}
61-
options.cert = &cert
41+
options.certReloader = certReloader
42+
}
43+
if upstreamCAFile != "" {
44+
options.caReloader = startCAReloader(ctx, g, upstreamCertFile, logger)
45+
}
46+
return options, nil
47+
}
6248

49+
func newNoWatchers(upstreamCertFile, upstreamKeyFile, upstreamCAFile string) (*UpstreamOptions, error) {
50+
options := &UpstreamOptions{}
51+
if upstreamCertFile != "" && upstreamKeyFile != "" {
52+
cert, err := stdtls.LoadX509KeyPair(upstreamCertFile, upstreamKeyFile)
53+
if err != nil {
54+
return nil, err
6355
}
56+
options.cert = &cert
57+
}
6458

65-
if upstreamCAFile != "" {
66-
ca, err := os.ReadFile(upstreamCAFile)
67-
if err != nil {
68-
return nil, err
69-
}
70-
options.ca = ca
59+
if upstreamCAFile != "" {
60+
ca, err := os.ReadFile(upstreamCAFile)
61+
if err != nil {
62+
return nil, err
7163
}
72-
64+
options.ca = ca
7365
}
7466
return options, nil
7567
}
7668

69+
func startCertReloader(ctx context.Context, g run.Group,
70+
upstreamKeyFile, upstreamCertFile string, interval time.Duration) (*rbacproxytls.CertReloader, error) {
71+
certReloader, err := rbacproxytls.NewCertReloader(
72+
upstreamKeyFile,
73+
upstreamCertFile,
74+
interval,
75+
)
76+
if err != nil {
77+
return nil, err
78+
}
79+
ctx, cancel := context.WithCancel(ctx)
80+
g.Add(func() error {
81+
return certReloader.Watch(ctx)
82+
}, func(error) {
83+
cancel()
84+
})
85+
return certReloader, nil
86+
}
87+
88+
func startCAReloader(ctx context.Context, g run.Group, upstreamCAFile string, logger log.Logger) *CertWatcher {
89+
caReloader := NewCertWatcher(upstreamCAFile, logger)
90+
ctx, cancel := context.WithCancel(ctx)
91+
g.Add(func() error {
92+
return caReloader.Watch(ctx)
93+
}, func(error) {
94+
cancel()
95+
})
96+
return caReloader
97+
}
98+
7799
func (uo *UpstreamOptions) hasCA() bool {
78100
return len(uo.ca) != 0 || uo.caReloader != nil
79101
}

0 commit comments

Comments
 (0)