Skip to content

Commit

Permalink
fix ping with TTL on Linux (#99875)
Browse files Browse the repository at this point in the history
* fix ping with TTL on Linux

* feedback

* feedback
  • Loading branch information
wfurt authored Apr 16, 2024
1 parent 92b9dca commit 1e42214
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Runtime.InteropServices;

internal static partial class Interop
{
internal static partial class Sys
{
[StructLayout(LayoutKind.Sequential)]
internal unsafe struct IOVector
{
public byte* Base;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Net.Sockets;
using System.Runtime.InteropServices;

internal static partial class Interop
{
internal static partial class Sys
{
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_ReceiveSocketError")]
internal static unsafe partial SocketError ReceiveSocketError(SafeHandle socket, MessageHeader* messageHeader);
}
}
15 changes: 15 additions & 0 deletions src/libraries/System.Net.Ping/src/System.Net.Ping.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,29 @@
Link="Common\System\Net\SocketProtocolSupportPal.Unix.cs" />
<Compile Include="$(CommonPath)System\Net\NetworkInformation\UnixCommandLinePing.cs"
Link="Common\System\Net\NetworkInformation\UnixCommandLinePing.cs" />
<Compile Include="$(CommonPath)System\Net\IPEndPointExtensions.cs"
Link="Common\System\Net\IPEndPointExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\SocketAddressPal.Unix.cs"
Link="Common\System\Net\SocketAddressPal.Unix.cs" />
<Compile Include="$(CommonPath)System\Net\IPAddressParserStatics.cs"
Link="Common\System\Net\IPAddressParserStatics.cs" />
<Compile Include="$(CommonPath)System\Net\Sockets\SocketErrorPal.Unix.cs"
Link="Common\System\Net\Sockets\SocketErrorPal.Unix" />
<!-- Interop -->
<Compile Include="$(CommonPath)Interop\Unix\Interop.DefaultPathBufferSize.cs"
Link="Common\Interop\Unix\Interop.DefaultPathBufferSize.cs" />
<Compile Include="$(CommonPath)Interop\Unix\Interop.Errors.cs"
Link="Common\Interop\Unix\Interop.Errors.cs" />
<Compile Include="$(CommonPath)Interop\Unix\Interop.Libraries.cs"
Link="Common\Interop\Unix\Interop.Libraries.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.IOVector.cs"
Link="Common\Interop\Unix\System.Native\Interop.IOVector.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.ReceiveSocketError.cs"
Link="Common\Interop\Unix\System.Native\Interop.ReceiveSocketError.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Close.cs"
Link="Common\Interop\Unix\System.Native\Interop.Close.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.MessageHeader.cs"
Link="Common\Interop\Unix\System.Native\Interop.MessageHeader.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Socket.cs"
Link="Common\Interop\Unix\System.Native\Interop.Socket.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SocketAddress.cs"
Expand Down Expand Up @@ -99,6 +113,7 @@

<ItemGroup>
<Reference Include="Microsoft.Win32.Primitives" />
<Reference Include="System.Collections" />
<Reference Include="System.ComponentModel.EventBasedAsync" />
<Reference Include="System.ComponentModel.Primitives" />
<Reference Include="System.Diagnostics.Tracing" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ private static Socket GetRawSocket(SocketConfig socketConfig)
{
// If it is not multicast, use Connect to scope responses only to the target address.
socket.Connect(socketConfig.EndPoint);
unsafe
{
int opt = 1;
// setsockopt(fd, IPPROTO_IP, IP_RECVERR, &value, sizeof(int))
socket.SetRawSocketOption(0, 11, new ReadOnlySpan<byte>(&opt, sizeof(int)));
}
}
#pragma warning restore 618

Expand Down Expand Up @@ -232,11 +238,12 @@ private static bool TryGetPingReply(
return true;
}

private static PingReply SendIcmpEchoRequestOverRawSocket(IPAddress address, byte[] buffer, int timeout, PingOptions? options)
private static unsafe PingReply SendIcmpEchoRequestOverRawSocket(IPAddress address, byte[] buffer, int timeout, PingOptions? options)
{
SocketConfig socketConfig = GetSocketConfig(address, buffer, timeout, options);
using (Socket socket = GetRawSocket(socketConfig))
{
Span<byte> socketAddress = stackalloc byte[SocketAddress.GetMaximumAddressSize(address.AddressFamily)];
int ipHeaderLength = socketConfig.IsIpv4 ? MinIpHeaderLengthInBytes : 0;
try
{
Expand Down Expand Up @@ -270,6 +277,29 @@ private static PingReply SendIcmpEchoRequestOverRawSocket(IPAddress address, byt
{
return CreatePingReply(IPStatus.PacketTooBig);
}
catch (SocketException ex) when (ex.SocketErrorCode == SocketError.HostUnreachable)
{
// This happens on Linux where we explicitly subscribed to error messages
// We should be able to get more info by getting extended socket error from error queue.

Interop.Sys.MessageHeader header = default;

SocketError result;
fixed (byte* sockAddr = &MemoryMarshal.GetReference(socketAddress))
{
header.SocketAddress = sockAddr;
header.SocketAddressLen = socketAddress.Length;
header.IOVectors = null;
header.IOVectorCount = 0;

result = Interop.Sys.ReceiveSocketError(socket.SafeHandle, &header);
}

if (result == SocketError.Success && header.SocketAddressLen > 0)
{
return CreatePingReply(IPStatus.TtlExpired, IPEndPointExtensions.GetIPAddress(socketAddress.Slice(0, header.SocketAddressLen)));
}
}

// We have exceeded our timeout duration, and no reply has been received.
return CreatePingReply(IPStatus.TimedOut);
Expand Down
1 change: 1 addition & 0 deletions src/native/libs/Common/pal_config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
#cmakedefine01 HAVE_IOS_NET_IFMEDIA_H
#cmakedefine01 HAVE_LINUX_RTNETLINK_H
#cmakedefine01 HAVE_LINUX_CAN_H
#cmakedefine01 HAVE_LINUX_ERRQUEUE_H
#cmakedefine01 HAVE_GETDOMAINNAME_SIZET
#cmakedefine01 HAVE_INOTIFY
#cmakedefine01 HAVE_CLOCK_MONOTONIC
Expand Down
1 change: 1 addition & 0 deletions src/native/libs/System.Native/entrypoints.c
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ static const Entry s_sysNative[] =
DllImportEntry(SystemNative_SetSendTimeout)
DllImportEntry(SystemNative_Receive)
DllImportEntry(SystemNative_ReceiveMessage)
DllImportEntry(SystemNative_ReceiveSocketError)
DllImportEntry(SystemNative_Send)
DllImportEntry(SystemNative_SendMessage)
DllImportEntry(SystemNative_Accept)
Expand Down
63 changes: 61 additions & 2 deletions src/native/libs/System.Native/pal_networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
#if HAVE_SYS_FILIO_H
#include <sys/filio.h>
#endif
#if HAVE_LINUX_ERRQUEUE_H
#include <linux/errqueue.h>
#endif


#if HAVE_KQUEUE
#if KEVENT_HAS_VOID_UDATA
Expand Down Expand Up @@ -1325,7 +1329,11 @@ int32_t SystemNative_SetSendTimeout(intptr_t socket, int32_t millisecondsTimeout

static int8_t ConvertSocketFlagsPalToPlatform(int32_t palFlags, int* platformFlags)
{
const int32_t SupportedFlagsMask = SocketFlags_MSG_OOB | SocketFlags_MSG_PEEK | SocketFlags_MSG_DONTROUTE | SocketFlags_MSG_TRUNC | SocketFlags_MSG_CTRUNC;
const int32_t SupportedFlagsMask =
#ifdef MSG_ERRQUEUE
SocketFlags_MSG_ERRQUEUE |
#endif
SocketFlags_MSG_OOB | SocketFlags_MSG_PEEK | SocketFlags_MSG_DONTROUTE | SocketFlags_MSG_TRUNC | SocketFlags_MSG_CTRUNC | SocketFlags_MSG_DONTWAIT;

if ((palFlags & ~SupportedFlagsMask) != 0)
{
Expand All @@ -1335,9 +1343,15 @@ static int8_t ConvertSocketFlagsPalToPlatform(int32_t palFlags, int* platformFla
*platformFlags = ((palFlags & SocketFlags_MSG_OOB) == 0 ? 0 : MSG_OOB) |
((palFlags & SocketFlags_MSG_PEEK) == 0 ? 0 : MSG_PEEK) |
((palFlags & SocketFlags_MSG_DONTROUTE) == 0 ? 0 : MSG_DONTROUTE) |
((palFlags & SocketFlags_MSG_DONTWAIT) == 0 ? 0 : MSG_DONTWAIT) |
((palFlags & SocketFlags_MSG_TRUNC) == 0 ? 0 : MSG_TRUNC) |
((palFlags & SocketFlags_MSG_CTRUNC) == 0 ? 0 : MSG_CTRUNC);

#ifdef MSG_ERRQUEUE
if ((palFlags & SocketFlags_MSG_ERRQUEUE) != 0)
{
*platformFlags |= MSG_ERRQUEUE;
}
#endif
return true;
}

Expand Down Expand Up @@ -1381,6 +1395,51 @@ int32_t SystemNative_Receive(intptr_t socket, void* buffer, int32_t bufferLen, i
return SystemNative_ConvertErrorPlatformToPal(errno);
}

int32_t SystemNative_ReceiveSocketError(intptr_t socket, MessageHeader* messageHeader)
{
int fd = ToFileDescriptor(socket);
ssize_t res;

#if HAVE_LINUX_ERRQUEUE_H
char buffer[sizeof(struct sock_extended_err) + sizeof(struct sockaddr_storage)];
messageHeader->ControlBufferLen = sizeof(buffer);
messageHeader->ControlBuffer = (void*)buffer;

struct msghdr header;
ConvertMessageHeaderToMsghdr(&header, messageHeader, fd);

while ((res = recvmsg(fd, &header, SocketFlags_MSG_DONTWAIT | SocketFlags_MSG_ERRQUEUE)) < 0 && errno == EINTR);

struct cmsghdr *cmsg;
for (cmsg = CMSG_FIRSTHDR(&header); cmsg; cmsg = GET_CMSG_NXTHDR(&header, cmsg))
{
if (cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR)
{
struct sock_extended_err *e = (struct sock_extended_err *)CMSG_DATA(cmsg);
if (e->ee_origin == SO_EE_ORIGIN_ICMP)
{
int size = (int)(cmsg->cmsg_len - sizeof(struct sock_extended_err));
messageHeader->SocketAddressLen = size < messageHeader->SocketAddressLen ? size : messageHeader->SocketAddressLen;
memcpy(messageHeader->SocketAddress, (struct sockaddr_in*)(e+1), (size_t)messageHeader->SocketAddressLen);
return Error_SUCCESS;
}
}
}
#else
res = -1;
errno = ENOTSUP;
#endif

messageHeader->SocketAddressLen = 0;

if (res != -1)
{
return Error_SUCCESS;
}

return SystemNative_ConvertErrorPlatformToPal(errno);
}

int32_t SystemNative_ReceiveMessage(intptr_t socket, MessageHeader* messageHeader, int32_t flags, int64_t* received)
{
if (messageHeader == NULL || received == NULL || messageHeader->SocketAddressLen < 0 ||
Expand Down
4 changes: 4 additions & 0 deletions src/native/libs/System.Native/pal_networking.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ typedef enum
SocketFlags_MSG_DONTROUTE = 0x0004, // SocketFlags.DontRoute
SocketFlags_MSG_TRUNC = 0x0100, // SocketFlags.Truncated
SocketFlags_MSG_CTRUNC = 0x0200, // SocketFlags.ControlDataTruncated
SocketFlags_MSG_DONTWAIT = 0x1000, // used privately by Ping
SocketFlags_MSG_ERRQUEUE = 0x2000, // used privately by Ping
} SocketFlags;

/*
Expand Down Expand Up @@ -356,6 +358,8 @@ PALEXPORT int32_t SystemNative_Receive(intptr_t socket, void* buffer, int32_t bu

PALEXPORT int32_t SystemNative_ReceiveMessage(intptr_t socket, MessageHeader* messageHeader, int32_t flags, int64_t* received);

PALEXPORT int32_t SystemNative_ReceiveSocketError(intptr_t socket, MessageHeader* messageHeader);

PALEXPORT int32_t SystemNative_Send(intptr_t socket, void* buffer, int32_t bufferLen, int32_t flags, int32_t* sent);

PALEXPORT int32_t SystemNative_SendMessage(intptr_t socket, MessageHeader* messageHeader, int32_t flags, int64_t* sent);
Expand Down
4 changes: 4 additions & 0 deletions src/native/libs/configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,10 @@ check_include_files(
"sys/proc_info.h"
HAVE_SYS_PROCINFO_H)

check_include_files(
"time.h;linux/errqueue.h"
HAVE_LINUX_ERRQUEUE_H)

check_symbol_exists(
epoll_create1
sys/epoll.h
Expand Down

0 comments on commit 1e42214

Please sign in to comment.