diff --git a/Network/Socket.hs b/Network/Socket.hs index e6f46540..858d83d2 100644 --- a/Network/Socket.hs +++ b/Network/Socket.hs @@ -192,7 +192,6 @@ module Network.Socket import Network.Socket.Buffer hiding (sendBufTo, recvBufFrom) import Network.Socket.Cbits -import Network.Socket.Close import Network.Socket.Fcntl import Network.Socket.Handle import Network.Socket.If @@ -200,6 +199,7 @@ import Network.Socket.Info import Network.Socket.Internal import Network.Socket.Name hiding (getPeerName, getSocketName) import Network.Socket.Options +import Network.Socket.Shutdown import Network.Socket.SockAddr import Network.Socket.Syscall hiding (connect, bind, accept) import Network.Socket.Types diff --git a/Network/Socket/Buffer.hs b/Network/Socket/Buffer.hs index 900eca00..51a01df8 100644 --- a/Network/Socket/Buffer.hs +++ b/Network/Socket/Buffer.hs @@ -34,17 +34,20 @@ sendBufTo :: SocketAddress sa => -> sa -> IO Int -- Number of Bytes sent sendBufTo s ptr nbytes sa = - withSocketAddress sa $ \p_sa sz -> - fromIntegral <$> - throwSocketErrorWaitWrite s "Network.Socket.sendBufTo" - (c_sendto (fdSocket s) ptr (fromIntegral nbytes) 0{-flags-} - p_sa (fromIntegral sz)) + withSocketAddress sa $ \p_sa siz -> fromIntegral <$> do + fd <- fdSocket s + let sz = fromIntegral siz + n = fromIntegral nbytes + flags = 0 + throwSocketErrorWaitWrite s "Network.Socket.sendBufTo" $ + c_sendto fd ptr n flags p_sa sz #if defined(mingw32_HOST_OS) -socket2FD :: Socket -> FD -socket2FD s = +socket2FD :: Socket -> IO FD +socket2FD s = do + fd <- fdSocket s -- HACK, 1 means True - FD{ fdFD = fdSocket s, fdIsSocket_ = 1 } + return $ FD{ fdFD = fd, fdIsSocket_ = 1 } #endif -- | Send data to the socket. The socket must be connected to a remote @@ -56,24 +59,22 @@ sendBuf :: Socket -- Bound/Connected Socket -> Ptr Word8 -- Pointer to the data to send -> Int -- Length of the buffer -> IO Int -- Number of Bytes sent -sendBuf s str len = +sendBuf s str len = fromIntegral <$> do #if defined(mingw32_HOST_OS) -- writeRawBufferPtr is supposed to handle checking for errors, but it's broken -- on x86_64 because of GHC bug #12010 so we duplicate the check here. The call -- to throwSocketErrorIfMinus1Retry can be removed when no GHC version with the -- bug is supported. - fromIntegral <$> - throwSocketErrorIfMinus1Retry "Network.Socket.sendBuf" - (writeRawBufferPtr - "Network.Socket.sendBuf" - (socket2FD s) - (castPtr str) - 0 - (fromIntegral len)) + fd <- socket2FD s + let clen = fromIntegral len + throwSocketErrorIfMinus1Retry "Network.Socket.sendBuf" $ + writeRawBufferPtr "Network.Socket.sendBuf" fd (castPtr str) 0 clen #else - fromIntegral <$> - throwSocketErrorWaitWrite s "Network.Socket.sendBuf" - (c_send (fdSocket s) str (fromIntegral len) 0{-flags-}) + fd <- fdSocket s + let flags = 0 + clen = fromIntegral len + throwSocketErrorWaitWrite s "Network.Socket.sendBuf" $ + c_send fd str clen flags #endif -- | Receive data from the socket, writing it into buffer instead of @@ -90,16 +91,12 @@ recvBufFrom :: SocketAddress sa => Socket -> Ptr a -> Int -> IO (Int, sa) recvBufFrom s ptr nbytes | nbytes <= 0 = ioError (mkInvalidRecvArgError "Network.Socket.recvBufFrom") | otherwise = withNewSocketAddress $ \ptr_sa sz -> alloca $ \ptr_len -> do + fd <- fdSocket s poke ptr_len (fromIntegral sz) - len <- - throwSocketErrorWaitRead s "Network.Socket.recvBufFrom" - $ c_recvfrom - (fdSocket s) - ptr - (fromIntegral nbytes) - 0{-flags-} - ptr_sa - ptr_len + let cnbytes = fromIntegral nbytes + flags = 0 + len <- throwSocketErrorWaitRead s "Network.Socket.recvBufFrom" $ + c_recvfrom fd ptr cnbytes flags ptr_sa ptr_len let len' = fromIntegral len if len' == 0 then ioError (mkEOFError "Network.Socket.recvFrom") @@ -127,17 +124,19 @@ recvBuf s ptr nbytes | otherwise = do #if defined(mingw32_HOST_OS) -- see comment in sendBuf above. - len <- throwSocketErrorIfMinus1Retry "Network.Socket.recvBuf" $ - readRawBufferPtr "Network.Socket.recvBuf" - (socket2FD s) ptr 0 (fromIntegral nbytes) + fd <- socket2FD s + let cnbytes = fromIntegral nbytes + len <- throwSocketErrorIfMinus1Retry "Network.Socket.recvBuf" $ + readRawBufferPtr "Network.Socket.recvBuf" fd ptr 0 cnbytes #else - len <- throwSocketErrorWaitRead s "Network.Socket.recvBuf" $ - c_recv (fdSocket s) (castPtr ptr) (fromIntegral nbytes) 0{-flags-} + fd <- fdSocket s + len <- throwSocketErrorWaitRead s "Network.Socket.recvBuf" $ + c_recv fd (castPtr ptr) (fromIntegral nbytes) 0{-flags-} #endif - let len' = fromIntegral len - if len' == 0 - then ioError (mkEOFError "Network.Socket.recvBuf") - else return len' + let len' = fromIntegral len + if len' == 0 + then ioError (mkEOFError "Network.Socket.recvBuf") + else return len' mkInvalidRecvArgError :: String -> IOError mkInvalidRecvArgError loc = ioeSetErrorString (mkIOError diff --git a/Network/Socket/ByteString/IO.hsc b/Network/Socket/ByteString/IO.hsc index 5fe15871..c60434b5 100644 --- a/Network/Socket/ByteString/IO.hsc +++ b/Network/Socket/ByteString/IO.hsc @@ -142,10 +142,11 @@ sendMany s cs = do when (sent < totalLength cs) $ sendMany s (remainingChunks sent cs) where sendManyInner = - fmap fromIntegral . withIOVec cs $ \(iovsPtr, iovsLen) -> + fmap fromIntegral . withIOVec cs $ \(iovsPtr, iovsLen) -> do + fd <- fdSocket s + let len = fromIntegral $ min iovsLen (#const IOV_MAX) throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendMany" $ - c_writev (fdSocket s) iovsPtr - (fromIntegral (min iovsLen (#const IOV_MAX))) + c_writev fd iovsPtr len #else sendMany s = sendAll s . B.concat #endif @@ -173,9 +174,10 @@ sendManyTo s cs addr = do let msgHdr = MsgHdr addrPtr (fromIntegral addrSize) iovsPtr (fromIntegral iovsLen) + fd <- fdSocket s with msgHdr $ \msgHdrPtr -> throwSocketErrorWaitWrite s "Network.Socket.ByteString.sendManyTo" $ - c_sendmsg (fdSocket s) msgHdrPtr 0 + c_sendmsg fd msgHdrPtr 0 #else sendManyTo s cs = sendAllTo s (B.concat cs) #endif diff --git a/Network/Socket/ByteString/Lazy/Posix.hs b/Network/Socket/ByteString/Lazy/Posix.hs index 9a8a2ebc..a609e1a3 100644 --- a/Network/Socket/ByteString/Lazy/Posix.hs +++ b/Network/Socket/ByteString/Lazy/Posix.hs @@ -25,12 +25,10 @@ send send s lbs = do let cs = take maxNumChunks (L.toChunks lbs) len = length cs - siz <- - ( allocaArray len $ \ptr -> - withPokes cs ptr - $ \niovs -> throwSocketErrorWaitWrite s "writev" - $ c_writev (fdSocket s) ptr niovs - ) + fd <- fdSocket s + siz <- allocaArray len $ \ptr -> + withPokes cs ptr $ \niovs -> + throwSocketErrorWaitWrite s "writev" $ c_writev fd ptr niovs return $ fromIntegral siz where withPokes ss p f = loop ss p 0 0 diff --git a/Network/Socket/Handle.hs b/Network/Socket/Handle.hs index 2d7eaed7..1eb49b4d 100644 --- a/Network/Socket/Handle.hs +++ b/Network/Socket/Handle.hs @@ -16,8 +16,9 @@ import Network.Socket.Types -- on the 'Handle'. socketToHandle :: Socket -> IOMode -> IO Handle -socketToHandle s mode = do - let fd = fromIntegral $ fdSocket s - h <- fdToHandle' fd (Just GHC.IO.Device.Stream) True (show s) mode True{-bin-} +socketToHandle s mode = invalidateSocket s err $ \oldfd -> do + h <- fdToHandle' oldfd (Just GHC.IO.Device.Stream) True (show s) mode True{-bin-} hSetBuffering h NoBuffering return h + where + err _ = ioError $ userError $ "socketToHandle: socket is no longer valid" diff --git a/Network/Socket/Internal.hs b/Network/Socket/Internal.hs index 0e32e270..8a3147cd 100644 --- a/Network/Socket/Internal.hs +++ b/Network/Socket/Internal.hs @@ -191,19 +191,17 @@ throwSocketErrorCode loc errno = -- @EWOULDBLOCK@ or similar, wait for the socket to be read-ready, -- and try again. throwSocketErrorWaitRead :: (Eq a, Num a) => Socket -> String -> IO a -> IO a -throwSocketErrorWaitRead s name io = - throwSocketErrorIfMinus1RetryMayBlock name - (threadWaitRead $ fromIntegral $ fdSocket s) - io +throwSocketErrorWaitRead s name io = do + fd <- fromIntegral <$> fdSocket s + throwSocketErrorIfMinus1RetryMayBlock name (threadWaitRead fd) io -- | Like 'throwSocketErrorIfMinus1Retry', but if the action fails with -- @EWOULDBLOCK@ or similar, wait for the socket to be write-ready, -- and try again. throwSocketErrorWaitWrite :: (Eq a, Num a) => Socket -> String -> IO a -> IO a -throwSocketErrorWaitWrite s name io = - throwSocketErrorIfMinus1RetryMayBlock name - (threadWaitWrite $ fromIntegral $ fdSocket s) - io +throwSocketErrorWaitWrite s name io = do + fd <- fromIntegral <$> fdSocket s + throwSocketErrorIfMinus1RetryMayBlock name (threadWaitWrite fd) io -- --------------------------------------------------------------------------- -- WinSock support diff --git a/Network/Socket/Name.hs b/Network/Socket/Name.hs index 538b3ee5..c4cede8c 100644 --- a/Network/Socket/Name.hs +++ b/Network/Socket/Name.hs @@ -19,8 +19,9 @@ getPeerName :: SocketAddress sa => Socket -> IO sa getPeerName s = withNewSocketAddress $ \ptr sz -> with (fromIntegral sz) $ \int_star -> do + fd <- fdSocket s throwSocketErrorIfMinus1Retry_ "Network.Socket.getPeerName" $ - c_getpeername (fdSocket s) ptr int_star + c_getpeername fd ptr int_star _sz <- peek int_star peekSocketAddress ptr @@ -29,8 +30,9 @@ getSocketName :: SocketAddress sa => Socket -> IO sa getSocketName s = withNewSocketAddress $ \ptr sz -> with (fromIntegral sz) $ \int_star -> do + fd <- fdSocket s throwSocketErrorIfMinus1Retry_ "Network.Socket.getSocketName" $ - c_getsockname (fdSocket s) ptr int_star + c_getsockname fd ptr int_star peekSocketAddress ptr foreign import CALLCONV unsafe "getpeername" diff --git a/Network/Socket/Options.hsc b/Network/Socket/Options.hsc index 674edd74..717d2d49 100644 --- a/Network/Socket/Options.hsc +++ b/Network/Socket/Options.hsc @@ -189,21 +189,21 @@ setSocketOption :: Socket setSocketOption s Linger v = do (level, opt) <- packSocketOption' "setSocketOption" Linger let arg = if v == 0 then StructLinger 0 0 else StructLinger 1 (fromIntegral v) - with arg $ \ptr_arg -> do - throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $ - c_setsockopt (fdSocket s) level opt - (ptr_arg :: Ptr StructLinger) - (fromIntegral (sizeOf (undefined :: StructLinger))) - return () + with arg $ \ptr_arg -> void $ do + let ptr = ptr_arg :: Ptr StructLinger + sz = fromIntegral $ sizeOf (undefined :: StructLinger) + fd <- fdSocket s + throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $ + c_setsockopt fd level opt ptr sz #endif setSocketOption s so v = do (level, opt) <- packSocketOption' "setSocketOption" so - with (fromIntegral v) $ \ptr_v -> do - throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $ - c_setsockopt (fdSocket s) level opt - (ptr_v :: Ptr CInt) - (fromIntegral (sizeOf (undefined :: CInt))) - return () + with (fromIntegral v) $ \ptr_v -> void $ do + let ptr = ptr_v :: Ptr CInt + sz = fromIntegral $ sizeOf (undefined :: CInt) + fd <- fdSocket s + throwSocketErrorIfMinus1_ "Network.Socket.setSocketOption" $ + c_setsockopt fd level opt ptr sz -- | Get a socket option that gives an Int value. -- There is currently no API to get e.g. the timeval socket options @@ -213,20 +213,26 @@ getSocketOption :: Socket #ifdef SO_LINGER getSocketOption s Linger = do (level, opt) <- packSocketOption' "getSocketOption" Linger - alloca $ \ptr_v -> - with (fromIntegral (sizeOf (undefined :: StructLinger))) $ \ptr_sz -> do + alloca $ \ptr_v -> do + let ptr = ptr_v :: Ptr StructLinger + sz = fromIntegral $ sizeOf (undefined :: StructLinger) + fd <- fdSocket s + with sz $ \ptr_sz -> do throwSocketErrorIfMinus1Retry_ "Network.Socket.getSocketOption" $ - c_getsockopt (fdSocket s) level opt (ptr_v :: Ptr StructLinger) ptr_sz - StructLinger onoff linger <- peek ptr_v + c_getsockopt fd level opt ptr ptr_sz + StructLinger onoff linger <- peek ptr return $ fromIntegral $ if onoff == 0 then 0 else linger #endif getSocketOption s so = do (level, opt) <- packSocketOption' "getSocketOption" so - alloca $ \ptr_v -> - with (fromIntegral (sizeOf (undefined :: CInt))) $ \ptr_sz -> do + alloca $ \ptr_v -> do + let ptr = ptr_v :: Ptr CInt + sz = fromIntegral $ sizeOf (undefined :: CInt) + fd <- fdSocket s + with sz $ \ptr_sz -> do throwSocketErrorIfMinus1Retry_ "Network.Socket.getSocketOption" $ - c_getsockopt (fdSocket s) level opt (ptr_v :: Ptr CInt) ptr_sz - fromIntegral <$> peek ptr_v + c_getsockopt fd level opt ptr ptr_sz + fromIntegral <$> peek ptr foreign import CALLCONV unsafe "getsockopt" c_getsockopt :: CInt -> CInt -> CInt -> Ptr a -> Ptr CInt -> IO CInt diff --git a/Network/Socket/Close.hs b/Network/Socket/Shutdown.hs similarity index 54% rename from Network/Socket/Close.hs rename to Network/Socket/Shutdown.hs index 52259225..9657ff30 100644 --- a/Network/Socket/Close.hs +++ b/Network/Socket/Shutdown.hs @@ -2,20 +2,15 @@ #include "HsNetDef.h" -module Network.Socket.Close ( +module Network.Socket.Shutdown ( ShutdownCmd(..) , shutdown - , close ) where -import GHC.Conc (closeFdWith) - import Network.Socket.Imports import Network.Socket.Internal import Network.Socket.Types --- ----------------------------------------------------------------------------- - data ShutdownCmd = ShutdownReceive | ShutdownSend | ShutdownBoth @@ -32,27 +27,10 @@ sdownCmdToInt ShutdownBoth = 2 -- 'ShutdownSend', further sends are disallowed. If it is -- 'ShutdownBoth', further sends and receives are disallowed. shutdown :: Socket -> ShutdownCmd -> IO () -shutdown s stype = void $ +shutdown s stype = void $ do + fd <- fdSocket s throwSocketErrorIfMinus1Retry_ "Network.Socket.shutdown" $ - c_shutdown (fdSocket s) (sdownCmdToInt stype) - --- ----------------------------------------------------------------------------- - --- | Close the socket. Sending data to or receiving data from closed socket --- may lead to undefined behaviour. -close :: Socket -> IO () -close s = closeFdWith (closeFd . fromIntegral) (fromIntegral $ fdSocket s) - -closeFd :: CInt -> IO () -closeFd fd = throwSocketErrorIfMinus1_ "Network.Socket.close" $ c_close fd + c_shutdown fd $ sdownCmdToInt stype foreign import CALLCONV unsafe "shutdown" c_shutdown :: CInt -> CInt -> IO CInt - -#if defined(mingw32_HOST_OS) -foreign import CALLCONV unsafe "closesocket" - c_close :: CInt -> IO CInt -#else -foreign import ccall unsafe "close" - c_close :: CInt -> IO CInt -#endif diff --git a/Network/Socket/Syscall.hs b/Network/Socket/Syscall.hs index c3a0e032..371a9fe2 100644 --- a/Network/Socket/Syscall.hs +++ b/Network/Socket/Syscall.hs @@ -14,7 +14,6 @@ import GHC.Conc (asyncDoProc) import Foreign.C.Error (getErrno, eINTR, eINPROGRESS) import GHC.Conc (threadWaitWrite) import GHC.IO (onException) -import Network.Socket.Close #endif #ifdef HAVE_ADVANCED_SOCKET_FLAGS @@ -84,7 +83,7 @@ socket family stype protocol = do #ifndef HAVE_ADVANCED_SOCKET_FLAGS setNonBlockIfNeeded fd #endif - let s = mkSocket fd + s <- mkSocket fd #if HAVE_DECL_IPV6_V6ONLY when (family == AF_INET6 && stype `elem` [Stream, Datagram]) $ # if defined(mingw32_HOST_OS) @@ -111,9 +110,10 @@ socket family stype protocol = do -- 'defaultPort' is passed then the system assigns the next available -- use port. bind :: SocketAddress sa => Socket -> sa -> IO () -bind s sa = withSocketAddress sa $ \p_sa sz -> void $ - throwSocketErrorIfMinus1Retry "Network.Socket.bind" $ - c_bind (fdSocket s) p_sa (fromIntegral sz) +bind s sa = withSocketAddress sa $ \p_sa siz -> void $ do + fd <- fdSocket s + let sz = fromIntegral siz + throwSocketErrorIfMinus1Retry "Network.Socket.bind" $ c_bind fd p_sa sz ----------------------------------------------------------------------------- -- Connecting a socket @@ -126,9 +126,9 @@ connect s sa = withSocketsDo $ withSocketAddress sa $ \p_sa sz -> connectLoop :: SocketAddress sa => Socket -> Ptr sa -> CInt -> IO () connectLoop s p_sa sz = loop where - fd = fdSocket s errLoc = "Network.Socket.connect: " ++ show s loop = do + fd <- fdSocket s r <- c_connect fd p_sa sz when (r == -1) $ do #if defined(mingw32_HOST_OS) @@ -142,7 +142,8 @@ connectLoop s p_sa sz = loop _otherwise -> throwSocketError errLoc connectBlocked = do - threadWaitWrite (fromIntegral fd) + fd <- fromIntegral <$> fdSocket s + threadWaitWrite fd err <- getSocketOption s SoError when (err == -1) $ throwSocketErrorCode errLoc (fromIntegral err) #endif @@ -154,9 +155,10 @@ connectLoop s p_sa sz = loop -- specifies the maximum number of queued connections and should be at -- least 1; the maximum value is system-dependent (usually 5). listen :: Socket -> Int -> IO () -listen s backlog = +listen s backlog = do + fd <- fdSocket s throwSocketErrorIfMinus1Retry_ "Network.Socket.listen" $ - c_listen (fdSocket s) (fromIntegral backlog) + c_listen fd $ fromIntegral backlog ----------------------------------------------------------------------------- -- Accept @@ -174,7 +176,7 @@ listen s backlog = -- to the socket on the other end of the connection. accept :: SocketAddress sa => Socket -> IO (Socket, sa) accept s = withNewSocketAddress $ \sa sz -> do - let fd = fdSocket s + fd <- fdSocket s #if defined(mingw32_HOST_OS) new_fd <- if threaded @@ -202,7 +204,7 @@ accept s = withNewSocketAddress $ \sa sz -> do # endif /* HAVE_ADVANCED_SOCKET_FLAGS */ #endif addr <- peekSocketAddress sa - let new_s = mkSocket new_fd + new_s <- mkSocket new_fd return (new_s, addr) foreign import CALLCONV unsafe "socket" diff --git a/Network/Socket/Types.hsc b/Network/Socket/Types.hsc index 51464b58..a6a27d03 100644 --- a/Network/Socket/Types.hsc +++ b/Network/Socket/Types.hsc @@ -5,12 +5,13 @@ #include "HsNet.h" ##include "HsNetDef.h" -module Network.Socket.Types - ( +module Network.Socket.Types ( -- * Socket type Socket , fdSocket , mkSocket + , invalidateSocket + , close -- * Types of socket , SocketType(..) , isSupportedSocketType @@ -58,8 +59,10 @@ module Network.Socket.Types , ntohl ) where +import Data.IORef (IORef, newIORef, readIORef, atomicModifyIORef', mkWeakIORef) import Data.Ratio import Foreign.Marshal.Alloc +import GHC.Conc (closeFdWith) #if defined(DOMAIN_SOCKET_SUPPORT) import Foreign.Marshal.Array @@ -70,14 +73,57 @@ import Network.Socket.Imports ----------------------------------------------------------------------------- -- | Basic type for a socket. -newtype Socket = Socket CInt deriving (Eq, Show) +data Socket = Socket (IORef CInt) CInt {- for Show -} + +instance Show Socket where + show (Socket _ ofd) = "" + +instance Eq Socket where + Socket ref1 _ == Socket ref2 _ = ref1 == ref2 -- | Getting a file descriptor from a socket. -fdSocket :: Socket -> CInt -fdSocket (Socket fd) = fd +fdSocket :: Socket -> IO CInt +fdSocket (Socket ref _) = readIORef ref + +-- | Creating a socket from a file descriptor. +mkSocket :: CInt -> IO Socket +mkSocket fd = do + ref <- newIORef fd + let s = Socket ref fd + void $ mkWeakIORef ref $ close s + return s + +invalidSocket :: CInt +#if defined(mingw32_HOST_OS) +invalidSocket = #const INVALID_SOCKET +#else +invalidSocket = -1 +#endif + +invalidateSocket :: + Socket + -> (CInt -> IO a) + -> (CInt -> IO a) + -> IO a +invalidateSocket (Socket ref _) errorAction normalAction = do + oldfd <- atomicModifyIORef' ref $ \cur -> (invalidSocket, cur) + if oldfd == invalidSocket then errorAction oldfd else normalAction oldfd -mkSocket :: CInt -> Socket -mkSocket = Socket +----------------------------------------------------------------------------- + +-- | Close the socket. Sending data to or receiving data from closed socket +-- may lead to undefined behaviour. +close :: Socket -> IO () +close s = invalidateSocket s (\_ -> return ()) $ \oldfd -> do + closeFdWith (void . c_close . fromIntegral) (fromIntegral oldfd) + +#if defined(mingw32_HOST_OS) +foreign import CALLCONV unsafe "closesocket" + c_close :: CInt -> IO CInt +#else +foreign import ccall unsafe "close" + c_close :: CInt -> IO CInt +#endif ----------------------------------------------------------------------------- diff --git a/Network/Socket/Unix.hsc b/Network/Socket/Unix.hsc index 99ddc733..e9a0767c 100644 --- a/Network/Socket/Unix.hsc +++ b/Network/Socket/Unix.hsc @@ -62,7 +62,7 @@ getPeerCred :: Socket -> IO (CUInt, CUInt, CUInt) #ifdef HAVE_STRUCT_UCRED_SO_PEERCRED getPeerCred s = do let sz = (#const sizeof(struct ucred)) - fd = fdSocket s + fd <- fdSocket s allocaBytes sz $ \ ptr_cr -> with (fromIntegral sz) $ \ ptr_sz -> do _ <- ($) throwSocketErrorIfMinus1Retry "Network.Socket.getPeerCred" $ @@ -85,8 +85,9 @@ getPeerEid :: Socket -> IO (CUInt, CUInt) getPeerEid s = do alloca $ \ ptr_uid -> alloca $ \ ptr_gid -> do + fd <- fdSocket s throwSocketErrorIfMinus1Retry_ "Network.Socket.getPeerEid" $ - c_getpeereid (fdSocket s) ptr_uid ptr_gid + c_getpeereid fd ptr_uid ptr_gid uid <- peek ptr_uid gid <- peek ptr_gid return (uid, gid) @@ -114,8 +115,9 @@ isUnixDomainSocketAvailable = False -- 'True'. sendFd :: Socket -> CInt -> IO () #if defined(DOMAIN_SOCKET_SUPPORT) -sendFd s outfd = void $ - throwSocketErrorWaitWrite s "Network.Socket.sendFd" $ c_sendFd (fdSocket s) outfd +sendFd s outfd = void $ do + fd <- fdSocket s + throwSocketErrorWaitWrite s "Network.Socket.sendFd" $ c_sendFd fd outfd foreign import ccall SAFE_ON_WIN "sendFd" c_sendFd :: CInt -> CInt -> IO CInt #else sendFd _ _ = error "Network.Socket.sendFd" @@ -128,8 +130,9 @@ sendFd _ _ = error "Network.Socket.sendFd" -- 'True'. recvFd :: Socket -> IO CInt #if defined(DOMAIN_SOCKET_SUPPORT) -recvFd s = - throwSocketErrorWaitRead s "Network.Socket.recvFd" $ c_recvFd (fdSocket s) +recvFd s = do + fd <- fdSocket s + throwSocketErrorWaitRead s "Network.Socket.recvFd" $ c_recvFd fd foreign import ccall SAFE_ON_WIN "recvFd" c_recvFd :: CInt -> IO CInt #else recvFd _ = error "Network.Socket.recvFd" @@ -152,8 +155,8 @@ socketPair family stype protocol = [fd1,fd2] <- peekArray 2 fdArr setNonBlockIfNeeded fd1 setNonBlockIfNeeded fd2 - let s1 = mkSocket fd1 - s2 = mkSocket fd2 + s1 <- mkSocket fd1 + s2 <- mkSocket fd2 return (s1, s2) foreign import ccall unsafe "socketpair" diff --git a/network.cabal b/network.cabal index c623d3ae..a1c9894d 100644 --- a/network.cabal +++ b/network.cabal @@ -49,7 +49,6 @@ library Network.Socket.ByteString.Internal Network.Socket.ByteString.IO Network.Socket.Cbits - Network.Socket.Close Network.Socket.Fcntl Network.Socket.Handle Network.Socket.Imports @@ -57,6 +56,7 @@ library Network.Socket.Info Network.Socket.Name Network.Socket.Options + Network.Socket.Shutdown Network.Socket.SockAddr Network.Socket.Syscall Network.Socket.Types diff --git a/tests/SimpleSpec.hs b/tests/SimpleSpec.hs index 73ddd706..c54502f6 100644 --- a/tests/SimpleSpec.hs +++ b/tests/SimpleSpec.hs @@ -174,8 +174,9 @@ tcpTest clientAct serverAct = do addr:_ <- getAddrInfo (Just hints) (Just serverAddr) (Just $ show serverPort) sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) #if !defined(mingw32_HOST_OS) - getNonBlock (fdSocket sock) `shouldReturn` True - getCloseOnExec (fdSocket sock) `shouldReturn` False + fd <- fdSocket sock + getNonBlock fd `shouldReturn` True + getCloseOnExec fd `shouldReturn` False #endif connect sock $ addrAddress addr return sock @@ -187,14 +188,15 @@ tcpTest clientAct serverAct = do } addr:_ <- getAddrInfo (Just hints) (Just serverAddr) Nothing sock <- socket (addrFamily addr) (addrSocketType addr) (addrProtocol addr) + fd <- fdSocket sock #if !defined(mingw32_HOST_OS) - getNonBlock (fdSocket sock) `shouldReturn` True - getCloseOnExec (fdSocket sock) `shouldReturn` False + getNonBlock fd `shouldReturn` True + getCloseOnExec fd `shouldReturn` False #endif setSocketOption sock ReuseAddr 1 - setCloseOnExecIfNeeded (fdSocket sock) + setCloseOnExecIfNeeded fd #if !defined(mingw32_HOST_OS) - getCloseOnExec (fdSocket sock) `shouldReturn` True + getCloseOnExec fd `shouldReturn` True #endif bind sock $ addrAddress addr listen sock 1 @@ -205,8 +207,9 @@ tcpTest clientAct serverAct = do server sock = do (clientSock, _) <- accept sock #if !defined(mingw32_HOST_OS) - getNonBlock (fdSocket clientSock) `shouldReturn` True - getCloseOnExec (fdSocket clientSock) `shouldReturn` True + fd <- fdSocket clientSock + getNonBlock fd `shouldReturn` True + getCloseOnExec fd `shouldReturn` True #endif _ <- serverAct clientSock close clientSock