Add suport for context-provider-command hook option.

The `context-provider-command` allows user to specify a command which will be run whenever the hook gets matched.
Webhook will pass the command a JSON string via the STDIN containing the request context (matched hook id, method used to trigger the hook, remote address, requested host, requested URI, raw body, headers and query values).
The output of the command must be a valid JSON string which will be mapped back into a special source named `context` that can be used with existing rules and directives.
This commit is contained in:
Adnan Hajdarevic 2019-11-22 02:40:59 +01:00
parent 34ae132930
commit 3ec7da2b15
2 changed files with 130 additions and 41 deletions

View File

@ -32,6 +32,7 @@ const (
SourceQuery string = "url" SourceQuery string = "url"
SourceQueryAlias string = "query" SourceQueryAlias string = "query"
SourcePayload string = "payload" SourcePayload string = "payload"
SourceContext string = "context"
SourceString string = "string" SourceString string = "string"
SourceEntirePayload string = "entire-payload" SourceEntirePayload string = "entire-payload"
SourceEntireQuery string = "entire-query" SourceEntireQuery string = "entire-query"
@ -323,7 +324,7 @@ type Argument struct {
// Get Argument method returns the value for the Argument's key name // Get Argument method returns the value for the Argument's key name
// based on the Argument's source // based on the Argument's source
func (ha *Argument) Get(headers, query, payload *map[string]interface{}) (string, bool) { func (ha *Argument) Get(headers, query, payload *map[string]interface{}, context *map[string]interface{}) (string, bool) {
var source *map[string]interface{} var source *map[string]interface{}
key := ha.Name key := ha.Name
@ -335,6 +336,8 @@ func (ha *Argument) Get(headers, query, payload *map[string]interface{}) (string
source = query source = query
case SourcePayload: case SourcePayload:
source = payload source = payload
case SourceContext:
source = context
case SourceString: case SourceString:
return ha.Name, true return ha.Name, true
case SourceEntirePayload: case SourceEntirePayload:
@ -424,6 +427,7 @@ func (h *HooksFiles) Set(value string) error {
type Hook struct { type Hook struct {
ID string `json:"id,omitempty"` ID string `json:"id,omitempty"`
ExecuteCommand string `json:"execute-command,omitempty"` ExecuteCommand string `json:"execute-command,omitempty"`
ContextProviderCommand string `json:"context-provider-command,omitempty"`
CommandWorkingDirectory string `json:"command-working-directory,omitempty"` CommandWorkingDirectory string `json:"command-working-directory,omitempty"`
ResponseMessage string `json:"response-message,omitempty"` ResponseMessage string `json:"response-message,omitempty"`
ResponseHeaders ResponseHeaders `json:"response-headers,omitempty"` ResponseHeaders ResponseHeaders `json:"response-headers,omitempty"`
@ -441,11 +445,11 @@ type Hook struct {
// ParseJSONParameters decodes specified arguments to JSON objects and replaces the // ParseJSONParameters decodes specified arguments to JSON objects and replaces the
// string with the newly created object // string with the newly created object
func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface{}) []error { func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface{}, context *map[string]interface{}) []error {
errors := make([]error, 0) errors := make([]error, 0)
for i := range h.JSONStringParameters { for i := range h.JSONStringParameters {
if arg, ok := h.JSONStringParameters[i].Get(headers, query, payload); ok { if arg, ok := h.JSONStringParameters[i].Get(headers, query, payload, context); ok {
var newArg map[string]interface{} var newArg map[string]interface{}
decoder := json.NewDecoder(strings.NewReader(string(arg))) decoder := json.NewDecoder(strings.NewReader(string(arg)))
@ -464,6 +468,8 @@ func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface
source = headers source = headers
case SourcePayload: case SourcePayload:
source = payload source = payload
case SourceContext:
source = context
case SourceQuery, SourceQueryAlias: case SourceQuery, SourceQueryAlias:
source = query source = query
} }
@ -493,14 +499,14 @@ func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface
// ExtractCommandArguments creates a list of arguments, based on the // ExtractCommandArguments creates a list of arguments, based on the
// PassArgumentsToCommand property that is ready to be used with exec.Command() // PassArgumentsToCommand property that is ready to be used with exec.Command()
func (h *Hook) ExtractCommandArguments(headers, query, payload *map[string]interface{}) ([]string, []error) { func (h *Hook) ExtractCommandArguments(headers, query, payload *map[string]interface{}, context *map[string]interface{}) ([]string, []error) {
args := make([]string, 0) args := make([]string, 0)
errors := make([]error, 0) errors := make([]error, 0)
args = append(args, h.ExecuteCommand) args = append(args, h.ExecuteCommand)
for i := range h.PassArgumentsToCommand { for i := range h.PassArgumentsToCommand {
if arg, ok := h.PassArgumentsToCommand[i].Get(headers, query, payload); ok { if arg, ok := h.PassArgumentsToCommand[i].Get(headers, query, payload, context); ok {
args = append(args, arg) args = append(args, arg)
} else { } else {
args = append(args, "") args = append(args, "")
@ -518,11 +524,11 @@ func (h *Hook) ExtractCommandArguments(headers, query, payload *map[string]inter
// ExtractCommandArgumentsForEnv creates a list of arguments in key=value // ExtractCommandArgumentsForEnv creates a list of arguments in key=value
// format, based on the PassEnvironmentToCommand property that is ready to be used // format, based on the PassEnvironmentToCommand property that is ready to be used
// with exec.Command(). // with exec.Command().
func (h *Hook) ExtractCommandArgumentsForEnv(headers, query, payload *map[string]interface{}) ([]string, []error) { func (h *Hook) ExtractCommandArgumentsForEnv(headers, query, payload *map[string]interface{}, context *map[string]interface{}) ([]string, []error) {
args := make([]string, 0) args := make([]string, 0)
errors := make([]error, 0) errors := make([]error, 0)
for i := range h.PassEnvironmentToCommand { for i := range h.PassEnvironmentToCommand {
if arg, ok := h.PassEnvironmentToCommand[i].Get(headers, query, payload); ok { if arg, ok := h.PassEnvironmentToCommand[i].Get(headers, query, payload, context); ok {
if h.PassEnvironmentToCommand[i].EnvName != "" { if h.PassEnvironmentToCommand[i].EnvName != "" {
// first try to use the EnvName if specified // first try to use the EnvName if specified
args = append(args, h.PassEnvironmentToCommand[i].EnvName+"="+arg) args = append(args, h.PassEnvironmentToCommand[i].EnvName+"="+arg)
@ -552,11 +558,11 @@ type FileParameter struct {
// ExtractCommandArgumentsForFile creates a list of arguments in key=value // ExtractCommandArgumentsForFile creates a list of arguments in key=value
// format, based on the PassFileToCommand property that is ready to be used // format, based on the PassFileToCommand property that is ready to be used
// with exec.Command(). // with exec.Command().
func (h *Hook) ExtractCommandArgumentsForFile(headers, query, payload *map[string]interface{}) ([]FileParameter, []error) { func (h *Hook) ExtractCommandArgumentsForFile(headers, query, payload *map[string]interface{}, context *map[string]interface{}) ([]FileParameter, []error) {
args := make([]FileParameter, 0) args := make([]FileParameter, 0)
errors := make([]error, 0) errors := make([]error, 0)
for i := range h.PassFileToCommand { for i := range h.PassFileToCommand {
if arg, ok := h.PassFileToCommand[i].Get(headers, query, payload); ok { if arg, ok := h.PassFileToCommand[i].Get(headers, query, payload, context); ok {
if h.PassFileToCommand[i].EnvName == "" { if h.PassFileToCommand[i].EnvName == "" {
// if no environment-variable name is set, fall-back on the name // if no environment-variable name is set, fall-back on the name
@ -664,16 +670,16 @@ type Rules struct {
// Evaluate finds the first rule property that is not nil and returns the value // Evaluate finds the first rule property that is not nil and returns the value
// it evaluates to // it evaluates to
func (r Rules) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { func (r Rules) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) {
switch { switch {
case r.And != nil: case r.And != nil:
return r.And.Evaluate(headers, query, payload, body, remoteAddr) return r.And.Evaluate(headers, query, payload, context, body, remoteAddr)
case r.Or != nil: case r.Or != nil:
return r.Or.Evaluate(headers, query, payload, body, remoteAddr) return r.Or.Evaluate(headers, query, payload, context, body, remoteAddr)
case r.Not != nil: case r.Not != nil:
return r.Not.Evaluate(headers, query, payload, body, remoteAddr) return r.Not.Evaluate(headers, query, payload, context, body, remoteAddr)
case r.Match != nil: case r.Match != nil:
return r.Match.Evaluate(headers, query, payload, body, remoteAddr) return r.Match.Evaluate(headers, query, payload, context, body, remoteAddr)
} }
return false, nil return false, nil
@ -683,11 +689,11 @@ func (r Rules) Evaluate(headers, query, payload *map[string]interface{}, body *[
type AndRule []Rules type AndRule []Rules
// Evaluate AndRule will return true if and only if all of ChildRules evaluate to true // Evaluate AndRule will return true if and only if all of ChildRules evaluate to true
func (r AndRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { func (r AndRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) {
res := true res := true
for _, v := range r { for _, v := range r {
rv, err := v.Evaluate(headers, query, payload, body, remoteAddr) rv, err := v.Evaluate(headers, query, payload, context, body, remoteAddr)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -705,11 +711,11 @@ func (r AndRule) Evaluate(headers, query, payload *map[string]interface{}, body
type OrRule []Rules type OrRule []Rules
// Evaluate OrRule will return true if any of ChildRules evaluate to true // Evaluate OrRule will return true if any of ChildRules evaluate to true
func (r OrRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { func (r OrRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) {
res := false res := false
for _, v := range r { for _, v := range r {
rv, err := v.Evaluate(headers, query, payload, body, remoteAddr) rv, err := v.Evaluate(headers, query, payload, context, body, remoteAddr)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -727,8 +733,8 @@ func (r OrRule) Evaluate(headers, query, payload *map[string]interface{}, body *
type NotRule Rules type NotRule Rules
// Evaluate NotRule will return true if and only if ChildRule evaluates to false // Evaluate NotRule will return true if and only if ChildRule evaluates to false
func (r NotRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { func (r NotRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) {
rv, err := Rules(r).Evaluate(headers, query, payload, body, remoteAddr) rv, err := Rules(r).Evaluate(headers, query, payload, context, body, remoteAddr)
return !rv, err return !rv, err
} }
@ -753,15 +759,16 @@ const (
) )
// Evaluate MatchRule will return based on the type // Evaluate MatchRule will return based on the type
func (r MatchRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { func (r MatchRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) {
if r.Type == IPWhitelist { if r.Type == IPWhitelist {
return CheckIPWhitelist(remoteAddr, r.IPRange) return CheckIPWhitelist(remoteAddr, r.IPRange)
} }
if r.Type == ScalrSignature { if r.Type == ScalrSignature {
return CheckScalrSignature(*headers, *body, r.Secret, true) return CheckScalrSignature(*headers, *body, r.Secret, true)
} }
if arg, ok := r.Parameter.Get(headers, query, payload); ok { if arg, ok := r.Parameter.Get(headers, query, payload, context); ok {
switch r.Type { switch r.Type {
case MatchValue: case MatchValue:
return arg == r.Value, nil return arg == r.Value, nil

View File

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
"net/http" "net/http"
@ -205,7 +206,6 @@ func main() {
} }
func hookHandler(w http.ResponseWriter, r *http.Request) { func hookHandler(w http.ResponseWriter, r *http.Request) {
// generate a request id for logging // generate a request id for logging
rid := uuid.NewV4().String()[:6] rid := uuid.NewV4().String()[:6]
@ -231,8 +231,87 @@ func hookHandler(w http.ResponseWriter, r *http.Request) {
// parse query variables // parse query variables
query := valuesToMap(r.URL.Query()) query := valuesToMap(r.URL.Query())
// parse body // parse context
var payload map[string]interface{} var context map[string]interface{}
if matchedHook.ContextProviderCommand != "" {
// check the command exists
cmdPath, err := exec.LookPath(matchedHook.ContextProviderCommand)
if err != nil {
// give a last chance, maybe it's a relative path
relativeToCwd := filepath.Join(matchedHook.CommandWorkingDirectory, matchedHook.ContextProviderCommand)
// check the command exists
cmdPath, err = exec.LookPath(relativeToCwd)
}
if err != nil {
log.Printf("[%s] unable to locate context provider command: '%s', %+v\n", rid, matchedHook.ContextProviderCommand, err)
// check if parameters specified in context-provider-command by mistake
if strings.IndexByte(matchedHook.ContextProviderCommand, ' ') != -1 {
s := strings.Fields(matchedHook.ContextProviderCommand)[0]
log.Printf("[%s] please use a wrapper script to provide arguments to context provider command for '%s'\n", rid, s)
}
} else {
contextProviderCommandStdin := struct {
HookID string `json:"hookID"`
Method string `json:"method"`
Body string `json:"body"`
RemoteAddress string `json:"remoteAddress"`
URI string `json:"URI"`
Host string `json:"host"`
Headers http.Header `json:"headers"`
Query url.Values `json:"query"`
}{
HookID: matchedHook.ID,
Method: r.Method,
Body: string(body),
RemoteAddress: r.RemoteAddr,
URI: r.RequestURI,
Host: r.Host,
Headers: r.Header,
Query: r.URL.Query(),
}
stdinJSON, err := json.Marshal(contextProviderCommandStdin)
if err != nil {
log.Printf("[%s] unable to encode context as JSON string for the context provider command: %+v\n", rid, err)
} else {
cmd := exec.Command(cmdPath)
cmd.Dir = matchedHook.CommandWorkingDirectory
cmd.Env = append(os.Environ())
stdin, err := cmd.StdinPipe()
if err != nil {
log.Printf("[%s] unable to acquire stdin pipe for the context provider command: %+v\n", rid, err)
} else {
_, err := io.WriteString(stdin, string(stdinJSON))
stdin.Close()
if err != nil {
log.Printf("[%s] unable to write to context provider command stdin: %+v\n", rid, err)
} else {
log.Printf("[%s] executing context provider command %s (%s) using %s as cwd\n", rid, matchedHook.ContextProviderCommand, cmd.Path, cmd.Dir)
out, err := cmd.CombinedOutput()
if err != nil {
log.Printf("[%s] unable to execute context provider command: %+v\n", rid, err)
} else {
log.Printf("[%s] got context provider command output: %+v\n", rid, string(out))
decoder := json.NewDecoder(strings.NewReader(string(out)))
decoder.UseNumber()
err := decoder.Decode(&context)
if err != nil {
log.Printf("[%s] unable to parse context provider command output: %+v\n", rid, err)
}
}
}
}
}
}
}
// set contentType to IncomingPayloadContentType or header value // set contentType to IncomingPayloadContentType or header value
contentType := r.Header.Get("Content-Type") contentType := r.Header.Get("Content-Type")
@ -240,6 +319,9 @@ func hookHandler(w http.ResponseWriter, r *http.Request) {
contentType = matchedHook.IncomingPayloadContentType contentType = matchedHook.IncomingPayloadContentType
} }
// parse body
var payload map[string]interface{}
if strings.Contains(contentType, "json") { if strings.Contains(contentType, "json") {
decoder := json.NewDecoder(strings.NewReader(string(body))) decoder := json.NewDecoder(strings.NewReader(string(body)))
decoder.UseNumber() decoder.UseNumber()
@ -259,7 +341,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) {
} }
// handle hook // handle hook
errors := matchedHook.ParseJSONParameters(&headers, &query, &payload) errors := matchedHook.ParseJSONParameters(&headers, &query, &payload, &context)
for _, err := range errors { for _, err := range errors {
log.Printf("[%s] error parsing JSON parameters: %s\n", rid, err) log.Printf("[%s] error parsing JSON parameters: %s\n", rid, err)
} }
@ -269,7 +351,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) {
if matchedHook.TriggerRule == nil { if matchedHook.TriggerRule == nil {
ok = true ok = true
} else { } else {
ok, err = matchedHook.TriggerRule.Evaluate(&headers, &query, &payload, &body, r.RemoteAddr) ok, err = matchedHook.TriggerRule.Evaluate(&headers, &query, &payload, &context, &body, r.RemoteAddr)
if err != nil { if err != nil {
msg := fmt.Sprintf("[%s] error evaluating hook: %s", rid, err) msg := fmt.Sprintf("[%s] error evaluating hook: %s", rid, err)
log.Print(msg) log.Print(msg)
@ -287,7 +369,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) {
} }
if matchedHook.CaptureCommandOutput { if matchedHook.CaptureCommandOutput {
response, err := handleHook(matchedHook, rid, &headers, &query, &payload, &body) response, err := handleHook(matchedHook, rid, &headers, &query, &payload, &context, &body)
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
@ -305,7 +387,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, response) fmt.Fprint(w, response)
} }
} else { } else {
go handleHook(matchedHook, rid, &headers, &query, &payload, &body) go handleHook(matchedHook, rid, &headers, &query, &payload, &context, &body)
// Check if a success return code is configured for the hook // Check if a success return code is configured for the hook
if matchedHook.SuccessHttpResponseCode != 0 { if matchedHook.SuccessHttpResponseCode != 0 {
@ -332,7 +414,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]interface{}, body *[]byte) (string, error) { func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte) (string, error) {
var errors []error var errors []error
// check the command exists // check the command exists
@ -345,12 +427,12 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in
} }
if err != nil { if err != nil {
log.Printf("unable to locate command: '%s'", h.ExecuteCommand) log.Printf("[%s] unable to locate command: '%s'\n", rid, h.ExecuteCommand)
// check if parameters specified in execute-command by mistake // check if parameters specified in execute-command by mistake
if strings.IndexByte(h.ExecuteCommand, ' ') != -1 { if strings.IndexByte(h.ExecuteCommand, ' ') != -1 {
s := strings.Fields(h.ExecuteCommand)[0] s := strings.Fields(h.ExecuteCommand)[0]
log.Printf("use 'pass-arguments-to-command' to specify args for '%s'", s) log.Printf("[%s] please use 'pass-arguments-to-command' to specify args for '%s'\n", rid, s)
} }
return "", err return "", err
@ -359,19 +441,19 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in
cmd := exec.Command(cmdPath) cmd := exec.Command(cmdPath)
cmd.Dir = h.CommandWorkingDirectory cmd.Dir = h.CommandWorkingDirectory
cmd.Args, errors = h.ExtractCommandArguments(headers, query, payload) cmd.Args, errors = h.ExtractCommandArguments(headers, query, payload, context)
for _, err := range errors { for _, err := range errors {
log.Printf("[%s] error extracting command arguments: %s\n", rid, err) log.Printf("[%s] error extracting command arguments: %s\n", rid, err)
} }
var envs []string var envs []string
envs, errors = h.ExtractCommandArgumentsForEnv(headers, query, payload) envs, errors = h.ExtractCommandArgumentsForEnv(headers, query, payload, context)
for _, err := range errors { for _, err := range errors {
log.Printf("[%s] error extracting command arguments for environment: %s\n", rid, err) log.Printf("[%s] error extracting command arguments for environment: %s\n", rid, err)
} }
files, errors := h.ExtractCommandArgumentsForFile(headers, query, payload) files, errors := h.ExtractCommandArgumentsForFile(headers, query, payload, context)
for _, err := range errors { for _, err := range errors {
log.Printf("[%s] error extracting command arguments for file: %s\n", rid, err) log.Printf("[%s] error extracting command arguments for file: %s\n", rid, err)
@ -380,16 +462,16 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in
for i := range files { for i := range files {
tmpfile, err := ioutil.TempFile(h.CommandWorkingDirectory, files[i].EnvName) tmpfile, err := ioutil.TempFile(h.CommandWorkingDirectory, files[i].EnvName)
if err != nil { if err != nil {
log.Printf("[%s] error creating temp file [%s]", rid, err) log.Printf("[%s] error creating temp file [%s]\n", rid, err)
continue continue
} }
log.Printf("[%s] writing env %s file %s", rid, files[i].EnvName, tmpfile.Name()) log.Printf("[%s] writing env %s file %s", rid, files[i].EnvName, tmpfile.Name())
if _, err := tmpfile.Write(files[i].Data); err != nil { if _, err := tmpfile.Write(files[i].Data); err != nil {
log.Printf("[%s] error writing file %s [%s]", rid, tmpfile.Name(), err) log.Printf("[%s] error writing file %s [%s]\n", rid, tmpfile.Name(), err)
continue continue
} }
if err := tmpfile.Close(); err != nil { if err := tmpfile.Close(); err != nil {
log.Printf("[%s] error closing file %s [%s]", rid, tmpfile.Name(), err) log.Printf("[%s] error closing file %s [%s]\n", rid, tmpfile.Name(), err)
continue continue
} }
@ -414,7 +496,7 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in
log.Printf("[%s] removing file %s\n", rid, files[i].File.Name()) log.Printf("[%s] removing file %s\n", rid, files[i].File.Name())
err := os.Remove(files[i].File.Name()) err := os.Remove(files[i].File.Name())
if err != nil { if err != nil {
log.Printf("[%s] error removing file %s [%s]", rid, files[i].File.Name(), err) log.Printf("[%s] error removing file %s [%s]\n", rid, files[i].File.Name(), err)
} }
} }
} }