1 {-# LANGUAGE BangPatterns             #-}
    2 {-# LANGUAGE CPP                      #-}
    3 {-# LANGUAGE DeriveDataTypeable       #-}
    4 {-# LANGUAGE ForeignFunctionInterface #-}
    5 {-# LANGUAGE OverloadedStrings        #-}
    6 {-# LANGUAGE PackageImports           #-}
    7 {-# LANGUAGE RankNTypes               #-}
    8 {-# LANGUAGE ScopedTypeVariables      #-}
    9 
   10 module Snap.Internal.Http.Server.SimpleBackend
   11   ( simpleEventLoop
   12   ) where
   13 
   14 
   15 ------------------------------------------------------------------------------
   16 import           Control.Monad.Trans
   17 
   18 import           Control.Concurrent hiding (yield)
   19 import           Control.Exception
   20 import           Control.Monad
   21 import           Data.ByteString (ByteString)
   22 import qualified Data.ByteString as S
   23 import           Data.ByteString.Internal (c2w)
   24 import           Data.Maybe
   25 import           Data.Typeable
   26 import           Foreign hiding (new)
   27 import           Foreign.C
   28 import           GHC.Conc (labelThread, forkOnIO)
   29 import           Network.Socket
   30 import           Prelude hiding (catch)
   31 ------------------------------------------------------------------------------
   32 import           Snap.Internal.Debug
   33 import           Snap.Internal.Http.Server.Date
   34 import qualified Snap.Internal.Http.Server.TimeoutManager as TM
   35 import           Snap.Internal.Http.Server.TimeoutManager (TimeoutManager)
   36 import           Snap.Internal.Http.Server.Backend
   37 import qualified Snap.Internal.Http.Server.ListenHelpers as Listen
   38 import           Snap.Iteratee hiding (map)
   39 
   40 #if defined(HAS_SENDFILE)
   41 import qualified System.SendFile as SF
   42 import           System.Posix.IO
   43 import           System.Posix.Types (Fd(..))
   44 #endif
   45 
   46 
   47 ------------------------------------------------------------------------------
   48 -- | For each cpu, we store:
   49 --    * A list of accept threads, one per port.
   50 --    * A TimeoutManager
   51 --    * An mvar to signal when the timeout thread is shutdown
   52 data EventLoopCpu = EventLoopCpu
   53     { _boundCpu        :: Int
   54     , _acceptThreads   :: [ThreadId]
   55     , _timeoutManager  :: TimeoutManager
   56     , _exitMVar        :: !(MVar ())
   57     }
   58 
   59 
   60 ------------------------------------------------------------------------------
   61 simpleEventLoop :: EventLoop
   62 simpleEventLoop defaultTimeout sockets cap elog handler = do
   63     loops <- Prelude.mapM (newLoop defaultTimeout sockets handler elog)
   64                           [0..(cap-1)]
   65 
   66     debug "simpleEventLoop: waiting for mvars"
   67 
   68     --wait for all threads to exit
   69     Prelude.mapM_ (takeMVar . _exitMVar) loops `finally` do
   70         debug "simpleEventLoop: killing all threads"
   71         _ <- mapM_ stopLoop loops
   72         mapM_ Listen.closeSocket sockets
   73 
   74 
   75 ------------------------------------------------------------------------------
   76 newLoop :: Int
   77         -> [ListenSocket]
   78         -> SessionHandler
   79         -> (S.ByteString -> IO ())
   80         -> Int
   81         -> IO EventLoopCpu
   82 newLoop defaultTimeout sockets handler elog cpu = do
   83     tmgr       <- TM.initialize defaultTimeout getCurrentDateTime
   84     exit       <- newEmptyMVar
   85     accThreads <- forM sockets $ \p -> forkOnIO cpu $
   86                   acceptThread defaultTimeout handler tmgr elog cpu p exit
   87 
   88     return $ EventLoopCpu cpu accThreads tmgr exit
   89 
   90 
   91 ------------------------------------------------------------------------------
   92 stopLoop :: EventLoopCpu -> IO ()
   93 stopLoop loop = block $ do
   94     TM.stop $ _timeoutManager loop
   95     Prelude.mapM_ killThread $ _acceptThreads loop
   96 
   97 
   98 ------------------------------------------------------------------------------
   99 acceptThread :: Int
  100              -> SessionHandler
  101              -> TimeoutManager
  102              -> (S.ByteString -> IO ())
  103              -> Int
  104              -> ListenSocket
  105              -> MVar ()
  106              -> IO ()
  107 acceptThread defaultTimeout handler tmgr elog cpu sock exitMVar =
  108     loop `finally` (tryPutMVar exitMVar () >> return ())
  109   where
  110     loop = do
  111         debug $ "acceptThread: calling accept() on socket " ++ show sock
  112         (s,addr) <- accept $ Listen.listenSocket sock
  113         debug $ "acceptThread: accepted connection from remote: " ++ show addr
  114         _ <- forkOnIO cpu (go s addr `catches` cleanup)
  115         loop
  116 
  117     go = runSession defaultTimeout handler tmgr sock
  118 
  119     cleanup =
  120         [
  121           Handler $ \(_ :: AsyncException) -> return ()
  122         , Handler $ \(e :: SomeException) -> elog
  123                   $ S.concat [ "SimpleBackend.acceptThread: "
  124                              , S.pack . map c2w $ show e]
  125         ]
  126 
  127 
  128 ------------------------------------------------------------------------------
  129 data AddressNotSupportedException = AddressNotSupportedException String
  130    deriving (Typeable)
  131 
  132 instance Show AddressNotSupportedException where
  133     show (AddressNotSupportedException x) = "Address not supported: " ++ x
  134 
  135 instance Exception AddressNotSupportedException
  136 
  137 
  138 ------------------------------------------------------------------------------
  139 runSession :: Int
  140            -> SessionHandler
  141            -> TimeoutManager
  142            -> ListenSocket
  143            -> Socket
  144            -> SockAddr -> IO ()
  145 runSession defaultTimeout handler tmgr lsock sock addr = do
  146     let fd = fdSocket sock
  147     curId <- myThreadId
  148 
  149     debug $ "Backend.withConnection: running session: " ++ show addr
  150     labelThread curId $ "connHndl " ++ show fd
  151 
  152     (rport,rhost) <-
  153         case addr of
  154           SockAddrInet p h -> do
  155              h' <- inet_ntoa h
  156              return (fromIntegral p, S.pack $ map c2w h')
  157           x -> throwIO $ AddressNotSupportedException $ show x
  158 
  159     laddr <- getSocketName sock
  160 
  161     (lport,lhost) <-
  162         case laddr of
  163           SockAddrInet p h -> do
  164              h' <- inet_ntoa h
  165              return (fromIntegral p, S.pack $ map c2w h')
  166           x -> throwIO $ AddressNotSupportedException $ show x
  167 
  168     let sinfo = SessionInfo lhost lport rhost rport $ Listen.isSecure lsock
  169 
  170     timeoutHandle <- TM.register (killThread curId) tmgr
  171     let timeout = TM.tickle timeoutHandle
  172 
  173     bracket (Listen.createSession lsock 8192 fd
  174               (threadWaitRead $ fromIntegral fd))
  175             (\session -> block $ do
  176                  debug "thread killed, closing socket"
  177 
  178                  -- cancel thread timeout
  179                  TM.cancel timeoutHandle
  180 
  181                  eatException $ Listen.endSession lsock session
  182                  eatException $ shutdown sock ShutdownBoth
  183                  eatException $ sClose sock
  184             )
  185             (\s -> let writeEnd = writeOut lsock s sock
  186                                       (timeout defaultTimeout)
  187                    in handler sinfo
  188                               (enumerate lsock s sock)
  189                               writeEnd
  190                               (sendFile lsock (timeout defaultTimeout) fd
  191                                         writeEnd)
  192                               timeout
  193             )
  194 
  195 
  196 ------------------------------------------------------------------------------
  197 eatException :: IO a -> IO ()
  198 eatException act = (act >> return ()) `catch` \(_::SomeException) -> return ()
  199 
  200 
  201 ------------------------------------------------------------------------------
  202 sendFile :: ListenSocket
  203          -> IO ()
  204          -> CInt
  205          -> Iteratee ByteString IO ()
  206          -> FilePath
  207          -> Int64
  208          -> Int64
  209          -> IO ()
  210 #if defined(HAS_SENDFILE)
  211 sendFile lsock tickle sock writeEnd fp start sz =
  212     case lsock of
  213         ListenHttp _ -> bracket (openFd fp ReadOnly Nothing defaultFileFlags)
  214                                 (closeFd)
  215                                 (go start sz)
  216         _            -> do
  217                    step <- runIteratee writeEnd
  218                    run_ $ enumFilePartial fp (start,start+sz) step
  219   where
  220     go off bytes fd
  221       | bytes == 0 = return ()
  222       | otherwise  = do
  223             sent <- SF.sendFile (threadWaitWrite $ fromIntegral sock)
  224                                 sfd fd off bytes
  225             if sent < bytes
  226               then tickle >> go (off+sent) (bytes-sent) fd
  227               else return ()
  228 
  229     sfd = Fd sock
  230 #else
  231 sendFile _ _ _ writeEnd fp start sz = do
  232     -- no need to count bytes
  233     step <- runIteratee writeEnd
  234     run_ $ enumFilePartial fp (start,start+sz) step
  235     return ()
  236 #endif
  237 
  238 
  239 ------------------------------------------------------------------------------
  240 enumerate :: (MonadIO m)
  241           => ListenSocket
  242           -> NetworkSession
  243           -> Socket
  244           -> Enumerator ByteString m a
  245 enumerate port session sock = loop
  246   where
  247     dbg s = debug $ "SimpleBackend.enumerate(" ++ show (_socket session)
  248             ++ "): " ++ s
  249 
  250     loop (Continue k) = do
  251         dbg "reading from socket"
  252         s <- liftIO $ timeoutRecv
  253         case s of
  254             Nothing -> do
  255                    dbg "got EOF from socket"
  256                    sendOne k ""
  257             Just s' -> do
  258                    dbg $ "got " ++ Prelude.show (S.length s')
  259                            ++ " bytes from read end"
  260                    sendOne k s'
  261 
  262     loop x = returnI x
  263 
  264 
  265     sendOne k s | S.null s  = do
  266         dbg "sending EOF to continuation"
  267         enumEOF $ Continue k
  268 
  269                 | otherwise = do
  270         dbg $ "sending " ++ show s ++ " to continuation"
  271         step <- lift $ runIteratee $ k $ Chunks [s]
  272         case step of
  273           (Yield x st)   -> do
  274                       dbg $ "got yield, remainder is " ++ show st
  275                       yield x st
  276           r@(Continue _) -> do
  277                       dbg $ "got continue"
  278                       loop r
  279           (Error e)      -> throwError e
  280 
  281     fd = fdSocket sock
  282 #ifdef PORTABLE
  283     timeoutRecv = Listen.recv port sock (threadWaitRead $
  284                   fromIntegral fd) session
  285 #else
  286     timeoutRecv = Listen.recv port (threadWaitRead $
  287                   fromIntegral fd) session
  288 #endif
  289 
  290 
  291 ------------------------------------------------------------------------------
  292 writeOut :: (MonadIO m)
  293          => ListenSocket
  294          -> NetworkSession
  295          -> Socket
  296          -> (IO ())
  297          -> Iteratee ByteString m ()
  298 writeOut port session sock tickle = loop
  299   where
  300     dbg s = debug $ "SimpleBackend.writeOut(" ++ show (_socket session)
  301             ++ "): " ++ s
  302 
  303     loop = continue k
  304 
  305     k EOF = yield () EOF
  306     k (Chunks xs) = do
  307         let s = S.concat xs
  308         let n = S.length s
  309         dbg $ "got chunk with " ++ show n ++ " bytes"
  310         ee <- liftIO $ try $ timeoutSend s
  311         case ee of
  312           (Left (e::SomeException)) -> do
  313               dbg $ "timeoutSend got error " ++ show e
  314               throwError e
  315           (Right _) -> do
  316               let last10 = S.drop (n-10) s
  317               dbg $ "wrote " ++ show n ++ " bytes, last 10=" ++ show last10
  318               loop
  319 
  320     fd = fdSocket sock
  321 #ifdef PORTABLE
  322     timeoutSend = Listen.send port sock tickle
  323                               (threadWaitWrite $ fromIntegral fd) session
  324 #else
  325     timeoutSend = Listen.send port tickle
  326                               (threadWaitWrite $ fromIntegral fd) session
  327 #endif