Skip to content

Commit

Permalink
wasm: restrict supported types in go:wasmimport function signatures
Browse files Browse the repository at this point in the history
Changes the set of types supported in functions declared with the
go:wasmimport directive to only allow 32 bits and 64 bits integers
and floats, as well as unsafe.Pointer in parameters only. Both the
compiler code and the standard library are updated because the new
restrictions require modifying the use of go:wasmimport in the
syscall and runtime packages.

In preparation of enabling packages outside of the standard library
to use the go:wasmimport directive, the error messages are modified
to carry more context and use ErrorfAt instead of Fatalf to avoid
printing the compiler stack trace when a function with an invalid
signature is encountered.

Fixes #59156

Change-Id: Ied8317f8ead9c28f0297060ac35a5b5255ab49db
Reviewed-on: https://go-review.googlesource.com/c/go/+/483415
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
Reviewed-by: David Chase <drchase@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
  • Loading branch information
achille-roussel authored and johanbrandhorst committed Apr 13, 2023
1 parent 92a3cb9 commit 53c0158
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 73 deletions.
45 changes: 35 additions & 10 deletions src/cmd/compile/internal/ssagen/abi.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,23 +366,48 @@ func CreateWasmImportWrapper(fn *ir.Func) bool {
return true
}

func toWasmFields(result *abi.ABIParamResultInfo, abiParams []abi.ABIParamAssignment) []obj.WasmField {
func paramsToWasmFields(f *ir.Func, result *abi.ABIParamResultInfo, abiParams []abi.ABIParamAssignment) []obj.WasmField {
wfs := make([]obj.WasmField, len(abiParams))
for i, p := range abiParams {
t := p.Type
switch {
case t.IsInteger() && t.Size() == 4:
switch t.Kind() {
case types.TINT32, types.TUINT32:
wfs[i].Type = obj.WasmI32
case t.IsInteger() && t.Size() == 8:
case types.TINT64, types.TUINT64:
wfs[i].Type = obj.WasmI64
case t.IsFloat() && t.Size() == 4:
case types.TFLOAT32:
wfs[i].Type = obj.WasmF32
case t.IsFloat() && t.Size() == 8:
case types.TFLOAT64:
wfs[i].Type = obj.WasmF64
case t.IsPtr():
case types.TUNSAFEPTR:
wfs[i].Type = obj.WasmPtr
default:
base.Fatalf("wasm import has bad function signature")
base.ErrorfAt(f.Pos(), 0, "go:wasmimport %s %s: unsupported parameter type %s", f.WasmImport.Module, f.WasmImport.Name, t.String())
}
wfs[i].Offset = p.FrameOffset(result)
}
return wfs
}

func resultsToWasmFields(f *ir.Func, result *abi.ABIParamResultInfo, abiParams []abi.ABIParamAssignment) []obj.WasmField {
if len(abiParams) > 1 {
base.ErrorfAt(f.Pos(), 0, "go:wasmimport %s %s: too many return values", f.WasmImport.Module, f.WasmImport.Name)
return nil
}
wfs := make([]obj.WasmField, len(abiParams))
for i, p := range abiParams {
t := p.Type
switch t.Kind() {
case types.TINT32, types.TUINT32:
wfs[i].Type = obj.WasmI32
case types.TINT64, types.TUINT64:
wfs[i].Type = obj.WasmI64
case types.TFLOAT32:
wfs[i].Type = obj.WasmF32
case types.TFLOAT64:
wfs[i].Type = obj.WasmF64
default:
base.ErrorfAt(f.Pos(), 0, "go:wasmimport %s %s: unsupported result type %s", f.WasmImport.Module, f.WasmImport.Name, t.String())
}
wfs[i].Offset = p.FrameOffset(result)
}
Expand Down Expand Up @@ -419,8 +444,8 @@ func setupWasmABI(f *ir.Func) {
// (import "a_module" "add" (func (param i32 i32) (result i32)))
abiConfig := AbiForBodylessFuncStackMap(f)
abiInfo := abiConfig.ABIAnalyzeFuncType(f.Type().FuncType())
wi.Params = toWasmFields(abiInfo, abiInfo.InParams())
wi.Results = toWasmFields(abiInfo, abiInfo.OutParams())
wi.Params = paramsToWasmFields(f, abiInfo, abiInfo.InParams())
wi.Results = resultsToWasmFields(f, abiInfo, abiInfo.OutParams())
}
f.LSym.Func().WasmImport = &wi
}
34 changes: 17 additions & 17 deletions src/runtime/os_wasip1.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,31 +48,31 @@ func exit(code int32)

//go:wasmimport wasi_snapshot_preview1 args_get
//go:noescape
func args_get(argv *uintptr32, argvBuf *byte) errno
func args_get(argv, argvBuf unsafe.Pointer) errno

//go:wasmimport wasi_snapshot_preview1 args_sizes_get
//go:noescape
func args_sizes_get(argc *size, argvBufLen *size) errno
func args_sizes_get(argc, argvBufLen unsafe.Pointer) errno

//go:wasmimport wasi_snapshot_preview1 clock_time_get
//go:noescape
func clock_time_get(clock_id clockid, precision timestamp, time *timestamp) errno
func clock_time_get(clock_id clockid, precision timestamp, time unsafe.Pointer) errno

//go:wasmimport wasi_snapshot_preview1 environ_get
//go:noescape
func environ_get(environ *uintptr32, environBuf *byte) errno
func environ_get(environ, environBuf unsafe.Pointer) errno

//go:wasmimport wasi_snapshot_preview1 environ_sizes_get
//go:noescape
func environ_sizes_get(environCount *size, environBufLen *size) errno
func environ_sizes_get(environCount, environBufLen unsafe.Pointer) errno

//go:wasmimport wasi_snapshot_preview1 fd_write
//go:noescape
func fd_write(fd int32, iovs *iovec, iovsLen size, nwritten *size) errno
func fd_write(fd int32, iovs unsafe.Pointer, iovsLen size, nwritten unsafe.Pointer) errno

//go:wasmimport wasi_snapshot_preview1 random_get
//go:noescape
func random_get(buf *byte, bufLen size) errno
func random_get(buf unsafe.Pointer, bufLen size) errno

type eventtype = uint8

Expand Down Expand Up @@ -140,15 +140,15 @@ func (u *subscriptionUnion) subscriptionClock() *subscriptionClock {

//go:wasmimport wasi_snapshot_preview1 poll_oneoff
//go:noescape
func poll_oneoff(in *subscription, out *event, nsubscriptions size, nevents *size) errno
func poll_oneoff(in, out unsafe.Pointer, nsubscriptions size, nevents unsafe.Pointer) errno

func write1(fd uintptr, p unsafe.Pointer, n int32) int32 {
iov := iovec{
buf: uintptr32(uintptr(p)),
bufLen: size(n),
}
var nwritten size
if fd_write(int32(fd), &iov, 1, &nwritten) != 0 {
if fd_write(int32(fd), unsafe.Pointer(&iov), 1, unsafe.Pointer(&nwritten)) != 0 {
throw("fd_write failed")
}
return int32(nwritten)
Expand All @@ -167,13 +167,13 @@ func usleep(usec uint32) {
subscription.timeout = timestamp(usec) * 1e3
subscription.precision = 1e3

if poll_oneoff(&in, &out, 1, &nevents) != 0 {
if poll_oneoff(unsafe.Pointer(&in), unsafe.Pointer(&out), 1, unsafe.Pointer(&nevents)) != 0 {
throw("wasi_snapshot_preview1.poll_oneoff")
}
}

func getRandomData(r []byte) {
if random_get(&r[0], size(len(r))) != 0 {
if random_get(unsafe.Pointer(&r[0]), size(len(r))) != 0 {
throw("random_get failed")
}
}
Expand All @@ -182,15 +182,15 @@ func goenvs() {
// arguments
var argc size
var argvBufLen size
if args_sizes_get(&argc, &argvBufLen) != 0 {
if args_sizes_get(unsafe.Pointer(&argc), unsafe.Pointer(&argvBufLen)) != 0 {
throw("args_sizes_get failed")
}

argslice = make([]string, argc)
if argc > 0 {
argv := make([]uintptr32, argc)
argvBuf := make([]byte, argvBufLen)
if args_get(&argv[0], &argvBuf[0]) != 0 {
if args_get(unsafe.Pointer(&argv[0]), unsafe.Pointer(&argvBuf[0])) != 0 {
throw("args_get failed")
}

Expand All @@ -207,15 +207,15 @@ func goenvs() {
// environment
var environCount size
var environBufLen size
if environ_sizes_get(&environCount, &environBufLen) != 0 {
if environ_sizes_get(unsafe.Pointer(&environCount), unsafe.Pointer(&environBufLen)) != 0 {
throw("environ_sizes_get failed")
}

envs = make([]string, environCount)
if environCount > 0 {
environ := make([]uintptr32, environCount)
environBuf := make([]byte, environBufLen)
if environ_get(&environ[0], &environBuf[0]) != 0 {
if environ_get(unsafe.Pointer(&environ[0]), unsafe.Pointer(&environBuf[0])) != 0 {
throw("environ_get failed")
}

Expand All @@ -236,15 +236,15 @@ func walltime() (sec int64, nsec int32) {

func walltime1() (sec int64, nsec int32) {
var time timestamp
if clock_time_get(clockRealtime, 0, &time) != 0 {
if clock_time_get(clockRealtime, 0, unsafe.Pointer(&time)) != 0 {
throw("clock_time_get failed")
}
return int64(time / 1000000000), int32(time % 1000000000)
}

func nanotime1() int64 {
var time timestamp
if clock_time_get(clockMonotonic, 0, &time) != 0 {
if clock_time_get(clockMonotonic, 0, unsafe.Pointer(&time)) != 0 {
throw("clock_time_get failed")
}
return int64(time)
Expand Down
Loading

0 comments on commit 53c0158

Please sign in to comment.