Skip to content

Commit

Permalink
small refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Mzack9999 committed Mar 27, 2023
1 parent c546046 commit 7d973ac
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 38 deletions.
30 changes: 19 additions & 11 deletions dsl.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ import (
"strings"
"time"

"github.com/pkg/errors"

"github.com/Knetic/govaluate"
"github.com/asaskevich/govalidator"
"github.com/hashicorp/go-version"
"github.com/kataras/jwt"
"github.com/logrusorgru/aurora"
"github.com/pkg/errors"
"github.com/spaolacci/murmur3"

"github.com/projectdiscovery/dsl/deserialization"
"github.com/projectdiscovery/dsl/randomip"
"github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/mapcidr"
stringsutil "github.com/projectdiscovery/utils/strings"
)

const (
Expand All @@ -57,8 +57,8 @@ var (
// DefaultHelperFunctions is a pre-compiled list of govaluate DSL functions
DefaultHelperFunctions map[string]govaluate.ExpressionFunction

functionSignaturePattern = regexp.MustCompile(`(\w+)\s*\((?:([\w\d,\s]+)\s+([.\w\d{}&*]+))?\)([\s.\w\d{}&*]+)?`)
dateFormatRegex = regexp.MustCompile("%([A-Za-z])")
funcSignatureRegex = regexp.MustCompile(`(\w+)\s*\((?:([\w\d,\s]+)\s+([.\w\d{}&*]+))?\)([\s.\w\d{}&*]+)?`)
dateFormatRegex = regexp.MustCompile("%([A-Za-z])")
)

type dslFunction struct {
Expand All @@ -76,6 +76,8 @@ var defaultDateTimeLayouts = []string{
"2006-01-02",
}

var PrintDebugCallback func(args ...interface{}) error

func init() {
tempDslFunctions := map[string]func(string) dslFunction{
"len": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
Expand Down Expand Up @@ -179,7 +181,7 @@ func init() {
return strings.TrimSuffix(toString(args[0]), toString(args[1])), nil
}),
"reverse": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
return reverseString(toString(args[0])), nil
return stringsutil.Reverse(toString(args[0])), nil
}),
"base64": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
return base64.StdEncoding.EncodeToString([]byte(toString(args[0]))), nil
Expand Down Expand Up @@ -481,7 +483,7 @@ func init() {
return compiled.MatchString(toString(args[1])), nil
}),
"regex_all": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
for _, arg := range toSlice(args[1]) {
for _, arg := range toStringSlice(args[1]) {
compiled, err := Regex(toString(arg))
if err != nil {
return nil, err
Expand All @@ -494,7 +496,7 @@ func init() {
}),

"regex_any": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
for _, arg := range toSlice(args[1]) {
for _, arg := range toStringSlice(args[1]) {
compiled, err := Regex(toString(arg))
if err != nil {
return nil, err
Expand All @@ -507,7 +509,7 @@ func init() {
}),

"equals_any": makeDslFunction(2, func(args ...interface{}) (interface{}, error) {
for _, arg := range toSlice(args[1]) {
for _, arg := range toStringSlice(args[1]) {
if args[0] == arg {
return true, nil
}
Expand Down Expand Up @@ -747,7 +749,13 @@ func init() {
if len(args) < 1 {
return nil, ErrinvalidDslFunction
}
gologger.Info().Msgf("print_debug value: %s", fmt.Sprint(args))
if PrintDebugCallback != nil {
if err := PrintDebugCallback(args...); err != nil {
return nil, err
}
} else {
gologger.Info().Msgf("print_debug value: %s", fmt.Sprint(args))
}
return true, nil
},
),
Expand All @@ -760,7 +768,7 @@ func init() {
sint, err := strconv.ParseFloat(argStr, 64)
return float64(sint), err
}
return nil, errors.Errorf("%v could not be converted to int", argStr)
return nil, fmt.Errorf("%v could not be converted to int", argStr)
}),
"to_string": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
return toString(args[0]), nil
Expand Down Expand Up @@ -1087,7 +1095,7 @@ func colorizeDslFunctionSignatures() []string {
result := make([]string, 0, len(signatures))

for _, signature := range signatures {
subMatchSlices := functionSignaturePattern.FindAllStringSubmatch(signature, -1)
subMatchSlices := funcSignatureRegex.FindAllStringSubmatch(signature, -1)
if len(subMatchSlices) != 1 {
result = append(result, signature)
continue
Expand Down
33 changes: 16 additions & 17 deletions dsl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/Knetic/govaluate"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -81,11 +80,11 @@ func TestDslFunctionSignatures(t *testing.T) {
actualResult, err := helperFunctions[methodName](currentTestCase.arguments...)

if currentTestCase.err == "" {
assert.Nil(t, err)
require.Nil(t, err)
} else {
assert.Equal(t, err.Error(), currentTestCase.err)
require.Equal(t, err.Error(), currentTestCase.err)
}
assert.Equal(t, currentTestCase.expected, actualResult)
require.Equal(t, currentTestCase.expected, actualResult)
})
}
}
Expand Down Expand Up @@ -176,7 +175,7 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
`

signatures := GetPrintableDslFunctionSignatures(true)
assert.Equal(t, expected, signatures)
require.Equal(t, expected, signatures)

coloredSignatures := GetPrintableDslFunctionSignatures(false)
require.Contains(t, coloredSignatures, `[93maes_cbc(arg1, arg2, arg3 interface{}) interface{}`, "could not get colored signatures")
Expand Down Expand Up @@ -272,13 +271,13 @@ func TestDslExpressions(t *testing.T) {
`join(", ", split(hex_encode("abcdefg"), 2))`: "61, 62, 63, 64, 65, 66, 67",
`json_minify("{ \"name\": \"John Doe\", \"foo\": \"bar\" }")`: "{\"foo\":\"bar\",\"name\":\"John Doe\"}",
`json_prettify("{\"foo\":\"bar\",\"name\":\"John Doe\"}")`: "{\n \"foo\": \"bar\",\n \"name\": \"John Doe\"\n}",
`ip_format('127.0.0.1', '1')`: "127.0.0.1",
`ip_format('127.0.0.1', '3')`: "0177.0.0.01",
`ip_format('127.0.0.1', '5')`: "281472812449793",
`ip_format('127.0.1.0', '11')`: "127.0.256",
`ip_format('127.0.0.1', '1')`: "127.0.0.1",
`ip_format('127.0.0.1', '3')`: "0177.0.0.01",
`ip_format('127.0.0.1', '5')`: "281472812449793",
`ip_format('127.0.1.0', '11')`: "127.0.256",
}

testDslExpressionScenarios(t, dslExpressions)
testDslExpressions(t, dslExpressions)
}

func TestDateTimeDSLFunction(t *testing.T) {
Expand Down Expand Up @@ -331,7 +330,7 @@ func TestDateTimeDslExpressions(t *testing.T) {
`date_time("02-01-2006", 1642032000)`: time.Date(2022, 01, 13, 0, 0, 0, 0, time.UTC).Local().Format("02-01-2006"),
}

testDslExpressionScenarios(t, dslExpressions)
testDslExpressions(t, dslExpressions)
})

t.Run("to_unix_time(input string) int", func(t *testing.T) {
Expand Down Expand Up @@ -359,7 +358,7 @@ func TestDateTimeDslExpressions(t *testing.T) {
dslExpression := fmt.Sprintf(`to_unix_time("%s")`, dateTimeInput)
t.Run(dslExpression, func(t *testing.T) {
actual := evaluateExpression(t, dslExpression)
assert.Equal(t, expectedTime.Unix(), actual)
require.Equal(t, expectedTime.Unix(), actual)
})
}
})
Expand All @@ -384,7 +383,7 @@ func TestDateTimeDslExpressions(t *testing.T) {
dslExpression := fmt.Sprintf(`to_unix_time("%s", "%s")`, testScenario.inputDateTime, testScenario.layout)
t.Run(dslExpression, func(t *testing.T) {
actual := evaluateExpression(t, dslExpression)
assert.Equal(t, testScenario.expectedTime.Unix(), actual)
require.Equal(t, testScenario.expectedTime.Unix(), actual)
})
}
})
Expand Down Expand Up @@ -421,7 +420,7 @@ func TestRandDslExpressions(t *testing.T) {

stringResult := toString(actualResult)

assert.True(t, compiledTester.MatchString(stringResult), "The result '%s' of '%s' expression does not match the expected regex: '%s'", actualResult, randDslExpression, regexTester)
require.True(t, compiledTester.MatchString(stringResult), "The result '%s' of '%s' expression does not match the expected regex: '%s'", actualResult, randDslExpression, regexTester)
})
}
}
Expand All @@ -444,7 +443,7 @@ func TestRandIntDslExpressions(t *testing.T) {
actualResult := evaluateExpression(t, randIntDslExpression)

actualIntResult := actualResult.(int)
assert.True(t, tester(actualIntResult), "The '%d' result of the '%s' expression, does not match th expected validation function.", actualIntResult, randIntDslExpression)
require.True(t, tester(actualIntResult), "The '%d' result of the '%s' expression, does not match th expected validation function.", actualIntResult, randIntDslExpression)
})
}
}
Expand All @@ -463,13 +462,13 @@ func evaluateExpression(t *testing.T, dslExpression string) interface{} {
return actualResult
}

func testDslExpressionScenarios(t *testing.T, dslExpressions map[string]interface{}) {
func testDslExpressions(t *testing.T, dslExpressions map[string]interface{}) {
for dslExpression, expectedResult := range dslExpressions {
t.Run(dslExpression, func(t *testing.T) {
actualResult := evaluateExpression(t, dslExpression)

if expectedResult != nil {
assert.Equal(t, expectedResult, actualResult)
require.Equal(t, expectedResult, actualResult)
}

fmt.Printf("%s: \t %v\n", dslExpression, actualResult)
Expand Down
11 changes: 1 addition & 10 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func toString(data interface{}) string {
}
}

func toSlice(v interface{}) (m []string) {
func toStringSlice(v interface{}) (m []string) {
switch vv := v.(type) {
case []string:
for _, item := range vv {
Expand All @@ -70,15 +70,6 @@ func toSlice(v interface{}) (m []string) {
return
}

func reverseString(s string) string {
runes := []rune(s)
for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {
runes[i], runes[j] = runes[j], runes[i]
}

return string(runes)
}

func insertInto(s string, interval int, sep rune) string {
var buffer bytes.Buffer
before := interval - 1
Expand Down

0 comments on commit 7d973ac

Please sign in to comment.