Skip to content

Commit

Permalink
Transparent Memoization via func Annotation (#4742)
Browse files Browse the repository at this point in the history
* initial implementation with manual code generation

* testing generation

* refactor to package methods + auto memoize

* more memos

* fixing signatures

* refactor

* adding gen util

* adding util

* regenerate memoized files

---------

Co-authored-by: Tarun Koyalwar <tarun@projectdiscovery.io>
  • Loading branch information
Mzack9999 and tarunKoyalwar authored Mar 1, 2024
1 parent e7252a4 commit 4c7a0f4
Show file tree
Hide file tree
Showing 37 changed files with 787 additions and 18 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,7 @@ jsupdate:
ts:
$(GOBUILD) $(GOFLAGS) -ldflags '$(LDFLAGS)' -o "tsgen" pkg/js/devtools/tsgen/cmd/tsgen/main.go
./tsgen -dir pkg/js/libs -out pkg/js/generated/ts
memogen:
$(GOBUILD) $(GOFLAGS) -ldflags '$(LDFLAGS)' -o "memogen" cmd/memogen/memogen.go
./memogen -src pkg/js/libs -tpl cmd/memogen/function.tpl

28 changes: 28 additions & 0 deletions cmd/memogen/function.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Warning - This is generated code
package {{.SourcePackage}}

import (
"github.com/projectdiscovery/utils/memoize"

{{range .Imports}}
{{.Name}} {{.Path}}
{{end}}
)

{{range .Functions}}
{{ .SignatureWithPrefix "memoized" }} {
hash := "{{ .Name }}" {{range .Params}} + ":" + fmt.Sprint({{.Name}}) {{end}}

v, err, _ := protocolstate.Memoizer.Do(hash, func() (interface{}, error) {
return {{.Name}}({{.ParamsNames}})
})
if err != nil {
return {{.ResultFirstFieldDefaultValue}}, err
}
if value, ok := v.({{.ResultFirstFieldType}}); ok {
return value, nil
}

return {{.ResultFirstFieldDefaultValue}}, errors.New("could not convert cached result")
}
{{end}}
77 changes: 77 additions & 0 deletions cmd/memogen/memogen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// this small cli tool is specific for those functions with arbitrary parameters and with result-error tuple as return values
// func(x,y) => result, error
// it works by creating a new memoized version of the functions in the same path as memo.original.file.go
// some parts are specific for nuclei and hardcoded within the template
package main

import (
"flag"
"io/fs"
"log"
"os"
"path/filepath"

"github.com/projectdiscovery/utils/memoize"
stringsutil "github.com/projectdiscovery/utils/strings"
)

var (
srcPath = flag.String("src", "", "nuclei source path")
tplPath = flag.String("tpl", "function.tpl", "template path")
tplSrc []byte
)

func main() {
flag.Parse()

var err error
tplSrc, err = os.ReadFile(*tplPath)
if err != nil {
log.Fatal(err)
}

err = filepath.Walk(*srcPath, walk)
if err != nil {
log.Fatal(err)
}
}

func walk(path string, info fs.FileInfo, err error) error {
if info.IsDir() {
return nil
}

if err != nil {
return err
}

ext := filepath.Ext(path)
base := filepath.Base(path)

if !stringsutil.EqualFoldAny(ext, ".go") {
return nil
}

basePath := filepath.Dir(path)
outPath := filepath.Join(basePath, "memo."+base)

// filename := filepath.Base(path)
data, err := os.ReadFile(path)
if err != nil {
return err
}
if !stringsutil.ContainsAnyI(string(data), "@memo") {
return nil
}
log.Println("processing:", path)
out, err := memoize.Src(string(tplSrc), path, data, "")
if err != nil {
return err
}

if err := os.WriteFile(outPath, out, os.ModePerm); err != nil {
return err
}

return nil
}
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ require (
github.com/projectdiscovery/tlsx v1.1.6
github.com/projectdiscovery/uncover v1.0.7
github.com/projectdiscovery/useragent v0.0.39
github.com/projectdiscovery/utils v0.0.80
github.com/projectdiscovery/utils v0.0.81
github.com/projectdiscovery/wappalyzergo v0.0.111
github.com/redis/go-redis/v9 v9.1.0
github.com/sashabaranov/go-openai v1.15.3
Expand Down Expand Up @@ -123,6 +123,7 @@ require (
github.com/bits-and-blooms/bloom/v3 v3.5.0 // indirect
github.com/bytedance/sonic v1.9.1 // indirect
github.com/cenkalti/backoff/v4 v4.2.1 // indirect
github.com/cespare/xxhash v1.1.0 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/cheggaaa/pb/v3 v3.1.4 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
Expand All @@ -132,7 +133,6 @@ require (
github.com/corpix/uarand v0.2.0 // indirect
github.com/cyphar/filepath-securejoin v0.2.4 // indirect
github.com/davidmz/go-pageant v1.0.2 // indirect
github.com/denisbrodbeck/machineid v1.0.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/docker/cli v24.0.5+incompatible // indirect
Expand Down Expand Up @@ -188,6 +188,7 @@ require (
github.com/projectdiscovery/asnmap v1.0.6 // indirect
github.com/projectdiscovery/cdncheck v1.0.9 // indirect
github.com/projectdiscovery/freeport v0.0.5 // indirect
github.com/projectdiscovery/machineid v0.0.0-20240226150047-2e2c51e35983 // indirect
github.com/projectdiscovery/stringsutil v0.0.2 // indirect
github.com/quic-go/quic-go v0.40.1 // indirect
github.com/refraction-networking/utls v1.6.1 // indirect
Expand Down
13 changes: 9 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ github.com/Mzack9999/ldapserver v1.0.2-0.20211229000134-b44a0d6ad0dd h1:RTWs+wEY
github.com/Mzack9999/ldapserver v1.0.2-0.20211229000134-b44a0d6ad0dd/go.mod h1:AqtPw7WNT0O69k+AbPKWVGYeW94TqgMW/g+Ppc8AZr4=
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw=
github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk=
github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/ProtonMail/go-crypto v0.0.0-20230217124315-7d5c6f04bbb8/go.mod h1:I0gYDMZ6Z5GRU7l58bNFSkPTFN6Yl12dsUlAZ8xy98g=
github.com/ProtonMail/go-crypto v0.0.0-20230828082145-3c4c8a2d2371 h1:kkhsdkhsCvIsutKu5zLMgWtgh9YxGCNAw8Ad8hjwfYg=
github.com/ProtonMail/go-crypto v0.0.0-20230828082145-3c4c8a2d2371/go.mod h1:EjAoLdwvbIOoOQr3ihjnSoLZRtE8azugULFRteWMNc0=
Expand Down Expand Up @@ -211,6 +213,8 @@ github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QH
github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM=
github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko=
github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc=
github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
Expand Down Expand Up @@ -260,8 +264,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davidmz/go-pageant v1.0.2 h1:bPblRCh5jGU+Uptpz6LgMZGD5hJoOt7otgT454WvHn0=
github.com/davidmz/go-pageant v1.0.2/go.mod h1:P2EDDnMqIwG5Rrp05dTRITj9z2zpGcD9efWSkTNKLIE=
github.com/denisbrodbeck/machineid v1.0.1 h1:geKr9qtkB876mXguW2X6TU4ZynleN6ezuMSRhl4D7AQ=
github.com/denisbrodbeck/machineid v1.0.1/go.mod h1:dJUwb7PTidGDeYyUBmXZ2GphQBbjJCrnectwCyxcUSI=
github.com/denisenkom/go-mssqldb v0.12.3 h1:pBSGx9Tq67pBOTLmxNuirNTeB8Vjmf886Kx+8Y+8shw=
github.com/denisenkom/go-mssqldb v0.12.3/go.mod h1:k0mtMFOnU+AihqFxPMiF05rtiDrorD1Vrm1KEz5hxDo=
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
Expand Down Expand Up @@ -836,6 +838,8 @@ github.com/projectdiscovery/httpx v1.3.9 h1:jDdoGH+5VVU/jI6dnai1DKNw9USPyCcw+tDh
github.com/projectdiscovery/httpx v1.3.9/go.mod h1:a/a5X6e2NLnS/+b3buFadGUpZSolnVkMA7KZdpCdg58=
github.com/projectdiscovery/interactsh v1.1.8 h1:mDD+f/oo2tV4Z1WyUync0tgYeJyuiS89Un64Gm6Pvgk=
github.com/projectdiscovery/interactsh v1.1.8/go.mod h1:E20ywFb7bL01GcOOk+6VZF48XZ8AZvYvBpULoBUSTbg=
github.com/projectdiscovery/machineid v0.0.0-20240226150047-2e2c51e35983 h1:ZScLodGSezQVwsQDtBSMFp72WDq0nNN+KE/5DHKY5QE=
github.com/projectdiscovery/machineid v0.0.0-20240226150047-2e2c51e35983/go.mod h1:3G3BRKui7nMuDFAZKR/M2hiOLtaOmyukT20g88qRQjI=
github.com/projectdiscovery/mapcidr v1.1.16 h1:rjj1w5D6hbTsUQXYClLcGdfBEy9bryclgi70t0vBggo=
github.com/projectdiscovery/mapcidr v1.1.16/go.mod h1:rGqpBhStdwOQ2uS62QM9qPsybwMwIhT7CTd2bxoHs8Q=
github.com/projectdiscovery/n3iwf v0.0.0-20230523120440-b8cd232ff1f5 h1:L/e8z8yw1pfT6bg35NiN7yd1XKtJap5Nk6lMwQ0RNi8=
Expand All @@ -862,8 +866,8 @@ github.com/projectdiscovery/uncover v1.0.7 h1:ut+2lTuvmftmveqF5RTjMWAgyLj8ltPQC7
github.com/projectdiscovery/uncover v1.0.7/go.mod h1:HFXgm1sRPuoN0D4oATljPIdmbo/EEh1wVuxQqo/dwFE=
github.com/projectdiscovery/useragent v0.0.39 h1:s2jyXdtjVo0MfYYkifx7irrOIoA0JhzhZaBkpcoWgV4=
github.com/projectdiscovery/useragent v0.0.39/go.mod h1:wO6GQImJ2IQ5K+GDggS/Rhg6IV9Z2Du6NbqC/um0g0w=
github.com/projectdiscovery/utils v0.0.80 h1:daFuQwhVRtQ14JZs3DnI9ubaX273S8V1dZ+x/vr+YbI=
github.com/projectdiscovery/utils v0.0.80/go.mod h1:WXm3MIzKhgqUtTMwxDIW5bWe5nWkCYqRlZeqin0FqTc=
github.com/projectdiscovery/utils v0.0.81 h1:Cqz6uFncCKWRLqpVHWlnHXaRE3whzH32yZJa/1zOEzU=
github.com/projectdiscovery/utils v0.0.81/go.mod h1:pTGvF08EXa07e2OM+tu8IcnxTeAT34bzAhSW/Efcens=
github.com/projectdiscovery/wappalyzergo v0.0.111 h1:A1fLEycJ1zwvIVh3jCL1n5OaKn7KP+pHAsFqQYMXPZM=
github.com/projectdiscovery/wappalyzergo v0.0.111/go.mod h1:hc/o+fgM8KtdpFesjfBTmHTwsR+yBd+4kYZW/DGy/x8=
github.com/projectdiscovery/yamldoc-go v1.0.4 h1:eZoESapnMw6WAHiVgRwNqvbJEfNHEH148uthhFbG5jE=
Expand Down Expand Up @@ -959,6 +963,7 @@ github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIK
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
github.com/sony/gobreaker v0.4.1/go.mod h1:ZKptC7FHNvhBz7dN2LGjPVBz2sZJmc0/PkyDJOjmxWY=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/cast v1.5.1 h1:R+kOtfhWQE6TVQzY+4D7wJLBgkdVasCEFxSUBYBYIlA=
Expand Down
Binary file added memogen
Binary file not shown.
43 changes: 43 additions & 0 deletions pkg/js/libs/mssql/memo.mssql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Warning - This is generated code
package mssql

import (
"errors"
"fmt"

_ "github.com/denisenkom/go-mssqldb"

"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
)

func memoizedconnect(host string, port int, username string, password string, dbName string) (bool, error) {
hash := "connect" + ":" + fmt.Sprint(host) + ":" + fmt.Sprint(port) + ":" + fmt.Sprint(username) + ":" + fmt.Sprint(password) + ":" + fmt.Sprint(dbName)

v, err, _ := protocolstate.Memoizer.Do(hash, func() (interface{}, error) {
return connect(host, port, username, password, dbName)
})
if err != nil {
return false, err
}
if value, ok := v.(bool); ok {
return value, nil
}

return false, errors.New("could not convert cached result")
}

func memoizedisMssql(host string, port int) (bool, error) {
hash := "isMssql" + ":" + fmt.Sprint(host) + ":" + fmt.Sprint(port)

v, err, _ := protocolstate.Memoizer.Do(hash, func() (interface{}, error) {
return isMssql(host, port)
})
if err != nil {
return false, err
}
if value, ok := v.(bool); ok {
return value, nil
}

return false, errors.New("could not convert cached result")
}
12 changes: 9 additions & 3 deletions pkg/js/libs/mssql/mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type (
// const connected = client.Connect('acme.com', 1433, 'username', 'password');
// ```
func (c *MSSQLClient) Connect(host string, port int, username, password string) (bool, error) {
return connect(host, port, username, password, "master")
return memoizedconnect(host, port, username, password, "master")
}

// ConnectWithDB connects to MS SQL database using given credentials and database name.
Expand All @@ -50,10 +50,11 @@ func (c *MSSQLClient) Connect(host string, port int, username, password string)
// const connected = client.ConnectWithDB('acme.com', 1433, 'username', 'password', 'master');
// ```
func (c *MSSQLClient) ConnectWithDB(host string, port int, username, password, dbName string) (bool, error) {
return connect(host, port, username, password, dbName)
return memoizedconnect(host, port, username, password, dbName)
}

func connect(host string, port int, username, password, dbName string) (bool, error) {
// @memo
func connect(host string, port int, username string, password string, dbName string) (bool, error) {
if host == "" || port <= 0 {
return false, fmt.Errorf("invalid host or port")
}
Expand Down Expand Up @@ -104,6 +105,11 @@ func connect(host string, port int, username, password, dbName string) (bool, er
// const isMssql = mssql.IsMssql('acme.com', 1433);
// ```
func (c *MSSQLClient) IsMssql(host string, port int) (bool, error) {
return memoizedisMssql(host, port)
}

// @memo
func isMssql(host string, port int) (bool, error) {
if !protocolstate.IsHostAllowed(host) {
// host is not valid according to network policy
return false, protocolstate.ErrHostDenied.Msgf(host)
Expand Down
41 changes: 41 additions & 0 deletions pkg/js/libs/mysql/memo.mysql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Warning - This is generated code
package mysql

import (
"errors"
"fmt"

"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
)

func memoizedisMySQL(host string, port int) (bool, error) {
hash := "isMySQL" + ":" + fmt.Sprint(host) + ":" + fmt.Sprint(port)

v, err, _ := protocolstate.Memoizer.Do(hash, func() (interface{}, error) {
return isMySQL(host, port)
})
if err != nil {
return false, err
}
if value, ok := v.(bool); ok {
return value, nil
}

return false, errors.New("could not convert cached result")
}

func memoizedfingerprintMySQL(host string, port int) (MySQLInfo, error) {
hash := "fingerprintMySQL" + ":" + fmt.Sprint(host) + ":" + fmt.Sprint(port)

v, err, _ := protocolstate.Memoizer.Do(hash, func() (interface{}, error) {
return fingerprintMySQL(host, port)
})
if err != nil {
return MySQLInfo{}, err
}
if value, ok := v.(MySQLInfo); ok {
return value, nil
}

return MySQLInfo{}, errors.New("could not convert cached result")
}
25 changes: 25 additions & 0 deletions pkg/js/libs/mysql/memo.mysql_private.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Warning - This is generated code
package mysql

import (
"errors"
"fmt"

"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
)

func memoizedconnectWithDSN(dsn string) (bool, error) {
hash := "connectWithDSN" + ":" + fmt.Sprint(dsn)

v, err, _ := protocolstate.Memoizer.Do(hash, func() (interface{}, error) {
return connectWithDSN(dsn)
})
if err != nil {
return false, err
}
if value, ok := v.(bool); ok {
return value, nil
}

return false, errors.New("could not convert cached result")
}
12 changes: 11 additions & 1 deletion pkg/js/libs/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ type (
// const isMySQL = mysql.IsMySQL('acme.com', 3306);
// ```
func (c *MySQLClient) IsMySQL(host string, port int) (bool, error) {
return memoizedisMySQL(host, port)
}

// @memo
func isMySQL(host string, port int) (bool, error) {
if !protocolstate.IsHostAllowed(host) {
// host is not valid according to network policy
return false, protocolstate.ErrHostDenied.Msgf(host)
Expand Down Expand Up @@ -110,6 +115,11 @@ type (
// log(to_json(info));
// ```
func (c *MySQLClient) FingerprintMySQL(host string, port int) (MySQLInfo, error) {
return memoizedfingerprintMySQL(host, port)
}

// @memo
func fingerprintMySQL(host string, port int) (MySQLInfo, error) {
info := MySQLInfo{}
if !protocolstate.IsHostAllowed(host) {
// host is not valid according to network policy
Expand Down Expand Up @@ -153,7 +163,7 @@ func (c *MySQLClient) FingerprintMySQL(host string, port int) (MySQLInfo, error)
// const connected = client.ConnectWithDSN('username:password@tcp(acme.com:3306)/');
// ```
func (c *MySQLClient) ConnectWithDSN(dsn string) (bool, error) {
return connectWithDSN(dsn)
return memoizedconnectWithDSN(dsn)
}

// ExecuteQueryWithOpts connects to Mysql database using given credentials
Expand Down
1 change: 1 addition & 0 deletions pkg/js/libs/mysql/mysql_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func BuildDSN(opts MySQLOptions) (string, error) {
return dsn.String(), nil
}

// @memo
func connectWithDSN(dsn string) (bool, error) {
db, err := sql.Open("mysql", dsn)
if err != nil {
Expand Down
Loading

0 comments on commit 4c7a0f4

Please sign in to comment.