aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-01-07 17:00:21 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-01-07 17:08:58 +0100
commit3b3de758ec898e47aef609fbf16d78e97dac2000 (patch)
treeff243a1b7f76a47843fb2132afbd19cb760754b5
parentdevice: receive: drain decryption queue before exiting RoutineDecryption (diff)
downloadwireguard-go-3b3de758ec898e47aef609fbf16d78e97dac2000.tar.xz
wireguard-go-3b3de758ec898e47aef609fbf16d78e97dac2000.zip
conn: linux: do not allow ReceiveIPvX to race with Close
If Close is called after ReceiveIPvX, then ReceiveIPvX will block on an invalid or potentially reused fd. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--conn/conn_linux.go49
1 files changed, 32 insertions, 17 deletions
diff --git a/conn/conn_linux.go b/conn/conn_linux.go
index ef98100..ef5c0ba 100644
--- a/conn/conn_linux.go
+++ b/conn/conn_linux.go
@@ -18,10 +18,6 @@ import (
"golang.org/x/sys/unix"
)
-const (
- FD_ERR = -1
-)
-
type IPv4Source struct {
Src [4]byte
Ifindex int32
@@ -63,6 +59,7 @@ type nativeBind struct {
sock4 int
sock6 int
lastMark uint32
+ closing sync.RWMutex
}
var _ Endpoint = (*NativeEndpoint)(nil)
@@ -129,7 +126,7 @@ func createBind(port uint16) (Bind, uint16, error) {
port = newPort
}
- if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
+ if bind.sock4 == -1 && bind.sock6 == -1 {
return nil, 0, errors.New("ipv4 and ipv6 not supported")
}
@@ -141,6 +138,9 @@ func (bind *nativeBind) LastMark() uint32 {
}
func (bind *nativeBind) SetMark(value uint32) error {
+ bind.closing.RLock()
+ defer bind.closing.RUnlock()
+
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
bind.sock6,
@@ -171,20 +171,26 @@ func (bind *nativeBind) SetMark(value uint32) error {
return nil
}
-func closeUnblock(fd int) error {
- // shutdown to unblock readers and writers
- unix.Shutdown(fd, unix.SHUT_RDWR)
- return unix.Close(fd)
-}
-
func (bind *nativeBind) Close() error {
var err1, err2 error
+ bind.closing.RLock()
+ if bind.sock6 != -1 {
+ unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
+ }
+ if bind.sock4 != -1 {
+ unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
+ }
+ bind.closing.RUnlock()
+ bind.closing.Lock()
if bind.sock6 != -1 {
- err1 = closeUnblock(bind.sock6)
+ err1 = unix.Close(bind.sock6)
+ bind.sock6 = -1
}
if bind.sock4 != -1 {
- err2 = closeUnblock(bind.sock4)
+ err2 = unix.Close(bind.sock4)
+ bind.sock4 = -1
}
+ bind.closing.Unlock()
if err1 != nil {
return err1
@@ -193,6 +199,9 @@ func (bind *nativeBind) Close() error {
}
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ bind.closing.RLock()
+ defer bind.closing.RUnlock()
+
var end NativeEndpoint
if bind.sock6 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
@@ -206,6 +215,9 @@ func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
}
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ bind.closing.RLock()
+ defer bind.closing.RUnlock()
+
var end NativeEndpoint
if bind.sock4 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
@@ -219,6 +231,9 @@ func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
}
func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
+ bind.closing.RLock()
+ defer bind.closing.RUnlock()
+
nend := end.(*NativeEndpoint)
if !nend.isV6 {
if bind.sock4 == -1 {
@@ -316,7 +331,7 @@ func create4(port uint16) (int, uint16, error) {
)
if err != nil {
- return FD_ERR, 0, err
+ return -1, 0, err
}
addr := unix.SockaddrInet4{
@@ -338,7 +353,7 @@ func create4(port uint16) (int, uint16, error) {
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
- return FD_ERR, 0, err
+ return -1, 0, err
}
sa, err := unix.Getsockname(fd)
@@ -360,7 +375,7 @@ func create6(port uint16) (int, uint16, error) {
)
if err != nil {
- return FD_ERR, 0, err
+ return -1, 0, err
}
// set sockopts and bind
@@ -392,7 +407,7 @@ func create6(port uint16) (int, uint16, error) {
}(); err != nil {
unix.Close(fd)
- return FD_ERR, 0, err
+ return -1, 0, err
}
sa, err := unix.Getsockname(fd)