From 5b6a938443d7fd0948c7b4734f732a8f013b6083 Mon Sep 17 00:00:00 2001 From: Alexander Klauer Date: Sun, 14 Jun 2020 16:08:45 +0200 Subject: [PATCH] reflect: allow creation of recursive types Introduce two new functions, Named and Bind, to the reflect API to create possibly recursive, named types. Fixes #39528. --- src/cmd/compile/internal/gc/reflect.go | 1 + src/cmd/link/internal/ld/decodesym.go | 5 +- src/internal/reflectlite/type.go | 4 + src/reflect/all_test.go | 567 +++++++++++++++++++++++++ src/reflect/type.go | 241 ++++++++++- src/runtime/type.go | 1 + 6 files changed, 806 insertions(+), 13 deletions(-) diff --git a/src/cmd/compile/internal/gc/reflect.go b/src/cmd/compile/internal/gc/reflect.go index f614b60685e4b..be791b4edb18a 100644 --- a/src/cmd/compile/internal/gc/reflect.go +++ b/src/cmd/compile/internal/gc/reflect.go @@ -812,6 +812,7 @@ const ( tflagExtraStar = 1 << 1 tflagNamed = 1 << 2 tflagRegularMemory = 1 << 3 + tflagIncomplete = 1 << 4 ) var ( diff --git a/src/cmd/link/internal/ld/decodesym.go b/src/cmd/link/internal/ld/decodesym.go index e9c87efe3796b..7e089b750de60 100644 --- a/src/cmd/link/internal/ld/decodesym.go +++ b/src/cmd/link/internal/ld/decodesym.go @@ -24,8 +24,9 @@ import ( // reflect/type.go // runtime/type.go const ( - tflagUncommon = 1 << 0 - tflagExtraStar = 1 << 1 + tflagUncommon = 1 << 0 + tflagExtraStar = 1 << 1 + tflagIncomplete = 1 << 4 ) func decodeInuxi(arch *sys.Arch, p []byte, sz int) uint64 { diff --git a/src/internal/reflectlite/type.go b/src/internal/reflectlite/type.go index eb7f1a4b78e71..389f76c1c5283 100644 --- a/src/internal/reflectlite/type.go +++ b/src/internal/reflectlite/type.go @@ -142,6 +142,10 @@ const ( // tflagRegularMemory means that equal and hash functions can treat // this type as a single region of t.size bytes. tflagRegularMemory tflag = 1 << 3 + + // tflagIncomplete means this type was created by Named and Bind has not + // been called on it yet. + tflagIncomplete tflag = 1 << 4 ) // rtype is the common implementation of most values. diff --git a/src/reflect/all_test.go b/src/reflect/all_test.go index e87d1d27cd068..8048a7a32f38e 100644 --- a/src/reflect/all_test.go +++ b/src/reflect/all_test.go @@ -5650,6 +5650,573 @@ func TestMapOfGCValues(t *testing.T) { } } +func TestBind(t *testing.T) { + // invalid type name "1nvalid" + shouldPanic("invalid type name", func() { + Named("1nvalid") + }) + + // invalid type name "+" + shouldPanic("invalid type name", func() { + Named("+") + }) + + // no type name + shouldPanic("invalid type name", func() { + Named("") + }) + + // verify creation of incomplete type with valid type name + valid := Named("valid") + if valid == nil { + t.Error("valid incomplete type is nil") + } + + // bind simple types + simple := []struct { + Name string + Underlying Type + }{ + {"sBool", TypeOf(true)}, + {"sInt", TypeOf(int(0))}, + {"sInt8", TypeOf(int8(0))}, + {"sInt16", TypeOf(int16(0))}, + {"sInt32", TypeOf(int32(0))}, + {"sInt64", TypeOf(int64(0))}, + {"sUint", TypeOf(uint(0))}, + {"sUint8", TypeOf(uint8(0))}, + {"sUint16", TypeOf(uint16(0))}, + {"sUint32", TypeOf(uint32(0))}, + {"sUint64", TypeOf(uint64(0))}, + {"sUintptr", TypeOf(uintptr(0))}, + {"sFloat32", TypeOf(float32(0))}, + {"sFloat64", TypeOf(float64(0))}, + {"sComplex64", TypeOf(complex64(0))}, + {"sComplex128", TypeOf(complex128(0))}, + {"sString", TypeOf("")}, + {"sUnsafePointer", TypeOf(unsafe.Pointer(uintptr(0)))}, + } + for _, testcase := range simple { + incomplete := Named(testcase.Name) + complete := Bind(incomplete, testcase.Underlying) + if complete == testcase.Underlying { + t.Errorf("simple type %s equal to underlying type", testcase.Name) + } + if complete.Name() != testcase.Name { + t.Errorf("simple type name = %s, want %s", complete.Name(), testcase.Name) + } + if complete.Kind() != testcase.Underlying.Kind() { + t.Errorf("simple type kind = %s, want %s", complete.Kind(), + testcase.Underlying.Kind()) + } + } + + // invalid array bind + incomplete := Named("incompleteArray") + shouldPanic("incomplete element type", func() { + ArrayOf(42, incomplete) + }) + + // nonrecursive array bind + incomplete = Named("nonrecursiveArray") + underlying := ArrayOf(42, TypeOf(0)) + complete := Bind(incomplete, underlying) + if complete == underlying { + t.Error("nonrecursiveArray complete = underlying") + } + if complete.Name() != "nonrecursiveArray" { + t.Errorf("bound array name = %s, want nonrecursiveArray", complete.Name()) + } + if complete.Kind() != Array { + t.Errorf("bound array kind = %s", complete.Kind()) + } + if complete.Elem() != TypeOf(0) { + t.Errorf("bound array element type = %s, want %s", + complete.Elem(), TypeOf(0)) + } + if complete.Len() != 42 { + t.Errorf("bound array len = %d, want 42", complete.Len()) + } + + // nonrecursive chan bind + incomplete = Named("nonrecursiveChan") + underlying = ChanOf(BothDir, TypeOf(0)) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("nonrecursiveChan complete = underlying") + } + if complete.Name() != "nonrecursiveChan" { + t.Errorf("bound chan name = %s, want nonrecursiveChan", complete.Name()) + } + if complete.Kind() != Chan { + t.Errorf("bound chan kind = %s", complete.Kind()) + } + if complete.Elem() != TypeOf(0) { + t.Errorf("bound chan element type = %s, want %s", + complete.Elem(), TypeOf(0)) + } + if complete.ChanDir() != BothDir { + t.Errorf("bound chan dir = %s, want BothDir", complete.ChanDir()) + } + + // nonrecursive func bind + incomplete = Named("nonrecursiveFunc") + in := []Type{TypeOf(true), TypeOf(0), TypeOf("")} + out := []Type{TypeOf(byte(0)), TypeOf(complex128(0))} + underlying = FuncOf(in, out, false) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("nonrecursiveFunc complete = underlying") + } + if complete.Name() != "nonrecursiveFunc" { + t.Errorf("bound func name = %s, want nonrecursiveFunc", complete.Name()) + } + if complete.Kind() != Func { + t.Errorf("bound func kind = %s", complete.Kind()) + } + if complete.NumIn() != len(in) { + t.Errorf("bound func num args = %d, want %d", complete.NumIn(), len(in)) + } + if complete.NumOut() != len(out) { + t.Errorf("bound func num rets = %d, want %d", complete.NumOut(), len(out)) + } + for i, expect := range in { + if complete.In(i) != expect { + t.Errorf("bound func arg %d type = %s, want %s", + i, complete.In(i), expect) + } + } + for i, expect := range out { + if complete.Out(i) != expect { + t.Errorf("bound func ret %d type = %s, want %d", + i, complete.Out(i), expect) + } + } + if complete.IsVariadic() { + t.Error("bound func is variadic, expected non-variadic") + } + + // unsupported interface bind + incomplete = Named("unsupportedInterface") + underlying = TypeOf(new(interface{})).Elem() + shouldPanic("binding to interface type is not supported", func() { + Bind(incomplete, underlying) + }) + + // nonrecursive map bind + incomplete = Named("nonrecursiveMap") + underlying = MapOf(TypeOf(""), TypeOf(0)) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("nonrecursiveMap complete = underlying") + } + if complete.Name() != "nonrecursiveMap" { + t.Errorf("bound map name = %s, want nonrecursiveMap", complete.Name()) + } + if complete.Kind() != Map { + t.Errorf("bound map kind = %s", complete.Kind()) + } + if complete.Key() != TypeOf("") { + t.Errorf("bound map key = %s, want %s", complete.Key(), TypeOf("")) + } + if complete.Elem() != TypeOf(0) { + t.Errorf("bound map elem = %s, want %s", complete.Elem(), TypeOf(0)) + } + + // nonrecursive pointer bind + incomplete = Named("nonrecursivePtr") + underlying = PtrTo(TypeOf(0)) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("nonrecursivePtr complete = underlying") + } + if complete.Name() != "nonrecursivePtr" { + t.Errorf("bound pointer name = %s, want nonrecursivePtr", complete.Name()) + } + if complete.Kind() != Ptr { + t.Errorf("bound pointer kind = %s", complete.Kind()) + } + if complete.Elem() != TypeOf(0) { + t.Errorf("bound pointer elem = %s, want %s", complete.Elem(), TypeOf(0)) + } + + // nonrecursive slice bind + incomplete = Named("nonrecursiveSlice") + underlying = SliceOf(TypeOf(0)) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("nonrecursiveSlice complete = underlying") + } + if complete.Name() != "nonrecursiveSlice" { + t.Errorf("bound slice name = %s, want nonrecursiveSlice", complete.Name()) + } + if complete.Kind() != Slice { + t.Errorf("bound slice kind = %s", complete.Kind()) + } + if complete.Elem() != TypeOf(0) { + t.Errorf("bound slice elem = %s, want %s", complete.Elem(), TypeOf(0)) + } + + // nonrecursive struct bind + incomplete = Named("nonrecursiveStruct") + fields := []StructField{ + {Name: "Bool", Type: TypeOf(true)}, + {Name: "Int", Type: TypeOf(0)}, + } + underlying = StructOf(fields) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("nonrecursiveStruct complete = underlying") + } + if complete.Name() != "nonrecursiveStruct" { + t.Errorf("bound struct name = %s, want nonrecursiveStruct", complete.Name()) + } + if complete.Kind() != Struct { + t.Errorf("bound struct kind = %s", complete.Kind()) + } + if complete.NumField() != len(fields) { + t.Errorf("bound struct num fields = %d, want %d", + complete.NumField(), len(fields)) + } + for i, field := range fields { + boundField := complete.Field(i) + if boundField.Name != field.Name { + t.Errorf("bound struct field %d name = %s, want %s", + i, boundField.Name, field.Name) + } + if boundField.Type != field.Type { + t.Errorf("bound struct field %d type = %s, want %s", + i, boundField.Type, field.Type) + } + boundField, ok := complete.FieldByName(field.Name) + if !ok { + t.Errorf("bound struct field %s missing", field.Name) + } + if boundField.Type != field.Type { + t.Errorf("bound struct field %s type = %s, want %s", + field.Name, boundField.Type, field.Type) + } + } + if _, ok := complete.FieldByName("Missing"); ok { + t.Error("bound struct has missing field") + } + + // recursive array bind + incomplete = Named("recursiveArray") + underlying = ArrayOf(42, PtrTo(incomplete)) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("recursiveArray complete = underlying") + } + if complete.Elem().Elem() != complete { + t.Errorf("recursive array elem = %s, want %s", + complete.Elem().Elem(), complete) + } + + // recursive chan bind + incomplete = Named("recursiveChan") + underlying = ChanOf(SendDir, incomplete) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("recursiveChan complete = underlying") + } + if complete.Elem() != complete { + t.Errorf("recursive chan elem = %s, want %s", complete.Elem(), complete) + } + + // oversized recursive chan bind + incomplete = Named("oversizedChan") + underlying = StructOf([]StructField{ + {Name: "Chan", Type: ChanOf(RecvDir, incomplete)}, + {Name: "Big", Type: ArrayOf(1<<16, PtrTo(incomplete))}, + }) + shouldPanic("element size too large", func() { + Bind(incomplete, underlying) + }) + + // recursive func bind + incomplete = Named("recursiveFunc") + underlying = FuncOf([]Type{incomplete}, []Type{incomplete}, false) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("recursiveFunc complete = underlying") + } + if complete.In(0) != complete { + t.Errorf("recursive func arg = %s, want %s", complete.In(0), complete) + } + if complete.Out(0) != complete { + t.Errorf("recursive func ret = %s, want %s", complete.Out(0), complete) + } + + // recursive map bind + incomplete = Named("recursiveMap") + underlying = MapOf(PtrTo(incomplete), incomplete) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("recursiveMap complete = underlying") + } + if complete.Key().Elem() != complete { + t.Errorf("recursive map key = %s, want %s", complete.Key().Elem(), complete) + } + if complete.Elem() != complete { + t.Errorf("recursive map elem = %s, want %s", complete.Elem(), complete) + } + + // recursive ptr bind + incomplete = Named("recursivePtr") + underlying = PtrTo(incomplete) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("recursivePtr complete = underlying") + } + if complete.Elem() != complete { + t.Errorf("recursive ptr elem = %s, want %s", complete.Elem(), complete) + } + + // recursive slice bind + incomplete = Named("recursiveSlice") + underlying = SliceOf(incomplete) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("recursiveSlice complete = underlying") + } + if complete.Elem() != complete { + t.Errorf("recursive slice elem = %s, want %s", complete.Elem(), complete) + } + + // recursive struct bind + incomplete = Named("recursiveStruct") + underlying = StructOf([]StructField{ + {Name: "A", Type: ArrayOf(42, PtrTo(incomplete))}, + {Name: "C", Type: ChanOf(BothDir, incomplete)}, + {Name: "F", Type: FuncOf([]Type{incomplete, SliceOf(incomplete)}, []Type{}, + true)}, + {Name: "M", Type: MapOf(PtrTo(incomplete), incomplete)}, + {Name: "P", Type: PtrTo(incomplete)}, + {Name: "S", Type: SliceOf(incomplete)}, + }) + complete = Bind(incomplete, underlying) + if complete == underlying { + t.Error("recursiveStruct complete = underlying") + } + if complete.Field(0).Type.Elem().Elem() != complete || + complete.Field(1).Type.Elem() != complete || + complete.Field(2).Type.In(0) != complete || + complete.Field(2).Type.In(1).Elem() != complete || + complete.Field(3).Type.Key().Elem() != complete || + complete.Field(3).Type.Elem() != complete || + complete.Field(4).Type.Elem() != complete || + complete.Field(5).Type.Elem() != complete { + t.Error("recursive struct incomplete type substitution error") + } + + // invalid self-bind + incomplete = Named("invalidSelf") + shouldPanic("incomplete underlying type", func() { + Bind(incomplete, incomplete) + }) + + // invalid cross-bind + incomplete1 := Named("incomplete1") + incomplete2 := Named("incomplete2") + underlying = PtrTo(incomplete1) + shouldPanic("incomplete underlying type", func() { + Bind(incomplete2, underlying) + }) + + // instantiate a recursive type and play around with it + incomplete = Named("Recursive") + underlying = StructOf([]StructField{ + {Name: "Bool", Type: TypeOf(true)}, + {Name: "Int", Type: TypeOf(0)}, + {Name: "Int8", Type: TypeOf(int8(0))}, + {Name: "Int16", Type: TypeOf(int16(0))}, + {Name: "Int32", Type: TypeOf(int32(0))}, + {Name: "Int64", Type: TypeOf(int64(0))}, + {Name: "Uint", Type: TypeOf(uint(0))}, + {Name: "Uint8", Type: TypeOf(uint8(0))}, + {Name: "Uint16", Type: TypeOf(uint16(0))}, + {Name: "Uint32", Type: TypeOf(uint32(0))}, + {Name: "Uint64", Type: TypeOf(uint64(0))}, + {Name: "Uintptr", Type: TypeOf(uintptr(0))}, + {Name: "Float32", Type: TypeOf(float32(0))}, + {Name: "Float64", Type: TypeOf(float64(0))}, + {Name: "Complex64", Type: TypeOf(complex64(0))}, + {Name: "Complex128", Type: TypeOf(complex128(0))}, + {Name: "Array", Type: ArrayOf(42, PtrTo(incomplete))}, + {Name: "Chan", Type: ChanOf(BothDir, incomplete)}, + {Name: "Func", Type: FuncOf([]Type{incomplete}, []Type{}, false)}, + {Name: "Interface", Type: TypeOf(new(interface{})).Elem()}, + {Name: "Map", Type: MapOf(PtrTo(incomplete), incomplete)}, + {Name: "Ptr", Type: PtrTo(incomplete)}, + {Name: "Slice", Type: SliceOf(incomplete)}, + {Name: "String", Type: TypeOf("")}, + {Name: "Struct", Type: StructOf([]StructField{ + {Name: "Subfield", Type: PtrTo(incomplete)}, + })}, + {Name: "UnsafePointer", Type: TypeOf(unsafe.Pointer(uintptr(0)))}, + }) + complete = Bind(incomplete, underlying) + value := New(complete).Elem() + value.FieldByName("Bool").SetBool(true) + if !value.FieldByName("Bool").Bool() { + t.Error("set Bool field failed") + } + value.FieldByName("Int").SetInt(-132) + if value.FieldByName("Int").Int() != -132 { + t.Error("set Int field failed") + } + value.FieldByName("Int8").SetInt(-42) + if value.FieldByName("Int8").Int() != -42 { + t.Error("set Int8 field failed") + } + value.FieldByName("Int16").SetInt(-4242) + if value.FieldByName("Int16").Int() != -4242 { + t.Error("set Int16 field failed") + } + value.FieldByName("Int32").SetInt(-424242) + if value.FieldByName("Int32").Int() != -424242 { + t.Error("set Int32 field failed") + } + value.FieldByName("Int64").SetInt(-424242424242) + if value.FieldByName("Int64").Int() != -424242424242 { + t.Error("set Int64 field failed") + } + value.FieldByName("Uint").SetUint(132) + if value.FieldByName("Uint").Uint() != 132 { + t.Error("set Uint field failed") + } + value.FieldByName("Uint8").SetUint(42) + if value.FieldByName("Uint8").Uint() != 42 { + t.Error("set Uint8 field failed") + } + value.FieldByName("Uint16").SetUint(4242) + if value.FieldByName("Uint16").Uint() != 4242 { + t.Error("set Uint16 field failed") + } + value.FieldByName("Uint32").SetUint(424242) + if value.FieldByName("Uint32").Uint() != 424242 { + t.Error("set Uint32 field failed") + } + value.FieldByName("Uint64").SetUint(424242424242) + if value.FieldByName("Uint64").Uint() != 424242424242 { + t.Error("set Uint64 field failed") + } + value.FieldByName("Uintptr").SetUint(0xdeadbeef) + if value.FieldByName("Uintptr").Uint() != 0xdeadbeef { + t.Error("set Uintptr field failed") + } + value.FieldByName("Float32").SetFloat(3.141) + if float32(value.FieldByName("Float32").Float()) != float32(3.141) { + t.Error("set Float32 field failed") + } + value.FieldByName("Float64").SetFloat(2.71828) + if value.FieldByName("Float64").Float() != 2.71828 { + t.Error("set Float64 field failed") + } + value.FieldByName("Complex64").SetComplex(3.141 + 1.6i) + if complex64(value.FieldByName("Complex64").Complex()) != + complex64(3.141+1.6i) { + t.Error("set Complex64 field failed") + } + value.FieldByName("Complex128").SetComplex(1.6 + 2.71828i) + if value.FieldByName("Complex128").Complex() != 1.6+2.71828i { + t.Error("set Complex128 field failed") + } + value.FieldByName("Array").Index(21).Set(value.Addr()) + if value.FieldByName("Array").Index(21).Pointer() != + value.Addr().Pointer() { + t.Errorf("Array[21] = %#x, want %#x", + value.FieldByName("Array").Index(21).Pointer(), value.Addr().Pointer()) + } + ch := MakeChan(ChanOf(BothDir, complete), 1) + done := make(chan interface{}, 1) + go func() { + defer func() { + if r := recover(); r != nil { + done <- r + } + close(done) + }() + v, ok := ch.Recv() + if !ok { + panic("Chan closed unexpectedly") + } + if !DeepEqual(v.Interface(), value.Interface()) { + panic("Chan received value not equal to original") + } + }() + value.FieldByName("Chan").Set(ch) + if value.FieldByName("Chan").Pointer() != ch.Pointer() { + t.Errorf("Chan = %#x, want %#x", + value.FieldByName("Chan").Pointer(), ch.Pointer()) + } + value.FieldByName("Chan").Send(value) + value.FieldByName("Chan").Close() + if r := <-done; r != nil { + panic(r) + } + f := MakeFunc(FuncOf( + []Type{complete}, []Type{}, false, + ), func(args []Value) (results []Value) { + if len(args) != 1 { + t.Errorf("Func bad arg len = %d, want 1", len(args)) + } + if args[0].Type() != value.Type() { + t.Error("Func arg value type not equal to original") + } + return []Value{} + }) + value.FieldByName("Func").Set(f) + if r := value.FieldByName("Func").Call([]Value{value}); len(r) != 0 { + t.Errorf("Func returned %d values, want 0", len(r)) + } + value.FieldByName("Func").Set(Zero(f.Type())) // Reset so we can DeepEqual + value.FieldByName("Interface").Set(value) + if value.FieldByName("Interface").Elem().Type() != value.Type() { + t.Error("set Interface field failed") + } + value.FieldByName("Interface").Set(value.Addr()) + if !DeepEqual(value.FieldByName("Interface").Elem().Elem().Interface(), + value.Interface()) { + t.Error("Interface value pointer content not equal to original") + } + value.FieldByName("Map").Set(MakeMap(MapOf(PtrTo(complete), complete))) + value.FieldByName("Map").SetMapIndex(value.Addr(), value) + if !DeepEqual(value.FieldByName("Map").MapIndex(value.Addr()).Interface(), + value.Interface()) { + t.Error("Map value not equal to original") + } + value.FieldByName("Ptr").Set(value.Addr()) + if value.FieldByName("Ptr").Pointer() != value.Addr().Pointer() { + t.Errorf("Ptr = %#x, want %#x", + value.FieldByName("Ptr").Pointer(), value.Addr().Pointer()) + } + value.FieldByName("Slice").Set(MakeSlice(SliceOf(complete), 1, 1)) + value.FieldByName("Slice").Index(0).Set(value) + if !DeepEqual(value.FieldByName("Slice").Index(0).Interface(), + value.Interface()) { + t.Error("Slice[0] not equal to original") + } + value.FieldByName("String").SetString("test") + if value.FieldByName("String").String() != "test" { + t.Errorf("String = %s, want test", value.FieldByName("String").String()) + } + value.FieldByName("Struct").Field(0).Set(value.Addr()) + if value.FieldByName("Struct").FieldByName("Subfield").Pointer() != + value.Addr().Pointer() { + t.Errorf("Struct.Subfield = %#x, want %#x", + value.FieldByName("Struct").FieldByName("Subfield").Pointer(), + value.Addr().Pointer()) + } + value.FieldByName("UnsafePointer"). + SetPointer(unsafe.Pointer(uintptr(0xdeadbeef))) + if value.FieldByName("UnsafePointer").Pointer() != 0xdeadbeef { + t.Errorf("UnsafePointer = %#x, want 0xdeadbeef", + value.FieldByName("UnsafePointer").Pointer()) + } +} + func TestTypelinksSorted(t *testing.T) { var last string for i, n := range TypeLinks() { diff --git a/src/reflect/type.go b/src/reflect/type.go index 32cc1ce0b2315..993b7f0fa10ba 100644 --- a/src/reflect/type.go +++ b/src/reflect/type.go @@ -294,6 +294,10 @@ const ( // tflagRegularMemory means that equal and hash functions can treat // this type as a single region of t.size bytes. tflagRegularMemory tflag = 1 << 3 + + // tflagIncomplete means this type was created by Named and Bind has not + // been called on it yet. + tflagIncomplete tflag = 1 << 4 ) // rtype is the common implementation of most values. @@ -1863,24 +1867,27 @@ func MapOf(key, elem Type) Type { mt.hash = fnv1(etyp.hash, 'm', byte(ktyp.hash>>24), byte(ktyp.hash>>16), byte(ktyp.hash>>8), byte(ktyp.hash)) mt.key = ktyp mt.elem = etyp - mt.bucket = bucketOf(ktyp, etyp) - mt.hasher = func(p unsafe.Pointer, seed uintptr) uintptr { - return typehash(ktyp, p, seed) - } mt.flags = 0 + if etyp.tflag&tflagIncomplete == 0 { + // For incomplete element types, these fields are set in Bind. + mt.bucket = bucketOf(ktyp, etyp) + mt.hasher = func(p unsafe.Pointer, seed uintptr) uintptr { + return typehash(ktyp, p, seed) + } + if etyp.size > maxValSize { + mt.valuesize = uint8(ptrSize) + mt.flags |= 2 // indirect value + } else { + mt.valuesize = uint8(etyp.size) + } + mt.bucketsize = uint16(mt.bucket.size) + } if ktyp.size > maxKeySize { mt.keysize = uint8(ptrSize) mt.flags |= 1 // indirect key } else { mt.keysize = uint8(ktyp.size) } - if etyp.size > maxValSize { - mt.valuesize = uint8(ptrSize) - mt.flags |= 2 // indirect value - } else { - mt.valuesize = uint8(etyp.size) - } - mt.bucketsize = uint16(mt.bucket.size) if isReflexive(ktyp) { mt.flags |= 4 } @@ -2353,6 +2360,23 @@ func isValidFieldName(fieldName string) bool { return isIdentifier(fieldName) } +// isExportableTypeName checks if a string is an exportable type name or not. +// isExportableTypeName panics if the string is not a valid type name at all. +// +// According to the language spec, a type name should be an identifier. +// According to the language spec, a type name is exported if the first +// character of the identifier's name is a Unicode upper case letter and +// the declaration is in the package block. For the purposes of the reflect +// package, this last requirement is ignored. +func isExportableTypeName(typeName string) bool { + if !isIdentifier(typeName) { + panic("reflect: invalid type name") + } + + r, _ := utf8.DecodeRuneInString(typeName) + return unicode.IsUpper(r) +} + // StructOf returns the struct type containing fields. // The Offset and Index fields are ignored and computed as they would be // by the compiler. @@ -2826,6 +2850,9 @@ const maxPtrmaskBytes = 2048 // ArrayOf panics. func ArrayOf(count int, elem Type) Type { typ := elem.(*rtype) + if typ.tflag&tflagIncomplete != 0 { + panic("reflect.ArrayOf: incomplete element type") + } // Look in cache. ckey := cacheKey{Array, typ, nil, uintptr(count)} @@ -2952,6 +2979,198 @@ func ArrayOf(count int, elem Type) Type { return ti.(Type) } +// Named creates a new incomplete type with the specified valid name. +// An invalid name causes Named to panic. +// +// The resulting type can be used as an argument to ArrayOf, ChanOf, FuncOf, +// MapOf, PtrTo, SliceOf, and StructOf whenever the analogous compile-time type +// specification in a type declaration would result in a valid recursive type. +// Invalid combinations cause Named to panic. E. g., passing an incomplete type +// to ArrayOf is invalid, but passing an incomplete type to PtrTo and then that +// type to ArrayOf is valid. +// +// Any such type, including the incomplete named type itself, may not be used +// in any other context until it has been completed with Bind. A future version +// may relax this requirement (e. g., allow use in MakeFunc before calling +// Bind). +func Named(name string) Type { + return &rtype{ + tflag: tflagNamed | tflagIncomplete, + hash: fnv1(0, []byte(name)...), + str: resolveReflectName(newName(name, "", isExportableTypeName(name))), + } +} + +// Bind binds the specified incomplete named type to underlying akin to a valid +// type declaration and returns a completed version of named. The underlying +// type may indirectly reference the named type, but no +// other incomplete type created with Named. A Bind which represents an illegal +// type declaration panics. +func Bind(named Type, underlying Type) Type { + stub := named.(*rtype) + typ := underlying.(*rtype) + if typ.tflag&tflagIncomplete != 0 { + panic("reflect.Bind: incomplete underlying type") + } + + var complete *rtype + switch typ.Kind() { + case Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, + Uint64, Uintptr, Float32, Float64, Complex64, Complex128, String, + UnsafePointer: + t := *typ + complete = &t + case Array: + t := *(*arrayType)(unsafe.Pointer(typ)) + complete = (*rtype)(unsafe.Pointer(&t)) + case Chan: + t := *(*chanType)(unsafe.Pointer(typ)) + complete = (*rtype)(unsafe.Pointer(&t)) + case Func: + // Model complete as a funcType followed directly in memory by an array of + // rtype pointers for the arg/return types. + tp := (*funcType)(unsafe.Pointer(typ)) + in := tp.in() + out := tp.out() + n := len(in) + len(out) + ft := New(StructOf([]StructField{ + {Name: "B", Type: TypeOf(funcType{})}, + {Name: "A", Type: ArrayOf(n, TypeOf((*rtype)(nil)))}, + })) + base := (*funcType)(unsafe.Pointer(ft.Elem().Field(0).UnsafeAddr())) + *base = *tp + copy(ft.Elem().Field(1).Slice(0, len(in)).Interface().([]*rtype), in) + copy(ft.Elem().Field(1).Slice(len(in), n).Interface().([]*rtype), out) + complete = (*rtype)(unsafe.Pointer(base)) + case Interface: + // There is no InterfaceOf yet, see #4146. + panic("reflect.Bind: binding to interface type is not supported") + case Map: + t := *(*mapType)(unsafe.Pointer(typ)) + complete = (*rtype)(unsafe.Pointer(&t)) + case Ptr: + t := *(*ptrType)(unsafe.Pointer(typ)) + complete = (*rtype)(unsafe.Pointer(&t)) + case Slice: + t := *(*sliceType)(unsafe.Pointer(typ)) + complete = (*rtype)(unsafe.Pointer(&t)) + case Struct: + t := *(*structType)(unsafe.Pointer(typ)) + complete = (*rtype)(unsafe.Pointer(&t)) + } + complete.hash = stub.hash + complete.tflag &^= tflagExtraStar | tflagUncommon + complete.tflag |= tflagNamed + complete.str = stub.str + + updateTypes(make(map[*rtype]struct{}), stub, complete, complete) + + return complete +} + +// updateTypes replaces each occurrence of find in typ with repl. The seen map +// is used to avoid infinite recursion. +func updateTypes(seen map[*rtype]struct{}, find, repl, typ *rtype) { + if _, ok := seen[typ]; ok { + return + } + seen[typ] = struct{}{} + if typ.tflag&tflagIncomplete != 0 { + panic("reflect.Bind: incomplete underlying type") + } + + switch typ.Kind() { + case Bool, Int, Int8, Int16, Int32, Int64, Uint, Uint8, Uint16, Uint32, + Uint64, Uintptr, Float32, Float64, Complex64, Complex128, Interface, String, + UnsafePointer: + case Array: + tp := (*arrayType)(unsafe.Pointer(typ)) + updateTypes(seen, find, repl, tp.elem) + case Chan: + tp := (*chanType)(unsafe.Pointer(typ)) + if tp.elem == find { + // Size restriction, see ChanOf. + if repl.size >= 1<<16 { + panic("reflect.Bind: channel element size too large") + } + lookupCache.Delete(cacheKey{Chan, tp.elem, nil, tp.dir}) + tp.elem = repl + lookupCache.Store(cacheKey{Chan, tp.elem, nil, tp.dir}, &tp.rtype) + } else { + updateTypes(seen, find, repl, tp.elem) + } + case Func: + tp := (*funcType)(unsafe.Pointer(typ)) + in := tp.in() + out := tp.out() + for i, t := range in { + if t == find { + in[i] = repl + } else { + updateTypes(seen, find, repl, t) + } + } + for i, t := range out { + if t == find { + out[i] = repl + } else { + updateTypes(seen, find, repl, t) + } + } + // No need to update function cache as hash should remain the same + case Map: + tp := (*mapType)(unsafe.Pointer(typ)) + updateTypes(seen, find, repl, tp.key) + tp.hasher = func(p unsafe.Pointer, seed uintptr) uintptr { + return typehash(tp.key, p, seed) + } + if tp.elem == find { + lookupCache.Delete(cacheKey{Map, tp.key, tp.elem, 0}) + tp.elem = repl + // Set some delayed fields, see MapOf + tp.bucket = bucketOf(tp.key, tp.elem) + if tp.elem.size > maxValSize { + tp.valuesize = uint8(ptrSize) + tp.flags |= 2 + } else { + tp.valuesize = uint8(tp.elem.size) + } + tp.bucketsize = uint16(tp.bucket.size) + lookupCache.Store(cacheKey{Map, tp.key, tp.elem, 0}, &tp.rtype) + } else { + updateTypes(seen, find, repl, tp.elem) + } + case Ptr: + tp := (*ptrType)(unsafe.Pointer(typ)) + if tp.elem == find { + ptrMap.Delete(tp.elem) + tp.elem = repl + ptrMap.Store(tp.elem, tp) + } else { + updateTypes(seen, find, repl, tp.elem) + } + case Slice: + tp := (*sliceType)(unsafe.Pointer(typ)) + if tp.elem == find { + lookupCache.Delete(cacheKey{Slice, tp.elem, nil, 0}) + tp.elem = repl + lookupCache.Store(cacheKey{Slice, tp.elem, nil, 0}, &tp.rtype) + } else { + updateTypes(seen, find, repl, tp.elem) + } + case Struct: + tp := (*structType)(unsafe.Pointer(typ)) + for i, f := range tp.fields { + if f.typ == find { + tp.fields[i].typ = repl + } else { + updateTypes(seen, find, repl, f.typ) + } + } + // No need to update struct cache as hash should remain the same + } +} + func appendVarint(x []byte, v uintptr) []byte { for ; v >= 0x80; v >>= 7 { x = append(x, byte(v|0x80)) diff --git a/src/runtime/type.go b/src/runtime/type.go index 52b6cb30b445f..28b006cd5b96d 100644 --- a/src/runtime/type.go +++ b/src/runtime/type.go @@ -22,6 +22,7 @@ const ( tflagExtraStar tflag = 1 << 1 tflagNamed tflag = 1 << 2 tflagRegularMemory tflag = 1 << 3 // equal and hash can treat values of this type as a single region of t.size bytes + tflagIncomplete tflag = 1 << 4 ) // Needs to be in sync with ../cmd/link/internal/ld/decodesym.go:/^func.commonsize,