This commit is contained in:
Cameron Moore 2021-09-04 08:00:44 -04:00 committed by GitHub
commit 3dbab57740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 405 additions and 128 deletions

View File

@ -4,7 +4,7 @@ jobs:
build: build:
strategy: strategy:
matrix: matrix:
go-version: [1.14.x, 1.15.x] go-version: [1.15.x]
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}

View File

@ -0,0 +1,182 @@
// Package security provides HTTP security management help to the webhook
// service.
package security
import (
"crypto/tls"
"fmt"
"io"
"log"
"strings"
"sync"
)
// KeyPairReloader contains the active TLS certificate. It can be used with
// the tls.Config.GetCertificate property to support live updating of the
// certificate.
type KeyPairReloader struct {
certMu sync.RWMutex
cert *tls.Certificate
certPath string
keyPath string
}
// NewKeyPairReloader creates a new KeyPairReloader given the certificate and
// key path.
func NewKeyPairReloader(certPath, keyPath string) (*KeyPairReloader, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, err
}
res := &KeyPairReloader{
cert: &cert,
certPath: certPath,
keyPath: keyPath,
}
return res, nil
}
// Reload attempts to reload the TLS key pair.
func (kpr *KeyPairReloader) Reload() error {
cert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath)
if err != nil {
return err
}
kpr.certMu.Lock()
defer kpr.certMu.Unlock()
kpr.cert = &cert
return nil
}
// GetCertificateFunc provides a function for tls.Config.GetCertificate.
func (kpr *KeyPairReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
kpr.certMu.RLock()
defer kpr.certMu.RUnlock()
return kpr.cert, nil
}
}
// WriteTLSSupportedCipherStrings writes a list of ciphers to w. The list is
// all supported TLS ciphers based upon min.
func WriteTLSSupportedCipherStrings(w io.Writer, min string) error {
m, err := GetTLSVersion(min)
if err != nil {
return err
}
for _, c := range tls.CipherSuites() {
var found bool
for _, v := range c.SupportedVersions {
if v >= m {
found = true
}
}
if !found {
continue
}
_, err := w.Write([]byte(c.Name + "\n"))
if err != nil {
return err
}
}
return nil
}
// GetTLSVersion converts a TLS version string, v, (e.g. "v1.3") into a TLS
// version ID.
func GetTLSVersion(v string) (uint16, error) {
switch v {
case "1.3", "v1.3", "tls1.3":
return tls.VersionTLS13, nil
case "1.2", "v1.2", "tls1.2", "":
return tls.VersionTLS12, nil
case "1.1", "v1.1", "tls1.1":
return tls.VersionTLS11, nil
case "1.0", "v1.0", "tls1.0":
return tls.VersionTLS10, nil
default:
return 0, fmt.Errorf("error: unknown TLS version: %s", v)
}
}
// GetTLSCipherSuites converts a comma separated list of cipher suites into a
// slice of TLS cipher suite IDs.
func GetTLSCipherSuites(v string) []uint16 {
supported := tls.CipherSuites()
if v == "" {
suites := make([]uint16, len(supported))
for _, cs := range supported {
suites = append(suites, cs.ID)
}
return suites
}
var found bool
txts := strings.Split(v, ",")
suites := make([]uint16, len(txts))
for _, want := range txts {
found = false
for _, cs := range supported {
if want == cs.Name {
suites = append(suites, cs.ID)
found = true
}
}
if !found {
log.Fatalln("error: unknown TLS cipher suite:", want)
}
}
return suites
}
// GetTLSCurves converts a comma separated list of curves into a
// slice of TLS curve IDs.
func GetTLSCurves(v string) []tls.CurveID {
supported := []tls.CurveID{
tls.CurveP256,
tls.CurveP384,
tls.CurveP521,
tls.X25519,
}
if v == "" {
return supported
}
var found bool
txts := strings.Split(v, ",")
res := make([]tls.CurveID, len(txts))
for _, want := range txts {
found = false
for _, c := range supported {
if want == c.String() {
res = append(res, c)
found = true
}
}
if !found {
log.Fatalln("error: unknown TLS curve:", want)
}
}
return res
}

163
internal/service/service.go Normal file
View File

