summaryrefslogtreecommitdiffstats
path: root/conn
diff options
context:
space:
mode:
authorDavid Crawshaw <crawshaw@tailscale.com>2019-11-07 11:13:05 -0500
committerJason A. Donenfeld <Jason@zx2c4.com>2020-05-02 01:46:42 -0600
commit203554620dc8114de1ff70bb30b80f828e9e26ad (patch)
tree49f36961f2090dc07d15cad3204a1a3531dc233d /conn
parentwintun: split error message for create vs open namespace. (diff)
downloadwireguard-go-203554620dc8114de1ff70bb30b80f828e9e26ad.tar.xz
wireguard-go-203554620dc8114de1ff70bb30b80f828e9e26ad.zip
conn: introduce new package that splits out the Bind and Endpoint types
The sticky socket code stays in the device package for now, as it reaches deeply into the peer list. This is the first step in an effort to split some code out of the very busy device package. Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
Diffstat (limited to 'conn')
-rw-r--r--conn/boundif_windows.go59
-rw-r--r--conn/conn.go101
-rw-r--r--conn/conn_default.go177
-rw-r--r--conn/conn_linux.go571
-rw-r--r--conn/mark_default.go12
-rw-r--r--conn/mark_unix.go65
6 files changed, 985 insertions, 0 deletions
diff --git a/conn/boundif_windows.go b/conn/boundif_windows.go
new file mode 100644
index 0000000..fe38d05
--- /dev/null
+++ b/conn/boundif_windows.go
@@ -0,0 +1,59 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "encoding/binary"
+ "unsafe"
+
+ "golang.org/x/sys/windows"
+)
+
+const (
+ sockoptIP_UNICAST_IF = 31
+ sockoptIPV6_UNICAST_IF = 31
+)
+
+func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
+ /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
+ bytes := make([]byte, 4)
+ binary.BigEndian.PutUint32(bytes, interfaceIndex)
+ interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
+
+ sysconn, err := bind.ipv4.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err2 := sysconn.Control(func(fd uintptr) {
+ err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex))
+ })
+ if err2 != nil {
+ return err2
+ }
+ if err != nil {
+ return err
+ }
+ bind.blackhole4 = blackhole
+ return nil
+}
+
+func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
+ sysconn, err := bind.ipv6.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err2 := sysconn.Control(func(fd uintptr) {
+ err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex))
+ })
+ if err2 != nil {
+ return err2
+ }
+ if err != nil {
+ return err
+ }
+ bind.blackhole6 = blackhole
+ return nil
+}
diff --git a/conn/conn.go b/conn/conn.go
new file mode 100644
index 0000000..6b7db12
--- /dev/null
+++ b/conn/conn.go
@@ -0,0 +1,101 @@
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+// Package conn implements WireGuard's network connections.
+package conn
+
+import (
+ "errors"
+ "net"
+ "strings"
+)
+
+// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
+type Bind interface {
+ // LastMark reports the last mark set for this Bind.
+ LastMark() uint32
+
+ // SetMark sets the mark for each packet sent through this Bind.
+ // This mark is passed to the kernel as the socket option SO_MARK.
+ SetMark(mark uint32) error
+
+ // ReceiveIPv6 reads an IPv6 UDP packet into b.
+ //
+ // It reports the number of bytes read, n,
+ // the packet source address ep,
+ // and any error.
+ ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
+
+ // ReceiveIPv4 reads an IPv4 UDP packet into b.
+ //
+ // It reports the number of bytes read, n,
+ // the packet source address ep,
+ // and any error.
+ ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
+
+ // Send writes a packet b to address ep.
+ Send(b []byte, ep Endpoint) error
+
+ // Close closes the Bind connection.
+ Close() error
+}
+
+// CreateBind creates a Bind bound to a port.
+//
+// The value actualPort reports the actual port number the Bind
+// object gets bound to.
+func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
+ return createBind(port)
+}
+
+// BindToInterface is implemented by Bind objects that support being
+// tied to a single network interface.
+type BindToInterface interface {
+ BindToInterface4(interfaceIndex uint32, blackhole bool) error
+ BindToInterface6(interfaceIndex uint32, blackhole bool) error
+}
+
+// An Endpoint maintains the source/destination caching for a peer.
+//
+// dst : the remote address of a peer ("endpoint" in uapi terminology)
+// src : the local address from which datagrams originate going to the peer
+type Endpoint interface {
+ ClearSrc() // clears the source address
+ SrcToString() string // returns the local source address (ip:port)
+ DstToString() string // returns the destination address (ip:port)
+ DstToBytes() []byte // used for mac2 cookie calculations
+ DstIP() net.IP
+ SrcIP() net.IP
+}
+
+func parseEndpoint(s string) (*net.UDPAddr, error) {
+ // ensure that the host is an IP address
+
+ host, _, err := net.SplitHostPort(s)
+ if err != nil {
+ return nil, err
+ }
+ if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
+ // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
+ // trying to make sure with a small sanity test that this is a real IP address and
+ // not something that's likely to incur DNS lookups.
+ host = host[:i]
+ }
+ if ip := net.ParseIP(host); ip == nil {
+ return nil, errors.New("Failed to parse IP address: " + host)
+ }
+
+ // parse address and port
+
+ addr, err := net.ResolveUDPAddr("udp", s)
+ if err != nil {
+ return nil, err
+ }
+ ip4 := addr.IP.To4()
+ if ip4 != nil {
+ addr.IP = ip4
+ }
+ return addr, err
+}
diff --git a/conn/conn_default.go b/conn/conn_default.go
new file mode 100644
index 0000000..bad9d4d
--- /dev/null
+++ b/conn/conn_default.go
@@ -0,0 +1,177 @@
+// +build !linux android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "net"
+ "os"
+ "syscall"
+)
+
+/* This code is meant to be a temporary solution
+ * on platforms for which the sticky socket / source caching behavior
+ * has not yet been implemented.
+ *
+ * See conn_linux.go for an implementation on the linux platform.
+ */
+
+type nativeBind struct {
+ ipv4 *net.UDPConn
+ ipv6 *net.UDPConn
+ blackhole4 bool
+ blackhole6 bool
+}
+
+type NativeEndpoint net.UDPAddr
+
+var _ Bind = (*nativeBind)(nil)
+var _ Endpoint = (*NativeEndpoint)(nil)
+
+func CreateEndpoint(s string) (Endpoint, error) {
+ addr, err := parseEndpoint(s)
+ return (*NativeEndpoint)(addr), err
+}
+
+func (_ *NativeEndpoint) ClearSrc() {}
+
+func (e *NativeEndpoint) DstIP() net.IP {
+ return (*net.UDPAddr)(e).IP
+}
+
+func (e *NativeEndpoint) SrcIP() net.IP {
+ return nil // not supported
+}
+
+func (e *NativeEndpoint) DstToBytes() []byte {
+ addr := (*net.UDPAddr)(e)
+ out := addr.IP.To4()
+ if out == nil {
+ out = addr.IP
+ }
+ out = append(out, byte(addr.Port&0xff))
+ out = append(out, byte((addr.Port>>8)&0xff))
+ return out
+}
+
+func (e *NativeEndpoint) DstToString() string {
+ return (*net.UDPAddr)(e).String()
+}
+
+func (e *NativeEndpoint) SrcToString() string {
+ return ""
+}
+
+func listenNet(network string, port int) (*net.UDPConn, int, error) {
+ conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
+ if err != nil {
+ return nil, 0, err
+ }
+
+ // Retrieve port.
+ // TODO(crawshaw): under what circumstances is this necessary?
+ laddr := conn.LocalAddr()
+ uaddr, err := net.ResolveUDPAddr(
+ laddr.Network(),
+ laddr.String(),
+ )
+ if err != nil {
+ return nil, 0, err
+ }
+ return conn, uaddr.Port, nil
+}
+
+func extractErrno(err error) error {
+ opErr, ok := err.(*net.OpError)
+ if !ok {
+ return nil
+ }
+ syscallErr, ok := opErr.Err.(*os.SyscallError)
+ if !ok {
+ return nil
+ }
+ return syscallErr.Err
+}
+
+func createBind(uport uint16) (Bind, uint16, error) {
+ var err error
+ var bind nativeBind
+
+ port := int(uport)
+
+ bind.ipv4, port, err = listenNet("udp4", port)
+ if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
+ return nil, 0, err
+ }
+
+ bind.ipv6, port, err = listenNet("udp6", port)
+ if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
+ bind.ipv4.Close()
+ bind.ipv4 = nil
+ return nil, 0, err
+ }
+
+ return &bind, uint16(port), nil
+}
+
+func (bind *nativeBind) Close() error {
+ var err1, err2 error
+ if bind.ipv4 != nil {
+ err1 = bind.ipv4.Close()
+ }
+ if bind.ipv6 != nil {
+ err2 = bind.ipv6.Close()
+ }
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+func (bind *nativeBind) LastMark() uint32 { return 0 }
+
+func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ if bind.ipv4 == nil {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
+ if endpoint != nil {
+ endpoint.IP = endpoint.IP.To4()
+ }
+ return n, (*NativeEndpoint)(endpoint), err
+}
+
+func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ if bind.ipv6 == nil {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
+ return n, (*NativeEndpoint)(endpoint), err
+}
+
+func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
+ var err error
+ nend := endpoint.(*NativeEndpoint)
+ if nend.IP.To4() != nil {
+ if bind.ipv4 == nil {
+ return syscall.EAFNOSUPPORT
+ }
+ if bind.blackhole4 {
+ return nil
+ }
+ _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
+ } else {
+ if bind.ipv6 == nil {
+ return syscall.EAFNOSUPPORT
+ }
+ if bind.blackhole6 {
+ return nil
+ }
+ _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
+ }
+ return err
+}
diff --git a/conn/conn_linux.go b/conn/conn_linux.go
new file mode 100644
index 0000000..523da4a
--- /dev/null
+++ b/conn/conn_linux.go
@@ -0,0 +1,571 @@
+// +build !android
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "errors"
+ "net"
+ "strconv"
+ "sync"
+ "syscall"
+ "unsafe"
+
+ "golang.org/x/sys/unix"
+)
+
+const (
+ FD_ERR = -1
+)
+
+type IPv4Source struct {
+ Src [4]byte
+ Ifindex int32
+}
+
+type IPv6Source struct {
+ src [16]byte
+ //ifindex belongs in dst.ZoneId
+}
+
+type NativeEndpoint struct {
+ sync.Mutex
+ dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
+ src [unsafe.Sizeof(IPv6Source{})]byte
+ isV6 bool
+}
+
+func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
+func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
+func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
+
+func (endpoint *NativeEndpoint) src4() *IPv4Source {
+ return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
+}
+
+func (endpoint *NativeEndpoint) src6() *IPv6Source {
+ return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
+}
+
+func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
+ return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
+}
+
+func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
+ return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
+}
+
+type nativeBind struct {
+ sock4 int
+ sock6 int
+ lastMark uint32
+}
+
+var _ Endpoint = (*NativeEndpoint)(nil)
+var _ Bind = (*nativeBind)(nil)
+
+func CreateEndpoint(s string) (Endpoint, error) {
+ var end NativeEndpoint
+ addr, err := parseEndpoint(s)
+ if err != nil {
+ return nil, err
+ }
+
+ ipv4 := addr.IP.To4()
+ if ipv4 != nil {
+ dst := end.dst4()
+ end.isV6 = false
+ dst.Port = addr.Port
+ copy(dst.Addr[:], ipv4)
+ end.ClearSrc()
+ return &end, nil
+ }
+
+ ipv6 := addr.IP.To16()
+ if ipv6 != nil {
+ zone, err := zoneToUint32(addr.Zone)
+ if err != nil {
+ return nil, err
+ }
+ dst := end.dst6()
+ end.isV6 = true
+ dst.Port = addr.Port
+ dst.ZoneId = zone
+ copy(dst.Addr[:], ipv6[:])
+ end.ClearSrc()
+ return &end, nil
+ }
+
+ return nil, errors.New("Invalid IP address")
+}
+
+func createBind(port uint16) (Bind, uint16, error) {
+ var err error
+ var bind nativeBind
+ var newPort uint16
+
+ // Attempt ipv6 bind, update port if successful.
+ bind.sock6, newPort, err = create6(port)
+ if err != nil {
+ if err != syscall.EAFNOSUPPORT {
+ return nil, 0, err
+ }
+ } else {
+ port = newPort
+ }
+
+ // Attempt ipv4 bind, update port if successful.
+ bind.sock4, newPort, err = create4(port)
+ if err != nil {
+ if err != syscall.EAFNOSUPPORT {
+ unix.Close(bind.sock6)
+ return nil, 0, err
+ }
+ } else {
+ port = newPort
+ }
+
+ if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
+ return nil, 0, errors.New("ipv4 and ipv6 not supported")
+ }
+
+ return &bind, port, nil
+}
+
+func (bind *nativeBind) LastMark() uint32 {
+ return bind.lastMark
+}
+
+func (bind *nativeBind) SetMark(value uint32) error {
+ if bind.sock6 != -1 {
+ err := unix.SetsockoptInt(
+ bind.sock6,
+ unix.SOL_SOCKET,
+ unix.SO_MARK,
+ int(value),
+ )
+
+ if err != nil {
+ return err
+ }
+ }
+
+ if bind.sock4 != -1 {
+ err := unix.SetsockoptInt(
+ bind.sock4,
+ unix.SOL_SOCKET,
+ unix.SO_MARK,
+ int(value),
+ )
+
+ if err != nil {
+ return err
+ }
+ }
+
+ bind.lastMark = value
+ return nil
+}
+
+func closeUnblock(fd int) error {
+ // shutdown to unblock readers and writers
+ unix.Shutdown(fd, unix.SHUT_RDWR)
+ return unix.Close(fd)
+}
+
+func (bind *nativeBind) Close() error {
+ var err1, err2 error
+ if bind.sock6 != -1 {
+ err1 = closeUnblock(bind.sock6)
+ }
+ if bind.sock4 != -1 {
+ err2 = closeUnblock(bind.sock4)
+ }
+
+ if err1 != nil {
+ return err1
+ }
+ return err2
+}
+
+func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ if bind.sock6 == -1 {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, err := receive6(
+ bind.sock6,
+ buff,
+ &end,
+ )
+ return n, &end, err
+}
+
+func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ if bind.sock4 == -1 {
+ return 0, nil, syscall.EAFNOSUPPORT
+ }
+ n, err := receive4(
+ bind.sock4,
+ buff,
+ &end,
+ )
+ return n, &end, err
+}
+
+func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
+ nend := end.(*NativeEndpoint)
+ if !nend.isV6 {
+ if bind.sock4 == -1 {
+ return syscall.EAFNOSUPPORT
+ }
+ return send4(bind.sock4, nend, buff)
+ } else {
+ if bind.sock6 == -1 {
+ return syscall.EAFNOSUPPORT
+ }
+ return send6(bind.sock6, nend, buff)
+ }
+}
+
+func (end *NativeEndpoint) SrcIP() net.IP {
+ if !end.isV6 {
+ return net.IPv4(
+ end.src4().Src[0],
+ end.src4().Src[1],
+ end.src4().Src[2],
+ end.src4().Src[3],
+ )
+ } else {
+ return end.src6().src[:]
+ }
+}
+
+func (end *NativeEndpoint) DstIP() net.IP {
+ if !end.isV6 {
+ return net.IPv4(
+ end.dst4().Addr[0],
+ end.dst4().Addr[1],
+ end.dst4().Addr[2],
+ end.dst4().Addr[3],
+ )
+ } else {
+ return end.dst6().Addr[:]
+ }
+}
+
+func (end *NativeEndpoint) DstToBytes() []byte {
+ if !end.isV6 {
+ return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
+ } else {
+ return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
+ }
+}
+
+func (end *NativeEndpoint) SrcToString() string {
+ return end.SrcIP().String()
+}
+
+func (end *NativeEndpoint) DstToString() string {
+ var udpAddr net.UDPAddr
+ udpAddr.IP = end.DstIP()
+ if !end.isV6 {
+ udpAddr.Port = end.dst4().Port
+ } else {
+ udpAddr.Port = end.dst6().Port
+ }
+ return udpAddr.String()
+}
+
+func (end *NativeEndpoint) ClearDst() {
+ for i := range end.dst {
+ end.dst[i] = 0
+ }
+}
+
+func (end *NativeEndpoint) ClearSrc() {
+ for i := range end.src {
+ end.src[i] = 0
+ }
+}
+
+func zoneToUint32(zone string) (uint32, error) {
+ if zone == "" {
+ return 0, nil
+ }
+ if intr, err := net.InterfaceByName(zone); err == nil {
+ return uint32(intr.Index), nil
+ }
+ n, err := strconv.ParseUint(zone, 10, 32)
+ return uint32(n), err
+}
+
+func create4(port uint16) (int, uint16, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET,
+ unix.SOCK_DGRAM,
+ 0,
+ )
+
+ if err != nil {
+ return FD_ERR, 0, err
+ }
+
+ addr := unix.SockaddrInet4{
+ Port: int(port),
+ }
+
+ // set sockopts and bind
+
+ if err := func() error {
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IP,
+ unix.IP_PKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ return unix.Bind(fd, &addr)
+ }(); err != nil {
+ unix.Close(fd)
+ return FD_ERR, 0, err
+ }
+
+ sa, err := unix.Getsockname(fd)
+ if err == nil {
+ addr.Port = sa.(*unix.SockaddrInet4).Port
+ }
+
+ return fd, uint16(addr.Port), err
+}
+
+func create6(port uint16) (int, uint16, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET6,
+ unix.SOCK_DGRAM,
+ 0,
+ )
+
+ if err != nil {
+ return FD_ERR, 0, err
+ }
+
+ // set sockopts and bind
+
+ addr := unix.SockaddrInet6{
+ Port: int(port),
+ }
+
+ if err := func() error {
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_RECVPKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_V6ONLY,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ return unix.Bind(fd, &addr)
+
+ }(); err != nil {
+ unix.Close(fd)
+ return FD_ERR, 0, err
+ }
+
+ sa, err := unix.Getsockname(fd)
+ if err == nil {
+ addr.Port = sa.(*unix.SockaddrInet6).Port
+ }
+
+ return fd, uint16(addr.Port), err
+}
+
+func send4(sock int, end *NativeEndpoint, buff []byte) error {
+
+ // construct message header
+
+ cmsg := struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet4Pktinfo
+ }{
+ unix.Cmsghdr{
+ Level: unix.IPPROTO_IP,
+ Type: unix.IP_PKTINFO,
+ Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
+ },
+ unix.Inet4Pktinfo{
+ Spec_dst: end.src4().Src,
+ Ifindex: end.src4().Ifindex,
+ },
+ }
+
+ end.Lock()
+ _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
+ end.Unlock()
+
+ if err == nil {
+ return nil
+ }
+
+ // clear src and retry
+
+ if err == unix.EINVAL {
+ end.ClearSrc()
+ cmsg.pktinfo = unix.Inet4Pktinfo{}
+ end.Lock()
+ _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
+ end.Unlock()
+ }
+
+ return err
+}
+
+func send6(sock int, end *NativeEndpoint, buff []byte) error {
+
+ // construct message header
+
+ cmsg := struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet6Pktinfo
+ }{
+ unix.Cmsghdr{
+ Level: unix.IPPROTO_IPV6,
+ Type: unix.IPV6_PKTINFO,
+ Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
+ },
+ unix.Inet6Pktinfo{
+ Addr: end.src6().src,
+ Ifindex: end.dst6().ZoneId,
+ },
+ }
+
+ if cmsg.pktinfo.Addr == [16]byte{} {
+ cmsg.pktinfo.Ifindex = 0
+ }
+
+ end.Lock()
+ _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
+ end.Unlock()
+
+ if err == nil {
+ return nil
+ }
+
+ // clear src and retry
+
+ if err == unix.EINVAL {
+ end.ClearSrc()
+ cmsg.pktinfo = unix.Inet6Pktinfo{}
+ end.Lock()
+ _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
+ end.Unlock()
+ }
+
+ return err
+}
+
+func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
+
+ // construct message header
+
+ var cmsg struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet4Pktinfo
+ }
+
+ size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
+
+ if err != nil {
+ return 0, err
+ }
+ end.isV6 = false
+
+ if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
+ *end.dst4() = *newDst4
+ }
+
+ // update source cache
+
+ if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
+ cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
+ cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
+ end.src4().Src = cmsg.pktinfo.Spec_dst
+ end.src4().Ifindex = cmsg.pktinfo.Ifindex
+ }
+
+ return size, nil
+}
+
+func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
+
+ // construct message header
+
+ var cmsg struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet6Pktinfo
+ }
+
+ size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
+
+ if err != nil {
+ return 0, err
+ }
+ end.isV6 = true
+
+ if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
+ *end.dst6() = *newDst6
+ }
+
+ // update source cache
+
+ if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
+ cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
+ cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
+ end.src6().src = cmsg.pktinfo.Addr
+ end.dst6().ZoneId = cmsg.pktinfo.Ifindex
+ }
+
+ return size, nil
+}
diff --git a/conn/mark_default.go b/conn/mark_default.go
new file mode 100644
index 0000000..fc41ba9
--- /dev/null
+++ b/conn/mark_default.go
@@ -0,0 +1,12 @@
+// +build !linux,!openbsd,!freebsd
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+func (bind *nativeBind) SetMark(mark uint32) error {
+ return nil
+}
diff --git a/conn/mark_unix.go b/conn/mark_unix.go
new file mode 100644
index 0000000..5334582
--- /dev/null
+++ b/conn/mark_unix.go
@@ -0,0 +1,65 @@
+// +build android openbsd freebsd
+
+/* SPDX-License-Identifier: MIT
+ *
+ * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
+ */
+
+package conn
+
+import (
+ "runtime"
+
+ "golang.org/x/sys/unix"
+)
+
+var fwmarkIoctl int
+
+func init() {
+ switch runtime.GOOS {
+ case "linux", "android":
+ fwmarkIoctl = 36 /* unix.SO_MARK */
+ case "freebsd":
+ fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
+ case "openbsd":
+ fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
+ }
+}
+
+func (bind *nativeBind) SetMark(mark uint32) error {
+ var operr error
+ if fwmarkIoctl == 0 {
+ return nil
+ }
+ if bind.ipv4 != nil {
+ fd, err := bind.ipv4.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err = fd.Control(func(fd uintptr) {
+ operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
+ })
+ if err == nil {
+ err = operr
+ }
+ if err != nil {
+ return err
+ }
+ }
+ if bind.ipv6 != nil {
+ fd, err := bind.ipv6.SyscallConn()
+ if err != nil {
+ return err
+ }
+ err = fd.Control(func(fd uintptr) {
+ operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
+ })
+ if err == nil {
+ err = operr
+ }
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}