Reload TLS key pair on HUP signal

This commit is contained in:
Cameron Moore 2020-12-27 23:38:13 -06:00
parent 1c72898604
commit 50a690a5e4
No known key found for this signature in database
GPG Key ID: AF96E12468D7553E
3 changed files with 34 additions and 0 deletions

View File

@ -38,6 +38,20 @@ func NewKeyPairReloader(certPath, keyPath string) (*KeyPairReloader, error) {
return res, nil
}
// Reload attempts to reload the TLS key pair.
func (kpr *KeyPairReloader) Reload() error {
cert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath)
if err != nil {
return err
}
kpr.certMu.Lock()
defer kpr.certMu.Unlock()
kpr.cert = &cert
return nil
}
// GetCertificateFunc provides a function for tls.Config.GetCertificate.
func (kpr *KeyPairReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {

View File

@ -112,6 +112,11 @@ func (s *Service) SetTLSEnabled() {
s.enableTLS = true
}
// TLSEnabled return true if TLS is enabled for the service.
func (s *Service) TLSEnabled() bool {
return s.enableTLS
}
// SetTLSKeyPair sets the TLS key pair for the service.
func (s *Service) SetTLSKeyPair(certPath, keyPath string) error {
if certPath == "" {
@ -132,6 +137,11 @@ func (s *Service) SetTLSKeyPair(certPath, keyPath string) error {
return nil
}
// ReloadTLSKeyPair attempts to reload the configured TLS certificate key pair.
func (s *Service) ReloadTLSKeyPair() error {
return s.kpr.Reload()
}
// SetTLSMinVersion sets the minimum support TLS version, such as "v1.3".
func (s *Service) SetTLSMinVersion(ver string) (err error) {
s.tlsMinVersion, err = security.GetTLSVersion(ver)

View File

@ -37,6 +37,16 @@ func watchForSignals(svc *service.Service) {
log.Println("caught HUP signal")
reloadAllHooks()
if svc.TLSEnabled() {
log.Println("attempting to reload TLS key pair")
err := svc.ReloadTLSKeyPair()
if err != nil {
log.Printf("failed to reload TLS key pair: %s\n", err)
} else {
log.Println("successfully reloaded TLS key pair")
}
}
case os.Interrupt, syscall.SIGTERM:
log.Printf("caught %s signal; exiting\n", sig)
err := svc.DeletePIDFile()