@ -0,0 +1,163 @@
// Package service manages the webhook HTTP service.
package service
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"github.com/adnanh/webhook/internal/pidfile"
"github.com/adnanh/webhook/internal/service/security"
"github.com/gorilla/mux"
)
// Service is the webhook HTTP service.
type Service struct {
// Address is the listener address for the service (e.g. "127.0.0.1:9000")
Address string
// TLS settings
enableTLS bool
tlsCiphers []uint16
tlsMinVersion uint16
kpr *security.KeyPairReloader
// Future TLS settings to consider:
// - tlsMaxVersion
// - configurable TLS curves
// - modern and intermediate helpers that follows Mozilla guidelines
// - ca root and intermediate certs
listener net.Listener
server *http.Server
pidFile *pidfile.PIDFile
// Hooks map[string]hook.Hooks
}
// New creates a new webhook HTTP service for the given address and port.
func New(ip string, port int) *Service {
return &Service{
Address: fmt.Sprintf("%s:%d", ip, port),
server: &http.Server{},
tlsMinVersion: tls.VersionTLS12,
}
}
// Listen announces the TCP service on the local network.
//
// To enable TLS, ensure that SetTLSEnabled is called prior to Listen.
//
// After calling Listen, Serve must be called to begin serving HTTP requests.
// The steps are separated so that we can drop privileges, if necessary, after
// opening the listening port.
func (s *Service) Listen() error {
ln, err := net.Listen("tcp", s.Address)
if err != nil {
return err
}
if !s.enableTLS {
s.listener = ln
return nil
}
if s.kpr == nil {
panic("Listen called with TLS enabled but KPR is nil")
}
c := &tls.Config{
GetCertificate: s.kpr.GetCertificateFunc(),
CipherSuites: s.tlsCiphers,
CurvePreferences: security.GetTLSCurves(""),
MinVersion: s.tlsMinVersion,
PreferServerCipherSuites: true,
}
s.listener = tls.NewListener(ln, c)
return nil
}
// Serve begins accepting incoming HTTP connections.
func (s *Service) Serve() error {
s.server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2
if s.listener == nil {
err := s.Listen()
if err != nil {
return err
}
}
defer s.listener.Close()
return s.server.Serve(s.listener)
}
// SetHTTPHandler sets the underly HTTP server Handler.
func (s *Service) SetHTTPHandler(r *mux.Router) {
s.server.Handler = r
}
// SetTLSCiphers sets the supported TLS ciphers.
func (s *Service) SetTLSCiphers(suites string) {
s.tlsCiphers = security.GetTLSCipherSuites(suites)
}
// SetTLSEnabled enables TLS for the service. Must be called prior to Listen.
func (s *Service) SetTLSEnabled() {
s.enableTLS = true
}
// TLSEnabled return true if TLS is enabled for the service.
func (s *Service) TLSEnabled() bool {
return s.enableTLS
}
// SetTLSKeyPair sets the TLS key pair for the service.
func (s *Service) SetTLSKeyPair(certPath, keyPath string) error {
if certPath == "" {
return fmt.Errorf("error: certificate required for TLS")
}
if keyPath == "" {
return fmt.Errorf("error: key required for TLS")
}
var err error
s.kpr, err = security.NewKeyPairReloader(certPath, keyPath)
if err != nil {
return err
}
return nil
}
// ReloadTLSKeyPair attempts to reload the configured TLS certificate key pair.
func (s *Service) ReloadTLSKeyPair() error {
return s.kpr.Reload()
}
// SetTLSMinVersion sets the minimum support TLS version, such as "v1.3".
func (s *Service) SetTLSMinVersion(ver string) (err error) {
s.tlsMinVersion, err = security.GetTLSVersion(ver)
return err
}
// CreatePIDFile creates a new PID file at path p.
func (s *Service) CreatePIDFile(p string) (err error) {
s.pidFile, err = pidfile.New(p)
return err
}
// DeletePIDFile deletes a previously created PID file.
func (s *Service) DeletePIDFile() error {
if s.pidFile != nil {
return s.pidFile.Remove()
}
return nil
}

View File

