aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2018-04-27 05:21:45 +0200
committerJason A. Donenfeld <Jason@zx2c4.com>2018-04-27 05:41:07 +0200
commitb34604245ec4dfb50846d0ba28d022be5b756c25 (patch)
tree5a175eeabada57b652189db8a576cb136ff3e4a0
parentFix error handling and cleanup of netlink listener (diff)
downloadwireguard-go-b34604245ec4dfb50846d0ba28d022be5b756c25.tar.xz
wireguard-go-b34604245ec4dfb50846d0ba28d022be5b756c25.zip
Clear src cache if route changes to new ifindex
-rw-r--r--conn_linux.go160
-rw-r--r--tun_linux.go5
2 files changed, 151 insertions, 14 deletions
diff --git a/conn_linux.go b/conn_linux.go
index 88b9ef4..ff3c483 100644
--- a/conn_linux.go
+++ b/conn_linux.go
@@ -53,12 +53,15 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
}
type NativeBind struct {
- sock4 int
- sock6 int
+ sock4 int
+ sock6 int
+ netlinkSock int
+ lastEndpoint *NativeEndpoint
+ lastMark uint32
}
var _ Endpoint = (*NativeEndpoint)(nil)
-var _ Bind = NativeBind{}
+var _ Bind = (*NativeBind)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint
@@ -95,23 +98,50 @@ func CreateEndpoint(s string) (Endpoint, error) {
return nil, errors.New("Invalid IP address")
}
-func CreateBind(port uint16) (Bind, uint16, error) {
+func createNetlinkRouteSocket() (int, error) {
+ sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
+ if err != nil {
+ return -1, err
+ }
+ saddr := &unix.SockaddrNetlink{
+ Family: unix.AF_NETLINK,
+ Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
+ }
+ err = unix.Bind(sock, saddr)
+ if err != nil {
+ unix.Close(sock)
+ return -1, err
+ }
+ return sock, nil
+
+}
+
+func CreateBind(port uint16) (*NativeBind, uint16, error) {
var err error
var bind NativeBind
+ bind.netlinkSock, err = createNetlinkRouteSocket()
+ if err != nil {
+ return nil, 0, err
+ }
+
+ go bind.routineRouteListener()
+
bind.sock6, port, err = create6(port)
if err != nil {
+ unix.Close(bind.netlinkSock)
return nil, port, err
}
bind.sock4, port, err = create4(port)
if err != nil {
+ unix.Close(bind.netlinkSock)
unix.Close(bind.sock6)
}
- return bind, port, err
+ return &bind, port, err
}
-func (bind NativeBind) SetMark(value uint32) error {
+func (bind *NativeBind) SetMark(value uint32) error {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
@@ -123,12 +153,19 @@ func (bind NativeBind) SetMark(value uint32) error {
return err
}
- return unix.SetsockoptInt(
+ err = unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
+
+ if err != nil {
+ return err
+ }
+
+ bind.lastMark = value
+ return nil
}
func closeUnblock(fd int) error {
@@ -137,16 +174,20 @@ func closeUnblock(fd int) error {
return unix.Close(fd)
}
-func (bind NativeBind) Close() error {
+func (bind *NativeBind) Close() error {
err1 := closeUnblock(bind.sock6)
err2 := closeUnblock(bind.sock4)
+ err3 := closeUnblock(bind.netlinkSock)
if err1 != nil {
return err1
}
- return err2
+ if err2 != nil {
+ return err2
+ }
+ return err3
}
-func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
n, err := receive6(
bind.sock6,
@@ -156,17 +197,18 @@ func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
return n, &end, err
}
-func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
n, err := receive4(
bind.sock4,
buff,
&end,
)
+ bind.lastEndpoint = &end
return n, &end, err
}
-func (bind NativeBind) Send(buff []byte, end Endpoint) error {
+func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint)
if !nend.isV6 {
return send4(bind.sock4, nend, buff)
@@ -506,3 +548,97 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
return size, nil
}
+
+func (bind *NativeBind) routineRouteListener() {
+ // TODO: this function doesn't lock the endpoint it modifies
+
+ for msg := make([]byte, 1<<16); ; {
+ msgn, _, _, _, err := unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
+ if err != nil {
+ return
+ }
+
+ for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
+
+ hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
+
+ if uint(hdr.Len) > uint(len(remain)) {
+ break
+ }
+
+ switch hdr.Type {
+ case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
+
+ if bind.lastEndpoint == nil || bind.lastEndpoint.isV6 || bind.lastEndpoint.src4().ifindex == 0 {
+ break
+ }
+
+ if hdr.Seq == 0xff {
+ if uint(len(remain)) < uint(hdr.Len) {
+ break
+ }
+ if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
+ attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
+ for {
+ if uint(len(attr)) < uint(unix.SizeofRtAttr) {
+ break
+ }
+ attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
+ if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
+ break
+ }
+ if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
+ ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
+ if uint32(bind.lastEndpoint.src4().ifindex) != ifidx {
+ bind.lastEndpoint.ClearSrc()
+ }
+ }
+ attr = attr[attrhdr.Len:]
+ }
+ }
+ break
+ }
+
+ nlmsg := struct {
+ hdr unix.NlMsghdr
+ msg unix.RtMsg
+ dsthdr unix.RtAttr
+ dst [4]byte
+ srchdr unix.RtAttr
+ src [4]byte
+ markhdr unix.RtAttr
+ mark uint32
+ }{
+ unix.NlMsghdr{
+ Type: uint16(unix.RTM_GETROUTE),
+ Flags: unix.NLM_F_REQUEST,
+ Seq: 0xff,
+ },
+ unix.RtMsg{
+ Family: unix.AF_INET,
+ Dst_len: 32,
+ Src_len: 32,
+ },
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_DST,
+ },
+ bind.lastEndpoint.dst4().Addr,
+ unix.RtAttr{
+ Len: 8,
+ Type: unix.RTA_SRC,
+ },
+ bind.lastEndpoint.src4().src,
+ unix.RtAttr{
+ Len: 8,
+ Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix
+ },
+ uint32(bind.lastMark),
+ }
+ nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
+ unix.Write(bind.netlinkSock, (*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
+ }
+ remain = remain[hdr.Len:]
+ }
+ }
+}
diff --git a/tun_linux.go b/tun_linux.go
index 0672b5e..b0ffa00 100644
--- a/tun_linux.go
+++ b/tun_linux.go
@@ -79,7 +79,6 @@ func (tun *NativeTun) RoutineNetlinkListener() {
defer unix.Close(sock)
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
- Pid: uint32(os.Getpid()),
Groups: uint32(groups),
}
err = unix.Bind(sock, saddr)
@@ -90,7 +89,9 @@ func (tun *NativeTun) RoutineNetlinkListener() {
// TODO: This function never actually exits in response to anything,
// a go routine that goes forever. We'll want to fix that if this is
- // to ever be used as any sort of library.
+ // to ever be used as any sort of library. See what we've done with
+ // calling shutdown() on the netlink socket in conn_linux.go, and
+ // change this to be more like that.
for msg := make([]byte, 1<<16); ; {