Skip to content

Commit

Permalink
Update for recent C API changes (#50)
Browse files Browse the repository at this point in the history
* Update for recent C API changes

Updating these bindings to account for bytecodealliance/wasmtime#2579

* Appease go vet
  • Loading branch information
alexcrichton authored Jan 14, 2021
1 parent fa7456e commit 16266bb
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 72 deletions.
93 changes: 46 additions & 47 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ type Caller struct {
// which can't be garbage collected.
type newMapEntry struct {
callback func(*Caller, []Val) ([]Val, *Trap)
nparams int
results []*ValType
}

Expand Down Expand Up @@ -73,7 +72,6 @@ func NewFunc(
idx := gNewMapSlab.allocate()
gNewMap[idx] = newMapEntry{
callback: f,
nparams: len(ty.Params()),
results: ty.Results(),
}
gLock.Unlock()
Expand All @@ -95,8 +93,8 @@ func goTrampolineNew(
caller_id C.size_t,
callerPtr *C.wasmtime_caller_t,
env C.size_t,
argsPtr *C.wasm_val_t,
resultsPtr *C.wasm_val_t,
argsPtr *C.wasm_val_vec_t,
resultsPtr *C.wasm_val_vec_t,
) *C.wasm_trap_t {
idx := int(env)
gLock.Lock()
Expand All @@ -107,9 +105,9 @@ func goTrampolineNew(
caller := &Caller{ptr: callerPtr, freelist: freelist}
defer func() { caller.ptr = nil }()

params := make([]Val, entry.nparams)
params := make([]Val, int(argsPtr.size))
var val C.wasm_val_t
base := unsafe.Pointer(argsPtr)
base := unsafe.Pointer(argsPtr.data)
for i := 0; i < len(params); i++ {
ptr := (*C.wasm_val_t)(unsafe.Pointer(uintptr(base) + uintptr(i)*unsafe.Sizeof(val)))
params[i] = mkVal(ptr, freelist)
Expand Down Expand Up @@ -144,7 +142,7 @@ func goTrampolineNew(
return trap.ptr()
}

base = unsafe.Pointer(resultsPtr)
base = unsafe.Pointer(resultsPtr.data)
for i := 0; i < len(results); i++ {
ptr := (*C.wasm_val_t)(unsafe.Pointer(uintptr(base) + uintptr(i)*unsafe.Sizeof(val)))
C.wasm_val_copy(ptr, results[i].ptr())
Expand Down Expand Up @@ -274,8 +272,8 @@ func goTrampolineWrap(
caller_id C.size_t,
callerPtr *C.wasmtime_caller_t,
env C.size_t,
argsPtr *C.wasm_val_t,
resultsPtr *C.wasm_val_t,
argsPtr *C.wasm_val_vec_t,
resultsPtr *C.wasm_val_vec_t,
) *C.wasm_trap_t {
// Convert all our parameters to `[]reflect.Value`, taking special care
// for `*Caller` but otherwise reading everything through `Val`.
Expand All @@ -291,7 +289,7 @@ func goTrampolineWrap(

ty := entry.callback.Type()
params := make([]reflect.Value, ty.NumIn())
base := unsafe.Pointer(argsPtr)
base := unsafe.Pointer(argsPtr.data)
var raw C.wasm_val_t
for i := 0; i < len(params); i++ {
if ty.In(i) == reflect.TypeOf(caller) {
Expand Down Expand Up @@ -321,7 +319,7 @@ func goTrampolineWrap(

// And now we write all the results into memory depending on the type
// of value that was returned.
base = unsafe.Pointer(resultsPtr)
base = unsafe.Pointer(resultsPtr.data)
for _, result := range results {
ptr := (*C.wasm_val_t)(base)
switch val := result.Interface().(type) {
Expand Down Expand Up @@ -446,51 +444,49 @@ func (f *Func) Call(args ...interface{}) (interface{}, error) {
if len(args) > len(params) {
return nil, errors.New("too many arguments provided")
}
paramsRaw := make([]C.wasm_val_t, len(args))
synthesizedParams := make([]Val, 0)
paramsVec := C.wasm_val_vec_t{}
C.wasm_val_vec_new_uninitialized(&paramsVec, C.size_t(len(args)))
for i, param := range args {
var rawVal Val
switch val := param.(type) {
case int:
switch params[i].Kind() {
case KindI32:
paramsRaw[i] = *ValI32(int32(val)).ptr()
rawVal = ValI32(int32(val))
case KindI64:
paramsRaw[i] = *ValI64(int64(val)).ptr()
rawVal = ValI64(int64(val))
default:
return nil, errors.New("integer provided for non-integer argument")
}
case int32:
paramsRaw[i] = *ValI32(val).ptr()
rawVal = ValI32(val)
case int64:
paramsRaw[i] = *ValI64(val).ptr()
rawVal = ValI64(val)
case float32:
paramsRaw[i] = *ValF32(val).ptr()
rawVal = ValF32(val)
case float64:
paramsRaw[i] = *ValF64(val).ptr()
rawVal = ValF64(val)
case *Func:
ffi := ValFuncref(val)
paramsRaw[i] = *ffi.ptr()
synthesizedParams = append(synthesizedParams, ffi)
rawVal = ValFuncref(val)
case Val:
paramsRaw[i] = *val.ptr()
rawVal = val

default:
ffi := ValExternref(val)
paramsRaw[i] = *ffi.ptr()
synthesizedParams = append(synthesizedParams, ffi)
rawVal = ValExternref(val)
}
}

resultsRaw := make([]C.wasm_val_t, f.ResultArity())
base := unsafe.Pointer(paramsVec.data)
ptr := rawVal.ptr()
C.wasm_val_copy(
(*C.wasm_val_t)(unsafe.Pointer(uintptr(base)+unsafe.Sizeof(*ptr)*uintptr(i))),
ptr,
)
runtime.KeepAlive(rawVal)
}

var paramsPtr, resultsPtr *C.wasm_val_t
resultsVec := C.wasm_val_vec_t{}
C.wasm_val_vec_new_uninitialized(&resultsVec, C.size_t(f.ResultArity()))
var trap *C.wasm_trap_t
if len(paramsRaw) > 0 {
paramsPtr = &paramsRaw[0]
}
if len(resultsRaw) > 0 {
resultsPtr = &resultsRaw[0]
}

// Use our `freelist` as an anchor to get an identifier which our C
// shim shoves into thread-local storage and then pops out on the
Expand All @@ -502,17 +498,14 @@ func (f *Func) Call(args ...interface{}) (interface{}, error) {

err := C.go_wasmtime_func_call(
f.ptr(),
paramsPtr,
C.size_t(len(paramsRaw)),
resultsPtr,
C.size_t(len(resultsRaw)),
&paramsVec,
&resultsVec,
&trap,
caller_id,
)
runtime.KeepAlive(f)
runtime.KeepAlive(paramsRaw)
runtime.KeepAlive(args)
runtime.KeepAlive(synthesizedParams)
C.wasm_val_vec_delete(&paramsVec)

// Clear our thread's caller id from the global maps now that the call
// is finished.
Expand Down Expand Up @@ -543,15 +536,21 @@ func (f *Func) Call(args ...interface{}) (interface{}, error) {
return nil, wrappedTrap
}

if len(resultsRaw) == 0 {
if resultsVec.size == 0 {
return nil, nil
} else if len(resultsRaw) == 1 {
return takeVal(&resultsRaw[0], f.freelist).Get(), nil
} else if resultsVec.size == 1 {
ret := mkVal(resultsVec.data, f.freelist).Get()
C.wasm_val_vec_delete(&resultsVec)
return ret, nil
} else {
results := make([]Val, len(resultsRaw))
for i := 0; i < len(resultsRaw); i++ {
results[i] = takeVal(&resultsRaw[i], f.freelist)
results := make([]Val, int(resultsVec.size))
base := unsafe.Pointer(resultsVec.data)
var val C.wasm_val_t
for i := 0; i < int(resultsVec.size); i++ {
ptr := (*C.wasm_val_t)(unsafe.Pointer(uintptr(base) + unsafe.Sizeof(val)*uintptr(i)))
results[i] = mkVal(ptr, f.freelist)
}
C.wasm_val_vec_delete(&resultsVec)
return results, nil
}

Expand Down
17 changes: 7 additions & 10 deletions instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,25 @@ type Instance struct {
// This will also run the `start` function of the instance, returning an error
// if it traps.
func NewInstance(store *Store, module *Module, imports []*Extern) (*Instance, error) {
importsRaw := make([]*C.wasm_extern_t, len(imports))
importsRaw := C.wasm_extern_vec_t{}
C.wasm_extern_vec_new_uninitialized(&importsRaw, C.size_t(len(imports)))
base := unsafe.Pointer(importsRaw.data)
for i, imp := range imports {
importsRaw[i] = imp.ptr()
}
var importsRawPtr **C.wasm_extern_t
if len(imports) > 0 {
importsRawPtr = &importsRaw[0]
ptr := C.wasm_extern_copy(imp.ptr())
*(**C.wasm_extern_t)(unsafe.Pointer(uintptr(base) + unsafe.Sizeof(ptr)*uintptr(i))) = ptr
}
var trap *C.wasm_trap_t
var ptr *C.wasm_instance_t
err := C.wasmtime_instance_new(
store.ptr(),
module.ptr(),
importsRawPtr,
C.size_t(len(imports)),
&importsRaw,
&ptr,
&trap,
)
runtime.KeepAlive(store)
runtime.KeepAlive(module)
runtime.KeepAlive(imports)
runtime.KeepAlive(importsRaw)
C.wasm_extern_vec_delete(&importsRaw)
if err != nil {
return nil, mkError(err)
}
Expand Down
20 changes: 9 additions & 11 deletions shims.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ __thread size_t caller_id;
static wasm_trap_t* trampoline(
const wasmtime_caller_t *caller,
void *env,
const wasm_val_t *args,
wasm_val_t *results
const wasm_val_vec_t *args,
wasm_val_vec_t *results
) {
return goTrampolineNew(caller_id, (wasmtime_caller_t*) caller, (size_t) env, (wasm_val_t*) args, results);
return goTrampolineNew(caller_id, (wasmtime_caller_t*) caller, (size_t) env, (wasm_val_vec_t*) args, results);
}

static wasm_trap_t* wrap_trampoline(
const wasmtime_caller_t *caller,
void *env,
const wasm_val_t *args,
wasm_val_t *results
const wasm_val_vec_t *args,
wasm_val_vec_t *results
) {
return goTrampolineWrap(caller_id, (wasmtime_caller_t*) caller, (size_t) env, (wasm_val_t*) args, results);
return goTrampolineWrap(caller_id, (wasmtime_caller_t*) caller, (size_t) env, (wasm_val_vec_t*) args, results);
}

wasm_func_t *c_func_new_with_env(wasm_store_t *store, wasm_functype_t *ty, size_t env, int wrap) {
Expand All @@ -29,16 +29,14 @@ wasm_func_t *c_func_new_with_env(wasm_store_t *store, wasm_functype_t *ty, size_

wasmtime_error_t *go_wasmtime_func_call(
wasm_func_t *func,
const wasm_val_t *args,
size_t num_args,
wasm_val_t *results,
size_t num_results,
const wasm_val_vec_t *args,
wasm_val_vec_t *results,
wasm_trap_t **trap,
size_t go_id
) {
size_t prev_caller_id = caller_id;
caller_id = go_id;
wasmtime_error_t *ret = wasmtime_func_call(func, args, num_args, results, num_results, trap);
wasmtime_error_t *ret = wasmtime_func_call(func, args, results, trap);
caller_id = prev_caller_id;
return ret;
}
Expand Down
6 changes: 2 additions & 4 deletions shims.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ void go_externref_new_with_finalizer(
);
wasmtime_error_t *go_wasmtime_func_call(
wasm_func_t *func,
const wasm_val_t *args,
size_t num_args,
wasm_val_t *results,
size_t num_results,
const wasm_val_vec_t *args,
wasm_val_vec_t *results,
wasm_trap_t **trap,
size_t go_id
);
Expand Down

0 comments on commit 16266bb

Please sign in to comment.