aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-10-06 22:56:01 +0200
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-10-06 22:56:01 +0200
commitc70f0c5da2a97715f5989f0d95ec795bdb085898 (patch)
tree64d41909fd4b57b21333d14fa5564a9916a422cd
parentSleep to close fd (diff)
downloadwireguard-go-c70f0c5da2a97715f5989f0d95ec795bdb085898.tar.xz
wireguard-go-c70f0c5da2a97715f5989f0d95ec795bdb085898.zip
Definition of platform specific socket bind
-rw-r--r--src/conn.go2
-rw-r--r--src/conn_default.go2
-rw-r--r--src/conn_linux.go230
-rw-r--r--src/uapi.go2
4 files changed, 197 insertions, 39 deletions
diff --git a/src/conn.go b/src/conn.go
index 2cf588d..60cd789 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -56,7 +56,7 @@ func updateUDPConn(device *Device) error {
// set fwmark
- err = setMark(netc.conn, netc.fwmark)
+ err = SetMark(netc.conn, netc.fwmark)
if err != nil {
return err
}
diff --git a/src/conn_default.go b/src/conn_default.go
index e7c60a8..279643e 100644
--- a/src/conn_default.go
+++ b/src/conn_default.go
@@ -6,6 +6,6 @@ import (
"net"
)
-func setMark(conn *net.UDPConn, value uint32) error {
+func SetMark(conn *net.UDPConn, value uint32) error {
return nil
}
diff --git a/src/conn_linux.go b/src/conn_linux.go
index a349a9e..64447a5 100644
--- a/src/conn_linux.go
+++ b/src/conn_linux.go
@@ -14,23 +14,30 @@ import (
"unsafe"
)
+import "fmt"
+
/* Supports source address caching
*
- * It is important that the endpoint is only updated after the packet content has been authenticated.
- *
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
+ * So this code is platform dependent.
+ *
+ * It is important that the endpoint is only updated after the packet content has been authenticated!
*/
+
type Endpoint struct {
// source (selected based on dst type)
// (could use RawSockaddrAny and unsafe)
- srcIPv6 unix.RawSockaddrInet6
- srcIPv4 unix.RawSockaddrInet4
- srcIf4 int32
+ src6 unix.RawSockaddrInet6
+ src4 unix.RawSockaddrInet4
+ src4if int32
dst unix.RawSockaddrAny
}
+type IPv4Socket int
+type IPv6Socket int
+
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
return 0, nil
@@ -42,10 +49,115 @@ func zoneToUint32(zone string) (uint32, error) {
return uint32(n), err
}
+func CreateIPv4Socket(port int) (IPv4Socket, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET,
+ unix.SOCK_DGRAM,
+ 0,
+ )
+
+ if err != nil {
+ return -1, err
+ }
+
+ // set sockopts and bind
+
+ if err := func() error {
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IP,
+ unix.IP_PKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ addr := unix.SockaddrInet4{
+ Port: port,
+ }
+ return unix.Bind(fd, &addr)
+
+ }(); err != nil {
+ unix.Close(fd)
+ }
+
+ return IPv4Socket(fd), err
+}
+
+func CreateIPv6Socket(port int) (IPv6Socket, error) {
+
+ // create socket
+
+ fd, err := unix.Socket(
+ unix.AF_INET,
+ unix.SOCK_DGRAM,
+ 0,
+ )
+
+ if err != nil {
+ return -1, err
+ }
+
+ // set sockopts and bind
+
+ if err := func() error {
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.SOL_SOCKET,
+ unix.SO_REUSEADDR,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_RECVPKTINFO,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ if err := unix.SetsockoptInt(
+ fd,
+ unix.IPPROTO_IPV6,
+ unix.IPV6_V6ONLY,
+ 1,
+ ); err != nil {
+ return err
+ }
+
+ addr := unix.SockaddrInet6{
+ Port: port,
+ }
+ return unix.Bind(fd, &addr)
+
+ }(); err != nil {
+ unix.Close(fd)
+ }
+
+ return IPv6Socket(fd), err
+}
+
func (end *Endpoint) ClearSrc() {
- end.srcIf4 = 0
- end.srcIPv4 = unix.RawSockaddrInet4{}
- end.srcIPv6 = unix.RawSockaddrInet6{}
+ end.src4if = 0
+ end.src4 = unix.RawSockaddrInet4{}
+ end.src6 = unix.RawSockaddrInet6{}
}
func (end *Endpoint) Set(s string) error {
@@ -85,8 +197,10 @@ func (end *Endpoint) Set(s string) error {
}
func send6(sock uintptr, end *Endpoint, buff []byte) error {
- var iovec unix.Iovec
+ // construct message header
+
+ var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
@@ -100,8 +214,8 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
Len: unix.SizeofInet6Pktinfo,
},
unix.Inet6Pktinfo{
- Addr: end.srcIPv6.Addr,
- Ifindex: end.srcIPv6.Scope_id,
+ Addr: end.src6.Addr,
+ Ifindex: end.src6.Scope_id,
},
}
@@ -130,8 +244,10 @@ func send6(sock uintptr, end *Endpoint, buff []byte) error {
}
func send4(sock uintptr, end *Endpoint, buff []byte) error {
- var iovec unix.Iovec
+ // construct message header
+
+ var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
@@ -142,11 +258,11 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
- Len: unix.SizeofInet6Pktinfo,
+ Len: unix.SizeofInet4Pktinfo,
},
unix.Inet4Pktinfo{
- Spec_dst: end.srcIPv4.Addr,
- Ifindex: end.srcIf4,
+ Spec_dst: end.src4.Addr,
+ Ifindex: end.src4if,
},
}
@@ -174,7 +290,7 @@ func send4(sock uintptr, end *Endpoint, buff []byte) error {
return errno
}
-func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
+func (end *Endpoint) Send(c *net.UDPConn, buff []byte) error {
// extract underlying file descriptor
@@ -195,60 +311,102 @@ func send(c *net.UDPConn, end *Endpoint, buff []byte) error {
return errors.New("Unknown address family of source")
}
-func receiveIPv4(end *Endpoint, c *net.UDPConn, buff []byte) (error, *net.UDPAddr, *net.UDPAddr) {
+func (end *Endpoint) ReceiveIPv4(sock IPv4Socket, buff []byte) (int, error) {
- file, err := c.File()
- if err != nil {
- return err, nil, nil
+ // contruct message header
+
+ var iovec unix.Iovec
+ iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
+ iovec.SetLen(len(buff))
+
+ var cmsg struct {
+ cmsghdr unix.Cmsghdr
+ pktinfo unix.Inet4Pktinfo
}
+ var msghdr unix.Msghdr
+ msghdr.Iov = &iovec
+ msghdr.Iovlen = 1
+ msghdr.Name = (*byte)(unsafe.Pointer(&end.dst))
+ msghdr.Namelen = unix.SizeofSockaddrInet4
+ msghdr.Control = (*byte)(unsafe.Pointer(&cmsg))
+ msghdr.SetControllen(int(unsafe.Sizeof(cmsg)))
+
+ // recvmsg(sock, &mskhdr, 0)
+
+ size, _, errno := unix.Syscall(
+ unix.SYS_RECVMSG,
+ uintptr(sock),
+ uintptr(unsafe.Pointer(&msghdr)),
+ 0,
+ )
+
+ if errno != 0 {
+ return 0, errno
+ }
+
+ fmt.Println(msghdr)
+ fmt.Println(cmsg)
+
+ // update source cache
+
+ if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
+ cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
+ cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
+ end.src4.Addr = cmsg.pktinfo.Spec_dst
+ end.src4if = cmsg.pktinfo.Ifindex
+ }
+
+ return int(size), nil
+}
+
+func (end *Endpoint) ReceiveIPv6(sock IPv6Socket, buff []byte) error {
+
+ // contruct message header
+
var iovec unix.Iovec
iovec.Base = (*byte)(unsafe.Pointer(&buff[0]))
iovec.SetLen(len(buff))
var cmsg struct {
cmsghdr unix.Cmsghdr
- pktinfo unix.Inet6Pktinfo // big enough
+ pktinfo unix.Inet6Pktinfo
}
var msg unix.Msghdr
msg.Iov = &iovec
msg.Iovlen = 1
msg.Name = (*byte)(unsafe.Pointer(&end.dst))
- msg.Namelen = uint32(unix.SizeofSockaddrAny)
+ msg.Namelen = uint32(unix.SizeofSockaddrInet6)
msg.Control = (*byte)(unsafe.Pointer(&cmsg))
msg.SetControllen(int(unsafe.Sizeof(cmsg)))
+ // recvmsg(sock, &mskhdr, 0)
+
_, _, errno := unix.Syscall(
unix.SYS_RECVMSG,
- file.Fd(),
+ uintptr(sock),
uintptr(unsafe.Pointer(&msg)),
0,
)
if errno != 0 {
- return errno, nil, nil
+ return errno
}
+ // update source cache
+
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
-
- }
-
- if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
- cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
- cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
-
- info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&cmsg.pktinfo))
- println(info)
-
+ end.src6.Addr = cmsg.pktinfo.Addr
+ end.src6.Scope_id = cmsg.pktinfo.Ifindex
}
- return nil, nil, nil
+ return nil
}
-func setMark(conn *net.UDPConn, value uint32) error {
+func SetMark(conn *net.UDPConn, value uint32) error {
if conn == nil {
return nil
}
diff --git a/src/uapi.go b/src/uapi.go
index 326216b..7d08e56 100644
--- a/src/uapi.go
+++ b/src/uapi.go
@@ -166,7 +166,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.net.mutex.Lock()
if fwmark > 0 || device.net.fwmark > 0 {
device.net.fwmark = uint32(fwmark)
- err := setMark(
+ err := SetMark(
device.net.conn,
device.net.fwmark,
)