aboutsummaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMathias Hall-Andersen <mathias@hall-andersen.dk>2017-11-18 23:34:02 +0100
committerMathias Hall-Andersen <mathias@hall-andersen.dk>2017-11-18 23:34:02 +0100
commitd10126f883ad39567248540347b5469956ab8b2e (patch)
treea83329093198bd5dd2c7770835a3851e6d23d880 /src
parentPorted remaining netns.sh (diff)
downloadwireguard-go-d10126f883ad39567248540347b5469956ab8b2e.tar.xz
wireguard-go-d10126f883ad39567248540347b5469956ab8b2e.zip
Moved endpoint into interface and simplified peer
Diffstat (limited to 'src')
-rw-r--r--src/conn.go20
-rw-r--r--src/conn_linux.go83
-rw-r--r--src/device.go6
-rw-r--r--src/peer.go19
-rw-r--r--src/receive.go29
-rw-r--r--src/uapi.go24
6 files changed, 101 insertions, 80 deletions
diff --git a/src/conn.go b/src/conn.go
index 3cf00ab..74bb075 100644
--- a/src/conn.go
+++ b/src/conn.go
@@ -7,26 +7,28 @@ import (
"net"
)
-type UDPBind interface {
+/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
+ */
+type Bind interface {
SetMark(value uint32) error
- ReceiveIPv6(buff []byte, end *Endpoint) (int, error)
- ReceiveIPv4(buff []byte, end *Endpoint) (int, error)
- Send(buff []byte, end *Endpoint) error
+ ReceiveIPv6(buff []byte) (int, Endpoint, error)
+ ReceiveIPv4(buff []byte) (int, Endpoint, error)
+ Send(buff []byte, end Endpoint) error
Close() error
}
/* An Endpoint maintains the source/destination caching for a peer
*
- * dst : the remote address of a peer
+ * dst : the remote address of a peer ("endpoint" in uapi terminology)
* src : the local address from which datagrams originate going to the peer
- *
*/
-type UDPEndpoint interface {
+type Endpoint interface {
ClearSrc() // clears the source address
ClearDst() // clears the destination address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
+ SetDst(string) error // used for manually setting the endpoint (uapi)
DstIP() net.IP
SrcIP() net.IP
}
@@ -107,7 +109,9 @@ func UpdateUDPListener(device *Device) error {
for _, peer := range device.peers {
peer.mutex.Lock()
- peer.endpoint.value.ClearSrc()
+ if peer.endpoint != nil {
+ peer.endpoint.ClearSrc()
+ }
peer.mutex.Unlock()
}
diff --git a/src/conn_linux.go b/src/conn_linux.go
index fb576b1..46f873f 100644
--- a/src/conn_linux.go
+++ b/src/conn_linux.go
@@ -21,22 +21,24 @@ import (
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
-
-type Endpoint struct {
+type NativeEndpoint struct {
src unix.RawSockaddrInet6
dst unix.RawSockaddrInet6
}
-type IPv4Source struct {
- src unix.RawSockaddrInet4
- Ifindex int32
-}
-
type NativeBind struct {
sock4 int
sock6 int
}
+var _ Endpoint = (*NativeEndpoint)(nil)
+var _ Bind = NativeBind{}
+
+type IPv4Source struct {
+ src unix.RawSockaddrInet4
+ Ifindex int32
+}
+
func htons(val uint16) uint16 {
var out [unsafe.Sizeof(val)]byte
binary.BigEndian.PutUint16(out[:], val)
@@ -48,7 +50,11 @@ func ntohs(val uint16) uint16 {
return binary.BigEndian.Uint16((*tmp)[:])
}
-func CreateUDPBind(port uint16) (UDPBind, uint16, error) {
+func NewEndpoint() Endpoint {
+ return &NativeEndpoint{}
+}
+
+func CreateUDPBind(port uint16) (Bind, uint16, error) {
var err error
var bind NativeBind
@@ -99,28 +105,33 @@ func (bind NativeBind) Close() error {
return err2
}
-func (bind NativeBind) ReceiveIPv6(buff []byte, end *Endpoint) (int, error) {
- return receive6(
+func (bind NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ n, err := receive6(
bind.sock6,
buff,
- end,
+ &end,
)
+ return n, &end, err
}
-func (bind NativeBind) ReceiveIPv4(buff []byte, end *Endpoint) (int, error) {
- return receive4(
+func (bind NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
+ var end NativeEndpoint
+ n, err := receive4(
bind.sock4,
buff,
- end,
+ &end,
)
+ return n, &end, err
}
-func (bind NativeBind) Send(buff []byte, end *Endpoint) error {
- switch end.dst.Family {
+func (bind NativeBind) Send(buff []byte, end Endpoint) error {
+ nend := end.(*NativeEndpoint)
+ switch nend.dst.Family {
case unix.AF_INET6:
- return send6(bind.sock6, end, buff)
+ return send6(bind.sock6, nend, buff)
case unix.AF_INET:
- return send4(bind.sock4, end, buff)
+ return send4(bind.sock4, nend, buff)
default:
return errors.New("Unknown address family of destination")
}
@@ -151,12 +162,12 @@ func sockaddrToString(addr unix.RawSockaddrInet6) string {
}
}
-func (end *Endpoint) DstIP() net.IP {
- switch end.dst.Family {
+func rawAddrToIP(addr unix.RawSockaddrInet6) net.IP {
+ switch addr.Family {
case unix.AF_INET6:
- return end.dst.Addr[:]
+ return addr.Addr[:]
case unix.AF_INET:
- ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&end.dst))
+ ptr := (*unix.RawSockaddrInet4)(unsafe.Pointer(&addr))
return net.IPv4(
ptr.Addr[0],
ptr.Addr[1],
@@ -168,25 +179,33 @@ func (end *Endpoint) DstIP() net.IP {
}
}
-func (end *Endpoint) DstToBytes() []byte {
+func (end *NativeEndpoint) SrcIP() net.IP {
+ return rawAddrToIP(end.src)
+}
+
+func (end *NativeEndpoint) DstIP() net.IP {
+ return rawAddrToIP(end.dst)
+}
+
+func (end *NativeEndpoint) DstToBytes() []byte {
ptr := unsafe.Pointer(&end.src)
arr := (*[unix.SizeofSockaddrInet6]byte)(ptr)
return arr[:]
}
-func (end *Endpoint) SrcToString() string {
+func (end *NativeEndpoint) SrcToString() string {
return sockaddrToString(end.src)
}
-func (end *Endpoint) DstToString() string {
+func (end *NativeEndpoint) DstToString() string {
return sockaddrToString(end.dst)
}
-func (end *Endpoint) ClearDst() {
+func (end *NativeEndpoint) ClearDst() {
end.dst = unix.RawSockaddrInet6{}
}
-func (end *Endpoint) ClearSrc() {
+func (end *NativeEndpoint) ClearSrc() {
end.src = unix.RawSockaddrInet6{}
}
@@ -306,7 +325,7 @@ func create6(port uint16) (int, uint16, error) {
return fd, uint16(addr.Port), err
}
-func (end *Endpoint) SetDst(s string) error {
+func (end *NativeEndpoint) SetDst(s string) error {
addr, err := parseEndpoint(s)
if err != nil {
return err
@@ -342,7 +361,7 @@ func (end *Endpoint) SetDst(s string) error {
return errors.New("Failed to recognize IP address format")
}
-func send6(sock int, end *Endpoint, buff []byte) error {
+func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
@@ -404,7 +423,7 @@ func send6(sock int, end *Endpoint, buff []byte) error {
return errno
}
-func send4(sock int, end *Endpoint, buff []byte) error {
+func send4(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
@@ -470,7 +489,7 @@ func send4(sock int, end *Endpoint, buff []byte) error {
return errno
}
-func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
+func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
@@ -518,7 +537,7 @@ func receive4(sock int, buff []byte, end *Endpoint) (int, error) {
return int(size), nil
}
-func receive6(sock int, buff []byte, end *Endpoint) (int, error) {
+func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
diff --git a/src/device.go b/src/device.go
index 0085cee..76235bd 100644
--- a/src/device.go
+++ b/src/device.go
@@ -22,9 +22,9 @@ type Device struct {
}
net struct {
mutex sync.RWMutex
- bind UDPBind // bind interface
- port uint16 // listening port
- fwmark uint32 // mark value (0 = disabled)
+ bind Bind // bind interface
+ port uint16 // listening port
+ fwmark uint32 // mark value (0 = disabled)
}
mutex sync.RWMutex
privateKey NoisePrivateKey
diff --git a/src/peer.go b/src/peer.go
index a98fc97..f3eb6c2 100644
--- a/src/peer.go
+++ b/src/peer.go
@@ -15,11 +15,8 @@ type Peer struct {
keyPairs KeyPairs
handshake Handshake
device *Device
- endpoint struct {
- set bool // has a known endpoint been discovered
- value Endpoint // source / destination cache
- }
- stats struct {
+ endpoint Endpoint
+ stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
@@ -110,9 +107,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
// reset endpoint
- peer.endpoint.set = false
- peer.endpoint.value.ClearDst()
- peer.endpoint.value.ClearSrc()
+ peer.endpoint = nil
// prepare queuing
@@ -143,16 +138,16 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
defer peer.device.net.mutex.RUnlock()
peer.mutex.RLock()
defer peer.mutex.RUnlock()
- if !peer.endpoint.set {
+ if peer.endpoint == nil {
return errors.New("No known endpoint for peer")
}
- return peer.device.net.bind.Send(buffer, &peer.endpoint.value)
+ return peer.device.net.bind.Send(buffer, peer.endpoint)
}
/* Returns a short string identification for logging
*/
func (peer *Peer) String() string {
- if !peer.endpoint.set {
+ if peer.endpoint == nil {
return fmt.Sprintf(
"peer(%d unknown %s)",
peer.id,
@@ -162,7 +157,7 @@ func (peer *Peer) String() string {
return fmt.Sprintf(
"peer(%d %s %s)",
peer.id,
- peer.endpoint.value.DstToString(),
+ peer.endpoint.DstToString(),
base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]),
)
}
diff --git a/src/receive.go b/src/receive.go
index b8b06f7..27fdb8a 100644
--- a/src/receive.go
+++ b/src/receive.go
@@ -93,7 +93,7 @@ func (device *Device) addToHandshakeQueue(
}
}
-func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) {
+func (device *Device) RoutineReceiveIncomming(IP int, bind Bind) {
logDebug := device.log.Debug
logDebug.Println("Routine, receive incomming, IP version:", IP)
@@ -104,20 +104,21 @@ func (device *Device) RoutineReceiveIncomming(IP int, bind UDPBind) {
buffer := device.GetMessageBuffer()
- var size int
- var err error
+ var (
+ err error
+ size int
+ endpoint Endpoint
+ )
for {
// read next datagram
- var endpoint Endpoint
-
switch IP {
case ipv4.Version:
- size, err = bind.ReceiveIPv4(buffer[:], &endpoint)
+ size, endpoint, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version:
- size, err = bind.ReceiveIPv6(buffer[:], &endpoint)
+ size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
return
}
@@ -339,10 +340,7 @@ func (device *Device) RoutineHandshake() {
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, reply)
- device.net.bind.Send(
- writer.Bytes(),
- &elem.endpoint,
- )
+ device.net.bind.Send(writer.Bytes(), elem.endpoint)
if err != nil {
logDebug.Println("Failed to send cookie reply:", err)
}
@@ -395,8 +393,7 @@ func (device *Device) RoutineHandshake() {
// update endpoint
peer.mutex.Lock()
- peer.endpoint.set = true
- peer.endpoint.value = elem.endpoint
+ peer.endpoint = elem.endpoint
peer.mutex.Unlock()
// create response
@@ -452,8 +449,7 @@ func (device *Device) RoutineHandshake() {
// update endpoint
peer.mutex.Lock()
- peer.endpoint.set = true
- peer.endpoint.value = elem.endpoint
+ peer.endpoint = elem.endpoint
peer.mutex.Unlock()
logDebug.Println("Received handshake initation from", peer)
@@ -527,8 +523,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// update endpoint
peer.mutex.Lock()
- peer.endpoint.set = true
- peer.endpoint.value = elem.endpoint
+ peer.endpoint = elem.endpoint
peer.mutex.Unlock()
// check for keep-alive
diff --git a/src/uapi.go b/src/uapi.go
index e1d0929..670ecc4 100644
--- a/src/uapi.go
+++ b/src/uapi.go
@@ -53,8 +53,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
defer peer.mutex.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
- if peer.endpoint.set {
- send("endpoint=" + peer.endpoint.value.DstToString())
+ if peer.endpoint != nil {
+ send("endpoint=" + peer.endpoint.DstToString())
}
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
@@ -255,17 +255,25 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "endpoint":
- // set endpoint destination and reset handshake timer
+ // set endpoint destination
+
+ err := func() error {
+ peer.mutex.Lock()
+ defer peer.mutex.Unlock()
+
+ endpoint := NewEndpoint()
+ if err := endpoint.SetDst(value); err != nil {
+ return err
+ }
+ peer.endpoint = endpoint
+ signalSend(peer.signal.handshakeReset)
+ return nil
+ }()
- peer.mutex.Lock()
- err := peer.endpoint.value.SetDst(value)
- peer.endpoint.set = (err == nil)
- peer.mutex.Unlock()
if err != nil {
logError.Println("Failed to set endpoint:", value)
return &IPCError{Code: ipcErrorInvalid}
}
- signalSend(peer.signal.handshakeReset)
case "persistent_keepalive_interval":