Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IORef based socket #303

Merged
merged 12 commits into from
Feb 6, 2018
2 changes: 1 addition & 1 deletion Network/Socket.hs
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,14 @@ module Network.Socket

import Network.Socket.Buffer hiding (sendBufTo, recvBufFrom)
import Network.Socket.Cbits
import Network.Socket.Close
import Network.Socket.Fcntl
import Network.Socket.Handle
import Network.Socket.If
import Network.Socket.Info
import Network.Socket.Internal
import Network.Socket.Name hiding (getPeerName, getSocketName)
import Network.Socket.Options
import Network.Socket.Shutdown
import Network.Socket.SockAddr
import Network.Socket.Syscall hiding (connect, bind, accept)
import Network.Socket.Types
Expand Down
75 changes: 37 additions & 38 deletions Network/Socket/Buffer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,17 +34,20 @@ sendBufTo :: SocketAddress sa =>
-> sa
-> IO Int -- Number of Bytes sent
sendBufTo s ptr nbytes sa =
withSocketAddress sa $ \p_sa sz ->
fromIntegral <$>
throwSocketErrorWaitWrite s "Network.Socket.sendBufTo"
(c_sendto (fdSocket s) ptr (fromIntegral nbytes) 0{-flags-}
p_sa (fromIntegral sz))
withSocketAddress sa $ \p_sa siz -> fromIntegral <$> do
fd <- fdSocket s
let sz = fromIntegral siz
n = fromIntegral nbytes
flags = 0
throwSocketErrorWaitWrite s "Network.Socket.sendBufTo" $
c_sendto fd ptr n flags p_sa sz

#if defined(mingw32_HOST_OS)
socket2FD :: Socket -> FD
socket2FD s =
socket2FD :: Socket -> IO FD
socket2FD s = do
fd <- fdSocket s
-- HACK, 1 means True
FD{ fdFD = fdSocket s, fdIsSocket_ = 1 }
return $ FD{ fdFD = fd, fdIsSocket_ = 1 }
#endif

-- | Send data to the socket. The socket must be connected to a remote
Expand All @@ -56,24 +59,22 @@ sendBuf :: Socket -- Bound/Connected Socket
-> Ptr Word8 -- Pointer to the data to send
-> Int -- Length of the buffer
-> IO Int -- Number of Bytes sent
sendBuf s str len =
sendBuf s str len = fromIntegral <$> do
#if defined(mingw32_HOST_OS)
-- writeRawBufferPtr is supposed to handle checking for errors, but it's broken
-- on x86_64 because of GHC bug #12010 so we duplicate the check here. The call
-- to throwSocketErrorIfMinus1Retry can be removed when no GHC version with the
-- bug is supported.
fromIntegral <$>
throwSocketErrorIfMinus1Retry "Network.Socket.sendBuf"
(writeRawBufferPtr
"Network.Socket.sendBuf"
(socket2FD s)
(castPtr str)
0
(fromIntegral len))
fd <- socket2FD s
let clen = fromIntegral len
throwSocketErrorIfMinus1Retry "Network.Socket.sendBuf" $
writeRawBufferPtr "Network.Socket.sendBuf" fd (castPtr str) 0 clen
#else
fromIntegral <$>
throwSocketErrorWaitWrite s "Network.Socket.sendBuf"
(c_send (fdSocket s) str (fromIntegral len) 0{-flags-})
fd <- fdSocket s
let flags = 0
clen = fromIntegral len
throwSocketErrorWaitWrite s "Network.Socket.sendBuf" $
c_send fd str clen flags
#endif

