Skip to content

Commit

Permalink
fix TCP reset; deep proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
Zesen Qian committed Jan 2, 2017
1 parent 56e0abd commit ead2416
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 84 deletions.
97 changes: 53 additions & 44 deletions src/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@ import qualified ZhinaDNS as ZDNS
import IPSet
import Text.Parsec
import Parse
import qualified Server as S
import qualified Log

import System.IO
import System.Log.Logger

import Network.Socket hiding (recv, recvFrom, send, sendTo)
import Network.Socket.ByteString
import qualified Network.Socket.ByteString.Lazy as SL

import qualified Resolve.Types as R
import qualified Resolve.DNS.Transport.Helper.UDP as UDP
import qualified Resolve.DNS.Transport.Helper.LiveTCP as TCP
import qualified Resolve.DNS.Transport as Transport
import qualified Resolve.DNS.Lookup as L
import Resolve.DNS.Utils
import Resolve.DNS.Coding

import Resolve.Timeout
import qualified Resolve.Log as L

import qualified Data.ByteString.Lazy as BSL
import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString as BS

import Data.ByteString.Builder
import Data.Maybe

Expand Down Expand Up @@ -111,35 +111,35 @@ main = do
r_world_tcp <- Transport.new t_world_tcp
infoM nameF $ "created client: " ++ (show c_world_tcp)

let log = L.log (debugM "Main") (debugM "Main")
let r_china_udp' = timeout zhina_udp_timeout' $ log $ decode $ R.resolve r_china_udp
let r_china_tcp' = timeout zhina_tcp_timeout' $ log $ decode $ R.resolve r_china_tcp
let r_world_tcp' = timeout world_tcp_timeout' $ log $ decode $ R.resolve r_world_tcp

let r_udp = ZDNS.resolve $ ZDNS.Config
{ ZDNS.china = r_china_udp'
, ZDNS.world = r_world_tcp'
, ZDNS.chinaIP = ips
}

