aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/Internal/Util.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Network/WireGuard/Internal/Util.hs')
-rw-r--r--src/Network/WireGuard/Internal/Util.hs62
1 files changed, 62 insertions, 0 deletions
diff --git a/src/Network/WireGuard/Internal/Util.hs b/src/Network/WireGuard/Internal/Util.hs
new file mode 100644
index 0000000..f7ecde5
--- /dev/null
+++ b/src/Network/WireGuard/Internal/Util.hs
@@ -0,0 +1,62 @@
+{-# LANGUAGE ScopedTypeVariables #-}
+
+module Network.WireGuard.Internal.Util
+ ( retryWithBackoff
+ , ignoreSyncExceptions
+ , foreverWithBackoff
+ , catchIOExceptionAnd
+ , catchSomeExceptionAnd
+ , withJust
+ , zeroMemory
+ , copyMemory
+ ) where
+
+import Control.Concurrent (threadDelay)
+import Control.Exception (Exception (..),
+ IOException,
+ SomeAsyncException,
+ SomeException, throwIO)
+import Control.Monad.Catch (MonadCatch (..))
+import System.IO (hPutStrLn, stderr)
+
+import Foreign
+import Foreign.C
+
+import Network.WireGuard.Internal.Constant
+
+retryWithBackoff :: IO () -> IO ()
+retryWithBackoff = foreverWithBackoff . ignoreSyncExceptions
+
+ignoreSyncExceptions :: IO () -> IO ()
+ignoreSyncExceptions m = catch m handleExcept
+ where
+ handleExcept e = case fromException e of
+ Just asyncExcept -> throwIO (asyncExcept :: SomeAsyncException)
+ Nothing -> hPutStrLn stderr (displayException e) -- TODO: proper logging
+
+foreverWithBackoff :: IO () -> IO ()
+foreverWithBackoff m = loop 1
+ where
+ loop t = m >> threadDelay t >> loop (min (t * 2) retryMaxWaitTime)
+
+catchIOExceptionAnd :: MonadCatch m => m () -> m () -> m ()
+catchIOExceptionAnd what m = catch m $ \(_ :: IOException) -> what
+
+catchSomeExceptionAnd :: MonadCatch m => m () -> m () -> m ()
+catchSomeExceptionAnd what m = catch m $ \(_ :: SomeException) -> what
+
+withJust :: Monad m => m (Maybe a) -> (a -> m ()) -> m ()
+withJust mma func = do
+ ma <- mma
+ case ma of
+ Nothing -> return ()
+ Just a -> func a
+
+zeroMemory :: Ptr a -> CSize -> IO ()
+zeroMemory dest nbytes = memset dest 0 (fromIntegral nbytes)
+
+copyMemory :: Ptr a -> Ptr b -> CSize -> IO ()
+copyMemory dest src nbytes = memcpy dest src nbytes
+
+foreign import ccall unsafe "string.h" memset :: Ptr a -> CInt -> CSize -> IO ()
+foreign import ccall unsafe "string.h" memcpy :: Ptr a -> Ptr b -> CSize -> IO ()