Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transparent Memoization via func Annotation #4742

Merged
merged 12 commits into from
Mar 1, 2024
Next Next commit
initial implementation with manual code generation
Mzack9999 committed Feb 7, 2024
commit a8c467a7fa712ce6af9d20a7cc7660b7e9ab0aaa
26 changes: 26 additions & 0 deletions cmd/memoize/gotest/test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package gotest

import (
"errors"
"time"
)

// @memo
func Test(a string, b string) string {
return "something"
}

// @memo
func TestNothing() {
time.Sleep(time.Second)
}

// @memo
func TestWithOneReturn() string {
return "a"
}

// @memo
func TestWithMultipleReturnValues() (string, int, error) {
return "a", 2, errors.New("test")
}
236 changes: 236 additions & 0 deletions cmd/memoize/memoize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
package main

import (
"bytes"
"flag"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/printer"
"go/token"
"io/fs"
"log"
"os"
"path/filepath"
"strings"

fileutil "github.com/projectdiscovery/utils/file"
"golang.org/x/tools/imports"
)

var (
srcFolder = flag.String("src", "", "source folder")
dstFolder = flag.String("dst", "", "destination foldder")
packageName = flag.String("pkg", "memo", "destination package")
)

func main() {
flag.Parse()

_ = fileutil.CreateFolder(*dstFolder)

err := filepath.WalkDir(*srcFolder, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.IsDir() {
return nil
}
if ext := filepath.Ext(path); strings.ToLower(ext) != ".go" {
return nil
}

return process(path)
})
if err != nil {
log.Fatal(err)
}
}

func process(path string) error {
fset := token.NewFileSet()
node, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
if err != nil {
return err
}

filename := filepath.Base(path)
dstFile := filepath.Join(*dstFolder, filename)

var content bytes.Buffer

content.WriteString(fmt.Sprintf("package %s\n\n", *packageName))

sourcePackage := node.Name.Name

ast.Inspect(node, func(n ast.Node) bool {
switch nn := n.(type) {
case *ast.FuncDecl:
if !nn.Name.IsExported() {
return false
}
for _, comment := range nn.Doc.List {
if comment.Text == "// @memo" {
var funcs strings.Builder

hasReturnType := nn.Type.Results != nil && len(nn.Type.Results.List) > 0
hasParams := nn.Type.Params != nil && len(nn.Type.Params.List) > 0

var (
retValuesNames []string
retStructFieldsNames []string
retValuesTypes []string
paramNames []string
)

if hasParams {
for _, param := range nn.Type.Params.List {
for _, name := range param.Names {
paramNames = append(paramNames, name.String())
}
}
}

if hasReturnType {
for idx, result := range nn.Type.Results.List {
retValueName := fmt.Sprintf("ret%d%s", idx, nn.Name.Name)
retValueType := fmt.Sprint(result.Type)
retValuesNames = append(retValuesNames, retValueName)
retValuesTypes = append(retValuesTypes, retValueType)
}
}
var retStructName, retStructInstance string
if hasReturnType && hasParams {
retStructName = "resStruct" + nn.Name.Name
retStructInstance = fmt.Sprintf("ret%s", retStructName)
funcs.WriteString(fmt.Sprintf("type %s struct {", retStructName))
for idx := range retValuesNames {
structFieldName := fmt.Sprintf("%s.%s", retStructInstance, retValuesNames[idx])
retStructFieldsNames = append(retStructFieldsNames, structFieldName)

funcs.WriteString("\n")
funcs.WriteString(fmt.Sprintf("%s %s\n", retValuesNames[idx], retValuesTypes[idx]))
}
funcs.WriteString("}")
funcs.WriteString("\n")
}

syncOnceName := "once" + nn.Name.Name
funcs.WriteString("var (\n")
if !hasParams {
funcs.WriteString(syncOnceName + " sync.Once")
}
if hasReturnType {
if !hasParams {
for idx := range retValuesNames {
funcs.WriteString("\n")
funcs.WriteString(fmt.Sprintf("%s %s\n", retValuesNames[idx], retValuesTypes[idx]))
}
}
}

funcs.WriteString("\n)\n")

var funcSign strings.Builder
printer.Fprint(&funcSign, fset, nn.Type)
funcs.WriteString(strings.Replace(funcSign.String(), "func", "func "+nn.Name.Name, 1))
funcs.WriteString("{")

if !hasParams {
returnStatement := strings.Join(retValuesNames, ",")
funcs.WriteString("\n" + syncOnceName + ".Do(func() {")
funcs.WriteString("\n")
if hasReturnType {
funcs.WriteString(returnStatement + "=")
}
funcs.WriteString(sourcePackage + "." + nn.Name.Name + "()")
funcs.WriteString("})")
if hasReturnType {
funcs.WriteString("\nreturn " + returnStatement)
}
} else {
funcs.WriteString(fmt.Sprintf("var %s *%s\n", retStructInstance, retStructName))
funcs.WriteString("h := hash(")
funcs.WriteString("\"" + nn.Name.Name + "\", ")
funcs.WriteString(strings.Join(paramNames, ","))

funcs.WriteString(")")
funcs.WriteString("\n")

funcs.WriteString(`if v, err := cache.GetIFPresent(h); err == nil {
retresStructTest = v.(*` + retStructName + `)
}`)

allStructFields := strings.Join(retStructFieldsNames, ",")

if hasReturnType {
funcs.WriteString("\n")
funcs.WriteString(allStructFields + "=")
}

funcs.WriteString(sourcePackage + "." + nn.Name.Name + "(")

if hasParams {
var params []string
for _, param := range nn.Type.Params.List {
for _, id := range param.Names {
params = append(params, id.Name)
}
}
funcs.WriteString(strings.Join(params, ", "))
}
funcs.WriteString(")")
funcs.WriteString("\n")
funcs.WriteString(`cache.Set(h, ` + retStructInstance + `)`)

funcs.WriteString("\nreturn " + allStructFields)
}
funcs.WriteString("}")
content.WriteString(funcs.String())
content.WriteString("\n")
}
}
return false
default:
return true
}
})

// inject std func
content.WriteString("\n" + hashFunc)
content.WriteString("\n" + memoizeCache)

log.Println(content.String())

out, err := imports.Process(dstFile, content.Bytes(), nil)
if err != nil {
return err
}

out, err = format.Source(out)
if err != nil {
return err
}

return os.WriteFile(dstFile, out, os.ModePerm)
}

