From 3ec7da2b1576407722dac70ce8efd491031f34e9 Mon Sep 17 00:00:00 2001 From: Adnan Hajdarevic Date: Fri, 22 Nov 2019 02:40:59 +0100 Subject: [PATCH] Add suport for `context-provider-command` hook option. The `context-provider-command` allows user to specify a command which will be run whenever the hook gets matched. Webhook will pass the command a JSON string via the STDIN containing the request context (matched hook id, method used to trigger the hook, remote address, requested host, requested URI, raw body, headers and query values). The output of the command must be a valid JSON string which will be mapped back into a special source named `context` that can be used with existing rules and directives. --- hook/hook.go | 51 ++++++++++++---------- webhook.go | 120 +++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 130 insertions(+), 41 deletions(-) diff --git a/hook/hook.go b/hook/hook.go index 98fb975..935afc5 100644 --- a/hook/hook.go +++ b/hook/hook.go @@ -32,6 +32,7 @@ const ( SourceQuery string = "url" SourceQueryAlias string = "query" SourcePayload string = "payload" + SourceContext string = "context" SourceString string = "string" SourceEntirePayload string = "entire-payload" SourceEntireQuery string = "entire-query" @@ -323,7 +324,7 @@ type Argument struct { // Get Argument method returns the value for the Argument's key name // based on the Argument's source -func (ha *Argument) Get(headers, query, payload *map[string]interface{}) (string, bool) { +func (ha *Argument) Get(headers, query, payload *map[string]interface{}, context *map[string]interface{}) (string, bool) { var source *map[string]interface{} key := ha.Name @@ -335,6 +336,8 @@ func (ha *Argument) Get(headers, query, payload *map[string]interface{}) (string source = query case SourcePayload: source = payload + case SourceContext: + source = context case SourceString: return ha.Name, true case SourceEntirePayload: @@ -424,6 +427,7 @@ func (h *HooksFiles) Set(value string) error { type Hook struct { ID string `json:"id,omitempty"` ExecuteCommand string `json:"execute-command,omitempty"` + ContextProviderCommand string `json:"context-provider-command,omitempty"` CommandWorkingDirectory string `json:"command-working-directory,omitempty"` ResponseMessage string `json:"response-message,omitempty"` ResponseHeaders ResponseHeaders `json:"response-headers,omitempty"` @@ -441,11 +445,11 @@ type Hook struct { // ParseJSONParameters decodes specified arguments to JSON objects and replaces the // string with the newly created object -func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface{}) []error { +func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface{}, context *map[string]interface{}) []error { errors := make([]error, 0) for i := range h.JSONStringParameters { - if arg, ok := h.JSONStringParameters[i].Get(headers, query, payload); ok { + if arg, ok := h.JSONStringParameters[i].Get(headers, query, payload, context); ok { var newArg map[string]interface{} decoder := json.NewDecoder(strings.NewReader(string(arg))) @@ -464,6 +468,8 @@ func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface source = headers case SourcePayload: source = payload + case SourceContext: + source = context case SourceQuery, SourceQueryAlias: source = query } @@ -493,14 +499,14 @@ func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface // ExtractCommandArguments creates a list of arguments, based on the // PassArgumentsToCommand property that is ready to be used with exec.Command() -func (h *Hook) ExtractCommandArguments(headers, query, payload *map[string]interface{}) ([]string, []error) { +func (h *Hook) ExtractCommandArguments(headers, query, payload *map[string]interface{}, context *map[string]interface{}) ([]string, []error) { args := make([]string, 0) errors := make([]error, 0) args = append(args, h.ExecuteCommand) for i := range h.PassArgumentsToCommand { - if arg, ok := h.PassArgumentsToCommand[i].Get(headers, query, payload); ok { + if arg, ok := h.PassArgumentsToCommand[i].Get(headers, query, payload, context); ok { args = append(args, arg) } else { args = append(args, "") @@ -518,11 +524,11 @@ func (h *Hook) ExtractCommandArguments(headers, query, payload *map[string]inter // ExtractCommandArgumentsForEnv creates a list of arguments in key=value // format, based on the PassEnvironmentToCommand property that is ready to be used // with exec.Command(). -func (h *Hook) ExtractCommandArgumentsForEnv(headers, query, payload *map[string]interface{}) ([]string, []error) { +func (h *Hook) ExtractCommandArgumentsForEnv(headers, query, payload *map[string]interface{}, context *map[string]interface{}) ([]string, []error) { args := make([]string, 0) errors := make([]error, 0) for i := range h.PassEnvironmentToCommand { - if arg, ok := h.PassEnvironmentToCommand[i].Get(headers, query, payload); ok { + if arg, ok := h.PassEnvironmentToCommand[i].Get(headers, query, payload, context); ok { if h.PassEnvironmentToCommand[i].EnvName != "" { // first try to use the EnvName if specified args = append(args, h.PassEnvironmentToCommand[i].EnvName+"="+arg) @@ -552,11 +558,11 @@ type FileParameter struct { // ExtractCommandArgumentsForFile creates a list of arguments in key=value // format, based on the PassFileToCommand property that is ready to be used // with exec.Command(). -func (h *Hook) ExtractCommandArgumentsForFile(headers, query, payload *map[string]interface{}) ([]FileParameter, []error) { +func (h *Hook) ExtractCommandArgumentsForFile(headers, query, payload *map[string]interface{}, context *map[string]interface{}) ([]FileParameter, []error) { args := make([]FileParameter, 0) errors := make([]error, 0) for i := range h.PassFileToCommand { - if arg, ok := h.PassFileToCommand[i].Get(headers, query, payload); ok { + if arg, ok := h.PassFileToCommand[i].Get(headers, query, payload, context); ok { if h.PassFileToCommand[i].EnvName == "" { // if no environment-variable name is set, fall-back on the name @@ -664,16 +670,16 @@ type Rules struct { // Evaluate finds the first rule property that is not nil and returns the value // it evaluates to -func (r Rules) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { +func (r Rules) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { switch { case r.And != nil: - return r.And.Evaluate(headers, query, payload, body, remoteAddr) + return r.And.Evaluate(headers, query, payload, context, body, remoteAddr) case r.Or != nil: - return r.Or.Evaluate(headers, query, payload, body, remoteAddr) + return r.Or.Evaluate(headers, query, payload, context, body, remoteAddr) case r.Not != nil: - return r.Not.Evaluate(headers, query, payload, body, remoteAddr) + return r.Not.Evaluate(headers, query, payload, context, body, remoteAddr) case r.Match != nil: - return r.Match.Evaluate(headers, query, payload, body, remoteAddr) + return r.Match.Evaluate(headers, query, payload, context, body, remoteAddr) } return false, nil @@ -683,11 +689,11 @@ func (r Rules) Evaluate(headers, query, payload *map[string]interface{}, body *[ type AndRule []Rules // Evaluate AndRule will return true if and only if all of ChildRules evaluate to true -func (r AndRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { +func (r AndRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { res := true for _, v := range r { - rv, err := v.Evaluate(headers, query, payload, body, remoteAddr) + rv, err := v.Evaluate(headers, query, payload, context, body, remoteAddr) if err != nil { return false, err } @@ -705,11 +711,11 @@ func (r AndRule) Evaluate(headers, query, payload *map[string]interface{}, body type OrRule []Rules // Evaluate OrRule will return true if any of ChildRules evaluate to true -func (r OrRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { +func (r OrRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { res := false for _, v := range r { - rv, err := v.Evaluate(headers, query, payload, body, remoteAddr) + rv, err := v.Evaluate(headers, query, payload, context, body, remoteAddr) if err != nil { return false, err } @@ -727,8 +733,8 @@ func (r OrRule) Evaluate(headers, query, payload *map[string]interface{}, body * type NotRule Rules // Evaluate NotRule will return true if and only if ChildRule evaluates to false -func (r NotRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { - rv, err := Rules(r).Evaluate(headers, query, payload, body, remoteAddr) +func (r NotRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { + rv, err := Rules(r).Evaluate(headers, query, payload, context, body, remoteAddr) return !rv, err } @@ -753,15 +759,16 @@ const ( ) // Evaluate MatchRule will return based on the type -func (r MatchRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { +func (r MatchRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { if r.Type == IPWhitelist { return CheckIPWhitelist(remoteAddr, r.IPRange) } + if r.Type == ScalrSignature { return CheckScalrSignature(*headers, *body, r.Secret, true) } - if arg, ok := r.Parameter.Get(headers, query, payload); ok { + if arg, ok := r.Parameter.Get(headers, query, payload, context); ok { switch r.Type { case MatchValue: return arg == r.Value, nil diff --git a/webhook.go b/webhook.go index 16cbcbe..f5a9cba 100644 --- a/webhook.go +++ b/webhook.go @@ -4,6 +4,7 @@ import ( "encoding/json" "flag" "fmt" + "io" "io/ioutil" "log" "net/http" @@ -205,7 +206,6 @@ func main() { } func hookHandler(w http.ResponseWriter, r *http.Request) { - // generate a request id for logging rid := uuid.NewV4().String()[:6] @@ -231,8 +231,87 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { // parse query variables query := valuesToMap(r.URL.Query()) - // parse body - var payload map[string]interface{} + // parse context + var context map[string]interface{} + + if matchedHook.ContextProviderCommand != "" { + // check the command exists + cmdPath, err := exec.LookPath(matchedHook.ContextProviderCommand) + if err != nil { + // give a last chance, maybe it's a relative path + relativeToCwd := filepath.Join(matchedHook.CommandWorkingDirectory, matchedHook.ContextProviderCommand) + // check the command exists + cmdPath, err = exec.LookPath(relativeToCwd) + } + + if err != nil { + log.Printf("[%s] unable to locate context provider command: '%s', %+v\n", rid, matchedHook.ContextProviderCommand, err) + // check if parameters specified in context-provider-command by mistake + if strings.IndexByte(matchedHook.ContextProviderCommand, ' ') != -1 { + s := strings.Fields(matchedHook.ContextProviderCommand)[0] + log.Printf("[%s] please use a wrapper script to provide arguments to context provider command for '%s'\n", rid, s) + } + } else { + contextProviderCommandStdin := struct { + HookID string `json:"hookID"` + Method string `json:"method"` + Body string `json:"body"` + RemoteAddress string `json:"remoteAddress"` + URI string `json:"URI"` + Host string `json:"host"` + Headers http.Header `json:"headers"` + Query url.Values `json:"query"` + }{ + HookID: matchedHook.ID, + Method: r.Method, + Body: string(body), + RemoteAddress: r.RemoteAddr, + URI: r.RequestURI, + Host: r.Host, + Headers: r.Header, + Query: r.URL.Query(), + } + + stdinJSON, err := json.Marshal(contextProviderCommandStdin) + + if err != nil { + log.Printf("[%s] unable to encode context as JSON string for the context provider command: %+v\n", rid, err) + } else { + cmd := exec.Command(cmdPath) + cmd.Dir = matchedHook.CommandWorkingDirectory + cmd.Env = append(os.Environ()) + stdin, err := cmd.StdinPipe() + + if err != nil { + log.Printf("[%s] unable to acquire stdin pipe for the context provider command: %+v\n", rid, err) + } else { + _, err := io.WriteString(stdin, string(stdinJSON)) + stdin.Close() + if err != nil { + log.Printf("[%s] unable to write to context provider command stdin: %+v\n", rid, err) + } else { + log.Printf("[%s] executing context provider command %s (%s) using %s as cwd\n", rid, matchedHook.ContextProviderCommand, cmd.Path, cmd.Dir) + out, err := cmd.CombinedOutput() + + if err != nil { + log.Printf("[%s] unable to execute context provider command: %+v\n", rid, err) + } else { + log.Printf("[%s] got context provider command output: %+v\n", rid, string(out)) + + decoder := json.NewDecoder(strings.NewReader(string(out))) + decoder.UseNumber() + + err := decoder.Decode(&context) + + if err != nil { + log.Printf("[%s] unable to parse context provider command output: %+v\n", rid, err) + } + } + } + } + } + } + } // set contentType to IncomingPayloadContentType or header value contentType := r.Header.Get("Content-Type") @@ -240,6 +319,9 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { contentType = matchedHook.IncomingPayloadContentType } + // parse body + var payload map[string]interface{} + if strings.Contains(contentType, "json") { decoder := json.NewDecoder(strings.NewReader(string(body))) decoder.UseNumber() @@ -259,7 +341,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { } // handle hook - errors := matchedHook.ParseJSONParameters(&headers, &query, &payload) + errors := matchedHook.ParseJSONParameters(&headers, &query, &payload, &context) for _, err := range errors { log.Printf("[%s] error parsing JSON parameters: %s\n", rid, err) } @@ -269,7 +351,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { if matchedHook.TriggerRule == nil { ok = true } else { - ok, err = matchedHook.TriggerRule.Evaluate(&headers, &query, &payload, &body, r.RemoteAddr) + ok, err = matchedHook.TriggerRule.Evaluate(&headers, &query, &payload, &context, &body, r.RemoteAddr) if err != nil { msg := fmt.Sprintf("[%s] error evaluating hook: %s", rid, err) log.Print(msg) @@ -287,7 +369,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { } if matchedHook.CaptureCommandOutput { - response, err := handleHook(matchedHook, rid, &headers, &query, &payload, &body) + response, err := handleHook(matchedHook, rid, &headers, &query, &payload, &context, &body) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -305,7 +387,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, response) } } else { - go handleHook(matchedHook, rid, &headers, &query, &payload, &body) + go handleHook(matchedHook, rid, &headers, &query, &payload, &context, &body) // Check if a success return code is configured for the hook if matchedHook.SuccessHttpResponseCode != 0 { @@ -332,25 +414,25 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { } } -func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]interface{}, body *[]byte) (string, error) { +func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte) (string, error) { var errors []error // check the command exists cmdPath, err := exec.LookPath(h.ExecuteCommand) if err != nil { - // give a last chance, maybe is a relative path - relativeToCwd := filepath.Join(h.CommandWorkingDirectory, h.ExecuteCommand) + // give a last chance, maybe is a relative path + relativeToCwd := filepath.Join(h.CommandWorkingDirectory, h.ExecuteCommand) // check the command exists cmdPath, err = exec.LookPath(relativeToCwd) } if err != nil { - log.Printf("unable to locate command: '%s'", h.ExecuteCommand) + log.Printf("[%s] unable to locate command: '%s'\n", rid, h.ExecuteCommand) // check if parameters specified in execute-command by mistake if strings.IndexByte(h.ExecuteCommand, ' ') != -1 { s := strings.Fields(h.ExecuteCommand)[0] - log.Printf("use 'pass-arguments-to-command' to specify args for '%s'", s) + log.Printf("[%s] please use 'pass-arguments-to-command' to specify args for '%s'\n", rid, s) } return "", err @@ -359,19 +441,19 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in cmd := exec.Command(cmdPath) cmd.Dir = h.CommandWorkingDirectory - cmd.Args, errors = h.ExtractCommandArguments(headers, query, payload) + cmd.Args, errors = h.ExtractCommandArguments(headers, query, payload, context) for _, err := range errors { log.Printf("[%s] error extracting command arguments: %s\n", rid, err) } var envs []string - envs, errors = h.ExtractCommandArgumentsForEnv(headers, query, payload) + envs, errors = h.ExtractCommandArgumentsForEnv(headers, query, payload, context) for _, err := range errors { log.Printf("[%s] error extracting command arguments for environment: %s\n", rid, err) } - files, errors := h.ExtractCommandArgumentsForFile(headers, query, payload) + files, errors := h.ExtractCommandArgumentsForFile(headers, query, payload, context) for _, err := range errors { log.Printf("[%s] error extracting command arguments for file: %s\n", rid, err) @@ -380,16 +462,16 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in for i := range files { tmpfile, err := ioutil.TempFile(h.CommandWorkingDirectory, files[i].EnvName) if err != nil { - log.Printf("[%s] error creating temp file [%s]", rid, err) + log.Printf("[%s] error creating temp file [%s]\n", rid, err) continue } log.Printf("[%s] writing env %s file %s", rid, files[i].EnvName, tmpfile.Name()) if _, err := tmpfile.Write(files[i].Data); err != nil { - log.Printf("[%s] error writing file %s [%s]", rid, tmpfile.Name(), err) + log.Printf("[%s] error writing file %s [%s]\n", rid, tmpfile.Name(), err) continue } if err := tmpfile.Close(); err != nil { - log.Printf("[%s] error closing file %s [%s]", rid, tmpfile.Name(), err) + log.Printf("[%s] error closing file %s [%s]\n", rid, tmpfile.Name(), err) continue } @@ -414,7 +496,7 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in log.Printf("[%s] removing file %s\n", rid, files[i].File.Name()) err := os.Remove(files[i].File.Name()) if err != nil { - log.Printf("[%s] error removing file %s [%s]", rid, files[i].File.Name(), err) + log.Printf("[%s] error removing file %s [%s]\n", rid, files[i].File.Name(), err) } } }