Skip to content

Commit

Permalink
Switch from sycall to windows (#295)
Browse files Browse the repository at this point in the history
Where ever possible, use `golang.org/x/sys/windows` instead of `syscall`
(which has been deprecated since go1.11).

Using `windows.LocalFree` requires using `unsafe.Pointer`, which ensures
that the Go garbage collector does not try to free memory pre-maturely
if it was previously declared as a pointer.

Since `syscall.Handle` is part of API for `vhd` package, it was left
unchanged.

For security descriptor functions, switch to using
`windows.SECURITY_DESCRIPTOR` to avoid unnecessary byte manipulation and
panics due to missing input validation and error checking.

Signed-off-by: Hamza El-Saawy <hamzaelsaawy@microsoft.com>
  • Loading branch information
helsaawy authored Aug 7, 2023
1 parent fec52bd commit 9f0d5dc
Show file tree
Hide file tree
Showing 20 changed files with 291 additions and 251 deletions.
33 changes: 15 additions & 18 deletions backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ import (
"io"
"os"
"runtime"
"syscall"
"unicode/utf16"

"github.com/Microsoft/go-winio/internal/fs"
"golang.org/x/sys/windows"
)

//sys backupRead(h syscall.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead
//sys backupWrite(h syscall.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite
//sys backupRead(h windows.Handle, b []byte, bytesRead *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupRead
//sys backupWrite(h windows.Handle, b []byte, bytesWritten *uint32, abort bool, processSecurity bool, context *uintptr) (err error) = BackupWrite

const (
BackupData = uint32(iota + 1)
Expand Down Expand Up @@ -104,7 +104,7 @@ func (r *BackupStreamReader) Next() (*BackupHeader, error) {
if err := binary.Read(r.r, binary.LittleEndian, name); err != nil {
return nil, err
}
hdr.Name = syscall.UTF16ToString(name)
hdr.Name = windows.UTF16ToString(name)
}
if wsi.StreamID == BackupSparseBlock {
if err := binary.Read(r.r, binary.LittleEndian, &hdr.Offset); err != nil {
Expand Down Expand Up @@ -205,7 +205,7 @@ func NewBackupFileReader(f *os.File, includeSecurity bool) *BackupFileReader {
// Read reads a backup stream from the file by calling the Win32 API BackupRead().
func (r *BackupFileReader) Read(b []byte) (int, error) {
var bytesRead uint32
err := backupRead(syscall.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx)
err := backupRead(windows.Handle(r.f.Fd()), b, &bytesRead, false, r.includeSecurity, &r.ctx)
if err != nil {
return 0, &os.PathError{Op: "BackupRead", Path: r.f.Name(), Err: err}
}
Expand All @@ -220,7 +220,7 @@ func (r *BackupFileReader) Read(b []byte) (int, error) {
// the underlying file.
func (r *BackupFileReader) Close() error {
if r.ctx != 0 {
_ = backupRead(syscall.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx)
_ = backupRead(windows.Handle(r.f.Fd()), nil, nil, true, false, &r.ctx)
runtime.KeepAlive(r.f)
r.ctx = 0
}
Expand All @@ -244,7 +244,7 @@ func NewBackupFileWriter(f *os.File, includeSecurity bool) *BackupFileWriter {
// Write restores a portion of the file using the provided backup stream.
func (w *BackupFileWriter) Write(b []byte) (int, error) {
var bytesWritten uint32
err := backupWrite(syscall.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx)
err := backupWrite(windows.Handle(w.f.Fd()), b, &bytesWritten, false, w.includeSecurity, &w.ctx)
if err != nil {
return 0, &os.PathError{Op: "BackupWrite", Path: w.f.Name(), Err: err}
}
Expand All @@ -259,7 +259,7 @@ func (w *BackupFileWriter) Write(b []byte) (int, error) {
// close the underlying file.
func (w *BackupFileWriter) Close() error {
if w.ctx != 0 {
_ = backupWrite(syscall.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx)
_ = backupWrite(windows.Handle(w.f.Fd()), nil, nil, true, false, &w.ctx)
runtime.KeepAlive(w.f)
w.ctx = 0
}
Expand All @@ -271,17 +271,14 @@ func (w *BackupFileWriter) Close() error {
//
// If the file opened was a directory, it cannot be used with Readdir().
func OpenForBackup(path string, access uint32, share uint32, createmode uint32) (*os.File, error) {
winPath, err := syscall.UTF16FromString(path)
if err != nil {
return nil, err
}
h, err := syscall.CreateFile(&winPath[0],
access,
share,
h, err := fs.CreateFile(path,
fs.AccessMask(access),
fs.FileShareMode(share),
nil,
createmode,
syscall.FILE_FLAG_BACKUP_SEMANTICS|syscall.FILE_FLAG_OPEN_REPARSE_POINT,
0)
fs.FileCreationDisposition(createmode),
fs.FILE_FLAG_BACKUP_SEMANTICS|fs.FILE_FLAG_OPEN_REPARSE_POINT,
0,
)
if err != nil {
err = &os.PathError{Op: "open", Path: path, Err: err}
return nil, err
Expand Down
3 changes: 1 addition & 2 deletions backup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package winio
import (
"io"
"os"
"syscall"
"testing"

"golang.org/x/sys/windows"
Expand Down Expand Up @@ -202,7 +201,7 @@ func makeSparseFile() error {
}
defer f.Close()

err = syscall.DeviceIoControl(syscall.Handle(f.Fd()), windows.FSCTL_SET_SPARSE, nil, 0, nil, 0, nil, nil)
err = windows.DeviceIoControl(windows.Handle(f.Fd()), windows.FSCTL_SET_SPARSE, nil, 0, nil, 0, nil, nil)
if err != nil {
return err
}
Expand Down
5 changes: 2 additions & 3 deletions backuptar/tar.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"path/filepath"
"strconv"
"strings"
"syscall"
"time"

"github.com/Microsoft/go-winio"
Expand Down Expand Up @@ -106,7 +105,7 @@ func BasicInfoHeader(name string, size int64, fileInfo *winio.FileBasicInfo) *ta
hdr.PAXRecords[hdrFileAttributes] = fmt.Sprintf("%d", fileInfo.FileAttributes)
hdr.PAXRecords[hdrCreationTime] = formatPAXTime(time.Unix(0, fileInfo.CreationTime.Nanoseconds()))

if (fileInfo.FileAttributes & syscall.FILE_ATTRIBUTE_DIRECTORY) != 0 {
if (fileInfo.FileAttributes & windows.FILE_ATTRIBUTE_DIRECTORY) != 0 {
hdr.Mode |= cISDIR
hdr.Size = 0
hdr.Typeflag = tar.TypeDir
Expand Down Expand Up @@ -396,7 +395,7 @@ func FileInfoFromHeader(hdr *tar.Header) (name string, size int64, fileInfo *win
fileInfo.FileAttributes = uint32(attr)
} else {
if hdr.Typeflag == tar.TypeDir {
fileInfo.FileAttributes |= syscall.FILE_ATTRIBUTE_DIRECTORY
fileInfo.FileAttributes |= windows.FILE_ATTRIBUTE_DIRECTORY
}
}
if creationTimeStr, ok := hdr.PAXRecords[hdrCreationTime]; ok {
Expand Down
5 changes: 3 additions & 2 deletions ea_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@ package winio
import (
"os"
"reflect"
"syscall"
"testing"
"unsafe"

"golang.org/x/sys/windows"
)

var (
Expand Down Expand Up @@ -82,7 +83,7 @@ func Test_SetFileEa(t *testing.T) {
}
defer os.Remove(f.Name())
defer f.Close()
ntdll := syscall.MustLoadDLL("ntdll.dll")
ntdll := windows.MustLoadDLL("ntdll.dll")
ntSetEaFile := ntdll.MustFindProc("NtSetEaFile")
var iosb [2]uintptr
r, _, _ := ntSetEaFile.Call(f.Fd(),
Expand Down
45 changes: 25 additions & 20 deletions file.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ import (
"golang.org/x/sys/windows"
)

//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult

//todo (go1.19): switch to [atomic.Bool]

Expand Down Expand Up @@ -52,7 +52,7 @@ func (*timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{}

var ioInitOnce sync.Once
var ioCompletionPort syscall.Handle
var ioCompletionPort windows.Handle

// ioResult contains the result of an asynchronous IO operation.
type ioResult struct {
Expand All @@ -62,12 +62,12 @@ type ioResult struct {

// ioOperation represents an outstanding asynchronous Win32 IO.
type ioOperation struct {
o syscall.Overlapped
o windows.Overlapped
ch chan ioResult
}

func initIO() {
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
if err != nil {
panic(err)
}
Expand All @@ -78,7 +78,7 @@ func initIO() {
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected.
type win32File struct {
handle syscall.Handle
handle windows.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
closing atomicBool
Expand All @@ -96,7 +96,7 @@ type deadlineHandler struct {
}

// makeWin32File makes a new win32File from an existing file handle.
func makeWin32File(h syscall.Handle) (*win32File, error) {
func makeWin32File(h windows.Handle) (*win32File, error) {
f := &win32File{handle: h}
ioInitOnce.Do(initIO)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
Expand All @@ -112,7 +112,12 @@ func makeWin32File(h syscall.Handle) (*win32File, error) {
return f, nil
}

// Deprecated: use NewOpenFile instead.
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
return NewOpenFile(windows.Handle(h))
}

func NewOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
// If we return the result of makeWin32File directly, it can result in an
// interface-wrapped nil, rather than a nil interface value.
f, err := makeWin32File(h)
Expand All @@ -132,7 +137,7 @@ func (f *win32File) closeHandle() {
_ = cancelIoEx(f.handle, nil)
f.wg.Wait()
// at this point, no new IO can start
syscall.Close(f.handle)
windows.Close(f.handle)
f.handle = 0
} else {
f.wgLock.Unlock()
Expand Down Expand Up @@ -166,12 +171,12 @@ func (f *win32File) prepareIO() (*ioOperation, error) {
}

// ioCompletionProcessor processes completed async IOs forever.
func ioCompletionProcessor(h syscall.Handle) {
func ioCompletionProcessor(h windows.Handle) {
for {
var bytes uint32
var key uintptr
var op *ioOperation
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE)
err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
if op == nil {
panic(err)
}
Expand All @@ -184,7 +189,7 @@ func ioCompletionProcessor(h syscall.Handle) {
// asyncIO processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed.
func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING { //nolint:errorlint // err is Errno
if err != windows.ERROR_IO_PENDING { //nolint:errorlint // err is Errno
return int(bytes), err
}

Expand All @@ -203,7 +208,7 @@ func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, er
select {
case r = <-c.ch:
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
if f.closing.isSet() {
err = ErrFileClosed
}
Expand All @@ -216,7 +221,7 @@ func (f *win32File) asyncIO(c *ioOperation, d *deadlineHandler, bytes uint32, er
_ = cancelIoEx(f.handle, &c.o)
r = <-c.ch
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
if err == windows.ERROR_OPERATION_ABORTED { //nolint:errorlint // err is Errno
err = ErrTimeout
}
}
Expand All @@ -242,14 +247,14 @@ func (f *win32File) Read(b []byte) (int, error) {
}

var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIO(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b)

// Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF
} else if err == syscall.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno
} else if err == windows.ERROR_BROKEN_PIPE { //nolint:errorlint // err is Errno
return 0, io.EOF
} else {
return n, err
Expand All @@ -269,7 +274,7 @@ func (f *win32File) Write(b []byte) (int, error) {
}

var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIO(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b)
return n, err
Expand All @@ -284,7 +289,7 @@ func (f *win32File) SetWriteDeadline(deadline time.Time) error {
}

func (f *win32File) Flush() error {
return syscall.FlushFileBuffers(f.handle)
return windows.FlushFileBuffers(f.handle)
}

func (f *win32File) Fd() uintptr {
Expand Down
Loading

0 comments on commit 9f0d5dc

Please sign in to comment.