aboutsummaryrefslogtreecommitdiffstats
path: root/src/Network/WireGuard/RPC.hs
blob: 3aa4e4400383fd6d695df49e3294050716ec6512 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
{-# LANGUAGE RecordWildCards #-}

module Network.WireGuard.RPC
  ( runRPC,
    serveConduit
  ) where

import           Control.Concurrent.STM              (STM, atomically,
                                                      modifyTVar', readTVar,
                                                      writeTVar)
import           Control.Monad                       (replicateM, sequence,
                                                      when)
import           Control.Monad.IO.Class              (liftIO)
import qualified Crypto.Noise.DH                     as DH
import qualified Data.ByteArray                      as BA
import qualified Data.ByteString                     as BS
import qualified Data.Conduit.Binary                 as CB
import           Data.Conduit.Network.Unix           (appSink, appSource,
                                                      runUnixServer,
                                                      serverSettings)
import qualified Data.HashMap.Strict                 as HM
import           Data.Int                            (Int32)
import           Data.List                           (genericLength)
import           Foreign.C.Types                     (CTime (..))

import           Data.Bits
import           Data.Conduit
import           Data.IP
import           Data.Maybe

import           Network.WireGuard.Foreign.UAPI
import           Network.WireGuard.Internal.Constant
import           Network.WireGuard.Internal.State
import           Network.WireGuard.Internal.Types
import           Network.WireGuard.Internal.Util     (catchIOExceptionAnd,
                                                      catchSomeExceptionAnd)

import Debug.Trace
-- | Run RPC service over a unix socket
runRPC :: FilePath -> Device -> IO ()
runRPC sockPath device = runUnixServer (serverSettings sockPath) $ \app ->
    catchIOExceptionAnd (return ()) $ runConduit (appSource app .| serveConduit device .| appSink app)
    
-- TODO: ensure that all bytestring over sockets will be erased
serveConduit :: Device -> ConduitM BS.ByteString BS.ByteString IO ()
serveConduit device = do
        h <- CB.head
        traceM $  "Received " ++ show h
        case h of
            Just 0 -> showDevice device
            Just byte -> do
                leftover (BS.singleton byte)
                mWgdev <- CB.sinkStorable
                case mWgdev of
                    Just wgdev -> catchSomeExceptionAnd returnError (updateDevice wgdev)
                    Nothing    -> mempty
            Nothing   -> mempty
  where
    returnError = yield $ writeConfig (-invalidValueError)

    showDevice Device{..} = do
        (wgdevice, peers') <- liftIO buildWgDevice
        yield (writeConfig wgdevice)
        mapM_ showPeer peers'
      where
        buildWgDevice = atomically $ do
            localKey' <- readTVar localKey
            let (pub, priv) = case localKey' of
                    Nothing          -> (emptyKey, emptyKey)
                    Just (sec, pub') -> (pubToBytes pub', privToBytes sec)
            psk' <- fmap pskToBytes <$> readTVar presharedKey
            fwmark' <- fromIntegral <$> readTVar fwmark
            port' <- fromIntegral <$> readTVar port
            peers' <- readTVar peers
            return (WgDevice intfName 0 pub priv (fromMaybe emptyKey psk')
                fwmark' port' (fromIntegral $ HM.size peers'), peers')

    showPeer Peer{..} = do
        (wgpeer, ipmasks') <- liftIO buildWgPeer
        yield (writeConfig wgpeer)
        yield $ BS.concat (map (writeConfig . ipRangeToWgIpmask) ipmasks')
      where
        extractTime Nothing          = 0
        extractTime (Just (CTime t)) = fromIntegral t

        buildWgPeer = atomically $ do
            ipmasks' <- readTVar ipmasks
            wgpeer <- WgPeer (pubToBytes remotePub)
                          <$> return 0
                          <*> readTVar endPoint
                          <*> (extractTime <$> readTVar lastHandshakeTime)
                          <*> (fromIntegral <$> readTVar receivedBytes)
                          <*> (fromIntegral <$> readTVar transferredBytes)
                          <*> (fromIntegral <$> readTVar keepaliveInterval)
                          <*> return (genericLength ipmasks')
            return (wgpeer, ipmasks')

    updateDevice wgdevice = do
        setPeerMs <- replicateM (fromIntegral $ deviceNumPeers wgdevice) $ do
            Just wgpeer <- CB.sinkStorable
            -- TODO: replace fromJust
            ipranges <- replicateM (fromIntegral $ peerNumIpmasks wgpeer)
                (wgIpmaskToIpRange . fromJust <$> CB.sinkStorable)
            return $ setPeer device wgpeer ipranges
        liftIO $ atomically $ do
            setDevice device wgdevice
            anyIpMaskChanged <- or <$> sequence setPeerMs
            -- TODO: modify routetable incrementally
            when anyIpMaskChanged $ buildRouteTables device
        yield $ writeConfig (0 :: Int32)

-- | implementation of config.c::set_peer()
setPeer :: Device -> WgPeer -> [IPRange] -> STM Bool
setPeer Device{..} WgPeer{..} ipranges
    | peerPubKey == emptyKey              = return False
    | testFlag peerFlags peerFlagRemoveMe = modifyTVar' peers (HM.delete peerPubKey) >> return False
    | otherwise                           = do
        peers' <- readTVar peers
        Peer{..} <- case HM.lookup peerPubKey peers' of
            Nothing -> do
                newPeer <- createPeer (fromJust $ bytesToPub peerPubKey) -- TODO: replace fromJust
                modifyTVar' peers (HM.insert peerPubKey newPeer)
                return newPeer
            Just p  -> return p
        when (isJust peerAddr) $ writeTVar endPoint peerAddr
        let replaceIpmasks = testFlag peerFlags peerFlagReplaceIpmasks
            changeIpmasks = replaceIpmasks || not (null ipranges)
        when changeIpmasks $
            if replaceIpmasks
              then writeTVar ipmasks ipranges
              else modifyTVar' ipmasks (++ipranges)
        when (peerKeepaliveInterval /= complement 0) $
            writeTVar keepaliveInterval (fromIntegral peerKeepaliveInterval)
        return changeIpmasks

-- | implementation of config.c::config_set_device()
setDevice :: Device -> WgDevice -> STM ()
setDevice device@Device{..} WgDevice{..} = do
    when (deviceFwmark /= 0 || deviceFwmark == 0 && testFlag deviceFlags deviceFlagRemoveFwmark) $
        writeTVar fwmark (fromIntegral deviceFwmark)
    when (devicePort /= 0) $ writeTVar port (fromIntegral devicePort)
    when (testFlag deviceFlags deviceFlagReplacePeers) $ writeTVar peers HM.empty

    let removeLocalKey = testFlag deviceFlags deviceFlagRemovePrivateKey
        changeLocalKey = removeLocalKey || devicePrivkey /= emptyKey
        changeLocalKeyTo = if removeLocalKey then Nothing else bytesToPair devicePrivkey
    when changeLocalKey $ writeTVar localKey changeLocalKeyTo

    let removePSK = testFlag deviceFlags deviceFlagRemovePresharedKey
        changePSK = removePSK || devicePSK /= emptyKey
        changePSKTo = if removePSK then Nothing else Just (bytesToPSK devicePSK)
    when changePSK $ writeTVar presharedKey changePSKTo

    when (changeLocalKey || changePSK) $ invalidateSessions device

ipRangeToWgIpmask :: IPRange -> WgIpmask
ipRangeToWgIpmask (IPv4Range ipv4range) = case addrRangePair ipv4range of
    (ipv4, prefix) -> WgIpmask (Left (toHostAddress ipv4)) (fromIntegral prefix)
ipRangeToWgIpmask (IPv6Range ipv6range) = case addrRangePair ipv6range of
    (ipv6, prefix) -> WgIpmask (Right (toHostAddress6 ipv6)) (fromIntegral prefix)

wgIpmaskToIpRange :: WgIpmask -> IPRange
wgIpmaskToIpRange (WgIpmask ip cidr) = case ip of
    Left ipv4  -> IPv4Range $ makeAddrRange (fromHostAddress ipv4) (fromIntegral cidr)
    Right ipv6 -> IPv6Range $ makeAddrRange (fromHostAddress6 ipv6) (fromIntegral cidr)

invalidValueError :: Int32
invalidValueError = 22  -- TODO: report back actual error

emptyKey :: BS.ByteString
emptyKey = BS.replicate keyLength 0

testFlag :: Bits a => a -> a -> Bool
testFlag a flag = (a .&. flag) /= zeroBits

pubToBytes :: PublicKey -> BS.ByteString
pubToBytes = BA.convert . DH.dhPubToBytes

privToBytes :: PrivateKey -> BS.ByteString
privToBytes = BA.convert . DH.dhSecToBytes

pskToBytes :: PresharedKey -> BS.ByteString
pskToBytes = BA.convert

bytesToPair :: BS.ByteString -> Maybe KeyPair
bytesToPair = DH.dhBytesToPair . BA.convert

bytesToPub :: BS.ByteString -> Maybe PublicKey
bytesToPub = DH.dhBytesToPub . BA.convert

bytesToPSK :: BS.ByteString -> PresharedKey
bytesToPSK = BA.convert