1 {-# LANGUAGE BangPatterns              #-}
    2 {-# LANGUAGE DeriveDataTypeable        #-}
    3 {-# LANGUAGE ExistentialQuantification #-}
    4 {-# LANGUAGE OverloadedStrings         #-}
    5 {-# LANGUAGE ScopedTypeVariables       #-}
    6 
    7 ------------------------------------------------------------------------------
    8 -- | This module contains primitives and helper functions for handling
    9 -- requests with @Content-type: multipart/form-data@, i.e. HTML forms and file
   10 -- uploads.
   11 --
   12 -- Typically most users will want to use 'handleFileUploads', which writes
   13 -- uploaded files to a temporary directory before sending them on to a handler
   14 -- specified by the user.
   15 --
   16 -- Users who wish to handle their file uploads differently can use the
   17 -- lower-level streaming 'Iteratee' interface called 'handleMultipart'. That
   18 -- function takes uploaded files and streams them to an 'Iteratee' consumer of
   19 -- the user's choosing.
   20 --
   21 -- Using these functions requires making \"policy\" decisions which Snap can't
   22 -- really make for users, such as \"what's the largest PDF file a user is
   23 -- allowed to upload?\" and \"should we read form inputs into the parameters
   24 -- mapping?\". Policy is specified on a \"global\" basis (using
   25 -- 'UploadPolicy'), and on a per-file basis (using 'PartUploadPolicy', which
   26 -- allows you to reject or limit the size of certain uploaded
   27 -- @Content-type@s).
   28 module Snap.Util.FileUploads
   29   ( -- * Functions
   30     handleFileUploads
   31   , handleMultipart
   32 
   33     -- * Uploaded parts
   34   , PartInfo(..)
   35 
   36     -- ** Policy
   37     -- *** General upload policy
   38   , UploadPolicy
   39   , defaultUploadPolicy
   40   , doProcessFormInputs
   41   , setProcessFormInputs
   42   , getMaximumFormInputSize
   43   , setMaximumFormInputSize
   44   , getMinimumUploadRate
   45   , setMinimumUploadRate
   46   , getMinimumUploadSeconds
   47   , setMinimumUploadSeconds
   48   , getUploadTimeout
   49   , setUploadTimeout
   50 
   51     -- *** Per-file upload policy
   52   , PartUploadPolicy
   53   , disallow
   54   , allowWithMaximumSize
   55 
   56     -- * Exceptions
   57   , FileUploadException
   58   , fileUploadExceptionReason
   59   , BadPartException
   60   , badPartExceptionReason
   61   , PolicyViolationException
   62   , policyViolationExceptionReason
   63   ) where
   64 
   65 ------------------------------------------------------------------------------
   66 import           Control.Arrow
   67 import           Control.Applicative
   68 import           Control.Exception (SomeException(..))
   69 import           Control.Monad
   70 import           Control.Monad.CatchIO
   71 import           Control.Monad.Trans
   72 import qualified Data.Attoparsec.Char8 as Atto
   73 import           Data.Attoparsec.Char8 hiding (many, Result(..))
   74 import           Data.Attoparsec.Enumerator
   75 import qualified Data.ByteString.Char8 as S
   76 import           Data.ByteString.Char8 (ByteString)
   77 import           Data.ByteString.Internal (c2w)
   78 import qualified Data.CaseInsensitive as CI
   79 import qualified Data.DList as D
   80 import           Data.Enumerator.Binary (iterHandle)
   81 import           Data.IORef
   82 import           Data.Int
   83 import           Data.List hiding (takeWhile)
   84 import qualified Data.Map as Map
   85 import           Data.Maybe
   86 import qualified Data.Text as T
   87 import           Data.Text (Text)
   88 import qualified Data.Text.Encoding as TE
   89 import           Data.Typeable
   90 import           Prelude hiding (catch, getLine, takeWhile)
   91 import           System.Directory
   92 import           System.IO hiding (isEOF)
   93 ------------------------------------------------------------------------------
   94 import           Snap.Iteratee hiding (map)
   95 import qualified Snap.Iteratee as I
   96 import           Snap.Internal.Debug
   97 import           Snap.Internal.Iteratee.Debug
   98 import           Snap.Internal.Iteratee.BoyerMooreHorspool
   99 import           Snap.Internal.Parsing
  100 import           Snap.Types
  101 
  102 
  103 ------------------------------------------------------------------------------
  104 -- | Reads uploaded files into a temporary directory and calls a user handler
  105 -- to process them.
  106 --
  107 -- Given a temporary directory, global and file-specific upload policies, and
  108 -- a user handler, this function consumes a request body uploaded with
  109 -- @Content-type: multipart/form-data@. Each file is read into the temporary
  110 -- directory, and then a list of the uploaded files is passed to the user
  111 -- handler. After the user handler runs (but before the 'Response' body
  112 -- 'Enumerator' is streamed to the client), the files are deleted from disk;
  113 -- so if you want to retain or use the uploaded files in the generated
  114 -- response, you would need to move or otherwise process them.
  115 --
  116 -- The argument passed to the user handler is a list of:
  117 --
  118 -- > (PartInfo, Either PolicyViolationException FilePath)
  119 --
  120 -- The first half of this tuple is a 'PartInfo', which contains the
  121 -- information the client browser sent about the given upload part (like
  122 -- filename, content-type, etc). The second half of this tuple is an 'Either'
  123 -- stipulating that either:
  124 --
  125 -- 1. the file was rejected on a policy basis because of the provided
  126 --    'PartUploadPolicy' handler
  127 --
  128 -- 2. the file was accepted and exists at the given path.
  129 --
  130 -- If the request's @Content-type@ was not \"@multipart/formdata@\", this
  131 -- function skips processing using 'pass'.
  132 --
  133 -- If the client's upload rate passes below the configured minimum (see
  134 -- 'setMinimumUploadRate' and 'setMinimumUploadSeconds'), this function throws
  135 -- a 'RateTooSlowException'. This setting is there to protect the server
  136 -- against slowloris-style denial of service attacks.
  137 --
  138 -- If the given 'UploadPolicy' stipulates that you wish form inputs to be
  139 -- placed in the 'rqParams' parameter map (using 'setProcessFormInputs'), and
  140 -- a form input exceeds the maximum allowable size, this function will throw a
  141 -- 'PolicyViolationException'.
  142 --
  143 -- If an uploaded part contains MIME headers longer than a fixed internal
  144 -- threshold (currently 32KB), this function will throw a 'BadPartException'.
  145 
  146 handleFileUploads ::
  147        (MonadSnap m) =>
  148        FilePath                       -- ^ temporary directory
  149     -> UploadPolicy                   -- ^ general upload policy
  150     -> (PartInfo -> PartUploadPolicy) -- ^ per-part upload policy
  151     -> ([(PartInfo, Either PolicyViolationException FilePath)] -> m a)
  152                                       -- ^ user handler (see function
  153                                       -- description)
  154     -> m a
  155 handleFileUploads tmpdir uploadPolicy partPolicy handler = do
  156     uploadedFiles <- newUploadedFiles
  157 
  158     (do
  159         xs <- handleMultipart uploadPolicy (iter uploadedFiles)
  160         handler xs
  161         ) `finally` (cleanupUploadedFiles uploadedFiles)
  162 
  163   where
  164     iter uploadedFiles partInfo = maybe disallowed takeIt mbFs
  165       where
  166         ctText = partContentType partInfo
  167         fnText = fromMaybe "" $ partFileName partInfo
  168 
  169         ct = TE.decodeUtf8 ctText
  170         fn = TE.decodeUtf8 fnText
  171 
  172         (PartUploadPolicy mbFs) = partPolicy partInfo
  173 
  174         retVal (_,x) = (partInfo, Right x)
  175 
  176         takeIt maxSize = do
  177             let it = fmap retVal $
  178                      joinI' $
  179                      takeNoMoreThan maxSize $$
  180                      fileReader uploadedFiles tmpdir partInfo
  181 
  182             it `catches` [ Handler $ \(_ :: TooManyBytesReadException) ->
  183                                      (skipToEof >> tooMany maxSize)
  184                          , Handler $ \(e :: SomeException) -> throw e
  185                          ]
  186 
  187         tooMany maxSize =
  188             return ( partInfo
  189                    , Left $ PolicyViolationException
  190                           $ T.concat [ "File \""
  191                                      , fn
  192                                      , "\" exceeded maximum allowable size "
  193                                      , T.pack $ show maxSize ] )
  194 
  195         disallowed =
  196             return ( partInfo
  197                    , Left $ PolicyViolationException
  198                           $ T.concat [ "Policy disallowed upload of file \""
  199                                      , fn
  200                                      , "\" with content-type \""
  201                                      , ct
  202                                      , "\"" ] )
  203 
  204 
  205 ------------------------------------------------------------------------------
  206 -- | Given an upload policy and a function to consume uploaded \"parts\",
  207 -- consume a request body uploaded with @Content-type: multipart/form-data@.
  208 -- Normally most users will want to use 'handleFileUploads' (which writes
  209 -- uploaded files to a temporary directory and passes their names to a given
  210 -- handler) rather than this function; the lower-level 'handleMultipart'
  211 -- function should be used if you want to stream uploaded files to your own
  212 -- iteratee function.
  213 --
  214 -- If the request's @Content-type@ was not \"@multipart/formdata@\", this
  215 -- function skips processing using 'pass'.
  216 --
  217 -- If the client's upload rate passes below the configured minimum (see
  218 -- 'setMinimumUploadRate' and 'setMinimumUploadSeconds'), this function throws
  219 -- a 'RateTooSlowException'. This setting is there to protect the server
  220 -- against slowloris-style denial of service attacks.
  221 --
  222 -- If the given 'UploadPolicy' stipulates that you wish form inputs to be
  223 -- placed in the 'rqParams' parameter map (using 'setProcessFormInputs'), and
  224 -- a form input exceeds the maximum allowable size, this function will throw a
  225 -- 'PolicyViolationException'.
  226 --
  227 -- If an uploaded part contains MIME headers longer than a fixed internal
  228 -- threshold (currently 32KB), this function will throw a 'BadPartException'.
  229 --
  230 handleMultipart ::
  231        (MonadSnap m) =>
  232        UploadPolicy                            -- ^ global upload policy
  233     -> (PartInfo -> Iteratee ByteString IO a)  -- ^ part processor
  234     -> m [a]
  235 handleMultipart uploadPolicy origPartHandler = do
  236     hdrs <- liftM headers getRequest
  237     let (ct, mbBoundary) = getContentType hdrs
  238 
  239     tickleTimeout <- getTimeoutAction
  240     let bumpTimeout = tickleTimeout $ uploadTimeout uploadPolicy
  241 
  242     let partHandler = if doProcessFormInputs uploadPolicy
  243                         then captureVariableOrReadFile
  244                                  (getMaximumFormInputSize uploadPolicy)
  245                                  origPartHandler
  246                         else (\p -> fmap File (origPartHandler p))
  247 
  248     -- not well-formed multipart? bomb out.
  249     when (ct /= "multipart/form-data") $ do
  250         debug $ "handleMultipart called with content-type=" ++ S.unpack ct
  251                   ++ ", passing"
  252         pass
  253 
  254     when (isNothing mbBoundary) $
  255          throw $ BadPartException $
  256          "got multipart/form-data without boundary"
  257 
  258     let boundary = fromJust mbBoundary
  259     captures <- runRequestBody (iter bumpTimeout boundary partHandler)
  260 
  261     procCaptures [] captures
  262 
  263   where
  264     iter bump boundary ph = iterateeDebugWrapper "killIfTooSlow" $
  265                             killIfTooSlow
  266                               bump
  267                               (minimumUploadRate uploadPolicy)
  268                               (minimumUploadSeconds uploadPolicy)
  269                               (internalHandleMultipart boundary ph)
  270 
  271     ins k v = Map.insertWith' (\a b -> Prelude.head a : b) k [v]
  272 
  273     maxFormVars = maximumNumberOfFormInputs uploadPolicy
  274 
  275     procCaptures l [] = return $ reverse l
  276     procCaptures l ((File x):xs) = procCaptures (x:l) xs
  277     procCaptures l ((Capture k v):xs) = do
  278         rq <- getRequest
  279         let n = Map.size $ rqParams rq
  280         when (n >= maxFormVars) $
  281           throw $ PolicyViolationException $
  282           T.concat [ "number of form inputs exceeded maximum of "
  283                    , T.pack $ show maxFormVars ]
  284         modifyRequest $ rqModifyParams (ins k v)
  285         procCaptures l xs
  286 
  287 
  288 ------------------------------------------------------------------------------
  289 -- | 'PartInfo' contains information about a \"part\" in a request uploaded
  290 -- with @Content-type: multipart/form-data@.
  291 data PartInfo =
  292     PartInfo { partFieldName   :: !ByteString
  293              , partFileName    :: !(Maybe ByteString)
  294              , partContentType :: !ByteString
  295              }
  296   deriving (Show)
  297 
  298 
  299 ------------------------------------------------------------------------------
  300 -- | All of the exceptions defined in this package inherit from
  301 -- 'FileUploadException', so if you write
  302 --
  303 -- > foo `catch` \(e :: FileUploadException) -> ...
  304 --
  305 -- you can catch a 'BadPartException', a 'PolicyViolationException', etc.
  306 data FileUploadException =
  307     GenericFileUploadException {
  308       _genericFileUploadExceptionReason :: Text
  309     }
  310   | forall e . (Exception e, Show e) =>
  311     WrappedFileUploadException {
  312       _wrappedFileUploadException :: e
  313     , _wrappedFileUploadExceptionReason :: Text
  314     }
  315   deriving (Typeable)
  316 
  317 
  318 ------------------------------------------------------------------------------
  319 instance Show FileUploadException where
  320     show (GenericFileUploadException r) = "File upload exception: " ++
  321                                           T.unpack r
  322     show (WrappedFileUploadException e _) = show e
  323 
  324 
  325 ------------------------------------------------------------------------------
  326 instance Exception FileUploadException
  327 
  328 
  329 ------------------------------------------------------------------------------
  330 fileUploadExceptionReason :: FileUploadException -> Text
  331 fileUploadExceptionReason (GenericFileUploadException r) = r
  332 fileUploadExceptionReason (WrappedFileUploadException _ r) = r
  333 
  334 
  335 ------------------------------------------------------------------------------
  336 uploadExceptionToException :: Exception e => e -> Text -> SomeException
  337 uploadExceptionToException e r =
  338     SomeException $ WrappedFileUploadException e r
  339 
  340 
  341 ------------------------------------------------------------------------------
  342 uploadExceptionFromException :: Exception e => SomeException -> Maybe e
  343 uploadExceptionFromException x = do
  344     WrappedFileUploadException e _ <- fromException x
  345     cast e
  346 
  347 
  348 ------------------------------------------------------------------------------
  349 data BadPartException = BadPartException { badPartExceptionReason :: Text }
  350   deriving (Typeable)
  351 
  352 instance Exception BadPartException where
  353     toException e@(BadPartException r) = uploadExceptionToException e r
  354     fromException = uploadExceptionFromException
  355 
  356 instance Show BadPartException where
  357   show (BadPartException s) = "Bad part: " ++ T.unpack s
  358 
  359 
  360 ------------------------------------------------------------------------------
  361 data PolicyViolationException = PolicyViolationException {
  362       policyViolationExceptionReason :: Text
  363     } deriving (Typeable)
  364 
  365 instance Exception PolicyViolationException where
  366     toException e@(PolicyViolationException r) =
  367         uploadExceptionToException e r
  368     fromException = uploadExceptionFromException
  369 
  370 instance Show PolicyViolationException where
  371   show (PolicyViolationException s) = "File upload policy violation: "
  372                                             ++ T.unpack s
  373 
  374 
  375 ------------------------------------------------------------------------------
  376 -- | 'UploadPolicy' controls overall policy decisions relating to
  377 -- @multipart/form-data@ uploads, specifically:
  378 --
  379 -- * whether to treat parts without filenames as form input (reading them into
  380 --   the 'rqParams' map)
  381 --
  382 -- * because form input is read into memory, the maximum size of a form input
  383 --   read in this manner, and the maximum number of form inputs
  384 --
  385 -- * the minimum upload rate a client must maintain before we kill the
  386 --   connection; if very low-bitrate uploads were allowed then a Snap server
  387 --   would be vulnerable to a trivial denial-of-service using a
  388 --   \"slowloris\"-type attack
  389 --
  390 -- * the minimum number of seconds which must elapse before we start killing
  391 --   uploads for having too low an upload rate.
  392 --
  393 -- * the amount of time we should wait before timing out the connection
  394 --   whenever we receive input from the client.
  395 data UploadPolicy = UploadPolicy {
  396       processFormInputs         :: Bool
  397     , maximumFormInputSize      :: Int
  398     , maximumNumberOfFormInputs :: Int
  399     , minimumUploadRate         :: Double
  400     , minimumUploadSeconds      :: Int
  401     , uploadTimeout             :: Int
  402 } deriving (Show, Eq)
  403 
  404 
  405 ------------------------------------------------------------------------------
  406 -- | A reasonable set of defaults for upload policy. The default policy is:
  407 --
  408 --   [@maximum form input size@]                128kB
  409 --
  410 --   [@maximum number of form inputs@]          10
  411 --
  412 --   [@minimum upload rate@]                    1kB/s
  413 --
  414 --   [@seconds before rate limiting kicks in@]  10
  415 --
  416 --   [@inactivity timeout@]                     20 seconds
  417 --
  418 defaultUploadPolicy :: UploadPolicy
  419 defaultUploadPolicy = UploadPolicy True maxSize maxNum minRate minSeconds tout
  420   where
  421     maxSize    = 2^(17::Int)
  422     maxNum     = 10
  423     minRate    = 1000
  424     minSeconds = 10
  425     tout       = 20
  426 
  427 
  428 ------------------------------------------------------------------------------
  429 -- | Does this upload policy stipulate that we want to treat parts without
  430 -- filenames as form input?
  431 doProcessFormInputs :: UploadPolicy -> Bool
  432 doProcessFormInputs = processFormInputs
  433 
  434 
  435 ------------------------------------------------------------------------------
  436 -- | Set the upload policy for treating parts without filenames as form input.
  437 setProcessFormInputs :: Bool -> UploadPolicy -> UploadPolicy
  438 setProcessFormInputs b u = u { processFormInputs = b }
  439 
  440 
  441 ------------------------------------------------------------------------------
  442 -- | Get the maximum size of a form input which will be read into our
  443 --   'rqParams' map.
  444 getMaximumFormInputSize :: UploadPolicy -> Int
  445 getMaximumFormInputSize = maximumFormInputSize
  446 
  447 
  448 ------------------------------------------------------------------------------
  449 -- | Set the maximum size of a form input which will be read into our
  450 --   'rqParams' map.
  451 setMaximumFormInputSize :: Int -> UploadPolicy -> UploadPolicy
  452 setMaximumFormInputSize s u = u { maximumFormInputSize = s }
  453 
  454 
  455 ------------------------------------------------------------------------------
  456 -- | Get the minimum rate (in /bytes\/second/) a client must maintain before
  457 --   we kill the connection.
  458 getMinimumUploadRate :: UploadPolicy -> Double
  459 getMinimumUploadRate = minimumUploadRate
  460 
  461 
  462 ------------------------------------------------------------------------------
  463 -- | Set the minimum rate (in /bytes\/second/) a client must maintain before
  464 --   we kill the connection.
  465 setMinimumUploadRate :: Double -> UploadPolicy -> UploadPolicy
  466 setMinimumUploadRate s u = u { minimumUploadRate = s }
  467 
  468 
  469 ------------------------------------------------------------------------------
  470 -- | Get the amount of time which must elapse before we begin enforcing the
  471 --   upload rate minimum
  472 getMinimumUploadSeconds :: UploadPolicy -> Int
  473 getMinimumUploadSeconds = minimumUploadSeconds
  474 
  475 
  476 ------------------------------------------------------------------------------
  477 -- | Set the amount of time which must elapse before we begin enforcing the
  478 --   upload rate minimum
  479 setMinimumUploadSeconds :: Int -> UploadPolicy -> UploadPolicy
  480 setMinimumUploadSeconds s u = u { minimumUploadSeconds = s }
  481 
  482 
  483 ------------------------------------------------------------------------------
  484 -- | Get the \"upload timeout\". Whenever input is received from the client,
  485 --   the connection timeout is set this many seconds in the future.
  486 getUploadTimeout :: UploadPolicy -> Int
  487 getUploadTimeout = uploadTimeout
  488 
  489 
  490 ------------------------------------------------------------------------------
  491 -- | Set the upload timeout.
  492 setUploadTimeout :: Int -> UploadPolicy -> UploadPolicy
  493 setUploadTimeout s u = u { uploadTimeout = s }
  494 
  495 
  496 ------------------------------------------------------------------------------
  497 -- | Upload policy can be set on an \"general\" basis (using 'UploadPolicy'),
  498 --   but handlers can also make policy decisions on individual files\/parts
  499 --   uploaded. For each part uploaded, handlers can decide:
  500 --
  501 -- * whether to allow the file upload at all
  502 --
  503 -- * the maximum size of uploaded files, if allowed
  504 data PartUploadPolicy = PartUploadPolicy {
  505       _maximumFileSize :: Maybe Int64
  506 } deriving (Show, Eq)
  507 
  508 
  509 ------------------------------------------------------------------------------
  510 -- | Disallows the file to be uploaded.
  511 disallow :: PartUploadPolicy
  512 disallow = PartUploadPolicy Nothing
  513 
  514 
  515 ------------------------------------------------------------------------------
  516 -- | Allows the file to be uploaded, with maximum size /n/.
  517 allowWithMaximumSize :: Int64 -> PartUploadPolicy
  518 allowWithMaximumSize = PartUploadPolicy . Just
  519 
  520 
  521 ------------------------------------------------------------------------------
  522 -- private exports follow. FIXME: organize
  523 ------------------------------------------------------------------------------
  524 
  525 ------------------------------------------------------------------------------
  526 captureVariableOrReadFile ::
  527        Int                                     -- ^ maximum size of form input
  528     -> (PartInfo -> Iteratee ByteString IO a)  -- ^ file reading code
  529     -> (PartInfo -> Iteratee ByteString IO (Capture a))
  530 captureVariableOrReadFile maxSize fileHandler partInfo =
  531     case partFileName partInfo of
  532       Nothing -> iter
  533       _       -> liftM File $ fileHandler partInfo
  534   where
  535     iter = varIter `catchError` handler
  536 
  537     fieldName = partFieldName partInfo
  538 
  539     varIter = do
  540         var <- liftM S.concat $
  541                joinI' $
  542                takeNoMoreThan (fromIntegral maxSize) $$ consume
  543         return $ Capture fieldName var
  544 
  545     handler e = do
  546         let m = fromException e :: Maybe TooManyBytesReadException
  547         case m of
  548           Nothing -> throwError e
  549           Just _  -> throwError $ PolicyViolationException $
  550                      T.concat [ "form input '"
  551                               , TE.decodeUtf8 fieldName
  552                               , "' exceeded maximum permissible size ("
  553                               , T.pack $ show maxSize
  554                               , " bytes)" ]
  555 
  556 
  557 ------------------------------------------------------------------------------
  558 data Capture a = Capture ByteString ByteString
  559                | File a
  560   deriving (Show)
  561 
  562 
  563 ------------------------------------------------------------------------------
  564 fileReader :: UploadedFiles
  565            -> FilePath
  566            -> PartInfo
  567            -> Iteratee ByteString IO (PartInfo, FilePath)
  568 fileReader uploadedFiles tmpdir partInfo = do
  569     (fn, h) <- openFileForUpload uploadedFiles tmpdir
  570     let i = iterateeDebugWrapper "fileReader" $ iter fn h
  571     i `catch` \(e::SomeException) -> throwError e
  572 
  573   where
  574     iter fileName h = do
  575         iterHandle h
  576         debug "fileReader: closing active file"
  577         closeActiveFile uploadedFiles
  578         return (partInfo, fileName)
  579 
  580 
  581 ------------------------------------------------------------------------------
  582 internalHandleMultipart ::
  583        ByteString                              -- ^ boundary value
  584     -> (PartInfo -> Iteratee ByteString IO a)  -- ^ part processor
  585     -> Iteratee ByteString IO [a]
  586 internalHandleMultipart boundary clientHandler = go `catch` errorHandler
  587 
  588   where
  589     --------------------------------------------------------------------------
  590     errorHandler :: SomeException -> Iteratee ByteString IO a
  591     errorHandler e = do
  592         skipToEof
  593         throwError e
  594 
  595     --------------------------------------------------------------------------
  596     go = do
  597         -- swallow the first boundary
  598         _ <- iterParser $ parseFirstBoundary boundary
  599         step <- iterateeDebugWrapper "boyer-moore" $
  600                 (bmhEnumeratee (fullBoundary boundary) $$ processParts iter)
  601         liftM concat $ lift $ run_ $ returnI step
  602 
  603     --------------------------------------------------------------------------
  604     pBoundary b = Atto.try $ do
  605       _ <- string "--"
  606       string b
  607 
  608     --------------------------------------------------------------------------
  609     fullBoundary b       = S.concat ["\r\n", "--", b]
  610     pLine                = takeWhile (not . isEndOfLine . c2w) <* eol
  611     takeLine             = pLine *> pure ()
  612     parseFirstBoundary b = pBoundary b <|> (takeLine *> parseFirstBoundary b)
  613 
  614 
  615     --------------------------------------------------------------------------
  616     takeHeaders = hdrs `catchError` handler
  617       where
  618         hdrs = liftM toHeaders $
  619                iterateeDebugWrapper "header parser" $
  620                joinI' $
  621                takeNoMoreThan mAX_HDRS_SIZE $$
  622                iterParser pHeadersWithSeparator
  623 
  624         handler e = do
  625             let m = fromException e :: Maybe TooManyBytesReadException
  626             case m of
  627               Nothing -> throwError e
  628               Just _  -> throwError $ BadPartException $
  629                          "headers exceeded maximum size"
  630 
  631     --------------------------------------------------------------------------
  632     iter = do
  633         hdrs <- takeHeaders
  634 
  635         -- are we using mixed?
  636         let (contentType, mboundary) = getContentType hdrs
  637 
  638         let (fieldName, fileName) = getFieldName hdrs
  639 
  640         if contentType == "multipart/mixed"
  641           then maybe (throwError $ BadPartException $
  642                       "got multipart/mixed without boundary")
  643                      (processMixed fieldName)
  644                      mboundary
  645           else do
  646               let info = PartInfo fieldName fileName contentType
  647               liftM (:[]) $ clientHandler info
  648 
  649 
  650     --------------------------------------------------------------------------
  651     processMixed fieldName mixedBoundary = do
  652         -- swallow the first boundary
  653         _ <- iterParser $ parseFirstBoundary mixedBoundary
  654         step <- iterateeDebugWrapper "boyer-moore" $
  655                 (bmhEnumeratee (fullBoundary mixedBoundary) $$
  656                  processParts (mixedIter fieldName))
  657         lift $ run_ $ returnI step
  658 
  659 
  660     --------------------------------------------------------------------------
  661     mixedIter fieldName = do
  662         hdrs <- takeHeaders
  663 
  664         let (contentType, _) = getContentType hdrs
  665         let (_, fileName)    = getFieldName hdrs
  666 
  667         let info = PartInfo fieldName fileName contentType
  668         clientHandler info
  669 
  670 
  671 ------------------------------------------------------------------------------
  672 getContentType :: Headers
  673                -> (ByteString, Maybe ByteString)
  674 getContentType hdrs = (contentType, boundary)
  675   where
  676     contentTypeValue = fromMaybe "text/plain" $
  677                        getHeader "content-type" hdrs
  678 
  679     eCT = fullyParse contentTypeValue pContentTypeWithParameters
  680     (contentType, params) = either (const ("text/plain", [])) id eCT
  681 
  682     boundary = findParam "boundary" params
  683 
  684 
  685 ------------------------------------------------------------------------------
  686 getFieldName :: Headers -> (ByteString, Maybe ByteString)
  687 getFieldName hdrs = (fieldName, fileName)
  688   where
  689     contentDispositionValue = fromMaybe "" $
  690                               getHeader "content-disposition" hdrs
  691 
  692     eDisposition = fullyParse contentDispositionValue pValueWithParameters
  693 
  694     (_, dispositionParameters) =
  695         either (const ("", [])) id eDisposition
  696 
  697     fieldName = fromMaybe "" $ findParam "name" dispositionParameters
  698 
  699     fileName = findParam "filename" dispositionParameters
  700 
  701 
  702 ------------------------------------------------------------------------------
  703 findParam :: (Eq a) => a -> [(a, b)] -> Maybe b
  704 findParam p = fmap snd . find ((== p) . fst)
  705 
  706 
  707 ------------------------------------------------------------------------------
  708 -- | Given a 'MatchInfo' stream which is partitioned by boundary values, read
  709 -- up until the next boundary and send all of the chunks into the wrapped
  710 -- iteratee
  711 processPart :: (Monad m) => Enumeratee MatchInfo ByteString m a
  712 processPart st = {-# SCC "pPart/outer" #-}
  713                    case st of
  714                      (Continue k) -> go k
  715                      _            -> yield st (Chunks [])
  716   where
  717     go :: (Monad m) => (Stream ByteString -> Iteratee ByteString m a)
  718                     -> Iteratee MatchInfo m (Step ByteString m a)
  719     go !k = {-# SCC "pPart/go" #-}
  720             I.head >>= maybe finish process
  721       where
  722         -- called when outer stream is EOF
  723         finish = {-# SCC "pPart/finish" #-}
  724                  lift $ runIteratee $ k EOF
  725 
  726         -- no match ==> pass the stream chunk along
  727         process (NoMatch !s) = {-# SCC "pPart/noMatch" #-} do
  728           !step <- lift $ runIteratee $ k $ Chunks [s]
  729           case step of
  730             (Continue k') -> go k'
  731             _             -> yield step (Chunks [])
  732 
  733         process (Match _) = {-# SCC "pPart/match" #-}
  734                             lift $ runIteratee $ k EOF
  735 
  736 
  737 ------------------------------------------------------------------------------
  738 -- | Assuming we've already identified the boundary value and run
  739 -- 'bmhEnumeratee' to split the input up into parts which match and parts
  740 -- which don't, run the given 'ByteString' iteratee over each part and grab a
  741 -- list of the resulting values.
  742 processParts :: Iteratee ByteString IO a
  743              -> Iteratee MatchInfo IO [a]
  744 processParts partIter = iterateeDebugWrapper "processParts" $ go D.empty
  745   where
  746     iter = {-# SCC "processParts/iter" #-} do
  747         isLast <- bParser
  748         if isLast
  749           then return Nothing
  750           else do
  751             x <- partIter
  752             skipToEof
  753             return $ Just x
  754 
  755     go soFar = {-# SCC "processParts/go" #-} do
  756       b <- isEOF
  757 
  758       if b
  759         then return $ D.toList soFar
  760         else do
  761            -- processPart $$ iter
  762            --   :: Iteratee MatchInfo m (Step ByteString m a)
  763            innerStep <- processPart $$ iter
  764 
  765            -- output :: Maybe a
  766            output <- lift $ run_ $ returnI innerStep
  767 
  768            case output of
  769              Just x  -> go (D.append soFar $ D.singleton x)
  770              Nothing -> return $ D.toList soFar
  771 
  772     bParser = iterateeDebugWrapper "boundary debugger" $
  773                   iterParser $ pBoundaryEnd
  774 
  775     pBoundaryEnd = (eol *> pure False) <|> (string "--" *> pure True)
  776 
  777 
  778 ------------------------------------------------------------------------------
  779 eol :: Parser ByteString
  780 eol = (string "\n") <|> (string "\r\n")
  781 
  782 
  783 ------------------------------------------------------------------------------
  784 pHeadersWithSeparator :: Parser [(ByteString,ByteString)]
  785 pHeadersWithSeparator = pHeaders <* crlf
  786 
  787 
  788 ------------------------------------------------------------------------------
  789 toHeaders :: [(ByteString,ByteString)] -> Headers
  790 toHeaders kvps = foldl' f Map.empty kvps'
  791   where
  792     kvps'     = map (first CI.mk . second (:[])) kvps
  793     f m (k,v) = Map.insertWith' (flip (++)) k v m
  794 
  795 
  796 ------------------------------------------------------------------------------
  797 mAX_HDRS_SIZE :: Int64
  798 mAX_HDRS_SIZE = 32768
  799 
  800 
  801 ------------------------------------------------------------------------------
  802 -- We need some code to keep track of the files we have already successfully
  803 -- created in case an exception is thrown by the request body enumerator or
  804 -- one of the client iteratees.
  805 data UploadedFilesState = UploadedFilesState {
  806       -- | This is the file which is currently being written to. If the
  807       -- calling function gets an exception here, it is responsible for
  808       -- closing and deleting this file.
  809       _currentFile :: Maybe (FilePath, Handle)
  810 
  811       -- | .. and these files have already been successfully read and closed.
  812     , _alreadyReadFiles :: [FilePath]
  813 }
  814 
  815 
  816 ------------------------------------------------------------------------------
  817 emptyUploadedFilesState :: UploadedFilesState
  818 emptyUploadedFilesState = UploadedFilesState Nothing []
  819 
  820 
  821 ------------------------------------------------------------------------------
  822 newtype UploadedFiles = UploadedFiles (IORef UploadedFilesState)
  823 
  824 
  825 ------------------------------------------------------------------------------
  826 newUploadedFiles :: MonadIO m => m UploadedFiles
  827 newUploadedFiles = liftM UploadedFiles $
  828                    liftIO $ newIORef emptyUploadedFilesState
  829 
  830 
  831 ------------------------------------------------------------------------------
  832 cleanupUploadedFiles :: (MonadIO m) => UploadedFiles -> m ()
  833 cleanupUploadedFiles (UploadedFiles stateRef) = liftIO $ do
  834     state <- readIORef stateRef
  835     killOpenFile state
  836     mapM_ killFile $ _alreadyReadFiles state
  837     writeIORef stateRef emptyUploadedFilesState
  838 
  839   where
  840     killFile = eatException . removeFile
  841 
  842     killOpenFile state = maybe (return ())
  843                                (\(fp,h) -> do
  844                                     eatException $ hClose h
  845                                     eatException $ removeFile fp)
  846                                (_currentFile state)
  847 
  848 
  849 ------------------------------------------------------------------------------
  850 openFileForUpload :: (MonadIO m) =>
  851                      UploadedFiles
  852                   -> FilePath
  853                   -> m (FilePath, Handle)
  854 openFileForUpload ufs@(UploadedFiles stateRef) tmpdir = liftIO $ do
  855     state <- readIORef stateRef
  856 
  857     -- It should be an error to open a new file with this interface if there
  858     -- is already a file handle active.
  859     when (isJust $ _currentFile state) $ do
  860         cleanupUploadedFiles ufs
  861         throw $ GenericFileUploadException alreadyOpenMsg
  862 
  863     fph@(_,h) <- openBinaryTempFile tmpdir "snap-"
  864     hSetBuffering h NoBuffering
  865 
  866     writeIORef stateRef $ state { _currentFile = Just fph }
  867     return fph
  868 
  869   where
  870     alreadyOpenMsg =
  871         T.concat [ "Internal error! UploadedFiles: "
  872                  , "opened new file with pre-existing open handle" ]
  873 
  874 
  875 ------------------------------------------------------------------------------
  876 closeActiveFile :: (MonadIO m) => UploadedFiles -> m ()
  877 closeActiveFile (UploadedFiles stateRef) = liftIO $ do
  878     state <- readIORef stateRef
  879     let m = _currentFile state
  880     maybe (return ())
  881           (\(fp,h) -> do
  882                eatException $ hClose h
  883                writeIORef stateRef $
  884                  state { _currentFile = Nothing
  885                        , _alreadyReadFiles = fp:(_alreadyReadFiles state) })
  886           m
  887 
  888 
  889 ------------------------------------------------------------------------------
  890 eatException :: (MonadCatchIO m) => m a -> m ()
  891 eatException m =
  892     (m >> return ()) `catch` (\(_ :: SomeException) -> return ())