diff --git a/hook/hook.go b/hook/hook.go index 6fcda3e..a271630 100644 --- a/hook/hook.go +++ b/hook/hook.go @@ -381,6 +381,7 @@ type Hook struct { PassEnvironmentToCommand []Argument `json:"pass-environment-to-command,omitempty"` PassArgumentsToCommand []Argument `json:"pass-arguments-to-command,omitempty"` JSONStringParameters []Argument `json:"parse-parameters-as-json,omitempty"` + MaxConcurrency int `json:"max-concurrency,omiempty"` TriggerRule *Rules `json:"trigger-rule,omitempty"` TriggerRuleMismatchHttpResponseCode int `json:"trigger-rule-mismatch-http-response-code,omitempty"` } diff --git a/webhook.go b/webhook.go index 1034290..05989da 100644 --- a/webhook.go +++ b/webhook.go @@ -43,6 +43,8 @@ var ( watcher *fsnotify.Watcher signals chan os.Signal + + limits = make(map[string]chan struct{}) ) func matchLoadedHook(id string) *hook.Hook { @@ -108,7 +110,15 @@ func main() { if matchLoadedHook(hook.ID) != nil { log.Fatalf("error: hook with the id %s has already been loaded!\nplease check your hooks file for duplicate hooks ids!\n", hook.ID) } - log.Printf("\tloaded: %s\n", hook.ID) + + msg := fmt.Sprintf("\tloaded: %s", hook.ID) + + // initialize concurrency map + if hook.MaxConcurrency > 0 { + limits[hook.ID] = make(chan struct{}, hook.MaxConcurrency) + msg = fmt.Sprintf("%s (max: %d)", msg, hook.MaxConcurrency) + } + log.Println(msg) } loadedHooksFromFiles[hooksFilePath] = newHooks @@ -208,6 +218,18 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { if matchedHook := matchLoadedHook(id); matchedHook != nil { log.Printf("%s got matched\n", id) + // check if we have concurrency limits + if _, ok := limits[id]; ok { + if len(limits[id]) == cap(limits[id]) { + log.Printf("reached concurrency limit for: %s (max=%d)", id, len(limits[id])) + w.WriteHeader(http.StatusTooManyRequests) + fmt.Fprintf(w, "Error occurred while evaluating hook rules.") + return + } + defer func() { <-limits[id] }() + limits[id] <- struct{}{} + } + body, err := ioutil.ReadAll(r.Body) if err != nil { log.Printf("error reading the request body. %+v\n", err)