aboutsummaryrefslogtreecommitdiffstats
path: root/src/receive.go
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-08-07 15:25:04 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-08-07 15:25:04 +0200
commitcba1d6585ab9b12ae3e0897db85675ba452c3f09 (patch)
tree13d0975bf53a107c2760c833fd07f36d860a338a /src/receive.go
parentFirst set of code review patches (diff)
downloadwireguard-go-cba1d6585ab9b12ae3e0897db85675ba452c3f09.tar.xz
wireguard-go-cba1d6585ab9b12ae3e0897db85675ba452c3f09.zip
Number of fixes in response to code review
This version cannot complete a handshake. The program will panic upon receiving any message on the UDP socket.
Diffstat (limited to 'src/receive.go')
-rw-r--r--src/receive.go507
1 files changed, 250 insertions, 257 deletions
diff --git a/src/receive.go b/src/receive.go
index fb5c51f..5f46925 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -111,113 +111,84 @@ func (device *Device) RoutineBusyMonitor() {
func (device *Device) RoutineReceiveIncomming() {
- logInfo := device.log.Info
logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, started")
- var buffer *[MaxMessageSize]byte
-
for {
- // check if stopped
+ // wait for new conn
+
+ var conn *net.UDPConn
select {
+ case <-device.signal.newUDPConn:
+ device.net.mutex.RLock()
+ conn = device.net.conn
+ device.net.mutex.RUnlock()
+
case <-device.signal.stop:
return
- default:
}
- // read next datagram
-
- if buffer == nil {
- buffer = device.GetMessageBuffer()
- }
-
- // TODO: Take writelock to sleep
- device.net.mutex.RLock()
- conn := device.net.conn
- device.net.mutex.RUnlock()
if conn == nil {
- time.Sleep(time.Second)
continue
}
- // TODO: Wait for new conn or message
- conn.SetReadDeadline(time.Now().Add(time.Second))
+ // receive datagrams until closed
- size, raddr, err := conn.ReadFromUDP(buffer[:])
- if err != nil || size < MinMessageSize {
- continue
- }
+ buffer := device.GetMessageBuffer()
- // handle packet
+ for {
- packet := buffer[:size]
- msgType := binary.LittleEndian.Uint32(packet[:4])
+ // read next datagram
- func() {
- switch msgType {
-
- case MessageInitiationType, MessageResponseType:
-
- // TODO: Check size early
+ size, raddr, err := conn.ReadFromUDP(buffer[:]) // TODO: This is broken
- // add to handshake queue
+ if err != nil {
+ break
+ }
- device.addToHandshakeQueue(
- device.queue.handshake,
- QueueHandshakeElement{
- msgType: msgType,
- buffer: buffer,
- packet: packet,
- source: raddr,
- },
- )
- buffer = nil
+ if size < MinMessageSize {
+ continue
+ }
- case MessageCookieReplyType:
+ // check size of packet
- // TODO: Queue all the things
+ packet := buffer[:size]
+ msgType := binary.LittleEndian.Uint32(packet[:4])
- // verify and update peer cookie state
+ var okay bool
- if len(packet) != MessageCookieReplySize {
- return
- }
+ switch msgType {
- var reply MessageCookieReply
- reader := bytes.NewReader(packet)
- err := binary.Read(reader, binary.LittleEndian, &reply)
- if err != nil {
- logDebug.Println("Failed to decode cookie reply")
- return
- }
- device.ConsumeMessageCookieReply(&reply)
+ // check if transport
case MessageTransportType:
- // lookup key pair
+ // check size
- if len(packet) < MessageTransportSize {
- return
+ if len(packet) < MessageTransportType {
+ continue
}
+ // lookup key pair
+
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indices.Lookup(receiver)
keyPair := value.keyPair
if keyPair == nil {
- return
+ continue
}
// check key-pair expiry
if keyPair.created.Add(RejectAfterTime).Before(time.Now()) {
- return
+ continue
}
- // add to peer queue
+ // create work element
peer := value.peer
elem := &QueueInboundElement{
@@ -233,11 +204,33 @@ func (device *Device) RoutineReceiveIncomming() {
device.addToInboundQueue(device.queue.decryption, elem)
device.addToInboundQueue(peer.queue.inbound, elem)
buffer = nil
+ continue
- default:
- logInfo.Println("Got unknown message from:", raddr)
+ // otherwise it is a handshake related packet
+
+ case MessageInitiationType:
+ okay = len(packet) == MessageInitiationSize
+
+ case MessageResponseType:
+ okay = len(packet) == MessageResponseSize
+
+ case MessageCookieReplyType:
+ okay = len(packet) == MessageCookieReplySize
}
- }()
+
+ if okay {
+ device.addToHandshakeQueue(
+ device.queue.handshake,
+ QueueHandshakeElement{
+ msgType: msgType,
+ buffer: buffer,
+ packet: packet,
+ source: raddr,
+ },
+ )
+ buffer = device.GetMessageBuffer()
+ }
+ }
}
}
@@ -306,154 +299,165 @@ func (device *Device) RoutineHandshake() {
return
}
- func() {
+ // handle cookie fields and ratelimiting
- // verify mac1
+ switch elem.msgType {
- if !device.mac.CheckMAC1(elem.packet) {
- logDebug.Println("Received packet with invalid mac1")
+ case MessageCookieReplyType:
+
+ // verify and update peer cookie state
+
+ var reply MessageCookieReply
+ reader := bytes.NewReader(elem.packet)
+ err := binary.Read(reader, binary.LittleEndian, &reply)
+ if err != nil {
+ logDebug.Println("Failed to decode cookie reply")
return
}
+ device.ConsumeMessageCookieReply(&reply)
+ continue
- // verify mac2
+ case MessageInitiationType, MessageResponseType:
- busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue
+ // check mac fields and ratelimit
- if busy && !device.mac.CheckMAC2(elem.packet, elem.source) {
- sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
- reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
- if err != nil {
- logError.Println("Failed to create cookie reply:", err)
- return
- }
- // TODO: Use temp
- writer := bytes.NewBuffer(elem.packet[:0])
- binary.Write(writer, binary.LittleEndian, reply)
- elem.packet = writer.Bytes()
- _, err = device.net.conn.WriteToUDP(elem.packet, elem.source)
- if err != nil {
- logDebug.Println("Failed to send cookie reply:", err)
- }
+ if !device.mac.CheckMAC1(elem.packet) {
+ logDebug.Println("Received packet with invalid mac1")
return
}
- // ratelimit
-
- // TODO: Only ratelimit when busy
+ busy := atomic.LoadInt32(&device.underLoad) == AtomicTrue
- if !device.ratelimiter.Allow(elem.source.IP) {
- return
+ if busy {
+ if !device.mac.CheckMAC2(elem.packet, elem.source) {
+ sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
+ reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
+ if err != nil {
+ logError.Println("Failed to create cookie reply:", err)
+ return
+ }
+ writer := bytes.NewBuffer(temp[:0])
+ binary.Write(writer, binary.LittleEndian, reply)
+ _, err = device.net.conn.WriteToUDP(
+ writer.Bytes(),
+ elem.source,
+ )
+ if err != nil {
+ logDebug.Println("Failed to send cookie reply:", err)
+ }
+ continue
+ }
+ if !device.ratelimiter.Allow(elem.source.IP) {
+ continue
+ }
}
- // handle messages
+ default:
+ logError.Println("Invalid packet ended up in the handshake queue")
+ continue
+ }
- switch elem.msgType {
- case MessageInitiationType:
+ // handle handshake initation/response content
- // unmarshal
+ switch elem.msgType {
+ case MessageInitiationType:
- if len(elem.packet) != MessageInitiationSize {
- return
- }
+ // unmarshal
- var msg MessageInitiation
- reader := bytes.NewReader(elem.packet)
- err := binary.Read(reader, binary.LittleEndian, &msg)
- if err != nil {
- logError.Println("Failed to decode initiation message")
- return
- }
+ var msg MessageInitiation
+ reader := bytes.NewReader(elem.packet)
+ err := binary.Read(reader, binary.LittleEndian, &msg)
+ if err != nil {
+ logError.Println("Failed to decode initiation message")
+ continue
+ }
- // consume initiation
+ // consume initiation
- peer := device.ConsumeMessageInitiation(&msg)
- if peer == nil {
- logInfo.Println(
- "Recieved invalid initiation message from",
- elem.source.IP.String(),
- elem.source.Port,
- )
- return
- }
+ peer := device.ConsumeMessageInitiation(&msg)
+ if peer == nil {
+ logInfo.Println(
+ "Recieved invalid initiation message from",
+ elem.source.IP.String(),
+ elem.source.Port,
+ )
+ continue
+ }
- // update timers
+ // update timers
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.TimerAnyAuthenticatedPacketReceived()
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
- // update endpoint
- // TODO: Add a race condition \s
+ // update endpoint
+ // TODO: Discover destination address also, only update on change
- peer.mutex.Lock()
- peer.endpoint = elem.source
- peer.mutex.Unlock()
+ peer.mutex.Lock()
+ peer.endpoint = elem.source
+ peer.mutex.Unlock()
- // create response
+ // create response
- response, err := device.CreateMessageResponse(peer)
- if err != nil {
- logError.Println("Failed to create response message:", err)
- return
- }
+ response, err := device.CreateMessageResponse(peer)
+ if err != nil {
+ logError.Println("Failed to create response message:", err)
+ continue
+ }
- peer.TimerEphemeralKeyCreated()
- peer.NewKeyPair()
+ peer.TimerEphemeralKeyCreated()
+ peer.NewKeyPair()
- logDebug.Println("Creating response message for", peer.String())
+ logDebug.Println("Creating response message for", peer.String())
- writer := bytes.NewBuffer(temp[:0])
- binary.Write(writer, binary.LittleEndian, response)
- packet := writer.Bytes()
- peer.mac.AddMacs(packet)
+ writer := bytes.NewBuffer(temp[:0])
+ binary.Write(writer, binary.LittleEndian, response)
+ packet := writer.Bytes()
+ peer.mac.AddMacs(packet)
- // send response
+ // send response
- peer.SendBuffer(packet)
+ _, err = peer.SendBuffer(packet)
+ if err == nil {
peer.TimerAnyAuthenticatedPacketTraversal()
+ }
- case MessageResponseType:
+ case MessageResponseType:
- // unmarshal
+ // unmarshal
- if len(elem.packet) != MessageResponseSize {
- return
- }
-
- var msg MessageResponse
- reader := bytes.NewReader(elem.packet)
- err := binary.Read(reader, binary.LittleEndian, &msg)
- if err != nil {
- logError.Println("Failed to decode response message")
- return
- }
+ var msg MessageResponse
+ reader := bytes.NewReader(elem.packet)
+ err := binary.Read(reader, binary.LittleEndian, &msg)
+ if err != nil {
+ logError.Println("Failed to decode response message")
+ continue
+ }
- // consume response
+ // consume response
- peer := device.ConsumeMessageResponse(&msg)
- if peer == nil {
- logInfo.Println(
- "Recieved invalid response message from",
- elem.source.IP.String(),
- elem.source.Port,
- )
- return
- }
+ peer := device.ConsumeMessageResponse(&msg)
+ if peer == nil {
+ logInfo.Println(
+ "Recieved invalid response message from",
+ elem.source.IP.String(),
+ elem.source.Port,
+ )
+ continue
+ }
- // update timers
+ peer.TimerEphemeralKeyCreated()
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.TimerAnyAuthenticatedPacketReceived()
- peer.TimerHandshakeComplete()
+ // update timers
- // derive key-pair
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
+ peer.TimerHandshakeComplete()
- peer.NewKeyPair()
- peer.SendKeepAlive()
+ // derive key-pair
- default:
- logError.Println("Invalid message type in handshake queue")
- }
- }()
+ peer.NewKeyPair()
+ peer.SendKeepAlive()
+ }
}
}
@@ -463,6 +467,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
logInfo := device.log.Info
+ logError := device.log.Error
logDebug := device.log.Debug
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
@@ -478,116 +483,104 @@ func (peer *Peer) RoutineSequentialReceiver() {
// process packet
- func() {
- if elem.IsDropped() {
- return
- }
-
- // check for replay
-
- if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
- return
- }
+ if elem.IsDropped() {
+ continue
+ }
- peer.TimerAnyAuthenticatedPacketTraversal()
- peer.TimerAnyAuthenticatedPacketReceived()
- peer.KeepKeyFreshReceiving()
+ // check for replay
- // check if using new key-pair
+ if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
+ continue
+ }
- kp := &peer.keyPairs
- kp.mutex.Lock()
- if kp.next == elem.keyPair {
- peer.TimerHandshakeComplete()
- kp.previous = kp.current
- kp.current = kp.next
- kp.next = nil
- }
- kp.mutex.Unlock()
+ peer.TimerAnyAuthenticatedPacketTraversal()
+ peer.TimerAnyAuthenticatedPacketReceived()
+ peer.KeepKeyFreshReceiving()
- // check for keep-alive
+ // check if using new key-pair
- if len(elem.packet) == 0 {
- logDebug.Println("Received keep-alive from", peer.String())
- return
- }
- peer.TimerDataReceived()
+ kp := &peer.keyPairs
+ kp.mutex.Lock()
+ if kp.next == elem.keyPair {
+ peer.TimerHandshakeComplete()
+ kp.previous = kp.current
+ kp.current = kp.next
+ kp.next = nil
+ }
+ kp.mutex.Unlock()
- // verify source and strip padding
+ // check for keep-alive
- switch elem.packet[0] >> 4 {
- case ipv4.Version:
+ if len(elem.packet) == 0 {
+ logDebug.Println("Received keep-alive from", peer.String())
+ continue
+ }
+ peer.TimerDataReceived()
- // strip padding
+ // verify source and strip padding
- if len(elem.packet) < ipv4.HeaderLen {
- return
- }
+ switch elem.packet[0] >> 4 {
+ case ipv4.Version:
- field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
- length := binary.BigEndian.Uint16(field)
- // TODO: check length of packet & NOT TOO SMALL either
- elem.packet = elem.packet[:length]
+ // strip padding
- // verify IPv4 source
+ if len(elem.packet) < ipv4.HeaderLen {
+ continue
+ }
- src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
- if device.routingTable.LookupIPv4(src) != peer {
- logInfo.Println("Packet with unallowed source IP from", peer.String())
- return
- }
+ field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
+ length := binary.BigEndian.Uint16(field)
+ if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
+ continue
+ }
- case ipv6.Version:
+ elem.packet = elem.packet[:length]
- // strip padding
+ // verify IPv4 source
- if len(elem.packet) < ipv6.HeaderLen {
- return
- }
+ src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
+ if device.routingTable.LookupIPv4(src) != peer {
+ logInfo.Println("Packet with unallowed source IP from", peer.String())
+ continue
+ }
- field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
- length := binary.BigEndian.Uint16(field)
- length += ipv6.HeaderLen
- // TODO: check length of packet
- elem.packet = elem.packet[:length]
+ case ipv6.Version:
- // verify IPv6 source
+ // strip padding
- src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
- if device.routingTable.LookupIPv6(src) != peer {
- logInfo.Println("Packet with unallowed source IP from", peer.String())
- return
- }
+ if len(elem.packet) < ipv6.HeaderLen {
+ continue
+ }
- default:
- logInfo.Println("Packet with invalid IP version from", peer.String())
- return
+ field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
+ length := binary.BigEndian.Uint16(field)
+ length += ipv6.HeaderLen
+ if int(length) > len(elem.packet) {
+ continue
}
- atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
- device.addToInboundQueue(device.queue.inbound, elem)
+ elem.packet = elem.packet[:length]
- // TODO: move TUN write into per peer routine
- }()
- }
-}
+ // verify IPv6 source
-func (device *Device) RoutineWriteToTUN() {
+ src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
+ if device.routingTable.LookupIPv6(src) != peer {
+ logInfo.Println("Packet with unallowed source IP from", peer.String())
+ continue
+ }
- logError := device.log.Error
- logDebug := device.log.Debug
- logDebug.Println("Routine, sequential tun writer, started")
+ default:
+ logInfo.Println("Packet with invalid IP version from", peer.String())
+ continue
+ }
- for {
- select {
- case <-device.signal.stop:
- return
- case elem := <-device.queue.inbound:
- _, err := device.tun.Write(elem.packet)
- device.PutMessageBuffer(elem.buffer)
- if err != nil {
- logError.Println("Failed to write packet to TUN device:", err)
- }
+ // write to tun
+
+ atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
+ _, err := device.tun.Write(elem.packet)
+ device.PutMessageBuffer(elem.buffer)
+ if err != nil {
+ logError.Println("Failed to write packet to TUN device:", err)
}
}
}