Compare commits

...

1 Commits

Author SHA1 Message Date
Adnan Hajdarevic
d97d94537a Add IsNull and Exists types to the Match rule 2021-01-26 22:19:46 +01:00
2 changed files with 34 additions and 27 deletions

View File

@ -394,7 +394,7 @@ func GetParameter(s string, params interface{}) (interface{}, error) {
return v, nil
}
// Checked for dotted references
// Check for dotted references
p := strings.SplitN(s, ".", 2)
if pValue, ok := params.(map[string]interface{})[p[0]]; ok {
if len(p) > 1 {
@ -411,23 +411,23 @@ func GetParameter(s string, params interface{}) (interface{}, error) {
// ExtractParameterAsString extracts value from interface{} as string based on
// the passed string. Complex data types are rendered as JSON instead of the Go
// Stringer format.
func ExtractParameterAsString(s string, params interface{}) (string, error) {
func ExtractParameterAsString(s string, params interface{}) (string, interface{}, error) {
pValue, err := GetParameter(s, params)
if err != nil {
return "", err
return "", nil, err
}
switch v := reflect.ValueOf(pValue); v.Kind() {
case reflect.Array, reflect.Map, reflect.Slice:
r, err := json.Marshal(pValue)
if err != nil {
return "", err
return "", pValue, err
}
return string(r), nil
return string(r), r, nil
default:
return fmt.Sprintf("%v", pValue), nil
return fmt.Sprintf("%v", pValue), pValue, nil
}
}
@ -442,7 +442,7 @@ type Argument struct {
// Get Argument method returns the value for the Argument's key name
// based on the Argument's source
func (ha *Argument) Get(r *Request) (string, error) {
func (ha *Argument) Get(r *Request) (string, interface{}, error) {
var source *map[string]interface{}
key := ha.Name
@ -458,55 +458,55 @@ func (ha *Argument) Get(r *Request) (string, error) {
source = &r.Payload
case SourceString:
return ha.Name, nil
return ha.Name, ha.Name, nil
case SourceRawRequestBody:
return string(r.Body), nil
return string(r.Body), r.Body, nil
case SourceRequest:
if r == nil || r.RawRequest == nil {
return "", errors.New("request is nil")
return "", nil, errors.New("request is nil")
}
switch strings.ToLower(ha.Name) {
case "remote-addr":
return r.RawRequest.RemoteAddr, nil
return r.RawRequest.RemoteAddr, r.RawRequest.RemoteAddr, nil
case "method":
return r.RawRequest.Method, nil
return r.RawRequest.Method, r.RawRequest.Method, nil
default:
return "", fmt.Errorf("unsupported request key: %q", ha.Name)
return "", nil, fmt.Errorf("unsupported request key: %q", ha.Name)
}
case SourceEntirePayload:
res, err := json.Marshal(&r.Payload)
if err != nil {
return "", err
return "", r.Payload, err
}
return string(res), nil
return string(res), r.Payload, nil
case SourceEntireHeaders:
res, err := json.Marshal(&r.Headers)
if err != nil {
return "", err
return "", r.Headers, err
}
return string(res), nil
return string(res), r.Headers, nil
case SourceEntireQuery:
res, err := json.Marshal(&r.Query)
if err != nil {
return "", err
return "", r.Query, err
}
return string(res), nil
return string(res), r.Query, nil
}
if source != nil {
return ExtractParameterAsString(key, *source)
}
return "", errors.New("no source for value retrieval")
return "", nil, errors.New("no source for value retrieval")
}
// Header is a structure containing header name and it's value
@ -589,7 +589,7 @@ func (h *Hook) ParseJSONParameters(r *Request) []error {
errors := make([]error, 0)
for i := range h.JSONStringParameters {
arg, err := h.JSONStringParameters[i].Get(r)
arg, _, err := h.JSONStringParameters[i].Get(r)
if err != nil {
errors = append(errors, &ArgumentError{h.JSONStringParameters[i]})
} else {
@ -645,7 +645,7 @@ func (h *Hook) ExtractCommandArguments(r *Request) ([]string, []error) {
args = append(args, h.ExecuteCommand)
for i := range h.PassArgumentsToCommand {
arg, err := h.PassArgumentsToCommand[i].Get(r)
arg, _, err := h.PassArgumentsToCommand[i].Get(r)
if err != nil {
args = append(args, "")
errors = append(errors, &ArgumentError{h.PassArgumentsToCommand[i]})
@ -669,7 +669,7 @@ func (h *Hook) ExtractCommandArgumentsForEnv(r *Request) ([]string, []error) {
args := make([]string, 0)
errors := make([]error, 0)
for i := range h.PassEnvironmentToCommand {
arg, err := h.PassEnvironmentToCommand[i].Get(r)
arg, _, err := h.PassEnvironmentToCommand[i].Get(r)
if err != nil {
errors = append(errors, &ArgumentError{h.PassEnvironmentToCommand[i]})
continue
@ -705,7 +705,7 @@ func (h *Hook) ExtractCommandArgumentsForFile(r *Request) ([]FileParameter, []er
args := make([]FileParameter, 0)
errors := make([]error, 0)
for i := range h.PassFileToCommand {
arg, err := h.PassFileToCommand[i].Get(r)
arg, _, err := h.PassFileToCommand[i].Get(r)
if err != nil {
errors = append(errors, &ArgumentError{h.PassFileToCommand[i]})
continue
@ -898,6 +898,8 @@ type MatchRule struct {
const (
MatchValue string = "value"
MatchRegex string = "regex"
MatchIsNull string = "is-null"
MatchExists string = "exists"
MatchHMACSHA1 string = "payload-hmac-sha1"
MatchHMACSHA256 string = "payload-hmac-sha256"
MatchHMACSHA512 string = "payload-hmac-sha512"
@ -917,13 +919,17 @@ func (r MatchRule) Evaluate(req *Request) (bool, error) {
return CheckScalrSignature(req, r.Secret, true)
}
arg, err := r.Parameter.Get(req)
arg, rawValue, err := r.Parameter.Get(req)
if err == nil {
switch r.Type {
case MatchValue:
return compare(arg, r.Value), nil
case MatchRegex:
return regexp.MatchString(r.Regex, arg)
case MatchIsNull:
return rawValue == nil, nil
case MatchExists:
return true, nil
case MatchHashSHA1:
log.Print(`warn: use of deprecated option payload-hash-sha1; use payload-hmac-sha1 instead`)
fallthrough
@ -944,6 +950,7 @@ func (r MatchRule) Evaluate(req *Request) (bool, error) {
return err == nil, err
}
}
return false, err
}

View File

@ -245,7 +245,7 @@ var extractParameterTests = []struct {
func TestExtractParameter(t *testing.T) {
for _, tt := range extractParameterTests {
value, err := ExtractParameterAsString(tt.s, tt.params)
value, _, err := ExtractParameterAsString(tt.s, tt.params)
if (err == nil) != tt.ok || value != tt.value {
t.Errorf("failed to extract parameter %q:\nexpected {value:%#v, ok:%#v},\ngot {value:%#v, err:%v}", tt.s, tt.value, tt.ok, value, err)
}
@ -281,7 +281,7 @@ func TestArgumentGet(t *testing.T) {
Payload: tt.payload,
RawRequest: tt.request,
}
value, err := a.Get(r)
value, _, err := a.Get(r)
if (err == nil) != tt.ok || value != tt.value {
t.Errorf("failed to get {%q, %q}:\nexpected {value:%#v, ok:%#v},\ngot {value:%#v, err:%v}", tt.source, tt.name, tt.value, tt.ok, value, err)
}