@ -7,9 +7,11 @@ import (
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/adnanh/webhook/internal/service"
) )
func setupSignals() { func setupSignals(svc *service.Service) {
log.Printf("setting up os signal watcher\n") log.Printf("setting up os signal watcher\n")
signals = make(chan os.Signal, 1) signals = make(chan os.Signal, 1)
@ -18,10 +20,10 @@ func setupSignals() {
signal.Notify(signals, syscall.SIGTERM) signal.Notify(signals, syscall.SIGTERM)
signal.Notify(signals, os.Interrupt) signal.Notify(signals, os.Interrupt)
go watchForSignals() go watchForSignals(svc)
} }
func watchForSignals() { func watchForSignals(svc *service.Service) {
log.Println("os signal watcher ready") log.Println("os signal watcher ready")
for { for {
@ -35,14 +37,22 @@ func watchForSignals() {
log.Println("caught HUP signal") log.Println("caught HUP signal")
reloadAllHooks() reloadAllHooks()
if svc.TLSEnabled() {
log.Println("attempting to reload TLS key pair")
err := svc.ReloadTLSKeyPair()
if err != nil {
log.Printf("failed to reload TLS key pair: %s\n", err)
} else {
log.Println("successfully reloaded TLS key pair")
}
}
case os.Interrupt, syscall.SIGTERM: case os.Interrupt, syscall.SIGTERM:
log.Printf("caught %s signal; exiting\n", sig) log.Printf("caught %s signal; exiting\n", sig)
if pidFile != nil { err := svc.DeletePIDFile()
err := pidFile.Remove()
if err != nil { if err != nil {
log.Print(err) log.Print(err)
} }
}
os.Exit(0) os.Exit(0)
default: default:

View File

@ -2,6 +2,8 @@
package main package main
func setupSignals() { import "github.com/adnanh/webhook/internal/service"
func setupSignals(_ *service.Service) {
// NOOP: Windows doesn't have signals equivalent to the Unix world. // NOOP: Windows doesn't have signals equivalent to the Unix world.
} }

85
tls.go
View File

@ -1,85 +0,0 @@
package main
import (
"crypto/tls"
"io"
"log"
"strings"
)
func writeTLSSupportedCipherStrings(w io.Writer, min uint16) error {
for _, c := range tls.CipherSuites() {
var found bool
for _, v := range c.SupportedVersions {
if v >= min {
found = true
}
}
if !found {
continue
}
_, err := w.Write([]byte(c.Name + "\n"))
if err != nil {
return err
}
}
return nil
}
// getTLSMinVersion converts a version string into a TLS version ID.
func getTLSMinVersion(v string) uint16 {
switch v {
case "1.0":
return tls.VersionTLS10
case "1.1":
return tls.VersionTLS11
case "1.2", "":
return tls.VersionTLS12
case "1.3":
return tls.VersionTLS13
default:
log.Fatalln("error: unknown minimum TLS version:", v)
return 0
}
}
// getTLSCipherSuites converts a comma separated list of cipher suites into a
// slice of TLS cipher suite IDs.
func getTLSCipherSuites(v string) []uint16 {
supported := tls.CipherSuites()
if v == "" {
suites := make([]uint16, len(supported))
for _, cs := range supported {
suites = append(suites, cs.ID)
}
return suites
}
var found bool
txts := strings.Split(v, ",")
suites := make([]uint16, len(txts))
for _, want := range txts {
found = false
for _, cs := range supported {
if want == cs.Name {
suites = append(suites, cs.ID)
found = true
}
}
if !found {
log.Fatalln("error: unknown TLS cipher suite:", want)
}
}
return suites
}

View File

@ -1,13 +1,11 @@
package main package main
import ( import (
"crypto/tls"
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
"net"
"net/http" "net/http"
"os" "os"
"os/exec" "os/exec"
@ -17,7 +15,8 @@ import (
"github.com/adnanh/webhook/internal/hook" "github.com/adnanh/webhook/internal/hook"
"github.com/adnanh/webhook/internal/middleware" "github.com/adnanh/webhook/internal/middleware"
"github.com/adnanh/webhook/internal/pidfile" "github.com/adnanh/webhook/internal/service"
"github.com/adnanh/webhook/internal/service/security"
chimiddleware "github.com/go-chi/chi/middleware" chimiddleware "github.com/go-chi/chi/middleware"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -60,7 +59,8 @@ var (
watcher *fsnotify.Watcher watcher *fsnotify.Watcher
signals chan os.Signal signals chan os.Signal
pidFile *pidfile.PIDFile
Service *service.Service
) )
func matchLoadedHook(id string) *hook.Hook { func matchLoadedHook(id string) *hook.Hook {
@ -94,7 +94,7 @@ func main() {
} }
if *justListCiphers { if *justListCiphers {
err := writeTLSSupportedCipherStrings(os.Stdout, getTLSMinVersion(*tlsMinVersion)) err := security.WriteTLSSupportedCipherStrings(os.Stdout, *tlsMinVersion)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)
@ -111,8 +111,26 @@ func main() {
*verbose = true *verbose = true
} }
if len(hooksFiles) == 0 { // Setup a new Service instance
hooksFiles = append(hooksFiles, "hooks.json") Service = service.New(*ip, *port)
// We must setup TLS prior to opening a listening port.
if *secure {
Service.SetTLSEnabled()
err := Service.SetTLSKeyPair(*cert, *key)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
err = Service.SetTLSMinVersion(*tlsMinVersion)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
Service.SetTLSCiphers(*tlsCipherSuites)
} }
// logQueue is a queue for log messages encountered during startup. We need // logQueue is a queue for log messages encountered during startup. We need
@ -120,10 +138,8 @@ func main() {
// log file opening prior to writing our first log message. // log file opening prior to writing our first log message.
var logQueue []string var logQueue []string
addr := fmt.Sprintf("%s:%d", *ip, *port)
// Open listener early so we can drop privileges. // Open listener early so we can drop privileges.
ln, err := net.Listen("tcp", addr) err := Service.Listen()
if err != nil { if err != nil {
logQueue = append(logQueue, fmt.Sprintf("error listening on port: %s", err)) logQueue = append(logQueue, fmt.Sprintf("error listening on port: %s", err))
// we'll bail out below // we'll bail out below
@ -166,7 +182,7 @@ func main() {
if *pidPath != "" { if *pidPath != "" {
var err error var err error
pidFile, err = pidfile.New(*pidPath) err = Service.CreatePIDFile(*pidPath)
if err != nil { if err != nil {
log.Fatalf("Error creating pidfile: %v", err) log.Fatalf("Error creating pidfile: %v", err)
} }
@ -174,7 +190,7 @@ func main() {
defer func() { defer func() {
// NOTE(moorereason): my testing shows that this doesn't work with // NOTE(moorereason): my testing shows that this doesn't work with
// ^C, so we also do a Remove in the signal handler elsewhere. // ^C, so we also do a Remove in the signal handler elsewhere.
if nerr := pidFile.Remove(); nerr != nil { if nerr := Service.DeletePIDFile(); nerr != nil {
log.Print(nerr) log.Print(nerr)
} }
}() }()
@ -183,9 +199,13 @@ func main() {
log.Println("version " + version + " starting") log.Println("version " + version + " starting")
// set os signal watcher // set os signal watcher
setupSignals() setupSignals(Service)
// load and parse hooks // load and parse hooks
if len(hooksFiles) == 0 {
hooksFiles = append(hooksFiles, "hooks.json")
}
for _, hooksFilePath := range hooksFiles { for _, hooksFilePath := range hooksFiles {
log.Printf("attempting to load hooks from %s\n", hooksFilePath) log.Printf("attempting to load hooks from %s\n", hooksFilePath)
@ -275,30 +295,15 @@ func main() {
r.HandleFunc(hooksURL, hookHandler) r.HandleFunc(hooksURL, hookHandler)
// Create common HTTP server settings // Create common HTTP server settings
svr := &http.Server{ Service.SetHTTPHandler(r)
Addr: addr,
Handler: r,
}
// Serve HTTP
if !*secure { if !*secure {
log.Printf("serving hooks on http://%s%s", addr, makeHumanPattern(hooksURLPrefix)) log.Printf("serving hooks on http://%s%s", Service.Address, makeHumanPattern(hooksURLPrefix))
log.Print(svr.Serve(ln)) } else {
log.Printf("serving hooks on https://%s%s", Service.Address, makeHumanPattern(hooksURLPrefix))
return
} }
// Server HTTPS log.Print(Service.Serve())
svr.TLSConfig = &tls.Config{
CipherSuites: getTLSCipherSuites(*tlsCipherSuites),
CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256},
MinVersion: getTLSMinVersion(*tlsMinVersion),
PreferServerCipherSuites: true,
}
svr.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2
log.Printf("serving hooks on https://%s%s", addr, makeHumanPattern(hooksURLPrefix))
log.Print(svr.ServeTLS(ln, *cert, *key))
} }
func hookHandler(w http.ResponseWriter, r *http.Request) { func hookHandler(w http.ResponseWriter, r *http.Request) {