diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go index 5197052fab..7386112395 100644 --- a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_other.go @@ -1,4 +1,4 @@ -//go:build !unix +//go:build !unix && !windows package sampledconn diff --git a/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go new file mode 100644 index 0000000000..46b0617996 --- /dev/null +++ b/p2p/transport/tcpreuse/internal/sampledconn/sampledconn_windows.go @@ -0,0 +1,49 @@ +//go:build windows + +package sampledconn + +import ( + "errors" + "golang.org/x/sys/windows" + "syscall" +) + +func OSPeekConn(conn syscall.Conn) (PeekedBytes, error) { + s := PeekedBytes{} + + rawConn, err := conn.SyscallConn() + if err != nil { + return s, err + } + + readBytes := 0 + var readErr error + err = rawConn.Read(func(fd uintptr) bool { + for readBytes < peekSize { + var n uint32 + flags := uint32(windows.MSG_PEEK) + wsabuf := windows.WSABuf{ + Len: uint32(len(s) - readBytes), + Buf: &s[readBytes], + } + + readErr = windows.WSARecv(windows.Handle(fd), &wsabuf, 1, &n, &flags, nil, nil) + if errors.Is(readErr, windows.WSAEWOULDBLOCK) { + return false + } + if readErr != nil { + return true + } + readBytes += int(n) + } + return true + }) + if readErr != nil { + return s, readErr + } + if err != nil { + return s, err + } + + return s, nil +}