diff options
author | Bin Jin <bjin@ctrl-d.org> | 2017-03-14 00:44:34 +0800 |
---|---|---|
committer | Bin Jin <bjin@ctrl-d.org> | 2017-03-14 00:44:34 +0800 |
commit | 8bc3504ebf0161f5553f924c3cac6445d46e5728 (patch) | |
tree | 8148814db4eda90dfd8ee5cdaee3af75a65b2170 /src/Network/WireGuard | |
parent | Use dhPubEq for publickey equality check (diff) | |
download | wireguard-hs-8bc3504ebf0161f5553f924c3cac6445d46e5728.tar.xz wireguard-hs-8bc3504ebf0161f5553f924c3cac6445d46e5728.zip |
Tun: use non-blocking fd
Diffstat (limited to 'src/Network/WireGuard')
-rw-r--r-- | src/Network/WireGuard/Foreign/Tun.hs | 55 | ||||
-rw-r--r-- | src/Network/WireGuard/TunListener.hs | 18 |
2 files changed, 34 insertions, 39 deletions
diff --git a/src/Network/WireGuard/Foreign/Tun.hs b/src/Network/WireGuard/Foreign/Tun.hs index c2b3a46..2d0f929 100644 --- a/src/Network/WireGuard/Foreign/Tun.hs +++ b/src/Network/WireGuard/Foreign/Tun.hs @@ -2,44 +2,51 @@ module Network.WireGuard.Foreign.Tun ( openTun - , fdReadBuf - , fdWriteBuf + , tunReadBuf + , tunWriteBuf ) where -import System.Posix.Types (Fd (..)) +import Control.Concurrent (threadWaitRead, threadWaitWrite) +import Control.Monad (forM_) +import System.Posix.Internals (setNonBlockingFD) +import System.Posix.Types (Fd (..)) import Foreign import Foreign.C -#ifdef OS_LINUX -import System.Posix.IO (fdReadBuf, fdWriteBuf) -#endif - openTun :: String -> Int -> IO (Maybe [Fd]) openTun intfName threads = withCString intfName $ \intf_name_c -> allocaArray threads $ \fds_c -> do res <- tun_alloc_c intf_name_c (fromIntegral threads) fds_c -- TODO: handle exception if res > 0 - then Just . map Fd <$> peekArray (fromIntegral res) fds_c - else return Nothing - -foreign import ccall safe "tun.h tun_alloc" tun_alloc_c :: CString -> CInt -> Ptr CInt -> IO CInt - -#ifdef OS_MACOS -fdReadBuf :: Fd -> Ptr Word8 -> CSize -> IO CSize -fdReadBuf _fd _buf 0 = return 0 -fdReadBuf fd buf nbytes = + then do + fds <- peekArray (fromIntegral res) fds_c + forM_ fds $ \fd -> setNonBlockingFD fd True + return (Just (map Fd fds)) + else return Nothing + +tunReadBuf :: Fd -> Ptr Word8 -> CSize -> IO CSize +tunReadBuf _fd _buf 0 = return 0 +tunReadBuf fd buf nbytes = fmap fromIntegral $ - throwErrnoIfMinus1Retry "fdReadBuf" $ - utun_read_c (fromIntegral fd) (castPtr buf) nbytes + throwErrnoIfMinus1RetryMayBlock "tunReadBuf" + (tun_read_c (fromIntegral fd) (castPtr buf) nbytes) + (threadWaitRead fd) -fdWriteBuf :: Fd -> Ptr Word8 -> CSize -> IO CSize -fdWriteBuf fd buf len = +tunWriteBuf :: Fd -> Ptr Word8 -> CSize -> IO CSize +tunWriteBuf fd buf len = fmap fromIntegral $ - throwErrnoIfMinus1Retry "fdWriteBuf" $ - utun_write_c (fromIntegral fd) (castPtr buf) len + throwErrnoIfMinus1RetryMayBlock "tunWriteBuf" + (tun_write_c (fromIntegral fd) (castPtr buf) len) + (threadWaitWrite fd) + +foreign import ccall unsafe "tun.h tun_alloc" tun_alloc_c :: CString -> CInt -> Ptr CInt -> IO CInt -foreign import ccall safe "tun.h utun_read" utun_read_c :: CInt -> Ptr CChar -> CSize -> IO CSize -foreign import ccall safe "tun.h utun_write" utun_write_c :: CInt -> Ptr CChar -> CSize -> IO CSize +#ifdef OS_MACOS +foreign import ccall unsafe "tun.h utun_read" tun_read_c :: CInt -> Ptr CChar -> CSize -> IO CSize +foreign import ccall unsafe "tun.h utun_write" tun_write_c :: CInt -> Ptr CChar -> CSize -> IO CSize +#else +foreign import ccall unsafe "read" tun_read_c :: CInt -> Ptr CChar -> CSize -> IO CSize +foreign import ccall unsafe "write" tun_write_c :: CInt -> Ptr CChar -> CSize -> IO CSize #endif diff --git a/src/Network/WireGuard/TunListener.hs b/src/Network/WireGuard/TunListener.hs index d57cf98..46c290f 100644 --- a/src/Network/WireGuard/TunListener.hs +++ b/src/Network/WireGuard/TunListener.hs @@ -14,18 +14,12 @@ import Foreign.Marshal.Alloc (allocaBytes) import Foreign.Ptr (Ptr) import System.Posix.Types (Fd) -import Network.WireGuard.Foreign.Tun (fdReadBuf, fdWriteBuf) - +import Network.WireGuard.Foreign.Tun import Network.WireGuard.Internal.Constant import Network.WireGuard.Internal.PacketQueue import Network.WireGuard.Internal.Types import Network.WireGuard.Internal.Util -#ifdef OS_LINUX -import Control.Concurrent (threadWaitRead, - threadWaitWrite) -#endif - runTunListener :: [Fd] -> PacketQueue TunPacket -> PacketQueue TunPacket -> IO () runTunListener fds readTunChan writeTunChan = loop fds [] where @@ -45,16 +39,10 @@ handleWrite writeTunChan fd = readFd :: BA.ByteArray ba => Ptr Word8 -> Fd -> IO ba readFd buf fd = do -#ifdef OS_LINUX - threadWaitRead fd -#endif - nbytes <- fdReadBuf fd buf (fromIntegral tunReadBufferLength) + nbytes <- tunReadBuf fd buf (fromIntegral tunReadBufferLength) snd <$> BA.allocRet (fromIntegral nbytes) (\ptr -> copyMemory ptr buf nbytes >> zeroMemory buf nbytes) writeFd :: BA.ByteArrayAccess ba => Fd -> ba -> IO () writeFd fd ba = BA.withByteArray ba $ \ptr -> do -#ifdef OS_LINUX - threadWaitWrite fd -#endif - void $ fdWriteBuf fd ptr (fromIntegral (BA.length ba)) + void $ tunWriteBuf fd ptr (fromIntegral (BA.length ba)) |