diff --git a/cmd/catalog/main.go b/cmd/catalog/main.go index 7117c3488c..1f9ba6f1dd 100644 --- a/cmd/catalog/main.go +++ b/cmd/catalog/main.go @@ -109,7 +109,7 @@ func main() { *catalogNamespace = catalogNamespaceEnvVarValue } - listenAndServe, err := server.GetListenAndServeFunc(logger, tlsCertPath, tlsKeyPath, clientCAPath) + listenAndServe, err := server.GetListenAndServeFunc(server.WithLogger(logger), server.WithTLS(tlsCertPath, tlsKeyPath, clientCAPath), server.WithDebug(*debug)) if err != nil { logger.Fatal("Error setting up health/metric/pprof service: %v", err) } diff --git a/cmd/olm/main.go b/cmd/olm/main.go index 72e2a4b429..4e2ed4ac55 100644 --- a/cmd/olm/main.go +++ b/cmd/olm/main.go @@ -118,7 +118,7 @@ func main() { } logger.Infof("log level %s", logger.Level) - listenAndServe, err := server.GetListenAndServeFunc(logger, tlsCertPath, tlsKeyPath, clientCAPath) + listenAndServe, err := server.GetListenAndServeFunc(server.WithLogger(logger), server.WithTLS(tlsCertPath, tlsKeyPath, clientCAPath), server.WithDebug(*debug)) if err != nil { logger.Fatal("Error setting up health/metric/pprof service: %v", err) } diff --git a/pkg/lib/profile/profile.go b/pkg/lib/profile/profile.go index 46cf07edb4..f6bafeb2ad 100644 --- a/pkg/lib/profile/profile.go +++ b/pkg/lib/profile/profile.go @@ -6,36 +6,39 @@ import ( ) type profileConfig struct { - pprof bool - cmdline bool - profile bool - symbol bool - trace bool + pprof bool + cmdline bool + profile bool + symbol bool + trace bool + enableTLS bool } // Option applies a configuration option to the given config. type Option func(p *profileConfig) func (p *profileConfig) apply(options []Option) { - if len(options) == 0 { - // If no options are given, default to all - p.pprof = true - p.cmdline = true - p.profile = true - p.symbol = true - p.trace = true - - return - } - for _, o := range options { o(p) } } +func WithTLS(enabled bool) Option { + return func(p *profileConfig) { + p.enableTLS = enabled + } +} + func defaultProfileConfig() *profileConfig { // Initialize config - return &profileConfig{} + return &profileConfig{ + pprof: true, + cmdline: true, + profile: true, + symbol: true, + trace: true, + enableTLS: true, + } } // RegisterHandlers registers profile Handlers with the given ServeMux. @@ -47,25 +50,25 @@ func RegisterHandlers(mux *http.ServeMux, options ...Option) { config.apply(options) if config.pprof { - mux.Handle("/debug/pprof/", requireVerifiedClientCertificate(http.HandlerFunc(pprof.Index))) + mux.Handle("/debug/pprof/", pprofHandlerFunc(http.HandlerFunc(pprof.Index), config.enableTLS)) } if config.cmdline { - mux.Handle("/debug/pprof/cmdline", requireVerifiedClientCertificate(http.HandlerFunc(pprof.Cmdline))) + mux.Handle("/debug/pprof/cmdline", pprofHandlerFunc(http.HandlerFunc(pprof.Cmdline), config.enableTLS)) } if config.profile { - mux.Handle("/debug/pprof/profile", requireVerifiedClientCertificate(http.HandlerFunc(pprof.Profile))) + mux.Handle("/debug/pprof/profile", pprofHandlerFunc(http.HandlerFunc(pprof.Profile), config.enableTLS)) } if config.symbol { - mux.Handle("/debug/pprof/symbol", requireVerifiedClientCertificate(http.HandlerFunc(pprof.Symbol))) + mux.Handle("/debug/pprof/symbol", pprofHandlerFunc(http.HandlerFunc(pprof.Symbol), config.enableTLS)) } if config.trace { - mux.Handle("/debug/pprof/trace", requireVerifiedClientCertificate(http.HandlerFunc(pprof.Trace))) + mux.Handle("/debug/pprof/trace", pprofHandlerFunc(http.HandlerFunc(pprof.Trace), config.enableTLS)) } } -func requireVerifiedClientCertificate(h http.Handler) http.Handler { +func pprofHandlerFunc(h http.Handler, enableTLS bool) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil || len(r.TLS.VerifiedChains) == 0 { + if enableTLS && (r.TLS == nil || len(r.TLS.VerifiedChains) == 0) { w.WriteHeader(http.StatusForbidden) return } diff --git a/pkg/lib/server/server.go b/pkg/lib/server/server.go index 79dd2bb588..3d79a192e0 100644 --- a/pkg/lib/server/server.go +++ b/pkg/lib/server/server.go @@ -13,68 +13,136 @@ import ( "github.com/sirupsen/logrus" ) -func GetListenAndServeFunc(logger *logrus.Logger, tlsCertPath, tlsKeyPath, clientCAPath *string) (func() error, error) { +// Option applies a configuration option to the given config. +type Option func(s *serverConfig) + +func GetListenAndServeFunc(options ...Option) (func() error, error) { + sc := defaultServerConfig() + sc.apply(options) + + return sc.getListenAndServeFunc() +} + +func WithTLS(tlsCertPath, tlsKeyPath, clientCAPath *string) Option { + return func(sc *serverConfig) { + sc.tlsCertPath = tlsCertPath + sc.tlsKeyPath = tlsKeyPath + sc.clientCAPath = clientCAPath + } +} + +func WithLogger(logger *logrus.Logger) Option { + return func(sc *serverConfig) { + sc.logger = logger + } +} + +func WithDebug(debug bool) Option { + return func(sc *serverConfig) { + sc.debug = debug + } +} + +type serverConfig struct { + logger *logrus.Logger + tlsCertPath *string + tlsKeyPath *string + clientCAPath *string + debug bool +} + +func (sc *serverConfig) apply(options []Option) { + for _, o := range options { + o(sc) + } +} + +func defaultServerConfig() serverConfig { + return serverConfig{ + tlsCertPath: nil, + tlsKeyPath: nil, + clientCAPath: nil, + logger: nil, + debug: false, + } +} +func (sc *serverConfig) tlsEnabled() (bool, error) { + if *sc.tlsCertPath != "" && *sc.tlsKeyPath != "" { + return true, nil + } + if *sc.tlsCertPath != "" || *sc.tlsKeyPath != "" { + return false, fmt.Errorf("both --tls-key and --tls-crt must be provided for TLS to be enabled") + } + return false, nil +} + +func (sc *serverConfig) getAddress(tlsEnabled bool) string { + if tlsEnabled { + return ":8443" + } + return ":8080" +} + +func (sc serverConfig) getListenAndServeFunc() (func() error, error) { + tlsEnabled, err := sc.tlsEnabled() + if err != nil { + return nil, fmt.Errorf("both --tls-key and --tls-crt must be provided for TLS to be enabled") + } + mux := http.NewServeMux() - profile.RegisterHandlers(mux) mux.Handle("/metrics", promhttp.Handler()) mux.HandleFunc("/healthz", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) + profile.RegisterHandlers(mux, profile.WithTLS(tlsEnabled || !sc.debug)) s := http.Server{ Handler: mux, - Addr: ":8080", - } - listenAndServe := s.ListenAndServe - - if *tlsCertPath != "" && *tlsKeyPath != "" { - logger.Info("TLS keys set, using https for metrics") - - certStore, err := filemonitor.NewCertStore(*tlsCertPath, *tlsKeyPath) - if err != nil { - return nil, fmt.Errorf("certificate monitoring for metrics (https) failed: %v", err) - } - - csw, err := filemonitor.NewWatch(logger, []string{filepath.Dir(*tlsCertPath), filepath.Dir(*tlsKeyPath)}, certStore.HandleFilesystemUpdate) - if err != nil { - return nil, fmt.Errorf("error creating cert file watcher: %v", err) - } - csw.Run(context.Background()) - certPoolStore, err := filemonitor.NewCertPoolStore(*clientCAPath) - if err != nil { - return nil, fmt.Errorf("certificate monitoring for client-ca failed: %v", err) - } - cpsw, err := filemonitor.NewWatch(logger, []string{filepath.Dir(*clientCAPath)}, certPoolStore.HandleCABundleUpdate) - if err != nil { - return nil, fmt.Errorf("error creating cert file watcher: %v", err) - } - cpsw.Run(context.Background()) - - s.Addr = ":8443" - s.TLSConfig = &tls.Config{ - GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { - return certStore.GetCertificate(), nil - }, - GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) { - var certs []tls.Certificate - if cert := certStore.GetCertificate(); cert != nil { - certs = append(certs, *cert) - } - return &tls.Config{ - Certificates: certs, - ClientCAs: certPoolStore.GetCertPool(), - ClientAuth: tls.VerifyClientCertIfGiven, - }, nil - }, - } - - listenAndServe = func() error { - return s.ListenAndServeTLS("", "") - } - } else if *tlsCertPath != "" || *tlsKeyPath != "" { - return nil, fmt.Errorf("both --tls-key and --tls-crt must be provided for TLS to be enabled") - } else { - logger.Info("TLS keys not set, using non-https for metrics") + Addr: sc.getAddress(tlsEnabled), + } + + if !tlsEnabled { + return s.ListenAndServe, nil + } + + sc.logger.Info("TLS keys set, using https for metrics") + certStore, err := filemonitor.NewCertStore(*sc.tlsCertPath, *sc.tlsKeyPath) + if err != nil { + return nil, fmt.Errorf("certificate monitoring for metrics (https) failed: %v", err) + } + + csw, err := filemonitor.NewWatch(sc.logger, []string{filepath.Dir(*sc.tlsCertPath), filepath.Dir(*sc.tlsKeyPath)}, certStore.HandleFilesystemUpdate) + if err != nil { + return nil, fmt.Errorf("error creating cert file watcher: %v", err) + } + csw.Run(context.Background()) + certPoolStore, err := filemonitor.NewCertPoolStore(*sc.clientCAPath) + if err != nil { + return nil, fmt.Errorf("certificate monitoring for client-ca failed: %v", err) + } + cpsw, err := filemonitor.NewWatch(sc.logger, []string{filepath.Dir(*sc.clientCAPath)}, certPoolStore.HandleCABundleUpdate) + if err != nil { + return nil, fmt.Errorf("error creating cert file watcher: %v", err) + } + cpsw.Run(context.Background()) + + s.TLSConfig = &tls.Config{ + GetCertificate: func(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { + return certStore.GetCertificate(), nil + }, + GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) { + var certs []tls.Certificate + if cert := certStore.GetCertificate(); cert != nil { + certs = append(certs, *cert) + } + return &tls.Config{ + Certificates: certs, + ClientCAs: certPoolStore.GetCertPool(), + ClientAuth: tls.VerifyClientCertIfGiven, + }, nil + }, } - return listenAndServe, nil + return func() error { + return s.ListenAndServeTLS("", "") + }, nil }