1 {-# LANGUAGE ForeignFunctionInterface #-} 2 {-# LANGUAGE DeriveDataTypeable #-} 3 {-# LANGUAGE EmptyDataDecls #-} 4 {-# LANGUAGE GeneralizedNewtypeDeriving #-} 5 {-# LANGUAGE OverloadedStrings #-} 6 {-# LANGUAGE CPP #-} 7 8 module Snap.Internal.Http.Server.GnuTLS 9 ( GnuTLSException(..) 10 , initTLS 11 , stopTLS 12 , bindHttps 13 , freePort 14 , createSession 15 , endSession 16 , recv 17 , send 18 ) where 19 20 21 ------------------------------------------------------------------------------ 22 import Control.Exception 23 import Data.ByteString (ByteString) 24 import Data.Dynamic 25 import Foreign.C 26 27 import Snap.Internal.Debug 28 import Snap.Internal.Http.Server.Backend 29 30 #ifdef GNUTLS 31 import qualified Data.ByteString as B 32 import Data.ByteString.Internal (w2c) 33 import qualified Data.ByteString.Internal as BI 34 import qualified Data.ByteString.Unsafe as BI 35 import Foreign 36 import qualified Network.Socket as Socket 37 #endif 38 39 40 ------------------------------------------------------------------------------ 41 data GnuTLSException = GnuTLSException String 42 deriving (Show, Typeable) 43 instance Exception GnuTLSException 44 45 #ifndef GNUTLS 46 47 initTLS :: IO () 48 initTLS = throwIO $ GnuTLSException "TLS is not supported" 49 50 stopTLS :: IO () 51 stopTLS = return () 52 53 bindHttps :: ByteString -> Int -> FilePath -> FilePath -> IO ListenSocket 54 bindHttps _ _ _ _ = throwIO $ GnuTLSException "TLS is not supported" 55 56 freePort :: ListenSocket -> IO () 57 freePort _ = return () 58 59 createSession :: ListenSocket -> Int -> CInt -> IO () -> IO NetworkSession 60 createSession _ _ _ _ = throwIO $ GnuTLSException "TLS is not supported" 61 62 endSession :: NetworkSession -> IO () 63 endSession _ = return () 64 65 send :: IO () -> IO () -> NetworkSession -> ByteString -> IO () 66 send _ _ _ _ = return () 67 68 recv :: IO b -> NetworkSession -> IO (Maybe ByteString) 69 recv _ _ = throwIO $ GnuTLSException "TLS is not supported" 70 71 #else 72 73 74 ------------------------------------------------------------------------------ 75 -- | Init 76 initTLS :: IO () 77 initTLS = gnutls_set_threading_helper >> 78 throwErrorIf "TLS init" gnutls_global_init 79 80 81 ------------------------------------------------------------------------------ 82 stopTLS :: IO () 83 stopTLS = gnutls_global_deinit 84 85 86 ------------------------------------------------------------------------------ 87 -- | Binds ssl port 88 bindHttps :: ByteString 89 -> Int 90 -> FilePath 91 -> FilePath 92 -> IO ListenSocket 93 bindHttps bindAddress bindPort cert key = do 94 sock <- Socket.socket Socket.AF_INET Socket.Stream 0 95 addr <- getHostAddr bindPort bindAddress 96 Socket.setSocketOption sock Socket.ReuseAddr 1 97 Socket.bindSocket sock addr 98 Socket.listen sock 150 99 100 creds <- loadCredentials cert key 101 dh <- regenerateDHParam creds 102 103 return $ ListenHttps sock (castPtr creds) (castPtr dh) 104 105 106 ------------------------------------------------------------------------------ 107 loadCredentials :: FilePath --- ^ Path to certificate 108 -> FilePath --- ^ Path to key 109 -> IO (Ptr GnuTLSCredentials) 110 loadCredentials cert key = alloca $ \cPtr -> do 111 throwErrorIf "TLS allocate" $ gnutls_certificate_allocate_credentials cPtr 112 creds <- peek cPtr 113 114 withCString cert $ \certstr -> withCString key $ \keystr -> 115 throwErrorIf "TLS set Certificate" $ 116 gnutls_certificate_set_x509_key_file 117 creds certstr keystr gnutls_x509_fmt_pem 118 119 return creds 120 121 122 ------------------------------------------------------------------------------ 123 regenerateDHParam :: Ptr GnuTLSCredentials -> IO (Ptr GnuTLSDHParam) 124 regenerateDHParam creds = alloca $ \dhptr -> do 125 throwErrorIf "TLS allocate" $ gnutls_dh_params_init dhptr 126 dh <- peek dhptr 127 throwErrorIf "TLS DHParm" $ gnutls_dh_params_generate2 dh 1024 128 gnutls_certificate_set_dh_params creds dh 129 return dh 130 131 132 ------------------------------------------------------------------------------ 133 freePort :: ListenSocket -> IO () 134 freePort (ListenHttps _ creds dh) = do 135 gnutls_certificate_free_credentials $ castPtr creds 136 gnutls_dh_params_deinit $ castPtr dh 137 freePort _ = return () 138 139 140 ------------------------------------------------------------------------------ 141 createSession :: ListenSocket -> Int -> CInt -> IO () -> IO NetworkSession 142 createSession (ListenHttps _ creds _) recvSize socket on_block = 143 alloca $ \sPtr -> do 144 throwErrorIf "TLS alloacte" $ gnutls_init sPtr 1 145 session <- peek sPtr 146 throwErrorIf "TLS session" $ 147 gnutls_credentials_set session 1 $ castPtr creds 148 throwErrorIf "TLS session" $ gnutls_set_default_priority session 149 gnutls_certificate_send_x509_rdn_sequence session 1 150 gnutls_session_enable_compatibility_mode session 151 152 let s = NetworkSession socket (castPtr session) $ 153 fromIntegral recvSize 154 155 gnutls_transport_set_ptr session $ intPtrToPtr $ fromIntegral $ socket 156 157 handshake s on_block 158 159 return s 160 createSession _ _ _ _ = error "Invalid socket" 161 162 163 ------------------------------------------------------------------------------ 164 endSession :: NetworkSession -> IO () 165 endSession (NetworkSession _ session _) = do 166 throwErrorIf "TLS bye" $ gnutls_bye (castPtr session) 1 `finally` do 167 gnutls_deinit $ castPtr session 168 169 170 ------------------------------------------------------------------------------ 171 handshake :: NetworkSession -> IO () -> IO () 172 handshake s@(NetworkSession { _session = session}) on_block = do 173 rc <- gnutls_handshake $ castPtr session 174 case rc of 175 x | x >= 0 -> return () 176 | isIntrCode x -> handshake s on_block 177 | isAgainCode x -> on_block >> handshake s on_block 178 | otherwise -> throwError "TLS handshake" rc 179 180 181 ------------------------------------------------------------------------------ 182 send :: IO () -> IO () -> NetworkSession -> ByteString -> IO () 183 send tickleTimeout onBlock (NetworkSession { _session = session}) bs = 184 BI.unsafeUseAsCStringLen bs $ uncurry loop 185 where 186 loop ptr len = do 187 sent <- gnutls_record_send (castPtr session) ptr $ fromIntegral len 188 let sent' = fromIntegral sent 189 case sent' of 190 x | x == 0 || x == len -> return () 191 | x > 0 && x < len -> tickleTimeout >> 192 loop (plusPtr ptr sent') (len - sent') 193 | isIntrCode x -> loop ptr len 194 | isAgainCode x -> onBlock >> loop ptr len 195 | otherwise -> throwError "TLS send" $ 196 fromIntegral sent' 197 198 199 ------------------------------------------------------------------------------ 200 recv :: IO b -> NetworkSession -> IO (Maybe ByteString) 201 recv onBlock (NetworkSession _ session recvLen) = do 202 fp <- BI.mallocByteString recvLen 203 sz <- withForeignPtr fp loop 204 if (sz :: Int) <= 0 205 then return Nothing 206 else return $ Just $ BI.fromForeignPtr fp 0 $ fromEnum sz 207 208 where 209 loop recvBuf = do 210 debug $ "TLS: calling record_recv with recvLen=" ++ show recvLen 211 size <- gnutls_record_recv (castPtr session) recvBuf $ toEnum recvLen 212 debug $ "TLS: record_recv returned with size=" ++ show size 213 let size' = fromIntegral size 214 case size' of 215 x | x >= 0 -> return x 216 | isIntrCode x -> loop recvBuf 217 | isAgainCode x -> onBlock >> loop recvBuf 218 | otherwise -> (throwError "TLS recv" $ fromIntegral size') 219 220 221 ------------------------------------------------------------------------------ 222 throwError :: String -> ReturnCode -> IO a 223 throwError prefix rc = gnutls_strerror rc >>= 224 peekCString >>= 225 throwIO . GnuTLSException . (prefix'++) 226 where 227 prefix' = prefix ++ "<" ++ show rc ++ ">: " 228 229 230 ------------------------------------------------------------------------------ 231 throwErrorIf :: String -> IO ReturnCode -> IO () 232 throwErrorIf prefix action = do 233 rc <- action 234 if (rc < 0) 235 then throwError prefix rc 236 else return () 237 238 239 ------------------------------------------------------------------------------ 240 isAgainCode :: (Integral a) => a -> Bool 241 isAgainCode x = (fromIntegral x) == (-28 :: Int) 242 243 244 ------------------------------------------------------------------------------ 245 isIntrCode :: (Integral a) => a -> Bool 246 isIntrCode x = (fromIntegral x) == (-52 :: Int) 247 248 249 ------------------------------------------------------------------------------ 250 getHostAddr :: Int 251 -> ByteString 252 -> IO Socket.SockAddr 253 getHostAddr p s = do 254 h <- if s == "*" 255 then return Socket.iNADDR_ANY 256 else Socket.inet_addr (map w2c . B.unpack $ s) 257 258 return $ Socket.SockAddrInet (fromIntegral p) h 259 260 -- Types 261 262 newtype ReturnCode = ReturnCode CInt 263 deriving (Show, Eq, Ord, Num, Real, Enum, Integral) 264 265 data GnuTLSCredentials 266 data GnuTLSSession 267 data GnuTLSDHParam 268 269 ------------------------------------------------------------------------------ 270 -- Global init/errors 271 272 foreign import ccall safe 273 "gnutls_set_threading_helper" 274 gnutls_set_threading_helper :: IO () 275 276 foreign import ccall safe 277 "gnutls/gnutls.h gnutls_global_init" 278 gnutls_global_init :: IO ReturnCode 279 280 foreign import ccall safe 281 "gnutls/gnutls.h gnutls_global_deinit" 282 gnutls_global_deinit :: IO () 283 284 foreign import ccall safe 285 "gnutls/gnutls.h gnutls_strerror" 286 gnutls_strerror :: ReturnCode -> IO CString 287 288 ------------------------------------------------------------------------------ 289 -- Sessions. All functions here except handshake and bye just 290 -- allocate memory or update members of structures, so they are ok with 291 -- unsafe ccall. 292 293 foreign import ccall unsafe 294 "gnutls/gnutls.h gnutls_init" 295 gnutls_init :: Ptr (Ptr GnuTLSSession) -> CInt -> IO ReturnCode 296 297 foreign import ccall unsafe 298 "gnutls/gnutls.h gnutls_deinit" 299 gnutls_deinit :: Ptr GnuTLSSession -> IO () 300 301 foreign import ccall safe 302 "gnutls/gnutls.h gnutls_handshake" 303 gnutls_handshake :: Ptr GnuTLSSession -> IO ReturnCode 304 305 foreign import ccall safe 306 "gnutls/gnutls.h gnutls_bye" 307 gnutls_bye :: Ptr GnuTLSSession -> CInt -> IO ReturnCode 308 309 foreign import ccall unsafe 310 "gnutls/gnutls.h gnutls_set_default_priority" 311 gnutls_set_default_priority :: Ptr GnuTLSSession -> IO ReturnCode 312 313 foreign import ccall unsafe 314 "gnutls/gnutls.h gnutls_session_enable_compatibility_mode" 315 gnutls_session_enable_compatibility_mode :: Ptr GnuTLSSession -> IO () 316 317 foreign import ccall unsafe 318 "gnutls/gnutls.h gnutls_certificate_send_x509_rdn_sequence" 319 gnutls_certificate_send_x509_rdn_sequence 320 :: Ptr GnuTLSSession -> CInt -> IO () 321 322 ------------------------------------------------------------------------------ 323 -- Certificates. Perhaps these could be unsafe but they are not performance 324 -- critical, since they are called only once during server startup. 325 326 foreign import ccall safe 327 "gnutls/gnutls.h gnutls_certificate_allocate_credentials" 328 gnutls_certificate_allocate_credentials 329 :: Ptr (Ptr GnuTLSCredentials) -> IO ReturnCode 330 331 foreign import ccall safe 332 "gnutls/gnutls.h gnutls_certificate_free_credentials" 333 gnutls_certificate_free_credentials 334 :: Ptr GnuTLSCredentials -> IO () 335 336 gnutls_x509_fmt_pem :: CInt 337 gnutls_x509_fmt_pem = 1 338 339 foreign import ccall safe 340 "gnutls/gnutls.h gnutls_certificate_set_x509_key_file" 341 gnutls_certificate_set_x509_key_file 342 :: Ptr GnuTLSCredentials -> CString -> CString -> CInt -> IO ReturnCode 343 344 345 ------------------------------------------------------------------------------ 346 -- Credentials. This is ok as unsafe because it just sets members in the 347 -- session structure. 348 349 foreign import ccall unsafe 350 "gnutls/gnutls.h gnutls_credentials_set" 351 gnutls_credentials_set 352 :: Ptr GnuTLSSession -> CInt -> Ptr a -> IO ReturnCode 353 354 ------------------------------------------------------------------------------ 355 -- Records. These are marked unsafe because they are very performance 356 -- critical. Since we are using non-blocking sockets send and recv will not 357 -- block. 358 359 foreign import ccall unsafe 360 "gnutls/gnutls.h gnutls_transport_set_ptr" 361 gnutls_transport_set_ptr :: Ptr GnuTLSSession -> Ptr a -> IO () 362 363 foreign import ccall unsafe 364 "gnutls/gnutls.h gnutls_record_recv" 365 gnutls_record_recv :: Ptr GnuTLSSession -> Ptr a -> CSize -> IO CSize 366 367 foreign import ccall unsafe 368 "gnutls/gnutls.h gnutls_record_send" 369 gnutls_record_send :: Ptr GnuTLSSession -> Ptr a -> CSize -> IO CSize 370 371 ------------------------------------------------------------------------------ 372 -- DHParam. Perhaps these could be unsafe but they are not performance 373 -- critical. 374 375 foreign import ccall safe 376 "gnutls/gnutls.h gnutls_dh_params_init" 377 gnutls_dh_params_init :: Ptr (Ptr GnuTLSDHParam) -> IO ReturnCode 378 379 foreign import ccall safe 380 "gnutls/gnutls.h gnutls_dh_params_deinit" 381 gnutls_dh_params_deinit :: Ptr GnuTLSDHParam -> IO () 382 383 foreign import ccall safe 384 "gnutls/gnutls.h gnutls_dh_params_generate2" 385 gnutls_dh_params_generate2 :: Ptr GnuTLSDHParam -> CUInt -> IO ReturnCode 386 387 foreign import ccall safe 388 "gnutls/gnutls.h gnutls_certificate_set_dh_params" 389 gnutls_certificate_set_dh_params 390 :: Ptr GnuTLSCredentials -> Ptr GnuTLSDHParam -> IO () 391 392 #endif