-- | Receive data from the socket, writing it into buffer instead of
Expand All @@ -90,16 +91,12 @@ recvBufFrom :: SocketAddress sa => Socket -> Ptr a -> Int -> IO (Int, sa)
recvBufFrom s ptr nbytes
| nbytes <= 0 = ioError (mkInvalidRecvArgError "Network.Socket.recvBufFrom")
| otherwise = withNewSocketAddress $ \ptr_sa sz -> alloca $ \ptr_len -> do
fd <- fdSocket s
poke ptr_len (fromIntegral sz)
len <-
throwSocketErrorWaitRead s "Network.Socket.recvBufFrom"
$ c_recvfrom
(fdSocket s)
ptr
(fromIntegral nbytes)
0{-flags-}
ptr_sa
ptr_len
let cnbytes = fromIntegral nbytes
flags = 0
len <- throwSocketErrorWaitRead s "Network.Socket.recvBufFrom" $
c_recvfrom fd ptr cnbytes flags ptr_sa ptr_len
let len' = fromIntegral len
if len' == 0
then ioError (mkEOFError "Network.Socket.recvFrom")
Expand Down Expand Up @@ -127,17 +124,19 @@ recvBuf s ptr nbytes
| otherwise = do
#if defined(mingw32_HOST_OS)
-- see comment in sendBuf above.
len <- throwSocketErrorIfMinus1Retry "Network.Socket.recvBuf" $
readRawBufferPtr "Network.Socket.recvBuf"
(socket2FD s) ptr 0 (fromIntegral nbytes)
fd <- socket2FD s
let cnbytes = fromIntegral nbytes
len <- throwSocketErrorIfMinus1Retry "Network.Socket.recvBuf" $
readRawBufferPtr "Network.Socket.recvBuf" fd ptr 0 cnbytes
#else
len <- throwSocketErrorWaitRead s "Network.Socket.recvBuf" $
c_recv (fdSocket s) (castPtr ptr) (fromIntegral nbytes) 0{-flags-}
fd <- fdSocket s
len <- throwSocketErrorWaitRead s "Network.Socket.recvBuf" $
c_recv fd (castPtr ptr) (fromIntegral nbytes) 0{-flags-}
#endif
let len' = fromIntegral len
if len' == 0
then ioError (mkEOFError "Network.Socket.recvBuf")
else return len'
let len' = fromIntegral len
if len' == 0
then ioError (mkEOFError "Network.Socket.recvBuf")
else return len'

