Skip to content

Commit

Permalink
Add Opt[IpAddress].
Browse files Browse the repository at this point in the history
Make IPv4 mapping to IPv6 space automatic.
  • Loading branch information
cheatfate committed Mar 30, 2024
1 parent 453990a commit 33476e3
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 67 deletions.
2 changes: 1 addition & 1 deletion chronos/transports/common.nim
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type
ServerFlags* = enum
## Server's flags
ReuseAddr, ReusePort, TcpNoDelay, NoAutoRead, GCUserData, FirstPipe,
NoPipeFlash, Broadcast
NoPipeFlash, Broadcast, V4Mapped

DualStackType* {.pure.} = enum
Auto, Enabled, Disabled, Default
Expand Down
210 changes: 150 additions & 60 deletions chronos/transports/datagram.nim
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@

import std/deques
when not(defined(windows)): import ".."/selectors2
import ".."/[asyncloop, osdefs, oserrno, osutils, handles]
import "."/common
import ".."/[asyncloop, config, osdefs, oserrno, osutils, handles]
import "."/[common, ipnet]
import stew/ptrops

type
Expand Down Expand Up @@ -60,29 +60,78 @@ type
const
DgramTransportTrackerName* = "datagram.transport"

proc remoteAddress*(transp: DatagramTransport): TransportAddress {.
raises: [TransportOsError].} =
proc getRemoteAddress(transp: DatagramTransport,
address: Sockaddr_storage, length: SockLen,
): TransportAddress =
var raddr: TransportAddress
fromSAddr(unsafeAddr address, length, raddr)
if ServerFlags.V4Mapped in transp.flags:
if raddr.isV4Mapped(): raddr.toIPv4() else: raddr
else:
raddr

proc getRemoteAddress(transp: DatagramTransport): TransportAddress =
transp.getRemoteAddress(transp.raddr, transp.ralen)

proc setRemoteAddress(transp: DatagramTransport,
address: TransportAddress): TransportAddress =
let
fixedAddress =
when defined(windows):
windowsAnyAddressFix(address)
else:
address
remoteAddress =
if ServerFlags.V4Mapped in transp.flags:
if address.family == AddressFamily.IPv4:
fixedAddress.toIPv6()
else:
fixedAddress
else:
fixedAddress
toSAddr(remoteAddress, transp.waddr, transp.walen)
remoteAddress

