Skip to content

Commit

Permalink
Socket.Windows: support ConnectAsync(SocketAsyncEventArgs) for UDP, a…
Browse files Browse the repository at this point in the history
…nd Unix sockets (#33674)

Socket.Windows: support ConnectAsync(SocketAsyncEventArgs) for non-stream protocols
  • Loading branch information
tmds authored Apr 3, 2020
1 parent 223b843 commit b755ba9
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 20 deletions.
22 changes: 18 additions & 4 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2090,11 +2090,17 @@ public IAsyncResult BeginConnect(EndPoint remoteEP, AsyncCallback? callback, obj

private bool CanUseConnectEx(EndPoint remoteEP)
{
return (_socketType == SocketType.Stream) &&
(_rightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint));
}
Debug.Assert(remoteEP.GetType() != typeof(DnsEndPoint));

// ConnectEx supports connection-oriented sockets.
// The socket must be bound before calling ConnectEx.
// In case of IPEndPoint, the Socket will be bound using WildcardBindForConnectIfNecessary.
// Unix sockets are not supported by ConnectEx.

return (_socketType == SocketType.Stream) &&
(_rightEndPoint != null || remoteEP.GetType() == typeof(IPEndPoint)) &&
(remoteEP.AddressFamily != AddressFamily.Unix);
}

internal IAsyncResult UnsafeBeginConnect(EndPoint remoteEP, AsyncCallback? callback, object? state, bool flowContext = false)
{
Expand Down Expand Up @@ -3817,7 +3823,15 @@ private bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket)
SocketError socketError = SocketError.Success;
try
{
socketError = e.DoOperationConnect(this, _handle);
if (CanUseConnectEx(endPointSnapshot))
{
socketError = e.DoOperationConnectEx(this, _handle);
}
else
{
// For connectionless protocols, Connect is not an I/O call.
socketError = e.DoOperationConnect(this, _handle);
}
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,6 @@ private void SetupMultipleBuffers()

private void CompleteCore() { }

private void FinishOperationSync(SocketError socketError, int bytesTransferred, SocketFlags flags)
{
Debug.Assert(socketError != SocketError.IOPending);

if (socketError == SocketError.Success)
{
FinishOperationSyncSuccess(bytesTransferred, flags);
}
else
{
FinishOperationSyncFailure(socketError, bytesTransferred, flags);
}
}

private void AcceptCompletionCallback(IntPtr acceptedFileDescriptor, byte[] socketAddress, int socketAddressSize, SocketError socketError)
{
CompleteAcceptOperation(acceptedFileDescriptor, socketAddress, socketAddressSize, socketError);
Expand Down Expand Up @@ -95,6 +81,9 @@ private void ConnectCompletionCallback(SocketError socketError)
CompletionCallback(0, SocketFlags.None, socketError);
}

internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle handle)
=> DoOperationConnect(socket, handle);

internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle handle)
{
SocketError socketError = handle.AsyncContext.ConnectAsync(_socketAddress!.Buffer, _socketAddress.Size, ConnectCompletionCallback);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha
}

internal unsafe SocketError DoOperationConnect(Socket socket, SafeSocketHandle handle)
{
// Called for connectionless protocols.
SocketError socketError = SocketPal.Connect(handle, _socketAddress!.Buffer, _socketAddress.Size);
FinishOperationSync(socketError, 0, SocketFlags.None);
return socketError;
}

internal unsafe SocketError DoOperationConnectEx(Socket socket, SafeSocketHandle handle)
{
// ConnectEx uses a sockaddr buffer containing the remote address to which to connect.
// It can also optionally take a single buffer of data to send after the connection is complete.
Expand Down Expand Up @@ -1160,6 +1168,13 @@ private unsafe SocketError FinishOperationConnect()
{
try
{
if (_currentSocket!.SocketType != SocketType.Stream)
{
// With connectionless sockets, regular connect is used instead of ConnectEx,
// attempting to set SO_UPDATE_CONNECT_CONTEXT will result in an error.
return SocketError.Success;
}

// Update the socket context.
SocketError socketError = Interop.Winsock.setsockopt(
_currentSocket!.SafeHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,5 +780,19 @@ internal void FinishOperationAsyncSuccess(int bytesTransferred, SocketFlags flag
ExecutionContext.Run(context, s_executionCallback, this);
}
}

private void FinishOperationSync(SocketError socketError, int bytesTransferred, SocketFlags flags)
{
Debug.Assert(socketError != SocketError.IOPending);

if (socketError == SocketError.Success)
{
FinishOperationSyncSuccess(bytesTransferred, flags);
}
else
{
FinishOperationSyncFailure(socketError, bytesTransferred, flags);
}
}
}
}
34 changes: 34 additions & 0 deletions src/libraries/System.Net.Sockets/tests/FunctionalTests/Connect.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,40 @@ public async Task Connect_Success(IPAddress listenAt)
}
}

[Theory]
[MemberData(nameof(Loopbacks))]
public async Task Connect_Udp_Success(IPAddress listenAt)
{
using Socket listener = new Socket(listenAt.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
using Socket client = new Socket(listenAt.AddressFamily, SocketType.Dgram, ProtocolType.Udp);
listener.Bind(new IPEndPoint(listenAt, 0));

await ConnectAsync(client, new IPEndPoint(listenAt, ((IPEndPoint)listener.LocalEndPoint).Port));
Assert.True(client.Connected);
}

[Theory]
[MemberData(nameof(Loopbacks))]
public async Task Connect_Dns_Success(IPAddress listenAt)
{
// On some systems (like Ubuntu 16.04 and Ubuntu 18.04) "localhost" doesn't resolve to '::1'.
if (Array.IndexOf(Dns.GetHostAddresses("localhost"), listenAt) == -1)
{
return;
}

int port;
using (SocketTestServer.SocketTestServerFactory(SocketImplementationType.Async, listenAt, out port))
{
using (Socket client = new Socket(listenAt.AddressFamily, SocketType.Stream, ProtocolType.Tcp))
{
Task connectTask = ConnectAsync(client, new DnsEndPoint("localhost", port));
await connectTask;
Assert.True(client.Connected);
}
}
}

[OuterLoop]
[Theory]
[MemberData(nameof(Loopbacks))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ public void OSSupportsUnixDomainSockets_ReturnsCorrectValue()
Assert.Equal(PlatformSupportsUnixDomainSockets, Socket.OSSupportsUnixDomainSockets);
}

[PlatformSpecific(~TestPlatforms.Windows)] // Windows doesn't currently support ConnectEx with domain sockets
[ConditionalFact(nameof(PlatformSupportsUnixDomainSockets))]
public async Task Socket_ConnectAsyncUnixDomainSocketEndPoint_Success()
{
Expand Down Expand Up @@ -100,7 +99,7 @@ public async Task Socket_ConnectAsyncUnixDomainSocketEndPoint_NotServer()
}

Assert.Equal(
RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SocketError.InvalidArgument : SocketError.AddressNotAvailable,
RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? SocketError.ConnectionRefused : SocketError.AddressNotAvailable,
args.SocketError);
}
}
Expand Down

0 comments on commit b755ba9

Please sign in to comment.