aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorFlorent Daigniere <nextgens@freenetproject.org>2019-02-23 21:50:04 +0100
committerFlorent Daigniere <nextgens@freenetproject.org>2019-02-25 18:20:23 +0100
commit0c2d06d8a5a6bb61b42857ac2c21c579b11a6f1c (patch)
treeabcd5992aaa3f02f0c0b5e14a4673317b6749fca
parentsend: propagate DSCP bits to the outer tunnel (diff)
downloadwireguard-go-0c2d06d8a5a6bb61b42857ac2c21c579b11a6f1c.tar.xz
wireguard-go-0c2d06d8a5a6bb61b42857ac2c21c579b11a6f1c.zip
net: implement ECN handling, rfc6040 stylefd/propagate-DSCP-bits
To decide whether we should use the compatibility mode or the normal mode with a peer, we use the handshake messages as a signaling channel. If we receive the expected ECN bits, it most likely means they're running a compatible version. Signed-off-by: Florent Daigniere <nextgens@freenetproject.org>
-rw-r--r--conn.go4
-rw-r--r--conn_default.go17
-rw-r--r--conn_linux.go85
-rw-r--r--misc.go59
-rw-r--r--peer.go9
-rw-r--r--receive.go19
-rw-r--r--send.go14
7 files changed, 164 insertions, 43 deletions
diff --git a/conn.go b/conn.go
index b8970e7..e38160a 100644
--- a/conn.go
+++ b/conn.go
@@ -20,8 +20,8 @@ const (
*/
type Bind interface {
SetMark(value uint32) error
- ReceiveIPv6(buff []byte) (int, Endpoint, error)
- ReceiveIPv4(buff []byte) (int, Endpoint, error)
+ ReceiveIPv6(buff []byte) (int, Endpoint, byte, error)
+ ReceiveIPv4(buff []byte) (int, Endpoint, byte, error)
Send(buff []byte, end Endpoint, tos byte) error
Close() error
}
diff --git a/conn_default.go b/conn_default.go
index 6f17de5..1b25863 100644
--- a/conn_default.go
+++ b/conn_default.go
@@ -133,26 +133,29 @@ func (bind *NativeBind) Close() error {
return err2
}
-func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+// TODO: implement TOS
+func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) {
if bind.ipv4 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
+ return 0, nil, 0, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
- return n, (*NativeEndpoint)(endpoint), err
+ return n, (*NativeEndpoint)(endpoint), 0, err
}
-func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+// TODO: implement TOS
+func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) {
if bind.ipv6 == nil {
- return 0, nil, syscall.EAFNOSUPPORT
+ return 0, nil, 0, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
- return n, (*NativeEndpoint)(endpoint), err
+ return n, (*NativeEndpoint)(endpoint), 0, err
}
-func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
+// TODO: implement TOS
+func (bind *NativeBind) Send(buff []byte, endpoint Endpoint, tos byte) error {
var err error
nend := endpoint.(*NativeEndpoint)
if nend.IP.To4() != nil {
diff --git a/conn_linux.go b/conn_linux.go
index 83cf1a2..cc1ce2e 100644
--- a/conn_linux.go
+++ b/conn_linux.go
@@ -232,30 +232,32 @@ func (bind *NativeBind) Close() error {
return err3
}
-func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, byte, error) {
var end NativeEndpoint
+ var tos byte
if bind.sock6 == -1 {
- return 0, nil, syscall.EAFNOSUPPORT
+ return 0, nil, tos, syscall.EAFNOSUPPORT
}
- n, err := receive6(
+ n, tos, err := receive6(
bind.sock6,
buff,
&end,
)
- return n, &end, err
+ return n, &end, tos, err
}
-func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, byte, error) {
var end NativeEndpoint
+ var tos byte
if bind.sock4 == -1 {
- return 0, nil, syscall.EAFNOSUPPORT
+ return 0, nil, tos, syscall.EAFNOSUPPORT
}
- n, err := receive4(
+ n, tos, err := receive4(
bind.sock4,
buff,
&end,
)
- return n, &end, err
+ return n, &end, tos, err
}
func (bind *NativeBind) Send(buff []byte, end Endpoint, tos byte) error {
@@ -384,6 +386,15 @@ func create4(port uint16) (int, uint16, error) {
return err
}
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IP,
+ unix.IP_RECVTOS,
+ 1,
+ ); err != nil {
+ return err
+ }
+
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
@@ -442,6 +453,15 @@ func create6(port uint16) (int, uint16, error) {
return err
}
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_RECVTCLASS,
+ 1,
+ ); err != nil {
+ return err
+ }
+
return unix.Bind(fd, &addr)
}(); err != nil {
@@ -452,12 +472,13 @@ func create6(port uint16) (int, uint16, error) {
return fd, uint16(addr.Port), err
}
+type ipTos struct {
+ tos byte
+}
+
func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
// construct message header
- type ipTos struct {
- tos byte
- }
cmsg := struct {
cmsghdr unix.Cmsghdr
@@ -505,9 +526,6 @@ func send4(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
// construct message header
- type ipTos struct {
- tos byte
- }
cmsg := struct {
cmsghdr unix.Cmsghdr
@@ -555,19 +573,21 @@ func send6(sock int, end *NativeEndpoint, buff []byte, tos byte) error {
return err
}
-func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
+func receive4(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) {
// contruct message header
var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet4Pktinfo
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet4Pktinfo
+ cmsghdr2 unix.Cmsghdr
+ iptos ipTos
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
- return 0, err
+ return 0, 0, err
}
end.isV6 = false
@@ -576,7 +596,6 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
}
// update source cache
-
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
@@ -584,22 +603,31 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
end.src4().ifindex = cmsg.pktinfo.Ifindex
}
- return size, nil
+ tos := byte(0)
+ if cmsg.cmsghdr2.Level == unix.IPPROTO_IP &&
+ cmsg.cmsghdr2.Type == unix.IP_TOS &&
+ cmsg.cmsghdr2.Len >= 1 {
+ tos = cmsg.iptos.tos
+ }
+
+ return size, tos, nil
}
-func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
+func receive6(sock int, buff []byte, end *NativeEndpoint) (int, byte, error) {
// contruct message header
var cmsg struct {
- cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet6Pktinfo
+ cmsghdr2 unix.Cmsghdr
+ iptos ipTos
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
- return 0, err
+ return 0, 0, err
}
end.isV6 = true
@@ -616,7 +644,14 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
}
- return size, nil
+ tos := byte(0)
+ if cmsg.cmsghdr2.Level == unix.IPPROTO_IPV6 &&
+ cmsg.cmsghdr2.Type == unix.IPV6_TCLASS &&
+ cmsg.cmsghdr2.Len >= 1 {
+ tos = cmsg.iptos.tos
+ }
+
+ return size, tos, nil
}
func (bind *NativeBind) routineRouteListener(device *Device) {
diff --git a/misc.go b/misc.go
index 6786cb5..e5688a5 100644
--- a/misc.go
+++ b/misc.go
@@ -46,3 +46,62 @@ func min(a, b uint) uint {
}
return a
}
+
+// called from receive
+func ecn_rfc6040_egress(inner byte, outer byte) (byte, bool) {
+ /*
+ +---------+------------------------------------------------+
+ |Arriving | Arriving Outer Header |
+ | Inner +---------+------------+------------+------------+
+ | Header | Not-ECT | ECT(0) | ECT(1) | CE |
+ +---------+---------+------------+------------+------------+
+ | Not-ECT | Not-ECT |Not-ECT(!!!)|Not-ECT(!!!)| <drop>(!!!)|
+ | ECT(0) | ECT(0) | ECT(0) | ECT(1) | CE |
+ | ECT(1) | ECT(1) | ECT(1) (!) | ECT(1) | CE |
+ | CE | CE | CE | CE(!!!)| CE |
+ +---------+---------+------------+------------+------------+
+ */
+ innerECN := CongestionExperienced & inner
+ outerECN := CongestionExperienced & outer
+
+ switch outerECN {
+ case CongestionExperienced:
+ switch innerECN {
+ case NotECNTransport:
+ return 0, true
+ }
+ return (inner & (CongestionExperienced ^ 255)) | CongestionExperienced, false
+ case ECNTransport1:
+ switch innerECN {
+ case ECNTransport0:
+ return (inner & (CongestionExperienced ^ 255)) | ECNTransport1, false
+ }
+ }
+ return inner, false
+}
+
+// called from send
+func ecn_rfc6040_ingress(inner byte, useNormalMode bool) byte {
+ /*
+ +-----------------+-------------------------------+
+ | Incoming Header | Departing Outer Header |
+ | (also equal to +---------------+---------------+
+ | departing Inner | Compatibility | Normal |
+ | Header) | Mode | Mode |
+ +-----------------+---------------+---------------+
+ | Not-ECT | Not-ECT | Not-ECT |
+ | ECT(0) | Not-ECT | ECT(0) |
+ | ECT(1) | Not-ECT | ECT(1) |
+ | CE | Not-ECT | CE |
+ +-----------------+---------------+---------------+
+ */
+ if !useNormalMode {
+ inner &= (CongestionExperienced ^ 255)
+ }
+
+ return inner
+}
+
+func ecn_rfc6040_enabled(tos byte) bool {
+ return (CongestionExperienced & tos) == ECNTransport0
+}
diff --git a/peer.go b/peer.go
index 96cfa61..642a0ee 100644
--- a/peer.go
+++ b/peer.go
@@ -15,6 +15,14 @@ import (
const (
PeerRoutineNumber = 3
+
+ DiffServAF41 = 0x88 // AF41
+ NotECNTransport = 0x00 // Not-ECT (Not ECN-Capable Transport)
+ ECNTransport1 = 0x01 // ECT(1) (ECN-Capable Transport(1))
+ ECNTransport0 = 0x02 // ECT(0) (ECN-Capable Transport(0))
+ CongestionExperienced = 0x03 // CE (Congestion Experienced)
+
+ HandshakeDSCP = DiffServAF41 | ECNTransport0 // AF41, plus 10 ECN
)
type Peer struct {
@@ -25,6 +33,7 @@ type Peer struct {
device *Device
endpoint Endpoint
persistentKeepaliveInterval uint16
+ isECNConfirmed AtomicBool
// This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
stats struct {
diff --git a/receive.go b/receive.go
index fb848eb..03dbd4b 100644
--- a/receive.go
+++ b/receive.go
@@ -23,6 +23,7 @@ type QueueHandshakeElement struct {
packet []byte
endpoint Endpoint
buffer *[MaxMessageSize]byte
+ isECNCompatible bool
}
type QueueInboundElement struct {
@@ -33,6 +34,7 @@ type QueueInboundElement struct {
counter uint64
keypair *Keypair
endpoint Endpoint
+ tos byte
}
func (elem *QueueInboundElement) Drop() {
@@ -108,6 +110,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
err error
size int
endpoint Endpoint
+ outerTOS byte
)
for {
@@ -116,9 +119,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
switch IP {
case ipv4.Version:
- size, endpoint, err = bind.ReceiveIPv4(buffer[:])
+ size, endpoint, outerTOS, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version:
- size, endpoint, err = bind.ReceiveIPv6(buffer[:])
+ size, endpoint, outerTOS, err = bind.ReceiveIPv6(buffer[:])
default:
panic("invalid IP version")
}
@@ -178,6 +181,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
elem.endpoint = endpoint
elem.counter = 0
elem.Mutex = sync.Mutex{}
+ elem.tos = outerTOS
elem.Lock()
// add to decryption queues
@@ -213,6 +217,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
buffer: buffer,
packet: packet,
endpoint: endpoint,
+ isECNCompatible: ecn_rfc6040_enabled(outerTOS),
},
)) {
buffer = device.GetMessageBuffer()
@@ -426,7 +431,7 @@ func (device *Device) RoutineHandshake() {
peer.SetEndpointFromPacket(elem.endpoint)
logDebug.Println(peer, "- Received handshake initiation")
-
+ peer.isECNConfirmed.Set(elem.isECNCompatible)
peer.SendHandshakeResponse()
case MessageResponseType:
@@ -473,6 +478,7 @@ func (device *Device) RoutineHandshake() {
peer.timersSessionDerived()
peer.timersHandshakeComplete()
+ peer.isECNConfirmed.Set(elem.isECNCompatible)
peer.SendKeepalive()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
@@ -565,6 +571,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
}
peer.timersDataReceived()
+ var shouldDrop bool
// verify source and strip padding
switch elem.packet[0] >> 4 {
@@ -595,6 +602,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue
}
+ elem.tos, shouldDrop = ecn_rfc6040_egress(elem.packet[1], elem.tos)
case ipv6.Version:
// strip padding
@@ -623,10 +631,15 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue
}
+ elem.tos, shouldDrop = ecn_rfc6040_egress(elem.packet[1], elem.tos);
default:
logInfo.Println("Packet with invalid IP version from", peer)
continue
}
+ if shouldDrop {
+ logInfo.Println("ECN/Congestion detected, dropping packet from", peer)
+ continue
+ }
// write to tun device
diff --git a/send.go b/send.go
index 57bb67b..f787027 100644
--- a/send.go
+++ b/send.go
@@ -41,10 +41,6 @@ import (
* (to allow the construction of transport messages in-place)
*/
-const (
- HandshakeDSCP = 0x88 // AF41, plus 00 ECN
-)
-
type QueueOutboundElement struct {
dropped int32
sync.Mutex
@@ -299,14 +295,20 @@ func (device *Device) RoutineReadFromTUN() {
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.LookupIPv4(dst)
- elem.tos = elem.packet[1];
+ if peer == nil {
+ continue
+ }
+ elem.tos = ecn_rfc6040_ingress(elem.packet[1], peer.isECNConfirmed.Get())
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.LookupIPv6(dst)
- elem.tos = elem.packet[1];
+ if peer == nil {
+ continue
+ }
+ elem.tos = ecn_rfc6040_ingress(elem.packet[1], peer.isECNConfirmed.Get())
default:
logDebug.Println("Received packet with unknown IP version")
}