r_tcp = ZDNS.resolve $ ZDNS.Config
{ ZDNS.china = r_china_tcp'
, ZDNS.world = r_world_tcp'
, ZDNS.chinaIP = ips
}
l_china <- L.new $ L.Config { L.udp = Just (timeout zhina_udp_timeout' $ R.resolve r_china_udp, (return 1024))
, L.tcp = Just (timeout zhina_tcp_timeout' $ R.resolve r_china_tcp)
}

l_world <- L.new $ L.Config { L.udp = Nothing
, L.tcp = Just (timeout world_tcp_timeout' $ R.resolve r_world_tcp)}

let r = ZDNS.resolve $ ZDNS.Config
{ ZDNS.china = R.resolve l_china
, ZDNS.world = R.resolve l_world
, ZDNS.chinaIP = ips
}

void $ forkIO $ udp $ Config { resolve = encode $ r_udp
, host = host'
, port = port'
}
void $ forkIO $ udp $ Config { resolve = S.server $ S.Config { S.back = r
, S.is_udp = True
}
, host = host'
, port = port'
}

void $ forkIO $ tcp_listen $ Config { resolve = encode $ r_tcp
, host = host'
, port = port'
}
void $ forkIO $ tcp_listen $ Config { resolve = S.server $ S.Config { S.back = r
, S.is_udp = False
}
, host = host'
, port = port'
}
forever $ threadDelay 1000000

data Config = Config { resolve :: R.Resolve ByteString ByteString
data Config = Config { resolve :: R.Resolve BSL.ByteString BSL.ByteString
, host :: String
, port :: String
}
Expand Down Expand Up @@ -185,11 +185,12 @@ tcp_listen c = do
(\(sock', _) -> close sock')
(\(sock', sa) -> do
let nameConn = nameF ++ "." ++ (show sa)
forkFinally (tcp sock' nameConn (resolve c)) (\_ -> close sock'))
forkFinally (tcp sock' nameConn (resolve c)) (\_ -> debugM nameF "closing the socket" >> close sock'))
)


tcp sock' _ r = do
let nameF = nameM ++ ".tcp"
qi <- newEmptyTMVarIO
qo <- newEmptyTMVarIO
si <- newTVarIO False
Expand All @@ -200,39 +201,45 @@ tcp sock' _ r = do
-- thread receiving messages to qi
ti <- forkFinally
(do
let recvAll' n = do
bs <- SL.recv sock' n
when (BSL.length bs == 0) $ throwIO ThreadKilled
mappend (lazyByteString bs) <$> (recvAll' $ n - (BSL.length bs))
recvAll n = toLazyByteString <$> recvAll' n
let recvAll' l n = if n == 0 then return l
else do
bs <- recv sock' n
debugM nameF $ "recv: " ++ (show $BS.length bs) ++ "B = " ++ (show bs)
when (BS.null bs) $ throwIO ThreadKilled
recvAll' (mappend l $ byteString bs) (n - (BS.length bs))
recvAll n = toLazyByteString <$> recvAll' mempty n

forever $ runMaybeT $ do
n <- lift $ recvAll 2
let n' = toWord16 (BSL.toStrict n)
lift $ do
bs <- recvAll $ fromIntegral n'
atomically $ putTMVar qi bs)
(\_ -> atomically $ writeTVar si True)
(\x -> do
debugM nameF $ "recv exited: " ++ either (\e -> show (e :: SomeException)) (\_ -> " elegantly") x
atomically $ writeTVar si True)

-- thread sending messages from qo
to <- forkFinally
(do
let sendAll bs = if BSL.null bs then
let sendAll bs = if BS.null bs then
return ()
else do
n <- SL.send sock' bs
sendAll (BSL.drop n bs)
n <- send sock' bs
sendAll (BS.drop n bs)
forever $ do
bs <- atomically $ takeTMVar qo
case safeFromIntegral $ BSL.length bs of
Nothing -> return ()
Just n -> do
sendAll $ BSL.fromStrict $ fromWord16 n
sendAll bs)
(\_ -> atomically $ writeTVar so True)
sendAll $ fromWord16 n
sendAll (BSL.toStrict bs))
(\x -> do
debugM nameF $ "send exited: " ++ either (\e -> show (e :: SomeException)) (\_ -> " elegantly") x
atomically $ writeTVar so True)
return (ti, to)
)
(\(ti, to) -> do
(\(ti, to) -> uninterruptibleMask_ $ do
killThread ti
killThread to
)
Expand All @@ -242,7 +249,9 @@ tcp sock' _ r = do
if x then tryTakeTMVar qi
else Just <$> takeTMVar qi
case a of
Nothing -> throwIO ThreadKilled
Nothing -> do
debugM nameF "my friend is dead, I want to kill myself"
throwIO ThreadKilled
Just a' -> forkIO $ do
b <- r a'
atomically $ do
Expand Down
113 changes: 113 additions & 0 deletions src/Server.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
module Server where

import Resolve.Types
import Resolve.DNS.Types
import qualified Resolve.DNS.Lookup as L

import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString.Lazy as BSL

import Control.Monad.Trans.Except
import Control.Monad.Trans.Class

import qualified Resolve.DNS.Encode as E
import qualified Resolve.DNS.Decode as D

import Control.Monad
import Control.Exception

import Data.Word
import Data.Typeable

import System.Log.Logger

data Error = WierdQuery String
| ResponseTooLong
deriving (Typeable, Show)

instance Exception Error

nameM = "Server"

data Config = Config { back :: Resolve L.Query L.Response
, is_udp :: Bool
}

server :: Config -> Resolve ByteString ByteString
server c bs_a = do
m_a <- case D.decodeMessage (BSL.toStrict bs_a) of
Left e -> throwIO $ D.Error e
Right m -> return m

let h = header m_a
when (qr h /= Query ||
opcode h /= STD ||
aa h == True ||
tc h == True ||
ra h == True ||
rcode h /= NoErr
) $ do
throwIO $ WierdQuery "Some fields are wierd"


when ((not $ null $ answer m_a) ||
(not $ null $ authority m_a) ||
(not $ null $ additional m_a) ) $ do
throwIO $ WierdQuery "some sections should be empty"

let a = L.Query { L.qquestion = question m_a
, L.qopt = []
}

b' <- try (back c a)
m_b <- case b' of
Left e -> do
debugM nameM $ show (e :: SomeException)
return $ Message { header = (header m_a) { qr = Response
, rcode = ServFail
, ra = True
, zero = 0
}
, question = question m_a
, answer = []
, authority = []
, additional = []
, opt = Nothing
}
Right b -> return $ Message { header = (header m_a) { qr = Response
, aa = False
, ra = True
, zero = 0
}
, question = question m_a
, answer = L.ranswer b
, authority = L.rauthority b
, additional = L.radditional b
, opt = Nothing
}

bs_b <- case E.encode E.message m_b of
Left e -> throwIO e
Right bs -> return bs

bs_b' <- runExceptT $ do
this <- case is_udp c of
False -> throwE bs_b
True -> return (512 :: Word16)
lift $ debugM nameM $ "payload on this side is " ++ (show this)
when (BSL.null (BSL.drop (fromIntegral this) bs_b)) $ throwE bs_b
lift $ debugM nameM "response too long, setting TC bit.."
let m_b' = m_b { header = (header m_b) { tc = True}
, answer = []
, authority = []
, additional = []
, opt = Nothing
}
bs_b' <- case E.encode E.message m_b' of
Left e -> lift $ throwIO e
Right bs -> return bs
when (BSL.null $ BSL.drop (fromIntegral this) bs_b') $ throwE bs_b'

case bs_b' of
Left bs_b -> return bs_b
Right _ -> throwIO $ ResponseTooLong
58 changes: 20 additions & 38 deletions src/ZhinaDNS.hs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module ZhinaDNS where

import qualified Resolve.Types as R
import Resolve.DNS.Types
import Resolve.DNS.Types hiding (Query, Response)
import Resolve.DNS.Lookup hiding (Config)
import IPSet

import Control.Monad
import Control.Monad.Trans.Except
import Control.Concurrent
import Control.Exception
Expand All @@ -14,37 +14,22 @@ import System.Log.Logger

nameM = "ZhinaDNS"

data Config = Config { china :: R.Resolve Message Message
, world :: R.Resolve Message Message
data Config = Config { china :: R.Resolve Query Response
, world :: R.Resolve Query Response
, chinaIP :: IPSet IPv4
}


resolve :: Config -> R.Resolve Message Message
resolve :: Config -> R.Resolve Query Response
resolve c a = do
let a' = a { opt = Nothing } -- strip OPT RR

let nameF = nameM ++ ".resolve"
m_china <- newEmptyMVar
m_world <- newEmptyMVar

let errRep m = Message { header = (header m) { qr = Response
, tc = False
, ra = True
, zero = 0
, rcode = ServFail
}
, question = question m
, answer = []
, authority = []
, additional = []
, opt = Nothing
}

bracket
(do
t_china <- forkIO $ putMVar m_china =<< try (china c a')
t_world <- forkIO $ putMVar m_world =<< try (world c a')
t_china <- forkIO $ putMVar m_china =<< try (china c a)
t_world <- forkIO $ putMVar m_world =<< try (world c a)
return (t_china, t_world)
)
(\(t_china, t_world) -> do
Expand All @@ -53,31 +38,28 @@ resolve c a = do
)
(\_ -> either id id <$> (runExceptT $
do
b_china <- lift $ takeMVar m_china
b_china' <- case b_china of
b_china' <- lift $ takeMVar m_china
b_china <- case b_china' of
Left e -> do
lift $ errorM nameF $ show (e :: SomeException)
throwE $ errRep a
lift $ debugM nameF $ "zhina: " ++ show (e :: SomeException)
lift $ throwIO e
Right b' -> return b'

when ((rcode $ header b_china') /= NoErr) $ throwE $ errRep a

let isForeign rdata' = case rdata' of
RR_A ip -> not $ test (chinaIP c) ip
_ -> False
b_final <- if any (\rr -> isForeign (rdata rr)) (answer b_china') then do
b_final <- if any (\rr -> isForeign (rdata rr)) (ranswer b_china) then do
lift $ debugM nameF "foreign results detected, waiting for foreign DNS"
b_world <- lift $ takeMVar m_world
b_world' <- case b_world of
b_world' <- lift $ takeMVar m_world
b_world <- case b_world' of
Left e -> do
lift $ errorM nameF $ show (e :: SomeException)
throwE $ errRep a
lift $ debugM nameF $ "world: " ++ show (e :: SomeException)
lift $ throwIO e
Right b' -> return b'
when ((rcode $ header b_world') /= NoErr) $ throwE $ errRep a

return b_world'

return b_world
else
return b_china'
return b_china

return $ b_final {header = (header b_final) {ident = ident $ header $ a}}
return $ b_final
))
Loading

0 comments on commit ead2416

Please sign in to comment.