var hashFunc = `
func hash(functionName string, args ...any) string {
var b bytes.Buffer
b.WriteString(functionName + ":")
for _, arg := range args {
b.WriteString(fmt.Sprint(arg))
}
h := sha256.Sum256(b.Bytes())
return hex.EncodeToString(h[:])
}
`

var memoizeCache = `
var cache gcache.Cache[string, interface{}]

func init() {
cache = gcache.New[string, interface{}](1000).Build()
}`
51 changes: 51 additions & 0 deletions cmd/memoize/memoize/store.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package memoize

import (
"errors"

"github.com/Mzack9999/gcache"
"golang.org/x/sync/singleflight"
)

type Memoizer struct {
cache gcache.Cache[string, interface{}]
group singleflight.Group
}

type MemoizeOption func(m *Memoizer) error

func WithMaxSize(size int) MemoizeOption {
return func(m *Memoizer) error {
m.cache = gcache.New[string, interface{}](size).Build()
return nil
}
}

func New(options ...MemoizeOption) (*Memoizer, error) {
m := &Memoizer{}
for _, option := range options {
if err := option(m); err != nil {
return nil, err
}
}

return m, nil
}

func (m *Memoizer) Do(funcHash string, fn func() (interface{}, error)) (interface{}, error, bool) {
if value, err := m.cache.GetIFPresent(funcHash); !errors.Is(err, gcache.KeyNotFoundError) {
return value, err, true
}

value, err, _ := m.group.Do(funcHash, func() (interface{}, error) {
data, innerErr := fn()

if innerErr == nil {
m.cache.Set(funcHash, data)

Check failure on line 44 in cmd/memoize/memoize/store.go

GitHub Actions / Lint Test

Error return value of `m.cache.Set` is not checked (errcheck)
}

return data, innerErr
})

return value, err, false
}
83 changes: 83 additions & 0 deletions cmd/memoize/test/test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package memo

Check failure on line 1 in cmd/memoize/test/test.go

GitHub Actions / Lint Test

: # github.com/projectdiscovery/nuclei/v3/cmd/memoize/test

import (
"bytes"
"crypto/sha256"
"encoding/hex"
"fmt"
"sync"

"github.com/bluele/gcache"
"github.com/projectdiscovery/nuclei/v3/cmd/memoize/gotest"
)

type resStructTest struct {
ret0Test string
}

var ()

func Test(a string, b string) string {
var retresStructTest *resStructTest
h := hash("Test", a, b)
if v, err := cache.GetIFPresent(h); err == nil {
retresStructTest = v.(*resStructTest)
}
retresStructTest.ret0Test = gotest.Test(a, b)
cache.Set(h, retresStructTest)
return retresStructTest.ret0Test
}

var (
onceTestNothing sync.Once
)

func TestNothing() {
onceTestNothing.Do(func() {
gotest.TestNothing()
})
}

var (
onceTestWithOneReturn sync.Once
ret0TestWithOneReturn string
)

func TestWithOneReturn() string {
onceTestWithOneReturn.Do(func() {
ret0TestWithOneReturn = gotest.TestWithOneReturn()
})
return ret0TestWithOneReturn
}

var (
onceTestWithMultipleReturnValues sync.Once
ret0TestWithMultipleReturnValues string

ret1TestWithMultipleReturnValues int

ret2TestWithMultipleReturnValues error
)

func TestWithMultipleReturnValues() (string, int, error) {
onceTestWithMultipleReturnValues.Do(func() {
ret0TestWithMultipleReturnValues, ret1TestWithMultipleReturnValues, ret2TestWithMultipleReturnValues = gotest.TestWithMultipleReturnValues()
})
return ret0TestWithMultipleReturnValues, ret1TestWithMultipleReturnValues, ret2TestWithMultipleReturnValues
}

func hash(functionName string, args ...any) string {
var b bytes.Buffer
b.WriteString(functionName + ":")
for _, arg := range args {
b.WriteString(fmt.Sprint(arg))
}
h := sha256.Sum256(b.Bytes())
return hex.EncodeToString(h[:])
}

var cache gcache.Cache[string, interface{}]

Check failure on line 79 in cmd/memoize/test/test.go

GitHub Actions / Lint Test

invalid operation: gcache.Cache[string, interface{}] (gcache.Cache is not a generic type)

Check failure on line 79 in cmd/memoize/test/test.go

GitHub Actions / Test Builds (1.21.x, ubuntu-latest)

invalid operation: gcache.Cache[string, interface{}] (gcache.Cache is not a generic type)

func init() {
cache = gcache.New[string, interface{}](1000).Build()

Check failure on line 82 in cmd/memoize/test/test.go

GitHub Actions / Lint Test

invalid operation: cannot index gcache.New (value of type func(size int) *gcache.CacheBuilder) (typecheck)

Check failure on line 82 in cmd/memoize/test/test.go

GitHub Actions / Test Builds (1.21.x, ubuntu-latest)

invalid operation: cannot index gcache.New (value of type func(size int) *gcache.CacheBuilder)
}
Loading