From a4f8e83d5d9f477554971e90e9ab85922f506ea9 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Mon, 22 Feb 2021 02:01:50 +0100 Subject: conn: make binds replacable Signed-off-by: Jason A. Donenfeld --- conn/bind_linux.go | 580 ++++++++++++++++++++++++++++++++++++++++++++++++ conn/bind_std.go | 180 +++++++++++++++ conn/boundif_android.go | 4 +- conn/boundif_windows.go | 4 +- conn/conn.go | 43 ++-- conn/conn_default.go | 171 -------------- conn/conn_linux.go | 576 ----------------------------------------------- conn/default.go | 10 + conn/mark_default.go | 2 +- conn/mark_unix.go | 4 +- 10 files changed, 796 insertions(+), 778 deletions(-) create mode 100644 conn/bind_linux.go create mode 100644 conn/bind_std.go delete mode 100644 conn/conn_default.go delete mode 100644 conn/conn_linux.go create mode 100644 conn/default.go (limited to 'conn') diff --git a/conn/bind_linux.go b/conn/bind_linux.go new file mode 100644 index 0000000..4199809 --- /dev/null +++ b/conn/bind_linux.go @@ -0,0 +1,580 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "net" + "strconv" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/unix" +) + +type ipv4Source struct { + Src [4]byte + Ifindex int32 +} + +type ipv6Source struct { + src [16]byte + // ifindex belongs in dst.ZoneId +} + +type LinuxSocketEndpoint struct { + sync.Mutex + dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte + src [unsafe.Sizeof(ipv6Source{})]byte + isV6 bool +} + +func (endpoint *LinuxSocketEndpoint) Src4() *ipv4Source { return endpoint.src4() } +func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } +func (endpoint *LinuxSocketEndpoint) IsV6() bool { return endpoint.isV6 } + +func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source { + return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0])) +} + +func (endpoint *LinuxSocketEndpoint) src6() *ipv6Source { + return (*ipv6Source)(unsafe.Pointer(&endpoint.src[0])) +} + +func (endpoint *LinuxSocketEndpoint) dst4() *unix.SockaddrInet4 { + return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) +} + +func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 { + return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) +} + +// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. +type LinuxSocketBind struct { + sock4 int + sock6 int + lastMark uint32 + closing sync.RWMutex +} + +func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } +func NewDefaultBind() Bind { return NewLinuxSocketBind() } + +var _ Endpoint = (*LinuxSocketEndpoint)(nil) +var _ Bind = (*LinuxSocketBind)(nil) + +func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) { + var end LinuxSocketEndpoint + addr, err := parseEndpoint(s) + if err != nil { + return nil, err + } + + ipv4 := addr.IP.To4() + if ipv4 != nil { + dst := end.dst4() + end.isV6 = false + dst.Port = addr.Port + copy(dst.Addr[:], ipv4) + end.ClearSrc() + return &end, nil + } + + ipv6 := addr.IP.To16() + if ipv6 != nil { + zone, err := zoneToUint32(addr.Zone) + if err != nil { + return nil, err + } + dst := end.dst6() + end.isV6 = true + dst.Port = addr.Port + dst.ZoneId = zone + copy(dst.Addr[:], ipv6[:]) + end.ClearSrc() + return &end, nil + } + + return nil, errors.New("invalid IP address") +} + +func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) { + var err error + var newPort uint16 + var tries int + + if bind.sock4 != -1 || bind.sock6 != -1 { + return 0, ErrBindAlreadyOpen + } + + originalPort := port + +again: + port = originalPort + // Attempt ipv6 bind, update port if successful. + bind.sock6, newPort, err = create6(port) + if err != nil { + if err != syscall.EAFNOSUPPORT { + return 0, err + } + } else { + port = newPort + } + + // Attempt ipv4 bind, update port if successful. + bind.sock4, newPort, err = create4(port) + if err != nil { + if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 { + unix.Close(bind.sock6) + tries++ + goto again + } + if err != syscall.EAFNOSUPPORT { + unix.Close(bind.sock6) + return 0, err + } + } else { + port = newPort + } + + if bind.sock4 == -1 && bind.sock6 == -1 { + return 0, syscall.EAFNOSUPPORT + } + return port, nil +} + +func (bind *LinuxSocketBind) SetMark(value uint32) error { + bind.closing.RLock() + defer bind.closing.RUnlock() + + if bind.sock6 != -1 { + err := unix.SetsockoptInt( + bind.sock6, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + } + + if bind.sock4 != -1 { + err := unix.SetsockoptInt( + bind.sock4, + unix.SOL_SOCKET, + unix.SO_MARK, + int(value), + ) + + if err != nil { + return err + } + } + + bind.lastMark = value + return nil +} + +func (bind *LinuxSocketBind) Close() error { + var err1, err2 error + bind.closing.RLock() + if bind.sock6 != -1 { + unix.Shutdown(bind.sock6, unix.SHUT_RDWR) + } + if bind.sock4 != -1 { + unix.Shutdown(bind.sock4, unix.SHUT_RDWR) + } + bind.closing.RUnlock() + bind.closing.Lock() + if bind.sock6 != -1 { + err1 = unix.Close(bind.sock6) + bind.sock6 = -1 + } + if bind.sock4 != -1 { + err2 = unix.Close(bind.sock4) + bind.sock4 = -1 + } + bind.closing.Unlock() + + if err1 != nil { + return err1 + } + return err2 +} + +func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + bind.closing.RLock() + defer bind.closing.RUnlock() + + var end LinuxSocketEndpoint + if bind.sock6 == -1 { + return 0, nil, net.ErrClosed + } + n, err := receive6( + bind.sock6, + buff, + &end, + ) + return n, &end, err +} + +func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + bind.closing.RLock() + defer bind.closing.RUnlock() + + var end LinuxSocketEndpoint + if bind.sock4 == -1 { + return 0, nil, net.ErrClosed + } + n, err := receive4( + bind.sock4, + buff, + &end, + ) + return n, &end, err +} + +func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { + bind.closing.RLock() + defer bind.closing.RUnlock() + + nend, ok := end.(*LinuxSocketEndpoint) + if !ok { + return ErrWrongEndpointType + } + if !nend.isV6 { + if bind.sock4 == -1 { + return net.ErrClosed + } + return send4(bind.sock4, nend, buff) + } else { + if bind.sock6 == -1 { + return net.ErrClosed + } + return send6(bind.sock6, nend, buff) + } +} + +func (end *LinuxSocketEndpoint) SrcIP() net.IP { + if !end.isV6 { + return net.IPv4( + end.src4().Src[0], + end.src4().Src[1], + end.src4().Src[2], + end.src4().Src[3], + ) + } else { + return end.src6().src[:] + } +} + +func (end *LinuxSocketEndpoint) DstIP() net.IP { + if !end.isV6 { + return net.IPv4( + end.dst4().Addr[0], + end.dst4().Addr[1], + end.dst4().Addr[2], + end.dst4().Addr[3], + ) + } else { + return end.dst6().Addr[:] + } +} + +func (end *LinuxSocketEndpoint) DstToBytes() []byte { + if !end.isV6 { + return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] + } else { + return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] + } +} + +func (end *LinuxSocketEndpoint) SrcToString() string { + return end.SrcIP().String() +} + +func (end *LinuxSocketEndpoint) DstToString() string { + var udpAddr net.UDPAddr + udpAddr.IP = end.DstIP() + if !end.isV6 { + udpAddr.Port = end.dst4().Port + } else { + udpAddr.Port = end.dst6().Port + } + return udpAddr.String() +} + +func (end *LinuxSocketEndpoint) ClearDst() { + for i := range end.dst { + end.dst[i] = 0 + } +} + +func (end *LinuxSocketEndpoint) ClearSrc() { + for i := range end.src { + end.src[i] = 0 + } +} + +func zoneToUint32(zone string) (uint32, error) { + if zone == "" { + return 0, nil + } + if intr, err := net.InterfaceByName(zone); err == nil { + return uint32(intr.Index), nil + } + n, err := strconv.ParseUint(zone, 10, 32) + return uint32(n), err +} + +func create4(port uint16) (int, uint16, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return -1, 0, err + } + + addr := unix.SockaddrInet4{ + Port: int(port), + } + + // set sockopts and bind + + if err := func() error { + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IP, + unix.IP_PKTINFO, + 1, + ); err != nil { + return err + } + + return unix.Bind(fd, &addr) + }(); err != nil { + unix.Close(fd) + return -1, 0, err + } + + sa, err := unix.Getsockname(fd) + if err == nil { + addr.Port = sa.(*unix.SockaddrInet4).Port + } + + return fd, uint16(addr.Port), err +} + +func create6(port uint16) (int, uint16, error) { + + // create socket + + fd, err := unix.Socket( + unix.AF_INET6, + unix.SOCK_DGRAM, + 0, + ) + + if err != nil { + return -1, 0, err + } + + // set sockopts and bind + + addr := unix.SockaddrInet6{ + Port: int(port), + } + + if err := func() error { + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_RECVPKTINFO, + 1, + ); err != nil { + return err + } + + if err := unix.SetsockoptInt( + fd, + unix.IPPROTO_IPV6, + unix.IPV6_V6ONLY, + 1, + ); err != nil { + return err + } + + return unix.Bind(fd, &addr) + + }(); err != nil { + unix.Close(fd) + return -1, 0, err + } + + sa, err := unix.Getsockname(fd) + if err == nil { + addr.Port = sa.(*unix.SockaddrInet6).Port + } + + return fd, uint16(addr.Port), err +} + +func send4(sock int, end *LinuxSocketEndpoint, buff []byte) error { + + // construct message header + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, + }, + unix.Inet4Pktinfo{ + Spec_dst: end.src4().Src, + Ifindex: end.src4().Ifindex, + }, + } + + end.Lock() + _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) + end.Unlock() + + if err == nil { + return nil + } + + // clear src and retry + + if err == unix.EINVAL { + end.ClearSrc() + cmsg.pktinfo = unix.Inet4Pktinfo{} + end.Lock() + _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) + end.Unlock() + } + + return err +} + +func send6(sock int, end *LinuxSocketEndpoint, buff []byte) error { + + // construct message header + + cmsg := struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + }{ + unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, + }, + unix.Inet6Pktinfo{ + Addr: end.src6().src, + Ifindex: end.dst6().ZoneId, + }, + } + + if cmsg.pktinfo.Addr == [16]byte{} { + cmsg.pktinfo.Ifindex = 0 + } + + end.Lock() + _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) + end.Unlock() + + if err == nil { + return nil + } + + // clear src and retry + + if err == unix.EINVAL { + end.ClearSrc() + cmsg.pktinfo = unix.Inet6Pktinfo{} + end.Lock() + _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) + end.Unlock() + } + + return err +} + +func receive4(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { + + // construct message header + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet4Pktinfo + } + + size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) + + if err != nil { + return 0, err + } + end.isV6 = false + + if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { + *end.dst4() = *newDst4 + } + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IP && + cmsg.cmsghdr.Type == unix.IP_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { + end.src4().Src = cmsg.pktinfo.Spec_dst + end.src4().Ifindex = cmsg.pktinfo.Ifindex + } + + return size, nil +} + +func receive6(sock int, buff []byte, end *LinuxSocketEndpoint) (int, error) { + + // construct message header + + var cmsg struct { + cmsghdr unix.Cmsghdr + pktinfo unix.Inet6Pktinfo + } + + size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) + + if err != nil { + return 0, err + } + end.isV6 = true + + if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { + *end.dst6() = *newDst6 + } + + // update source cache + + if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && + cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && + cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { + end.src6().src = cmsg.pktinfo.Addr + end.dst6().ZoneId = cmsg.pktinfo.Ifindex + } + + return size, nil +} diff --git a/conn/bind_std.go b/conn/bind_std.go new file mode 100644 index 0000000..193c4fe --- /dev/null +++ b/conn/bind_std.go @@ -0,0 +1,180 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "net" + "syscall" +) + +// StdNetBind is meant to be a temporary solution on platforms for which +// the sticky socket / source caching behavior has not yet been implemented. +// It uses the Go's net package to implement networking. +// See LinuxSocketBind for a proper implementation on the Linux platform. +type StdNetBind struct { + ipv4 *net.UDPConn + ipv6 *net.UDPConn + blackhole4 bool + blackhole6 bool +} + +func NewStdNetBind() Bind { return &StdNetBind{} } + +type StdNetEndpoint net.UDPAddr + +var _ Bind = (*StdNetBind)(nil) +var _ Endpoint = (*StdNetEndpoint)(nil) + +func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { + addr, err := parseEndpoint(s) + return (*StdNetEndpoint)(addr), err +} + +func (*StdNetEndpoint) ClearSrc() {} + +func (e *StdNetEndpoint) DstIP() net.IP { + return (*net.UDPAddr)(e).IP +} + +func (e *StdNetEndpoint) SrcIP() net.IP { + return nil // not supported +} + +func (e *StdNetEndpoint) DstToBytes() []byte { + addr := (*net.UDPAddr)(e) + out := addr.IP.To4() + if out == nil { + out = addr.IP + } + out = append(out, byte(addr.Port&0xff)) + out = append(out, byte((addr.Port>>8)&0xff)) + return out +} + +func (e *StdNetEndpoint) DstToString() string { + return (*net.UDPAddr)(e).String() +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + +func listenNet(network string, port int) (*net.UDPConn, int, error) { + conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) + if err != nil { + return nil, 0, err + } + + // Retrieve port. + laddr := conn.LocalAddr() + uaddr, err := net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + if err != nil { + return nil, 0, err + } + return conn, uaddr.Port, nil +} + +func (bind *StdNetBind) Open(uport uint16) (uint16, error) { + var err error + var tries int + + if bind.ipv4 != nil || bind.ipv6 != nil { + return 0, ErrBindAlreadyOpen + } + +again: + port := int(uport) + + bind.ipv4, port, err = listenNet("udp4", port) + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + bind.ipv4 = nil + return 0, err + } + + bind.ipv6, port, err = listenNet("udp6", port) + if uport == 0 && err != nil && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + bind.ipv4.Close() + bind.ipv4 = nil + bind.ipv6 = nil + tries++ + goto again + } + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + bind.ipv4.Close() + bind.ipv4 = nil + bind.ipv6 = nil + return 0, err + } + if bind.ipv4 == nil && bind.ipv6 == nil { + return 0, syscall.EAFNOSUPPORT + } + return uint16(port), nil +} + +func (bind *StdNetBind) Close() error { + var err1, err2 error + if bind.ipv4 != nil { + err1 = bind.ipv4.Close() + bind.ipv4 = nil + } + if bind.ipv6 != nil { + err2 = bind.ipv6.Close() + bind.ipv6 = nil + } + if err1 != nil { + return err1 + } + return err2 +} + +func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { + if bind.ipv4 == nil { + return 0, nil, syscall.EAFNOSUPPORT + } + n, endpoint, err := bind.ipv4.ReadFromUDP(buff) + if endpoint != nil { + endpoint.IP = endpoint.IP.To4() + } + return n, (*StdNetEndpoint)(endpoint), err +} + +func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { + if bind.ipv6 == nil { + return 0, nil, syscall.EAFNOSUPPORT + } + n, endpoint, err := bind.ipv6.ReadFromUDP(buff) + return n, (*StdNetEndpoint)(endpoint), err +} + +func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { + var err error + nend, ok := endpoint.(*StdNetEndpoint) + if !ok { + return ErrWrongEndpointType + } + if nend.IP.To4() != nil { + if bind.ipv4 == nil { + return syscall.EAFNOSUPPORT + } + if bind.blackhole4 { + return nil + } + _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } else { + if bind.ipv6 == nil { + return syscall.EAFNOSUPPORT + } + if bind.blackhole6 { + return nil + } + _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) + } + return err +} diff --git a/conn/boundif_android.go b/conn/boundif_android.go index 2c68d57..8b82bfc 100644 --- a/conn/boundif_android.go +++ b/conn/boundif_android.go @@ -5,7 +5,7 @@ package conn -func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) { +func (bind *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { sysconn, err := bind.ipv4.SyscallConn() if err != nil { return -1, err @@ -19,7 +19,7 @@ func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) { return } -func (bind *nativeBind) PeekLookAtSocketFd6() (fd int, err error) { +func (bind *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { sysconn, err := bind.ipv6.SyscallConn() if err != nil { return -1, err diff --git a/conn/boundif_windows.go b/conn/boundif_windows.go index e425d23..6f6fdd8 100644 --- a/conn/boundif_windows.go +++ b/conn/boundif_windows.go @@ -17,7 +17,7 @@ const ( sockoptIPV6_UNICAST_IF = 31 ) -func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { +func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ bytes := make([]byte, 4) binary.BigEndian.PutUint32(bytes, interfaceIndex) @@ -40,7 +40,7 @@ func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole return nil } -func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { +func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { sysconn, err := bind.ipv6.SyscallConn() if err != nil { return err diff --git a/conn/conn.go b/conn/conn.go index 6e7939c..6fd232f 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -17,40 +17,30 @@ import ( // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, // depending on the platform-specific implementation. type Bind interface { - // LastMark reports the last mark set for this Bind. - LastMark() uint32 + // Open puts the Bind into a listening state on a given port and reports the actual + // port that it bound to. Passing zero results in a random selection. + Open(port uint16) (actualPort uint16, err error) + + // Close closes the Bind listener. + Close() error // SetMark sets the mark for each packet sent through this Bind. // This mark is passed to the kernel as the socket option SO_MARK. SetMark(mark uint32) error - // ReceiveIPv6 reads an IPv6 UDP packet into b. - // - // It reports the number of bytes read, n, - // the packet source address ep, - // and any error. + // ReceiveIPv6 reads an IPv6 UDP packet into b. It reports the number of bytes read, + // n, the packet source address ep, and any error. ReceiveIPv6(b []byte) (n int, ep Endpoint, err error) - // ReceiveIPv4 reads an IPv4 UDP packet into b. - // - // It reports the number of bytes read, n, - // the packet source address ep, - // and any error. + // ReceiveIPv4 reads an IPv4 UDP packet into b. It reports the number of bytes read, + // n, the packet source address ep, and any error. ReceiveIPv4(b []byte) (n int, ep Endpoint, err error) // Send writes a packet b to address ep. Send(b []byte, ep Endpoint) error - // Close closes the Bind connection. - Close() error -} - -// CreateBind creates a Bind bound to a port. -// -// The value actualPort reports the actual port number the Bind -// object gets bound to. -func CreateBind(port uint16) (b Bind, actualPort uint16, err error) { - return createBind(port) + // ParseEndpoint creates a new endpoint from a string. + ParseEndpoint(s string) (Endpoint, error) } // BindSocketToInterface is implemented by Bind objects that support being @@ -69,8 +59,8 @@ type PeekLookAtSocketFd interface { // An Endpoint maintains the source/destination caching for a peer. // -// dst : the remote address of a peer ("endpoint" in uapi terminology) -// src : the local address from which datagrams originate going to the peer +// dst: the remote address of a peer ("endpoint" in uapi terminology) +// src: the local address from which datagrams originate going to the peer type Endpoint interface { ClearSrc() // clears the source address SrcToString() string // returns the local source address (ip:port) @@ -109,3 +99,8 @@ func parseEndpoint(s string) (*net.UDPAddr, error) { } return addr, err } + +var ( + ErrBindAlreadyOpen = errors.New("bind is already open") + ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") +) diff --git a/conn/conn_default.go b/conn/conn_default.go deleted file mode 100644 index 82a1e42..0000000 --- a/conn/conn_default.go +++ /dev/null @@ -1,171 +0,0 @@ -// +build !linux android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "errors" - "net" - "syscall" -) - -/* This code is meant to be a temporary solution - * on platforms for which the sticky socket / source caching behavior - * has not yet been implemented. - * - * See conn_linux.go for an implementation on the linux platform. - */ - -type nativeBind struct { - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool -} - -type NativeEndpoint net.UDPAddr - -var _ Bind = (*nativeBind)(nil) -var _ Endpoint = (*NativeEndpoint)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*NativeEndpoint)(addr), err -} - -func (*NativeEndpoint) ClearSrc() {} - -func (e *NativeEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP -} - -func (e *NativeEndpoint) SrcIP() net.IP { - return nil // not supported -} - -func (e *NativeEndpoint) DstToBytes() []byte { - addr := (*net.UDPAddr)(e) - out := addr.IP.To4() - if out == nil { - out = addr.IP - } - out = append(out, byte(addr.Port&0xff)) - out = append(out, byte((addr.Port>>8)&0xff)) - return out -} - -func (e *NativeEndpoint) DstToString() string { - return (*net.UDPAddr)(e).String() -} - -func (e *NativeEndpoint) SrcToString() string { - return "" -} - -func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err != nil { - return nil, 0, err - } - - // Retrieve port. - laddr := conn.LocalAddr() - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - return nil, 0, err - } - return conn, uaddr.Port, nil -} - -func createBind(uport uint16) (Bind, uint16, error) { - var err error - var bind nativeBind - var tries int - -again: - port := int(uport) - - bind.ipv4, port, err = listenNet("udp4", port) - if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - return nil, 0, err - } - - bind.ipv6, port, err = listenNet("udp6", port) - if uport == 0 && err != nil && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { - bind.ipv4.Close() - tries++ - goto again - } - if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { - bind.ipv4.Close() - bind.ipv4 = nil - return nil, 0, err - } - - return &bind, uint16(port), nil -} - -func (bind *nativeBind) Close() error { - var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - } - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *nativeBind) LastMark() uint32 { return 0 } - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - if bind.ipv4 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - if endpoint != nil { - endpoint.IP = endpoint.IP.To4() - } - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - if bind.ipv6 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error { - var err error - nend := endpoint.(*NativeEndpoint) - if nend.IP.To4() != nil { - if bind.ipv4 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole4 { - return nil - } - _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } else { - if bind.ipv6 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole6 { - return nil - } - _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } - return err -} diff --git a/conn/conn_linux.go b/conn/conn_linux.go deleted file mode 100644 index 58b7de1..0000000 --- a/conn/conn_linux.go +++ /dev/null @@ -1,576 +0,0 @@ -// +build !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "errors" - "net" - "strconv" - "sync" - "syscall" - "unsafe" - - "golang.org/x/sys/unix" -) - -type IPv4Source struct { - Src [4]byte - Ifindex int32 -} - -type IPv6Source struct { - src [16]byte - //ifindex belongs in dst.ZoneId -} - -type NativeEndpoint struct { - sync.Mutex - dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte - src [unsafe.Sizeof(IPv6Source{})]byte - isV6 bool -} - -func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() } -func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } -func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 } - -func (endpoint *NativeEndpoint) src4() *IPv4Source { - return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) src6() *IPv6Source { - return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { - return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) -} - -func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { - return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) -} - -type nativeBind struct { - sock4 int - sock6 int - lastMark uint32 - closing sync.RWMutex -} - -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = (*nativeBind)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - var end NativeEndpoint - addr, err := parseEndpoint(s) - if err != nil { - return nil, err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := end.dst4() - end.isV6 = false - dst.Port = addr.Port - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return &end, nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return nil, err - } - dst := end.dst6() - end.isV6 = true - dst.Port = addr.Port - dst.ZoneId = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("Invalid IP address") -} - -func createBind(port uint16) (Bind, uint16, error) { - var err error - var bind nativeBind - var newPort uint16 - var tries int - originalPort := port - -again: - port = originalPort - // Attempt ipv6 bind, update port if successful. - bind.sock6, newPort, err = create6(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - return nil, 0, err - } - } else { - port = newPort - } - - // Attempt ipv4 bind, update port if successful. - bind.sock4, newPort, err = create4(port) - if err != nil { - if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 { - unix.Close(bind.sock6) - tries++ - goto again - } - if err != syscall.EAFNOSUPPORT { - unix.Close(bind.sock6) - return nil, 0, err - } - } else { - port = newPort - } - - if bind.sock4 == -1 && bind.sock6 == -1 { - return nil, 0, errors.New("ipv4 and ipv6 not supported") - } - - return &bind, port, nil -} - -func (bind *nativeBind) LastMark() uint32 { - return bind.lastMark -} - -func (bind *nativeBind) SetMark(value uint32) error { - bind.closing.RLock() - defer bind.closing.RUnlock() - - if bind.sock6 != -1 { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - if bind.sock4 != -1 { - err := unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - bind.lastMark = value - return nil -} - -func (bind *nativeBind) Close() error { - var err1, err2 error - bind.closing.RLock() - if bind.sock6 != -1 { - unix.Shutdown(bind.sock6, unix.SHUT_RDWR) - } - if bind.sock4 != -1 { - unix.Shutdown(bind.sock4, unix.SHUT_RDWR) - } - bind.closing.RUnlock() - bind.closing.Lock() - if bind.sock6 != -1 { - err1 = unix.Close(bind.sock6) - bind.sock6 = -1 - } - if bind.sock4 != -1 { - err2 = unix.Close(bind.sock4) - bind.sock4 = -1 - } - bind.closing.Unlock() - - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - bind.closing.RLock() - defer bind.closing.RUnlock() - - var end NativeEndpoint - if bind.sock6 == -1 { - return 0, nil, net.ErrClosed - } - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - bind.closing.RLock() - defer bind.closing.RUnlock() - - var end NativeEndpoint - if bind.sock4 == -1 { - return 0, nil, net.ErrClosed - } - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) Send(buff []byte, end Endpoint) error { - bind.closing.RLock() - defer bind.closing.RUnlock() - - nend := end.(*NativeEndpoint) - if !nend.isV6 { - if bind.sock4 == -1 { - return net.ErrClosed - } - return send4(bind.sock4, nend, buff) - } else { - if bind.sock6 == -1 { - return net.ErrClosed - } - return send6(bind.sock6, nend, buff) - } -} - -func (end *NativeEndpoint) SrcIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.src4().Src[0], - end.src4().Src[1], - end.src4().Src[2], - end.src4().Src[3], - ) - } else { - return end.src6().src[:] - } -} - -func (end *NativeEndpoint) DstIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.dst4().Addr[0], - end.dst4().Addr[1], - end.dst4().Addr[2], - end.dst4().Addr[3], - ) - } else { - return end.dst6().Addr[:] - } -} - -func (end *NativeEndpoint) DstToBytes() []byte { - if !end.isV6 { - return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] - } else { - return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] - } -} - -func (end *NativeEndpoint) SrcToString() string { - return end.SrcIP().String() -} - -func (end *NativeEndpoint) DstToString() string { - var udpAddr net.UDPAddr - udpAddr.IP = end.DstIP() - if !end.isV6 { - udpAddr.Port = end.dst4().Port - } else { - udpAddr.Port = end.dst6().Port - } - return udpAddr.String() -} - -func (end *NativeEndpoint) ClearDst() { - for i := range end.dst { - end.dst[i] = 0 - } -} - -func (end *NativeEndpoint) ClearSrc() { - for i := range end.src { - end.src[i] = 0 - } -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return -1, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return -1, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet4).Port - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return -1, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - - }(); err != nil { - unix.Close(fd) - return -1, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet6).Port - } - - return fd, uint16(addr.Port), err -} - -func send4(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: end.src4().Src, - Ifindex: end.src4().Ifindex, - }, - } - - end.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - end.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() - } - - return err -} - -func send6(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src6().src, - Ifindex: end.dst6().ZoneId, - }, - } - - if cmsg.pktinfo.Addr == [16]byte{} { - cmsg.pktinfo.Ifindex = 0 - } - - end.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - end.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() - } - - return err -} - -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = false - - if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { - *end.dst4() = *newDst4 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().Src = cmsg.pktinfo.Spec_dst - end.src4().Ifindex = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = true - - if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { - *end.dst6() = *newDst6 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6().src = cmsg.pktinfo.Addr - end.dst6().ZoneId = cmsg.pktinfo.Ifindex - } - - return size, nil -} diff --git a/conn/default.go b/conn/default.go new file mode 100644 index 0000000..cd9bfb0 --- /dev/null +++ b/conn/default.go @@ -0,0 +1,10 @@ +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019-2021 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func NewDefaultBind() Bind { return NewStdNetBind() } diff --git a/conn/mark_default.go b/conn/mark_default.go index 0f00f6f..c315f4b 100644 --- a/conn/mark_default.go +++ b/conn/mark_default.go @@ -7,6 +7,6 @@ package conn -func (bind *nativeBind) SetMark(mark uint32) error { +func (bind *StdNetBind) SetMark(mark uint32) error { return nil } diff --git a/conn/mark_unix.go b/conn/mark_unix.go index c29f247..18eb581 100644 --- a/conn/mark_unix.go +++ b/conn/mark_unix.go @@ -1,4 +1,4 @@ -// +build android openbsd freebsd +// +build linux openbsd freebsd /* SPDX-License-Identifier: MIT * @@ -26,7 +26,7 @@ func init() { } } -func (bind *nativeBind) SetMark(mark uint32) error { +func (bind *StdNetBind) SetMark(mark uint32) error { var operr error if fwmarkIoctl == 0 { return nil -- cgit v1.2.3-59-g8ed1b