diff options
Diffstat (limited to 'src/Network/WireGuard/Internal/IPPacket.hs')
-rw-r--r-- | src/Network/WireGuard/Internal/IPPacket.hs | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/src/Network/WireGuard/Internal/IPPacket.hs b/src/Network/WireGuard/Internal/IPPacket.hs new file mode 100644 index 0000000..56f4461 --- /dev/null +++ b/src/Network/WireGuard/Internal/IPPacket.hs @@ -0,0 +1,56 @@ +module Network.WireGuard.Internal.IPPacket + ( IPPacket(..) + , parseIPPacket + ) where + +import qualified Data.ByteArray as BA +import Data.IP (IPv4, IPv6, fromHostAddress, + fromHostAddress6) +import Foreign.Ptr (Ptr) +import Foreign.Storable (peekByteOff) + +import Data.Bits +import Data.Word + +data IPPacket = InvalidIPPacket + | IPv4Packet { src4 :: IPv4, dest4 :: IPv4 } + | IPv6Packet { src6 :: IPv6, dest6 :: IPv6 } + +parseIPPacket :: BA.ByteArrayAccess ba => ba -> IO IPPacket +parseIPPacket packet | BA.length packet < 20 = return InvalidIPPacket +parseIPPacket packet = BA.withByteArray packet $ \ptr -> do + firstByte <- peekByteOff ptr 0 :: IO Word8 + let version = firstByte `shiftR` 4 + parse4 = do + s4 <- peekByteOff ptr 12 + d4 <- peekByteOff ptr 16 + return (IPv4Packet (fromHostAddress s4) (fromHostAddress d4)) + parse6 + | BA.length packet < 40 = return InvalidIPPacket + | otherwise = do + s6a <- peek32be ptr 8 + s6b <- peek32be ptr 12 + s6c <- peek32be ptr 16 + s6d <- peek32be ptr 20 + d6a <- peek32be ptr 24 + d6b <- peek32be ptr 28 + d6c <- peek32be ptr 32 + d6d <- peek32be ptr 36 + let s6 = (s6a, s6b, s6c, s6d) + d6 = (d6a, d6b, d6c, d6d) + return (IPv6Packet (fromHostAddress6 s6) (fromHostAddress6 d6)) + case version of + 4 -> parse4 + 6 -> parse6 + _ -> return InvalidIPPacket + +peek32be :: Ptr a -> Int -> IO Word32 +peek32be ptr offset = do + a <- peekByteOff ptr offset :: IO Word8 + b <- peekByteOff ptr (offset + 1) :: IO Word8 + c <- peekByteOff ptr (offset + 2) :: IO Word8 + d <- peekByteOff ptr (offset + 3) :: IO Word8 + return $! (fromIntegral a `unsafeShiftL` 24) .|. + (fromIntegral b `unsafeShiftL` 16) .|. + (fromIntegral c `unsafeShiftL` 8) .|. + fromIntegral d |