Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix ping with TTL on Linux #99875

Merged
merged 6 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ internal static partial class Sys
{
internal unsafe struct IOVector
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
#pragma warning disable CS0649
public byte* Base;
public UIntPtr Count;
#pragma warning restore CS0649
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Net.Sockets;
using System.Runtime.InteropServices;

internal static partial class Interop
{
internal static partial class Sys
{
[LibraryImport(Libraries.SystemNative, EntryPoint = "SystemNative_ReceiveSocketError")]
internal static unsafe partial SocketError ReceiveSocketError(SafeHandle socket, MessageHeader* messageHeader);
}
}
15 changes: 15 additions & 0 deletions src/libraries/System.Net.Ping/src/System.Net.Ping.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,29 @@
Link="Common\System\Net\SocketProtocolSupportPal.Unix.cs" />
<Compile Include="$(CommonPath)System\Net\NetworkInformation\UnixCommandLinePing.cs"
Link="Common\System\Net\NetworkInformation\UnixCommandLinePing.cs" />
<Compile Include="$(CommonPath)System\Net\IPEndPointExtensions.cs"
Link="Common\System\Net\IPEndPointExtensions.cs" />
<Compile Include="$(CommonPath)System\Net\SocketAddressPal.Unix.cs"
Link="Common\System\Net\SocketAddressPal.Unix.cs" />
<Compile Include="$(CommonPath)System\Net\IPAddressParserStatics.cs"
Link="Common\System\Net\IPAddressParserStatics.cs" />
<Compile Include="$(CommonPath)System\Net\Sockets\SocketErrorPal.Unix.cs"
Link="Common\System\Net\Sockets\SocketErrorPal.Unix" />
<!-- Interop -->
<Compile Include="$(CommonPath)Interop\Unix\Interop.DefaultPathBufferSize.cs"
Link="Common\Interop\Unix\Interop.DefaultPathBufferSize.cs" />
<Compile Include="$(CommonPath)Interop\Unix\Interop.Errors.cs"
Link="Common\Interop\Unix\Interop.Errors.cs" />
<Compile Include="$(CommonPath)Interop\Unix\Interop.Libraries.cs"
Link="Common\Interop\Unix\Interop.Libraries.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.IOVector.cs"
Link="Common\Interop\Unix\System.Native\Interop.IOVector.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.ReceiveSocketError.cs"
Link="Common\Interop\Unix\System.Native\Interop.ReceiveSocketError.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Close.cs"
Link="Common\Interop\Unix\System.Native\Interop.Close.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.MessageHeader.cs"
Link="Common\Interop\Unix\System.Native\Interop.MessageHeader.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.Socket.cs"
Link="Common\Interop\Unix\System.Native\Interop.Socket.cs" />
<Compile Include="$(CommonPath)Interop\Unix\System.Native\Interop.SocketAddress.cs"
Expand Down Expand Up @@ -99,6 +113,7 @@

<ItemGroup>
<Reference Include="Microsoft.Win32.Primitives" />
<Reference Include="System.Collections" />
<Reference Include="System.ComponentModel.EventBasedAsync" />
<Reference Include="System.ComponentModel.Primitives" />
<Reference Include="System.Diagnostics.Tracing" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ private static Socket GetRawSocket(SocketConfig socketConfig)
{
// If it is not multicast, use Connect to scope responses only to the target address.
socket.Connect(socketConfig.EndPoint);
unsafe
{
int opt = 1;
// setsockopt(fd, IPPROTO_IP, IP_RECVERR, &value, sizeof(int))
socket.SetRawSocketOption(0, 11, new ReadOnlySpan<byte>(&opt, sizeof(int)));
}
}
#pragma warning restore 618

Expand Down Expand Up @@ -232,11 +238,12 @@ private static bool TryGetPingReply(
return true;
}

private static PingReply SendIcmpEchoRequestOverRawSocket(IPAddress address, byte[] buffer, int timeout, PingOptions? options)
private static unsafe PingReply SendIcmpEchoRequestOverRawSocket(IPAddress address, byte[] buffer, int timeout, PingOptions? options)
{
SocketConfig socketConfig = GetSocketConfig(address, buffer, timeout, options);
using (Socket socket = GetRawSocket(socketConfig))
{
Span<byte> socketAddress = stackalloc byte[SocketAddress.GetMaximumAddressSize(address.AddressFamily)];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move this just before its first usage, at line 288.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping.RawSocket.cs(287,48): error CS0255: stackalloc may not be used in a catch or finally block

int ipHeaderLength = socketConfig.IsIpv4 ? MinIpHeaderLengthInBytes : 0;
try
{
Expand Down Expand Up @@ -270,6 +277,29 @@ private static PingReply SendIcmpEchoRequestOverRawSocket(IPAddress address, byt
{
return CreatePingReply(IPStatus.PacketTooBig);
}
catch (SocketException ex) when (ex.SocketErrorCode == SocketError.HostUnreachable)
{
// This happens on Linux where we explicitly subscribed to error messages
// We should be able to get more info by getting extended socket error from error queue.

Interop.Sys.MessageHeader header = default;

SocketError result;
fixed (byte* sockAddr = &MemoryMarshal.GetReference(socketAddress))
{
header.SocketAddress = sockAddr;
header.SocketAddressLen = socketAddress.Length;
header.IOVectors = null;
header.IOVectorCount = 0;

result = Interop.Sys.ReceiveSocketError(socket.SafeHandle, &header);
}

if (result == SocketError.Success && header.SocketAddressLen > 0)
{
return CreatePingReply(IPStatus.TtlExpired, IPEndPointExtensions.GetIPAddress(socketAddress.Slice(0, header.SocketAddressLen)));
}
}

// We have exceeded our timeout duration, and no reply has been received.
return CreatePingReply(IPStatus.TimedOut);
Expand Down
1 change: 1 addition & 0 deletions src/native/libs/Common/pal_config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
#cmakedefine01 HAVE_IOS_NET_IFMEDIA_H
#cmakedefine01 HAVE_LINUX_RTNETLINK_H
#cmakedefine01 HAVE_LINUX_CAN_H
#cmakedefine01 HAVE_LINUX_ERRQUEUE_H
#cmakedefine01 HAVE_GETDOMAINNAME_SIZET
#cmakedefine01 HAVE_INOTIFY
#cmakedefine01 HAVE_CLOCK_MONOTONIC
Expand Down
1 change: 1 addition & 0 deletions src/native/libs/System.Native/entrypoints.c
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ static const Entry s_sysNative[] =
DllImportEntry(SystemNative_SetSendTimeout)
DllImportEntry(SystemNative_Receive)
DllImportEntry(SystemNative_ReceiveMessage)
DllImportEntry(SystemNative_ReceiveSocketError)
DllImportEntry(SystemNative_Send)
DllImportEntry(SystemNative_SendMessage)
DllImportEntry(SystemNative_Accept)
Expand Down
63 changes: 61 additions & 2 deletions src/native/libs/System.Native/pal_networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@
#if HAVE_SYS_FILIO_H
#include <sys/filio.h>
#endif
#if HAVE_LINUX_ERRQUEUE_H
#include <linux/errqueue.h>
#endif


#if HAVE_KQUEUE
#if KEVENT_HAS_VOID_UDATA
Expand Down Expand Up @@ -1325,7 +1329,11 @@ int32_t SystemNative_SetSendTimeout(intptr_t socket, int32_t millisecondsTimeout

static int8_t ConvertSocketFlagsPalToPlatform(int32_t palFlags, int* platformFlags)
{
const int32_t SupportedFlagsMask = SocketFlags_MSG_OOB | SocketFlags_MSG_PEEK | SocketFlags_MSG_DONTROUTE | SocketFlags_MSG_TRUNC | SocketFlags_MSG_CTRUNC;
const int32_t SupportedFlagsMask =
#ifdef MSG_ERRQUEUE
SocketFlags_MSG_ERRQUEUE |
#endif
SocketFlags_MSG_OOB | SocketFlags_MSG_PEEK | SocketFlags_MSG_DONTROUTE | SocketFlags_MSG_TRUNC | SocketFlags_MSG_CTRUNC | SocketFlags_MSG_DONTWAIT;

if ((palFlags & ~SupportedFlagsMask) != 0)
{
Expand All @@ -1335,9 +1343,15 @@ static int8_t ConvertSocketFlagsPalToPlatform(int32_t palFlags, int* platformFla
*platformFlags = ((palFlags & SocketFlags_MSG_OOB) == 0 ? 0 : MSG_OOB) |
((palFlags & SocketFlags_MSG_PEEK) == 0 ? 0 : MSG_PEEK) |
((palFlags & SocketFlags_MSG_DONTROUTE) == 0 ? 0 : MSG_DONTROUTE) |
((palFlags & SocketFlags_MSG_DONTWAIT) == 0 ? 0 : MSG_DONTWAIT) |
((palFlags & SocketFlags_MSG_TRUNC) == 0 ? 0 : MSG_TRUNC) |
((palFlags & SocketFlags_MSG_CTRUNC) == 0 ? 0 : MSG_CTRUNC);

#ifdef MSG_ERRQUEUE
if ((palFlags & SocketFlags_MSG_ERRQUEUE) != 0)
{
*platformFlags |= MSG_ERRQUEUE;
}
#endif
return true;
}

Expand Down Expand Up @@ -1381,6 +1395,51 @@ int32_t SystemNative_Receive(intptr_t socket, void* buffer, int32_t bufferLen, i
return SystemNative_ConvertErrorPlatformToPal(errno);
}

int32_t SystemNative_ReceiveSocketError(intptr_t socket, MessageHeader* messageHeader)
{
int fd = ToFileDescriptor(socket);
ssize_t res;

#if HAVE_LINUX_ERRQUEUE_H
char buffer[sizeof(struct sock_extended_err) + sizeof(struct sockaddr_storage)];
messageHeader->ControlBufferLen = sizeof(buffer);
messageHeader->ControlBuffer = (void*)buffer;

struct msghdr header;
ConvertMessageHeaderToMsghdr(&header, messageHeader, fd);

while ((res = recvmsg(fd, &header, SocketFlags_MSG_DONTWAIT | SocketFlags_MSG_ERRQUEUE)) < 0 && errno == EINTR);

struct cmsghdr *cmsg;
for (cmsg = CMSG_FIRSTHDR(&header); cmsg; cmsg = GET_CMSG_NXTHDR(&header, cmsg))
{
if (cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR)
{
struct sock_extended_err *e = (struct sock_extended_err *)CMSG_DATA(cmsg);
if (e->ee_origin == SO_EE_ORIGIN_ICMP)
{
int size = (int)(cmsg->cmsg_len - sizeof(struct sock_extended_err));
messageHeader->SocketAddressLen = size < messageHeader->SocketAddressLen ? size : messageHeader->SocketAddressLen;
memcpy(messageHeader->SocketAddress, (struct sockaddr_in*)(e+1), (size_t)messageHeader->SocketAddressLen);
return Error_SUCCESS;
}
}
}
#else
res = -1;
errno = ENOTSUP;
#endif

messageHeader->SocketAddressLen = 0;

if (res != -1)
{
return Error_SUCCESS;
}

return SystemNative_ConvertErrorPlatformToPal(errno);
}

int32_t SystemNative_ReceiveMessage(intptr_t socket, MessageHeader* messageHeader, int32_t flags, int64_t* received)
{
if (messageHeader == NULL || received == NULL || messageHeader->SocketAddressLen < 0 ||
Expand Down
4 changes: 4 additions & 0 deletions src/native/libs/System.Native/pal_networking.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ typedef enum
SocketFlags_MSG_DONTROUTE = 0x0004, // SocketFlags.DontRoute
SocketFlags_MSG_TRUNC = 0x0100, // SocketFlags.Truncated
SocketFlags_MSG_CTRUNC = 0x0200, // SocketFlags.ControlDataTruncated
SocketFlags_MSG_DONTWAIT = 0x1000, // used privately by Ping
SocketFlags_MSG_ERRQUEUE = 0x2000, // used privately by Ping
} SocketFlags;

/*
Expand Down Expand Up @@ -356,6 +358,8 @@ PALEXPORT int32_t SystemNative_Receive(intptr_t socket, void* buffer, int32_t bu

PALEXPORT int32_t SystemNative_ReceiveMessage(intptr_t socket, MessageHeader* messageHeader, int32_t flags, int64_t* received);

PALEXPORT int32_t SystemNative_ReceiveSocketError(intptr_t socket, MessageHeader* messageHeader);

PALEXPORT int32_t SystemNative_Send(intptr_t socket, void* buffer, int32_t bufferLen, int32_t flags, int32_t* sent);

PALEXPORT int32_t SystemNative_SendMessage(intptr_t socket, MessageHeader* messageHeader, int32_t flags, int64_t* sent);
Expand Down
4 changes: 4 additions & 0 deletions src/native/libs/configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,10 @@ check_include_files(
"sys/proc_info.h"
HAVE_SYS_PROCINFO_H)

check_include_files(
"time.h;linux/errqueue.h"
HAVE_LINUX_ERRQUEUE_H)

check_symbol_exists(
epoll_create1
sys/epoll.h
Expand Down
Loading