Add service package
To help facilitate new features, begin moving the main webhook service properties to a Service struct.
This commit is contained in:
parent
194a9c4b3f
commit
b82e15e836
168
internal/service/security/tls.go
Normal file
168
internal/service/security/tls.go
Normal file
@ -0,0 +1,168 @@
|
||||
// Package security provides HTTP security management help to the webhook
|
||||
// service.
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// KeyPairReloader contains the active TLS certificate. It can be used with
|
||||
// the tls.Config.GetCertificate property to support live updating of the
|
||||
// certificate.
|
||||
type KeyPairReloader struct {
|
||||
certMu sync.RWMutex
|
||||
cert *tls.Certificate
|
||||
certPath string
|
||||
keyPath string
|
||||
}
|
||||
|
||||
// NewKeyPairReloader creates a new KeyPairReloader given the certificate and
|
||||
// key path.
|
||||
func NewKeyPairReloader(certPath, keyPath string) (*KeyPairReloader, error) {
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := &KeyPairReloader{
|
||||
cert: &cert,
|
||||
certPath: certPath,
|
||||
keyPath: keyPath,
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// GetCertificateFunc provides a function for tls.Config.GetCertificate.
|
||||
func (kpr *KeyPairReloader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
kpr.certMu.RLock()
|
||||
defer kpr.certMu.RUnlock()
|
||||
return kpr.cert, nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTLSSupportedCipherStrings writes a list of ciphers to w. The list is
|
||||
// all supported TLS ciphers based upon min.
|
||||
func WriteTLSSupportedCipherStrings(w io.Writer, min string) error {
|
||||
m, err := GetTLSVersion(min)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, c := range tls.CipherSuites() {
|
||||
var found bool
|
||||
|
||||
for _, v := range c.SupportedVersions {
|
||||
if v >= m {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err := w.Write([]byte(c.Name + "\n"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTLSVersion converts a TLS version string, v, (e.g. "v1.3") into a TLS
|
||||
// version ID.
|
||||
func GetTLSVersion(v string) (uint16, error) {
|
||||
switch v {
|
||||
case "1.3", "v1.3", "tls1.3":
|
||||
return tls.VersionTLS13, nil
|
||||
case "1.2", "v1.2", "tls1.2", "":
|
||||
return tls.VersionTLS12, nil
|
||||
case "1.1", "v1.1", "tls1.1":
|
||||
return tls.VersionTLS11, nil
|
||||
case "1.0", "v1.0", "tls1.0":
|
||||
return tls.VersionTLS10, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("error: unknown TLS version: %s", v)
|
||||
}
|
||||
}
|
||||
|
||||
// GetTLSCipherSuites converts a comma separated list of cipher suites into a
|
||||
// slice of TLS cipher suite IDs.
|
||||
func GetTLSCipherSuites(v string) []uint16 {
|
||||
supported := tls.CipherSuites()
|
||||
|
||||
if v == "" {
|
||||
suites := make([]uint16, len(supported))
|
||||
|
||||
for _, cs := range supported {
|
||||
suites = append(suites, cs.ID)
|
||||
}
|
||||
|
||||
return suites
|
||||
}
|
||||
|
||||
var found bool
|
||||
txts := strings.Split(v, ",")
|
||||
suites := make([]uint16, len(txts))
|
||||
|
||||
for _, want := range txts {
|
||||
found = false
|
||||
|
||||
for _, cs := range supported {
|
||||
if want == cs.Name {
|
||||
suites = append(suites, cs.ID)
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
log.Fatalln("error: unknown TLS cipher suite:", want)
|
||||
}
|
||||
}
|
||||
|
||||
return suites
|
||||
}
|
||||
|
||||
// GetTLSCurves converts a comma separated list of curves into a
|
||||
// slice of TLS curve IDs.
|
||||
func GetTLSCurves(v string) []tls.CurveID {
|
||||
supported := []tls.CurveID{
|
||||
tls.CurveP256,
|
||||
tls.CurveP384,
|
||||
tls.CurveP521,
|
||||
tls.X25519,
|
||||
}
|
||||
|
||||
if v == "" {
|
||||
return supported
|
||||
}
|
||||
|
||||
var found bool
|
||||
txts := strings.Split(v, ",")
|
||||
res := make([]tls.CurveID, len(txts))
|
||||
|
||||
for _, want := range txts {
|
||||
found = false
|
||||
|
||||
for _, c := range supported {
|
||||
if want == c.String() {
|
||||
res = append(res, c)
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
log.Fatalln("error: unknown TLS curve:", want)
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
153
internal/service/service.go
Normal file
153
internal/service/service.go
Normal file
@ -0,0 +1,153 @@
|
||||
// Package service manages the webhook HTTP service.
|
||||
package service
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/adnanh/webhook/internal/pidfile"
|
||||
"github.com/adnanh/webhook/internal/service/security"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// Service is the webhook HTTP service.
|
||||
type Service struct {
|
||||
// Address is the listener address for the service (e.g. "127.0.0.1:9000")
|
||||
Address string
|
||||
|
||||
// TLS settings
|
||||
enableTLS bool
|
||||
tlsCiphers []uint16
|
||||
tlsMinVersion uint16
|
||||
kpr *security.KeyPairReloader
|
||||
|
||||
// Future TLS settings to consider:
|
||||
// - tlsMaxVersion
|
||||
// - configurable TLS curves
|
||||
// - modern and intermediate helpers that follows Mozilla guidelines
|
||||
// - ca root and intermediate certs
|
||||
|
||||
listener net.Listener
|
||||
server *http.Server
|
||||
|
||||
pidFile *pidfile.PIDFile
|
||||
|
||||
// Hooks map[string]hook.Hooks
|
||||
}
|
||||
|
||||
// New creates a new webhook HTTP service for the given address and port.
|
||||
func New(ip string, port int) *Service {
|
||||
return &Service{
|
||||
Address: fmt.Sprintf("%s:%d", ip, port),
|
||||
server: &http.Server{},
|
||||
tlsMinVersion: tls.VersionTLS12,
|
||||
}
|
||||
}
|
||||
|
||||
// Listen announces the TCP service on the local network.
|
||||
//
|
||||
// To enable TLS, ensure that SetTLSEnabled is called prior to Listen.
|
||||
//
|
||||
// After calling Listen, Serve must be called to begin serving HTTP requests.
|
||||
// The steps are separated so that we can drop privileges, if necessary, after
|
||||
// opening the listening port.
|
||||
func (s *Service) Listen() error {
|
||||
ln, err := net.Listen("tcp", s.Address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !s.enableTLS {
|
||||
s.listener = ln
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.kpr == nil {
|
||||
panic("Listen called with TLS enabled but KPR is nil")
|
||||
}
|
||||
|
||||
c := &tls.Config{
|
||||
GetCertificate: s.kpr.GetCertificateFunc(),
|
||||
CipherSuites: s.tlsCiphers,
|
||||
CurvePreferences: security.GetTLSCurves(""),
|
||||
MinVersion: s.tlsMinVersion,
|
||||
PreferServerCipherSuites: true,
|
||||
}
|
||||
|
||||
s.listener = tls.NewListener(ln, c)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Serve begins accepting incoming HTTP connections.
|
||||
func (s *Service) Serve() error {
|
||||
s.server.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2
|
||||
|
||||
if s.listener == nil {
|
||||
err := s.Listen()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
defer s.listener.Close()
|
||||
return s.server.Serve(s.listener)
|
||||
}
|
||||
|
||||
// SetHTTPHandler sets the underly HTTP server Handler.
|
||||
func (s *Service) SetHTTPHandler(r *mux.Router) {
|
||||
s.server.Handler = r
|
||||
}
|
||||
|
||||
// SetTLSCiphers sets the supported TLS ciphers.
|
||||
func (s *Service) SetTLSCiphers(suites string) {
|
||||
s.tlsCiphers = security.GetTLSCipherSuites(suites)
|
||||
}
|
||||
|
||||
// SetTLSEnabled enables TLS for the service. Must be called prior to Listen.
|
||||
func (s *Service) SetTLSEnabled() {
|
||||
s.enableTLS = true
|
||||
}
|
||||
|
||||
// SetTLSKeyPair sets the TLS key pair for the service.
|
||||
func (s *Service) SetTLSKeyPair(certPath, keyPath string) error {
|
||||
if certPath == "" {
|
||||
return fmt.Errorf("error: certificate required for TLS")
|
||||
}
|
||||
|
||||
if keyPath == "" {
|
||||
return fmt.Errorf("error: key required for TLS")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
s.kpr, err = security.NewKeyPairReloader(certPath, keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetTLSMinVersion sets the minimum support TLS version, such as "v1.3".
|
||||
func (s *Service) SetTLSMinVersion(ver string) (err error) {
|
||||
s.tlsMinVersion, err = security.GetTLSVersion(ver)
|
||||
return err
|
||||
}
|
||||
|
||||
// CreatePIDFile creates a new PID file at path p.
|
||||
func (s *Service) CreatePIDFile(p string) (err error) {
|
||||
s.pidFile, err = pidfile.New(p)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeletePIDFile deletes a previously created PID file.
|
||||
func (s *Service) DeletePIDFile() error {
|
||||
if s.pidFile != nil {
|
||||
return s.pidFile.Remove()
|
||||
}
|
||||
return nil
|
||||
}
|
16
signals.go
16
signals.go
@ -7,9 +7,11 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/adnanh/webhook/internal/service"
|
||||
)
|
||||
|
||||
func setupSignals() {
|
||||
func setupSignals(svc *service.Service) {
|
||||
log.Printf("setting up os signal watcher\n")
|
||||
|
||||
signals = make(chan os.Signal, 1)
|
||||
@ -18,10 +20,10 @@ func setupSignals() {
|
||||
signal.Notify(signals, syscall.SIGTERM)
|
||||
signal.Notify(signals, os.Interrupt)
|
||||
|
||||
go watchForSignals()
|
||||
go watchForSignals(svc)
|
||||
}
|
||||
|
||||
func watchForSignals() {
|
||||
func watchForSignals(svc *service.Service) {
|
||||
log.Println("os signal watcher ready")
|
||||
|
||||
for {
|
||||
@ -37,11 +39,9 @@ func watchForSignals() {
|
||||
|
||||
case os.Interrupt, syscall.SIGTERM:
|
||||
log.Printf("caught %s signal; exiting\n", sig)
|
||||
if pidFile != nil {
|
||||
err := pidFile.Remove()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
err := svc.DeletePIDFile()
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
os.Exit(0)
|
||||
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
package main
|
||||
|
||||
func setupSignals() {
|
||||
import "github.com/adnanh/webhook/internal/service"
|
||||
|
||||
func setupSignals(_ *service.Service) {
|
||||
// NOOP: Windows doesn't have signals equivalent to the Unix world.
|
||||
}
|
||||
|
85
tls.go
85
tls.go
@ -1,85 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func writeTLSSupportedCipherStrings(w io.Writer, min uint16) error {
|
||||
for _, c := range tls.CipherSuites() {
|
||||
var found bool
|
||||
|
||||
for _, v := range c.SupportedVersions {
|
||||
if v >= min {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
continue
|
||||
}
|
||||
|
||||
_, err := w.Write([]byte(c.Name + "\n"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getTLSMinVersion converts a version string into a TLS version ID.
|
||||
func getTLSMinVersion(v string) uint16 {
|
||||
switch v {
|
||||
case "1.0":
|
||||
return tls.VersionTLS10
|
||||
case "1.1":
|
||||
return tls.VersionTLS11
|
||||
case "1.2", "":
|
||||
return tls.VersionTLS12
|
||||
case "1.3":
|
||||
return tls.VersionTLS13
|
||||
default:
|
||||
log.Fatalln("error: unknown minimum TLS version:", v)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// getTLSCipherSuites converts a comma separated list of cipher suites into a
|
||||
// slice of TLS cipher suite IDs.
|
||||
func getTLSCipherSuites(v string) []uint16 {
|
||||
supported := tls.CipherSuites()
|
||||
|
||||
if v == "" {
|
||||
suites := make([]uint16, len(supported))
|
||||
|
||||
for _, cs := range supported {
|
||||
suites = append(suites, cs.ID)
|
||||
}
|
||||
|
||||
return suites
|
||||
}
|
||||
|
||||
var found bool
|
||||
txts := strings.Split(v, ",")
|
||||
suites := make([]uint16, len(txts))
|
||||
|
||||
for _, want := range txts {
|
||||
found = false
|
||||
|
||||
for _, cs := range supported {
|
||||
if want == cs.Name {
|
||||
suites = append(suites, cs.ID)
|
||||
found = true
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
log.Fatalln("error: unknown TLS cipher suite:", want)
|
||||
}
|
||||
}
|
||||
|
||||
return suites
|
||||
}
|
71
webhook.go
71
webhook.go
@ -1,13 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
@ -17,7 +15,8 @@ import (
|
||||
|
||||
"github.com/adnanh/webhook/internal/hook"
|
||||
"github.com/adnanh/webhook/internal/middleware"
|
||||
"github.com/adnanh/webhook/internal/pidfile"
|
||||
"github.com/adnanh/webhook/internal/service"
|
||||
"github.com/adnanh/webhook/internal/service/security"
|
||||
|
||||
chimiddleware "github.com/go-chi/chi/middleware"
|
||||
"github.com/gorilla/mux"
|
||||
@ -60,7 +59,8 @@ var (
|
||||
|
||||
watcher *fsnotify.Watcher
|
||||
signals chan os.Signal
|
||||
pidFile *pidfile.PIDFile
|
||||
|
||||
S *service.Service
|
||||
)
|
||||
|
||||
func matchLoadedHook(id string) *hook.Hook {
|
||||
@ -94,7 +94,7 @@ func main() {
|
||||
}
|
||||
|
||||
if *justListCiphers {
|
||||
err := writeTLSSupportedCipherStrings(os.Stdout, getTLSMinVersion(*tlsMinVersion))
|
||||
err := security.WriteTLSSupportedCipherStrings(os.Stdout, *tlsMinVersion)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
@ -111,8 +111,26 @@ func main() {
|
||||
*verbose = true
|
||||
}
|
||||
|
||||
if len(hooksFiles) == 0 {
|
||||
hooksFiles = append(hooksFiles, "hooks.json")
|
||||
// Setup a new Service instance
|
||||
S = service.New(*ip, *port)
|
||||
|
||||
// We must setup TLS prior to opening a listening port.
|
||||
if *secure {
|
||||
S.SetTLSEnabled()
|
||||
|
||||
err := S.SetTLSKeyPair(*cert, *key)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = S.SetTLSMinVersion(*tlsMinVersion)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
S.SetTLSCiphers(*tlsCipherSuites)
|
||||
}
|
||||
|
||||
// logQueue is a queue for log messages encountered during startup. We need
|
||||
@ -120,10 +138,8 @@ func main() {
|
||||
// log file opening prior to writing our first log message.
|
||||
var logQueue []string
|
||||
|
||||
addr := fmt.Sprintf("%s:%d", *ip, *port)
|
||||
|
||||
// Open listener early so we can drop privileges.
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
err := S.Listen()
|
||||
if err != nil {
|
||||
logQueue = append(logQueue, fmt.Sprintf("error listening on port: %s", err))
|
||||
// we'll bail out below
|
||||
@ -166,7 +182,7 @@ func main() {
|
||||
if *pidPath != "" {
|
||||
var err error
|
||||
|
||||
pidFile, err = pidfile.New(*pidPath)
|
||||
err = S.CreatePIDFile(*pidPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Error creating pidfile: %v", err)
|
||||
}
|
||||
@ -174,7 +190,7 @@ func main() {
|
||||
defer func() {
|
||||
// NOTE(moorereason): my testing shows that this doesn't work with
|
||||
// ^C, so we also do a Remove in the signal handler elsewhere.
|
||||
if nerr := pidFile.Remove(); nerr != nil {
|
||||
if nerr := S.DeletePIDFile(); nerr != nil {
|
||||
log.Print(nerr)
|
||||
}
|
||||
}()
|
||||
@ -183,9 +199,13 @@ func main() {
|
||||
log.Println("version " + version + " starting")
|
||||
|
||||
// set os signal watcher
|
||||
setupSignals()
|
||||
setupSignals(S)
|
||||
|
||||
// load and parse hooks
|
||||
if len(hooksFiles) == 0 {
|
||||
hooksFiles = append(hooksFiles, "hooks.json")
|
||||
}
|
||||
|
||||
for _, hooksFilePath := range hooksFiles {
|
||||
log.Printf("attempting to load hooks from %s\n", hooksFilePath)
|
||||
|
||||
@ -271,30 +291,15 @@ func main() {
|
||||
r.HandleFunc(hooksURL, hookHandler)
|
||||
|
||||
// Create common HTTP server settings
|
||||
svr := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: r,
|
||||
}
|
||||
S.SetHTTPHandler(r)
|
||||
|
||||
// Serve HTTP
|
||||
if !*secure {
|
||||
log.Printf("serving hooks on http://%s%s", addr, makeHumanPattern(hooksURLPrefix))
|
||||
log.Print(svr.Serve(ln))
|
||||
|
||||
return
|
||||
log.Printf("serving hooks on http://%s%s", S.Address, makeHumanPattern(hooksURLPrefix))
|
||||
} else {
|
||||
log.Printf("serving hooks on https://%s%s", S.Address, makeHumanPattern(hooksURLPrefix))
|
||||
}
|
||||
|
||||
// Server HTTPS
|
||||
svr.TLSConfig = &tls.Config{
|
||||
CipherSuites: getTLSCipherSuites(*tlsCipherSuites),
|
||||
CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256},
|
||||
MinVersion: getTLSMinVersion(*tlsMinVersion),
|
||||
PreferServerCipherSuites: true,
|
||||
}
|
||||
svr.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2
|
||||
|
||||
log.Printf("serving hooks on https://%s%s", addr, makeHumanPattern(hooksURLPrefix))
|
||||
log.Print(svr.ServeTLS(ln, *cert, *key))
|
||||
log.Print(S.Serve())
|
||||
}
|
||||
|
||||
func hookHandler(w http.ResponseWriter, r *http.Request) {
|
||||
|
Loading…
Reference in New Issue
Block a user