diff --git a/internal/service/security/tls.go b/internal/service/security/tls.go index 4c77474..e856829 100644 --- a/internal/service/security/tls.go +++ b/internal/service/security/tls.go @@ -38,6 +38,20 @@ func NewKeyPairReloader(certPath, keyPath string) (*KeyPairReloader, error) { return res, nil } +// Reload attempts to reload the TLS key pair. +func (kpr *KeyPairReloader) Reload() error { + cert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath) + if err != nil { + return err + } + + kpr.certMu.Lock() + defer kpr.certMu.Unlock() + + kpr.cert = &cert + return nil +} + // GetCertificateFunc provides a function for tls.Config.GetCertificate. func (kpr *KeyPairReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { diff --git a/internal/service/service.go b/internal/service/service.go index 2c07c17..b85c280 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -112,6 +112,11 @@ func (s *Service) SetTLSEnabled() { s.enableTLS = true } +// TLSEnabled return true if TLS is enabled for the service. +func (s *Service) TLSEnabled() bool { + return s.enableTLS +} + // SetTLSKeyPair sets the TLS key pair for the service. func (s *Service) SetTLSKeyPair(certPath, keyPath string) error { if certPath == "" { @@ -132,6 +137,11 @@ func (s *Service) SetTLSKeyPair(certPath, keyPath string) error { return nil } +// ReloadTLSKeyPair attempts to reload the configured TLS certificate key pair. +func (s *Service) ReloadTLSKeyPair() error { + return s.kpr.Reload() +} + // SetTLSMinVersion sets the minimum support TLS version, such as "v1.3". func (s *Service) SetTLSMinVersion(ver string) (err error) { s.tlsMinVersion, err = security.GetTLSVersion(ver) diff --git a/signals.go b/signals.go index 5ad3fa4..2dd43b4 100644 --- a/signals.go +++ b/signals.go @@ -37,6 +37,16 @@ func watchForSignals(svc *service.Service) { log.Println("caught HUP signal") reloadAllHooks() + if svc.TLSEnabled() { + log.Println("attempting to reload TLS key pair") + err := svc.ReloadTLSKeyPair() + if err != nil { + log.Printf("failed to reload TLS key pair: %s\n", err) + } else { + log.Println("successfully reloaded TLS key pair") + } + } + case os.Interrupt, syscall.SIGTERM: log.Printf("caught %s signal; exiting\n", sig) err := svc.DeletePIDFile()