From 6e759337b06c186154acaf2fe1a7cb5a6b3da2ef Mon Sep 17 00:00:00 2001 From: Leorize Date: Tue, 13 Feb 2024 17:57:48 -0600 Subject: [PATCH] sockets(windows): implement IPv6 support --- src/sys/private/sockets_windows.nim | 114 +++++++++++++++++++++------- 1 file changed, 88 insertions(+), 26 deletions(-) diff --git a/src/sys/private/sockets_windows.nim b/src/sys/private/sockets_windows.nim index 07c2248..1ea7f68 100644 --- a/src/sys/private/sockets_windows.nim +++ b/src/sys/private/sockets_windows.nim @@ -157,11 +157,17 @@ proc `=destroy`(r: var ResolverResultImpl) = FreeAddrInfoW(r.info) r.info = nil -template ip4Resolve() {.dirty.} = +template ipResolve() {.dirty.} = result = new ResolverResultImpl let hints = AddrInfoW( - ai_family: AF_INET + ai_family: + if isNone(kind): + AF_UNSPEC + else: + case kind.get + of V4: AF_INET + of V6: AF_INET6 ) let err = GetAddrInfoW( @@ -191,8 +197,13 @@ template resolvedItems() {.dirty.} = var info = r.info while info != nil: if info.ai_addr != nil: - if info.ai_addr.sa_family == AF_INET: - yield cast[ptr IP4Endpoint](info.ai_addr)[] + case info.ai_addr.sa_family + of AF_INET: + yield IPEndpoint(kind: V4, v4: cast[ptr IP4Endpoint](info.ai_addr)[]) + of AF_INET6: + yield IPEndpoint(kind: V6, v6: cast[ptr IP6Endpoint](info.ai_addr)[]) + else: + discard "Should not be possible, but harmless even if it is" info = info.ai_next @@ -216,9 +227,15 @@ func toWSAFlags(flags: set[SockFlag]): DWORD {.inline.} = var ConnectEx: LPFN_CONNECTEX template tcpConnect() {.dirty.} = + const addressFamily = + when endpoint is IP4Endpoint: + AF_INET + elif endpoint is IP6Endpoint: + AF_INET6 + let sock = initHandle: SocketFD: - WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, nil, 0, toWSAFlags({})) + WSASocketW(addressFamily, SOCK_STREAM, IPPROTO_TCP, nil, 0, toWSAFlags({})) if sock.fd == InvalidFD: raise newOSError(WSAGetLastError(), $Error.Connect) @@ -233,10 +250,16 @@ template tcpConnect() {.dirty.} = result = Conn[TCP] newSocket(sock) template tcpAsyncConnect() {.dirty.} = + const addressFamily = + when endpoint is IP4Endpoint: + AF_INET + elif endpoint is IP6Endpoint: + AF_INET6 + # Use a bare AsyncSocket for this, so that on failure the FD is unregistered. let sock = newAsyncSocket: SocketFD: - WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, nil, 0, toWSAFlags({sfOverlapped})) + WSASocketW(addressFamily, SOCK_STREAM, IPPROTO_TCP, nil, 0, toWSAFlags({sfOverlapped})) if sock.fd == InvalidFD: raise newOSError(WSAGetLastError(), $Error.Connect) @@ -245,7 +268,11 @@ template tcpAsyncConnect() {.dirty.} = # # Bind an "any" address to this so that the system can choose the best local # address to use. - var empty = initEndpoint(IP4Any, PortNone) + var empty = + when addressFamily == AF_INET: + initEndpoint(IP4Any, PortNone) + elif addressFamily == AF_INET6: + initEndpoint(IP6Any, PortNone) if `bind`( wincore.Socket(sock.fd), cast[ptr sockaddr](addr empty), @@ -306,9 +333,15 @@ template tcpAsyncConnect() {.dirty.} = result = AsyncConn[TCP] sock template tcpListen() {.dirty.} = + const addressFamily = + when endpoint is IP4Endpoint: + AF_INET + elif endpoint is IP6Endpoint: + AF_INET6 + var sock = initHandle: SocketFD: - WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, nil, 0, toWSAFlags({})) + WSASocketW(addressFamily, SOCK_STREAM, IPPROTO_TCP, nil, 0, toWSAFlags({})) if sock.fd == InvalidFD: raise newOSError(WSAGetLastError(), $Error.Listen) @@ -330,9 +363,15 @@ template tcpListen() {.dirty.} = result = Listener[TCP] newSocket(sock) template tcpAsyncListen() {.dirty.} = + const addressFamily = + when endpoint is IP4Endpoint: + AF_INET + elif endpoint is IP6Endpoint: + AF_INET6 + var sock = initHandle: SocketFD: - WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, nil, 0, toWSAFlags({sfOverlapped})) + WSASocketW(addressFamily, SOCK_STREAM, IPPROTO_TCP, nil, 0, toWSAFlags({sfOverlapped})) if sock.fd == InvalidFD: raise newOSError(WSAGetLastError(), $Error.Listen) @@ -362,13 +401,30 @@ const AcceptExBufferLength = AcceptExLocalLength + AcceptExRemoteLength template acceptCommon(listener: SocketFD, conn: var Handle[SocketFD], - remote: var IP4Endpoint, overlapped: static bool) = + remote: var IPEndpoint, overlapped: static bool) = # Common parts for dealing with AcceptEx, `overlapped` dictates whether the # operation should be done in an overlapped manner and yields an overlapped # socket. + + # TODO: cache this data + var socketInfo: WSAPROTOCOL_INFOW + var socketInfoSize = cint sizeof(socketInfo) + if getsockopt( + winsock.Socket(listener), + SOL_SOCKET, + SO_PROTOCOL_INFOW, + cast[cstring](addr socketInfo), + addr socketInfoSize + ) != 0: + raise newOSError(WSAGetLastError(), $Error.Accept) + assert socketInfoSize == sizeof(socketInfo): + "SO_PROTOCOL_INFOW output size is different than expected, this is a nim-sys bug" + + let addressFamily = socketInfo.iAddressFamily + conn = initHandle: SocketFD: - WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, nil, 0): + WSASocketW(addressFamily, SOCK_STREAM, IPPROTO_TCP, nil, 0): toWSAFlags: when overlapped: {sfOverlapped} @@ -439,16 +495,14 @@ template acceptCommon(listener: SocketFD, conn: var Handle[SocketFD], addr remoteAddr, addr remoteAddrLength ) - # TODO: Remove this once IPv6 support lands - # - # This is used to verify that we are getting IPv4 address. - assert remoteAddrLength == sizeof remote: - "The length of the endpoint structure does not match assumption. This is a nim-sys bug." - assert remoteAddr.sa_family == AF_INET: - "The address is not IPv4. This is a nim-sys bug." - # Copy the remote address - remote = cast[ptr IP4Endpoint](remoteAddr)[] + case remoteAddr.sa_family + of AF_INET: + remote = IPEndpoint(kind: V4, v4: cast[ptr IP4Endpoint](remoteAddr)[]) + of AF_INET6: + remote = IPEndpoint(kind: V6, v6: cast[ptr IP6Endpoint](remoteAddr)[]) + else: + doAssert false, "Unexpected remote address family: " & $remoteAddr.sa_family # Update the connection attributes so that other functions can be used on the # socket. @@ -473,19 +527,27 @@ template tcpAsyncAccept() {.dirty.} = result.conn = AsyncConn[TCP] newAsyncSocket(move conn) template tcpLocalEndpoint() {.dirty.} = - var endpointLen = cint sizeof(result) + var + saddr: SockaddrStorage + endpointLen = cint sizeof(saddr) if getsockname( wincore.Socket(l.fd), - cast[ptr sockaddr](addr result), + cast[ptr sockaddr](addr saddr), addr endpointLen ) == SocketError: raise newOSError(WSAGetLastError(), $Error.LocalEndpoint) - assert endpointLen == sizeof(result): - "The length of the endpoint structure does not match assumption. This is a nim-sys bug." - assert result.sin_family == AF_INET: - "The address is not IPv4. This is a nim-sys bug." + assert endpointLen <= sizeof(saddr): + "The length of the endpoint structure is bigger than expected size. This is a nim-sys bug." + + case saddr.ss_family + of AF_INET: + result = IPEndpoint(kind: V4, v4: cast[ptr IP4Endpoint](addr saddr)[]) + of AF_INET6: + result = IPEndpoint(kind: V6, v6: cast[ptr IP6Endpoint](addr saddr)[]) + else: + doAssert false, "Unexpected local address family: " & $saddr.ss_family proc initWinsock() = ## Initializes winsock for use with sys/sockets