Skip to content

Commit

Permalink
Shorten datatype conversions in enums.go with generics. (#321)
Browse files Browse the repository at this point in the history
  • Loading branch information
thetorpedodog authored Jun 17, 2024
1 parent 02ede09 commit ff685d3
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 192 deletions.
4 changes: 1 addition & 3 deletions common.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package tiledb

import (
"reflect"
"unsafe"
)

Expand All @@ -17,6 +16,5 @@ type scalarType interface {

// slicePtr gives you an unsafe pointer to the start of a slice.
func slicePtr[T any](slc []T) unsafe.Pointer {
hdr := (*reflect.SliceHeader)(unsafe.Pointer(&slc))
return unsafe.Pointer(hdr.Data)
return unsafe.Pointer(unsafe.SliceData(slc))
}
254 changes: 65 additions & 189 deletions enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"fmt"
"reflect"
"strconv"
"time"
"unsafe"
)

Expand Down Expand Up @@ -265,240 +266,115 @@ func (d Datatype) Size() uint64 {
func (d Datatype) MakeSlice(numElements uint64) (interface{}, unsafe.Pointer, error) {
switch d {
case TILEDB_INT8:
slice := make([]int8, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[int8](numElements)
case TILEDB_INT16:
slice := make([]int16, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[int16](numElements)
case TILEDB_INT32:
slice := make([]int32, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[int32](numElements)
case TILEDB_INT64, TILEDB_DATETIME_YEAR, TILEDB_DATETIME_MONTH, TILEDB_DATETIME_WEEK, TILEDB_DATETIME_DAY, TILEDB_DATETIME_HR, TILEDB_DATETIME_MIN, TILEDB_DATETIME_SEC, TILEDB_DATETIME_MS, TILEDB_DATETIME_US, TILEDB_DATETIME_NS, TILEDB_DATETIME_PS, TILEDB_DATETIME_FS, TILEDB_DATETIME_AS, TILEDB_TIME_HR, TILEDB_TIME_MIN, TILEDB_TIME_SEC, TILEDB_TIME_MS, TILEDB_TIME_US, TILEDB_TIME_NS, TILEDB_TIME_PS, TILEDB_TIME_FS, TILEDB_TIME_AS:
slice := make([]int64, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[int64](numElements)
case TILEDB_UINT8, TILEDB_CHAR, TILEDB_STRING_ASCII, TILEDB_STRING_UTF8, TILEDB_BLOB, TILEDB_GEOM_WKB, TILEDB_GEOM_WKT:
slice := make([]uint8, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[uint8](numElements)
case TILEDB_UINT16, TILEDB_STRING_UTF16, TILEDB_STRING_UCS2:
slice := make([]uint16, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[uint16](numElements)
case TILEDB_UINT32, TILEDB_STRING_UTF32, TILEDB_STRING_UCS4:
slice := make([]uint32, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[uint32](numElements)
case TILEDB_UINT64:
slice := make([]uint64, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[uint64](numElements)
case TILEDB_FLOAT32:
slice := make([]float32, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[float32](numElements)
case TILEDB_FLOAT64:
slice := make([]float64, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[float64](numElements)
case TILEDB_BOOL:
slice := make([]bool, numElements)
return slice, unsafe.Pointer(&slice[0]), nil

return makeSlice[bool](numElements)
default:
return nil, nil, fmt.Errorf("error making datatype slice; unrecognized datatype: %d", d)
}
}

// makeSlice makes a slice and returns it as well as a pointer to its start.
// Its return type matches d.MakeSlice for convenience.
func makeSlice[T any](numElements uint64) (any, unsafe.Pointer, error) {
slice := make([]T, numElements)
return slice, slicePtr(slice), nil
}

// GetValue gets value stored in a void pointer for this data type.
func (d Datatype) GetValue(valueNum uint, cvalue unsafe.Pointer) (interface{}, error) {
switch d {
case TILEDB_INT8:
if cvalue == nil {
return int8(0), nil
}
if valueNum > 1 {
tmpValue := make([]int8, valueNum)
tmpslice := (*[1 << 46]C.int8_t)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = int8(s)
}
return tmpValue, nil
}
return *(*int8)(cvalue), nil
return getValueInternal[int8](valueNum, cvalue)
case TILEDB_INT16:
if cvalue == nil {
return int16(0), nil
}
if valueNum > 1 {
tmpValue := make([]int16, valueNum)
tmpslice := (*[1 << 46]C.int16_t)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = int16(s)
}
return tmpValue, nil
}
return *(*int16)(cvalue), nil
return getValueInternal[int16](valueNum, cvalue)
case TILEDB_INT32:
if cvalue == nil {
return int32(0), nil
}
if valueNum > 1 {
tmpValue := make([]int32, valueNum)
tmpslice := (*[1 << 46]C.int32_t)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = int32(s)
}
return tmpValue, nil
}
return *(*int32)(cvalue), nil
return getValueInternal[int32](valueNum, cvalue)
case TILEDB_INT64:
if cvalue == nil {
return int64(0), nil
}
if valueNum > 1 {
tmpValue := make([]int64, valueNum)
tmpslice := (*[1 << 46]C.int64_t)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = int64(s)
}
return tmpValue, nil
}
return *(*int64)(cvalue), nil
return getValueInternal[int64](valueNum, cvalue)
case TILEDB_UINT8, TILEDB_BLOB, TILEDB_GEOM_WKB, TILEDB_GEOM_WKT:
if cvalue == nil {
return uint8(0), nil
}
if valueNum > 1 {
tmpValue := make([]uint8, valueNum)
tmpslice := (*[1 << 46]C.uint8_t)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = uint8(s)
}
return tmpValue, nil
}
return *(*uint8)(cvalue), nil
return getValueInternal[uint8](valueNum, cvalue)
case TILEDB_UINT16:
if cvalue == nil {
return uint16(0), nil
}
if valueNum > 1 {
tmpValue := make([]uint16, valueNum)
tmpslice := (*[1 << 46]C.uint16_t)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = uint16(s)
}
return tmpValue, nil
}
return *(*uint16)(cvalue), nil
return getValueInternal[uint16](valueNum, cvalue)
case TILEDB_UINT32:
if cvalue == nil {
return uint32(0), nil
}
if valueNum > 1 {
tmpValue := make([]uint32, valueNum)
tmpslice := (*[1 << 46]C.uint32_t)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = uint32(s)
}
return tmpValue, nil
}
return *(*uint32)(cvalue), nil
return getValueInternal[uint32](valueNum, cvalue)
case TILEDB_UINT64:
if cvalue == nil {
return uint64(0), nil
}
if valueNum > 1 {
tmpValue := make([]uint64, valueNum)
tmpslice := (*[1 << 46]C.uint64_t)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = uint64(s)
}
return tmpValue, nil
}
return *(*uint64)(cvalue), nil
return getValueInternal[uint64](valueNum, cvalue)
case TILEDB_FLOAT32:
if cvalue == nil {
return float32(0), nil
}
if valueNum > 1 {
tmpValue := make([]float32, valueNum)
tmpslice := (*[1 << 46]C.float)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = float32(s)
}
return tmpValue, nil
}
return *(*float32)(cvalue), nil
return getValueInternal[float32](valueNum, cvalue)
case TILEDB_FLOAT64:
if cvalue == nil {
return float64(0), nil
}
if valueNum > 1 {
tmpValue := make([]float64, valueNum)
tmpslice := (*[1 << 46]C.double)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = float64(s)
}
return tmpValue, nil
}
return *(*float64)(cvalue), nil
case TILEDB_CHAR:
if cvalue == nil || valueNum == 0 {
return "", nil
}
tmpslice := (*[1 << 46]C.char)(cvalue)[:valueNum:valueNum]
// TODO: Handle overflow from unsigned conversion
return C.GoStringN(&tmpslice[0], C.int(valueNum))[0:valueNum], nil
case TILEDB_STRING_ASCII:
if cvalue == nil || valueNum == 0 {
return "", nil
}
tmpslice := (*[1 << 46]C.char)(cvalue)[:valueNum:valueNum]
// TODO: Handle overflow from unsigned conversion
return C.GoStringN(&tmpslice[0], C.int(valueNum))[0:valueNum], nil
case TILEDB_STRING_UTF8:
if cvalue == nil || valueNum == 0 {
return "", nil
}
tmpslice := (*[1 << 46]C.char)(cvalue)[:valueNum:valueNum]
// TODO: Handle overflow from unsigned conversion
return C.GoStringN(&tmpslice[0], C.int(valueNum))[0:valueNum], nil
return getValueInternal[float64](valueNum, cvalue)
case TILEDB_CHAR, TILEDB_STRING_ASCII, TILEDB_STRING_UTF8:
return C.GoStringN((*C.char)(cvalue), C.int(valueNum)), nil
case TILEDB_DATETIME_YEAR, TILEDB_DATETIME_MONTH, TILEDB_DATETIME_WEEK,
TILEDB_DATETIME_DAY, TILEDB_DATETIME_HR, TILEDB_DATETIME_MIN,
TILEDB_DATETIME_SEC, TILEDB_DATETIME_MS, TILEDB_DATETIME_US,
TILEDB_DATETIME_NS, TILEDB_DATETIME_PS, TILEDB_DATETIME_FS,
TILEDB_DATETIME_AS, TILEDB_TIME_HR, TILEDB_TIME_MIN, TILEDB_TIME_SEC, TILEDB_TIME_MS, TILEDB_TIME_US, TILEDB_TIME_NS, TILEDB_TIME_PS, TILEDB_TIME_FS, TILEDB_TIME_AS:
if valueNum > 1 {
return nil, fmt.Errorf("Unrecognized value type: %d", d)
} else {
if cvalue == nil {
return int64(0), nil
}
var timestamp interface{} = *(*int16)(cvalue)
return GetTimeFromTimestamp(d, timestamp.(int64)), nil
return nil, fmt.Errorf("only 1 timestamp may be returned, not %d", d)
}
if cvalue == nil {
return time.Time{}, nil
}
timestamp := *(*int64)(cvalue)
return GetTimeFromTimestamp(d, timestamp), nil
case TILEDB_BOOL:
// We handle this differently to ensure that our bools are always in the
// canonical form (true/1 or false/0).
if cvalue == nil {
return false, nil
}
if valueNum > 1 {
tmpValue := make([]bool, valueNum)
tmpslice := (*[1 << 46]C.int8_t)(cvalue)[:valueNum:valueNum]
for i, s := range tmpslice {
tmpValue[i] = s != 0
}
return tmpValue, nil
bytes := unsafeSlice[byte](cvalue, valueNum)
if valueNum == 1 {
return bytes[0] != 0, nil
}
bools := make([]bool, valueNum)
for i, b := range bytes {
bools[i] = b != 0
}
return *(*int8)(cvalue), nil
return bools, nil
default:
return nil, fmt.Errorf("Unrecognized value type: %d", d)
}
}

// getValueInternal handles the internals of Datatype.GetValue. It returns
// `valueNum` Ts located at `ptr`. As a special case, if valueNum == 1,
// it returns a T itself rather than a []T.
func getValueInternal[T any](valueNum uint, ptr unsafe.Pointer) (any, error) {
var singleValue T
if ptr == nil {
return singleValue, nil
}
if valueNum == 1 {
singleValue = *(*T)(ptr)
return singleValue, nil
}
out := make([]T, valueNum)
inSlice := unsafeSlice[T](ptr, valueNum)
copy(out, inSlice)
return out, nil
}

var tileDBInt, tileDBUint = intUintTypes() // The Datatypes of Go `int` and `uint`.

func intUintTypes() (Datatype, Datatype) {
Expand Down
9 changes: 9 additions & 0 deletions memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,12 @@ func (bb byteBuffer) subSlice(sliceStart unsafe.Pointer, sliceBytes uintptr) []b
startIdx := uintptr(sliceStart) - uintptr(bb.start())
return bb[startIdx:sliceBytes]
}

// unsafeSlice creates a slice pointing at the given memory.
func unsafeSlice[T any](ptr unsafe.Pointer, length uint) []T {
if ptr == nil {
return nil
}
typedPtr := (*T)(ptr)
return unsafe.Slice(typedPtr, length)
}

0 comments on commit ff685d3

Please sign in to comment.