mkInvalidRecvArgError :: String -> IOError
mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError
Expand Down
10 changes: 6 additions & 4 deletions Network/Socket/ByteString/IO.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,11 @@ sendMany s cs = do
when (sent < totalLength cs) $ sendMany s (remainingChunks sent cs)
where
sendManyInner =
fmap fromIntegral . withIOVec cs $ \(iovsPtr, iovsLen) ->
fmap fromIntegral . withIOVec cs $ \(iovsPtr, iovsLen) -> do
fd <- fdSocket s
let len = fromIntegral $ min iovsLen (#const IOV_MAX)
throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMany" $
c_writev (fdSocket s) iovsPtr
(fromIntegral (min iovsLen (#const IOV_MAX)))
c_writev fd iovsPtr len
#else
sendMany s = sendAll s . B.concat
#endif
Expand Down Expand Up @@ -173,9 +174,10 @@ sendManyTo s cs addr = do
let msgHdr = MsgHdr
addrPtr (fromIntegral addrSize)
iovsPtr (fromIntegral iovsLen)
fd <- fdSocket s
with msgHdr $ \msgHdrPtr ->
throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendManyTo" $
c_sendmsg (fdSocket s) msgHdrPtr 0
c_sendmsg fd msgHdrPtr 0
#else
sendManyTo s cs = sendAllTo s (B.concat cs)
#endif
Expand Down
10 changes: 4 additions & 6 deletions Network/Socket/ByteString/Lazy/Posix.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@ send
send s lbs = do
let cs = take maxNumChunks (L.toChunks lbs)
len = length cs
siz <-
( allocaArray len $ \ptr ->
withPokes cs ptr
$ \niovs -> throwSocketErrorWaitWrite s "writev"
$ c_writev (fdSocket s) ptr niovs
)
fd <- fdSocket s
siz <- allocaArray len $ \ptr ->
withPokes cs ptr $ \niovs ->
throwSocketErrorWaitWrite s "writev" $ c_writev fd ptr niovs
return $ fromIntegral siz
where
withPokes ss p f = loop ss p 0 0
Expand Down
3 changes: 2 additions & 1 deletion Network/Socket/Handle.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import qualified GHC.IO.Device (IODeviceType(Stream))
import GHC.IO.Handle.FD (fdToHandle')
import System.IO (IOMode(..), Handle, BufferMode(..), hSetBuffering)

import Network.Socket.Imports
import Network.Socket.Types

-- | Turns a Socket into an 'Handle'. By default, the new handle is
Expand All @@ -17,7 +18,7 @@ import Network.Socket.Types

socketToHandle :: Socket -> IOMode -> IO Handle
socketToHandle s mode = do
let fd = fromIntegral $ fdSocket s
fd <- fromIntegral <$> fdSocket s

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should invalidate the socket's file descriptor making shoud avoid behaviour impossible. Basically, emulate close, but without closing the socket's file descriptor. Or if you prefer (where available) dup(2) the socket's file descriptor, and close the socket.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, replace that should with a MUST. Given that Sockets (in this branch of the code) have finalizers that automatically close them, and doing so would then break the Handle, and trigger a second close outside the IORef guard in sockets, this is NOT optional. You must either invalidate the socket, or at least dup the file descriptor and pass that to the handle, but invalidation is I think more appropriate here, since continued I/O via the socket is not supported or expected.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks.
And done in 7e6eea8

h <- fdToHandle' fd (Just GHC.IO.Device.Stream) True (show s) mode True{-bin-}
hSetBuffering h NoBuffering
return h
14 changes: 6 additions & 8 deletions Network/Socket/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,17 @@ throwSocketErrorCode loc errno =
-- @EWOULDBLOCK@ or similar, wait for the socket to be read-ready,
-- and try again.
throwSocketErrorWaitRead :: (Eq a, Num a) => Socket -> String -> IO a -> IO a
throwSocketErrorWaitRead s name io =
throwSocketErrorIfMinus1RetryMayBlock name
(threadWaitRead $ fromIntegral $ fdSocket s)
io
throwSocketErrorWaitRead s name io = do
fd <- fromIntegral <$> fdSocket s
throwSocketErrorIfMinus1RetryMayBlock name (threadWaitRead fd) io

-- | Like 'throwSocketErrorIfMinus1Retry', but if the action fails with
-- @EWOULDBLOCK@ or similar, wait for the socket to be write-ready,
-- and try again.
throwSocketErrorWaitWrite :: (Eq a, Num a) => Socket -> String -> IO a -> IO a
throwSocketErrorWaitWrite s name io =
throwSocketErrorIfMinus1RetryMayBlock name
(threadWaitWrite $ fromIntegral $ fdSocket s)
io
throwSocketErrorWaitWrite s name io = do
fd <- fromIntegral <$> fdSocket s
throwSocketErrorIfMinus1RetryMayBlock name (threadWaitWrite fd) io

-- ---------------------------------------------------------------------------
-- WinSock support
Expand Down
6 changes: 4 additions & 2 deletions Network/Socket/Name.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ getPeerName :: SocketAddress sa => Socket -> IO sa
getPeerName s =
withNewSocketAddress $ \ptr sz ->
with (fromIntegral sz) $ \int_star -> do
fd <- fdSocket s
throwSocketErrorIfMinus1Retry_ "Network.Socket.getPeerName" $
c_getpeername (fdSocket s) ptr int_star
c_getpeername fd ptr int_star
_sz <- peek int_star
peekSocketAddress ptr

Expand All @@ -29,8 +30,9 @@ getSocketName :: SocketAddress sa => Socket -> IO sa
getSocketName s =
withNewSocketAddress $ \ptr sz ->
with (fromIntegral sz) $ \int_star -> do
fd <- fdSocket s
throwSocketErrorIfMinus1Retry_ "Network.Socket.getSocketName" $
c_getsockname (fdSocket s) ptr int_star
c_getsockname fd ptr int_star
peekSocketAddress ptr

foreign import CALLCONV unsafe "getpeername"
Expand Down
46 changes: 26 additions & 20 deletions Network/Socket/Options.hsc
Original file line number Diff line number Diff line change
Expand Up @@ -189,21 +189,21 @@ setSocketOption :: Socket
setSocketOption s Linger v = do
(level, opt) <- packSocketOption' "setSocketOption" Linger
let arg = if v == 0 then StructLinger 0 0 else StructLinger 1 (fromIntegral v)
with arg $ \ptr_arg -> do
throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $
c_setsockopt (fdSocket s) level opt
(ptr_arg :: Ptr StructLinger)
(fromIntegral (sizeOf (undefined :: StructLinger)))
return ()
with arg $ \ptr_arg -> void $ do
let ptr = ptr_arg :: Ptr StructLinger
sz = fromIntegral $ sizeOf (undefined :: StructLinger)
fd <- fdSocket s
throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $
c_setsockopt fd level opt ptr sz
#endif
setSocketOption s so v = do
(level, opt) <- packSocketOption' "setSocketOption" so
with (fromIntegral v) $ \ptr_v -> do
throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $
c_setsockopt (fdSocket s) level opt
(ptr_v :: Ptr CInt)
(fromIntegral (sizeOf (undefined :: CInt)))
return ()
with (fromIntegral v) $ \ptr_v -> void $ do
let ptr = ptr_v :: Ptr CInt
sz = fromIntegral $ sizeOf (undefined :: CInt)
fd <- fdSocket s
throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $
c_setsockopt fd level opt ptr sz

-- | Get a socket option that gives an Int value.
-- There is currently no API to get e.g. the timeval socket options
Expand All @@ -213,20 +213,26 @@ getSocketOption :: Socket
#ifdef SO_LINGER
getSocketOption s Linger = do
(level, opt) <- packSocketOption' "getSocketOption" Linger
alloca $ \ptr_v ->
with (fromIntegral (sizeOf (undefined :: StructLinger))) $ \ptr_sz -> do
alloca $ \ptr_v -> do
let ptr = ptr_v :: Ptr StructLinger
sz = fromIntegral $ sizeOf (undefined :: StructLinger)
fd <- fdSocket s
with sz $ \ptr_sz -> do
throwSocketErrorIfMinus1Retry_ "Network.Socket.getSocketOption" $
c_getsockopt (fdSocket s) level opt (ptr_v :: Ptr StructLinger) ptr_sz
StructLinger onoff linger <- peek ptr_v
c_getsockopt fd level opt ptr ptr_sz
StructLinger onoff linger <- peek ptr
return $ fromIntegral $ if onoff == 0 then 0 else linger
#endif
getSocketOption s so = do
(level, opt) <- packSocketOption' "getSocketOption" so
alloca $ \ptr_v ->
with (fromIntegral (sizeOf (undefined :: CInt))) $ \ptr_sz -> do
alloca $ \ptr_v -> do
let ptr = ptr_v :: Ptr CInt
sz = fromIntegral $ sizeOf (undefined :: CInt)
fd <- fdSocket s
with sz $ \ptr_sz -> do
throwSocketErrorIfMinus1Retry_ "Network.Socket.getSocketOption" $
c_getsockopt (fdSocket s) level opt (ptr_v :: Ptr CInt) ptr_sz
fromIntegral <$> peek ptr_v
c_getsockopt fd level opt ptr ptr_sz
fromIntegral <$> peek ptr

foreign import CALLCONV unsafe "getsockopt"
c_getsockopt :: CInt -> CInt -> CInt -> Ptr a -> Ptr CInt -> IO CInt
Expand Down
30 changes: 4 additions & 26 deletions Network/Socket/Close.hs → Network/Socket/Shutdown.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,15 @@

#include "HsNetDef.h"

module Network.Socket.Close (
module Network.Socket.Shutdown (
ShutdownCmd(..)
, shutdown
, close
) where

import GHC.Conc (closeFdWith)

import Network.Socket.Imports
import Network.Socket.Internal
import Network.Socket.Types

-- -----------------------------------------------------------------------------

data ShutdownCmd = ShutdownReceive
| ShutdownSend
| ShutdownBoth
Expand All @@ -32,27 +27,10 @@ sdownCmdToInt ShutdownBoth = 2
-- 'ShutdownSend', further sends are disallowed. If it is
-- 'ShutdownBoth', further sends and receives are disallowed.
shutdown :: Socket -> ShutdownCmd -> IO ()
shutdown s stype = void $
shutdown s stype = void $ do
fd <- fdSocket s
throwSocketErrorIfMinus1Retry_ "Network.Socket.shutdown" $
c_shutdown (fdSocket s) (sdownCmdToInt stype)

-- -----------------------------------------------------------------------------

-- | Close the socket. Sending data to or receiving data from closed socket
-- may lead to undefined behaviour.
close :: Socket -> IO ()
close s = closeFdWith (closeFd . fromIntegral) (fromIntegral $ fdSocket s)

closeFd :: CInt -> IO ()
closeFd fd = throwSocketErrorIfMinus1_ "Network.Socket.close" $ c_close fd
c_shutdown fd $ sdownCmdToInt stype

foreign import CALLCONV unsafe "shutdown"
c_shutdown :: CInt -> CInt -> IO CInt

#if defined(mingw32_HOST_OS)
foreign import CALLCONV unsafe "closesocket"
c_close :: CInt -> IO CInt
#else
foreign import ccall unsafe "close"
c_close :: CInt -> IO CInt
#endif
Loading