1 {-# LANGUAGE ForeignFunctionInterface #-}
    2 {-# LANGUAGE DeriveDataTypeable #-}
    3 {-# LANGUAGE EmptyDataDecls #-}
    4 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
    5 {-# LANGUAGE OverloadedStrings #-}
    6 {-# LANGUAGE CPP #-}
    7 
    8 module Snap.Internal.Http.Server.GnuTLS
    9   ( GnuTLSException(..)
   10   , initTLS
   11   , stopTLS
   12   , bindHttps
   13   , freePort
   14   , createSession
   15   , endSession
   16   , recv
   17   , send
   18   ) where
   19 
   20 
   21 ------------------------------------------------------------------------------
   22 import           Control.Exception
   23 import           Data.ByteString (ByteString)
   24 import           Data.Dynamic
   25 import           Foreign.C
   26 
   27 import           Snap.Internal.Debug
   28 import           Snap.Internal.Http.Server.Backend
   29 
   30 #ifdef GNUTLS
   31 import qualified Data.ByteString as B
   32 import           Data.ByteString.Internal (w2c)
   33 import qualified Data.ByteString.Internal as BI
   34 import qualified Data.ByteString.Unsafe as BI
   35 import           Foreign
   36 import qualified Network.Socket as Socket
   37 #endif
   38 
   39 
   40 ------------------------------------------------------------------------------
   41 data GnuTLSException = GnuTLSException String
   42     deriving (Show, Typeable)
   43 instance Exception GnuTLSException
   44 
   45 #ifndef GNUTLS
   46 
   47 initTLS :: IO ()
   48 initTLS = throwIO $ GnuTLSException "TLS is not supported"
   49 
   50 stopTLS :: IO ()
   51 stopTLS = return ()
   52 
   53 bindHttps :: ByteString -> Int -> FilePath -> FilePath -> IO ListenSocket
   54 bindHttps _ _ _ _ = throwIO $ GnuTLSException "TLS is not supported"
   55 
   56 freePort :: ListenSocket -> IO ()
   57 freePort _ = return ()
   58 
   59 createSession :: ListenSocket -> Int -> CInt -> IO () -> IO NetworkSession
   60 createSession _ _ _ _ = throwIO $ GnuTLSException "TLS is not supported"
   61 
   62 endSession :: NetworkSession -> IO ()
   63 endSession _ = return ()
   64 
   65 send :: IO () -> IO () -> NetworkSession -> ByteString -> IO ()
   66 send _ _ _ _ = return ()
   67 
   68 recv :: IO b -> NetworkSession -> IO (Maybe ByteString)
   69 recv _ _ = throwIO $ GnuTLSException "TLS is not supported"
   70 
   71 #else
   72 
   73 
   74 ------------------------------------------------------------------------------
   75 -- | Init
   76 initTLS :: IO ()
   77 initTLS = gnutls_set_threading_helper >>
   78           throwErrorIf "TLS init" gnutls_global_init
   79 
   80 
   81 ------------------------------------------------------------------------------
   82 stopTLS :: IO ()
   83 stopTLS = gnutls_global_deinit
   84 
   85 
   86 ------------------------------------------------------------------------------
   87 -- | Binds ssl port
   88 bindHttps :: ByteString
   89           -> Int
   90           -> FilePath
   91           -> FilePath
   92           -> IO ListenSocket
   93 bindHttps bindAddress bindPort cert key = do
   94     sock <- Socket.socket Socket.AF_INET Socket.Stream 0
   95     addr <- getHostAddr bindPort bindAddress
   96     Socket.setSocketOption sock Socket.ReuseAddr 1
   97     Socket.bindSocket sock addr
   98     Socket.listen sock 150
   99 
  100     creds <- loadCredentials cert key
  101     dh <- regenerateDHParam creds
  102 
  103     return $ ListenHttps sock (castPtr creds) (castPtr dh)
  104 
  105 
  106 ------------------------------------------------------------------------------
  107 loadCredentials :: FilePath       --- ^ Path to certificate
  108                 -> FilePath       --- ^ Path to key
  109                 -> IO (Ptr GnuTLSCredentials)
  110 loadCredentials cert key = alloca $ \cPtr -> do
  111     throwErrorIf "TLS allocate" $ gnutls_certificate_allocate_credentials cPtr
  112     creds <- peek cPtr
  113 
  114     withCString cert $ \certstr -> withCString key $ \keystr ->
  115         throwErrorIf "TLS set Certificate" $
  116         gnutls_certificate_set_x509_key_file
  117             creds certstr keystr gnutls_x509_fmt_pem
  118 
  119     return creds
  120 
  121 
  122 ------------------------------------------------------------------------------
  123 regenerateDHParam :: Ptr GnuTLSCredentials -> IO (Ptr GnuTLSDHParam)
  124 regenerateDHParam creds = alloca $ \dhptr -> do
  125     throwErrorIf "TLS allocate" $ gnutls_dh_params_init dhptr
  126     dh <- peek dhptr
  127     throwErrorIf "TLS DHParm" $ gnutls_dh_params_generate2 dh 1024
  128     gnutls_certificate_set_dh_params creds dh
  129     return dh
  130 
  131 
  132 ------------------------------------------------------------------------------
  133 freePort :: ListenSocket -> IO ()
  134 freePort (ListenHttps _ creds dh) = do
  135     gnutls_certificate_free_credentials $ castPtr creds
  136     gnutls_dh_params_deinit $ castPtr dh
  137 freePort _ = return ()
  138 
  139 
  140 ------------------------------------------------------------------------------
  141 createSession :: ListenSocket -> Int -> CInt -> IO () -> IO NetworkSession
  142 createSession (ListenHttps _ creds _) recvSize socket on_block =
  143     alloca $ \sPtr -> do
  144         throwErrorIf "TLS alloacte" $ gnutls_init sPtr 1
  145         session <- peek sPtr
  146         throwErrorIf "TLS session" $
  147             gnutls_credentials_set session 1 $ castPtr creds
  148         throwErrorIf "TLS session" $ gnutls_set_default_priority session
  149         gnutls_certificate_send_x509_rdn_sequence session 1
  150         gnutls_session_enable_compatibility_mode session
  151 
  152         let s = NetworkSession socket (castPtr session) $
  153                     fromIntegral recvSize
  154 
  155         gnutls_transport_set_ptr session $ intPtrToPtr $ fromIntegral $ socket
  156 
  157         handshake s on_block
  158 
  159         return s
  160 createSession _ _ _ _ = error "Invalid socket"
  161 
  162 
  163 ------------------------------------------------------------------------------
  164 endSession :: NetworkSession -> IO ()
  165 endSession (NetworkSession _ session _) = do
  166     throwErrorIf "TLS bye" $ gnutls_bye (castPtr session) 1 `finally` do
  167         gnutls_deinit $ castPtr session
  168 
  169 
  170 ------------------------------------------------------------------------------
  171 handshake :: NetworkSession -> IO () -> IO ()
  172 handshake s@(NetworkSession { _session = session}) on_block = do
  173     rc <- gnutls_handshake $ castPtr session
  174     case rc of
  175         x | x >= 0         -> return ()
  176           | isIntrCode x   -> handshake s on_block
  177           | isAgainCode x  -> on_block >> handshake s on_block
  178           | otherwise      -> throwError "TLS handshake" rc
  179 
  180 
  181 ------------------------------------------------------------------------------
  182 send :: IO () -> IO () -> NetworkSession -> ByteString -> IO ()
  183 send tickleTimeout onBlock (NetworkSession { _session = session}) bs =
  184      BI.unsafeUseAsCStringLen bs $ uncurry loop
  185   where
  186     loop ptr len = do
  187         sent <- gnutls_record_send (castPtr session) ptr $ fromIntegral len
  188         let sent' = fromIntegral sent
  189         case sent' of
  190             x | x == 0 || x == len -> return ()
  191               | x > 0 && x < len   -> tickleTimeout >>
  192                                       loop (plusPtr ptr sent') (len - sent')
  193               | isIntrCode x       -> loop ptr len
  194               | isAgainCode x      -> onBlock >> loop ptr len
  195               | otherwise          -> throwError "TLS send" $
  196                                           fromIntegral sent'
  197 
  198 
  199 ------------------------------------------------------------------------------
  200 recv :: IO b -> NetworkSession -> IO (Maybe ByteString)
  201 recv onBlock (NetworkSession _ session recvLen) = do
  202     fp <- BI.mallocByteString recvLen
  203     sz <- withForeignPtr fp loop
  204     if (sz :: Int) <= 0
  205        then return Nothing
  206        else return $ Just $ BI.fromForeignPtr fp 0 $ fromEnum sz
  207 
  208   where
  209     loop recvBuf = do
  210         debug $ "TLS: calling record_recv with recvLen=" ++ show recvLen
  211         size <- gnutls_record_recv (castPtr session) recvBuf $ toEnum recvLen
  212         debug $ "TLS: record_recv returned with size=" ++ show size
  213         let size' = fromIntegral size
  214         case size' of
  215             x | x >= 0        -> return x
  216               | isIntrCode x  -> loop recvBuf
  217               | isAgainCode x -> onBlock >> loop recvBuf
  218               | otherwise     -> (throwError "TLS recv" $ fromIntegral size')
  219 
  220 
  221 ------------------------------------------------------------------------------
  222 throwError :: String -> ReturnCode -> IO a
  223 throwError prefix rc = gnutls_strerror rc >>=
  224                        peekCString >>=
  225                        throwIO . GnuTLSException . (prefix'++)
  226   where
  227     prefix' = prefix ++ "<" ++ show rc ++ ">: "
  228 
  229 
  230 ------------------------------------------------------------------------------
  231 throwErrorIf :: String -> IO ReturnCode -> IO ()
  232 throwErrorIf prefix action = do
  233     rc <- action
  234     if (rc < 0)
  235         then throwError prefix rc
  236         else return ()
  237 
  238 
  239 ------------------------------------------------------------------------------
  240 isAgainCode :: (Integral a) => a -> Bool
  241 isAgainCode x = (fromIntegral x) == (-28 :: Int)
  242 
  243 
  244 ------------------------------------------------------------------------------
  245 isIntrCode :: (Integral a) => a -> Bool
  246 isIntrCode x = (fromIntegral x) == (-52 :: Int)
  247 
  248 
  249 ------------------------------------------------------------------------------
  250 getHostAddr :: Int
  251             -> ByteString
  252             -> IO Socket.SockAddr
  253 getHostAddr p s = do
  254     h <- if s == "*"
  255           then return Socket.iNADDR_ANY
  256           else Socket.inet_addr (map w2c . B.unpack $ s)
  257 
  258     return $ Socket.SockAddrInet (fromIntegral p) h
  259 
  260 -- Types
  261 
  262 newtype ReturnCode = ReturnCode CInt
  263     deriving (Show, Eq, Ord, Num, Real, Enum, Integral)
  264 
  265 data GnuTLSCredentials
  266 data GnuTLSSession
  267 data GnuTLSDHParam
  268 
  269 ------------------------------------------------------------------------------
  270 -- Global init/errors
  271 
  272 foreign import ccall safe
  273     "gnutls_set_threading_helper"
  274     gnutls_set_threading_helper :: IO ()
  275 
  276 foreign import ccall safe
  277     "gnutls/gnutls.h gnutls_global_init"
  278     gnutls_global_init :: IO ReturnCode
  279 
  280 foreign import ccall safe
  281     "gnutls/gnutls.h gnutls_global_deinit"
  282     gnutls_global_deinit :: IO ()
  283 
  284 foreign import ccall safe
  285     "gnutls/gnutls.h gnutls_strerror"
  286     gnutls_strerror :: ReturnCode -> IO CString
  287 
  288 ------------------------------------------------------------------------------
  289 -- Sessions.  All functions here except handshake and bye just
  290 -- allocate memory or update members of structures, so they are ok with
  291 -- unsafe ccall.
  292 
  293 foreign import ccall unsafe
  294     "gnutls/gnutls.h gnutls_init"
  295     gnutls_init :: Ptr (Ptr GnuTLSSession) -> CInt -> IO ReturnCode
  296 
  297 foreign import ccall unsafe
  298     "gnutls/gnutls.h gnutls_deinit"
  299     gnutls_deinit :: Ptr GnuTLSSession -> IO ()
  300 
  301 foreign import ccall safe
  302     "gnutls/gnutls.h gnutls_handshake"
  303     gnutls_handshake :: Ptr GnuTLSSession -> IO ReturnCode
  304 
  305 foreign import ccall safe
  306     "gnutls/gnutls.h gnutls_bye"
  307     gnutls_bye :: Ptr GnuTLSSession -> CInt -> IO ReturnCode
  308 
  309 foreign import ccall unsafe
  310     "gnutls/gnutls.h gnutls_set_default_priority"
  311     gnutls_set_default_priority :: Ptr GnuTLSSession -> IO ReturnCode
  312 
  313 foreign import ccall unsafe
  314     "gnutls/gnutls.h gnutls_session_enable_compatibility_mode"
  315     gnutls_session_enable_compatibility_mode :: Ptr GnuTLSSession -> IO ()
  316 
  317 foreign import ccall unsafe
  318     "gnutls/gnutls.h gnutls_certificate_send_x509_rdn_sequence"
  319     gnutls_certificate_send_x509_rdn_sequence
  320       :: Ptr GnuTLSSession -> CInt -> IO ()
  321 
  322 ------------------------------------------------------------------------------
  323 -- Certificates.  Perhaps these could be unsafe but they are not performance
  324 -- critical, since they are called only once during server startup.
  325 
  326 foreign import ccall safe
  327     "gnutls/gnutls.h gnutls_certificate_allocate_credentials"
  328     gnutls_certificate_allocate_credentials
  329       :: Ptr (Ptr GnuTLSCredentials) -> IO ReturnCode
  330 
  331 foreign import ccall safe
  332     "gnutls/gnutls.h gnutls_certificate_free_credentials"
  333     gnutls_certificate_free_credentials
  334       :: Ptr GnuTLSCredentials -> IO ()
  335 
  336 gnutls_x509_fmt_pem :: CInt
  337 gnutls_x509_fmt_pem = 1
  338 
  339 foreign import ccall safe
  340     "gnutls/gnutls.h gnutls_certificate_set_x509_key_file"
  341     gnutls_certificate_set_x509_key_file
  342       :: Ptr GnuTLSCredentials -> CString -> CString -> CInt -> IO ReturnCode
  343 
  344 
  345 ------------------------------------------------------------------------------
  346 -- Credentials.  This is ok as unsafe because it just sets members in the
  347 -- session structure.
  348 
  349 foreign import ccall unsafe
  350     "gnutls/gnutls.h gnutls_credentials_set"
  351     gnutls_credentials_set
  352         :: Ptr GnuTLSSession -> CInt -> Ptr a -> IO ReturnCode
  353 
  354 ------------------------------------------------------------------------------
  355 -- Records.  These are marked unsafe because they are very performance
  356 -- critical.  Since we are using non-blocking sockets send and recv will not
  357 -- block.
  358 
  359 foreign import ccall unsafe
  360     "gnutls/gnutls.h gnutls_transport_set_ptr"
  361     gnutls_transport_set_ptr :: Ptr GnuTLSSession -> Ptr a -> IO ()
  362 
  363 foreign import ccall unsafe
  364     "gnutls/gnutls.h gnutls_record_recv"
  365     gnutls_record_recv :: Ptr GnuTLSSession -> Ptr a -> CSize -> IO CSize
  366 
  367 foreign import ccall unsafe
  368     "gnutls/gnutls.h gnutls_record_send"
  369     gnutls_record_send :: Ptr GnuTLSSession -> Ptr a -> CSize -> IO CSize
  370 
  371 ------------------------------------------------------------------------------
  372 -- DHParam.  Perhaps these could be unsafe but they are not performance
  373 -- critical.
  374 
  375 foreign import ccall safe
  376     "gnutls/gnutls.h gnutls_dh_params_init"
  377     gnutls_dh_params_init :: Ptr (Ptr GnuTLSDHParam) -> IO ReturnCode
  378 
  379 foreign import ccall safe
  380     "gnutls/gnutls.h gnutls_dh_params_deinit"
  381     gnutls_dh_params_deinit :: Ptr GnuTLSDHParam -> IO ()
  382 
  383 foreign import ccall safe
  384     "gnutls/gnutls.h gnutls_dh_params_generate2"
  385     gnutls_dh_params_generate2 :: Ptr GnuTLSDHParam -> CUInt -> IO ReturnCode
  386 
  387 foreign import ccall safe
  388     "gnutls/gnutls.h gnutls_certificate_set_dh_params"
  389     gnutls_certificate_set_dh_params
  390       :: Ptr GnuTLSCredentials -> Ptr GnuTLSDHParam -> IO ()
  391 
  392 #endif