Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make[T](): add support for field and type overrides #72

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 86 additions & 31 deletions make.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,30 @@ import (
"reflect"
)

type MakeConfig struct {
// Types, if specified, provides Generators for a given type that will
// be used in favor of the automatic reflection based generation.
Types map[reflect.Type]*Generator[any]
// Fields, if specified, provides Generators for a given field of a
// given type that will be used in favor of automatic reflection based
// generation.
Fields map[reflect.Type]map[string]*Generator[any]
}

// Make creates a generator of values of type V, using reflection to infer the required structure.
// Currently, Make may be unable to terminate generation of values of some recursive types, thus using
// Make with recursive types requires extra care.
func Make[V any]() *Generator[V] {
return MakeCustom[V](MakeConfig{})
}

// MakeCustom creates a generator of values of type V, using reflection and
// overrides from MakeConfig to infer the required structure.
// Currently, Make may be unable to terminate generation of values of some recursive types, thus using
// Make with recursive types requires extra care.
func MakeCustom[V any](cfg MakeConfig) *Generator[V] {
var zero V
gen := newMakeGen(reflect.TypeOf(zero))
gen := cfg.newMakeGen(reflect.TypeOf(zero))
return newGenerator[V](&makeGen[V]{
gen: gen,
})
Expand All @@ -35,8 +53,8 @@ func (g *makeGen[V]) value(t *T) V {
return g.gen.value(t).(V)
}

func newMakeGen(typ reflect.Type) *Generator[any] {
gen, mayNeedCast := newMakeKindGen(typ)
func (c *MakeConfig) newMakeGen(typ reflect.Type) *Generator[any] {
gen, mayNeedCast := c.newMakeKindGen(typ)
if !mayNeedCast || typ.String() == typ.Kind().String() {
return gen // fast path with less reflect
}
Expand All @@ -54,10 +72,19 @@ func (g *castGen) String() string {

func (g *castGen) value(t *T) any {
v := g.gen.value(t)
if v == nil {
return nil
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this check here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was seeing nil pointer exceptions in cases where v was nil. After double checking reflect.ValueOf I suspect the issue stems from reflect.Value.Convert as reflect.ValueOf(nil) returns a zero reflect.Value.

I'll add a comment and/or I could change this to:

v := reflect.ValueOf(g.gen.value(t))
if v.IsNil() {
    return nil
}
return v.Convert(g.type).Interface()

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what cases .value() is nil? It should not be, ever.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! I figured it out. I was attempting to skip over a pointer field but missed the * so I had overridden a struct generator with rapid.Just[any](nil). I'll see if I can make the error message a bit more friendly and add a regression test.

return reflect.ValueOf(v).Convert(g.typ).Interface()
}

func newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) {
func (c *MakeConfig) newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) {
if c.Types != nil {
if gen, ok := c.Types[typ]; ok {
return gen, true
}
}

switch typ.Kind() {
case reflect.Bool:
return Bool().AsAny(), true
Expand Down Expand Up @@ -88,28 +115,28 @@ func newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) {
case reflect.Float64:
return Float64().AsAny(), true
case reflect.Array:
return genAnyArray(typ), false
return c.genAnyArray(typ), false
case reflect.Map:
return genAnyMap(typ), false
return c.genAnyMap(typ), false
case reflect.Pointer:
return Deferred(func() *Generator[any] { return genAnyPointer(typ) }), false
return Deferred(func() *Generator[any] { return c.genAnyPointer(typ) }), false
case reflect.Slice:
return genAnySlice(typ), false
return c.genAnySlice(typ), false
case reflect.String:
return String().AsAny(), true
case reflect.Struct:
return genAnyStruct(typ), false
return c.genAnyStruct(typ), false
default:
panic(fmt.Sprintf("unsupported type kind for Make: %v", typ.Kind()))
}
}

func genAnyPointer(typ reflect.Type) *Generator[any] {
func (c *MakeConfig) genAnyPointer(typ reflect.Type) *Generator[any] {
elem := typ.Elem()
elemGen := newMakeGen(elem)
elemGen := c.newMakeGen(elem)
const pNonNil = 0.5

return Custom[any](func(t *T) any {
return Custom(func(t *T) any {
if flipBiasedCoin(t.s, pNonNil) {
val := elemGen.value(t)
ptr := reflect.New(elem)
Expand All @@ -121,11 +148,11 @@ func genAnyPointer(typ reflect.Type) *Generator[any] {
})
}

func genAnyArray(typ reflect.Type) *Generator[any] {
func (c *MakeConfig) genAnyArray(typ reflect.Type) *Generator[any] {
count := typ.Len()
elemGen := newMakeGen(typ.Elem())
elemGen := c.newMakeGen(typ.Elem())

return Custom[any](func(t *T) any {
return Custom(func(t *T) any {
a := reflect.Indirect(reflect.New(typ))
if count == 0 {
t.s.drawBits(0)
Expand All @@ -139,10 +166,10 @@ func genAnyArray(typ reflect.Type) *Generator[any] {
})
}

func genAnySlice(typ reflect.Type) *Generator[any] {
elemGen := newMakeGen(typ.Elem())
func (c *MakeConfig) genAnySlice(typ reflect.Type) *Generator[any] {
elemGen := c.newMakeGen(typ.Elem())

return Custom[any](func(t *T) any {
return Custom(func(t *T) any {
repeat := newRepeat(-1, -1, -1, elemGen.String())
sl := reflect.MakeSlice(typ, 0, repeat.avg())
for repeat.more(t.s) {
Expand All @@ -153,11 +180,11 @@ func genAnySlice(typ reflect.Type) *Generator[any] {
})
}

func genAnyMap(typ reflect.Type) *Generator[any] {
keyGen := newMakeGen(typ.Key())
valGen := newMakeGen(typ.Elem())
func (c *MakeConfig) genAnyMap(typ reflect.Type) *Generator[any] {
keyGen := c.newMakeGen(typ.Key())
valGen := c.newMakeGen(typ.Elem())

return Custom[any](func(t *T) any {
return Custom(func(t *T) any {
label := keyGen.String() + "," + valGen.String()
repeat := newRepeat(-1, -1, -1, label)
m := reflect.MakeMapWithSize(typ, repeat.avg())
Expand All @@ -174,23 +201,51 @@ func genAnyMap(typ reflect.Type) *Generator[any] {
})
}

func genAnyStruct(typ reflect.Type) *Generator[any] {
func (c *MakeConfig) genAnyStruct(typ reflect.Type) *Generator[any] {
customFields := map[string]*Generator[any]{}
if c.Fields != nil {
if custom, ok := c.Fields[typ]; ok {
customFields = custom
}
}

numFields := typ.NumField()
fieldGens := make([]*Generator[any], numFields)
for i := 0; i < numFields; i++ {
fieldGens[i] = newMakeGen(typ.Field(i).Type)
field := typ.Field(i)
if !field.IsExported() {
continue
}

if gen, ok := customFields[field.Name]; ok {
fieldGens[i] = gen
} else {
fieldGens[i] = c.newMakeGen(field.Type)
}
}

return Custom[any](func(t *T) any {
return Custom(func(t *T) any {
s := reflect.Indirect(reflect.New(typ))
if numFields == 0 {
t.s.drawBits(0)
} else {
for i := 0; i < numFields; i++ {
f := reflect.ValueOf(fieldGens[i].value(t))
s.Field(i).Set(f)

fieldsSet := 0
for i := 0; i < numFields; i++ {
if fieldGens[i] == nil {
continue
}

value := fieldGens[i].value(t)
if value == nil {
continue
}

s.Field(i).Set(reflect.ValueOf(value))
fieldsSet++
}

if fieldsSet == 0 {
t.s.drawBits(0)
}

return s.Interface()
})
}
35 changes: 35 additions & 0 deletions make_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright 2022 Gregory Petrosyan <gregory.petrosyan@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

package rapid_test

import (
"reflect"
"testing"

"pgregory.net/rapid"
)

type PrivateFields struct {
private bool
}

func TestMake(t *testing.T) {
// Private fields are ignored (and don't panic).
rapid.Make[PrivateFields]().Example()
}

func TestMakeCustom(t *testing.T) {
ex := rapid.MakeCustom[PrivateFields](rapid.MakeConfig{
Types: map[reflect.Type]*rapid.Generator[any]{
reflect.TypeOf(PrivateFields{}): rapid.Just(PrivateFields{private: true}).AsAny(),
},
}).Example()

if !ex.private {
t.Errorf(".private should be true. got: %#v", ex)
}
}
Loading