diff --git a/make.go b/make.go index b93e4f8..a4de3e9 100644 --- a/make.go +++ b/make.go @@ -11,12 +11,30 @@ import ( "reflect" ) +type MakeConfig struct { + // Types, if specified, provides Generators for a given type that will + // be used in favor of the automatic reflection based generation. + Types map[reflect.Type]*Generator[any] + // Fields, if specified, provides Generators for a given field of a + // given type that will be used in favor of automatic reflection based + // generation. + Fields map[reflect.Type]map[string]*Generator[any] +} + // Make creates a generator of values of type V, using reflection to infer the required structure. // Currently, Make may be unable to terminate generation of values of some recursive types, thus using // Make with recursive types requires extra care. func Make[V any]() *Generator[V] { + return MakeCustom[V](MakeConfig{}) +} + +// MakeCustom creates a generator of values of type V, using reflection and +// overrides from MakeConfig to infer the required structure. +// Currently, Make may be unable to terminate generation of values of some recursive types, thus using +// Make with recursive types requires extra care. +func MakeCustom[V any](cfg MakeConfig) *Generator[V] { var zero V - gen := newMakeGen(reflect.TypeOf(zero)) + gen := cfg.newMakeGen(reflect.TypeOf(zero)) return newGenerator[V](&makeGen[V]{ gen: gen, }) @@ -35,8 +53,8 @@ func (g *makeGen[V]) value(t *T) V { return g.gen.value(t).(V) } -func newMakeGen(typ reflect.Type) *Generator[any] { - gen, mayNeedCast := newMakeKindGen(typ) +func (c *MakeConfig) newMakeGen(typ reflect.Type) *Generator[any] { + gen, mayNeedCast := c.newMakeKindGen(typ) if !mayNeedCast || typ.String() == typ.Kind().String() { return gen // fast path with less reflect } @@ -54,10 +72,19 @@ func (g *castGen) String() string { func (g *castGen) value(t *T) any { v := g.gen.value(t) + if v == nil { + return nil + } return reflect.ValueOf(v).Convert(g.typ).Interface() } -func newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) { +func (c *MakeConfig) newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) { + if c.Types != nil { + if gen, ok := c.Types[typ]; ok { + return gen, true + } + } + switch typ.Kind() { case reflect.Bool: return Bool().AsAny(), true @@ -88,28 +115,28 @@ func newMakeKindGen(typ reflect.Type) (gen *Generator[any], mayNeedCast bool) { case reflect.Float64: return Float64().AsAny(), true case reflect.Array: - return genAnyArray(typ), false + return c.genAnyArray(typ), false case reflect.Map: - return genAnyMap(typ), false + return c.genAnyMap(typ), false case reflect.Pointer: - return Deferred(func() *Generator[any] { return genAnyPointer(typ) }), false + return Deferred(func() *Generator[any] { return c.genAnyPointer(typ) }), false case reflect.Slice: - return genAnySlice(typ), false + return c.genAnySlice(typ), false case reflect.String: return String().AsAny(), true case reflect.Struct: - return genAnyStruct(typ), false + return c.genAnyStruct(typ), false default: panic(fmt.Sprintf("unsupported type kind for Make: %v", typ.Kind())) } } -func genAnyPointer(typ reflect.Type) *Generator[any] { +func (c *MakeConfig) genAnyPointer(typ reflect.Type) *Generator[any] { elem := typ.Elem() - elemGen := newMakeGen(elem) + elemGen := c.newMakeGen(elem) const pNonNil = 0.5 - return Custom[any](func(t *T) any { + return Custom(func(t *T) any { if flipBiasedCoin(t.s, pNonNil) { val := elemGen.value(t) ptr := reflect.New(elem) @@ -121,11 +148,11 @@ func genAnyPointer(typ reflect.Type) *Generator[any] { }) } -func genAnyArray(typ reflect.Type) *Generator[any] { +func (c *MakeConfig) genAnyArray(typ reflect.Type) *Generator[any] { count := typ.Len() - elemGen := newMakeGen(typ.Elem()) + elemGen := c.newMakeGen(typ.Elem()) - return Custom[any](func(t *T) any { + return Custom(func(t *T) any { a := reflect.Indirect(reflect.New(typ)) if count == 0 { t.s.drawBits(0) @@ -139,10 +166,10 @@ func genAnyArray(typ reflect.Type) *Generator[any] { }) } -func genAnySlice(typ reflect.Type) *Generator[any] { - elemGen := newMakeGen(typ.Elem()) +func (c *MakeConfig) genAnySlice(typ reflect.Type) *Generator[any] { + elemGen := c.newMakeGen(typ.Elem()) - return Custom[any](func(t *T) any { + return Custom(func(t *T) any { repeat := newRepeat(-1, -1, -1, elemGen.String()) sl := reflect.MakeSlice(typ, 0, repeat.avg()) for repeat.more(t.s) { @@ -153,11 +180,11 @@ func genAnySlice(typ reflect.Type) *Generator[any] { }) } -func genAnyMap(typ reflect.Type) *Generator[any] { - keyGen := newMakeGen(typ.Key()) - valGen := newMakeGen(typ.Elem()) +func (c *MakeConfig) genAnyMap(typ reflect.Type) *Generator[any] { + keyGen := c.newMakeGen(typ.Key()) + valGen := c.newMakeGen(typ.Elem()) - return Custom[any](func(t *T) any { + return Custom(func(t *T) any { label := keyGen.String() + "," + valGen.String() repeat := newRepeat(-1, -1, -1, label) m := reflect.MakeMapWithSize(typ, repeat.avg()) @@ -174,23 +201,51 @@ func genAnyMap(typ reflect.Type) *Generator[any] { }) } -func genAnyStruct(typ reflect.Type) *Generator[any] { +func (c *MakeConfig) genAnyStruct(typ reflect.Type) *Generator[any] { + customFields := map[string]*Generator[any]{} + if c.Fields != nil { + if custom, ok := c.Fields[typ]; ok { + customFields = custom + } + } + numFields := typ.NumField() fieldGens := make([]*Generator[any], numFields) for i := 0; i < numFields; i++ { - fieldGens[i] = newMakeGen(typ.Field(i).Type) + field := typ.Field(i) + if !field.IsExported() { + continue + } + + if gen, ok := customFields[field.Name]; ok { + fieldGens[i] = gen + } else { + fieldGens[i] = c.newMakeGen(field.Type) + } } - return Custom[any](func(t *T) any { + return Custom(func(t *T) any { s := reflect.Indirect(reflect.New(typ)) - if numFields == 0 { - t.s.drawBits(0) - } else { - for i := 0; i < numFields; i++ { - f := reflect.ValueOf(fieldGens[i].value(t)) - s.Field(i).Set(f) + + fieldsSet := 0 + for i := 0; i < numFields; i++ { + if fieldGens[i] == nil { + continue } + + value := fieldGens[i].value(t) + if value == nil { + continue + } + + s.Field(i).Set(reflect.ValueOf(value)) + fieldsSet++ + } + + if fieldsSet == 0 { + t.s.drawBits(0) } + return s.Interface() }) } diff --git a/make_test.go b/make_test.go new file mode 100644 index 0000000..c366c54 --- /dev/null +++ b/make_test.go @@ -0,0 +1,35 @@ +// Copyright 2022 Gregory Petrosyan +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package rapid_test + +import ( + "reflect" + "testing" + + "pgregory.net/rapid" +) + +type PrivateFields struct { + private bool +} + +func TestMake(t *testing.T) { + // Private fields are ignored (and don't panic). + rapid.Make[PrivateFields]().Example() +} + +func TestMakeCustom(t *testing.T) { + ex := rapid.MakeCustom[PrivateFields](rapid.MakeConfig{ + Types: map[reflect.Type]*rapid.Generator[any]{ + reflect.TypeOf(PrivateFields{}): rapid.Just(PrivateFields{private: true}).AsAny(), + }, + }).Example() + + if !ex.private { + t.Errorf(".private should be true. got: %#v", ex) + } +}