proc remoteAddress2*(
transp: DatagramTransport
): Result[TransportAddress, OSErrorCode] =
## Returns ``transp`` remote socket address.
if transp.remote.family == AddressFamily.None:
var saddr: Sockaddr_storage
var slen = SockLen(sizeof(saddr))
var
saddr: Sockaddr_storage
slen = SockLen(sizeof(saddr))
if getpeername(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
raiseTransportOsError(osLastError())
fromSAddr(addr saddr, slen, transp.remote)
transp.remote
return err(osLastError())
transp.remote = transp.getRemoteAddress(saddr, slen)
ok(transp.remote)

proc localAddress*(transp: DatagramTransport): TransportAddress {.
raises: [TransportOsError].} =
proc localAddress2*(
transp: DatagramTransport
): Result[TransportAddress, OSErrorCode] =
## Returns ``transp`` local socket address.
if transp.local.family == AddressFamily.None:
var saddr: Sockaddr_storage
var slen = SockLen(sizeof(saddr))
var
saddr: Sockaddr_storage
slen = SockLen(sizeof(saddr))
if getsockname(SocketHandle(transp.fd), cast[ptr SockAddr](addr saddr),
addr slen) != 0:
raiseTransportOsError(osLastError())
return err(osLastError())
fromSAddr(addr saddr, slen, transp.local)
transp.local
ok(transp.local)

func toException(v: OSErrorCode): ref TransportOsError =
getTransportOsError(v)

proc remoteAddress*(transp: DatagramTransport): TransportAddress {.
raises: [TransportOsError].} =
## Returns ``transp`` remote socket address.
remoteAddress2(transp).tryGet()

proc localAddress*(transp: DatagramTransport): TransportAddress {.
raises: [TransportOsError].} =
## Returns ``transp`` remote socket address.
localAddress2(transp).tryGet()

template setReadError(t, e: untyped) =
(t).state.incl(ReadError)
Expand Down Expand Up @@ -124,8 +173,8 @@ when defined(windows):
transp.setWriterWSABuffer(vector)
let ret =
if vector.kind == WithAddress:
var fixedAddress = windowsAnyAddressFix(vector.address)
toSAddr(fixedAddress, transp.waddr, transp.walen)
# We only need `Sockaddr_storage` data here, so result discarded.
discard transp.setRemoteAddress(vector.address)
wsaSendTo(fd, addr transp.wwsabuf, DWORD(1), addr bytesCount,
DWORD(0), cast[ptr SockAddr](addr transp.waddr),
cint(transp.walen),
Expand Down Expand Up @@ -159,22 +208,24 @@ when defined(windows):
proc readDatagramLoop(udata: pointer) =
var
bytesCount: uint32
raddr: TransportAddress
var ovl = cast[PtrCustomOverlapped](udata)
var transp = cast[DatagramTransport](ovl.data.udata)
ovl = cast[PtrCustomOverlapped](udata)

let transp = cast[DatagramTransport](ovl.data.udata)

while true:
if ReadPending in transp.state:
## Continuation
transp.state.excl(ReadPending)
let err = transp.rovl.data.errCode
let
err = transp.rovl.data.errCode
remoteAddress = transp.getRemoteAddress()
case err
of OSErrorCode(-1):
let bytesCount = transp.rovl.data.bytesCount
if bytesCount == 0:
transp.state.incl({ReadEof, ReadPaused})
fromSAddr(addr transp.raddr, transp.ralen, raddr)
transp.buflen = int(bytesCount)
asyncSpawn transp.function(transp, raddr)
asyncSpawn transp.function(transp, remoteAddress)
of ERROR_OPERATION_ABORTED:
# CancelIO() interrupt or closeSocket() call.
transp.state.incl(ReadPaused)
Expand All @@ -189,7 +240,7 @@ when defined(windows):
transp.setReadError(err)
transp.state.incl(ReadPaused)
transp.buflen = 0
asyncSpawn transp.function(transp, raddr)
asyncSpawn transp.function(transp, remoteAddress)
else:
## Initiation
if transp.state * {ReadEof, ReadClosed, ReadError} == {}:
Expand Down Expand Up @@ -220,7 +271,7 @@ when defined(windows):
transp.state.incl(ReadPaused)
transp.setReadError(err)
transp.buflen = 0
asyncSpawn transp.function(transp, raddr)
asyncSpawn transp.function(transp, transp.getRemoteAddress())
else:
# Transport closure happens in callback, and we not started new
# WSARecvFrom session.
Expand Down Expand Up @@ -341,18 +392,25 @@ when defined(windows):
closeSocket(localSock)
raiseTransportOsError(err)

res.flags =
block:
# Add `V4Mapped` flag when `::` address is used and dualstack is
# set to enabled or auto.
var res = flags
if (local.family == AddressFamily.IPv6) and local.isAnyLocal():
if dualstack in {DualStackType.Enabled, DualStackType.Auto}:
res.incl(ServerFlags.V4Mapped)
res

if remote.port != Port(0):
var fixedAddress = windowsAnyAddressFix(remote)
var saddr: Sockaddr_storage
var slen: SockLen
toSAddr(fixedAddress, saddr, slen)
if connect(SocketHandle(localSock), cast[ptr SockAddr](addr saddr),
slen) != 0:
let remoteAddress = res.setRemoteAddress(remote)
if connect(SocketHandle(localSock), cast[ptr SockAddr](addr res.waddr),
res.walen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeSocket(localSock)
raiseTransportOsError(err)
res.remote = fixedAddress
res.remote = remoteAddress

res.fd = localSock
res.function = cbproc
Expand All @@ -362,12 +420,12 @@ when defined(windows):
res.state = {ReadPaused, WritePaused}
res.future = Future[void].Raising([]).init(
"datagram.transport", {FutureFlag.OwnCancelSchedule})
res.rovl.data = CompletionData(cb: readDatagramLoop,
udata: cast[pointer](res))
res.wovl.data = CompletionData(cb: writeDatagramLoop,
udata: cast[pointer](res))
res.rwsabuf = WSABUF(buf: cast[cstring](baseAddr res.buffer),
len: ULONG(len(res.buffer)))
res.rovl.data = CompletionData(
cb: readDatagramLoop, udata: cast[pointer](res))
res.wovl.data = CompletionData(
cb: writeDatagramLoop, udata: cast[pointer](res))
res.rwsabuf = WSABUF(
buf: cast[cstring](baseAddr res.buffer), len: ULONG(len(res.buffer)))
GC_ref(res)
# Start tracking transport
trackCounter(DgramTransportTrackerName)
Expand All @@ -380,10 +438,10 @@ else:
# Linux/BSD/MacOS part

proc readDatagramLoop(udata: pointer) {.raises: [].}=
var raddr: TransportAddress
doAssert(not isNil(udata))
let transp = cast[DatagramTransport](udata)
let fd = SocketHandle(transp.fd)
let
transp = cast[DatagramTransport](udata)
fd = SocketHandle(transp.fd)
if int(fd) == 0:
## This situation can be happen, when there events present
## after transport was closed.
Expand All @@ -398,9 +456,8 @@ else:
cast[ptr SockAddr](addr transp.raddr),
addr transp.ralen)
if res >= 0:
fromSAddr(addr transp.raddr, transp.ralen, raddr)
transp.buflen = res
asyncSpawn transp.function(transp, raddr)
asyncSpawn transp.function(transp, transp.getRemoteAddress())
else:
let err = osLastError()
case err
Expand All @@ -409,14 +466,15 @@ else:
else:
transp.buflen = 0
transp.setReadError(err)
asyncSpawn transp.function(transp, raddr)
asyncSpawn transp.function(transp, transp.getRemoteAddress())
break

proc writeDatagramLoop(udata: pointer) =
var res: int
doAssert(not isNil(udata))
var transp = cast[DatagramTransport](udata)
let fd = SocketHandle(transp.fd)
let
transp = cast[DatagramTransport](udata)
fd = SocketHandle(transp.fd)
if int(fd) == 0:
## This situation can be happen, when there events present
## after transport was closed.
Expand All @@ -428,7 +486,8 @@ else:
let vector = transp.queue.popFirst()
while true:
if vector.kind == WithAddress:
toSAddr(vector.address, transp.waddr, transp.walen)
# We only need `Sockaddr_storage` data here, so result discarded.
discard transp.setRemoteAddress(vector.address)
res = osdefs.sendto(fd, vector.buf, vector.buflen, MSG_NOSIGNAL,
cast[ptr SockAddr](addr transp.waddr),
transp.walen)
Expand Down Expand Up @@ -551,21 +610,28 @@ else:
closeSocket(localSock)
raiseTransportOsError(err)

res.flags =
block:
# Add `V4Mapped` flag when `::` address is used and dualstack is
# set to enabled or auto.
var res = flags
if (local.family == AddressFamily.IPv6) and local.isAnyLocal():
if dualstack != DualStackType.Disabled:
res.incl(ServerFlags.V4Mapped)
res

if remote.port != Port(0):
var saddr: Sockaddr_storage
var slen: SockLen
toSAddr(remote, saddr, slen)
if connect(SocketHandle(localSock), cast[ptr SockAddr](addr saddr),
slen) != 0:
let remoteAddress = res.setRemoteAddress(remote)
if connect(SocketHandle(localSock), cast[ptr SockAddr](addr res.waddr),
res.walen) != 0:
let err = osLastError()
if sock == asyncInvalidSocket:
closeSocket(localSock)
raiseTransportOsError(err)
res.remote = remote
res.remote = remoteAddress

res.fd = localSock
res.function = cbproc
res.flags = flags
res.buffer = newSeq[byte](bufferSize)
res.queue = initDeque[GramVector]()
res.udata = udata
Expand Down Expand Up @@ -826,6 +892,7 @@ proc newDatagramTransport6*[T](cbproc: UnsafeDatagramCallback,

proc newDatagramTransport*(cbproc: DatagramCallback,
port: Port,
local: Opt[IpAddress] = Opt.none(IpAddress),
flags: set[ServerFlags] = {},
udata: pointer = nil,
child: DatagramTransport = nil,
Expand All @@ -848,14 +915,28 @@ proc newDatagramTransport*(cbproc: DatagramCallback,
## ``ttl`` - TTL for UDP datagram packet (only usable when flags has
## ``Broadcast`` option).
let
localHost = getAutoAddress(port)
remoteHost = getAutoAddress(Port(0))
(localHost, remoteHost) =
block:
let
lres =
if local.isSome():
initTAddress(local.get(), port)
else:
getAutoAddress(port)
rres =
if lres.family == AddressFamily.IPv4:
AnyAddress
else:
AnyAddress6
(lres, rres)

newDatagramTransportCommon(cbproc, remoteHost, localHost, asyncInvalidSocket,
flags, cast[pointer](udata), child, bufSize, ttl,
dualstack)
flags, cast[pointer](udata), child, bufSize,
ttl, dualstack)

proc newDatagramTransport*[T](cbproc: DatagramCallback,
port: Port,
local: Opt[IpAddress] = Opt.none(IpAddress),
flags: set[ServerFlags] = {},
udata: ref T,
child: DatagramTransport = nil,
Expand All @@ -865,8 +946,17 @@ proc newDatagramTransport*[T](cbproc: DatagramCallback,
): DatagramTransport {.
raises: [TransportOsError].} =
let
localHost = getAutoAddress(port)
remoteHost = getAutoAddress(Port(0))
(localHost, remoteHost) =
block:
let
lres = local.valueOr:
getAutoAddress(port)
rres =
if lres.family == AddressFamily.IPv4:
AnyAddress
else:
AnyAddress6
(lres, rres)
fflags = flags + {GCUserData}
GC_ref(udata)
newDatagramTransportCommon(cbproc, remoteHost, localHost, asyncInvalidSocket,
Expand Down
Loading

0 comments on commit 33476e3

Please sign in to comment.