aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/Internal/Util.hs
blob: f7ecde544538c056e6228adf28f488df2bea70ca (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
{-# LANGUAGE ScopedTypeVariables #-}

module Network.WireGuard.Internal.Util
  ( retryWithBackoff
  , ignoreSyncExceptions
  , foreverWithBackoff
  , catchIOExceptionAnd
  , catchSomeExceptionAnd
  , withJust
  , zeroMemory
  , copyMemory
  ) where

import           Control.Concurrent                  (threadDelay)
import           Control.Exception                   (Exception (..),
                                                      IOException,
                                                      SomeAsyncException,
                                                      SomeException, throwIO)
import           Control.Monad.Catch                 (MonadCatch (..))
import           System.IO                           (hPutStrLn, stderr)

import           Foreign
import           Foreign.C

import           Network.WireGuard.Internal.Constant

retryWithBackoff :: IO () -> IO ()
retryWithBackoff = foreverWithBackoff . ignoreSyncExceptions

ignoreSyncExceptions :: IO () -> IO ()
ignoreSyncExceptions m = catch m handleExcept
  where
    handleExcept e = case fromException e of
        Just asyncExcept -> throwIO (asyncExcept :: SomeAsyncException)
        Nothing          -> hPutStrLn stderr (displayException e)  -- TODO: proper logging

foreverWithBackoff :: IO () -> IO ()
foreverWithBackoff m = loop 1
  where
    loop t = m >> threadDelay t >> loop (min (t * 2) retryMaxWaitTime)

catchIOExceptionAnd :: MonadCatch m => m () -> m () -> m ()
catchIOExceptionAnd what m = catch m $ \(_ :: IOException) -> what

catchSomeExceptionAnd :: MonadCatch m => m () -> m () -> m ()
catchSomeExceptionAnd what m = catch m $ \(_ :: SomeException) -> what

withJust :: Monad m => m (Maybe a) -> (a -> m ()) -> m ()
withJust mma func = do
    ma <- mma
    case ma of
        Nothing -> return ()
        Just a  -> func a

zeroMemory :: Ptr a -> CSize -> IO ()
zeroMemory dest nbytes = memset dest 0 (fromIntegral nbytes)

copyMemory :: Ptr a -> Ptr b -> CSize -> IO ()
copyMemory dest src nbytes = memcpy dest src nbytes

foreign import ccall unsafe "string.h" memset :: Ptr a -> CInt -> CSize -> IO ()
foreign import ccall unsafe "string.h" memcpy :: Ptr a -> Ptr b -> CSize -> IO ()