From 647d7b7157b6957b61ebe6be60f49828c025a4d7 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 1 Jul 2019 09:39:08 +0200 Subject: device: prepare for multiple send/receive --- device/conn.go | 3 +- device/conn_default.go | 6 +++- device/conn_linux.go | 6 +++- device/peer.go | 4 +-- device/receive.go | 8 ++--- device/send.go | 87 +++++++++++++++++++++++++++++++------------------- 6 files changed, 73 insertions(+), 41 deletions(-) diff --git a/device/conn.go b/device/conn.go index 7b341f6..94e73ac 100644 --- a/device/conn.go +++ b/device/conn.go @@ -24,7 +24,8 @@ type Bind interface { SetMark(value uint32) error ReceiveIPv6(buff []byte) (int, Endpoint, error) ReceiveIPv4(buff []byte) (int, Endpoint, error) - Send(buff []byte, end Endpoint) error + Send(buff []byte, end Endpoint, now bool) error + Flush() error Close() error } diff --git a/device/conn_default.go b/device/conn_default.go index 820bb96..777b0a0 100644 --- a/device/conn_default.go +++ b/device/conn_default.go @@ -152,7 +152,7 @@ func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { return n, (*NativeEndpoint)(endpoint), err } -func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error { +func (bind *nativeBind) Send(buff []byte, endpoint Endpoint, now bool) error { var err error nend := endpoint.(*NativeEndpoint) if nend.IP.To4() != nil { @@ -168,3 +168,7 @@ func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error { } return err } + +func (bind *nativeBind) Flush() error { + return nil +} \ No newline at end of file diff --git a/device/conn_linux.go b/device/conn_linux.go index ebbbe11..ed2d2b3 100644 --- a/device/conn_linux.go +++ b/device/conn_linux.go @@ -259,7 +259,7 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { return n, &end, err } -func (bind *nativeBind) Send(buff []byte, end Endpoint) error { +func (bind *nativeBind) Send(buff []byte, end Endpoint, now bool) error { nend := end.(*NativeEndpoint) if !nend.isV6 { if bind.sock4 == -1 { @@ -274,6 +274,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error { } } +func (bind *nativeBind) Flush() error { + return nil +} + func (end *NativeEndpoint) SrcIP() net.IP { if !end.isV6 { return net.IPv4( diff --git a/device/peer.go b/device/peer.go index 4e7f2da..ebe0f65 100644 --- a/device/peer.go +++ b/device/peer.go @@ -126,7 +126,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { return peer, nil } -func (peer *Peer) SendBuffer(buffer []byte) error { +func (peer *Peer) SendBuffer(buffer []byte, now bool) error { peer.device.net.RLock() defer peer.device.net.RUnlock() @@ -141,7 +141,7 @@ func (peer *Peer) SendBuffer(buffer []byte) error { return errors.New("no known endpoint for peer") } - err := peer.device.net.bind.Send(buffer, peer.endpoint) + err := peer.device.net.bind.Send(buffer, peer.endpoint, now) if err == nil { atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) } diff --git a/device/receive.go b/device/receive.go index 62b5ef4..be0ee4e 100644 --- a/device/receive.go +++ b/device/receive.go @@ -485,7 +485,7 @@ func (device *Device) RoutineHandshake() { } } -func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) { +func (peer *Peer) receiveElementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) { if !*shouldFlush { select { case <-peer.routines.stop: @@ -505,9 +505,9 @@ func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, *shouldFlush = false err := peer.device.tun.device.Flush() if err != nil { - peer.device.log.Error.Printf("Unable to flush packets: %v", err) + peer.device.log.Error.Printf("Unable to flush receive packets: %v", err) } - return peer.elementStopOrFlush(shouldFlush) + return peer.receiveElementStopOrFlush(shouldFlush) } } } @@ -549,7 +549,7 @@ func (peer *Peer) RoutineSequentialReceiver() { elem = nil } - stop, ok, elem = peer.elementStopOrFlush(&shouldFlush) + stop, ok, elem = peer.receiveElementStopOrFlush(&shouldFlush) if stop || !ok { return } diff --git a/device/send.go b/device/send.go index c4aa5b9..edc58c0 100644 --- a/device/send.go +++ b/device/send.go @@ -160,7 +160,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet, true) if err != nil { peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err) } @@ -198,7 +198,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffer(packet, true) if err != nil { peer.device.log.Error.Println(peer, "- Failed to send handshake response", err) } @@ -219,7 +219,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) var buff [MessageCookieReplySize]byte writer := bytes.NewBuffer(buff[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint, true) if err != nil { device.log.Error.Println("Failed to send cookie reply:", err) } @@ -541,6 +541,33 @@ func (device *Device) RoutineEncryption() { } } +func (peer *Peer) sendElementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueOutboundElement) { + if !*shouldFlush { + select { + case <-peer.routines.stop: + stop = true + return + case elem, elemOk = <-peer.queue.outbound: + return + } + } else { + select { + case <-peer.routines.stop: + stop = true + return + case elem, elemOk = <-peer.queue.outbound: + return + default: + *shouldFlush = false + err := peer.device.net.bind.Flush() + if err != nil { + peer.device.log.Error.Printf("Unable to flush send packets: %v", err) + } + return peer.sendElementStopOrFlush(shouldFlush) + } + } +} + /* Sequentially reads packets from queue and sends to endpoint * * Obs. Single instance per peer. @@ -577,41 +604,37 @@ func (peer *Peer) RoutineSequentialSender() { peer.routines.starting.Done() + shouldFlush := false for { - select { - - case <-peer.routines.stop: + stop, ok, elem := peer.sendElementStopOrFlush(&shouldFlush) + if stop || !ok { return + } - case elem, ok := <-peer.queue.outbound: - - if !ok { - return - } - - elem.Lock() - if elem.IsDropped() { - device.PutOutboundElement(elem) - continue - } - - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketSent() + elem.Lock() + if elem.IsDropped() { + device.PutOutboundElement(elem) + continue + } - // send message and return buffer to pool + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() - err := peer.SendBuffer(elem.packet) - if len(elem.packet) != MessageKeepaliveSize { - peer.timersDataSent() - } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - if err != nil { - logError.Println(peer, "- Failed to send data packet", err) - continue - } + // send message and return buffer to pool - peer.keepKeyFreshSending() + err := peer.SendBuffer(elem.packet, false) + if len(elem.packet) != MessageKeepaliveSize { + peer.timersDataSent() } + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + if err != nil { + logError.Println(peer, "- Failed to send data packet", err) + continue + } else { + shouldFlush = true + } + + peer.keepKeyFreshSending() } } -- cgit v1.2.3-59-g8ed1b