diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 75f347d..6f6c47f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,7 +4,7 @@ jobs: build: strategy: matrix: - go-version: [1.14.x, 1.15.x] + go-version: [1.15.x] os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} diff --git a/internal/service/security/tls.go b/internal/service/security/tls.go new file mode 100644 index 0000000..e856829 --- /dev/null +++ b/internal/service/security/tls.go @@ -0,0 +1,182 @@ +// Package security provides HTTP security management help to the webhook +// service. +package security + +import ( + "crypto/tls" + "fmt" + "io" + "log" + "strings" + "sync" +) + +// KeyPairReloader contains the active TLS certificate. It can be used with +// the tls.Config.GetCertificate property to support live updating of the +// certificate. +type KeyPairReloader struct { + certMu sync.RWMutex + cert *tls.Certificate + certPath string + keyPath string +} + +// NewKeyPairReloader creates a new KeyPairReloader given the certificate and +// key path. +func NewKeyPairReloader(certPath, keyPath string) (*KeyPairReloader, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + + res := &KeyPairReloader{ + cert: &cert, + certPath: certPath, + keyPath: keyPath, + } + + return res, nil +} + +// Reload attempts to reload the TLS key pair. +func (kpr *KeyPairReloader) Reload() error { + cert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath) + if err != nil { + return err + } + + kpr.certMu.Lock() + defer kpr.certMu.Unlock() + + kpr.cert = &cert + return nil +} + +// GetCertificateFunc provides a function for tls.Config.GetCertificate. +func (kpr *KeyPairReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + kpr.certMu.RLock() + defer kpr.certMu.RUnlock() + return kpr.cert, nil + } +} + +// WriteTLSSupportedCipherStrings writes a list of ciphers to w. The list is +// all supported TLS ciphers based upon min. +func WriteTLSSupportedCipherStrings(w io.Writer, min string) error { + m, err := GetTLSVersion(min) + if err != nil { + return err + } + + for _, c := range tls.CipherSuites() { + var found bool + + for _, v := range c.SupportedVersions { + if v >= m { + found = true + } + } + + if !found { + continue + } + + _, err := w.Write([]byte(c.Name + "\n")) + if err != nil { + return err + } + } + + return nil +} + +// GetTLSVersion converts a TLS version string, v, (e.g. "v1.3") into a TLS +// version ID. +func GetTLSVersion(v string) (uint16, error) { + switch v { + case "1.3", "v1.3", "tls1.3": + return tls.VersionTLS13, nil + case "1.2", "v1.2", "tls1.2", "": + return tls.VersionTLS12, nil + case "1.1", "v1.1", "tls1.1": + return tls.VersionTLS11, nil + case "1.0", "v1.0", "tls1.0": + return tls.VersionTLS10, nil + default: + return 0, fmt.Errorf("error: unknown TLS version: %s", v) + } +} + +// GetTLSCipherSuites converts a comma separated list of cipher suites into a +// slice of TLS cipher suite IDs. +func GetTLSCipherSuites(v string) []uint16 { + supported := tls.CipherSuites() + + if v == "" { + suites := make([]uint16, len(supported)) + + for _, cs := range supported { + suites = append(suites, cs.ID) + } + + return suites + } + + var found bool + txts := strings.Split(v, ",") + suites := make([]uint16, len(txts)) + + for _, want := range txts { + found = false + + for _, cs := range supported { + if want == cs.Name { + suites = append(suites, cs.ID) + found = true + } + } + + if !found { + log.Fatalln("error: unknown TLS cipher suite:", want) + } + } + + return suites +} + +// GetTLSCurves converts a comma separated list of curves into a +// slice of TLS curve IDs. +func GetTLSCurves(v string) []tls.CurveID { + supported := []tls.CurveID{ + tls.CurveP256, + tls.CurveP384, + tls.CurveP521, + tls.X25519, + } + + if v == "" { + return supported + } + + var found bool + txts := strings.Split(v, ",") + res := make([]tls.CurveID, len(txts)) + + for _, want := range txts { + found = false + + for _, c := range supported { + if want == c.String() { + res = append(res, c) + found = true + } + } + + if !found { + log.Fatalln("error: unknown TLS curve:", want) + } + } + + return res +} diff --git a/internal/service/service.go b/internal/service/service.go new file mode 100644 index 0000000..b85c280 --- /dev/null +++ b/internal/service/service.go @@ -0,0 +1,163 @@ +// Package service manages the webhook HTTP service. +package service + +import ( + "crypto/tls" + "fmt" + "net" + "net/http" + + "github.com/adnanh/webhook/internal/pidfile" + "github.com/adnanh/webhook/internal/service/security" + + "github.com/gorilla/mux" +) + +// Service is the webhook HTTP service. +type Service struct { + // Address is the listener address for the service (e.g. "127.0.0.1:9000") + Address string + + // TLS settings + enableTLS bool + tlsCiphers []uint16 + tlsMinVersion uint16 + kpr *security.KeyPairReloader + + // Future TLS settings to consider: + // - tlsMaxVersion + // - configurable TLS curves + // - modern and intermediate helpers that follows Mozilla guidelines + // - ca root and intermediate certs + + listener net.Listener + server *http.Server + + pidFile *pidfile.PIDFile + + // Hooks map[string]hook.Hooks +} + +// New creates a new webhook HTTP service for the given address and port. +func New(ip string, port int) *Service { + return &Service{ + Address: fmt.Sprintf("%s:%d", ip, port), + server: &http.Server{}, + tlsMinVersion: tls.VersionTLS12, + } +} + +// Listen announces the TCP service on the local network. +// +// To enable TLS, ensure that SetTLSEnabled is called prior to Listen. +// +// After calling Listen, Serve must be called to begin serving HTTP requests. +// The steps are separated so that we can drop privileges, if necessary, after +// opening the listening port. +func (s *Service) Listen() error { + ln, err := net.Listen("tcp", s.Address) + if err != nil { + return err + } + + if !s.enableTLS { + s.listener = ln + return nil + } + + if s.kpr == nil { + panic("Listen called with TLS enabled but KPR is nil") + } + + c := &tls.Config{ + GetCertificate: s.kpr.GetCertificateFunc(), + CipherSuites: s.tlsCiphers, + CurvePreferences: security.GetTLSCurves(""), + MinVersion: s.tlsMinVersion, + PreferServerCipherSuites: true, + } + + s.listener = tls.NewListener(ln, c) + + return nil +} + +// Serve begins accepting incoming HTTP connections. +func (s *Service) Serve() error { + s.server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2 + + if s.listener == nil { + err := s.Listen() + if err != nil { + return err + } + } + + defer s.listener.Close() + return s.server.Serve(s.listener) +} + +// SetHTTPHandler sets the underly HTTP server Handler. +func (s *Service) SetHTTPHandler(r *mux.Router) { + s.server.Handler = r +} + +// SetTLSCiphers sets the supported TLS ciphers. +func (s *Service) SetTLSCiphers(suites string) { + s.tlsCiphers = security.GetTLSCipherSuites(suites) +} + +// SetTLSEnabled enables TLS for the service. Must be called prior to Listen. +func (s *Service) SetTLSEnabled() { + s.enableTLS = true +} + +// TLSEnabled return true if TLS is enabled for the service. +func (s *Service) TLSEnabled() bool { + return s.enableTLS +} + +// SetTLSKeyPair sets the TLS key pair for the service. +func (s *Service) SetTLSKeyPair(certPath, keyPath string) error { + if certPath == "" { + return fmt.Errorf("error: certificate required for TLS") + } + + if keyPath == "" { + return fmt.Errorf("error: key required for TLS") + } + + var err error + + s.kpr, err = security.NewKeyPairReloader(certPath, keyPath) + if err != nil { + return err + } + + return nil +} + +// ReloadTLSKeyPair attempts to reload the configured TLS certificate key pair. +func (s *Service) ReloadTLSKeyPair() error { + return s.kpr.Reload() +} + +// SetTLSMinVersion sets the minimum support TLS version, such as "v1.3". +func (s *Service) SetTLSMinVersion(ver string) (err error) { + s.tlsMinVersion, err = security.GetTLSVersion(ver) + return err +} + +// CreatePIDFile creates a new PID file at path p. +func (s *Service) CreatePIDFile(p string) (err error) { + s.pidFile, err = pidfile.New(p) + return err +} + +// DeletePIDFile deletes a previously created PID file. +func (s *Service) DeletePIDFile() error { + if s.pidFile != nil { + return s.pidFile.Remove() + } + return nil +} diff --git a/signals.go b/signals.go index 23f0a74..2dd43b4 100644 --- a/signals.go +++ b/signals.go @@ -7,9 +7,11 @@ import ( "os" "os/signal" "syscall" + + "github.com/adnanh/webhook/internal/service" ) -func setupSignals() { +func setupSignals(svc *service.Service) { log.Printf("setting up os signal watcher\n") signals = make(chan os.Signal, 1) @@ -18,10 +20,10 @@ func setupSignals() { signal.Notify(signals, syscall.SIGTERM) signal.Notify(signals, os.Interrupt) - go watchForSignals() + go watchForSignals(svc) } -func watchForSignals() { +func watchForSignals(svc *service.Service) { log.Println("os signal watcher ready") for { @@ -35,13 +37,21 @@ func watchForSignals() { log.Println("caught HUP signal") reloadAllHooks() + if svc.TLSEnabled() { + log.Println("attempting to reload TLS key pair") + err := svc.ReloadTLSKeyPair() + if err != nil { + log.Printf("failed to reload TLS key pair: %s\n", err) + } else { + log.Println("successfully reloaded TLS key pair") + } + } + case os.Interrupt, syscall.SIGTERM: log.Printf("caught %s signal; exiting\n", sig) - if pidFile != nil { - err := pidFile.Remove() - if err != nil { - log.Print(err) - } + err := svc.DeletePIDFile() + if err != nil { + log.Print(err) } os.Exit(0) diff --git a/signals_windows.go b/signals_windows.go index e7a2a1d..bcf941e 100644 --- a/signals_windows.go +++ b/signals_windows.go @@ -2,6 +2,8 @@ package main -func setupSignals() { +import "github.com/adnanh/webhook/internal/service" + +func setupSignals(_ *service.Service) { // NOOP: Windows doesn't have signals equivalent to the Unix world. } diff --git a/tls.go b/tls.go deleted file mode 100644 index 526fd36..0000000 --- a/tls.go +++ /dev/null @@ -1,85 +0,0 @@ -package main - -import ( - "crypto/tls" - "io" - "log" - "strings" -) - -func writeTLSSupportedCipherStrings(w io.Writer, min uint16) error { - for _, c := range tls.CipherSuites() { - var found bool - - for _, v := range c.SupportedVersions { - if v >= min { - found = true - } - } - - if !found { - continue - } - - _, err := w.Write([]byte(c.Name + "\n")) - if err != nil { - return err - } - } - - return nil -} - -// getTLSMinVersion converts a version string into a TLS version ID. -func getTLSMinVersion(v string) uint16 { - switch v { - case "1.0": - return tls.VersionTLS10 - case "1.1": - return tls.VersionTLS11 - case "1.2", "": - return tls.VersionTLS12 - case "1.3": - return tls.VersionTLS13 - default: - log.Fatalln("error: unknown minimum TLS version:", v) - return 0 - } -} - -// getTLSCipherSuites converts a comma separated list of cipher suites into a -// slice of TLS cipher suite IDs. -func getTLSCipherSuites(v string) []uint16 { - supported := tls.CipherSuites() - - if v == "" { - suites := make([]uint16, len(supported)) - - for _, cs := range supported { - suites = append(suites, cs.ID) - } - - return suites - } - - var found bool - txts := strings.Split(v, ",") - suites := make([]uint16, len(txts)) - - for _, want := range txts { - found = false - - for _, cs := range supported { - if want == cs.Name { - suites = append(suites, cs.ID) - found = true - } - } - - if !found { - log.Fatalln("error: unknown TLS cipher suite:", want) - } - } - - return suites -} diff --git a/webhook.go b/webhook.go index 94f0500..82e3936 100644 --- a/webhook.go +++ b/webhook.go @@ -1,13 +1,11 @@ package main import ( - "crypto/tls" "encoding/json" "flag" "fmt" "io/ioutil" "log" - "net" "net/http" "os" "os/exec" @@ -17,7 +15,8 @@ import ( "github.com/adnanh/webhook/internal/hook" "github.com/adnanh/webhook/internal/middleware" - "github.com/adnanh/webhook/internal/pidfile" + "github.com/adnanh/webhook/internal/service" + "github.com/adnanh/webhook/internal/service/security" chimiddleware "github.com/go-chi/chi/middleware" "github.com/gorilla/mux" @@ -60,7 +59,8 @@ var ( watcher *fsnotify.Watcher signals chan os.Signal - pidFile *pidfile.PIDFile + + Service *service.Service ) func matchLoadedHook(id string) *hook.Hook { @@ -94,7 +94,7 @@ func main() { } if *justListCiphers { - err := writeTLSSupportedCipherStrings(os.Stdout, getTLSMinVersion(*tlsMinVersion)) + err := security.WriteTLSSupportedCipherStrings(os.Stdout, *tlsMinVersion) if err != nil { fmt.Println(err) os.Exit(1) @@ -111,8 +111,26 @@ func main() { *verbose = true } - if len(hooksFiles) == 0 { - hooksFiles = append(hooksFiles, "hooks.json") + // Setup a new Service instance + Service = service.New(*ip, *port) + + // We must setup TLS prior to opening a listening port. + if *secure { + Service.SetTLSEnabled() + + err := Service.SetTLSKeyPair(*cert, *key) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + err = Service.SetTLSMinVersion(*tlsMinVersion) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + Service.SetTLSCiphers(*tlsCipherSuites) } // logQueue is a queue for log messages encountered during startup. We need @@ -120,10 +138,8 @@ func main() { // log file opening prior to writing our first log message. var logQueue []string - addr := fmt.Sprintf("%s:%d", *ip, *port) - // Open listener early so we can drop privileges. - ln, err := net.Listen("tcp", addr) + err := Service.Listen() if err != nil { logQueue = append(logQueue, fmt.Sprintf("error listening on port: %s", err)) // we'll bail out below @@ -166,7 +182,7 @@ func main() { if *pidPath != "" { var err error - pidFile, err = pidfile.New(*pidPath) + err = Service.CreatePIDFile(*pidPath) if err != nil { log.Fatalf("Error creating pidfile: %v", err) } @@ -174,7 +190,7 @@ func main() { defer func() { // NOTE(moorereason): my testing shows that this doesn't work with // ^C, so we also do a Remove in the signal handler elsewhere. - if nerr := pidFile.Remove(); nerr != nil { + if nerr := Service.DeletePIDFile(); nerr != nil { log.Print(nerr) } }() @@ -183,9 +199,13 @@ func main() { log.Println("version " + version + " starting") // set os signal watcher - setupSignals() + setupSignals(Service) // load and parse hooks + if len(hooksFiles) == 0 { + hooksFiles = append(hooksFiles, "hooks.json") + } + for _, hooksFilePath := range hooksFiles { log.Printf("attempting to load hooks from %s\n", hooksFilePath) @@ -275,30 +295,15 @@ func main() { r.HandleFunc(hooksURL, hookHandler) // Create common HTTP server settings - svr := &http.Server{ - Addr: addr, - Handler: r, - } + Service.SetHTTPHandler(r) - // Serve HTTP if !*secure { - log.Printf("serving hooks on http://%s%s", addr, makeHumanPattern(hooksURLPrefix)) - log.Print(svr.Serve(ln)) - - return + log.Printf("serving hooks on http://%s%s", Service.Address, makeHumanPattern(hooksURLPrefix)) + } else { + log.Printf("serving hooks on https://%s%s", Service.Address, makeHumanPattern(hooksURLPrefix)) } - // Server HTTPS - svr.TLSConfig = &tls.Config{ - CipherSuites: getTLSCipherSuites(*tlsCipherSuites), - CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, - MinVersion: getTLSMinVersion(*tlsMinVersion), - PreferServerCipherSuites: true, - } - svr.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2 - - log.Printf("serving hooks on https://%s%s", addr, makeHumanPattern(hooksURLPrefix)) - log.Print(svr.ServeTLS(ln, *cert, *key)) + log.Print(Service.Serve()) } func hookHandler(w http.ResponseWriter, r *http.Request) {