blob: 56f4461097afeb92083fef52b9dae913e28524bb (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
|
module Network.WireGuard.Internal.IPPacket
( IPPacket(..)
, parseIPPacket
) where
import qualified Data.ByteArray as BA
import Data.IP (IPv4, IPv6, fromHostAddress,
fromHostAddress6)
import Foreign.Ptr (Ptr)
import Foreign.Storable (peekByteOff)
import Data.Bits
import Data.Word
data IPPacket = InvalidIPPacket
| IPv4Packet { src4 :: IPv4, dest4 :: IPv4 }
| IPv6Packet { src6 :: IPv6, dest6 :: IPv6 }
parseIPPacket :: BA.ByteArrayAccess ba => ba -> IO IPPacket
parseIPPacket packet | BA.length packet < 20 = return InvalidIPPacket
parseIPPacket packet = BA.withByteArray packet $ \ptr -> do
firstByte <- peekByteOff ptr 0 :: IO Word8
let version = firstByte `shiftR` 4
parse4 = do
s4 <- peekByteOff ptr 12
d4 <- peekByteOff ptr 16
return (IPv4Packet (fromHostAddress s4) (fromHostAddress d4))
parse6
| BA.length packet < 40 = return InvalidIPPacket
| otherwise = do
s6a <- peek32be ptr 8
s6b <- peek32be ptr 12
s6c <- peek32be ptr 16
s6d <- peek32be ptr 20
d6a <- peek32be ptr 24
d6b <- peek32be ptr 28
d6c <- peek32be ptr 32
d6d <- peek32be ptr 36
let s6 = (s6a, s6b, s6c, s6d)
d6 = (d6a, d6b, d6c, d6d)
return (IPv6Packet (fromHostAddress6 s6) (fromHostAddress6 d6))
case version of
4 -> parse4
6 -> parse6
_ -> return InvalidIPPacket
peek32be :: Ptr a -> Int -> IO Word32
peek32be ptr offset = do
a <- peekByteOff ptr offset :: IO Word8
b <- peekByteOff ptr (offset + 1) :: IO Word8
c <- peekByteOff ptr (offset + 2) :: IO Word8
d <- peekByteOff ptr (offset + 3) :: IO Word8
return $! (fromIntegral a `unsafeShiftL` 24) .|.
(fromIntegral b `unsafeShiftL` 16) .|.
(fromIntegral c `unsafeShiftL` 8) .|.
fromIntegral d
|