aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/Foreign/UAPI.hsc
blob: 7aff338c14bee30fa0186871fb39d9247ab898fa (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -fno-warn-missing-signatures #-}

module Network.WireGuard.Foreign.UAPI
  ( WgIpmask(..)
  , PeerFlags
  , peerFlagRemoveMe
  , peerFlagReplaceIpmasks
  , WgPeer(..)
  , DeviceFlags
  , deviceFlagReplacePeers
  , deviceFlagRemovePrivateKey
  , deviceFlagRemovePresharedKey
  , deviceFlagRemoveFwmark
  , WgDevice(..)
  , readConfig
  , writeConfig
  ) where

import           Control.Monad                     (unless)
import           Data.ByteString.Internal          (ByteString (..))
import           Network.Socket.Internal           (HostAddress, HostAddress6,
                                                    SockAddr, peekSockAddr,
                                                    pokeSockAddr)
import           System.IO.Unsafe                  (unsafePerformIO)

import           Data.Char
import           Data.Int
import           Data.Word
import           Foreign
import           Foreign.C.String

import qualified Network.WireGuard.Foreign.Key     as K
import           Network.WireGuard.Internal.Util   (zeroMemory)

import           Network.WireGuard.Foreign.In6Addr

#include "uapi.h"

data WgIpmask = WgIpmask
              { ipmaskIp   :: ! (Either HostAddress HostAddress6)
              , ipmaskCidr :: ! #{type typeof((struct wgipmask){0}.cidr)}
              }

type PeerFlags = #{type typeof((struct wgpeer){0}.flags)}

peerFlagRemoveMe       = #{const WGPEER_REMOVE_ME}       :: PeerFlags
peerFlagReplaceIpmasks = #{const WGPEER_REPLACE_IPMASKS} :: PeerFlags

data WgPeer = WgPeer
            { peerPubKey            :: ! ByteString  -- TODO: use Bytes
            , peerFlags             :: ! PeerFlags
            , peerAddr              :: ! (Maybe SockAddr)
            , peerLastHandshakeTime :: ! #{type typeof((struct wgpeer){0}.last_handshake_time.tv_sec)}
            , peerReceivedBytes     :: ! #{type typeof((struct wgpeer){0}.rx_bytes)}
            , peerTransferredBytes  :: ! #{type typeof((struct wgpeer){0}.tx_bytes)}
            , peerKeepaliveInterval :: ! #{type typeof((struct wgpeer){0}.persistent_keepalive_interval)}
            , peerNumIpmasks        :: ! #{type typeof((struct wgpeer){0}.num_ipmasks)}
            }

type DeviceFlags = #{type typeof((struct wgdevice){0}.flags)}

deviceFlagReplacePeers       = #{const WGDEVICE_REPLACE_PEERS}        :: DeviceFlags
deviceFlagRemovePrivateKey   = #{const WGDEVICE_REMOVE_PRIVATE_KEY}   :: DeviceFlags
deviceFlagRemovePresharedKey = #{const WGDEVICE_REMOVE_PRESHARED_KEY} :: DeviceFlags
deviceFlagRemoveFwmark       = #{const WGDEVICE_REMOVE_FWMARK}        :: DeviceFlags

type VersionMagicType = #{type typeof((struct wgdevice){0}.version_magic)}

apiVersionMagic = #{const WG_API_VERSION_MAGIC } :: VersionMagicType

data WgDevice = WgDevice
              { deviceInterface :: ! String
              , deviceFlags     :: ! DeviceFlags
              , devicePubkey    :: ! ByteString  -- TODO: use Bytes
              , devicePrivkey   :: ! ByteString  -- TODO: use ScrubbedBytes
              , devicePSK       :: ! ByteString  -- TODO: use ScrubbedBytes
              , deviceFwmark    :: ! #{type typeof((struct wgdevice){0}.fwmark)}
              , devicePort      :: ! #{type typeof((struct wgdevice){0}.port)}
              , deviceNumPeers  :: ! #{type typeof((struct wgdevice){0}.num_peers)}
              }

type IpmaskIpFamilyType = #{type typeof((struct wgipmask){0}.family)}

instance Storable WgIpmask where
    sizeOf _              = #{size      struct wgipmask}
    alignment _           = #{alignment struct wgipmask}
    peek ptr              = do
        ipFamily <- #{peek struct wgipmask, family} ptr :: IO IpmaskIpFamilyType
        ip <- case ipFamily of
            #{const AF_INET} -> Left <$> #{peek struct wgipmask, ip4.s_addr} ptr
            #{const AF_INET6} -> Right . fromIn6Addr <$> #{peek struct wgipmask, ip6} ptr
            _ -> fail "WgIpmask.peek: unknown ipfamily"
        cidr <- #{peek struct wgipmask, cidr} ptr
        return (WgIpmask ip cidr)

    poke ptr self@WgIpmask{..} = do
        zeroMemory ptr $ fromIntegral $ sizeOf self
        case ipmaskIp of
            Left ip4  -> do
                #{poke struct wgipmask, family} ptr (#{const AF_INET} :: IpmaskIpFamilyType)
                #{poke struct wgipmask, ip4.s_addr} ptr ip4
            Right ip6 -> do
                #{poke struct wgipmask, family} ptr (#{const AF_INET6} :: IpmaskIpFamilyType)
                #{poke struct wgipmask, ip6} ptr (In6Addr ip6)
        #{poke struct wgipmask, cidr} ptr ipmaskCidr

type IpFamilyType = #{type sa_family_t}

sockaddrOffset = #{offset struct wgpeer, endpoint.addr}
ipfamilyOffset = #{offset struct wgpeer, endpoint.addr.sa_family}

instance Storable WgPeer where
    sizeOf _            = #{size      struct wgpeer}
    alignment _         = #{alignment struct wgpeer}
    peek ptr            = do
        ipfamily <- peek (ptr `plusPtr` ipfamilyOffset) :: IO IpFamilyType
        let sockaddrM = case ipfamily of
                0 -> return Nothing
                _ -> Just <$> peekSockAddr (ptr `plusPtr` sockaddrOffset)
        WgPeer <$> (K.toByteString <$> #{peek struct wgpeer, public_key} ptr)
               <*> #{peek struct wgpeer, flags} ptr
               <*> sockaddrM
               <*> #{peek struct wgpeer, last_handshake_time.tv_sec} ptr
               <*> #{peek struct wgpeer, rx_bytes} ptr
               <*> #{peek struct wgpeer, tx_bytes} ptr
               <*> #{peek struct wgpeer, persistent_keepalive_interval} ptr
               <*> #{peek struct wgpeer, num_ipmasks} ptr
    poke ptr self@WgPeer{..} = do
        zeroMemory ptr $ fromIntegral $ sizeOf self
        #{poke struct wgpeer, public_key} ptr (K.fromByteString peerPubKey)
        #{poke struct wgpeer, flags} ptr peerFlags
        case peerAddr of
            Just addr -> pokeSockAddr (ptr `plusPtr` sockaddrOffset) addr
            Nothing   -> poke (ptr `plusPtr` ipfamilyOffset) (0 :: IpFamilyType)
        #{poke struct wgpeer, last_handshake_time.tv_sec} ptr peerLastHandshakeTime
        #{poke struct wgpeer, rx_bytes} ptr peerReceivedBytes
        #{poke struct wgpeer, tx_bytes} ptr peerTransferredBytes
        #{poke struct wgpeer, persistent_keepalive_interval} ptr peerKeepaliveInterval
        #{poke struct wgpeer, num_ipmasks} ptr peerNumIpmasks

instance Storable WgDevice where
    sizeOf _              = #{size struct wgdevice}
    alignment _           = #{alignment struct wgdevice}
    peek ptr              = do
        magic <- #{peek struct wgdevice, version_magic} ptr
        unless (magic == apiVersionMagic) $ fail "unexpected version_magic"
        WgDevice <$> peekCString (ptr `plusPtr` #{offset struct wgdevice, interface})
                 <*> #{peek struct wgdevice, flags} ptr
                 <*> (K.toByteString <$> #{peek struct wgdevice, public_key} ptr)
                 <*> (K.toByteString <$> #{peek struct wgdevice, private_key} ptr)
                 <*> (K.toByteString <$> #{peek struct wgdevice, preshared_key} ptr)
                 <*> #{peek struct wgdevice, fwmark} ptr
                 <*> #{peek struct wgdevice, port} ptr
                 <*> #{peek struct wgdevice, num_peers} ptr
    poke ptr self@WgDevice{..}
        | length deviceInterface >= #{const IFNAMSIZ} = fail "interface name is too long"
        | otherwise                                   = do
            zeroMemory ptr $ fromIntegral $ sizeOf self
            #{poke struct wgdevice, version_magic} ptr apiVersionMagic
            pokeArray0 (0 :: Word8) (ptr `plusPtr` #{offset struct wgdevice, interface}) (map (fromIntegral.ord) deviceInterface)
            #{poke struct wgdevice, flags} ptr deviceFlags
            #{poke struct wgdevice, public_key} ptr (K.fromByteString devicePubkey)
            #{poke struct wgdevice, private_key} ptr (K.fromByteString devicePrivkey)
            #{poke struct wgdevice, preshared_key} ptr (K.fromByteString devicePSK)
            #{poke struct wgdevice, fwmark} ptr deviceFwmark
            #{poke struct wgdevice, port} ptr devicePort
            #{poke struct wgdevice, num_peers} ptr deviceNumPeers

readConfig :: Storable a => ByteString -> a
readConfig (PS fptr off len)
    | len == sizeOf output = output
    | otherwise            = error "UAPI.readConfig: length mismatch"
  where
    output = unsafePerformIO $ withForeignPtr fptr $ \ptr -> peek (ptr `plusPtr` off)

writeConfig :: Storable a => a -> ByteString
writeConfig input = unsafePerformIO $ do
    fptr <- mallocForeignPtr
    withForeignPtr fptr $ \ptr -> poke ptr input
    return $ PS (castForeignPtr fptr) 0 (sizeOf input)