Skip to content

Commit 44abbae

Browse files
committed
Add watcher to TLS certificates
Signed-off-by: Ruben Vargas <[email protected]>
1 parent ca8d2de commit 44abbae

File tree

8 files changed

+337
-44
lines changed

8 files changed

+337
-44
lines changed

api/traces/v1/api.go

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package v1
22

33
import (
44
"context"
5-
stdtls "crypto/tls"
65
"time"
76

87
"github.com/go-kit/log"
@@ -18,9 +17,8 @@ import (
1817
const TraceRoute = "/opentelemetry.proto.collector.trace.v1.TraceService/Export"
1918

2019
type connOptions struct {
21-
logger log.Logger
22-
tracesUpstreamCert *stdtls.Certificate
23-
tracesUpstreamCA []byte
20+
logger log.Logger
21+
watcher *tls.CertWatcher
2422
}
2523

2624
// ClientOption modifies the connection's configuration.
@@ -33,15 +31,14 @@ func WithLogger(logger log.Logger) ClientOption {
3331
}
3432
}
3533

36-
func WithUpstreamTLS(tracesUpstreamCA []byte, tracesUpstreamCert *stdtls.Certificate) ClientOption {
34+
func WithUpstreamTLSWatcher(watcher *tls.CertWatcher) ClientOption {
3735
return func(h *connOptions) {
38-
h.tracesUpstreamCA = tracesUpstreamCA
39-
h.tracesUpstreamCert = tracesUpstreamCert
36+
h.watcher = watcher
4037
}
4138
}
4239

43-
func newCredentials(upstreamCA []byte, upstreamCert *stdtls.Certificate) credentials.TransportCredentials {
44-
tlsConfig := tls.NewClientConfig(upstreamCA, upstreamCert)
40+
func newCredentials(watcher *tls.CertWatcher) credentials.TransportCredentials {
41+
tlsConfig := tls.NewClientConfigFromWatcher(watcher)
4542
if tlsConfig == nil {
4643
return insecure.NewCredentials()
4744
}
@@ -70,5 +67,5 @@ func NewOTelConnection(write string, opts ...ClientOption) (*grpc.ClientConn, er
7067
// because the codec we need to register is also deprecated. A better fix, is the newer
7168
// version of mwitkow/grpc-proxy, but that version doesn't (currently) work with OTel protocol.
7269
grpc.WithCodec(grpcproxy.Codec()), // nolint: staticcheck
73-
grpc.WithTransportCredentials(newCredentials(c.tracesUpstreamCA, c.tracesUpstreamCert)))
70+
grpc.WithTransportCredentials(newCredentials(c.watcher)))
7471
}

api/traces/v1/http.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"compress/flate"
66
"compress/gzip"
7-
stdtls "crypto/tls"
87
"fmt"
98
"io"
109
"net"
@@ -109,7 +108,7 @@ func (n nopInstrumentHandler) NewHandler(labels prometheus.Labels, handler http.
109108
// The web UI handler is able to rewrite
110109
// HTML to change the <base> attribute so that it works with the Observatorium-style
111110
// "/api/v1/traces/{tenant}/" URLs.
112-
func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPHttp *url.URL, upstreamCA []byte, upstreamCert *stdtls.Certificate, opts ...HandlerOption) http.Handler {
111+
func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPHttp *url.URL, upstreamTLSWatcher *tls.CertWatcher, opts ...HandlerOption) http.Handler {
113112

114113
if read == nil && readTemplate == "" && tempo == nil {
115114
panic("missing Jaeger read url")
@@ -152,7 +151,7 @@ func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPH
152151
DialContext: (&net.Dialer{
153152
Timeout: dialTimeout,
154153
}).DialContext,
155-
TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert),
154+
TLSClientConfig: tls.NewClientConfigFromWatcher(upstreamTLSWatcher),
156155
}
157156

158157
proxyRead = &httputil.ReverseProxy{
@@ -203,7 +202,7 @@ func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPH
203202
DialContext: (&net.Dialer{
204203
Timeout: dialTimeout,
205204
}).DialContext,
206-
TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert),
205+
TLSClientConfig: tls.NewClientConfigFromWatcher(upstreamTLSWatcher),
207206
}
208207

209208
proxyOTLP := &httputil.ReverseProxy{
@@ -229,7 +228,7 @@ func NewV2Handler(read *url.URL, readTemplate string, tempo *url.URL, writeOTLPH
229228
DialContext: (&net.Dialer{
230229
Timeout: dialTimeout,
231230
}).DialContext,
232-
TLSClientConfig: tls.NewClientConfig(upstreamCA, upstreamCert),
231+
TLSClientConfig: tls.NewClientConfigFromWatcher(upstreamTLSWatcher),
233232
}
234233

235234
middlewares := proxy.Middlewares(

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
github.com/deepmap/oapi-codegen v1.16.2
1111
github.com/efficientgo/core v1.0.0-rc.2
1212
github.com/efficientgo/e2e v0.14.1-0.20230413162904-ebc233c5a32f
13+
github.com/fsnotify/fsnotify v1.7.0
1314
github.com/ghodss/yaml v1.0.0
1415
github.com/go-chi/chi v4.1.2+incompatible
1516
github.com/go-chi/chi/v5 v5.0.12
@@ -84,7 +85,6 @@ require (
8485
github.com/fatih/structs v1.1.0 // indirect
8586
github.com/felixge/httpsnoop v1.0.4 // indirect
8687
github.com/flosch/pongo2/v4 v4.0.2 // indirect
87-
github.com/fsnotify/fsnotify v1.7.0 // indirect
8888
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
8989
github.com/gin-contrib/sse v0.1.0 // indirect
9090
github.com/gin-gonic/gin v1.9.1 // indirect

main.go

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -505,8 +505,6 @@ func main() {
505505
metricsUpstreamClientCert *stdtls.Certificate
506506
logsUpstreamCACert []byte
507507
logsUpstreamClientCert *stdtls.Certificate
508-
tracesUpstreamCACert []byte
509-
tracesUpstreamClientCert *stdtls.Certificate
510508
)
511509

512510
if cfg.metrics.upstreamCAFile != "" {
@@ -540,20 +538,14 @@ func main() {
540538
logsUpstreamClientCert = &clientCert
541539
}
542540

543-
if cfg.traces.upstreamCAFile != "" {
544-
tracesUpstreamCACert, err = os.ReadFile(cfg.traces.upstreamCAFile)
545-
if err != nil {
546-
stdlog.Fatalf("failed to read upstream traces TLS CA: %v", err)
547-
}
548-
549-
}
541+
tracesUpstreamTLSWatcher, err := tls.NewCertWatcher(tls.CertPaths{
542+
CAPath: cfg.traces.upstreamCAFile,
543+
CertPath: cfg.traces.upstreamCertFile,
544+
KeyPath: cfg.traces.upstreamKeyFile,
545+
}, logger)
550546

551-
if cfg.traces.upstreamCertFile != "" && cfg.traces.upstreamKeyFile != "" {
552-
clientCert, err := stdtls.LoadX509KeyPair(cfg.traces.upstreamCertFile, cfg.traces.upstreamKeyFile)
553-
if err != nil {
554-
stdlog.Fatalf("failed to read upstream traces client TLS cert/key pair: %v", err)
555-
}
556-
tracesUpstreamClientCert = &clientCert
547+
if err != nil {
548+
stdlog.Fatalf("failed to read upstream traces TLS watcher: %v", err)
557549
}
558550

559551
r := chi.NewRouter()
@@ -804,8 +796,7 @@ func main() {
804796
cfg.traces.readTemplateEndpoint,
805797
cfg.traces.tempoEndpoint,
806798
cfg.traces.writeOTLPHTTPEndpoint,
807-
tracesUpstreamCACert,
808-
tracesUpstreamClientCert,
799+
tracesUpstreamTLSWatcher,
809800
tracesv1.Logger(logger),
810801
tracesv1.WithRegistry(reg),
811802
tracesv1.WithHandlerInstrumenter(instrumenter),
@@ -895,8 +886,7 @@ func main() {
895886
pm.GRPCMiddlewares,
896887
authorizers,
897888
logger,
898-
tracesUpstreamCACert,
899-
tracesUpstreamClientCert,
889+
tracesUpstreamTLSWatcher,
900890
)
901891
if err != nil {
902892
stdlog.Fatalf("failed to initialize gRPC server: %v", err)
@@ -1458,12 +1448,12 @@ var gRPCRBAC = authorization.GRPCRBac{
14581448
}
14591449

14601450
func newGRPCServer(cfg *config, tenantHeader string, tenantIDs map[string]string, pmis authentication.GRPCMiddlewareFunc,
1461-
authorizers map[string]rbac.Authorizer, logger log.Logger, tracesUpstreamCA []byte, tracesUpstreamCert *stdtls.Certificate,
1451+
authorizers map[string]rbac.Authorizer, logger log.Logger, watcher *tls.CertWatcher,
14621452
) (*grpc.Server, error) {
14631453
connOtel, err := tracesv1.NewOTelConnection(
14641454
cfg.traces.writeOTLPGRPCEndpoint,
14651455
tracesv1.WithLogger(logger),
1466-
tracesv1.WithUpstreamTLS(tracesUpstreamCA, tracesUpstreamCert),
1456+
tracesv1.WithUpstreamTLSWatcher(watcher),
14671457
)
14681458
if err != nil {
14691459
return nil, err

tls/cert_watcher.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package tls
2+
3+
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
"errors"
7+
"fmt"
8+
"path/filepath"
9+
"sync"
10+
11+
"github.com/go-kit/log"
12+
"github.com/go-kit/log/level"
13+
)
14+
15+
type CertWatcher struct {
16+
mu sync.RWMutex
17+
watchers []*fsWatcher
18+
cert *tls.Certificate
19+
paths CertPaths
20+
certPool *x509.CertPool
21+
logger log.Logger
22+
}
23+
24+
type CertPaths struct {
25+
CertPath string
26+
KeyPath string
27+
CAPath string
28+
}
29+
30+
func NewCertWatcher(paths CertPaths, logger log.Logger) (*CertWatcher, error) {
31+
var cert *tls.Certificate
32+
if paths.CertPath != "" && paths.KeyPath != "" {
33+
// load certs at startup to catch missing certs error early
34+
c, err := tls.LoadX509KeyPair(filepath.Clean(paths.CertPath), filepath.Clean(paths.KeyPath))
35+
if err != nil {
36+
return nil, fmt.Errorf("failed to load server TLS cert and key: %w", err)
37+
}
38+
cert = &c
39+
}
40+
41+
// TLS disabled, no key and no CA specified
42+
if cert == nil && paths.CAPath == "" {
43+
return nil, nil
44+
}
45+
46+
w := &CertWatcher{
47+
paths: paths,
48+
cert: cert,
49+
logger: logger,
50+
}
51+
52+
if cert != nil {
53+
if err := w.watchCertPair(); err != nil {
54+
return nil, err
55+
}
56+
}
57+
58+
if paths.CAPath != "" {
59+
if err := w.watchCert(w.paths.CAPath); err != nil {
60+
return nil, err
61+
}
62+
}
63+
64+
return w, nil
65+
}
66+
67+
func (w *CertWatcher) Close() error {
68+
var errs []error
69+
for _, w := range w.watchers {
70+
errs = append(errs, w.Close())
71+
}
72+
return errors.Join(errs...)
73+
}
74+
75+
func (w *CertWatcher) certificate() *tls.Certificate {
76+
w.mu.RLock()
77+
defer w.mu.RUnlock()
78+
return w.cert
79+
}
80+
81+
func (w *CertWatcher) watchCertPair() error {
82+
watcher, err := newWatcher(
83+
[]string{w.paths.CertPath, w.paths.KeyPath},
84+
w.onCertPairChange, w.logger,
85+
)
86+
if err == nil {
87+
w.watchers = append(w.watchers, watcher)
88+
return nil
89+
}
90+
w.Close()
91+
return fmt.Errorf("failed to watch key pair %s and %s: %w", w.paths.KeyPath, w.paths.CertPath, err)
92+
}
93+
94+
func (w *CertWatcher) watchCert(certPath string) error {
95+
onCertChange := func() { w.onCertChange(certPath) }
96+
97+
watcher, err := newWatcher([]string{certPath}, onCertChange, w.logger)
98+
if err == nil {
99+
w.watchers = append(w.watchers, watcher)
100+
return nil
101+
}
102+
w.Close()
103+
return fmt.Errorf("failed to watch cert %s: %w", certPath, err)
104+
}
105+
106+
func (w *CertWatcher) onCertPairChange() {
107+
cert, err := tls.LoadX509KeyPair(filepath.Clean(w.paths.CertPath), filepath.Clean(w.paths.KeyPath))
108+
if err == nil {
109+
w.mu.Lock()
110+
w.cert = &cert
111+
w.mu.Unlock()
112+
level.Info(w.logger).Log("msg", "cert/key pair reloaded", "cert", w.paths.CertPath, "key", w.paths.KeyPath)
113+
} else {
114+
level.Error(w.logger).Log("msg", "error reloading cert/key pair",
115+
"cert", w.paths.CertPath, "key", w.paths.KeyPath, "error", err.Error())
116+
}
117+
}
118+
119+
func (w *CertWatcher) Pool() *x509.CertPool {
120+
return w.certPool
121+
}
122+
123+
func (w *CertWatcher) onCertChange(certPath string) {
124+
w.mu.Lock() // prevent concurrent updates to the same certPool
125+
if err := addCertToPool(certPath, w.certPool); err == nil {
126+
level.Info(w.logger).Log("msg", "certificate reloaded", certPath)
127+
} else {
128+
level.Error(w.logger).Log("msg", "error reloading certificate", "cert", certPath, "error", err.Error())
129+
}
130+
w.mu.Unlock()
131+
}

tls/config.go

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,31 @@ func NewClientConfig(upstreamCA []byte, upstreamCert *tls.Certificate) *tls.Conf
2727
return cfg
2828
}
2929

30+
func NewClientConfigFromWatcher(watcher *CertWatcher) *tls.Config {
31+
if watcher == nil {
32+
return nil
33+
}
34+
35+
cfg := &tls.Config{
36+
RootCAs: watcher.Pool(),
37+
}
38+
39+
cfg.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
40+
return watcher.certificate(), nil
41+
}
42+
43+
return cfg
44+
}
45+
3046
// NewServerConfig provides new server TLS configuration.
3147
func NewServerConfig(logger log.Logger, certFile, keyFile, minVersion, maxVersion, clientAuthType string, cipherSuites []string) (*tls.Config, error) {
3248
if certFile == "" && keyFile == "" {
3349
level.Info(logger).Log("msg", "TLS disabled; key and cert must be set to enable")
34-
3550
return nil, nil
3651
}
3752

3853
level.Info(logger).Log("msg", "enabling server side TLS")
3954

40-
tlsCert, err := tls.LoadX509KeyPair(certFile, keyFile)
41-
if err != nil {
42-
return nil, fmt.Errorf("server credentials: %w", err)
43-
}
44-
4555
tlsMinVersion, err := parseTLSVersion(minVersion)
4656
if err != nil {
4757
return nil, fmt.Errorf("cannot parse TLS Version: %w", err)
@@ -66,15 +76,25 @@ func NewServerConfig(logger log.Logger, certFile, keyFile, minVersion, maxVersio
6676
return nil, fmt.Errorf("can not parse TLS Client authentication policy: %w", err)
6777
}
6878

79+
watcher, err := NewCertWatcher(CertPaths{
80+
CertPath: certFile,
81+
KeyPath: keyFile,
82+
}, logger)
83+
if err != nil {
84+
return nil, fmt.Errorf("error starting certificate watcher: %w", err)
85+
}
86+
6987
tlsCfg := &tls.Config{
70-
Certificates: []tls.Certificate{tlsCert},
7188
// A list of supported cipher suites for TLS versions up to TLS 1.2.
7289
// If CipherSuites is nil, a default list of secure cipher suites is used.
7390
// Note that TLS 1.3 ciphersuites are not configurable.
7491
CipherSuites: cipherSuiteIDs,
7592
ClientAuth: tlsClientAuthType,
7693
MinVersion: tlsMinVersion,
7794
MaxVersion: tlsMaxVersion,
95+
GetCertificate: func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
96+
return watcher.certificate(), nil
97+
},
7898
}
7999

80100
return tlsCfg, nil

0 commit comments

Comments
 (0)