diff options
155 files changed, 11128 insertions, 10522 deletions
@@ -1,4 +1 @@ wireguard-go -vendor -.gopath -ireallywantobuildon_linux.go @@ -10,10 +10,10 @@ MAKEFLAGS += --no-print-directory generate-version-and-build: @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ tag="$$(git describe --dirty 2>/dev/null)" && \ - ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$${tag#v}")" && \ - [ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \ - echo "$$ver" > device/version.go && \ - git update-index --assume-unchanged device/version.go || true + ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \ + [ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \ + echo "$$ver" > version.go && \ + git update-index --assume-unchanged version.go || true @$(MAKE) wireguard-go wireguard-go: $(wildcard *.go) $(wildcard */*.go) @@ -22,7 +22,10 @@ wireguard-go: $(wildcard *.go) $(wildcard */*.go) install: wireguard-go @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go" +test: + go test ./... + clean: rm -f wireguard-go -.PHONY: all clean install generate-version-and-build +.PHONY: all clean test install generate-version-and-build @@ -26,7 +26,7 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`. ### Linux -This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead use the kernel module; see the [installation page](https://www.wireguard.com/install/) for instructions. +This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions. ### macOS @@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp ## Building -This requires an installation of [go](https://golang.org) ā„ 1.12. +This requires an installation of the latest version of [Go](https://go.dev/). ``` $ git clone https://git.zx2c4.com/wireguard-go @@ -56,7 +56,7 @@ $ make ## License - Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/conn/bind_std.go b/conn/bind_std.go new file mode 100644 index 0000000..46df7fd --- /dev/null +++ b/conn/bind_std.go @@ -0,0 +1,544 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "runtime" + "strconv" + "sync" + "syscall" + + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +var ( + _ Bind = (*StdNetBind)(nil) +) + +// StdNetBind implements Bind for all platforms. While Windows has its own Bind +// (see bind_windows.go), it may fall back to StdNetBind. +// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable +// methods for sending and receiving multiple datagrams per-syscall. See the +// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564. +type StdNetBind struct { + mu sync.Mutex // protects all fields except as specified + ipv4 *net.UDPConn + ipv6 *net.UDPConn + ipv4PC *ipv4.PacketConn // will be nil on non-Linux + ipv6PC *ipv6.PacketConn // will be nil on non-Linux + ipv4TxOffload bool + ipv4RxOffload bool + ipv6TxOffload bool + ipv6RxOffload bool + + // these two fields are not guarded by mu + udpAddrPool sync.Pool + msgsPool sync.Pool + + blackhole4 bool + blackhole6 bool +} + +func NewStdNetBind() Bind { + return &StdNetBind{ + udpAddrPool: sync.Pool{ + New: func() any { + return &net.UDPAddr{ + IP: make([]byte, 16), + } + }, + }, + + msgsPool: sync.Pool{ + New: func() any { + // ipv6.Message and ipv4.Message are interchangeable as they are + // both aliases for x/net/internal/socket.Message. + msgs := make([]ipv6.Message, IdealBatchSize) + for i := range msgs { + msgs[i].Buffers = make(net.Buffers, 1) + msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize) + } + return &msgs + }, + }, + } +} + +type StdNetEndpoint struct { + // AddrPort is the endpoint destination. + netip.AddrPort + // src is the current sticky source address and interface index, if + // supported. Typically this is a PKTINFO structure from/for control + // messages, see unix.PKTINFO for an example. + src []byte +} + +var ( + _ Bind = (*StdNetBind)(nil) + _ Endpoint = &StdNetEndpoint{} +) + +func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) { + e, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + return &StdNetEndpoint{ + AddrPort: e, + }, nil +} + +func (e *StdNetEndpoint) ClearSrc() { + if e.src != nil { + // Truncate src, no need to reallocate. + e.src = e.src[:0] + } +} + +func (e *StdNetEndpoint) DstIP() netip.Addr { + return e.AddrPort.Addr() +} + +// See control_default,linux, etc for implementations of SrcIP and SrcIfidx. + +func (e *StdNetEndpoint) DstToBytes() []byte { + b, _ := e.AddrPort.MarshalBinary() + return b +} + +func (e *StdNetEndpoint) DstToString() string { + return e.AddrPort.String() +} + +func listenNet(network string, port int) (*net.UDPConn, int, error) { + conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port)) + if err != nil { + return nil, 0, err + } + + // Retrieve port. + laddr := conn.LocalAddr() + uaddr, err := net.ResolveUDPAddr( + laddr.Network(), + laddr.String(), + ) + if err != nil { + return nil, 0, err + } + return conn.(*net.UDPConn), uaddr.Port, nil +} + +func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + var err error + var tries int + + if s.ipv4 != nil || s.ipv6 != nil { + return nil, 0, ErrBindAlreadyOpen + } + + // Attempt to open ipv4 and ipv6 listeners on the same port. + // If uport is 0, we can retry on failure. +again: + port := int(uport) + var v4conn, v6conn *net.UDPConn + var v4pc *ipv4.PacketConn + var v6pc *ipv6.PacketConn + + v4conn, port, err = listenNet("udp4", port) + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + return nil, 0, err + } + + // Listen on the same port as we're using for ipv4. + v6conn, port, err = listenNet("udp6", port) + if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { + v4conn.Close() + tries++ + goto again + } + if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { + v4conn.Close() + return nil, 0, err + } + var fns []ReceiveFunc + if v4conn != nil { + s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v4pc = ipv4.NewPacketConn(v4conn) + s.ipv4PC = v4pc + } + fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload)) + s.ipv4 = v4conn + } + if v6conn != nil { + s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + v6pc = ipv6.NewPacketConn(v6conn) + s.ipv6PC = v6pc + } + fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload)) + s.ipv6 = v6conn + } + if len(fns) == 0 { + return nil, 0, syscall.EAFNOSUPPORT + } + + return fns, uint16(port), nil +} + +func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) { + for i := range *msgs { + (*msgs)[i].OOB = (*msgs)[i].OOB[:0] + (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB} + } + s.msgsPool.Put(msgs) +} + +func (s *StdNetBind) getMessages() *[]ipv6.Message { + return s.msgsPool.Get().(*[]ipv6.Message) +} + +var ( + // If compilation fails here these are no longer the same underlying type. + _ ipv6.Message = ipv4.Message{} +) + +type batchReader interface { + ReadBatch([]ipv6.Message, int) (int, error) +} + +type batchWriter interface { + WriteBatch([]ipv6.Message, int) (int, error) +} + +func (s *StdNetBind) receiveIP( + br batchReader, + conn *net.UDPConn, + rxOffload bool, + bufs [][]byte, + sizes []int, + eps []Endpoint, +) (n int, err error) { + msgs := s.getMessages() + for i := range bufs { + (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] + } + defer s.putMessages(msgs) + var numMsgs int + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + if rxOffload { + readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams) + numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0) + if err != nil { + return 0, err + } + numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize) + if err != nil { + return 0, err + } + } else { + numMsgs, err = br.ReadBatch(*msgs, 0) + if err != nil { + return 0, err + } + } + } else { + msg := &(*msgs)[0] + msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB) + if err != nil { + return 0, err + } + numMsgs = 1 + } + for i := 0; i < numMsgs; i++ { + msg := &(*msgs)[i] + sizes[i] = msg.N + if sizes[i] == 0 { + continue + } + addrPort := msg.Addr.(*net.UDPAddr).AddrPort() + ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation + getSrcFromControl(msg.OOB[:msg.NN], ep) + eps[i] = ep + } + return numMsgs, nil +} + +func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + } +} + +// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and +// rename the IdealBatchSize constant to BatchSize. +func (s *StdNetBind) BatchSize() int { + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + return IdealBatchSize + } + return 1 +} + +func (s *StdNetBind) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + var err1, err2 error + if s.ipv4 != nil { + err1 = s.ipv4.Close() + s.ipv4 = nil + s.ipv4PC = nil + } + if s.ipv6 != nil { + err2 = s.ipv6.Close() + s.ipv6 = nil + s.ipv6PC = nil + } + s.blackhole4 = false + s.blackhole6 = false + s.ipv4TxOffload = false + s.ipv4RxOffload = false + s.ipv6TxOffload = false + s.ipv6RxOffload = false + if err1 != nil { + return err1 + } + return err2 +} + +type ErrUDPGSODisabled struct { + onLaddr string + RetryErr error +} + +func (e ErrUDPGSODisabled) Error() string { + return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr) +} + +func (e ErrUDPGSODisabled) Unwrap() error { + return e.RetryErr +} + +func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error { + s.mu.Lock() + blackhole := s.blackhole4 + conn := s.ipv4 + offload := s.ipv4TxOffload + br := batchWriter(s.ipv4PC) + is6 := false + if endpoint.DstIP().Is6() { + blackhole = s.blackhole6 + conn = s.ipv6 + br = s.ipv6PC + is6 = true + offload = s.ipv6TxOffload + } + s.mu.Unlock() + + if blackhole { + return nil + } + if conn == nil { + return syscall.EAFNOSUPPORT + } + + msgs := s.getMessages() + defer s.putMessages(msgs) + ua := s.udpAddrPool.Get().(*net.UDPAddr) + defer s.udpAddrPool.Put(ua) + if is6 { + as16 := endpoint.DstIP().As16() + copy(ua.IP, as16[:]) + ua.IP = ua.IP[:16] + } else { + as4 := endpoint.DstIP().As4() + copy(ua.IP, as4[:]) + ua.IP = ua.IP[:4] + } + ua.Port = int(endpoint.(*StdNetEndpoint).Port()) + var ( + retried bool + err error + ) +retry: + if offload { + n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize) + err = s.send(conn, br, (*msgs)[:n]) + if err != nil && offload && errShouldDisableUDPGSO(err) { + offload = false + s.mu.Lock() + if is6 { + s.ipv6TxOffload = false + } else { + s.ipv4TxOffload = false + } + s.mu.Unlock() + retried = true + goto retry + } + } else { + for i := range bufs { + (*msgs)[i].Addr = ua + (*msgs)[i].Buffers[0] = bufs[i] + setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint)) + } + err = s.send(conn, br, (*msgs)[:len(bufs)]) + } + if retried { + return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err} + } + return err +} + +func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error { + var ( + n int + err error + start int + ) + if runtime.GOOS == "linux" || runtime.GOOS == "android" { + for { + n, err = pc.WriteBatch(msgs[start:], 0) + if err != nil || n == len(msgs[start:]) { + break + } + start += n + } + } else { + for _, msg := range msgs { + _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr)) + if err != nil { + break + } + } + } + return err +} + +const ( + // Exceeding these values results in EMSGSIZE. They account for layer3 and + // layer4 headers. IPv6 does not need to account for itself as the payload + // length field is self excluding. + maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8 + maxIPv6PayloadLen = 1<<16 - 1 - 8 + + // This is a hard limit imposed by the kernel. + udpSegmentMaxDatagrams = 64 +) + +type setGSOFunc func(control *[]byte, gsoSize uint16) + +func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int { + var ( + base = -1 // index of msg we are currently coalescing into + gsoSize int // segmentation size of msgs[base] + dgramCnt int // number of dgrams coalesced into msgs[base] + endBatch bool // tracking flag to start a new batch on next iteration of bufs + ) + maxPayloadLen := maxIPv4PayloadLen + if ep.DstIP().Is6() { + maxPayloadLen = maxIPv6PayloadLen + } + for i, buf := range bufs { + if i > 0 { + msgLen := len(buf) + baseLenBefore := len(msgs[base].Buffers[0]) + freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore + if msgLen+baseLenBefore <= maxPayloadLen && + msgLen <= gsoSize && + msgLen <= freeBaseCap && + dgramCnt < udpSegmentMaxDatagrams && + !endBatch { + msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...) + if i == len(bufs)-1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + dgramCnt++ + if msgLen < gsoSize { + // A smaller than gsoSize packet on the tail is legal, but + // it must end the batch. + endBatch = true + } + continue + } + } + if dgramCnt > 1 { + setGSO(&msgs[base].OOB, uint16(gsoSize)) + } + // Reset prior to incrementing base since we are preparing to start a + // new potential batch. + endBatch = false + base++ + gsoSize = len(buf) + setSrcControl(&msgs[base].OOB, ep) + msgs[base].Buffers[0] = buf + msgs[base].Addr = addr + dgramCnt = 1 + } + return base + 1 +} + +type getGSOFunc func(control []byte) (int, error) + +func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) { + for i := firstMsgAt; i < len(msgs); i++ { + msg := &msgs[i] + if msg.N == 0 { + return n, err + } + var ( + gsoSize int + start int + end = msg.N + numToSplit = 1 + ) + gsoSize, err = getGSO(msg.OOB[:msg.NN]) + if err != nil { + return n, err + } + if gsoSize > 0 { + numToSplit = (msg.N + gsoSize - 1) / gsoSize + end = gsoSize + } + for j := 0; j < numToSplit; j++ { + if n > i { + return n, errors.New("splitting coalesced packet resulted in overflow") + } + copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end]) + msgs[n].N = copied + msgs[n].Addr = msg.Addr + start = end + end += gsoSize + if end > msg.N { + end = msg.N + } + n++ + } + if i != n-1 { + // It is legal for bytes to move within msg.Buffers[0] as a result + // of splitting, so we only zero the source msg len when it is not + // the destination of the last split operation above. + msg.N = 0 + } + } + return n, nil +} diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go new file mode 100644 index 0000000..34a3c9a --- /dev/null +++ b/conn/bind_std_test.go @@ -0,0 +1,250 @@ +package conn + +import ( + "encoding/binary" + "net" + "testing" + + "golang.org/x/net/ipv6" +) + +func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { + bind := NewStdNetBind().(*StdNetBind) + fns, _, err := bind.Open(0) + if err != nil { + t.Fatal(err) + } + bind.Close() + bufs := make([][]byte, 1) + bufs[0] = make([]byte, 1) + sizes := make([]int, 1) + eps := make([]Endpoint, 1) + for _, fn := range fns { + // The ReceiveFuncs must not access conn-related fields on StdNetBind + // unguarded. Close() nils the conn-related fields resulting in a panic + // if they violate the mutex. + fn(bufs, sizes, eps) + } +} + +func mockSetGSOSize(control *[]byte, gsoSize uint16) { + *control = (*control)[:cap(*control)] + binary.LittleEndian.PutUint16(*control, gsoSize) +} + +func Test_coalesceMessages(t *testing.T) { + cases := []struct { + name string + buffs [][]byte + wantLens []int + wantGSO []int + }{ + { + name: "one message no coalesce", + buffs: [][]byte{ + make([]byte, 1, 1), + }, + wantLens: []int{1}, + wantGSO: []int{0}, + }, + { + name: "two messages equal len coalesce", + buffs: [][]byte{ + make([]byte, 1, 2), + make([]byte, 1, 1), + }, + wantLens: []int{2}, + wantGSO: []int{1}, + }, + { + name: "two messages unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + }, + wantLens: []int{3}, + wantGSO: []int{2}, + }, + { + name: "three messages second unequal len coalesce", + buffs: [][]byte{ + make([]byte, 2, 3), + make([]byte, 1, 1), + make([]byte, 2, 2), + }, + wantLens: []int{3, 2}, + wantGSO: []int{2, 0}, + }, + { + name: "three messages limited cap coalesce", + buffs: [][]byte{ + make([]byte, 2, 4), + make([]byte, 2, 2), + make([]byte, 2, 2), + }, + wantLens: []int{4, 2}, + wantGSO: []int{2, 0}, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1").To4(), + Port: 1, + } + msgs := make([]ipv6.Message, len(tt.buffs)) + for i := range msgs { + msgs[i].Buffers = make([][]byte, 1) + msgs[i].OOB = make([]byte, 0, 2) + } + got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize) + if got != len(tt.wantLens) { + t.Fatalf("got len %d want: %d", got, len(tt.wantLens)) + } + for i := 0; i < got; i++ { + if msgs[i].Addr != addr { + t.Errorf("msgs[%d].Addr != passed addr", i) + } + gotLen := len(msgs[i].Buffers[0]) + if gotLen != tt.wantLens[i] { + t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i]) + } + gotGSO, err := mockGetGSOSize(msgs[i].OOB) + if err != nil { + t.Fatalf("msgs[%d] getGSOSize err: %v", i, err) + } + if gotGSO != tt.wantGSO[i] { + t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i]) + } + } + }) + } +} + +func mockGetGSOSize(control []byte) (int, error) { + if len(control) < 2 { + return 0, nil + } + return int(binary.LittleEndian.Uint16(control)), nil +} + +func Test_splitCoalescedMessages(t *testing.T) { + newMsg := func(n, gso int) ipv6.Message { + msg := ipv6.Message{ + Buffers: [][]byte{make([]byte, 1<<16-1)}, + N: n, + OOB: make([]byte, 2), + } + binary.LittleEndian.PutUint16(msg.OOB, uint16(gso)) + if gso > 0 { + msg.NN = 2 + } + return msg + } + + cases := []struct { + name string + msgs []ipv6.Message + firstMsgAt int + wantNumEval int + wantMsgLens []int + wantErr bool + }{ + { + name: "second last split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(3, 1), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 3, + wantMsgLens: []int{1, 1, 1, 0}, + wantErr: false, + }, + { + name: "second last no split last empty", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(0, 0), + }, + firstMsgAt: 2, + wantNumEval: 1, + wantMsgLens: []int{1, 0, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last no split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(1, 0), + }, + firstMsgAt: 2, + wantNumEval: 2, + wantMsgLens: []int{1, 1, 0, 0}, + wantErr: false, + }, + { + name: "second last no split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(3, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last split last split", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(2, 1), + newMsg(2, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: false, + }, + { + name: "second last no split last split overflow", + msgs: []ipv6.Message{ + newMsg(0, 0), + newMsg(0, 0), + newMsg(1, 0), + newMsg(4, 1), + }, + firstMsgAt: 2, + wantNumEval: 4, + wantMsgLens: []int{1, 1, 1, 1}, + wantErr: true, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize) + if err != nil && !tt.wantErr { + t.Fatalf("err: %v", err) + } + if got != tt.wantNumEval { + t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval) + } + for i, msg := range tt.msgs { + if msg.N != tt.wantMsgLens[i] { + t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i]) + } + } + }) + } +} diff --git a/conn/bind_windows.go b/conn/bind_windows.go new file mode 100644 index 0000000..d5095e0 --- /dev/null +++ b/conn/bind_windows.go @@ -0,0 +1,601 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "encoding/binary" + "io" + "net" + "net/netip" + "strconv" + "sync" + "sync/atomic" + "unsafe" + + "golang.org/x/sys/windows" + + "golang.zx2c4.com/wireguard/conn/winrio" +) + +const ( + packetsPerRing = 1024 + bytesPerPacket = 2048 - 32 + receiveSpins = 15 +) + +type ringPacket struct { + addr WinRingEndpoint + data [bytesPerPacket]byte +} + +type ringBuffer struct { + packets uintptr + head, tail uint32 + id winrio.BufferId + iocp windows.Handle + isFull bool + cq winrio.Cq + mu sync.Mutex + overlapped windows.Overlapped +} + +func (rb *ringBuffer) Push() *ringPacket { + for rb.isFull { + panic("ring is full") + } + ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) + rb.tail += 1 + if rb.tail%packetsPerRing == rb.head%packetsPerRing { + rb.isFull = true + } + return ret +} + +func (rb *ringBuffer) Return(count uint32) { + if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull { + return + } + rb.head += count + rb.isFull = false +} + +type afWinRingBind struct { + sock windows.Handle + rx, tx ringBuffer + rq winrio.Rq + mu sync.Mutex + blackhole bool +} + +// WinRingBind uses Windows registered I/O for fast ring buffered networking. +type WinRingBind struct { + v4, v6 afWinRingBind + mu sync.RWMutex + isOpen atomic.Uint32 // 0, 1, or 2 +} + +func NewDefaultBind() Bind { return NewWinRingBind() } + +func NewWinRingBind() Bind { + if !winrio.Initialize() { + return NewStdNetBind() + } + return new(WinRingBind) +} + +type WinRingEndpoint struct { + family uint16 + data [30]byte +} + +var ( + _ Bind = (*WinRingBind)(nil) + _ Endpoint = (*WinRingEndpoint)(nil) +) + +func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { + host, port, err := net.SplitHostPort(s) + if err != nil { + return nil, err + } + host16, err := windows.UTF16PtrFromString(host) + if err != nil { + return nil, err + } + port16, err := windows.UTF16PtrFromString(port) + if err != nil { + return nil, err + } + hints := windows.AddrinfoW{ + Flags: windows.AI_NUMERICHOST, + Family: windows.AF_UNSPEC, + Socktype: windows.SOCK_DGRAM, + Protocol: windows.IPPROTO_UDP, + } + var addrinfo *windows.AddrinfoW + err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo) + if err != nil { + return nil, err + } + defer windows.FreeAddrInfoW(addrinfo) + if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { + return nil, windows.ERROR_INVALID_ADDRESS + } + var dst [unsafe.Sizeof(WinRingEndpoint{})]byte + copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen)) + return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil +} + +func (*WinRingEndpoint) ClearSrc() {} + +func (e *WinRingEndpoint) DstIP() netip.Addr { + switch e.family { + case windows.AF_INET: + return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) + case windows.AF_INET6: + return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) + } + return netip.Addr{} +} + +func (e *WinRingEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported +} + +func (e *WinRingEndpoint) DstToBytes() []byte { + switch e.family { + case windows.AF_INET: + b := make([]byte, 0, 6) + b = append(b, e.data[2:6]...) + b = append(b, e.data[1], e.data[0]) + return b + case windows.AF_INET6: + b := make([]byte, 0, 18) + b = append(b, e.data[6:22]...) + b = append(b, e.data[1], e.data[0]) + return b + } + return nil +} + +func (e *WinRingEndpoint) DstToString() string { + switch e.family { + case windows.AF_INET: + return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() + case windows.AF_INET6: + var zone string + if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { + zone = strconv.FormatUint(uint64(scope), 10) + } + return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() + } + return "" +} + +func (e *WinRingEndpoint) SrcToString() string { + return "" +} + +func (ring *ringBuffer) CloseAndZero() { + if ring.cq != 0 { + winrio.CloseCompletionQueue(ring.cq) + ring.cq = 0 + } + if ring.iocp != 0 { + windows.CloseHandle(ring.iocp) + ring.iocp = 0 + } + if ring.id != 0 { + winrio.DeregisterBuffer(ring.id) + ring.id = 0 + } + if ring.packets != 0 { + windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) + ring.packets = 0 + } + ring.head = 0 + ring.tail = 0 + ring.isFull = false +} + +func (bind *afWinRingBind) CloseAndZero() { + bind.rx.CloseAndZero() + bind.tx.CloseAndZero() + if bind.sock != 0 { + windows.CloseHandle(bind.sock) + bind.sock = 0 + } + bind.blackhole = false +} + +func (bind *WinRingBind) closeAndZero() { + bind.isOpen.Store(0) + bind.v4.CloseAndZero() + bind.v6.CloseAndZero() +} + +func (ring *ringBuffer) Open() error { + var err error + packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing + ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) + if err != nil { + return err + } + ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen)) + if err != nil { + return err + } + ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return err + } + ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped) + if err != nil { + return err + } + return nil +} + +func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) { + var err error + bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return nil, err + } + err = bind.rx.Open() + if err != nil { + return nil, err + } + err = bind.tx.Open() + if err != nil { + return nil, err + } + bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0) + if err != nil { + return nil, err + } + err = windows.Bind(bind.sock, sa) + if err != nil { + return nil, err + } + sa, err = windows.Getsockname(bind.sock) + if err != nil { + return nil, err + } + return sa, nil +} + +func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) { + bind.mu.Lock() + defer bind.mu.Unlock() + defer func() { + if err != nil { + bind.closeAndZero() + } + }() + if bind.isOpen.Load() != 0 { + return nil, 0, ErrBindAlreadyOpen + } + var sa windows.Sockaddr + sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) + if err != nil { + return nil, 0, err + } + sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) + if err != nil { + return nil, 0, err + } + selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) + for i := 0; i < packetsPerRing; i++ { + err = bind.v4.InsertReceiveRequest() + if err != nil { + return nil, 0, err + } + err = bind.v6.InsertReceiveRequest() + if err != nil { + return nil, 0, err + } + } + bind.isOpen.Store(1) + return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err +} + +func (bind *WinRingBind) Close() error { + bind.mu.RLock() + if bind.isOpen.Load() != 1 { + bind.mu.RUnlock() + return nil + } + bind.isOpen.Store(2) + windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) + windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil) + bind.mu.RUnlock() + bind.mu.Lock() + defer bind.mu.Unlock() + bind.closeAndZero() + return nil +} + +// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and +// rename the IdealBatchSize constant to BatchSize. +func (bind *WinRingBind) BatchSize() int { + // TODO: implement batching in and out of the ring + return 1 +} + +func (bind *WinRingBind) SetMark(mark uint32) error { + return nil +} + +func (bind *afWinRingBind) InsertReceiveRequest() error { + packet := bind.rx.Push() + dataBuffer := &winrio.Buffer{ + Id: bind.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets), + Length: uint32(len(packet.data)), + } + addressBuffer := &winrio.Buffer{ + Id: bind.rx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + bind.mu.Lock() + defer bind.mu.Unlock() + return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet))) +} + +//go:linkname procyield runtime.procyield +func procyield(cycles uint32) + +func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) { + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + bind.rx.mu.Lock() + defer bind.rx.mu.Unlock() + + var err error + var count uint32 + var results [1]winrio.Result +retry: + count = 0 + for tries := 0; count == 0 && tries < receiveSpins; tries++ { + if tries > 0 { + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + procyield(1) + } + count = winrio.DequeueCompletion(bind.rx.cq, results[:]) + } + if count == 0 { + err = winrio.Notify(bind.rx.cq) + if err != nil { + return 0, nil, err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return 0, nil, err + } + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + count = winrio.DequeueCompletion(bind.rx.cq, results[:]) + if count == 0 { + return 0, nil, io.ErrNoProgress + } + } + bind.rx.Return(1) + err = bind.InsertReceiveRequest() + if err != nil { + return 0, nil, err + } + // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us + // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to + // attacker bandwidth, just like the rest of the receive path. + if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { + if isOpen.Load() != 1 { + return 0, nil, net.ErrClosed + } + goto retry + } + if results[0].Status != 0 { + return 0, nil, windows.Errno(results[0].Status) + } + packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext))) + ep := packet.addr + n := copy(buf, packet.data[:results[0].BytesTransferred]) + return n, &ep, nil +} + +func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { + bind.mu.RLock() + defer bind.mu.RUnlock() + n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err +} + +func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { + bind.mu.RLock() + defer bind.mu.RUnlock() + n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err +} + +func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { + if isOpen.Load() != 1 { + return net.ErrClosed + } + if len(buf) > bytesPerPacket { + return io.ErrShortBuffer + } + bind.tx.mu.Lock() + defer bind.tx.mu.Unlock() + var results [packetsPerRing]winrio.Result + count := winrio.DequeueCompletion(bind.tx.cq, results[:]) + if count == 0 && bind.tx.isFull { + err := winrio.Notify(bind.tx.cq) + if err != nil { + return err + } + var bytes uint32 + var key uintptr + var overlapped *windows.Overlapped + err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE) + if err != nil { + return err + } + if isOpen.Load() != 1 { + return net.ErrClosed + } + count = winrio.DequeueCompletion(bind.tx.cq, results[:]) + if count == 0 { + return io.ErrNoProgress + } + } + if count > 0 { + bind.tx.Return(count) + } + packet := bind.tx.Push() + packet.addr = *nend + copy(packet.data[:], buf) + dataBuffer := &winrio.Buffer{ + Id: bind.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets), + Length: uint32(len(buf)), + } + addressBuffer := &winrio.Buffer{ + Id: bind.tx.id, + Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets), + Length: uint32(unsafe.Sizeof(packet.addr)), + } + bind.mu.Lock() + defer bind.mu.Unlock() + return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) +} + +func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { + nend, ok := endpoint.(*WinRingEndpoint) + if !ok { + return ErrWrongEndpointType + } + bind.mu.RLock() + defer bind.mu.RUnlock() + for _, buf := range bufs { + switch nend.family { + case windows.AF_INET: + if bind.v4.blackhole { + continue + } + if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + case windows.AF_INET6: + if bind.v6.blackhole { + continue + } + if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + } + } + return nil +} + +func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv4.SyscallConn() + if err != nil { + return err + } + err2 := sysconn.Control(func(fd uintptr) { + err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex) + }) + if err2 != nil { + return err2 + } + if err != nil { + return err + } + s.blackhole4 = blackhole + return nil +} + +func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv6.SyscallConn() + if err != nil { + return err + } + err2 := sysconn.Control(func(fd uintptr) { + err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex) + }) + if err2 != nil { + return err2 + } + if err != nil { + return err + } + s.blackhole6 = blackhole + return nil +} + +func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + bind.mu.RLock() + defer bind.mu.RUnlock() + if bind.isOpen.Load() != 1 { + return net.ErrClosed + } + err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) + if err != nil { + return err + } + bind.v4.blackhole = blackhole + return nil +} + +func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + bind.mu.RLock() + defer bind.mu.RUnlock() + if bind.isOpen.Load() != 1 { + return net.ErrClosed + } + err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) + if err != nil { + return err + } + bind.v6.blackhole = blackhole + return nil +} + +func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error { + const IP_UNICAST_IF = 31 + /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ + var bytes [4]byte + binary.BigEndian.PutUint32(bytes[:], interfaceIndex) + interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) + err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex)) + if err != nil { + return err + } + return nil +} + +func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error { + const IPV6_UNICAST_IF = 31 + return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) +} diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go new file mode 100644 index 0000000..74e7add --- /dev/null +++ b/conn/bindtest/bindtest.go @@ -0,0 +1,136 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package bindtest + +import ( + "fmt" + "math/rand" + "net" + "net/netip" + "os" + + "golang.zx2c4.com/wireguard/conn" +) + +type ChannelBind struct { + rx4, tx4 *chan []byte + rx6, tx6 *chan []byte + closeSignal chan bool + source4, source6 ChannelEndpoint + target4, target6 ChannelEndpoint +} + +type ChannelEndpoint uint16 + +var ( + _ conn.Bind = (*ChannelBind)(nil) + _ conn.Endpoint = (*ChannelEndpoint)(nil) +) + +func NewChannelBinds() [2]conn.Bind { + arx4 := make(chan []byte, 8192) + brx4 := make(chan []byte, 8192) + arx6 := make(chan []byte, 8192) + brx6 := make(chan []byte, 8192) + var binds [2]ChannelBind + binds[0].rx4 = &arx4 + binds[0].tx4 = &brx4 + binds[1].rx4 = &brx4 + binds[1].tx4 = &arx4 + binds[0].rx6 = &arx6 + binds[0].tx6 = &brx6 + binds[1].rx6 = &brx6 + binds[1].tx6 = &arx6 + binds[0].target4 = ChannelEndpoint(1) + binds[1].target4 = ChannelEndpoint(2) + binds[0].target6 = ChannelEndpoint(3) + binds[1].target6 = ChannelEndpoint(4) + binds[0].source4 = binds[1].target4 + binds[0].source6 = binds[1].target6 + binds[1].source4 = binds[0].target4 + binds[1].source6 = binds[0].target6 + return [2]conn.Bind{&binds[0], &binds[1]} +} + +func (c ChannelEndpoint) ClearSrc() {} + +func (c ChannelEndpoint) SrcToString() string { return "" } + +func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) } + +func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} } + +func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } + +func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} } + +func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + c.closeSignal = make(chan bool) + fns = append(fns, c.makeReceiveFunc(*c.rx4)) + fns = append(fns, c.makeReceiveFunc(*c.rx6)) + if rand.Uint32()&1 == 0 { + return fns, uint16(c.source4), nil + } else { + return fns, uint16(c.source6), nil + } +} + +func (c *ChannelBind) Close() error { + if c.closeSignal != nil { + select { + case <-c.closeSignal: + default: + close(c.closeSignal) + } + } + return nil +} + +func (c *ChannelBind) BatchSize() int { return 1 } + +func (c *ChannelBind) SetMark(mark uint32) error { return nil } + +func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + select { + case <-c.closeSignal: + return 0, net.ErrClosed + case rx := <-ch: + copied := copy(bufs[0], rx) + sizes[0] = copied + eps[0] = c.target6 + return 1, nil + } + } +} + +func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error { + for _, b := range bufs { + select { + case <-c.closeSignal: + return net.ErrClosed + default: + bc := make([]byte, len(b)) + copy(bc, b) + if ep.(ChannelEndpoint) == c.target4 { + *c.tx4 <- bc + } else if ep.(ChannelEndpoint) == c.target6 { + *c.tx6 <- bc + } else { + return os.ErrInvalid + } + } + } + return nil +} + +func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) { + addr, err := netip.ParseAddrPort(s) + if err != nil { + return nil, err + } + return ChannelEndpoint(addr.Port()), nil +} diff --git a/conn/boundif_android.go b/conn/boundif_android.go new file mode 100644 index 0000000..dd3ca5b --- /dev/null +++ b/conn/boundif_android.go @@ -0,0 +1,34 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) { + sysconn, err := s.ipv4.SyscallConn() + if err != nil { + return -1, err + } + err = sysconn.Control(func(f uintptr) { + fd = int(f) + }) + if err != nil { + return -1, err + } + return +} + +func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) { + sysconn, err := s.ipv6.SyscallConn() + if err != nil { + return -1, err + } + err = sysconn.Control(func(f uintptr) { + fd = int(f) + }) + if err != nil { + return -1, err + } + return +} diff --git a/conn/boundif_windows.go b/conn/boundif_windows.go deleted file mode 100644 index fe38d05..0000000 --- a/conn/boundif_windows.go +++ /dev/null @@ -1,59 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "encoding/binary" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - sockoptIP_UNICAST_IF = 31 - sockoptIPV6_UNICAST_IF = 31 -) - -func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ - bytes := make([]byte, 4) - binary.BigEndian.PutUint32(bytes, interfaceIndex) - interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) - - sysconn, err := bind.ipv4.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IP, sockoptIP_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - bind.blackhole4 = blackhole - return nil -} - -func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - sysconn, err := bind.ipv6.SyscallConn() - if err != nil { - return err - } - err2 := sysconn.Control(func(fd uintptr) { - err = windows.SetsockoptInt(windows.Handle(fd), windows.IPPROTO_IPV6, sockoptIPV6_UNICAST_IF, int(interfaceIndex)) - }) - if err2 != nil { - return err2 - } - if err != nil { - return err - } - bind.blackhole6 = blackhole - return nil -} diff --git a/conn/conn.go b/conn/conn.go index 6b7db12..a1f57d2 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ // Package conn implements WireGuard's network connections. @@ -8,94 +8,126 @@ package conn import ( "errors" - "net" + "fmt" + "net/netip" + "reflect" + "runtime" "strings" ) +const ( + IdealBatchSize = 128 // maximum number of packets handled per read and write +) + +// A ReceiveFunc receives at least one packet from the network and writes them +// into packets. On a successful read it returns the number of elements of +// sizes, packets, and endpoints that should be evaluated. Some elements of +// sizes may be zero, and callers should ignore them. Callers must pass a sizes +// and eps slice with a length greater than or equal to the length of packets. +// These lengths must not exceed the length of the associated Bind.BatchSize(). +type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) + // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. +// +// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, +// depending on the platform-specific implementation. type Bind interface { - // LastMark reports the last mark set for this Bind. - LastMark() uint32 + // Open puts the Bind into a listening state on a given port and reports the actual + // port that it bound to. Passing zero results in a random selection. + // fns is the set of functions that will be called to receive packets. + Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error) + + // Close closes the Bind listener. + // All fns returned by Open must return net.ErrClosed after a call to Close. + Close() error // SetMark sets the mark for each packet sent through this Bind. // This mark is passed to the kernel as the socket option SO_MARK. SetMark(mark uint32) error - // ReceiveIPv6 reads an IPv6 UDP packet into b. - // - // It reports the number of bytes read, n, - // the packet source address ep, - // and any error. - ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error) - - // ReceiveIPv4 reads an IPv4 UDP packet into b. - // - // It reports the number of bytes read, n, - // the packet source address ep, - // and any error. - ReceiveIPv4(b []byte) (n int, ep Endpoint, err error) + // Send writes one or more packets in bufs to address ep. The length of + // bufs must not exceed BatchSize(). + Send(bufs [][]byte, ep Endpoint) error - // Send writes a packet b to address ep. - Send(b []byte, ep Endpoint) error + // ParseEndpoint creates a new endpoint from a string. + ParseEndpoint(s string) (Endpoint, error) - // Close closes the Bind connection. - Close() error + // BatchSize is the number of buffers expected to be passed to + // the ReceiveFuncs, and the maximum expected to be passed to SendBatch. + BatchSize() int } -// CreateBind creates a Bind bound to a port. -// -// The value actualPort reports the actual port number the Bind -// object gets bound to. -func CreateBind(port uint16) (b Bind, actualPort uint16, err error) { - return createBind(port) +// BindSocketToInterface is implemented by Bind objects that support being +// tied to a single network interface. Used by wireguard-windows. +type BindSocketToInterface interface { + BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error + BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error } -// BindToInterface is implemented by Bind objects that support being -// tied to a single network interface. -type BindToInterface interface { - BindToInterface4(interfaceIndex uint32, blackhole bool) error - BindToInterface6(interfaceIndex uint32, blackhole bool) error +// PeekLookAtSocketFd is implemented by Bind objects that support having their +// file descriptor peeked at. Used by wireguard-android. +type PeekLookAtSocketFd interface { + PeekLookAtSocketFd4() (fd int, err error) + PeekLookAtSocketFd6() (fd int, err error) } // An Endpoint maintains the source/destination caching for a peer. // -// dst : the remote address of a peer ("endpoint" in uapi terminology) -// src : the local address from which datagrams originate going to the peer +// dst: the remote address of a peer ("endpoint" in uapi terminology) +// src: the local address from which datagrams originate going to the peer type Endpoint interface { ClearSrc() // clears the source address SrcToString() string // returns the local source address (ip:port) DstToString() string // returns the destination address (ip:port) DstToBytes() []byte // used for mac2 cookie calculations - DstIP() net.IP - SrcIP() net.IP + DstIP() netip.Addr + SrcIP() netip.Addr } -func parseEndpoint(s string) (*net.UDPAddr, error) { - // ensure that the host is an IP address +var ( + ErrBindAlreadyOpen = errors.New("bind is already open") + ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type") +) - host, _, err := net.SplitHostPort(s) - if err != nil { - return nil, err +func (fn ReceiveFunc) PrettyName() string { + name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() + // 0. cheese/taco.beansIPv6.func12.func21218-fm + name = strings.TrimSuffix(name, "-fm") + // 1. cheese/taco.beansIPv6.func12.func21218 + if idx := strings.LastIndexByte(name, '/'); idx != -1 { + name = name[idx+1:] + // 2. taco.beansIPv6.func12.func21218 } - if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 { - // Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just - // trying to make sure with a small sanity test that this is a real IP address and - // not something that's likely to incur DNS lookups. - host = host[:i] + for { + var idx int + for idx = len(name) - 1; idx >= 0; idx-- { + if name[idx] < '0' || name[idx] > '9' { + break + } + } + if idx == len(name)-1 { + break + } + const dotFunc = ".func" + if !strings.HasSuffix(name[:idx+1], dotFunc) { + break + } + name = name[:idx+1-len(dotFunc)] + // 3. taco.beansIPv6.func12 + // 4. taco.beansIPv6 } - if ip := net.ParseIP(host); ip == nil { - return nil, errors.New("Failed to parse IP address: " + host) + if idx := strings.LastIndexByte(name, '.'); idx != -1 { + name = name[idx+1:] + // 5. beansIPv6 } - - // parse address and port - - addr, err := net.ResolveUDPAddr("udp", s) - if err != nil { - return nil, err + if name == "" { + return fmt.Sprintf("%p", fn) + } + if strings.HasSuffix(name, "IPv4") { + return "v4" } - ip4 := addr.IP.To4() - if ip4 != nil { - addr.IP = ip4 + if strings.HasSuffix(name, "IPv6") { + return "v6" } - return addr, err + return name } diff --git a/conn/conn_default.go b/conn/conn_default.go deleted file mode 100644 index bad9d4d..0000000 --- a/conn/conn_default.go +++ /dev/null @@ -1,177 +0,0 @@ -// +build !linux android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "net" - "os" - "syscall" -) - -/* This code is meant to be a temporary solution - * on platforms for which the sticky socket / source caching behavior - * has not yet been implemented. - * - * See conn_linux.go for an implementation on the linux platform. - */ - -type nativeBind struct { - ipv4 *net.UDPConn - ipv6 *net.UDPConn - blackhole4 bool - blackhole6 bool -} - -type NativeEndpoint net.UDPAddr - -var _ Bind = (*nativeBind)(nil) -var _ Endpoint = (*NativeEndpoint)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - addr, err := parseEndpoint(s) - return (*NativeEndpoint)(addr), err -} - -func (_ *NativeEndpoint) ClearSrc() {} - -func (e *NativeEndpoint) DstIP() net.IP { - return (*net.UDPAddr)(e).IP -} - -func (e *NativeEndpoint) SrcIP() net.IP { - return nil // not supported -} - -func (e *NativeEndpoint) DstToBytes() []byte { - addr := (*net.UDPAddr)(e) - out := addr.IP.To4() - if out == nil { - out = addr.IP - } - out = append(out, byte(addr.Port&0xff)) - out = append(out, byte((addr.Port>>8)&0xff)) - return out -} - -func (e *NativeEndpoint) DstToString() string { - return (*net.UDPAddr)(e).String() -} - -func (e *NativeEndpoint) SrcToString() string { - return "" -} - -func listenNet(network string, port int) (*net.UDPConn, int, error) { - conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port}) - if err != nil { - return nil, 0, err - } - - // Retrieve port. - // TODO(crawshaw): under what circumstances is this necessary? - laddr := conn.LocalAddr() - uaddr, err := net.ResolveUDPAddr( - laddr.Network(), - laddr.String(), - ) - if err != nil { - return nil, 0, err - } - return conn, uaddr.Port, nil -} - -func extractErrno(err error) error { - opErr, ok := err.(*net.OpError) - if !ok { - return nil - } - syscallErr, ok := opErr.Err.(*os.SyscallError) - if !ok { - return nil - } - return syscallErr.Err -} - -func createBind(uport uint16) (Bind, uint16, error) { - var err error - var bind nativeBind - - port := int(uport) - - bind.ipv4, port, err = listenNet("udp4", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - return nil, 0, err - } - - bind.ipv6, port, err = listenNet("udp6", port) - if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT { - bind.ipv4.Close() - bind.ipv4 = nil - return nil, 0, err - } - - return &bind, uint16(port), nil -} - -func (bind *nativeBind) Close() error { - var err1, err2 error - if bind.ipv4 != nil { - err1 = bind.ipv4.Close() - } - if bind.ipv6 != nil { - err2 = bind.ipv6.Close() - } - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *nativeBind) LastMark() uint32 { return 0 } - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - if bind.ipv4 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv4.ReadFromUDP(buff) - if endpoint != nil { - endpoint.IP = endpoint.IP.To4() - } - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - if bind.ipv6 == nil { - return 0, nil, syscall.EAFNOSUPPORT - } - n, endpoint, err := bind.ipv6.ReadFromUDP(buff) - return n, (*NativeEndpoint)(endpoint), err -} - -func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error { - var err error - nend := endpoint.(*NativeEndpoint) - if nend.IP.To4() != nil { - if bind.ipv4 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole4 { - return nil - } - _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } else { - if bind.ipv6 == nil { - return syscall.EAFNOSUPPORT - } - if bind.blackhole6 { - return nil - } - _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) - } - return err -} diff --git a/conn/conn_linux.go b/conn/conn_linux.go deleted file mode 100644 index 523da4a..0000000 --- a/conn/conn_linux.go +++ /dev/null @@ -1,571 +0,0 @@ -// +build !android - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package conn - -import ( - "errors" - "net" - "strconv" - "sync" - "syscall" - "unsafe" - - "golang.org/x/sys/unix" -) - -const ( - FD_ERR = -1 -) - -type IPv4Source struct { - Src [4]byte - Ifindex int32 -} - -type IPv6Source struct { - src [16]byte - //ifindex belongs in dst.ZoneId -} - -type NativeEndpoint struct { - sync.Mutex - dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte - src [unsafe.Sizeof(IPv6Source{})]byte - isV6 bool -} - -func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() } -func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() } -func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 } - -func (endpoint *NativeEndpoint) src4() *IPv4Source { - return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) src6() *IPv6Source { - return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0])) -} - -func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 { - return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0])) -} - -func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 { - return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0])) -} - -type nativeBind struct { - sock4 int - sock6 int - lastMark uint32 -} - -var _ Endpoint = (*NativeEndpoint)(nil) -var _ Bind = (*nativeBind)(nil) - -func CreateEndpoint(s string) (Endpoint, error) { - var end NativeEndpoint - addr, err := parseEndpoint(s) - if err != nil { - return nil, err - } - - ipv4 := addr.IP.To4() - if ipv4 != nil { - dst := end.dst4() - end.isV6 = false - dst.Port = addr.Port - copy(dst.Addr[:], ipv4) - end.ClearSrc() - return &end, nil - } - - ipv6 := addr.IP.To16() - if ipv6 != nil { - zone, err := zoneToUint32(addr.Zone) - if err != nil { - return nil, err - } - dst := end.dst6() - end.isV6 = true - dst.Port = addr.Port - dst.ZoneId = zone - copy(dst.Addr[:], ipv6[:]) - end.ClearSrc() - return &end, nil - } - - return nil, errors.New("Invalid IP address") -} - -func createBind(port uint16) (Bind, uint16, error) { - var err error - var bind nativeBind - var newPort uint16 - - // Attempt ipv6 bind, update port if successful. - bind.sock6, newPort, err = create6(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - return nil, 0, err - } - } else { - port = newPort - } - - // Attempt ipv4 bind, update port if successful. - bind.sock4, newPort, err = create4(port) - if err != nil { - if err != syscall.EAFNOSUPPORT { - unix.Close(bind.sock6) - return nil, 0, err - } - } else { - port = newPort - } - - if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR { - return nil, 0, errors.New("ipv4 and ipv6 not supported") - } - - return &bind, port, nil -} - -func (bind *nativeBind) LastMark() uint32 { - return bind.lastMark -} - -func (bind *nativeBind) SetMark(value uint32) error { - if bind.sock6 != -1 { - err := unix.SetsockoptInt( - bind.sock6, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - if bind.sock4 != -1 { - err := unix.SetsockoptInt( - bind.sock4, - unix.SOL_SOCKET, - unix.SO_MARK, - int(value), - ) - - if err != nil { - return err - } - } - - bind.lastMark = value - return nil -} - -func closeUnblock(fd int) error { - // shutdown to unblock readers and writers - unix.Shutdown(fd, unix.SHUT_RDWR) - return unix.Close(fd) -} - -func (bind *nativeBind) Close() error { - var err1, err2 error - if bind.sock6 != -1 { - err1 = closeUnblock(bind.sock6) - } - if bind.sock4 != -1 { - err2 = closeUnblock(bind.sock4) - } - - if err1 != nil { - return err1 - } - return err2 -} - -func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock6 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive6( - bind.sock6, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { - var end NativeEndpoint - if bind.sock4 == -1 { - return 0, nil, syscall.EAFNOSUPPORT - } - n, err := receive4( - bind.sock4, - buff, - &end, - ) - return n, &end, err -} - -func (bind *nativeBind) Send(buff []byte, end Endpoint) error { - nend := end.(*NativeEndpoint) - if !nend.isV6 { - if bind.sock4 == -1 { - return syscall.EAFNOSUPPORT - } - return send4(bind.sock4, nend, buff) - } else { - if bind.sock6 == -1 { - return syscall.EAFNOSUPPORT - } - return send6(bind.sock6, nend, buff) - } -} - -func (end *NativeEndpoint) SrcIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.src4().Src[0], - end.src4().Src[1], - end.src4().Src[2], - end.src4().Src[3], - ) - } else { - return end.src6().src[:] - } -} - -func (end *NativeEndpoint) DstIP() net.IP { - if !end.isV6 { - return net.IPv4( - end.dst4().Addr[0], - end.dst4().Addr[1], - end.dst4().Addr[2], - end.dst4().Addr[3], - ) - } else { - return end.dst6().Addr[:] - } -} - -func (end *NativeEndpoint) DstToBytes() []byte { - if !end.isV6 { - return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:] - } else { - return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:] - } -} - -func (end *NativeEndpoint) SrcToString() string { - return end.SrcIP().String() -} - -func (end *NativeEndpoint) DstToString() string { - var udpAddr net.UDPAddr - udpAddr.IP = end.DstIP() - if !end.isV6 { - udpAddr.Port = end.dst4().Port - } else { - udpAddr.Port = end.dst6().Port - } - return udpAddr.String() -} - -func (end *NativeEndpoint) ClearDst() { - for i := range end.dst { - end.dst[i] = 0 - } -} - -func (end *NativeEndpoint) ClearSrc() { - for i := range end.src { - end.src[i] = 0 - } -} - -func zoneToUint32(zone string) (uint32, error) { - if zone == "" { - return 0, nil - } - if intr, err := net.InterfaceByName(zone); err == nil { - return uint32(intr.Index), nil - } - n, err := strconv.ParseUint(zone, 10, 32) - return uint32(n), err -} - -func create4(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - addr := unix.SockaddrInet4{ - Port: int(port), - } - - // set sockopts and bind - - if err := func() error { - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IP, - unix.IP_PKTINFO, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet4).Port - } - - return fd, uint16(addr.Port), err -} - -func create6(port uint16) (int, uint16, error) { - - // create socket - - fd, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { - return FD_ERR, 0, err - } - - // set sockopts and bind - - addr := unix.SockaddrInet6{ - Port: int(port), - } - - if err := func() error { - - if err := unix.SetsockoptInt( - fd, - unix.SOL_SOCKET, - unix.SO_REUSEADDR, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_RECVPKTINFO, - 1, - ); err != nil { - return err - } - - if err := unix.SetsockoptInt( - fd, - unix.IPPROTO_IPV6, - unix.IPV6_V6ONLY, - 1, - ); err != nil { - return err - } - - return unix.Bind(fd, &addr) - - }(); err != nil { - unix.Close(fd) - return FD_ERR, 0, err - } - - sa, err := unix.Getsockname(fd) - if err == nil { - addr.Port = sa.(*unix.SockaddrInet6).Port - } - - return fd, uint16(addr.Port), err -} - -func send4(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IP, - Type: unix.IP_PKTINFO, - Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet4Pktinfo{ - Spec_dst: end.src4().Src, - Ifindex: end.src4().Ifindex, - }, - } - - end.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet4Pktinfo{} - end.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) - end.Unlock() - } - - return err -} - -func send6(sock int, end *NativeEndpoint, buff []byte) error { - - // construct message header - - cmsg := struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - }{ - unix.Cmsghdr{ - Level: unix.IPPROTO_IPV6, - Type: unix.IPV6_PKTINFO, - Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr, - }, - unix.Inet6Pktinfo{ - Addr: end.src6().src, - Ifindex: end.dst6().ZoneId, - }, - } - - if cmsg.pktinfo.Addr == [16]byte{} { - cmsg.pktinfo.Ifindex = 0 - } - - end.Lock() - _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() - - if err == nil { - return nil - } - - // clear src and retry - - if err == unix.EINVAL { - end.ClearSrc() - cmsg.pktinfo = unix.Inet6Pktinfo{} - end.Lock() - _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) - end.Unlock() - } - - return err -} - -func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet4Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = false - - if newDst4, ok := newDst.(*unix.SockaddrInet4); ok { - *end.dst4() = *newDst4 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IP && - cmsg.cmsghdr.Type == unix.IP_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo { - end.src4().Src = cmsg.pktinfo.Spec_dst - end.src4().Ifindex = cmsg.pktinfo.Ifindex - } - - return size, nil -} - -func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { - - // construct message header - - var cmsg struct { - cmsghdr unix.Cmsghdr - pktinfo unix.Inet6Pktinfo - } - - size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0) - - if err != nil { - return 0, err - } - end.isV6 = true - - if newDst6, ok := newDst.(*unix.SockaddrInet6); ok { - *end.dst6() = *newDst6 - } - - // update source cache - - if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 && - cmsg.cmsghdr.Type == unix.IPV6_PKTINFO && - cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo { - end.src6().src = cmsg.pktinfo.Addr - end.dst6().ZoneId = cmsg.pktinfo.Ifindex - } - - return size, nil -} diff --git a/conn/conn_test.go b/conn/conn_test.go new file mode 100644 index 0000000..c6194ee --- /dev/null +++ b/conn/conn_test.go @@ -0,0 +1,24 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "testing" +) + +func TestPrettyName(t *testing.T) { + var ( + recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } + ) + + const want = "TestPrettyName" + + t.Run("ReceiveFunc.PrettyName", func(t *testing.T) { + if got := recvFunc.PrettyName(); got != want { + t.Errorf("PrettyName() = %v, want %v", got, want) + } + }) +} diff --git a/conn/controlfns.go b/conn/controlfns.go new file mode 100644 index 0000000..4f7d90f --- /dev/null +++ b/conn/controlfns.go @@ -0,0 +1,43 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + "syscall" +) + +// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is +// the max supported by a default configuration of macOS. Some platforms will +// silently clamp the value to other maximums, such as linux clamping to +// net.core.{r,w}mem_max (see _linux.go for additional implementation that works +// around this limitation) +const socketBufferSize = 7 << 20 + +// controlFn is the callback function signature from net.ListenConfig.Control. +// It is used to apply platform specific configuration to the socket prior to +// bind. +type controlFn func(network, address string, c syscall.RawConn) error + +// controlFns is a list of functions that are called from the listen config +// that can apply socket options. +var controlFns = []controlFn{} + +// listenConfig returns a net.ListenConfig that applies the controlFns to the +// socket prior to bind. This is used to apply socket buffer sizing and packet +// information OOB configuration for sticky sockets. +func listenConfig() *net.ListenConfig { + return &net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) error { + for _, fn := range controlFns { + if err := fn(network, address, c); err != nil { + return err + } + } + return nil + }, + } +} diff --git a/conn/controlfns_linux.go b/conn/controlfns_linux.go new file mode 100644 index 0000000..f6ab1d2 --- /dev/null +++ b/conn/controlfns_linux.go @@ -0,0 +1,69 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "runtime" + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + + // Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by + // using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to + // fail silently - the result of failure is lower performance on very fast + // links or high latency links. + func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + // Set up to *mem_max + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) + // Set beyond *mem_max if CAP_NET_ADMIN + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize) + }) + }, + + // Enable receiving of the packet information (IP_PKTINFO for IPv4, + // IPV6_PKTINFO for IPv6) that is used to implement sticky socket support. + func(network, address string, c syscall.RawConn) error { + var err error + switch network { + case "udp4": + if runtime.GOOS != "android" { + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1) + }) + } + case "udp6": + c.Control(func(fd uintptr) { + if runtime.GOOS != "android" { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1) + if err != nil { + return + } + } + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + default: + err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL) + } + return err + }, + + // Attempt to enable UDP_GRO + func(network, address string, c syscall.RawConn) error { + c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1) + }) + return nil + }, + ) +} diff --git a/conn/controlfns_unix.go b/conn/controlfns_unix.go new file mode 100644 index 0000000..91692c0 --- /dev/null +++ b/conn/controlfns_unix.go @@ -0,0 +1,35 @@ +//go:build !windows && !linux && !wasm + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "syscall" + + "golang.org/x/sys/unix" +) + +func init() { + controlFns = append(controlFns, + func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize) + _ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize) + }) + }, + + func(network, address string, c syscall.RawConn) error { + var err error + if network == "udp6" { + c.Control(func(fd uintptr) { + err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1) + }) + } + return err + }, + ) +} diff --git a/conn/controlfns_windows.go b/conn/controlfns_windows.go new file mode 100644 index 0000000..c3bdf7d --- /dev/null +++ b/conn/controlfns_windows.go @@ -0,0 +1,23 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "syscall" + + "golang.org/x/sys/windows" +) + +func init() { + controlFns = append(controlFns, + func(network, address string, c syscall.RawConn) error { + return c.Control(func(fd uintptr) { + _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize) + _ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize) + }) + }, + ) +} diff --git a/conn/default.go b/conn/default.go new file mode 100644 index 0000000..b6f761b --- /dev/null +++ b/conn/default.go @@ -0,0 +1,10 @@ +//go:build !windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func NewDefaultBind() Bind { return NewStdNetBind() } diff --git a/conn/errors_default.go b/conn/errors_default.go new file mode 100644 index 0000000..f1e5b90 --- /dev/null +++ b/conn/errors_default.go @@ -0,0 +1,12 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +func errShouldDisableUDPGSO(err error) bool { + return false +} diff --git a/conn/errors_linux.go b/conn/errors_linux.go new file mode 100644 index 0000000..8e61000 --- /dev/null +++ b/conn/errors_linux.go @@ -0,0 +1,26 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +func errShouldDisableUDPGSO(err error) bool { + var serr *os.SyscallError + if errors.As(err, &serr) { + // EIO is returned by udp_send_skb() if the device driver does not have + // tx checksumming enabled, which is a hard requirement of UDP_SEGMENT. + // See: + // https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228 + // https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942 + return serr.Err == unix.EIO + } + return false +} diff --git a/conn/features_default.go b/conn/features_default.go new file mode 100644 index 0000000..d53ff5f --- /dev/null +++ b/conn/features_default.go @@ -0,0 +1,15 @@ +//go:build !linux +// +build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net" + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + return +} diff --git a/conn/features_linux.go b/conn/features_linux.go new file mode 100644 index 0000000..8959d93 --- /dev/null +++ b/conn/features_linux.go @@ -0,0 +1,29 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net" + + "golang.org/x/sys/unix" +) + +func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) { + rc, err := conn.SyscallConn() + if err != nil { + return + } + err = rc.Control(func(fd uintptr) { + _, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT) + txOffload = errSyscall == nil + opt, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO) + rxOffload = errSyscall == nil && opt == 1 + }) + if err != nil { + return false, false + } + return txOffload, rxOffload +} diff --git a/conn/gso_default.go b/conn/gso_default.go new file mode 100644 index 0000000..57780db --- /dev/null +++ b/conn/gso_default.go @@ -0,0 +1,21 @@ +//go:build !linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. +func setGSOSize(control *[]byte, gsoSize uint16) { +} + +// gsoControlSize returns the recommended buffer size for pooling sticky and UDP +// offloading control data. +const gsoControlSize = 0 diff --git a/conn/gso_linux.go b/conn/gso_linux.go new file mode 100644 index 0000000..8596b29 --- /dev/null +++ b/conn/gso_linux.go @@ -0,0 +1,65 @@ +//go:build linux + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "fmt" + "unsafe" + + "golang.org/x/sys/unix" +) + +const ( + sizeOfGSOData = 2 +) + +// getGSOSize parses control for UDP_GRO and if found returns its GSO size data. +func getGSOSize(control []byte) (int, error) { + var ( + hdr unix.Cmsghdr + data []byte + rem = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return 0, fmt.Errorf("error parsing socket control message: %w", err) + } + if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData { + var gso uint16 + copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData]) + return int(gso), nil + } + } + return 0, nil +} + +// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing +// data in control untouched. +func setGSOSize(control *[]byte, gsoSize uint16) { + existingLen := len(*control) + avail := cap(*control) - existingLen + space := unix.CmsgSpace(sizeOfGSOData) + if avail < space { + return + } + *control = (*control)[:cap(*control)] + gsoControl := (*control)[existingLen:] + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0])) + hdr.Level = unix.SOL_UDP + hdr.Type = unix.UDP_SEGMENT + hdr.SetLen(unix.CmsgLen(sizeOfGSOData)) + copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData)) + *control = (*control)[:existingLen+space] +} + +// gsoControlSize returns the recommended buffer size for pooling UDP +// offloading control data. +var gsoControlSize = unix.CmsgSpace(sizeOfGSOData) diff --git a/conn/mark_default.go b/conn/mark_default.go index fc41ba9..3102384 100644 --- a/conn/mark_default.go +++ b/conn/mark_default.go @@ -1,12 +1,12 @@ -// +build !linux,!openbsd,!freebsd +//go:build !linux && !openbsd && !freebsd /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn -func (bind *nativeBind) SetMark(mark uint32) error { +func (s *StdNetBind) SetMark(mark uint32) error { return nil } diff --git a/conn/mark_unix.go b/conn/mark_unix.go index 5334582..d9e46ee 100644 --- a/conn/mark_unix.go +++ b/conn/mark_unix.go @@ -1,8 +1,8 @@ -// +build android openbsd freebsd +//go:build linux || openbsd || freebsd /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn @@ -26,13 +26,13 @@ func init() { } } -func (bind *nativeBind) SetMark(mark uint32) error { +func (s *StdNetBind) SetMark(mark uint32) error { var operr error if fwmarkIoctl == 0 { return nil } - if bind.ipv4 != nil { - fd, err := bind.ipv4.SyscallConn() + if s.ipv4 != nil { + fd, err := s.ipv4.SyscallConn() if err != nil { return err } @@ -46,8 +46,8 @@ func (bind *nativeBind) SetMark(mark uint32) error { return err } } - if bind.ipv6 != nil { - fd, err := bind.ipv6.SyscallConn() + if s.ipv6 != nil { + fd, err := s.ipv6.SyscallConn() if err != nil { return err } diff --git a/conn/sticky_default.go b/conn/sticky_default.go new file mode 100644 index 0000000..0b21386 --- /dev/null +++ b/conn/sticky_default.go @@ -0,0 +1,42 @@ +//go:build !linux || android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import "net/netip" + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return "" +} + +// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets +// {get,set}srcControl feature set, but use alternatively named flags and need +// ports and require testing. + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { +} + +// setSrcControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { +} + +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +const stickyControlSize = 0 + +const StdNetSupportsStickySockets = false diff --git a/conn/sticky_linux.go b/conn/sticky_linux.go new file mode 100644 index 0000000..8e206e9 --- /dev/null +++ b/conn/sticky_linux.go @@ -0,0 +1,112 @@ +//go:build linux && !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "net/netip" + "unsafe" + + "golang.org/x/sys/unix" +) + +func (e *StdNetEndpoint) SrcIP() netip.Addr { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return netip.AddrFrom4(info.Spec_dst) + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + // TODO: set zone. in order to do so we need to check if the address is + // link local, and if it is perform a syscall to turn the ifindex into a + // zone string because netip uses string zones. + return netip.AddrFrom16(info.Addr) + } + return netip.Addr{} +} + +func (e *StdNetEndpoint) SrcIfidx() int32 { + switch len(e.src) { + case unix.CmsgSpace(unix.SizeofInet4Pktinfo): + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return info.Ifindex + case unix.CmsgSpace(unix.SizeofInet6Pktinfo): + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)])) + return int32(info.Ifindex) + } + return 0 +} + +func (e *StdNetEndpoint) SrcToString() string { + return e.SrcIP().String() +} + +// getSrcFromControl parses the control for PKTINFO and if found updates ep with +// the source information found. +func getSrcFromControl(control []byte, ep *StdNetEndpoint) { + ep.ClearSrc() + + var ( + hdr unix.Cmsghdr + data []byte + rem []byte = control + err error + ) + + for len(rem) > unix.SizeofCmsghdr { + hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem) + if err != nil { + return + } + + if hdr.Level == unix.IPPROTO_IP && + hdr.Type == unix.IP_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + } + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + + if hdr.Level == unix.IPPROTO_IPV6 && + hdr.Type == unix.IPV6_PKTINFO { + + if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) { + ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + } + + ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)] + + hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr) + copy(ep.src, hdrBuf) + copy(ep.src[unix.CmsgLen(0):], data) + return + } + } +} + +// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address +// and source ifindex found in ep. control's len will be set to 0 in the event +// that ep is a default value. +func setSrcControl(control *[]byte, ep *StdNetEndpoint) { + if cap(*control) < len(ep.src) { + return + } + *control = (*control)[:0] + *control = append(*control, ep.src...) +} + +// stickyControlSize returns the recommended buffer size for pooling sticky +// offloading control data. +var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo) + +const StdNetSupportsStickySockets = true diff --git a/conn/sticky_linux_test.go b/conn/sticky_linux_test.go new file mode 100644 index 0000000..d2bd584 --- /dev/null +++ b/conn/sticky_linux_test.go @@ -0,0 +1,266 @@ +//go:build linux && !android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package conn + +import ( + "context" + "net" + "net/netip" + "runtime" + "testing" + "unsafe" + + "golang.org/x/sys/unix" +) + +func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { + var buf []byte + if addr.Is4() { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IP, + Type: unix.IP_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet4Pktinfo{ + Ifindex: ifidx, + Spec_dst: addr.As4(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) + } else { + buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) + hdr := unix.Cmsghdr{ + Level: unix.IPPROTO_IPV6, + Type: unix.IPV6_PKTINFO, + } + hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) + copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) + + info := unix.Inet6Pktinfo{ + Ifindex: uint32(ifidx), + Addr: addr.As16(), + } + copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) + } + + ep.src = buf +} + +func Test_setSrcControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), + } + setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) + + control := make([]byte, stickyControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IP { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IP_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { + t.Errorf("unexpected address: %v", info.Spec_dst) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("IPv6", func(t *testing.T) { + ep := &StdNetEndpoint{ + AddrPort: netip.MustParseAddrPort("[::1]:1234"), + } + setSrc(ep, netip.MustParseAddr("::1"), 5) + + control := make([]byte, stickyControlSize) + + setSrcControl(&control, ep) + + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + if hdr.Level != unix.IPPROTO_IPV6 { + t.Errorf("unexpected level: %d", hdr.Level) + } + if hdr.Type != unix.IPV6_PKTINFO { + t.Errorf("unexpected type: %d", hdr.Type) + } + if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { + t.Errorf("unexpected length: %d", hdr.Len) + } + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + if info.Addr != ep.SrcIP().As16() { + t.Errorf("unexpected address: %v", info.Addr) + } + if info.Ifindex != 5 { + t.Errorf("unexpected ifindex: %d", info.Ifindex) + } + }) + + t.Run("ClearOnNoSrc", func(t *testing.T) { + control := make([]byte, stickyControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = 1 + hdr.Type = 2 + hdr.Len = 3 + + setSrcControl(&control, &StdNetEndpoint{}) + + if len(control) != 0 { + t.Errorf("unexpected control: %v", control) + } + }) +} + +func Test_getSrcFromControl(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + control := make([]byte, stickyControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("IPv6", func(t *testing.T) { + control := make([]byte, stickyControlSize) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IPV6 + hdr.Type = unix.IPV6_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) + info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} + info.Ifindex = 5 + + ep := &StdNetEndpoint{} + getSrcFromControl(control, ep) + + if ep.SrcIP() != netip.MustParseAddr("::1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("ClearOnEmpty", func(t *testing.T) { + var control []byte + ep := &StdNetEndpoint{} + setSrc(ep, netip.MustParseAddr("::1"), 5) + + getSrcFromControl(control, ep) + if ep.SrcIP().IsValid() { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 0 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) + t.Run("Multiple", func(t *testing.T) { + zeroControl := make([]byte, unix.CmsgSpace(0)) + zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) + zeroHdr.SetLen(unix.CmsgLen(0)) + + control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) + hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) + hdr.Level = unix.IPPROTO_IP + hdr.Type = unix.IP_PKTINFO + hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) + info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) + info.Spec_dst = [4]byte{127, 0, 0, 1} + info.Ifindex = 5 + + combined := make([]byte, 0) + combined = append(combined, zeroControl...) + combined = append(combined, control...) + + ep := &StdNetEndpoint{} + getSrcFromControl(combined, ep) + + if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { + t.Errorf("unexpected address: %v", ep.SrcIP()) + } + if ep.SrcIfidx() != 5 { + t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) + } + }) +} + +func Test_listenConfig(t *testing.T) { + t.Run("IPv4", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IP_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) + t.Run("IPv6", func(t *testing.T) { + conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") + if err != nil { + t.Fatal(err) + } + sc, err := conn.(*net.UDPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + + if runtime.GOOS == "linux" { + var i int + sc.Control(func(fd uintptr) { + i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) + }) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Error("IPV6_PKTINFO not set!") + } + } else { + t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) + } + }) +} diff --git a/conn/winrio/rio_windows.go b/conn/winrio/rio_windows.go new file mode 100644 index 0000000..d1037bb --- /dev/null +++ b/conn/winrio/rio_windows.go @@ -0,0 +1,254 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package winrio + +import ( + "log" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +const ( + MsgDontNotify = 1 + MsgDefer = 2 + MsgWaitAll = 4 + MsgCommitOnly = 8 + + MaxCqSize = 0x8000000 + + invalidBufferId = 0xFFFFFFFF + invalidCq = 0 + invalidRq = 0 + corruptCq = 0xFFFFFFFF +) + +var extensionFunctionTable struct { + cbSize uint32 + rioReceive uintptr + rioReceiveEx uintptr + rioSend uintptr + rioSendEx uintptr + rioCloseCompletionQueue uintptr + rioCreateCompletionQueue uintptr + rioCreateRequestQueue uintptr + rioDequeueCompletion uintptr + rioDeregisterBuffer uintptr + rioNotify uintptr + rioRegisterBuffer uintptr + rioResizeCompletionQueue uintptr + rioResizeRequestQueue uintptr +} + +type Cq uintptr + +type Rq uintptr + +type BufferId uintptr + +type Buffer struct { + Id BufferId + Offset uint32 + Length uint32 +} + +type Result struct { + Status int32 + BytesTransferred uint32 + SocketContext uint64 + RequestContext uint64 +} + +type notificationCompletionType uint32 + +const ( + eventCompletion notificationCompletionType = 1 + iocpCompletion notificationCompletionType = 2 +) + +type eventNotificationCompletion struct { + completionType notificationCompletionType + event windows.Handle + notifyReset uint32 +} + +type iocpNotificationCompletion struct { + completionType notificationCompletionType + iocp windows.Handle + key uintptr + overlapped *windows.Overlapped +} + +var ( + initialized sync.Once + available bool +) + +func Initialize() bool { + initialized.Do(func() { + var ( + err error + socket windows.Handle + cq Cq + ) + defer func() { + if err == nil { + return + } + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 { + return + } + log.Printf("Registered I/O is unavailable: %v", err) + }() + socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP) + if err != nil { + return + } + defer windows.CloseHandle(socket) + WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}} + const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024 + ob := uint32(0) + err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER, + (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)), + (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)), + &ob, nil, 0) + if err != nil { + return + } + + // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes + // failures in RIOCreateRequestQueue, so keep going to be certain this is supported. + var iocp windows.Handle + iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) + if err != nil { + return + } + defer windows.CloseHandle(iocp) + var overlapped windows.Overlapped + cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped) + if err != nil { + return + } + defer CloseCompletionQueue(cq) + _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0) + if err != nil { + return + } + available = true + }) + return available +} + +func Socket(af, typ, proto int32) (windows.Handle, error) { + return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO) +} + +func CloseCompletionQueue(cq Cq) { + _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0) +} + +func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) { + notificationCompletion := &eventNotificationCompletion{ + completionType: eventCompletion, + event: event, + } + if notifyReset { + notificationCompletion.notifyReset = 1 + } + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) { + notificationCompletion := &iocpNotificationCompletion{ + completionType: iocpCompletion, + iocp: iocp, + key: key, + overlapped: overlapped, + } + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) { + ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0) + if ret == invalidCq { + return 0, err + } + return Cq(ret), nil +} + +func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0) + if ret == invalidRq { + return 0, err + } + return Rq(ret), nil +} + +func DequeueCompletion(cq Cq, results []Result) uint32 { + var array uintptr + if len(results) > 0 { + array = uintptr(unsafe.Pointer(&results[0])) + } + ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results))) + if ret == corruptCq { + panic("cq is corrupt") + } + return uint32(ret) +} + +func DeregisterBuffer(id BufferId) { + _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0) +} + +func RegisterBuffer(buffer []byte) (BufferId, error) { + var buf unsafe.Pointer + if len(buffer) > 0 { + buf = unsafe.Pointer(&buffer[0]) + } + return RegisterPointer(buf, uint32(len(buffer))) +} + +func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) { + ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0) + if ret == invalidBufferId { + return 0, err + } + return BufferId(ret), nil +} + +func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) + if ret == 0 { + return err + } + return nil +} + +func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error { + ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext) + if ret == 0 { + return err + } + return nil +} + +func Notify(cq Cq) error { + ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0) + if ret != 0 { + return windows.Errno(ret) + } + return nil +} diff --git a/device/allowedips.go b/device/allowedips.go index efc27c0..fa46f97 100644 --- a/device/allowedips.go +++ b/device/allowedips.go @@ -1,173 +1,201 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( + "container/list" + "encoding/binary" "errors" "math/bits" "net" + "net/netip" "sync" "unsafe" ) -type trieEntry struct { - cidr uint - child [2]*trieEntry - bits net.IP - peer *Peer - - // index of "branching" bit - - bit_at_byte uint - bit_at_shift uint -} - -func isLittleEndian() bool { - one := uint32(1) - return *(*byte)(unsafe.Pointer(&one)) != 0 +type parentIndirection struct { + parentBit **trieEntry + parentBitType uint8 } -func swapU32(i uint32) uint32 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes32(i) -} - -func swapU64(i uint64) uint64 { - if !isLittleEndian() { - return i - } - - return bits.ReverseBytes64(i) +type trieEntry struct { + peer *Peer + child [2]*trieEntry + parent parentIndirection + cidr uint8 + bitAtByte uint8 + bitAtShift uint8 + bits []byte + perPeerElem *list.Element } -func commonBits(ip1 net.IP, ip2 net.IP) uint { +func commonBits(ip1, ip2 []byte) uint8 { size := len(ip1) if size == net.IPv4len { - a := (*uint32)(unsafe.Pointer(&ip1[0])) - b := (*uint32)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b - return uint(bits.LeadingZeros32(swapU32(x))) + a := binary.BigEndian.Uint32(ip1) + b := binary.BigEndian.Uint32(ip2) + x := a ^ b + return uint8(bits.LeadingZeros32(x)) } else if size == net.IPv6len { - a := (*uint64)(unsafe.Pointer(&ip1[0])) - b := (*uint64)(unsafe.Pointer(&ip2[0])) - x := *a ^ *b + a := binary.BigEndian.Uint64(ip1) + b := binary.BigEndian.Uint64(ip2) + x := a ^ b if x != 0 { - return uint(bits.LeadingZeros64(swapU64(x))) + return uint8(bits.LeadingZeros64(x)) } - a = (*uint64)(unsafe.Pointer(&ip1[8])) - b = (*uint64)(unsafe.Pointer(&ip2[8])) - x = *a ^ *b - return 64 + uint(bits.LeadingZeros64(swapU64(x))) + a = binary.BigEndian.Uint64(ip1[8:]) + b = binary.BigEndian.Uint64(ip2[8:]) + x = a ^ b + return 64 + uint8(bits.LeadingZeros64(x)) } else { panic("Wrong size bit string") } } -func (node *trieEntry) removeByPeer(p *Peer) *trieEntry { - if node == nil { - return node - } - - // walk recursively - - node.child[0] = node.child[0].removeByPeer(p) - node.child[1] = node.child[1].removeByPeer(p) +func (node *trieEntry) addToPeerEntries() { + node.perPeerElem = node.peer.trieEntries.PushBack(node) +} - if node.peer != p { - return node +func (node *trieEntry) removeFromPeerEntries() { + if node.perPeerElem != nil { + node.peer.trieEntries.Remove(node.perPeerElem) + node.perPeerElem = nil } +} - // remove peer & merge +func (node *trieEntry) choose(ip []byte) byte { + return (ip[node.bitAtByte] >> node.bitAtShift) & 1 +} - node.peer = nil - if node.child[0] == nil { - return node.child[1] +func (node *trieEntry) maskSelf() { + mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) + for i := 0; i < len(mask); i++ { + node.bits[i] &= mask[i] } - return node.child[0] } -func (node *trieEntry) choose(ip net.IP) byte { - return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1 +func (node *trieEntry) zeroizePointers() { + // Make the garbage collector's life slightly easier + node.peer = nil + node.child[0] = nil + node.child[1] = nil + node.parent.parentBit = nil } -func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry { - - // at leaf - - if node == nil { - return &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), +func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) { + for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr { + parent = node + if parent.cidr == cidr { + exact = true + return } + bit := node.choose(ip) + node = node.child[bit] } + return +} - // traverse deeper - - common := commonBits(node.bits, ip) - if node.cidr <= cidr && common >= node.cidr { - if node.cidr == cidr { - node.peer = peer - return node +func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) { + if *trie.parentBit == nil { + node := &trieEntry{ + peer: peer, + parent: trie, + bits: ip, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } - bit := node.choose(ip) - node.child[bit] = node.child[bit].insert(ip, cidr, peer) - return node + node.maskSelf() + node.addToPeerEntries() + *trie.parentBit = node + return + } + node, exact := (*trie.parentBit).nodePlacement(ip, cidr) + if exact { + node.removeFromPeerEntries() + node.peer = peer + node.addToPeerEntries() + return } - - // split node newNode := &trieEntry{ - bits: ip, - peer: peer, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + peer: peer, + bits: ip, + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), } + newNode.maskSelf() + newNode.addToPeerEntries() - cidr = min(cidr, common) - - // check for shorter prefix + var down *trieEntry + if node == nil { + down = *trie.parentBit + } else { + bit := node.choose(ip) + down = node.child[bit] + if down == nil { + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + return + } + } + common := commonBits(down.bits, ip) + if common < cidr { + cidr = common + } + parent := node if newNode.cidr == cidr { - bit := newNode.choose(node.bits) - newNode.child[bit] = node - return newNode + bit := newNode.choose(down.bits) + down.parent = parentIndirection{&newNode.child[bit], bit} + newNode.child[bit] = down + if parent == nil { + newNode.parent = trie + *trie.parentBit = newNode + } else { + bit := parent.choose(newNode.bits) + newNode.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = newNode + } + return } - // create new parent for node & newNode - - parent := &trieEntry{ - bits: ip, - peer: nil, - cidr: cidr, - bit_at_byte: cidr / 8, - bit_at_shift: 7 - (cidr % 8), + node = &trieEntry{ + bits: append([]byte{}, newNode.bits...), + cidr: cidr, + bitAtByte: cidr / 8, + bitAtShift: 7 - (cidr % 8), + } + node.maskSelf() + + bit := node.choose(down.bits) + down.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = down + bit = node.choose(newNode.bits) + newNode.parent = parentIndirection{&node.child[bit], bit} + node.child[bit] = newNode + if parent == nil { + node.parent = trie + *trie.parentBit = node + } else { + bit := parent.choose(node.bits) + node.parent = parentIndirection{&parent.child[bit], bit} + parent.child[bit] = node } - - bit := parent.choose(ip) - parent.child[bit] = newNode - parent.child[bit^1] = node - - return parent } -func (node *trieEntry) lookup(ip net.IP) *Peer { +func (node *trieEntry) lookup(ip []byte) *Peer { var found *Peer - size := uint(len(ip)) + size := uint8(len(ip)) for node != nil && commonBits(node.bits, ip) >= node.cidr { if node.peer != nil { found = node.peer } - if node.bit_at_byte == size { + if node.bitAtByte == size { break } bit := node.choose(ip) @@ -176,76 +204,91 @@ func (node *trieEntry) lookup(ip net.IP) *Peer { return found } -func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet { - if node == nil { - return results - } - if node.peer == p { - mask := net.CIDRMask(int(node.cidr), len(node.bits)*8) - results = append(results, net.IPNet{ - Mask: mask, - IP: node.bits.Mask(mask), - }) - } - results = node.child[0].entriesForPeer(p, results) - results = node.child[1].entriesForPeer(p, results) - return results -} - type AllowedIPs struct { IPv4 *trieEntry IPv6 *trieEntry mutex sync.RWMutex } -func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet { +func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) { table.mutex.RLock() defer table.mutex.RUnlock() - allowed := make([]net.IPNet, 0, 10) - allowed = table.IPv4.entriesForPeer(peer, allowed) - allowed = table.IPv6.entriesForPeer(peer, allowed) - return allowed -} - -func (table *AllowedIPs) Reset() { - table.mutex.Lock() - defer table.mutex.Unlock() - - table.IPv4 = nil - table.IPv6 = nil + for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() { + node := elem.Value.(*trieEntry) + a, _ := netip.AddrFromSlice(node.bits) + if !cb(netip.PrefixFrom(a, int(node.cidr))) { + return + } + } } func (table *AllowedIPs) RemoveByPeer(peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - table.IPv4 = table.IPv4.removeByPeer(peer) - table.IPv6 = table.IPv6.removeByPeer(peer) + var next *list.Element + for elem := peer.trieEntries.Front(); elem != nil; elem = next { + next = elem.Next() + node := elem.Value.(*trieEntry) + + node.removeFromPeerEntries() + node.peer = nil + if node.child[0] != nil && node.child[1] != nil { + continue + } + bit := 0 + if node.child[0] == nil { + bit = 1 + } + child := node.child[bit] + if child != nil { + child.parent = node.parent + } + *node.parent.parentBit = child + if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 { + node.zeroizePointers() + continue + } + parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType))) + if parent.peer != nil { + node.zeroizePointers() + continue + } + child = parent.child[node.parent.parentBitType^1] + if child != nil { + child.parent = parent.parent + } + *parent.parent.parentBit = child + node.zeroizePointers() + parent.zeroizePointers() + } } -func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) { +func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) { table.mutex.Lock() defer table.mutex.Unlock() - switch len(ip) { - case net.IPv6len: - table.IPv6 = table.IPv6.insert(ip, cidr, peer) - case net.IPv4len: - table.IPv4 = table.IPv4.insert(ip, cidr, peer) - default: + if prefix.Addr().Is6() { + ip := prefix.Addr().As16() + parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else if prefix.Addr().Is4() { + ip := prefix.Addr().As4() + parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer) + } else { panic(errors.New("inserting unknown address type")) } } -func (table *AllowedIPs) LookupIPv4(address []byte) *Peer { +func (table *AllowedIPs) Lookup(ip []byte) *Peer { table.mutex.RLock() defer table.mutex.RUnlock() - return table.IPv4.lookup(address) -} - -func (table *AllowedIPs) LookupIPv6(address []byte) *Peer { - table.mutex.RLock() - defer table.mutex.RUnlock() - return table.IPv6.lookup(address) + switch len(ip) { + case net.IPv6len: + return table.IPv6.lookup(ip) + case net.IPv4len: + return table.IPv4.lookup(ip) + default: + panic(errors.New("looking up unknown address type")) + } } diff --git a/device/allowedips_rand_test.go b/device/allowedips_rand_test.go index 59c10f7..07065c3 100644 --- a/device/allowedips_rand_test.go +++ b/device/allowedips_rand_test.go @@ -1,25 +1,28 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" + "net" + "net/netip" "sort" "testing" ) const ( - NumberOfPeers = 100 - NumberOfAddresses = 250 - NumberOfTests = 10000 + NumberOfPeers = 100 + NumberOfPeerRemovals = 4 + NumberOfAddresses = 250 + NumberOfTests = 10000 ) type SlowNode struct { peer *Peer - cidr uint + cidr uint8 bits []byte } @@ -37,7 +40,7 @@ func (r SlowRouter) Swap(i, j int) { r[i], r[j] = r[j], r[i] } -func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter { +func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter { for _, t := range r { if t.cidr == cidr && commonBits(t.bits, addr) >= cidr { t.peer = peer @@ -64,68 +67,75 @@ func (r SlowRouter) Lookup(addr []byte) *Peer { return nil } -func TestTrieRandomIPv4(t *testing.T) { - var trie *trieEntry - var slow SlowRouter - var peers []*Peer - - rand.Seed(1) - - const AddressLength = 4 - - for n := 0; n < NumberOfPeers; n += 1 { - peers = append(peers, &Peer{}) - } - - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) - } - - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) +func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter { + n := 0 + for _, x := range r { + if x.peer != peer { + r[n] = x + n++ } } + return r[:n] } -func TestTrieRandomIPv6(t *testing.T) { - var trie *trieEntry - var slow SlowRouter +func TestTrieRandom(t *testing.T) { + var slow4, slow6 SlowRouter var peers []*Peer + var allowedIPs AllowedIPs rand.Seed(1) - const AddressLength = 16 - - for n := 0; n < NumberOfPeers; n += 1 { + for n := 0; n < NumberOfPeers; n++ { peers = append(peers, &Peer{}) } - for n := 0; n < NumberOfAddresses; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) - index := rand.Int() % NumberOfPeers - trie = trie.insert(addr[:], cidr, peers[index]) - slow = slow.Insert(addr[:], cidr, peers[index]) + for n := 0; n < NumberOfAddresses; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + cidr := uint8(rand.Intn(32) + 1) + index := rand.Intn(NumberOfPeers) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index]) + slow4 = slow4.Insert(addr4[:], cidr, peers[index]) + + var addr6 [16]byte + rand.Read(addr6[:]) + cidr = uint8(rand.Intn(128) + 1) + index = rand.Intn(NumberOfPeers) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index]) + slow6 = slow6.Insert(addr6[:], cidr, peers[index]) } - for n := 0; n < NumberOfTests; n += 1 { - var addr [AddressLength]byte - rand.Read(addr[:]) - peer1 := slow.Lookup(addr[:]) - peer2 := trie.lookup(addr[:]) - if peer1 != peer2 { - t.Error("Trie did not match naive implementation, for:", addr) + var p int + for p = 0; ; p++ { + for n := 0; n < NumberOfTests; n++ { + var addr4 [4]byte + rand.Read(addr4[:]) + peer1 := slow4.Lookup(addr4[:]) + peer2 := allowedIPs.Lookup(addr4[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2) + } + + var addr6 [16]byte + rand.Read(addr6[:]) + peer1 = slow6.Lookup(addr6[:]) + peer2 = allowedIPs.Lookup(addr6[:]) + if peer1 != peer2 { + t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2) + } + } + if p >= len(peers) || p >= NumberOfPeerRemovals { + break } + allowedIPs.RemoveByPeer(peers[p]) + slow4 = slow4.RemoveByPeer(peers[p]) + slow6 = slow6.RemoveByPeer(peers[p]) + } + for ; p < len(peers); p++ { + allowedIPs.RemoveByPeer(peers[p]) + } + + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Failed to remove all nodes from trie by peer") } } diff --git a/device/allowedips_test.go b/device/allowedips_test.go index 075ff06..cde068e 100644 --- a/device/allowedips_test.go +++ b/device/allowedips_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,40 +8,17 @@ package device import ( "math/rand" "net" + "net/netip" "testing" ) -/* Todo: More comprehensive - */ - type testPairCommonBits struct { s1 []byte s2 []byte - match uint -} - -type testPairTrieInsert struct { - key []byte - cidr uint - peer *Peer -} - -type testPairTrieLookup struct { - key []byte - peer *Peer -} - -func printTrie(t *testing.T, p *trieEntry) { - if p == nil { - return - } - t.Log(p) - printTrie(t, p.child[0]) - printTrie(t, p.child[1]) + match uint8 } func TestCommonBits(t *testing.T) { - tests := []testPairCommonBits{ {s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7}, {s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13}, @@ -62,27 +39,28 @@ func TestCommonBits(t *testing.T) { } } -func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) { +func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) { var trie *trieEntry var peers []*Peer + root := parentIndirection{&trie, 2} rand.Seed(1) const AddressLength = 4 - for n := 0; n < peerNumber; n += 1 { + for n := 0; n < peerNumber; n++ { peers = append(peers, &Peer{}) } - for n := 0; n < addressNumber; n += 1 { + for n := 0; n < addressNumber; n++ { var addr [AddressLength]byte rand.Read(addr[:]) - cidr := uint(rand.Uint32() % (AddressLength * 8)) + cidr := uint8(rand.Uint32() % (AddressLength * 8)) index := rand.Int() % peerNumber - trie = trie.insert(addr[:], cidr, peers[index]) + root.insert(addr[:], cidr, peers[index]) } - for n := 0; n < b.N; n += 1 { + for n := 0; n < b.N; n++ { var addr [AddressLength]byte rand.Read(addr[:]) trie.lookup(addr[:]) @@ -117,21 +95,21 @@ func TestTrieIPv4(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs - insert := func(peer *Peer, a, b, c, d byte, cidr uint) { - trie = trie.insert([]byte{a, b, c, d}, cidr, peer) + insert := func(peer *Peer, a, b, c, d byte, cidr uint8) { + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p != peer { t.Error("Assert EQ failed") } } assertNEQ := func(peer *Peer, a, b, c, d byte) { - p := trie.lookup([]byte{a, b, c, d}) + p := allowedIPs.Lookup([]byte{a, b, c, d}) if p == peer { t.Error("Assert NEQ failed") } @@ -173,7 +151,7 @@ func TestTrieIPv4(t *testing.T) { assertEQ(a, 192, 0, 0, 0) assertEQ(a, 255, 0, 0, 0) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 1, 0, 0, 0) assertNEQ(a, 64, 0, 0, 0) @@ -181,12 +159,21 @@ func TestTrieIPv4(t *testing.T) { assertNEQ(a, 192, 0, 0, 0) assertNEQ(a, 255, 0, 0, 0) - trie = nil + allowedIPs.RemoveByPeer(a) + allowedIPs.RemoveByPeer(b) + allowedIPs.RemoveByPeer(c) + allowedIPs.RemoveByPeer(d) + allowedIPs.RemoveByPeer(e) + allowedIPs.RemoveByPeer(g) + allowedIPs.RemoveByPeer(h) + if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil { + t.Error("Expected removing all the peers to empty trie, but it did not") + } insert(a, 192, 168, 0, 0, 16) insert(a, 192, 168, 0, 0, 24) - trie = trie.removeByPeer(a) + allowedIPs.RemoveByPeer(a) assertNEQ(a, 192, 168, 0, 1) } @@ -204,7 +191,7 @@ func TestTrieIPv6(t *testing.T) { g := &Peer{} h := &Peer{} - var trie *trieEntry + var allowedIPs AllowedIPs expand := func(a uint32) []byte { var out [4]byte @@ -215,13 +202,13 @@ func TestTrieIPv6(t *testing.T) { return out[:] } - insert := func(peer *Peer, a, b, c, d uint32, cidr uint) { + insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) { var addr []byte addr = append(addr, expand(a)...) addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - trie = trie.insert(addr, cidr, peer) + allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer) } assertEQ := func(peer *Peer, a, b, c, d uint32) { @@ -230,7 +217,7 @@ func TestTrieIPv6(t *testing.T) { addr = append(addr, expand(b)...) addr = append(addr, expand(c)...) addr = append(addr, expand(d)...) - p := trie.lookup(addr) + p := allowedIPs.Lookup(addr) if p != peer { t.Error("Assert EQ failed") } diff --git a/device/bind_test.go b/device/bind_test.go index c5f7f68..302a521 100644 --- a/device/bind_test.go +++ b/device/bind_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -14,14 +14,11 @@ import ( type DummyDatagram struct { msg []byte endpoint conn.Endpoint - world bool // better type } type DummyBind struct { in6 chan DummyDatagram - ou6 chan DummyDatagram in4 chan DummyDatagram - ou4 chan DummyDatagram closed bool } @@ -29,21 +26,21 @@ func (b *DummyBind) SetMark(v uint32) error { return nil } -func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) { +func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in6 if !ok { return 0, nil, errors.New("closed") } - copy(buff, datagram.msg) + copy(buf, datagram.msg) return len(datagram.msg), datagram.endpoint, nil } -func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) { +func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) { datagram, ok := <-b.in4 if !ok { return 0, nil, errors.New("closed") } - copy(buff, datagram.msg) + copy(buf, datagram.msg) return len(datagram.msg), datagram.endpoint, nil } @@ -54,6 +51,6 @@ func (b *DummyBind) Close() error { return nil } -func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error { +func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error { return nil } diff --git a/device/bindsocketshim.go b/device/bindsocketshim.go deleted file mode 100644 index c4dd4ef..0000000 --- a/device/bindsocketshim.go +++ /dev/null @@ -1,36 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - - "golang.zx2c4.com/wireguard/conn" -) - -// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. -func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - if device.net.bind == nil { - return errors.New("Bind is not yet initialized") - } - - if iface, ok := device.net.bind.(conn.BindToInterface); ok { - return iface.BindToInterface4(interfaceIndex, blackhole) - } - return nil -} - -// TODO(crawshaw): this method is a compatibility shim. Replace with direct use of conn. -func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - if device.net.bind == nil { - return errors.New("Bind is not yet initialized") - } - - if iface, ok := device.net.bind.(conn.BindToInterface); ok { - return iface.BindToInterface6(interfaceIndex, blackhole) - } - return nil -} diff --git a/device/boundif_android.go b/device/boundif_android.go deleted file mode 100644 index 6d0fecf..0000000 --- a/device/boundif_android.go +++ /dev/null @@ -1,44 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import "errors" - -func (device *Device) PeekLookAtSocketFd4() (fd int, err error) { - nb, ok := device.net.bind.(*nativeBind) - if !ok { - return 0, errors.New("no socket exists") - } - sysconn, err := nb.ipv4.SyscallConn() - if err != nil { - return - } - err = sysconn.Control(func(f uintptr) { - fd = int(f) - }) - if err != nil { - return - } - return -} - -func (device *Device) PeekLookAtSocketFd6() (fd int, err error) { - nb, ok := device.net.bind.(*nativeBind) - if !ok { - return 0, errors.New("no socket exists") - } - sysconn, err := nb.ipv6.SyscallConn() - if err != nil { - return - } - err = sysconn.Control(func(f uintptr) { - fd = int(f) - }) - if err != nil { - return - } - return -} diff --git a/device/channels.go b/device/channels.go new file mode 100644 index 0000000..e526f6b --- /dev/null +++ b/device/channels.go @@ -0,0 +1,137 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "runtime" + "sync" +) + +// An outboundQueue is a channel of QueueOutboundElements awaiting encryption. +// An outboundQueue is ref-counted using its wg field. +// An outboundQueue created with newOutboundQueue has one reference. +// Every additional writer must call wg.Add(1). +// Every completed writer must call wg.Done(). +// When no further writers will be added, +// call wg.Done to remove the initial reference. +// When the refcount hits 0, the queue's channel is closed. +type outboundQueue struct { + c chan *QueueOutboundElementsContainer + wg sync.WaitGroup +} + +func newOutboundQueue() *outboundQueue { + q := &outboundQueue{ + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + +// A inboundQueue is similar to an outboundQueue; see those docs. +type inboundQueue struct { + c chan *QueueInboundElementsContainer + wg sync.WaitGroup +} + +func newInboundQueue() *inboundQueue { + q := &inboundQueue{ + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + +// A handshakeQueue is similar to an outboundQueue; see those docs. +type handshakeQueue struct { + c chan QueueHandshakeElement + wg sync.WaitGroup +} + +func newHandshakeQueue() *handshakeQueue { + q := &handshakeQueue{ + c: make(chan QueueHandshakeElement, QueueHandshakeSize), + } + q.wg.Add(1) + go func() { + q.wg.Wait() + close(q.c) + }() + return q +} + +type autodrainingInboundQueue struct { + c chan *QueueInboundElementsContainer +} + +// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd. +// It is useful in cases in which is it hard to manage the lifetime of the channel. +// The returned channel must not be closed. Senders should signal shutdown using +// some other means, such as sending a sentinel nil values. +func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { + q := &autodrainingInboundQueue{ + c: make(chan *QueueInboundElementsContainer, QueueInboundSize), + } + runtime.SetFinalizer(q, device.flushInboundQueue) + return q +} + +func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { + for { + select { + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsContainer(elemsContainer) + default: + return + } + } +} + +type autodrainingOutboundQueue struct { + c chan *QueueOutboundElementsContainer +} + +// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd. +// It is useful in cases in which is it hard to manage the lifetime of the channel. +// The returned channel must not be closed. Senders should signal shutdown using +// some other means, such as sending a sentinel nil values. +// All sends to the channel must be best-effort, because there may be no receivers. +func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { + q := &autodrainingOutboundQueue{ + c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize), + } + runtime.SetFinalizer(q, device.flushOutboundQueue) + return q +} + +func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { + for { + select { + case elemsContainer := <-q.c: + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsContainer(elemsContainer) + default: + return + } + } +} diff --git a/device/constants.go b/device/constants.go index 586ec9e..59854a1 100644 --- a/device/constants.go +++ b/device/constants.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -13,7 +13,7 @@ import ( const ( RekeyAfterMessages = (1 << 60) - RejectAfterMessages = (1 << 64) - (1 << 4) - 1 + RejectAfterMessages = (1 << 64) - (1 << 13) - 1 RekeyAfterTime = time.Second * 120 RekeyAttemptTime = time.Second * 90 RekeyTimeout = time.Second * 5 @@ -35,7 +35,6 @@ const ( /* Implementation constants */ const ( - UnderLoadQueueSize = QueueHandshakeSize / 8 UnderLoadAfterTime = time.Second // how long does the device remain under load after detected MaxPeers = 1 << 16 // maximum number of configured peers ) diff --git a/device/cookie.go b/device/cookie.go index ec54f61..876f05d 100644 --- a/device/cookie.go +++ b/device/cookie.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -13,7 +13,6 @@ import ( "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" - "golang.zx2c4.com/wireguard/wgcfg" ) type CookieChecker struct { @@ -42,7 +41,7 @@ type CookieGenerator struct { } } -func (st *CookieChecker) Init(pk wgcfg.Key) { +func (st *CookieChecker) Init(pk NoisePublicKey) { st.Lock() defer st.Unlock() @@ -84,7 +83,7 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool { return hmac.Equal(mac1[:], msg[smac1:smac2]) } -func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool { +func (st *CookieChecker) CheckMAC2(msg, src []byte) bool { st.RLock() defer st.RUnlock() @@ -120,7 +119,6 @@ func (st *CookieChecker) CreateReply( recv uint32, src []byte, ) (*MessageCookieReply, error) { - st.RLock() // refresh cookie secret @@ -172,7 +170,7 @@ func (st *CookieChecker) CreateReply( return reply, nil } -func (st *CookieGenerator) Init(pk wgcfg.Key) { +func (st *CookieGenerator) Init(pk NoisePublicKey) { st.Lock() defer st.Unlock() @@ -205,7 +203,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:]) _, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:]) - if err != nil { return false } @@ -216,7 +213,6 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool { } func (st *CookieGenerator) AddMacs(msg []byte) { - size := len(msg) smac2 := size - blake2s.Size128 diff --git a/device/cookie_test.go b/device/cookie_test.go index ef01d46..4f1e50a 100644 --- a/device/cookie_test.go +++ b/device/cookie_test.go @@ -1,18 +1,15 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "testing" - - "golang.zx2c4.com/wireguard/wgcfg" ) func TestCookieMAC1(t *testing.T) { - // setup generator / checker var ( @@ -20,11 +17,11 @@ func TestCookieMAC1(t *testing.T) { checker CookieChecker ) - sk, err := wgcfg.NewPrivateKey() + sk, err := newPrivateKey() if err != nil { t.Fatal(err) } - pk := sk.Public() + pk := sk.publicKey() generator.Init(pk) checker.Init(pk) @@ -134,12 +131,12 @@ func TestCookieMAC1(t *testing.T) { msg[5] ^= 0x20 - srcBad1 := []byte{192, 168, 13, 37, 40, 01} + srcBad1 := []byte{192, 168, 13, 37, 40, 1} if checker.CheckMAC2(msg, srcBad1) { t.Fatal("MAC2 generation/verification failed") } - srcBad2 := []byte{192, 168, 13, 38, 40, 01} + srcBad2 := []byte{192, 168, 13, 38, 40, 1} if checker.CheckMAC2(msg, srcBad2) { t.Fatal("MAC2 generation/verification failed") } diff --git a/device/device.go b/device/device.go index 081d59f..83c33ee 100644 --- a/device/device.go +++ b/device/device.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -11,203 +11,228 @@ import ( "sync/atomic" "time" - "golang.org/x/net/ipv4" - "golang.org/x/net/ipv6" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/wgcfg" ) type Device struct { - isUp AtomicBool // device is (going) up - isClosed AtomicBool // device is closed? (acting as guard) - log *Logger - - // synchronized resources (locks acquired in order) - state struct { - starting sync.WaitGroup + // state holds the device's state. It is accessed atomically. + // Use the device.deviceState method to read it. + // device.deviceState does not acquire the mutex, so it captures only a snapshot. + // During state transitions, the state variable is updated before the device itself. + // The state is thus either the current state of the device or + // the intended future state of the device. + // For example, while executing a call to Up, state will be deviceStateUp. + // There is no guarantee that that intended future state of the device + // will become the actual state; Up can fail. + // The device can also change state multiple times between time of check and time of use. + // Unsynchronized uses of state must therefore be advisory/best-effort only. + state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience + // stopping blocks until all inputs to Device have been closed. stopping sync.WaitGroup + // mu protects state changes. sync.Mutex - changing AtomicBool - current bool } net struct { - starting sync.WaitGroup stopping sync.WaitGroup sync.RWMutex bind conn.Bind // bind interface netlinkCancel *rwcancel.RWCancel port uint16 // listening port fwmark uint32 // mark value (0 = disabled) + brokenRoaming bool } staticIdentity struct { sync.RWMutex - privateKey wgcfg.PrivateKey - publicKey wgcfg.Key + privateKey NoisePrivateKey + publicKey NoisePublicKey } peers struct { - sync.RWMutex - keyMap map[wgcfg.Key]*Peer + sync.RWMutex // protects keyMap + keyMap map[NoisePublicKey]*Peer } - // unprotected / "self-synchronising resources" + rate struct { + underLoadUntil atomic.Int64 + limiter ratelimiter.Ratelimiter + } allowedips AllowedIPs indexTable IndexTable cookieChecker CookieChecker - rate struct { - underLoadUntil atomic.Value - limiter ratelimiter.Ratelimiter - } - pool struct { - messageBufferPool *sync.Pool - messageBufferReuseChan chan *[MaxMessageSize]byte - inboundElementPool *sync.Pool - inboundElementReuseChan chan *QueueInboundElement - outboundElementPool *sync.Pool - outboundElementReuseChan chan *QueueOutboundElement + inboundElementsContainer *WaitPool + outboundElementsContainer *WaitPool + messageBuffers *WaitPool + inboundElements *WaitPool + outboundElements *WaitPool } queue struct { - encryption chan *QueueOutboundElement - decryption chan *QueueInboundElement - handshake chan QueueHandshakeElement - } - - signals struct { - stop chan struct{} + encryption *outboundQueue + decryption *inboundQueue + handshake *handshakeQueue } tun struct { device tun.Device - mtu int32 + mtu atomic.Int32 } + + ipcMutex sync.RWMutex + closed chan struct{} + log *Logger } -/* Converts the peer into a "zombie", which remains in the peer map, - * but processes no packets and does not exists in the routing table. - * - * Must hold device.peers.Mutex - */ -func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) { +// deviceState represents the state of a Device. +// There are three states: down, up, closed. +// Transitions: +// +// down -----+ +// āā ā +// up -> closed +type deviceState uint32 + +//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState +const ( + deviceStateDown deviceState = iota + deviceStateUp + deviceStateClosed +) - // stop routing and processing of packets +// deviceState returns device.state.state as a deviceState +// See those docs for how to interpret this value. +func (device *Device) deviceState() deviceState { + return deviceState(device.state.state.Load()) +} + +// isClosed reports whether the device is closed (or is closing). +// See device.state.state comments for how to interpret this value. +func (device *Device) isClosed() bool { + return device.deviceState() == deviceStateClosed +} + +// isUp reports whether the device is up (or is attempting to come up). +// See device.state.state comments for how to interpret this value. +func (device *Device) isUp() bool { + return device.deviceState() == deviceStateUp +} +// Must hold device.peers.Lock() +func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) { + // stop routing and processing of packets device.allowedips.RemoveByPeer(peer) peer.Stop() // remove from peer map - delete(device.peers.keyMap, key) } -func deviceUpdateState(device *Device) { - - // check if state already being updated (guard) - - if device.state.changing.Swap(true) { - return - } - - // compare to current state of device - +// changeState attempts to change the device state to match want. +func (device *Device) changeState(want deviceState) (err error) { device.state.Lock() - - newIsUp := device.isUp.Get() - - if newIsUp == device.state.current { - device.state.changing.Set(false) - device.state.Unlock() - return + defer device.state.Unlock() + old := device.deviceState() + if old == deviceStateClosed { + // once closed, always closed + device.log.Verbosef("Interface closed, ignored requested state %s", want) + return nil } - - // change state of device - - switch newIsUp { - case true: - if err := device.BindUpdate(); err != nil { - device.log.Error.Printf("Unable to update bind: %v\n", err) - device.isUp.Set(false) + switch want { + case old: + return nil + case deviceStateUp: + device.state.state.Store(uint32(deviceStateUp)) + err = device.upLocked() + if err == nil { break } - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Start() - if peer.persistentKeepaliveInterval > 0 { - peer.SendKeepalive() - } - } - device.peers.RUnlock() - - case false: - device.BindClose() - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Stop() + fallthrough // up failed; bring the device all the way back down + case deviceStateDown: + device.state.state.Store(uint32(deviceStateDown)) + errDown := device.downLocked() + if err == nil { + err = errDown } - device.peers.RUnlock() } + device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState()) + return +} - // update state variables - - device.state.current = newIsUp - device.state.changing.Set(false) - device.state.Unlock() +// upLocked attempts to bring the device up and reports whether it succeeded. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) upLocked() error { + if err := device.BindUpdate(); err != nil { + device.log.Errorf("Unable to update bind: %v", err) + return err + } - // check for state change in the mean time + // The IPC set operation waits for peers to be created before calling Start() on them, + // so if there's a concurrent IPC set request happening, we should wait for it to complete. + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() - deviceUpdateState(device) + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Start() + if peer.persistentKeepaliveInterval.Load() > 0 { + peer.SendKeepalive() + } + } + device.peers.RUnlock() + return nil } -func (device *Device) Up() { - - // closed device cannot be brought up +// downLocked attempts to bring the device down. +// The caller must hold device.state.mu and is responsible for updating device.state.state. +func (device *Device) downLocked() error { + err := device.BindClose() + if err != nil { + device.log.Errorf("Bind close failed: %v", err) + } - if device.isClosed.Get() { - return + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.Stop() } + device.peers.RUnlock() + return err +} - device.isUp.Set(true) - deviceUpdateState(device) +func (device *Device) Up() error { + return device.changeState(deviceStateUp) } -func (device *Device) Down() { - device.isUp.Set(false) - deviceUpdateState(device) +func (device *Device) Down() error { + return device.changeState(deviceStateDown) } func (device *Device) IsUnderLoad() bool { - // check if currently under load - now := time.Now() - underLoad := len(device.queue.handshake) >= UnderLoadQueueSize + underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8 if underLoad { - device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime)) + device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano()) return true } - // check if recently under load - - until := device.rate.underLoadUntil.Load().(time.Time) - return until.After(now) + return device.rate.underLoadUntil.Load() > now.UnixNano() } -func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { +func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { // lock required resources device.staticIdentity.Lock() defer device.staticIdentity.Unlock() - if sk.Equal(device.staticIdentity.privateKey) { + if sk.Equals(device.staticIdentity.privateKey) { return nil } @@ -222,10 +247,12 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { // remove peers with matching public keys - publicKey := sk.Public() + publicKey := sk.publicKey() for key, peer := range device.peers.keyMap { - if peer.handshake.remoteStatic.Equal(publicKey) { - unsafeRemovePeer(device, peer, key) + if peer.handshake.remoteStatic.Equals(publicKey) { + peer.handshake.mutex.RUnlock() + removePeerLocked(device, peer, key) + peer.handshake.mutex.RLock() } } @@ -240,7 +267,7 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { handshake := &peer.handshake - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) expiredPeers = append(expiredPeers, peer) } @@ -254,85 +281,78 @@ func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error { return nil } -func NewDevice(tunDevice tun.Device, logger *Logger) *Device { +func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device { device := new(Device) - - device.isUp.Set(false) - device.isClosed.Set(false) - + device.state.state.Store(uint32(deviceStateDown)) + device.closed = make(chan struct{}) device.log = logger - + device.net.bind = bind device.tun.device = tunDevice mtu, err := device.tun.device.MTU() if err != nil { - logger.Error.Println("Trouble determining MTU, assuming default:", err) + device.log.Errorf("Trouble determining MTU, assuming default: %v", err) mtu = DefaultMTU } - device.tun.mtu = int32(mtu) - - device.peers.keyMap = make(map[wgcfg.Key]*Peer) - + device.tun.mtu.Store(int32(mtu)) + device.peers.keyMap = make(map[NoisePublicKey]*Peer) device.rate.limiter.Init() - device.rate.underLoadUntil.Store(time.Time{}) - device.indexTable.Init() - device.allowedips.Reset() device.PopulatePools() // create queues - device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize) - device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize) - device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize) - - // prepare signals - - device.signals.stop = make(chan struct{}) - - // prepare net - - device.net.port = 0 - device.net.bind = nil + device.queue.handshake = newHandshakeQueue() + device.queue.encryption = newOutboundQueue() + device.queue.decryption = newInboundQueue() // start workers cpus := runtime.NumCPU() - device.state.starting.Wait() device.state.stopping.Wait() - for i := 0; i < cpus; i += 1 { - device.state.starting.Add(3) - device.state.stopping.Add(3) - go device.RoutineEncryption() - go device.RoutineDecryption() - go device.RoutineHandshake() + device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake + for i := 0; i < cpus; i++ { + go device.RoutineEncryption(i + 1) + go device.RoutineDecryption(i + 1) + go device.RoutineHandshake(i + 1) } - device.state.starting.Add(2) - device.state.stopping.Add(2) + device.state.stopping.Add(1) // RoutineReadFromTUN + device.queue.encryption.wg.Add(1) // RoutineReadFromTUN go device.RoutineReadFromTUN() go device.RoutineTUNEventReader() - device.state.starting.Wait() - return device } -func (device *Device) LookupPeer(pk wgcfg.Key) *Peer { +// BatchSize returns the BatchSize for the device as a whole which is the max of +// the bind batch size and the tun batch size. The batch size reported by device +// is the size used to construct memory pools, and is the allowed batch size for +// the lifetime of the device. +func (device *Device) BatchSize() int { + size := device.net.bind.BatchSize() + dSize := device.tun.device.BatchSize() + if size < dSize { + size = dSize + } + return size +} + +func (device *Device) LookupPeer(pk NoisePublicKey) *Peer { device.peers.RLock() defer device.peers.RUnlock() return device.peers.keyMap[pk] } -func (device *Device) RemovePeer(key wgcfg.Key) { +func (device *Device) RemovePeer(key NoisePublicKey) { device.peers.Lock() defer device.peers.Unlock() // stop peer and remove from routing peer, ok := device.peers.keyMap[key] if ok { - unsafeRemovePeer(device, peer, key) + removePeerLocked(device, peer, key) } } @@ -341,67 +361,50 @@ func (device *Device) RemoveAllPeers() { defer device.peers.Unlock() for key, peer := range device.peers.keyMap { - unsafeRemovePeer(device, peer, key) - } - - device.peers.keyMap = make(map[wgcfg.Key]*Peer) -} - -func (device *Device) FlushPacketQueues() { - for { - select { - case elem, ok := <-device.queue.decryption: - if ok { - elem.Drop() - } - case elem, ok := <-device.queue.encryption: - if ok { - elem.Drop() - } - case <-device.queue.handshake: - default: - return - } + removePeerLocked(device, peer, key) } + device.peers.keyMap = make(map[NoisePublicKey]*Peer) } func (device *Device) Close() { - if device.isClosed.Swap(true) { - return - } - - device.state.starting.Wait() - - device.log.Info.Println("Device closing") - device.state.changing.Set(true) device.state.Lock() defer device.state.Unlock() + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() + if device.isClosed() { + return + } + device.state.state.Store(uint32(deviceStateClosed)) + device.log.Verbosef("Device closing") device.tun.device.Close() - device.BindClose() - - device.isUp.Set(false) - - close(device.signals.stop) + device.downLocked() + // Remove peers before closing queues, + // because peers assume that queues are active. device.RemoveAllPeers() + // We kept a reference to the encryption and decryption queues, + // in case we started any new peers that might write to them. + // No new peers are coming; we are done with these queues. + device.queue.encryption.wg.Done() + device.queue.decryption.wg.Done() + device.queue.handshake.wg.Done() device.state.stopping.Wait() - device.FlushPacketQueues() device.rate.limiter.Close() - device.state.changing.Set(false) - device.log.Info.Println("Interface closed") + device.log.Verbosef("Device closed") + close(device.closed) } func (device *Device) Wait() chan struct{} { - return device.signals.stop + return device.closed } func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { - if device.isClosed.Get() { + if !device.isUp() { return } @@ -417,7 +420,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { device.peers.RUnlock() } -func unsafeCloseBind(device *Device) error { +// closeBindLocked closes the device's net.bind. +// The caller must hold the net mutex. +func closeBindLocked(device *Device) error { var err error netc := &device.net if netc.netlinkCancel != nil { @@ -425,41 +430,38 @@ func unsafeCloseBind(device *Device) error { } if netc.bind != nil { err = netc.bind.Close() - netc.bind = nil } netc.stopping.Wait() return err } -func (device *Device) BindSetMark(mark uint32) error { +func (device *Device) Bind() conn.Bind { + device.net.Lock() + defer device.net.Unlock() + return device.net.bind +} +func (device *Device) BindSetMark(mark uint32) error { device.net.Lock() defer device.net.Unlock() // check if modified - if device.net.fwmark == mark { return nil } // update fwmark on existing bind - device.net.fwmark = mark - if device.isUp.Get() && device.net.bind != nil { + if device.isUp() && device.net.bind != nil { if err := device.net.bind.SetMark(mark); err != nil { return err } } // clear cached source addresses - device.peers.RLock() for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } + peer.markEndpointSrcForClearing() } device.peers.RUnlock() @@ -467,76 +469,68 @@ func (device *Device) BindSetMark(mark uint32) error { } func (device *Device) BindUpdate() error { - device.net.Lock() defer device.net.Unlock() // close existing sockets - - if err := unsafeCloseBind(device); err != nil { + if err := closeBindLocked(device); err != nil { return err } // open new sockets + if !device.isUp() { + return nil + } - if device.isUp.Get() { + // bind to new port + var err error + var recvFns []conn.ReceiveFunc + netc := &device.net - // bind to new port + recvFns, netc.port, err = netc.bind.Open(netc.port) + if err != nil { + netc.port = 0 + return err + } - var err error - netc := &device.net - netc.bind, netc.port, err = conn.CreateBind(netc.port) - if err != nil { - netc.bind = nil - netc.port = 0 - return err - } - netc.netlinkCancel, err = device.startRouteListener(netc.bind) + netc.netlinkCancel, err = device.startRouteListener(netc.bind) + if err != nil { + netc.bind.Close() + netc.port = 0 + return err + } + + // set fwmark + if netc.fwmark != 0 { + err = netc.bind.SetMark(netc.fwmark) if err != nil { - netc.bind.Close() - netc.bind = nil - netc.port = 0 return err } + } - // set fwmark - - if netc.fwmark != 0 { - err = netc.bind.SetMark(netc.fwmark) - if err != nil { - return err - } - } - - // clear cached source addresses - - device.peers.RLock() - for _, peer := range device.peers.keyMap { - peer.Lock() - defer peer.Unlock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - } - device.peers.RUnlock() - - // start receiving routines - - device.net.starting.Add(2) - device.net.stopping.Add(2) - go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) - go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) - device.net.starting.Wait() + // clear cached source addresses + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.markEndpointSrcForClearing() + } + device.peers.RUnlock() - device.log.Debug.Println("UDP bind has been updated") + // start receiving routines + device.net.stopping.Add(len(recvFns)) + device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption + device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake + batchSize := netc.bind.BatchSize() + for _, fn := range recvFns { + go device.RoutineReceiveIncoming(batchSize, fn) } + device.log.Verbosef("UDP bind has been updated") return nil } func (device *Device) BindClose() error { device.net.Lock() - err := unsafeCloseBind(device) + err := closeBindLocked(device) device.net.Unlock() return err } diff --git a/device/device_test.go b/device/device_test.go index 925d2b1..fff172b 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -1,103 +1,476 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "bufio" "bytes" - "net" - "strings" + "encoding/hex" + "fmt" + "io" + "math/rand" + "net/netip" + "os" + "runtime" + "runtime/pprof" + "sync" + "sync/atomic" "testing" "time" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/conn/bindtest" + "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/tuntest" - "golang.zx2c4.com/wireguard/wgcfg" ) -func TestTwoDevicePing(t *testing.T) { - // TODO(crawshaw): pick unused ports on localhost - cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58 -listen_port=53511 -replace_peers=true -public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725 -protocol_version=1 -replace_allowed_ips=true -allowed_ip=1.0.0.2/32 -endpoint=127.0.0.1:53512` - tun1 := tuntest.NewChannelTUN() - dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: ")) - dev1.Up() - defer dev1.Close() - if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil { - t.Fatal(err) - } - - cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768 -listen_port=53512 -replace_peers=true -public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427 -protocol_version=1 -replace_allowed_ips=true -allowed_ip=1.0.0.1/32 -endpoint=127.0.0.1:53511` - tun2 := tuntest.NewChannelTUN() - dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: ")) - dev2.Up() - defer dev2.Close() - if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil { - t.Fatal(err) +// uapiCfg returns a string that contains cfg formatted use with IpcSet. +// cfg is a series of alternating key/value strings. +// uapiCfg exists because editors and humans like to insert +// whitespace into configs, which can cause failures, some of which are silent. +// For example, a leading blank newline causes the remainder +// of the config to be silently ignored. +func uapiCfg(cfg ...string) string { + if len(cfg)%2 != 0 { + panic("odd number of args to uapiReader") + } + buf := new(bytes.Buffer) + for i, s := range cfg { + buf.WriteString(s) + sep := byte('\n') + if i%2 == 0 { + sep = '=' + } + buf.WriteByte(sep) } + return buf.String() +} - t.Run("ping 1.0.0.1", func(t *testing.T) { - msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) - tun2.Outbound <- msg2to1 +// genConfigs generates a pair of configs that connect to each other. +// The configs use distinct, probably-usable ports. +func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) { + var key1, key2 NoisePrivateKey + _, err := rand.Read(key1[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + _, err = rand.Read(key2[:]) + if err != nil { + tb.Errorf("unable to generate private key random bytes: %v", err) + } + pub1, pub2 := key1.publicKey(), key2.publicKey() + + cfgs[0] = uapiCfg( + "private_key", hex.EncodeToString(key1[:]), + "listen_port", "0", + "replace_peers", "true", + "public_key", hex.EncodeToString(pub2[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.2/32", + ) + endpointCfgs[0] = uapiCfg( + "public_key", hex.EncodeToString(pub2[:]), + "endpoint", "127.0.0.1:%d", + ) + cfgs[1] = uapiCfg( + "private_key", hex.EncodeToString(key2[:]), + "listen_port", "0", + "replace_peers", "true", + "public_key", hex.EncodeToString(pub1[:]), + "protocol_version", "1", + "replace_allowed_ips", "true", + "allowed_ip", "1.0.0.1/32", + ) + endpointCfgs[1] = uapiCfg( + "public_key", hex.EncodeToString(pub1[:]), + "endpoint", "127.0.0.1:%d", + ) + return +} + +// A testPair is a pair of testPeers. +type testPair [2]testPeer + +// A testPeer is a peer used for testing. +type testPeer struct { + tun *tuntest.ChannelTUN + dev *Device + ip netip.Addr +} + +type SendDirection bool + +const ( + Ping SendDirection = true + Pong SendDirection = false +) + +func (d SendDirection) String() string { + if d == Ping { + return "ping" + } + return "pong" +} + +func (pair *testPair) Send(tb testing.TB, ping SendDirection, done chan struct{}) { + tb.Helper() + p0, p1 := pair[0], pair[1] + if !ping { + // pong is the new ping + p0, p1 = p1, p0 + } + msg := tuntest.Ping(p0.ip, p1.ip) + p1.tun.Outbound <- msg + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + var err error + select { + case msgRecv := <-p0.tun.Inbound: + if !bytes.Equal(msg, msgRecv) { + err = fmt.Errorf("%s did not transit correctly", ping) + } + case <-timer.C: + err = fmt.Errorf("%s did not transit", ping) + case <-done: + } + if err != nil { + // The error may have occurred because the test is done. select { - case msgRecv := <-tun1.Inbound: - if !bytes.Equal(msg2to1, msgRecv) { - t.Error("ping did not transit correctly") + case <-done: + return + default: + } + // Real error. + tb.Error(err) + } +} + +// genTestPair creates a testPair. +func genTestPair(tb testing.TB, realSocket bool) (pair testPair) { + cfg, endpointCfg := genConfigs(tb) + var binds [2]conn.Bind + if realSocket { + binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind() + } else { + binds = bindtest.NewChannelBinds() + } + // Bring up a ChannelTun for each config. + for i := range pair { + p := &pair[i] + p.tun = tuntest.NewChannelTUN() + p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)}) + level := LogLevelVerbose + if _, ok := tb.(*testing.B); ok && !testing.Verbose() { + level = LogLevelError + } + p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i))) + if err := p.dev.IpcSet(cfg[i]); err != nil { + tb.Errorf("failed to configure device %d: %v", i, err) + p.dev.Close() + continue + } + if err := p.dev.Up(); err != nil { + tb.Errorf("failed to bring up device %d: %v", i, err) + p.dev.Close() + continue + } + endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port) + } + for i := range pair { + p := &pair[i] + if err := p.dev.IpcSet(endpointCfg[i]); err != nil { + tb.Errorf("failed to configure device endpoint %d: %v", i, err) + p.dev.Close() + continue + } + // The device is ready. Close it when the test completes. + tb.Cleanup(p.dev.Close) + } + return +} + +func TestTwoDevicePing(t *testing.T) { + goroutineLeakCheck(t) + pair := genTestPair(t, true) + t.Run("ping 1.0.0.1", func(t *testing.T) { + pair.Send(t, Ping, nil) + }) + t.Run("ping 1.0.0.2", func(t *testing.T) { + pair.Send(t, Pong, nil) + }) +} + +func TestUpDown(t *testing.T) { + goroutineLeakCheck(t) + const itrials = 50 + const otrials = 10 + + for n := 0; n < otrials; n++ { + pair := genTestPair(t, false) + for i := range pair { + for k := range pair[i].dev.peers.keyMap { + pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:]))) + } + } + var wg sync.WaitGroup + wg.Add(len(pair)) + for i := range pair { + go func(d *Device) { + defer wg.Done() + for i := 0; i < itrials; i++ { + if err := d.Up(); err != nil { + t.Errorf("failed up bring up device: %v", err) + } + time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + if err := d.Down(); err != nil { + t.Errorf("failed to bring down device: %v", err) + } + time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1))))) + } + }(pair[i].dev) + } + wg.Wait() + for i := range pair { + pair[i].dev.Up() + pair[i].dev.Close() + } + } +} + +// TestConcurrencySafety does other things concurrently with tunnel use. +// It is intended to be used with the race detector to catch data races. +func TestConcurrencySafety(t *testing.T) { + pair := genTestPair(t, true) + done := make(chan struct{}) + + const warmupIters = 10 + var warmup sync.WaitGroup + warmup.Add(warmupIters) + go func() { + // Send data continuously back and forth until we're done. + // Note that we may continue to attempt to send data + // even after done is closed. + i := warmupIters + for ping := Ping; ; ping = !ping { + pair.Send(t, ping, done) + select { + case <-done: + return + default: + } + if i > 0 { + warmup.Done() + i-- } - case <-time.After(300 * time.Millisecond): - t.Error("ping did not transit") + } + }() + warmup.Wait() + + applyCfg := func(cfg string) { + err := pair[0].dev.IpcSet(cfg) + if err != nil { + t.Fatal(err) + } + } + + // Change persistent_keepalive_interval concurrently with tunnel use. + t.Run("persistentKeepaliveInterval", func(t *testing.T) { + var pub NoisePublicKey + for key := range pair[0].dev.peers.keyMap { + pub = key + break + } + cfg := uapiCfg( + "public_key", hex.EncodeToString(pub[:]), + "persistent_keepalive_interval", "1", + ) + for i := 0; i < 1000; i++ { + applyCfg(cfg) } }) - t.Run("ping 1.0.0.2", func(t *testing.T) { - msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) - tun1.Outbound <- msg1to2 - select { - case msgRecv := <-tun2.Inbound: - if !bytes.Equal(msg1to2, msgRecv) { - t.Error("return ping did not transit correctly") + // Change private keys concurrently with tunnel use. + t.Run("privateKey", func(t *testing.T) { + bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777") + good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:])) + // Set iters to a large number like 1000 to flush out data races quickly. + // Don't leave it large. That can cause logical races + // in which the handshake is interleaved with key changes + // such that the private key appears to be unchanging but + // other state gets reset, which can cause handshake failures like + // "Received packet with invalid mac1". + const iters = 1 + for i := 0; i < iters; i++ { + applyCfg(bad) + applyCfg(good) + } + }) + + // Perform bind updates and keepalive sends concurrently with tunnel use. + t.Run("bindUpdate and keepalive", func(t *testing.T) { + const iters = 10 + for i := 0; i < iters; i++ { + for _, peer := range pair { + peer.dev.BindUpdate() + peer.dev.SendKeepalivesToPeersWithCurrentKeypair() } - case <-time.After(300 * time.Millisecond): - t.Error("return ping did not transit") } }) + + close(done) } -func assertNil(t *testing.T, err error) { - if err != nil { - t.Fatal(err) +func BenchmarkLatency(b *testing.B) { + pair := genTestPair(b, true) + + // Establish a connection. + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) } } -func assertEqual(t *testing.T, a, b []byte) { - if !bytes.Equal(a, b) { - t.Fatal(a, "!=", b) +func BenchmarkThroughput(b *testing.B) { + pair := genTestPair(b, true) + + // Establish a connection. + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + + // Measure how long it takes to receive b.N packets, + // starting when we receive the first packet. + var recv atomic.Uint64 + var elapsed time.Duration + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + var start time.Time + for { + <-pair[0].tun.Inbound + new := recv.Add(1) + if new == 1 { + start = time.Now() + } + // Careful! Don't change this to else if; b.N can be equal to 1. + if new == uint64(b.N) { + elapsed = time.Since(start) + return + } + } + }() + + // Send packets as fast as we can until we've received enough. + ping := tuntest.Ping(pair[0].ip, pair[1].ip) + pingc := pair[1].tun.Outbound + var sent uint64 + for recv.Load() != uint64(b.N) { + sent++ + pingc <- ping } + wg.Wait() + + b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op") + b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss") } -func randDevice(t *testing.T) *Device { - sk, err := wgcfg.NewPrivateKey() - if err != nil { - t.Fatal(err) +func BenchmarkUAPIGet(b *testing.B) { + pair := genTestPair(b, true) + pair.Send(b, Ping, nil) + pair.Send(b, Pong, nil) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + pair[0].dev.IpcGetOperation(io.Discard) + } +} + +func goroutineLeakCheck(t *testing.T) { + goroutines := func() (int, []byte) { + p := pprof.Lookup("goroutine") + b := new(bytes.Buffer) + p.WriteTo(b, 1) + return p.Count(), b.Bytes() + } + + startGoroutines, startStacks := goroutines() + t.Cleanup(func() { + if t.Failed() { + return + } + // Give goroutines time to exit, if they need it. + for i := 0; i < 10000; i++ { + if runtime.NumGoroutine() <= startGoroutines { + return + } + time.Sleep(1 * time.Millisecond) + } + endGoroutines, endStacks := goroutines() + t.Logf("starting stacks:\n%s\n", startStacks) + t.Logf("ending stacks:\n%s\n", endStacks) + t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines) + }) +} + +type fakeBindSized struct { + size int +} + +func (b *fakeBindSized) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { + return nil, 0, nil +} +func (b *fakeBindSized) Close() error { return nil } +func (b *fakeBindSized) SetMark(mark uint32) error { return nil } +func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil } +func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil } +func (b *fakeBindSized) BatchSize() int { return b.size } + +type fakeTUNDeviceSized struct { + size int +} + +func (t *fakeTUNDeviceSized) File() *os.File { return nil } +func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + return 0, nil +} +func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil } +func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil } +func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil } +func (t *fakeTUNDeviceSized) Close() error { return nil } +func (t *fakeTUNDeviceSized) BatchSize() int { return t.size } + +func TestBatchSize(t *testing.T) { + d := Device{} + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 1, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{1} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{1} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) + } + + d.net.bind = &fakeBindSized{128} + d.tun.device = &fakeTUNDeviceSized{128} + if want, got := 128, d.BatchSize(); got != want { + t.Errorf("expected batch size %d, got %d", want, got) } - tun := newDummyTUN("dummy") - logger := NewLogger(LogLevelError, "") - device := NewDevice(tun, logger) - device.SetPrivateKey(sk) - return device } diff --git a/device/devicestate_string.go b/device/devicestate_string.go new file mode 100644 index 0000000..6577dd4 --- /dev/null +++ b/device/devicestate_string.go @@ -0,0 +1,16 @@ +// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT. + +package device + +import "strconv" + +const _deviceState_name = "DownUpClosed" + +var _deviceState_index = [...]uint8{0, 4, 6, 12} + +func (i deviceState) String() string { + if i >= deviceState(len(_deviceState_index)-1) { + return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]] +} diff --git a/device/endpoint_test.go b/device/endpoint_test.go index 1896790..93a4998 100644 --- a/device/endpoint_test.go +++ b/device/endpoint_test.go @@ -1,53 +1,49 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "math/rand" - "net" + "net/netip" ) type DummyEndpoint struct { - src [16]byte - dst [16]byte + src, dst netip.Addr } func CreateDummyEndpoint() (*DummyEndpoint, error) { - var end DummyEndpoint - if _, err := rand.Read(end.src[:]); err != nil { + var src, dst [16]byte + if _, err := rand.Read(src[:]); err != nil { return nil, err } - _, err := rand.Read(end.dst[:]) - return &end, err + _, err := rand.Read(dst[:]) + return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err } func (e *DummyEndpoint) ClearSrc() {} func (e *DummyEndpoint) SrcToString() string { - var addr net.UDPAddr - addr.IP = e.SrcIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.SrcIP(), 1000).String() } func (e *DummyEndpoint) DstToString() string { - var addr net.UDPAddr - addr.IP = e.DstIP() - addr.Port = 1000 - return addr.String() + return netip.AddrPortFrom(e.DstIP(), 1000).String() } -func (e *DummyEndpoint) SrcToBytes() []byte { - return e.src[:] +func (e *DummyEndpoint) DstToBytes() []byte { + out := e.DstIP().AsSlice() + out = append(out, byte(1000&0xff)) + out = append(out, byte((1000>>8)&0xff)) + return out } -func (e *DummyEndpoint) DstIP() net.IP { - return e.dst[:] +func (e *DummyEndpoint) DstIP() netip.Addr { + return e.dst } -func (e *DummyEndpoint) SrcIP() net.IP { - return e.src[:] +func (e *DummyEndpoint) SrcIP() netip.Addr { + return e.src } diff --git a/device/indextable.go b/device/indextable.go index 4cba970..00ade7d 100644 --- a/device/indextable.go +++ b/device/indextable.go @@ -1,14 +1,14 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "crypto/rand" + "encoding/binary" "sync" - "unsafe" ) type IndexTableEntry struct { @@ -25,7 +25,8 @@ type IndexTable struct { func randUint32() (uint32, error) { var integer [4]byte _, err := rand.Read(integer[:]) - return *(*uint32)(unsafe.Pointer(&integer[0])), err + // Arbitrary endianness; both are intrinsified by the Go compiler. + return binary.LittleEndian.Uint32(integer[:]), err } func (table *IndexTable) Init() { diff --git a/device/ip.go b/device/ip.go index 9d4fb74..eaf2363 100644 --- a/device/ip.go +++ b/device/ip.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device diff --git a/device/kdf_test.go b/device/kdf_test.go index cb8dbab..f9c76d6 100644 --- a/device/kdf_test.go +++ b/device/kdf_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -20,7 +20,7 @@ type KDFTest struct { t2 string } -func assertEquals(t *testing.T, a string, b string) { +func assertEquals(t *testing.T, a, b string) { if a != b { t.Fatal("expected", a, "=", b) } diff --git a/device/keypair.go b/device/keypair.go index 9c78fa9..e3540d7 100644 --- a/device/keypair.go +++ b/device/keypair.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,6 +8,7 @@ package device import ( "crypto/cipher" "sync" + "sync/atomic" "time" "golang.zx2c4.com/wireguard/replay" @@ -21,10 +22,10 @@ import ( */ type Keypair struct { - sendNonce uint64 + sendNonce atomic.Uint64 send cipher.AEAD receive cipher.AEAD - replayFilter replay.ReplayFilter + replayFilter replay.Filter isInitiator bool created time.Time localIndex uint32 @@ -35,7 +36,7 @@ type Keypairs struct { sync.RWMutex current *Keypair previous *Keypair - next *Keypair + next atomic.Pointer[Keypair] } func (kp *Keypairs) Current() *Keypair { diff --git a/device/logger.go b/device/logger.go index 7c8b704..22b0df0 100644 --- a/device/logger.go +++ b/device/logger.go @@ -1,59 +1,48 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "io" - "io/ioutil" "log" "os" ) +// A Logger provides logging for a Device. +// The functions are Printf-style functions. +// They must be safe for concurrent use. +// They do not require a trailing newline in the format. +// If nil, that level of logging will be silent. +type Logger struct { + Verbosef func(format string, args ...any) + Errorf func(format string, args ...any) +} + +// Log levels for use with NewLogger. const ( LogLevelSilent = iota LogLevelError - LogLevelInfo - LogLevelDebug + LogLevelVerbose ) -type Logger struct { - Debug *log.Logger - Info *log.Logger - Error *log.Logger -} +// Function for use in Logger for discarding logged lines. +func DiscardLogf(format string, args ...any) {} +// NewLogger constructs a Logger that writes to stdout. +// It logs at the specified log level and above. +// It decorates log lines with the log level, date, time, and prepend. func NewLogger(level int, prepend string) *Logger { - output := os.Stdout - logger := new(Logger) - - logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) { - if level >= LogLevelDebug { - return output, output, output - } - if level >= LogLevelInfo { - return output, output, ioutil.Discard - } - if level >= LogLevelError { - return output, ioutil.Discard, ioutil.Discard - } - return ioutil.Discard, ioutil.Discard, ioutil.Discard - }() - - logger.Debug = log.New(logDebug, - "DEBUG: "+prepend, - log.Ldate|log.Ltime, - ) - - logger.Info = log.New(logInfo, - "INFO: "+prepend, - log.Ldate|log.Ltime, - ) - logger.Error = log.New(logErr, - "ERROR: "+prepend, - log.Ldate|log.Ltime, - ) + logger := &Logger{DiscardLogf, DiscardLogf} + logf := func(prefix string) func(string, ...any) { + return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf + } + if level >= LogLevelVerbose { + logger.Verbosef = logf("DEBUG") + } + if level >= LogLevelError { + logger.Errorf = logf("ERROR") + } return logger } diff --git a/device/misc.go b/device/misc.go deleted file mode 100644 index a38d1c1..0000000 --- a/device/misc.go +++ /dev/null @@ -1,48 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "sync/atomic" -) - -/* Atomic Boolean */ - -const ( - AtomicFalse = int32(iota) - AtomicTrue -) - -type AtomicBool struct { - int32 -} - -func (a *AtomicBool) Get() bool { - return atomic.LoadInt32(&a.int32) == AtomicTrue -} - -func (a *AtomicBool) Swap(val bool) bool { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - return atomic.SwapInt32(&a.int32, flag) == AtomicTrue -} - -func (a *AtomicBool) Set(val bool) { - flag := AtomicFalse - if val { - flag = AtomicTrue - } - atomic.StoreInt32(&a.int32, flag) -} - -func min(a, b uint) uint { - if a > b { - return b - } - return a -} diff --git a/device/mobilequirks.go b/device/mobilequirks.go new file mode 100644 index 0000000..0a0080e --- /dev/null +++ b/device/mobilequirks.go @@ -0,0 +1,19 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created, +// though it will try to deal with it, and race maybe, if called after. +func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() { + device.net.brokenRoaming = true + device.peers.RLock() + for _, peer := range device.peers.keyMap { + peer.endpoint.Lock() + peer.endpoint.disableRoaming = peer.endpoint.val != nil + peer.endpoint.Unlock() + } + device.peers.RUnlock() +} diff --git a/device/noise-helpers.go b/device/noise-helpers.go index ae52a7d..c2f356b 100644 --- a/device/noise-helpers.go +++ b/device/noise-helpers.go @@ -1,16 +1,19 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "crypto/hmac" + "crypto/rand" "crypto/subtle" + "errors" "hash" "golang.org/x/crypto/blake2s" + "golang.org/x/crypto/curve25519" ) /* KDF related functions. @@ -73,3 +76,33 @@ func setZero(arr []byte) { arr[i] = 0 } } + +func (sk *NoisePrivateKey) clamp() { + sk[0] &= 248 + sk[31] = (sk[31] & 127) | 64 +} + +func newPrivateKey() (sk NoisePrivateKey, err error) { + _, err = rand.Read(sk[:]) + sk.clamp() + return +} + +func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarBaseMult(apk, ask) + return +} + +var errInvalidPublicKey = errors.New("invalid public key") + +func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) { + apk := (*[NoisePublicKeySize]byte)(&pk) + ask := (*[NoisePrivateKeySize]byte)(sk) + curve25519.ScalarMult(&ss, ask, apk) + if isZero(ss[:]) { + return ss, errInvalidPublicKey + } + return ss, nil +} diff --git a/device/noise-protocol.go b/device/noise-protocol.go index dbb6f93..e8f6145 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -14,13 +14,12 @@ import ( "golang.org/x/crypto/blake2s" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/poly1305" + "golang.zx2c4.com/wireguard/tai64n" - "golang.zx2c4.com/wireguard/wgcfg" ) type handshakeState int -// TODO(crawshaw): add commentary describing each state and the transitions const ( handshakeZeroed = handshakeState(iota) handshakeInitiationCreated @@ -85,8 +84,8 @@ const ( type MessageInitiation struct { Type uint32 Sender uint32 - Ephemeral wgcfg.Key - Static [wgcfg.KeySize + poly1305.TagSize]byte + Ephemeral NoisePublicKey + Static [NoisePublicKeySize + poly1305.TagSize]byte Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte @@ -96,7 +95,7 @@ type MessageResponse struct { Type uint32 Sender uint32 Receiver uint32 - Ephemeral wgcfg.Key + Ephemeral NoisePublicKey Empty [poly1305.TagSize]byte MAC1 [blake2s.Size128]byte MAC2 [blake2s.Size128]byte @@ -119,15 +118,15 @@ type MessageCookieReply struct { type Handshake struct { state handshakeState mutex sync.RWMutex - hash [blake2s.Size]byte // hash value - chainKey [blake2s.Size]byte // chain key - presharedKey wgcfg.SymmetricKey // psk - localEphemeral wgcfg.PrivateKey // ephemeral secret key - localIndex uint32 // used to clear hash-table - remoteIndex uint32 // index for sending - remoteStatic wgcfg.Key // long term key - remoteEphemeral wgcfg.Key // ephemeral public key - precomputedStaticStatic [wgcfg.KeySize]byte // precomputed shared secret + hash [blake2s.Size]byte // hash value + chainKey [blake2s.Size]byte // chain key + presharedKey NoisePresharedKey // psk + localEphemeral NoisePrivateKey // ephemeral secret key + localIndex uint32 // used to clear hash-table + remoteIndex uint32 // index for sending + remoteStatic NoisePublicKey // long term key + remoteEphemeral NoisePublicKey // ephemeral public key + precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret lastTimestamp tai64n.Timestamp lastInitiationConsumption time.Time lastSentHandshake time.Time @@ -139,11 +138,11 @@ var ( ZeroNonce [chacha20poly1305.NonceSize]byte ) -func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) { +func mixKey(dst, c *[blake2s.Size]byte, data []byte) { KDF1(dst, c[:], data) } -func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) { +func mixHash(dst, h *[blake2s.Size]byte, data []byte) { hash, _ := blake2s.New256(nil) hash.Write(h[:]) hash.Write(data) @@ -176,8 +175,6 @@ func init() { } func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { - var errZeroECDHResult = errors.New("ECDH returned all zeros") - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -189,7 +186,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e var err error handshake.hash = InitialHash handshake.chainKey = InitialChainKey - handshake.localEphemeral, err = wgcfg.NewPrivateKey() + handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err } @@ -198,16 +195,16 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e msg := MessageInitiation{ Type: MessageInitiationType, - Ephemeral: handshake.localEphemeral.Public(), + Ephemeral: handshake.localEphemeral.publicKey(), } handshake.mixKey(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:]) // encrypt static key - ss := handshake.localEphemeral.SharedSecret(handshake.remoteStatic) - if isZero(ss[:]) { - return nil, errZeroECDHResult + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err } var key [chacha20poly1305.KeySize]byte KDF2( @@ -222,7 +219,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e // encrypt timestamp if isZero(handshake.precomputedStaticStatic[:]) { - return nil, errZeroECDHResult + return nil, errInvalidPublicKey } KDF2( &handshake.chainKey, @@ -265,11 +262,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key - var err error - var peerPK wgcfg.Key + var peerPK NoisePublicKey var key [chacha20poly1305.KeySize]byte - ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral) - if isZero(ss[:]) { + ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { return nil } KDF2(&chainKey, &key, chainKey[:], ss[:]) @@ -283,7 +279,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { // lookup peer peer := device.LookupPeer(peerPK) - if peer == nil { + if peer == nil || !peer.isRunning.Load() { return nil } @@ -319,11 +315,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer { flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate handshake.mutex.RUnlock() if replay { - device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake replay @ %v\n", peer, timestamp) + device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp) return nil } if flood { - device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake flood\n", peer) + device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer) return nil } @@ -377,20 +373,24 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error // create ephemeral key - handshake.localEphemeral, err = wgcfg.NewPrivateKey() + handshake.localEphemeral, err = newPrivateKey() if err != nil { return nil, err } - msg.Ephemeral = handshake.localEphemeral.Public() + msg.Ephemeral = handshake.localEphemeral.publicKey() handshake.mixHash(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.SharedSecret(handshake.remoteEphemeral) - handshake.mixKey(ss[:]) - ss = handshake.localEphemeral.SharedSecret(handshake.remoteStatic) - handshake.mixKey(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) + ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic) + if err != nil { + return nil, err + } + handshake.mixKey(ss[:]) // add preshared key @@ -407,11 +407,9 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error handshake.mixHash(tau[:]) - func() { - aead, _ := chacha20poly1305.New(key[:]) - aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) - handshake.mixHash(msg.Empty[:]) - }() + aead, _ := chacha20poly1305.New(key[:]) + aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:]) + handshake.mixHash(msg.Empty[:]) handshake.state = handshakeResponseCreated @@ -437,7 +435,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { ) ok := func() bool { - // lock handshake state handshake.mutex.RLock() @@ -457,17 +454,19 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:]) - func() { - ss := handshake.localEphemeral.SharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) - func() { - ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral) - mixKey(&chainKey, &chainKey, ss[:]) - setZero(ss[:]) - }() + ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + if err != nil { + return false + } + mixKey(&chainKey, &chainKey, ss[:]) + setZero(ss[:]) // add preshared key (psk) @@ -485,7 +484,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { // authenticate transcript aead, _ := chacha20poly1305.New(key[:]) - _, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) + _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:]) if err != nil { return false } @@ -566,8 +565,7 @@ func (peer *Peer) BeginSymmetricSession() error { setZero(recvKey[:]) keypair.created = time.Now() - keypair.sendNonce = 0 - keypair.replayFilter.Init() + keypair.replayFilter.Reset() keypair.isInitiator = isInitiator keypair.localIndex = peer.handshake.localIndex keypair.remoteIndex = peer.handshake.remoteIndex @@ -584,12 +582,12 @@ func (peer *Peer) BeginSymmetricSession() error { defer keypairs.Unlock() previous := keypairs.previous - next := keypairs.next + next := keypairs.next.Load() current := keypairs.current if isInitiator { if next != nil { - keypairs.next = nil + keypairs.next.Store(nil) keypairs.previous = next device.DeleteKeypair(current) } else { @@ -598,7 +596,7 @@ func (peer *Peer) BeginSymmetricSession() error { device.DeleteKeypair(previous) keypairs.current = keypair } else { - keypairs.next = keypair + keypairs.next.Store(keypair) device.DeleteKeypair(next) keypairs.previous = nil device.DeleteKeypair(previous) @@ -609,18 +607,19 @@ func (peer *Peer) BeginSymmetricSession() error { func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool { keypairs := &peer.keypairs - if keypairs.next != receivedKeypair { + + if keypairs.next.Load() != receivedKeypair { return false } keypairs.Lock() defer keypairs.Unlock() - if keypairs.next != receivedKeypair { + if keypairs.next.Load() != receivedKeypair { return false } old := keypairs.previous keypairs.previous = keypairs.current peer.device.DeleteKeypair(old) - keypairs.current = keypairs.next - keypairs.next = nil + keypairs.current = keypairs.next.Load() + keypairs.next.Store(nil) return true } diff --git a/device/noise-types.go b/device/noise-types.go new file mode 100644 index 0000000..e850359 --- /dev/null +++ b/device/noise-types.go @@ -0,0 +1,78 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "crypto/subtle" + "encoding/hex" + "errors" +) + +const ( + NoisePublicKeySize = 32 + NoisePrivateKeySize = 32 + NoisePresharedKeySize = 32 +) + +type ( + NoisePublicKey [NoisePublicKeySize]byte + NoisePrivateKey [NoisePrivateKeySize]byte + NoisePresharedKey [NoisePresharedKeySize]byte + NoiseNonce uint64 // padded to 12-bytes +) + +func loadExactHex(dst []byte, src string) error { + slice, err := hex.DecodeString(src) + if err != nil { + return err + } + if len(slice) != len(dst) { + return errors.New("hex string does not fit the slice") + } + copy(dst, slice) + return nil +} + +func (key NoisePrivateKey) IsZero() bool { + var zero NoisePrivateKey + return key.Equals(zero) +} + +func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoisePrivateKey) FromHex(src string) (err error) { + err = loadExactHex(key[:], src) + key.clamp() + return +} + +func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) { + err = loadExactHex(key[:], src) + if key.IsZero() { + return + } + key.clamp() + return +} + +func (key *NoisePublicKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} + +func (key NoisePublicKey) IsZero() bool { + var zero NoisePublicKey + return key.Equals(zero) +} + +func (key NoisePublicKey) Equals(tar NoisePublicKey) bool { + return subtle.ConstantTimeCompare(key[:], tar[:]) == 1 +} + +func (key *NoisePresharedKey) FromHex(src string) error { + return loadExactHex(key[:], src) +} diff --git a/device/noise_test.go b/device/noise_test.go index e431588..2dd5324 100644 --- a/device/noise_test.go +++ b/device/noise_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -9,8 +9,53 @@ import ( "bytes" "encoding/binary" "testing" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/tun/tuntest" ) +func TestCurveWrappers(t *testing.T) { + sk1, err := newPrivateKey() + assertNil(t, err) + + sk2, err := newPrivateKey() + assertNil(t, err) + + pk1 := sk1.publicKey() + pk2 := sk2.publicKey() + + ss1, err1 := sk1.sharedSecret(pk2) + ss2, err2 := sk2.sharedSecret(pk1) + + if ss1 != ss2 || err1 != nil || err2 != nil { + t.Fatal("Failed to compute shared secet") + } +} + +func randDevice(t *testing.T) *Device { + sk, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + tun := tuntest.NewChannelTUN() + logger := NewLogger(LogLevelError, "") + device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger) + device.SetPrivateKey(sk) + return device +} + +func assertNil(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } +} + +func assertEqual(t *testing.T, a, b []byte) { + if !bytes.Equal(a, b) { + t.Fatal(a, "!=", b) + } +} + func TestNoiseHandshake(t *testing.T) { dev1 := randDevice(t) dev2 := randDevice(t) @@ -18,14 +63,16 @@ func TestNoiseHandshake(t *testing.T) { defer dev1.Close() defer dev2.Close() - peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.Public()) + peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey()) if err != nil { t.Fatal(err) } - peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.Public()) + peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey()) if err != nil { t.Fatal(err) } + peer1.Start() + peer2.Start() assertEqual( t, @@ -101,7 +148,7 @@ func TestNoiseHandshake(t *testing.T) { t.Fatal("failed to derive keypair for peer 2", err) } - key1 := peer1.keypairs.next + key1 := peer1.keypairs.next.Load() key2 := peer2.keypairs.current // encrypting / decryption test diff --git a/device/peer.go b/device/peer.go index 3ec625f..47a2f14 100644 --- a/device/peer.go +++ b/device/peer.go @@ -1,48 +1,36 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "encoding/base64" + "container/list" "errors" - "fmt" "sync" "sync/atomic" "time" "golang.zx2c4.com/wireguard/conn" - "golang.zx2c4.com/wireguard/wgcfg" -) - -const ( - PeerRoutineNumber = 3 ) type Peer struct { - // These fields are accessed with atomic operations, which must be - // 64-bit aligned even on 32-bit platforms. Go guarantees that an - // allocated struct will be 64-bit aligned. So we place - // atomically-accessed fields up front, so that they can share in - // this alignment before smaller fields throw it off. - stats struct { - txBytes uint64 // bytes send to peer (endpoint) - rxBytes uint64 // bytes received from peer - lastHandshakeNano int64 // nano seconds since epoch + isRunning atomic.Bool + keypairs Keypairs + handshake Handshake + device *Device + stopping sync.WaitGroup // routines pending stop + txBytes atomic.Uint64 // bytes send to peer (endpoint) + rxBytes atomic.Uint64 // bytes received from peer + lastHandshakeNano atomic.Int64 // nano seconds since epoch + + endpoint struct { + sync.Mutex + val conn.Endpoint + clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission + disableRoaming bool } - // This field is only 32 bits wide, but is still aligned to 64 - // bits. Don't place other atomic fields after this one. - isRunning AtomicBool - - // Mostly protects endpoint, but is generally taken whenever we modify peer - sync.RWMutex - keypairs Keypairs - handshake Handshake - device *Device - endpoint conn.Endpoint - persistentKeepaliveInterval uint16 timers struct { retransmitHandshake *Timer @@ -50,41 +38,32 @@ type Peer struct { newHandshake *Timer zeroKeyMaterial *Timer persistentKeepalive *Timer - handshakeAttempts uint32 - needAnotherKeepalive AtomicBool - sentLastMinuteHandshake AtomicBool + handshakeAttempts atomic.Uint32 + needAnotherKeepalive atomic.Bool + sentLastMinuteHandshake atomic.Bool } - signals struct { - newKeypairArrived chan struct{} - flushNonceQueue chan struct{} + state struct { + sync.Mutex // protects against concurrent Start/Stop } queue struct { - nonce chan *QueueOutboundElement // nonce / pre-handshake queue - outbound chan *QueueOutboundElement // sequential ordering of work - inbound chan *QueueInboundElement // sequential ordering of work - packetInNonceQueueIsAwaitingKey AtomicBool + staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available + outbound *autodrainingOutboundQueue // sequential ordering of udp transmission + inbound *autodrainingInboundQueue // sequential ordering of tun writing } - routines struct { - sync.Mutex // held when stopping / starting routines - starting sync.WaitGroup // routines pending start - stopping sync.WaitGroup // routines pending stop - stop chan struct{} // size 0, stop all go routines in peer - } - - cookieGenerator CookieGenerator + cookieGenerator CookieGenerator + trieEntries list.List + persistentKeepaliveInterval atomic.Uint32 } -func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) { - - if device.isClosed.Get() { +func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { + if device.isClosed() { return nil, errors.New("device closed") } // lock resources - device.staticIdentity.RLock() defer device.staticIdentity.RUnlock() @@ -92,131 +71,144 @@ func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) { defer device.peers.Unlock() // check if over limit - if len(device.peers.keyMap) >= MaxPeers { return nil, errors.New("too many peers") } // create peer - peer := new(Peer) - peer.Lock() - defer peer.Unlock() peer.cookieGenerator.Init(pk) peer.device = device - peer.isRunning.Set(false) + peer.queue.outbound = newAutodrainingOutboundQueue(device) + peer.queue.inbound = newAutodrainingInboundQueue(device) + peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize) // map public key - _, ok := device.peers.keyMap[pk] if ok { return nil, errors.New("adding existing peer") } // pre-compute DH - handshake := &peer.handshake handshake.mutex.Lock() - handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(pk) + handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk) handshake.remoteStatic = pk handshake.mutex.Unlock() // reset endpoint + peer.endpoint.Lock() + peer.endpoint.val = nil + peer.endpoint.disableRoaming = false + peer.endpoint.clearSrcOnTx = false + peer.endpoint.Unlock() - peer.endpoint = nil + // init timers + peer.timersInit() // add - device.peers.keyMap[pk] = peer - // start peer - - if peer.device.isUp.Get() { - peer.Start() - } - return peer, nil } -func (peer *Peer) SendBuffer(buffer []byte) error { +func (peer *Peer) SendBuffers(buffers [][]byte) error { peer.device.net.RLock() defer peer.device.net.RUnlock() - if peer.device.net.bind == nil { - return errors.New("no bind") + if peer.device.isClosed() { + return nil } - peer.RLock() - defer peer.RUnlock() - - if peer.endpoint == nil { + peer.endpoint.Lock() + endpoint := peer.endpoint.val + if endpoint == nil { + peer.endpoint.Unlock() return errors.New("no known endpoint for peer") } + if peer.endpoint.clearSrcOnTx { + endpoint.ClearSrc() + peer.endpoint.clearSrcOnTx = false + } + peer.endpoint.Unlock() - err := peer.device.net.bind.Send(buffer, peer.endpoint) + err := peer.device.net.bind.Send(buffers, endpoint) if err == nil { - atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer))) + var totalLen uint64 + for _, b := range buffers { + totalLen += uint64(len(b)) + } + peer.txBytes.Add(totalLen) } return err } func (peer *Peer) String() string { - base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) - abbreviatedKey := "invalid" - if len(base64Key) == 44 { - abbreviatedKey = base64Key[0:4] + "ā¦" + base64Key[39:43] + // The awful goo that follows is identical to: + // + // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:]) + // abbreviatedKey := base64Key[0:4] + "ā¦" + base64Key[39:43] + // return fmt.Sprintf("peer(%s)", abbreviatedKey) + // + // except that it is considerably more efficient. + src := peer.handshake.remoteStatic + b64 := func(input byte) byte { + return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3) } - return fmt.Sprintf("peer(%s)", abbreviatedKey) + b := []byte("peer(____ā¦____)") + const first = len("peer(") + const second = len("peer(____ā¦") + b[first+0] = b64((src[0] >> 2) & 63) + b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63) + b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63) + b[first+3] = b64(src[2] & 63) + b[second+0] = b64(src[29] & 63) + b[second+1] = b64((src[30] >> 2) & 63) + b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63) + b[second+3] = b64((src[31] << 2) & 63) + return string(b) } func (peer *Peer) Start() { - // should never start a peer on a closed device - - if peer.device.isClosed.Get() { + if peer.device.isClosed() { return } // prevent simultaneous start/stop operations + peer.state.Lock() + defer peer.state.Unlock() - peer.routines.Lock() - defer peer.routines.Unlock() - - if peer.isRunning.Get() { + if peer.isRunning.Load() { return } device := peer.device - device.log.Debug.Println(peer, "- Starting...") + device.log.Verbosef("%v - Starting", peer) // reset routine state + peer.stopping.Wait() + peer.stopping.Add(2) - peer.routines.starting.Wait() - peer.routines.stopping.Wait() - peer.routines.stop = make(chan struct{}) - peer.routines.starting.Add(PeerRoutineNumber) - peer.routines.stopping.Add(PeerRoutineNumber) + peer.handshake.mutex.Lock() + peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + peer.handshake.mutex.Unlock() - // prepare queues + peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes - peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize) - peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize) + peer.timersStart() - peer.timersInit() - peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) - peer.signals.newKeypairArrived = make(chan struct{}, 1) - peer.signals.flushNonceQueue = make(chan struct{}, 1) + device.flushInboundQueue(peer.queue.inbound) + device.flushOutboundQueue(peer.queue.outbound) - // wait for routines to start + // Use the device batch size, not the bind batch size, as the device size is + // the size of the batch pools. + batchSize := peer.device.BatchSize() + go peer.RoutineSequentialSender(batchSize) + go peer.RoutineSequentialReceiver(batchSize) - go peer.RoutineNonce() - go peer.RoutineSequentialSender() - go peer.RoutineSequentialReceiver() - - peer.routines.starting.Wait() - peer.isRunning.Set(true) + peer.isRunning.Store(true) } func (peer *Peer) ZeroAndFlushAll() { @@ -228,10 +220,10 @@ func (peer *Peer) ZeroAndFlushAll() { keypairs.Lock() device.DeleteKeypair(keypairs.previous) device.DeleteKeypair(keypairs.current) - device.DeleteKeypair(keypairs.next) + device.DeleteKeypair(keypairs.next.Load()) keypairs.previous = nil keypairs.current = nil - keypairs.next = nil + keypairs.next.Store(nil) keypairs.Unlock() // clear handshake state @@ -242,7 +234,7 @@ func (peer *Peer) ZeroAndFlushAll() { handshake.Clear() handshake.mutex.Unlock() - peer.FlushNonceQueue() + peer.FlushStagedPackets() } func (peer *Peer) ExpireCurrentKeypairs() { @@ -250,58 +242,55 @@ func (peer *Peer) ExpireCurrentKeypairs() { handshake.mutex.Lock() peer.device.indexTable.Delete(handshake.localIndex) handshake.Clear() - handshake.mutex.Unlock() peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second)) + handshake.mutex.Unlock() keypairs := &peer.keypairs keypairs.Lock() if keypairs.current != nil { - keypairs.current.sendNonce = RejectAfterMessages + keypairs.current.sendNonce.Store(RejectAfterMessages) } - if keypairs.next != nil { - keypairs.next.sendNonce = RejectAfterMessages + if next := keypairs.next.Load(); next != nil { + next.sendNonce.Store(RejectAfterMessages) } keypairs.Unlock() } func (peer *Peer) Stop() { - - // prevent simultaneous start/stop operations + peer.state.Lock() + defer peer.state.Unlock() if !peer.isRunning.Swap(false) { return } - peer.routines.starting.Wait() - - peer.routines.Lock() - defer peer.routines.Unlock() - - peer.device.log.Debug.Println(peer, "- Stopping...") + peer.device.log.Verbosef("%v - Stopping", peer) peer.timersStop() - - // stop & wait for ongoing peer routines - - close(peer.routines.stop) - peer.routines.stopping.Wait() - - // close queues - - close(peer.queue.nonce) - close(peer.queue.outbound) - close(peer.queue.inbound) + // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit. + peer.queue.inbound.c <- nil + peer.queue.outbound.c <- nil + peer.stopping.Wait() + peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us peer.ZeroAndFlushAll() } -var RoamingDisabled bool - func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) { - if RoamingDisabled { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.disableRoaming { + return + } + peer.endpoint.clearSrcOnTx = false + peer.endpoint.val = endpoint +} + +func (peer *Peer) markEndpointSrcForClearing() { + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + if peer.endpoint.val == nil { return } - peer.Lock() - peer.endpoint = endpoint - peer.Unlock() + peer.endpoint.clearSrcOnTx = true } diff --git a/device/peer_test.go b/device/peer_test.go deleted file mode 100644 index de87ab6..0000000 --- a/device/peer_test.go +++ /dev/null @@ -1,29 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "testing" - "unsafe" -) - -func checkAlignment(t *testing.T, name string, offset uintptr) { - t.Helper() - if offset%8 != 0 { - t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) - } -} - -// TestPeerAlignment checks that atomically-accessed fields are -// aligned to 64-bit boundaries, as required by the atomic package. -// -// Unfortunately, violating this rule on 32-bit platforms results in a -// hard segfault at runtime. -func TestPeerAlignment(t *testing.T) { - var p Peer - checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats)) - checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning)) -} diff --git a/device/pools.go b/device/pools.go index 98f4ef1..94f3dc7 100644 --- a/device/pools.go +++ b/device/pools.go @@ -1,89 +1,120 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device -import "sync" +import ( + "sync" + "sync/atomic" +) -func (device *Device) PopulatePools() { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool = &sync.Pool{ - New: func() interface{} { - return new([MaxMessageSize]byte) - }, - } - device.pool.inboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueInboundElement) - }, - } - device.pool.outboundElementPool = &sync.Pool{ - New: func() interface{} { - return new(QueueOutboundElement) - }, - } - } else { - device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte) - } - device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.inboundElementReuseChan <- new(QueueInboundElement) - } - device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool) - for i := 0; i < PreallocatedBuffersPerPool; i += 1 { - device.pool.outboundElementReuseChan <- new(QueueOutboundElement) +type WaitPool struct { + pool sync.Pool + cond sync.Cond + lock sync.Mutex + count atomic.Uint32 + max uint32 +} + +func NewWaitPool(max uint32, new func() any) *WaitPool { + p := &WaitPool{pool: sync.Pool{New: new}, max: max} + p.cond = sync.Cond{L: &p.lock} + return p +} + +func (p *WaitPool) Get() any { + if p.max != 0 { + p.lock.Lock() + for p.count.Load() >= p.max { + p.cond.Wait() } + p.count.Add(1) + p.lock.Unlock() } + return p.pool.Get() } -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - if PreallocatedBuffersPerPool == 0 { - return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte) - } else { - return <-device.pool.messageBufferReuseChan +func (p *WaitPool) Put(x any) { + p.pool.Put(x) + if p.max == 0 { + return } + p.count.Add(^uint32(0)) + p.cond.Signal() } -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - if PreallocatedBuffersPerPool == 0 { - device.pool.messageBufferPool.Put(msg) - } else { - device.pool.messageBufferReuseChan <- msg - } +func (device *Device) PopulatePools() { + device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueInboundElement, 0, device.BatchSize()) + return &QueueInboundElementsContainer{elems: s} + }) + device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + s := make([]*QueueOutboundElement, 0, device.BatchSize()) + return &QueueOutboundElementsContainer{elems: s} + }) + device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new([MaxMessageSize]byte) + }) + device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new(QueueInboundElement) + }) + device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + return new(QueueOutboundElement) + }) } -func (device *Device) GetInboundElement() *QueueInboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.inboundElementPool.Get().(*QueueInboundElement) - } else { - return <-device.pool.inboundElementReuseChan +func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer { + c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer) + c.Mutex = sync.Mutex{} + return c +} + +func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } + c.elems = c.elems[:0] + device.pool.inboundElementsContainer.Put(c) +} + +func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer { + c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer) + c.Mutex = sync.Mutex{} + return c } -func (device *Device) PutInboundElement(msg *QueueInboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.inboundElementPool.Put(msg) - } else { - device.pool.inboundElementReuseChan <- msg +func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) { + for i := range c.elems { + c.elems[i] = nil } + c.elems = c.elems[:0] + device.pool.outboundElementsContainer.Put(c) +} + +func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { + return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) +} + +func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { + device.pool.messageBuffers.Put(msg) +} + +func (device *Device) GetInboundElement() *QueueInboundElement { + return device.pool.inboundElements.Get().(*QueueInboundElement) +} + +func (device *Device) PutInboundElement(elem *QueueInboundElement) { + elem.clearPointers() + device.pool.inboundElements.Put(elem) } func (device *Device) GetOutboundElement() *QueueOutboundElement { - if PreallocatedBuffersPerPool == 0 { - return device.pool.outboundElementPool.Get().(*QueueOutboundElement) - } else { - return <-device.pool.outboundElementReuseChan - } + return device.pool.outboundElements.Get().(*QueueOutboundElement) } -func (device *Device) PutOutboundElement(msg *QueueOutboundElement) { - if PreallocatedBuffersPerPool == 0 { - device.pool.outboundElementPool.Put(msg) - } else { - device.pool.outboundElementReuseChan <- msg - } +func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { + elem.clearPointers() + device.pool.outboundElements.Put(elem) } diff --git a/device/pools_test.go b/device/pools_test.go new file mode 100644 index 0000000..82d7493 --- /dev/null +++ b/device/pools_test.go @@ -0,0 +1,139 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "math/rand" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestWaitPool(t *testing.T) { + t.Skip("Currently disabled") + var wg sync.WaitGroup + var trials atomic.Int32 + startTrials := int32(100000) + if raceEnabled { + // This test can be very slow with -race. + startTrials /= 10 + } + trials.Store(startTrials) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + t.Skip("Not enough cores") + } + p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) + wg.Add(workers) + var max atomic.Uint32 + updateMax := func() { + count := p.count.Load() + if count > p.max { + t.Errorf("count (%d) > max (%d)", count, p.max) + } + for { + old := max.Load() + if count <= old { + break + } + if max.CompareAndSwap(old, count) { + break + } + } + } + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + updateMax() + x := p.Get() + updateMax() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + updateMax() + p.Put(x) + updateMax() + } + }() + } + wg.Wait() + if max.Load() != p.max { + t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max) + } +} + +func BenchmarkWaitPool(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} + +func BenchmarkWaitPoolEmpty(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := NewWaitPool(0, func() any { return make([]byte, 16) }) + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} + +func BenchmarkSyncPool(b *testing.B) { + var wg sync.WaitGroup + var trials atomic.Int32 + trials.Store(int32(b.N)) + workers := runtime.NumCPU() + 2 + if workers-4 <= 0 { + b.Skip("Not enough cores") + } + p := sync.Pool{New: func() any { return make([]byte, 16) }} + wg.Add(workers) + b.ResetTimer() + for i := 0; i < workers; i++ { + go func() { + defer wg.Done() + for trials.Add(-1) > 0 { + x := p.Get() + time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond) + p.Put(x) + } + }() + } + wg.Wait() +} diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index f5c042d..25f700a 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -1,16 +1,19 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device +import "golang.zx2c4.com/wireguard/conn" + /* Reduce memory consumption for Android */ const ( + QueueStagedSize = conn.IdealBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 - MaxSegmentSize = 2200 + MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram PreallocatedBuffersPerPool = 4096 ) diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index cf86ba1..ea763d0 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -1,13 +1,16 @@ -// +build !android,!ios +//go:build !android && !ios && !windows /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device +import "golang.zx2c4.com/wireguard/conn" + const ( + QueueStagedSize = conn.IdealBatchSize QueueOutboundSize = 1024 QueueInboundSize = 1024 QueueHandshakeSize = 1024 diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go index 589b0aa..acd3cec 100644 --- a/device/queueconstants_ios.go +++ b/device/queueconstants_ios.go @@ -1,18 +1,21 @@ -// +build ios +//go:build ios /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device -/* Fit within memory limits for iOS's Network Extension API, which has stricter requirements */ - -const ( - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = 1700 - PreallocatedBuffersPerPool = 1024 +// Fit within memory limits for iOS's Network Extension API, which has stricter requirements. +// These are vars instead of consts, because heavier network extensions might want to reduce +// them further. +var ( + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + PreallocatedBuffersPerPool uint32 = 1024 ) + +const MaxSegmentSize = 1700 diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go new file mode 100644 index 0000000..1eee32b --- /dev/null +++ b/device/queueconstants_windows.go @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +const ( + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 + MaxSegmentSize = 2048 - 32 // largest possible UDP datagram + PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth +) diff --git a/device/race_disabled_test.go b/device/race_disabled_test.go new file mode 100644 index 0000000..bb5c450 --- /dev/null +++ b/device/race_disabled_test.go @@ -0,0 +1,10 @@ +//go:build !race + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +const raceEnabled = false diff --git a/device/race_enabled_test.go b/device/race_enabled_test.go new file mode 100644 index 0000000..4e9daea --- /dev/null +++ b/device/race_enabled_test.go @@ -0,0 +1,10 @@ +//go:build race + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +const raceEnabled = true diff --git a/device/receive.go b/device/receive.go index 4818d64..1ab3e29 100644 --- a/device/receive.go +++ b/device/receive.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,10 +8,9 @@ package device import ( "bytes" "encoding/binary" + "errors" "net" - "strconv" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" @@ -28,8 +27,6 @@ type QueueHandshakeElement struct { } type QueueInboundElement struct { - dropped int32 - sync.Mutex buffer *[MaxMessageSize]byte packet []byte counter uint64 @@ -37,38 +34,20 @@ type QueueInboundElement struct { endpoint conn.Endpoint } -func (elem *QueueInboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueInboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool { - select { - case inboundQueue <- element: - select { - case decryptionQueue <- element: - return true - default: - element.Drop() - element.Unlock() - return false - } - default: - device.PutInboundElement(element) - return false - } +type QueueInboundElementsContainer struct { + sync.Mutex + elems []*QueueInboundElement } -func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool { - select { - case queue <- element: - return true - default: - return false - } +// clearPointers clears elem fields that contain pointers. +// This makes the garbage collector's life easier and +// avoids accidentally keeping other objects around unnecessarily. +// It also reduces the possible collateral damage from use-after-free bugs. +func (elem *QueueInboundElement) clearPointers() { + elem.buffer = nil + elem.packet = nil + elem.keypair = nil + elem.endpoint = nil } /* Called when a new authenticated message has been received @@ -76,12 +55,12 @@ func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem * NOTE: Not thread safe, but called by sequential receiver! */ func (peer *Peer) keepKeyFreshReceiving() { - if peer.timers.sentLastMinuteHandshake.Get() { + if peer.timers.sentLastMinuteHandshake.Load() { return } keypair := peer.keypairs.Current() if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { - peer.timers.sentLastMinuteHandshake.Set(true) + peer.timers.sentLastMinuteHandshake.Store(true) peer.SendHandshakeInitiation(false) } } @@ -91,188 +70,189 @@ func (peer *Peer) keepKeyFreshReceiving() { * Every time the bind is updated a new routine is started for * IPv4 and IPv6 (separately) */ -func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { - - logDebug := device.log.Debug +func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) { + recvName := recv.PrettyName() defer func() { - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") + device.log.Verbosef("Routine: receive incoming %s - stopped", recvName) + device.queue.decryption.wg.Done() + device.queue.handshake.wg.Done() device.net.stopping.Done() }() - logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started") - device.net.starting.Done() + device.log.Verbosef("Routine: receive incoming %s - started", recvName) // receive datagrams until conn is closed - buffer := device.GetMessageBuffer() - var ( - err error - size int - endpoint conn.Endpoint + bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) + bufs = make([][]byte, maxBatchSize) + err error + sizes = make([]int, maxBatchSize) + count int + endpoints = make([]conn.Endpoint, maxBatchSize) + deathSpiral int + elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) - for { - - // read next datagram + for i := range bufsArrs { + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + } - switch IP { - case ipv4.Version: - size, endpoint, err = bind.ReceiveIPv4(buffer[:]) - case ipv6.Version: - size, endpoint, err = bind.ReceiveIPv6(buffer[:]) - default: - panic("invalid IP version") + defer func() { + for i := 0; i < maxBatchSize; i++ { + if bufsArrs[i] != nil { + device.PutMessageBuffer(bufsArrs[i]) + } } + }() + for { + count, err = recv(bufs, sizes, endpoints) if err != nil { - device.PutMessageBuffer(buffer) + if errors.Is(err, net.ErrClosed) { + return + } + device.log.Verbosef("Failed to receive %s packet: %v", recvName, err) + if neterr, ok := err.(net.Error); ok && !neterr.Temporary() { + return + } + if deathSpiral < 10 { + deathSpiral++ + time.Sleep(time.Second / 3) + continue + } return } + deathSpiral = 0 - if size < MinMessageSize { - continue - } - - // check size of packet + // handle each packet in the batch + for i, size := range sizes[:count] { + if size < MinMessageSize { + continue + } - packet := buffer[:size] - msgType := binary.LittleEndian.Uint32(packet[:4]) + // check size of packet - var okay bool + packet := bufsArrs[i][:size] + msgType := binary.LittleEndian.Uint32(packet[:4]) - switch msgType { + switch msgType { - // check if transport + // check if transport - case MessageTransportType: + case MessageTransportType: - // check size + // check size - if len(packet) < MessageTransportSize { - continue - } + if len(packet) < MessageTransportSize { + continue + } - // lookup key pair + // lookup key pair - receiver := binary.LittleEndian.Uint32( - packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], - ) - value := device.indexTable.Lookup(receiver) - keypair := value.keypair - if keypair == nil { - continue - } + receiver := binary.LittleEndian.Uint32( + packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter], + ) + value := device.indexTable.Lookup(receiver) + keypair := value.keypair + if keypair == nil { + continue + } - // check keypair expiry + // check keypair expiry - if keypair.created.Add(RejectAfterTime).Before(time.Now()) { - continue - } + if keypair.created.Add(RejectAfterTime).Before(time.Now()) { + continue + } - // create work element - peer := value.peer - elem := device.GetInboundElement() - elem.packet = packet - elem.buffer = buffer - elem.keypair = keypair - elem.dropped = AtomicFalse - elem.endpoint = endpoint - elem.counter = 0 - elem.Mutex = sync.Mutex{} - elem.Lock() - - // add to decryption queues - - if peer.isRunning.Get() { - if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) { - buffer = device.GetMessageBuffer() + // create work element + peer := value.peer + elem := device.GetInboundElement() + elem.packet = packet + elem.buffer = bufsArrs[i] + elem.keypair = keypair + elem.endpoint = endpoints[i] + elem.counter = 0 + + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetInboundElementsContainer() + elemsForPeer.Lock() + elemsByPeer[peer] = elemsForPeer } - } + elemsForPeer.elems = append(elemsForPeer.elems, elem) + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + continue - continue + // otherwise it is a fixed size & handshake related packet - // otherwise it is a fixed size & handshake related packet + case MessageInitiationType: + if len(packet) != MessageInitiationSize { + continue + } - case MessageInitiationType: - okay = len(packet) == MessageInitiationSize + case MessageResponseType: + if len(packet) != MessageResponseSize { + continue + } - case MessageResponseType: - okay = len(packet) == MessageResponseSize + case MessageCookieReplyType: + if len(packet) != MessageCookieReplySize { + continue + } - case MessageCookieReplyType: - okay = len(packet) == MessageCookieReplySize + default: + device.log.Verbosef("Received message with unknown type") + continue + } - default: - logDebug.Println("Received message with unknown type") + select { + case device.queue.handshake.c <- QueueHandshakeElement{ + msgType: msgType, + buffer: bufsArrs[i], + packet: packet, + endpoint: endpoints[i], + }: + bufsArrs[i] = device.GetMessageBuffer() + bufs[i] = bufsArrs[i][:] + default: + } } - - if okay { - if (device.addToHandshakeQueue( - device.queue.handshake, - QueueHandshakeElement{ - msgType: msgType, - buffer: buffer, - packet: packet, - endpoint: endpoint, - }, - )) { - buffer = device.GetMessageBuffer() + for peer, elemsContainer := range elemsByPeer { + if peer.isRunning.Load() { + peer.queue.inbound.c <- elemsContainer + device.queue.decryption.c <- elemsContainer + } else { + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) + } + device.PutInboundElementsContainer(elemsContainer) } + delete(elemsByPeer, peer) } } } -func (device *Device) RoutineDecryption() { - +func (device *Device) RoutineDecryption(id int) { var nonce [chacha20poly1305.NonceSize]byte - logDebug := device.log.Debug - defer func() { - logDebug.Println("Routine: decryption worker - stopped") - device.state.stopping.Done() - }() - logDebug.Println("Routine: decryption worker - started") - device.state.starting.Done() - - for { - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.decryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } + defer device.log.Verbosef("Routine: decryption worker %d - stopped", id) + device.log.Verbosef("Routine: decryption worker %d - started", id) + for elemsContainer := range device.queue.decryption.c { + for _, elem := range elemsContainer.elems { // split message into fields - counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent] content := elem.packet[MessageTransportOffsetContent:] - // expand nonce - - nonce[0x4] = counter[0x0] - nonce[0x5] = counter[0x1] - nonce[0x6] = counter[0x2] - nonce[0x7] = counter[0x3] - - nonce[0x8] = counter[0x4] - nonce[0x9] = counter[0x5] - nonce[0xa] = counter[0x6] - nonce[0xb] = counter[0x7] - // decrypt and release to consumer - var err error elem.counter = binary.LittleEndian.Uint64(counter) + // copy counter to nonce + binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter) elem.packet, err = elem.keypair.receive.Open( content[:0], nonce[:], @@ -280,51 +260,23 @@ func (device *Device) RoutineDecryption() { nil, ) if err != nil { - elem.Drop() - device.PutMessageBuffer(elem.buffer) + elem.packet = nil } - elem.Unlock() } + elemsContainer.Unlock() } } /* Handles incoming packets related to handshake */ -func (device *Device) RoutineHandshake() { - - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem QueueHandshakeElement - var ok bool - +func (device *Device) RoutineHandshake(id int) { defer func() { - logDebug.Println("Routine: handshake worker - stopped") - device.state.stopping.Done() - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - } + device.log.Verbosef("Routine: handshake worker %d - stopped", id) + device.queue.encryption.wg.Done() }() + device.log.Verbosef("Routine: handshake worker %d - started", id) - logDebug.Println("Routine: handshake worker - started") - device.state.starting.Done() - - for { - if elem.buffer != nil { - device.PutMessageBuffer(elem.buffer) - elem.buffer = nil - } - - select { - case elem, ok = <-device.queue.handshake: - case <-device.signals.stop: - return - } - - if !ok { - return - } + for elem := range device.queue.handshake.c { // handle cookie fields and ratelimiting @@ -338,8 +290,8 @@ func (device *Device) RoutineHandshake() { reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &reply) if err != nil { - logDebug.Println("Failed to decode cookie reply") - return + device.log.Verbosef("Failed to decode cookie reply") + goto skip } // lookup peer from index @@ -347,27 +299,27 @@ func (device *Device) RoutineHandshake() { entry := device.indexTable.Lookup(reply.Receiver) if entry.peer == nil { - continue + goto skip } // consume reply - if peer := entry.peer; peer.isRunning.Get() { - logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString()) + if peer := entry.peer; peer.isRunning.Load() { + device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString()) if !peer.cookieGenerator.ConsumeReply(&reply) { - logDebug.Println("Could not decrypt invalid cookie response") + device.log.Verbosef("Could not decrypt invalid cookie response") } } - continue + goto skip case MessageInitiationType, MessageResponseType: // check mac fields and maybe ratelimit if !device.cookieChecker.CheckMAC1(elem.packet) { - logDebug.Println("Received packet with invalid mac1") - continue + device.log.Verbosef("Received packet with invalid mac1") + goto skip } // endpoints destination address is the source of the datagram @@ -378,19 +330,19 @@ func (device *Device) RoutineHandshake() { if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) { device.SendHandshakeCookie(&elem) - continue + goto skip } // check ratelimiter if !device.rate.limiter.Allow(elem.endpoint.DstIP()) { - continue + goto skip } } default: - logError.Println("Invalid packet ended up in the handshake queue") - continue + device.log.Errorf("Invalid packet ended up in the handshake queue") + goto skip } // handle handshake initiation/response content @@ -404,19 +356,16 @@ func (device *Device) RoutineHandshake() { reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { - logError.Println("Failed to decode initiation message") - continue + device.log.Errorf("Failed to decode initiation message") + goto skip } // consume initiation peer := device.ConsumeMessageInitiation(&msg) if peer == nil { - logInfo.Println( - "Received invalid initiation message from", - elem.endpoint.DstToString(), - ) - continue + device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString()) + goto skip } // update timers @@ -427,8 +376,8 @@ func (device *Device) RoutineHandshake() { // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - logDebug.Println(peer, "- Received handshake initiation") - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + device.log.Verbosef("%v - Received handshake initiation", peer) + peer.rxBytes.Add(uint64(len(elem.packet))) peer.SendHandshakeResponse() @@ -440,26 +389,23 @@ func (device *Device) RoutineHandshake() { reader := bytes.NewReader(elem.packet) err := binary.Read(reader, binary.LittleEndian, &msg) if err != nil { - logError.Println("Failed to decode response message") - continue + device.log.Errorf("Failed to decode response message") + goto skip } // consume response peer := device.ConsumeMessageResponse(&msg) if peer == nil { - logInfo.Println( - "Received invalid response message from", - elem.endpoint.DstToString(), - ) - continue + device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString()) + goto skip } // update endpoint peer.SetEndpointFromPacket(elem.endpoint) - logDebug.Println(peer, "- Received handshake response") - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) + device.log.Verbosef("%v - Received handshake response", peer) + peer.rxBytes.Add(uint64(len(elem.packet))) // update timers @@ -471,178 +417,124 @@ func (device *Device) RoutineHandshake() { err = peer.BeginSymmetricSession() if err != nil { - logError.Println(peer, "- Failed to derive keypair:", err) - continue + device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) + goto skip } peer.timersSessionDerived() peer.timersHandshakeComplete() peer.SendKeepalive() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } } + skip: + device.PutMessageBuffer(elem.buffer) } } -func (peer *Peer) RoutineSequentialReceiver() { - +func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { device := peer.device - logInfo := device.log.Info - logError := device.log.Error - logDebug := device.log.Debug - - var elem *QueueInboundElement - defer func() { - logDebug.Println(peer, "- Routine: sequential receiver - stopped") - peer.routines.stopping.Done() - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - } + device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer) + peer.stopping.Done() }() + device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - logDebug.Println(peer, "- Routine: sequential receiver - started") - - peer.routines.starting.Done() - - for { - if elem != nil { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - } - device.PutInboundElement(elem) - elem = nil - } + bufs := make([][]byte, 0, maxBatchSize) - var elemOk bool - select { - case <-peer.routines.stop: + for elemsContainer := range peer.queue.inbound.c { + if elemsContainer == nil { return - case elem, elemOk = <-peer.queue.inbound: - if !elemOk { - return - } - } - - // wait for decryption - - elem.Lock() - - if elem.IsDropped() { - continue - } - - // check for replay - - if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { - continue - } - - // update endpoint - peer.SetEndpointFromPacket(elem.endpoint) - - // check if using new keypair - if peer.ReceivedWithKeypair(elem.keypair) { - peer.timersHandshakeComplete() - select { - case peer.signals.newKeypairArrived <- struct{}{}: - default: - } - } - - peer.keepKeyFreshReceiving() - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketReceived() - atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize)) - - // check for keepalive - - if len(elem.packet) == 0 { - logDebug.Println(peer, "- Receiving keepalive packet") - continue } - peer.timersDataReceived() - - // verify source and strip padding - - switch elem.packet[0] >> 4 { - case ipv4.Version: - - // strip padding - - if len(elem.packet) < ipv4.HeaderLen { + elemsContainer.Lock() + validTailPacket := -1 + dataPacketReceived := false + rxBytesLen := uint64(0) + for i, elem := range elemsContainer.elems { + if elem.packet == nil { + // decryption failed continue } - field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] - length := binary.BigEndian.Uint16(field) - if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) { continue } - elem.packet = elem.packet[:length] - - // verify IPv4 source - - src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] - if device.allowedips.LookupIPv4(src) != peer { - logInfo.Println( - "IPv4 packet with disallowed source address from", - peer, - ) - continue + validTailPacket = i + if peer.ReceivedWithKeypair(elem.keypair) { + peer.SetEndpointFromPacket(elem.endpoint) + peer.timersHandshakeComplete() + peer.SendStagedPackets() } + rxBytesLen += uint64(len(elem.packet) + MinMessageSize) - case ipv6.Version: - - // strip padding - - if len(elem.packet) < ipv6.HeaderLen { + if len(elem.packet) == 0 { + device.log.Verbosef("%v - Receiving keepalive packet", peer) continue } + dataPacketReceived = true - field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] - length := binary.BigEndian.Uint16(field) - length += ipv6.HeaderLen - if int(length) > len(elem.packet) { - continue - } - - elem.packet = elem.packet[:length] + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2] + length := binary.BigEndian.Uint16(field) + if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer) + continue + } - // verify IPv6 source + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2] + length := binary.BigEndian.Uint16(field) + length += ipv6.HeaderLen + if int(length) > len(elem.packet) { + continue + } + elem.packet = elem.packet[:length] + src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] + if device.allowedips.Lookup(src) != peer { + device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer) + continue + } - src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] - if device.allowedips.LookupIPv6(src) != peer { - logInfo.Println( - "IPv6 packet with disallowed source address from", - peer, - ) + default: + device.log.Verbosef("Packet with invalid IP version from %v", peer) continue } - default: - logInfo.Println("Packet with invalid IP version from", peer) - continue + bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) } - // write to tun device - - offset := MessageTransportOffsetContent - _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) - if len(peer.queue.inbound) == 0 { - err = device.tun.device.Flush() - if err != nil { - peer.device.log.Error.Printf("Unable to flush packets: %v", err) + peer.rxBytes.Add(rxBytesLen) + if validTailPacket >= 0 { + peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint) + peer.keepKeyFreshReceiving() + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketReceived() + } + if dataPacketReceived { + peer.timersDataReceived() + } + if len(bufs) > 0 { + _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent) + if err != nil && !device.isClosed() { + device.log.Errorf("Failed to write packets to TUN device: %v", err) } } - if err != nil && !device.isClosed.Get() { - logError.Println("Failed to write packet to TUN device:", err) + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutInboundElement(elem) } + bufs = bufs[:0] + device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index 9e29d77..769720a 100644 --- a/device/send.go +++ b/device/send.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device @@ -8,14 +8,17 @@ package device import ( "bytes" "encoding/binary" + "errors" "net" + "os" "sync" - "sync/atomic" "time" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/tun" ) /* Outbound flow @@ -43,8 +46,6 @@ import ( */ type QueueOutboundElement struct { - dropped int32 - sync.Mutex buffer *[MaxMessageSize]byte // slice holding the packet data packet []byte // slice of "buffer" (always!) nonce uint64 // nonce for encryption @@ -52,80 +53,52 @@ type QueueOutboundElement struct { peer *Peer // related peer } +type QueueOutboundElementsContainer struct { + sync.Mutex + elems []*QueueOutboundElement +} + func (device *Device) NewOutboundElement() *QueueOutboundElement { elem := device.GetOutboundElement() - elem.dropped = AtomicFalse elem.buffer = device.GetMessageBuffer() - elem.Mutex = sync.Mutex{} elem.nonce = 0 - elem.keypair = nil - elem.peer = nil + // keypair and peer were cleared (if necessary) by clearPointers. return elem } -func (elem *QueueOutboundElement) Drop() { - atomic.StoreInt32(&elem.dropped, AtomicTrue) -} - -func (elem *QueueOutboundElement) IsDropped() bool { - return atomic.LoadInt32(&elem.dropped) == AtomicTrue -} - -func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) { - for { - select { - case queue <- element: - return - default: - select { - case old := <-queue: - device.PutMessageBuffer(old.buffer) - device.PutOutboundElement(old) - default: - } - } - } +// clearPointers clears elem fields that contain pointers. +// This makes the garbage collector's life easier and +// avoids accidentally keeping other objects around unnecessarily. +// It also reduces the possible collateral damage from use-after-free bugs. +func (elem *QueueOutboundElement) clearPointers() { + elem.buffer = nil + elem.packet = nil + elem.keypair = nil + elem.peer = nil } -func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) { - select { - case outboundQueue <- element: +/* Queues a keepalive if no packets are queued for peer + */ +func (peer *Peer) SendKeepalive() { + if len(peer.queue.staged) == 0 && peer.isRunning.Load() { + elem := peer.device.NewOutboundElement() + elemsContainer := peer.device.GetOutboundElementsContainer() + elemsContainer.elems = append(elemsContainer.elems, elem) select { - case encryptionQueue <- element: - return + case peer.queue.staged <- elemsContainer: + peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: - element.Drop() - element.peer.device.PutMessageBuffer(element.buffer) - element.Unlock() + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + peer.device.PutOutboundElementsContainer(elemsContainer) } - default: - element.peer.device.PutMessageBuffer(element.buffer) - element.peer.device.PutOutboundElement(element) - } -} - -/* Queues a keepalive if no packets are queued for peer - */ -func (peer *Peer) SendKeepalive() bool { - if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() { - return false - } - elem := peer.device.NewOutboundElement() - elem.packet = nil - select { - case peer.queue.nonce <- elem: - peer.device.log.Debug.Println(peer, "- Sending keepalive packet") - return true - default: - peer.device.PutMessageBuffer(elem.buffer) - peer.device.PutOutboundElement(elem) - return false } + peer.SendStagedPackets() } func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { if !isRetry { - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) + peer.timers.handshakeAttempts.Store(0) } peer.handshake.mutex.RLock() @@ -143,16 +116,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.handshake.lastSentHandshake = time.Now() peer.handshake.mutex.Unlock() - peer.device.log.Debug.Println(peer, "- Sending handshake initiation") + peer.device.log.Verbosef("%v - Sending handshake initiation", peer) msg, err := peer.device.CreateMessageInitiation(peer) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err) + peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err) return err } - var buff [MessageInitiationSize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageInitiationSize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, msg) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) @@ -160,9 +133,9 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + err = peer.SendBuffers([][]byte{packet}) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err) + peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } peer.timersHandshakeInitiated() @@ -174,23 +147,23 @@ func (peer *Peer) SendHandshakeResponse() error { peer.handshake.lastSentHandshake = time.Now() peer.handshake.mutex.Unlock() - peer.device.log.Debug.Println(peer, "- Sending handshake response") + peer.device.log.Verbosef("%v - Sending handshake response", peer) response, err := peer.device.CreateMessageResponse(peer) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to create response message:", err) + peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err) return err } - var buff [MessageResponseSize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageResponseSize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, response) packet := writer.Bytes() peer.cookieGenerator.AddMacs(packet) err = peer.BeginSymmetricSession() if err != nil { - peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err) + peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err) return err } @@ -198,28 +171,29 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffer(packet) + // TODO: allocation could be avoided + err = peer.SendBuffers([][]byte{packet}) if err != nil { - peer.device.log.Error.Println(peer, "- Failed to send handshake response", err) + peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } return err } func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error { - - device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString()) + device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString()) sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8]) reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes()) if err != nil { - device.log.Error.Println("Failed to create cookie reply:", err) + device.log.Errorf("Failed to create cookie reply: %v", err) return err } - var buff [MessageCookieReplySize]byte - writer := bytes.NewBuffer(buff[:0]) + var buf [MessageCookieReplySize]byte + writer := bytes.NewBuffer(buf[:0]) binary.Write(writer, binary.LittleEndian, reply) - device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) + // TODO: allocation could be avoided + device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint) return nil } @@ -228,280 +202,255 @@ func (peer *Peer) keepKeyFreshSending() { if keypair == nil { return } - nonce := atomic.LoadUint64(&keypair.sendNonce) + nonce := keypair.sendNonce.Load() if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) { peer.SendHandshakeInitiation(false) } } -/* Reads packets from the TUN and inserts - * into nonce queue for peer - * - * Obs. Single instance per TUN device - */ func (device *Device) RoutineReadFromTUN() { - - logDebug := device.log.Debug - logError := device.log.Error - defer func() { - logDebug.Println("Routine: TUN reader - stopped") + device.log.Verbosef("Routine: TUN reader - stopped") device.state.stopping.Done() + device.queue.encryption.wg.Done() }() - logDebug.Println("Routine: TUN reader - started") - device.state.starting.Done() - - var elem *QueueOutboundElement + device.log.Verbosef("Routine: TUN reader - started") + + var ( + batchSize = device.BatchSize() + readErr error + elems = make([]*QueueOutboundElement, batchSize) + bufs = make([][]byte, batchSize) + elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) + count = 0 + sizes = make([]int, batchSize) + offset = MessageTransportHeaderSize + ) + + for i := range elems { + elems[i] = device.NewOutboundElement() + bufs[i] = elems[i].buffer[:] + } - for { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) + defer func() { + for _, elem := range elems { + if elem != nil { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } } - elem = device.NewOutboundElement() - - // read packet - - offset := MessageTransportHeaderSize - size, err := device.tun.device.Read(elem.buffer[:], offset) + }() - if err != nil { - if !device.isClosed.Get() { - logError.Println("Failed to read packet from TUN device:", err) - device.Close() + for { + // read packets + count, readErr = device.tun.device.Read(bufs, sizes, offset) + for i := 0; i < count; i++ { + if sizes[i] < 1 { + continue } - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return - } - if size == 0 || size > MaxContentSize { - continue - } + elem := elems[i] + elem.packet = bufs[i][offset : offset+sizes[i]] - elem.packet = elem.buffer[offset : offset+size] + // lookup peer + var peer *Peer + switch elem.packet[0] >> 4 { + case 4: + if len(elem.packet) < ipv4.HeaderLen { + continue + } + dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] + peer = device.allowedips.Lookup(dst) - // lookup peer + case 6: + if len(elem.packet) < ipv6.HeaderLen { + continue + } + dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] + peer = device.allowedips.Lookup(dst) - var peer *Peer - switch elem.packet[0] >> 4 { - case ipv4.Version: - if len(elem.packet) < ipv4.HeaderLen { - continue + default: + device.log.Verbosef("Received packet with unknown IP version") } - dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len] - peer = device.allowedips.LookupIPv4(dst) - case ipv6.Version: - if len(elem.packet) < ipv6.HeaderLen { + if peer == nil { continue } - dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len] - peer = device.allowedips.LookupIPv6(dst) - - default: - logDebug.Println("Received packet with unknown IP version") + elemsForPeer, ok := elemsByPeer[peer] + if !ok { + elemsForPeer = device.GetOutboundElementsContainer() + elemsByPeer[peer] = elemsForPeer + } + elemsForPeer.elems = append(elemsForPeer.elems, elem) + elems[i] = device.NewOutboundElement() + bufs[i] = elems[i].buffer[:] } - if peer == nil { - continue + for peer, elemsForPeer := range elemsByPeer { + if peer.isRunning.Load() { + peer.StagePackets(elemsForPeer) + peer.SendStagedPackets() + } else { + for _, elem := range elemsForPeer.elems { + device.PutMessageBuffer(elem.buffer) + device.PutOutboundElement(elem) + } + device.PutOutboundElementsContainer(elemsForPeer) + } + delete(elemsByPeer, peer) } - // insert into nonce/pre-handshake queue - - if peer.isRunning.Get() { - if peer.queue.packetInNonceQueueIsAwaitingKey.Get() { - peer.SendHandshakeInitiation(false) + if readErr != nil { + if errors.Is(readErr, tun.ErrTooManySegments) { + // TODO: record stat for this + // This will happen if MSS is surprisingly small (< 576) + // coincident with reasonably high throughput. + device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr) + continue } - addToNonceQueue(peer.queue.nonce, elem, device) - elem = nil + if !device.isClosed() { + if !errors.Is(readErr, os.ErrClosed) { + device.log.Errorf("Failed to read packet from TUN device: %v", readErr) + } + go device.Close() + } + return } } } -func (peer *Peer) FlushNonceQueue() { - select { - case peer.signals.flushNonceQueue <- struct{}{}: - default: - } -} - -/* Queues packets when there is no handshake. - * Then assigns nonces to packets sequentially - * and creates "work" structs for workers - * - * Obs. A single instance per peer - */ -func (peer *Peer) RoutineNonce() { - var keypair *Keypair - - device := peer.device - logDebug := device.log.Debug - - flush := func() { - for { - select { - case elem := <-peer.queue.nonce: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - default: - return +func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { + for { + select { + case peer.queue.staged <- elems: + return + default: + } + select { + case tooOld := <-peer.queue.staged: + for _, elem := range tooOld.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) } + peer.device.PutOutboundElementsContainer(tooOld) + default: } } +} - defer func() { - flush() - logDebug.Println(peer, "- Routine: nonce worker - stopped") - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - peer.routines.stopping.Done() - }() +func (peer *Peer) SendStagedPackets() { +top: + if len(peer.queue.staged) == 0 || !peer.device.isUp() { + return + } - peer.routines.starting.Done() - logDebug.Println(peer, "- Routine: nonce worker - started") + keypair := peer.keypairs.Current() + if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime { + peer.SendHandshakeInitiation(false) + return + } for { - NextPacket: - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - + var elemsContainerOOO *QueueOutboundElementsContainer select { - case <-peer.routines.stop: - return - - case <-peer.signals.flushNonceQueue: - flush() - goto NextPacket - - case elem, ok := <-peer.queue.nonce: - - if !ok { - return - } - - // make sure to always pick the newest key - - for { - - // check validity of newest key pair - - keypair = peer.keypairs.Current() - if keypair != nil && keypair.sendNonce < RejectAfterMessages { - if time.Since(keypair.created) < RejectAfterTime { - break + case elemsContainer := <-peer.queue.staged: + i := 0 + for _, elem := range elemsContainer.elems { + elem.peer = peer + elem.nonce = keypair.sendNonce.Add(1) - 1 + if elem.nonce >= RejectAfterMessages { + keypair.sendNonce.Store(RejectAfterMessages) + if elemsContainerOOO == nil { + elemsContainerOOO = peer.device.GetOutboundElementsContainer() } + elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem) + continue + } else { + elemsContainer.elems[i] = elem + i++ } - peer.queue.packetInNonceQueueIsAwaitingKey.Set(true) - - // no suitable key pair, request for new handshake - - select { - case <-peer.signals.newKeypairArrived: - default: - } - - peer.SendHandshakeInitiation(false) - // wait for key to be established - - logDebug.Println(peer, "- Awaiting keypair") + elem.keypair = keypair + } + elemsContainer.Lock() + elemsContainer.elems = elemsContainer.elems[:i] - select { - case <-peer.signals.newKeypairArrived: - logDebug.Println(peer, "- Obtained awaited keypair") + if elemsContainerOOO != nil { + peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans + } - case <-peer.signals.flushNonceQueue: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - flush() - goto NextPacket + if len(elemsContainer.elems) == 0 { + peer.device.PutOutboundElementsContainer(elemsContainer) + goto top + } - case <-peer.routines.stop: - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - return + // add to parallel and sequential queue + if peer.isRunning.Load() { + peer.queue.outbound.c <- elemsContainer + peer.device.queue.encryption.c <- elemsContainer + } else { + for _, elem := range elemsContainer.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) } + peer.device.PutOutboundElementsContainer(elemsContainer) } - peer.queue.packetInNonceQueueIsAwaitingKey.Set(false) - - // populate work element - elem.peer = peer - elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1 - - // double check in case of race condition added by future code - - if elem.nonce >= RejectAfterMessages { - atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - goto NextPacket + if elemsContainerOOO != nil { + goto top } + default: + return + } + } +} - elem.keypair = keypair - elem.dropped = AtomicFalse - elem.Lock() - - // add to parallel and sequential queue - addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem) +func (peer *Peer) FlushStagedPackets() { + for { + select { + case elemsContainer := <-peer.queue.staged: + for _, elem := range elemsContainer.elems { + peer.device.PutMessageBuffer(elem.buffer) + peer.device.PutOutboundElement(elem) + } + peer.device.PutOutboundElementsContainer(elemsContainer) + default: + return } } } +func calculatePaddingSize(packetSize, mtu int) int { + lastUnit := packetSize + if mtu == 0 { + return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit + } + if lastUnit > mtu { + lastUnit %= mtu + } + paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) + if paddedSize > mtu { + paddedSize = mtu + } + return paddedSize - lastUnit +} + /* Encrypts the elements in the queue * and marks them for sequential consumption (by releasing the mutex) * * Obs. One instance per core */ -func (device *Device) RoutineEncryption() { - +func (device *Device) RoutineEncryption(id int) { + var paddingZeros [PaddingMultiple]byte var nonce [chacha20poly1305.NonceSize]byte - logDebug := device.log.Debug - - defer func() { - for { - select { - case elem, ok := <-device.queue.encryption: - if ok && !elem.IsDropped() { - elem.Drop() - device.PutMessageBuffer(elem.buffer) - elem.Unlock() - } - default: - goto out - } - } - out: - logDebug.Println("Routine: encryption worker - stopped") - device.state.stopping.Done() - }() - - logDebug.Println("Routine: encryption worker - started") - device.state.starting.Done() - - for { - - // fetch next element - - select { - case <-device.signals.stop: - return - - case elem, ok := <-device.queue.encryption: - - if !ok { - return - } - - // check if dropped - - if elem.IsDropped() { - continue - } + defer device.log.Verbosef("Routine: encryption worker %d - stopped", id) + device.log.Verbosef("Routine: encryption worker %d - started", id) + for elemsContainer := range device.queue.encryption.c { + for _, elem := range elemsContainer.elems { // populate header fields - header := elem.buffer[:MessageTransportHeaderSize] fieldType := header[0:4] @@ -513,24 +462,8 @@ func (device *Device) RoutineEncryption() { binary.LittleEndian.PutUint64(fieldNonce, elem.nonce) // pad content to multiple of 16 - - mtu := int(atomic.LoadInt32(&device.tun.mtu)) - var paddedSize int - if mtu == 0 { - paddedSize = (len(elem.packet) + PaddingMultiple - 1) & ^(PaddingMultiple - 1) - } else { - lastUnit := len(elem.packet) - if lastUnit > mtu { - lastUnit %= mtu - } - paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1) - if paddedSize > mtu { - paddedSize = mtu - } - } - for i := len(elem.packet); i < paddedSize; i++ { - elem.packet = append(elem.packet, 0) - } + paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load())) + elem.packet = append(elem.packet, paddingZeros[:paddingSize]...) // encrypt content and release to consumer @@ -541,82 +474,73 @@ func (device *Device) RoutineEncryption() { elem.packet, nil, ) - elem.Unlock() } + elemsContainer.Unlock() } } -/* Sequentially reads packets from queue and sends to endpoint - * - * Obs. Single instance per peer. - * The routine terminates then the outbound queue is closed. - */ -func (peer *Peer) RoutineSequentialSender() { - +func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { device := peer.device - - logDebug := device.log.Debug - logError := device.log.Error - defer func() { - for { - select { - case elem, ok := <-peer.queue.outbound: - if ok { - if !elem.IsDropped() { - device.PutMessageBuffer(elem.buffer) - elem.Drop() - } - device.PutOutboundElement(elem) - } - default: - goto out - } - } - out: - logDebug.Println(peer, "- Routine: sequential sender - stopped") - peer.routines.stopping.Done() + defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer) + peer.stopping.Done() }() + device.log.Verbosef("%v - Routine: sequential sender - started", peer) - logDebug.Println(peer, "- Routine: sequential sender - started") - - peer.routines.starting.Done() + bufs := make([][]byte, 0, maxBatchSize) - for { - select { - - case <-peer.routines.stop: + for elemsContainer := range peer.queue.outbound.c { + bufs = bufs[:0] + if elemsContainer == nil { return - - case elem, ok := <-peer.queue.outbound: - - if !ok { - return - } - - elem.Lock() - if elem.IsDropped() { + } + if !peer.isRunning.Load() { + // peer has been stopped; return re-usable elems to the shared pool. + // This is an optimization only. It is possible for the peer to be stopped + // immediately after this check, in which case, elem will get processed. + // The timers and SendBuffers code are resilient to a few stragglers. + // TODO: rework peer shutdown order to ensure + // that we never accidentally keep timers alive longer than necessary. + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { + device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) - continue } - - peer.timersAnyAuthenticatedPacketTraversal() - peer.timersAnyAuthenticatedPacketSent() - - // send message and return buffer to pool - - err := peer.SendBuffer(elem.packet) + continue + } + dataSent := false + elemsContainer.Lock() + for _, elem := range elemsContainer.elems { if len(elem.packet) != MessageKeepaliveSize { - peer.timersDataSent() + dataSent = true } + bufs = append(bufs, elem.packet) + } + + peer.timersAnyAuthenticatedPacketTraversal() + peer.timersAnyAuthenticatedPacketSent() + + err := peer.SendBuffers(bufs) + if dataSent { + peer.timersDataSent() + } + for _, elem := range elemsContainer.elems { device.PutMessageBuffer(elem.buffer) device.PutOutboundElement(elem) - if err != nil { - logError.Println(peer, "- Failed to send data packet", err) - continue + } + device.PutOutboundElementsContainer(elemsContainer) + if err != nil { + var errGSO conn.ErrUDPGSODisabled + if errors.As(err, &errGSO) { + device.log.Verbosef(err.Error()) + err = errGSO.RetryErr } - - peer.keepKeyFreshSending() } + if err != nil { + device.log.Errorf("%v - Failed to send data packets: %v", peer, err) + continue + } + + peer.keepKeyFreshSending() } } diff --git a/device/sticky_default.go b/device/sticky_default.go index 1cc52f6..1038256 100644 --- a/device/sticky_default.go +++ b/device/sticky_default.go @@ -1,4 +1,4 @@ -// +build !linux +//go:build !linux package device diff --git a/device/sticky_linux.go b/device/sticky_linux.go index f9522c2..6057ff1 100644 --- a/device/sticky_linux.go +++ b/device/sticky_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * * This implements userspace semantics of "sticky sockets", modeled after * WireGuard's kernelspace implementation. This is more or less a straight port @@ -19,11 +19,19 @@ import ( "unsafe" "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) { + if !conn.StdNetSupportsStickySockets { + return nil, nil + } + if _, ok := bind.(*conn.StdNetBind); !ok { + return nil, nil + } + netlinkSock, err := createNetlinkRouteSocket() if err != nil { return nil, err @@ -47,6 +55,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl var reqPeer map[uint32]peerEndpointPtr var reqPeerLock sync.Mutex + defer netlinkCancel.Close() defer unix.Close(netlinkSock) for msg := make([]byte, 1<<16); ; { @@ -101,17 +110,17 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl if !ok { break } - pePtr.peer.Lock() - if &pePtr.peer.endpoint != pePtr.endpoint { - pePtr.peer.Unlock() + pePtr.peer.endpoint.Lock() + if &pePtr.peer.endpoint.val != pePtr.endpoint { + pePtr.peer.endpoint.Unlock() break } - if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx { - pePtr.peer.Unlock() + if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx { + pePtr.peer.endpoint.Unlock() break } - pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc() - pePtr.peer.Unlock() + pePtr.peer.endpoint.clearSrcOnTx = true + pePtr.peer.endpoint.Unlock() } attr = attr[attrhdr.Len:] } @@ -125,18 +134,18 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl device.peers.RLock() i := uint32(1) for _, peer := range device.peers.keyMap { - peer.RLock() - if peer.endpoint == nil { - peer.RUnlock() + peer.endpoint.Lock() + if peer.endpoint.val == nil { + peer.endpoint.Unlock() continue } - nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint) + nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint) if nativeEP == nil { - peer.RUnlock() + peer.endpoint.Unlock() continue } - if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 { - peer.RUnlock() + if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 { + peer.endpoint.Unlock() break } nlmsg := struct { @@ -163,26 +172,26 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl Len: 8, Type: unix.RTA_DST, }, - nativeEP.Dst4().Addr, + nativeEP.DstIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_SRC, }, - nativeEP.Src4().Src, + nativeEP.SrcIP().As4(), unix.RtAttr{ Len: 8, Type: unix.RTA_MARK, }, - uint32(bind.LastMark()), + device.net.fwmark, } nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) reqPeerLock.Lock() reqPeer[i] = peerEndpointPtr{ peer: peer, - endpoint: &peer.endpoint, + endpoint: &peer.endpoint.val, } reqPeerLock.Unlock() - peer.RUnlock() + peer.endpoint.Unlock() i++ _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) if err != nil { @@ -198,13 +207,13 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl } func createNetlinkRouteSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) if err != nil { return -1, err } saddr := &unix.SockaddrNetlink{ Family: unix.AF_NETLINK, - Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), + Groups: unix.RTMGRP_IPV4_ROUTE, } err = unix.Bind(sock, saddr) if err != nil { diff --git a/device/timers.go b/device/timers.go index 18ee736..d4a4ed4 100644 --- a/device/timers.go +++ b/device/timers.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. * * This is based heavily on timers.c from the kernel implementation. */ @@ -8,16 +8,16 @@ package device import ( - "math/rand" "sync" - "sync/atomic" "time" + _ "unsafe" ) -/* This Timer structure and related functions should roughly copy the interface of - * the Linux kernel's struct timer_list. - */ +//go:linkname fastrandn runtime.fastrandn +func fastrandn(n uint32) uint32 +// A Timer manages time-based aspects of the WireGuard protocol. +// Timer roughly copies the interface of the Linux kernel's struct timer_list. type Timer struct { *time.Timer modifyingLock sync.RWMutex @@ -29,18 +29,17 @@ func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer { timer := &Timer{} timer.Timer = time.AfterFunc(time.Hour, func() { timer.runningLock.Lock() + defer timer.runningLock.Unlock() timer.modifyingLock.Lock() if !timer.isPending { timer.modifyingLock.Unlock() - timer.runningLock.Unlock() return } timer.isPending = false timer.modifyingLock.Unlock() expirationFunction(peer) - timer.runningLock.Unlock() }) timer.Stop() return timer @@ -74,12 +73,12 @@ func (timer *Timer) IsPending() bool { } func (peer *Peer) timersActive() bool { - return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 + return peer.isRunning.Load() && peer.device != nil && peer.device.isUp() } func expiredRetransmitHandshake(peer *Peer) { - if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes { - peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) + if peer.timers.handshakeAttempts.Load() > MaxTimerHandshakes { + peer.device.log.Verbosef("%s - Handshake did not complete after %d attempts, giving up", peer, MaxTimerHandshakes+2) if peer.timersActive() { peer.timers.sendKeepalive.Del() @@ -88,7 +87,7 @@ func expiredRetransmitHandshake(peer *Peer) { /* We drop all packets without a keypair and don't try again, * if we try unsuccessfully for too long to make a handshake. */ - peer.FlushNonceQueue() + peer.FlushStagedPackets() /* We set a timer for destroying any residue that might be left * of a partial exchange. @@ -97,15 +96,11 @@ func expiredRetransmitHandshake(peer *Peer) { peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) } } else { - atomic.AddUint32(&peer.timers.handshakeAttempts, 1) - peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1) + peer.timers.handshakeAttempts.Add(1) + peer.device.log.Verbosef("%s - Handshake did not complete after %d seconds, retrying (try %d)", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts.Load()+1) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(true) } @@ -113,8 +108,8 @@ func expiredRetransmitHandshake(peer *Peer) { func expiredSendKeepalive(peer *Peer) { peer.SendKeepalive() - if peer.timers.needAnotherKeepalive.Get() { - peer.timers.needAnotherKeepalive.Set(false) + if peer.timers.needAnotherKeepalive.Load() { + peer.timers.needAnotherKeepalive.Store(false) if peer.timersActive() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } @@ -122,24 +117,19 @@ func expiredSendKeepalive(peer *Peer) { } func expiredNewHandshake(peer *Peer) { - peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) + peer.device.log.Verbosef("%s - Retrying handshake because we stopped hearing back after %d seconds", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds())) /* We clear the endpoint address src address, in case this is the cause of trouble. */ - peer.Lock() - if peer.endpoint != nil { - peer.endpoint.ClearSrc() - } - peer.Unlock() + peer.markEndpointSrcForClearing() peer.SendHandshakeInitiation(false) - } func expiredZeroKeyMaterial(peer *Peer) { - peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) + peer.device.log.Verbosef("%s - Removing all keys, since we haven't received a new one in %d seconds", peer, int((RejectAfterTime * 3).Seconds())) peer.ZeroAndFlushAll() } func expiredPersistentKeepalive(peer *Peer) { - if peer.persistentKeepaliveInterval > 0 { + if peer.persistentKeepaliveInterval.Load() > 0 { peer.SendKeepalive() } } @@ -147,7 +137,7 @@ func expiredPersistentKeepalive(peer *Peer) { /* Should be called after an authenticated data packet is sent. */ func (peer *Peer) timersDataSent() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() { - peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) } } @@ -157,7 +147,7 @@ func (peer *Peer) timersDataReceived() { if !peer.timers.sendKeepalive.IsPending() { peer.timers.sendKeepalive.Mod(KeepaliveTimeout) } else { - peer.timers.needAnotherKeepalive.Set(true) + peer.timers.needAnotherKeepalive.Store(true) } } } @@ -179,7 +169,7 @@ func (peer *Peer) timersAnyAuthenticatedPacketReceived() { /* Should be called after a handshake initiation message is sent. */ func (peer *Peer) timersHandshakeInitiated() { if peer.timersActive() { - peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs))) + peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(fastrandn(RekeyTimeoutJitterMaxMs))) } } @@ -188,9 +178,9 @@ func (peer *Peer) timersHandshakeComplete() { if peer.timersActive() { peer.timers.retransmitHandshake.Del() } - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) + peer.timers.handshakeAttempts.Store(0) + peer.timers.sentLastMinuteHandshake.Store(false) + peer.lastHandshakeNano.Store(time.Now().UnixNano()) } /* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */ @@ -202,8 +192,9 @@ func (peer *Peer) timersSessionDerived() { /* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */ func (peer *Peer) timersAnyAuthenticatedPacketTraversal() { - if peer.persistentKeepaliveInterval > 0 && peer.timersActive() { - peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second) + keepalive := peer.persistentKeepaliveInterval.Load() + if keepalive > 0 && peer.timersActive() { + peer.timers.persistentKeepalive.Mod(time.Duration(keepalive) * time.Second) } } @@ -213,9 +204,12 @@ func (peer *Peer) timersInit() { peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) - atomic.StoreUint32(&peer.timers.handshakeAttempts, 0) - peer.timers.sentLastMinuteHandshake.Set(false) - peer.timers.needAnotherKeepalive.Set(false) +} + +func (peer *Peer) timersStart() { + peer.timers.handshakeAttempts.Store(0) + peer.timers.sentLastMinuteHandshake.Store(false) + peer.timers.needAnotherKeepalive.Store(false) } func (peer *Peer) timersStop() { diff --git a/device/tun.go b/device/tun.go index 0a3fc79..2a2ace9 100644 --- a/device/tun.go +++ b/device/tun.go @@ -1,12 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( - "sync/atomic" + "fmt" "golang.zx2c4.com/wireguard/tun" ) @@ -14,43 +14,40 @@ import ( const DefaultMTU = 1420 func (device *Device) RoutineTUNEventReader() { - setUp := false - logDebug := device.log.Debug - logInfo := device.log.Info - logError := device.log.Error - - logDebug.Println("Routine: event worker - started") - device.state.starting.Done() + device.log.Verbosef("Routine: event worker - started") for event := range device.tun.device.Events() { if event&tun.EventMTUUpdate != 0 { mtu, err := device.tun.device.MTU() - old := atomic.LoadInt32(&device.tun.mtu) if err != nil { - logError.Println("Failed to load updated MTU of device:", err) - } else if int(old) != mtu { - if mtu+MessageTransportSize > MaxMessageSize { - logInfo.Println("MTU updated:", mtu, "(too large)") - } else { - logInfo.Println("MTU updated:", mtu) - } - atomic.StoreInt32(&device.tun.mtu, int32(mtu)) + device.log.Errorf("Failed to load updated MTU of device: %v", err) + continue + } + if mtu < 0 { + device.log.Errorf("MTU not updated to negative value: %v", mtu) + continue + } + var tooLarge string + if mtu > MaxContentSize { + tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize) + mtu = MaxContentSize + } + old := device.tun.mtu.Swap(int32(mtu)) + if int(old) != mtu { + device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge) } } - if event&tun.EventUp != 0 && !setUp { - logInfo.Println("Interface set up") - setUp = true + if event&tun.EventUp != 0 { + device.log.Verbosef("Interface up requested") device.Up() } - if event&tun.EventDown != 0 && setUp { - logInfo.Println("Interface set down") - setUp = false + if event&tun.EventDown != 0 { + device.log.Verbosef("Interface down requested") device.Down() } } - logDebug.Println("Routine: event worker - stopped") - device.state.stopping.Done() + device.log.Verbosef("Routine: event worker - stopped") } diff --git a/device/tun_test.go b/device/tun_test.go deleted file mode 100644 index 5614771..0000000 --- a/device/tun_test.go +++ /dev/null @@ -1,56 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package device - -import ( - "errors" - "os" - - "golang.zx2c4.com/wireguard/tun" -) - -// newDummyTUN creates a dummy TUN device with the specified name. -func newDummyTUN(name string) tun.Device { - return &dummyTUN{ - name: name, - packets: make(chan []byte, 100), - events: make(chan tun.Event, 10), - } -} - -// A dummyTUN is a tun.Device which is used in unit tests. -type dummyTUN struct { - name string - mtu int - packets chan []byte - events chan tun.Event -} - -func (d *dummyTUN) Events() chan tun.Event { return d.events } -func (*dummyTUN) File() *os.File { return nil } -func (*dummyTUN) Flush() error { return nil } -func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil } -func (d *dummyTUN) Name() (string, error) { return d.name, nil } - -func (d *dummyTUN) Close() error { - close(d.events) - close(d.packets) - return nil -} - -func (d *dummyTUN) Read(b []byte, offset int) (int, error) { - buf, ok := <-d.packets - if !ok { - return 0, errors.New("device closed") - } - copy(b[offset:], buf) - return len(buf), nil -} - -func (d *dummyTUN) Write(b []byte, offset int) (int, error) { - d.packets <- b[offset:] - return len(b), nil -} diff --git a/device/uapi.go b/device/uapi.go index b266f4c..d81dae3 100644 --- a/device/uapi.go +++ b/device/uapi.go @@ -1,46 +1,77 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package device import ( "bufio" + "bytes" "errors" "fmt" "io" "net" + "net/netip" "strconv" "strings" - "sync/atomic" + "sync" "time" - "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/wgcfg" ) type IPCError struct { - int64 + code int64 // error code + err error // underlying/wrapped error } func (s IPCError) Error() string { - return fmt.Sprintf("IPC error: %d", s.int64) + return fmt.Sprintf("IPC error %d: %v", s.code, s.err) +} + +func (s IPCError) Unwrap() error { + return s.err } func (s IPCError) ErrorCode() int64 { - return s.int64 + return s.code +} + +func ipcErrorf(code int64, msg string, args ...any) *IPCError { + return &IPCError{code: code, err: fmt.Errorf(msg, args...)} } -func (device *Device) IpcGetOperation(socket *bufio.Writer) error { - lines := make([]string, 0, 100) - send := func(line string) { - lines = append(lines, line) +var byteBufferPool = &sync.Pool{ + New: func() any { return new(bytes.Buffer) }, +} + +// IpcGetOperation implements the WireGuard configuration protocol "get" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. +func (device *Device) IpcGetOperation(w io.Writer) error { + device.ipcMutex.RLock() + defer device.ipcMutex.RUnlock() + + buf := byteBufferPool.Get().(*bytes.Buffer) + buf.Reset() + defer byteBufferPool.Put(buf) + sendf := func(format string, args ...any) { + fmt.Fprintf(buf, format, args...) + buf.WriteByte('\n') + } + keyf := func(prefix string, key *[32]byte) { + buf.Grow(len(key)*2 + 2 + len(prefix)) + buf.WriteString(prefix) + buf.WriteByte('=') + const hex = "0123456789abcdef" + for i := 0; i < len(key); i++ { + buf.WriteByte(hex[key[i]>>4]) + buf.WriteByte(hex[key[i]&0xf]) + } + buf.WriteByte('\n') } func() { - // lock required resources device.net.RLock() @@ -55,352 +86,326 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) error { // serialize device related values if !device.staticIdentity.privateKey.IsZero() { - send("private_key=" + device.staticIdentity.privateKey.HexString()) + keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey)) } if device.net.port != 0 { - send(fmt.Sprintf("listen_port=%d", device.net.port)) + sendf("listen_port=%d", device.net.port) } if device.net.fwmark != 0 { - send(fmt.Sprintf("fwmark=%d", device.net.fwmark)) + sendf("fwmark=%d", device.net.fwmark) } - // serialize each peer state - for _, peer := range device.peers.keyMap { - peer.RLock() - defer peer.RUnlock() - - send("public_key=" + peer.handshake.remoteStatic.HexString()) - send("preshared_key=" + peer.handshake.presharedKey.HexString()) - send("protocol_version=1") - if peer.endpoint != nil { - send("endpoint=" + peer.endpoint.DstToString()) + // Serialize peer state. + peer.handshake.mutex.RLock() + keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic)) + keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey)) + peer.handshake.mutex.RUnlock() + sendf("protocol_version=1") + peer.endpoint.Lock() + if peer.endpoint.val != nil { + sendf("endpoint=%s", peer.endpoint.val.DstToString()) } + peer.endpoint.Unlock() - nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano) + nano := peer.lastHandshakeNano.Load() secs := nano / time.Second.Nanoseconds() nano %= time.Second.Nanoseconds() - send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) - send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) - send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes))) - send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))) - send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) - - for _, ip := range device.allowedips.EntriesForPeer(peer) { - send("allowed_ip=" + ip.String()) - } + sendf("last_handshake_time_sec=%d", secs) + sendf("last_handshake_time_nsec=%d", nano) + sendf("tx_bytes=%d", peer.txBytes.Load()) + sendf("rx_bytes=%d", peer.rxBytes.Load()) + sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load()) + device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool { + sendf("allowed_ip=%s", prefix.String()) + return true + }) } }() // send lines (does not require resource locks) - - for _, line := range lines { - _, err := socket.WriteString(line + "\n") - if err != nil { - return &IPCError{ipc.IpcErrorIO} - } + if _, err := w.Write(buf.Bytes()); err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err) } return nil } -func (device *Device) IpcSetOperation(socket *bufio.Reader) error { - scanner := bufio.NewScanner(socket) - logError := device.log.Error - logDebug := device.log.Debug +// IpcSetOperation implements the WireGuard configuration protocol "set" operation. +// See https://www.wireguard.com/xplatform/#configuration-protocol for details. +func (device *Device) IpcSetOperation(r io.Reader) (err error) { + device.ipcMutex.Lock() + defer device.ipcMutex.Unlock() - var peer *Peer + defer func() { + if err != nil { + device.log.Errorf("%v", err) + } + }() - dummy := false - createdNewPeer := false + peer := new(ipcSetPeer) deviceConfig := true + scanner := bufio.NewScanner(r) for scanner.Scan() { - - // parse line - line := scanner.Text() if line == "" { + // Blank line means terminate operation. + peer.handlePostConfig() return nil } - parts := strings.Split(line, "=") - if len(parts) != 2 { - return &IPCError{ipc.IpcErrorProtocol} + key, value, ok := strings.Cut(line, "=") + if !ok { + return ipcErrorf(ipc.IpcErrorProtocol, "failed to parse line %q", line) } - key := parts[0] - value := parts[1] - - /* device configuration */ - - if deviceConfig { - - switch key { - case "private_key": - sk, err := wgcfg.ParsePrivateHexKey(value) - if err != nil { - logError.Println("Failed to set private_key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - logDebug.Println("UAPI: Updating private key") - device.SetPrivateKey(sk) - - case "listen_port": - - // parse port number - - port, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to parse listen_port:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - // update port and rebind - - logDebug.Println("UAPI: Updating listen port") - - device.net.Lock() - device.net.port = uint16(port) - device.net.Unlock() - - if err := device.BindUpdate(); err != nil { - logError.Println("Failed to set listen_port:", err) - return &IPCError{ipc.IpcErrorPortInUse} - } - case "fwmark": - - // parse fwmark field - - fwmark, err := func() (uint32, error) { - if value == "" { - return 0, nil - } - mark, err := strconv.ParseUint(value, 10, 32) - return uint32(mark), err - }() - - if err != nil { - logError.Println("Invalid fwmark", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - logDebug.Println("UAPI: Updating fwmark") - - if err := device.BindSetMark(uint32(fwmark)); err != nil { - logError.Println("Failed to update fwmark:", err) - return &IPCError{ipc.IpcErrorPortInUse} - } - - case "public_key": - // switch to peer configuration - logDebug.Println("UAPI: Transition to peer configuration") + if key == "public_key" { + if deviceConfig { deviceConfig = false - - case "replace_peers": - if value != "true" { - logError.Println("Failed to set replace_peers, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - logDebug.Println("UAPI: Removing all peers") - device.RemoveAllPeers() - - default: - logError.Println("Invalid UAPI device key:", key) - return &IPCError{ipc.IpcErrorInvalid} } + peer.handlePostConfig() + // Load/create the peer we are now configuring. + err := device.handlePublicKeyLine(peer, value) + if err != nil { + return err + } + continue } - /* peer configuration */ - - if !deviceConfig { - - switch key { - - case "public_key": - publicKey, err := wgcfg.ParseHexKey(value) - if err != nil { - logError.Println("Failed to get peer by public key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - // ignore peer with public key of device - - device.staticIdentity.RLock() - dummy = device.staticIdentity.publicKey.Equal(publicKey) - device.staticIdentity.RUnlock() + var err error + if deviceConfig { + err = device.handleDeviceLine(key, value) + } else { + err = device.handlePeerLine(peer, key, value) + } + if err != nil { + return err + } + } + peer.handlePostConfig() - if dummy { - peer = &Peer{} - } else { - peer = device.LookupPeer(publicKey) - } + if err := scanner.Err(); err != nil { + return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err) + } + return nil +} - createdNewPeer = peer == nil - if createdNewPeer { - peer, err = device.NewPeer(publicKey) - if err != nil { - logError.Println("Failed to create new peer:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - if peer == nil { - dummy = true - peer = &Peer{} - } else { - logDebug.Println(peer, "- UAPI: Created") - } - } +func (device *Device) handleDeviceLine(key, value string) error { + switch key { + case "private_key": + var sk NoisePrivateKey + err := sk.FromMaybeZeroHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err) + } + device.log.Verbosef("UAPI: Updating private key") + device.SetPrivateKey(sk) - case "update_only": - - // allow disabling of creation + case "listen_port": + port, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err) + } - if value != "true" { - logError.Println("Failed to set update only, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - if createdNewPeer && !dummy { - device.RemovePeer(peer.handshake.remoteStatic) - peer = &Peer{} - dummy = true - } + // update port and rebind + device.log.Verbosef("UAPI: Updating listen port") - case "remove": + device.net.Lock() + device.net.port = uint16(port) + device.net.Unlock() - // remove currently selected peer from device + if err := device.BindUpdate(); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err) + } - if value != "true" { - logError.Println("Failed to set remove, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } - if !dummy { - logDebug.Println(peer, "- UAPI: Removing") - device.RemovePeer(peer.handshake.remoteStatic) - } - peer = &Peer{} - dummy = true + case "fwmark": + mark, err := strconv.ParseUint(value, 10, 32) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err) + } - case "preshared_key": + device.log.Verbosef("UAPI: Updating fwmark") + if err := device.BindSetMark(uint32(mark)); err != nil { + return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err) + } - // update PSK + case "replace_peers": + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value) + } + device.log.Verbosef("UAPI: Removing all peers") + device.RemoveAllPeers() - logDebug.Println(peer, "- UAPI: Updating preshared key") + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key) + } - peer.handshake.mutex.Lock() - key, err := wgcfg.ParseSymmetricHexKey(value) - peer.handshake.presharedKey = key - peer.handshake.mutex.Unlock() + return nil +} - if err != nil { - logError.Println("Failed to set preshared key:", err) - return &IPCError{ipc.IpcErrorInvalid} - } +// An ipcSetPeer is the current state of an IPC set operation on a peer. +type ipcSetPeer struct { + *Peer // Peer is the current peer being operated on + dummy bool // dummy reports whether this peer is a temporary, placeholder peer + created bool // new reports whether this is a newly created peer + pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on +} - case "endpoint": +func (peer *ipcSetPeer) handlePostConfig() { + if peer.Peer == nil || peer.dummy { + return + } + if peer.created { + peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil + } + if peer.device.isUp() { + peer.Start() + if peer.pkaOn { + peer.SendKeepalive() + } + peer.SendStagedPackets() + } +} - // set endpoint destination +func (device *Device) handlePublicKeyLine(peer *ipcSetPeer, value string) error { + // Load/create the peer we are configuring. + var publicKey NoisePublicKey + err := publicKey.FromHex(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err) + } - logDebug.Println(peer, "- UAPI: Updating endpoint") + // Ignore peer with the same public key as this device. + device.staticIdentity.RLock() + peer.dummy = device.staticIdentity.publicKey.Equals(publicKey) + device.staticIdentity.RUnlock() - err := func() error { - peer.Lock() - defer peer.Unlock() - endpoint, err := conn.CreateEndpoint(value) - if err != nil { - return err - } - peer.endpoint = endpoint - return nil - }() + if peer.dummy { + peer.Peer = &Peer{} + } else { + peer.Peer = device.LookupPeer(publicKey) + } - if err != nil { - logError.Println("Failed to set endpoint:", err, ":", value) - return &IPCError{ipc.IpcErrorInvalid} - } - - case "persistent_keepalive_interval": - - // update persistent keepalive interval - - logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval") - - secs, err := strconv.ParseUint(value, 10, 16) - if err != nil { - logError.Println("Failed to set persistent keepalive interval:", err) - return &IPCError{ipc.IpcErrorInvalid} - } - - old := peer.persistentKeepaliveInterval - peer.persistentKeepaliveInterval = uint16(secs) - - // send immediate keepalive if we're turning it on and before it wasn't on - - if old == 0 && secs != 0 { - if err != nil { - logError.Println("Failed to get tun device status:", err) - return &IPCError{ipc.IpcErrorIO} - } - if device.isUp.Get() && !dummy { - peer.SendKeepalive() - } - } + peer.created = peer.Peer == nil + if peer.created { + peer.Peer, err = device.NewPeer(publicKey) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err) + } + device.log.Verbosef("%v - UAPI: Created", peer.Peer) + } + return nil +} - case "replace_allowed_ips": +func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error { + switch key { + case "update_only": + // allow disabling of creation + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value) + } + if peer.created && !peer.dummy { + device.RemovePeer(peer.handshake.remoteStatic) + peer.Peer = &Peer{} + peer.dummy = true + } - logDebug.Println(peer, "- UAPI: Removing all allowedips") + case "remove": + // remove currently selected peer from device + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value) + } + if !peer.dummy { + device.log.Verbosef("%v - UAPI: Removing", peer.Peer) + device.RemovePeer(peer.handshake.remoteStatic) + } + peer.Peer = &Peer{} + peer.dummy = true - if value != "true" { - logError.Println("Failed to replace allowedips, invalid value:", value) - return &IPCError{ipc.IpcErrorInvalid} - } + case "preshared_key": + device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer) - if dummy { - continue - } + peer.handshake.mutex.Lock() + err := peer.handshake.presharedKey.FromHex(value) + peer.handshake.mutex.Unlock() - device.allowedips.RemoveByPeer(peer) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err) + } - case "allowed_ip": + case "endpoint": + device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer) + endpoint, err := device.net.bind.ParseEndpoint(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err) + } + peer.endpoint.Lock() + defer peer.endpoint.Unlock() + peer.endpoint.val = endpoint - logDebug.Println(peer, "- UAPI: Adding allowedip") + case "persistent_keepalive_interval": + device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer) - _, network, err := net.ParseCIDR(value) - if err != nil { - logError.Println("Failed to set allowed ip:", err) - return &IPCError{ipc.IpcErrorInvalid} - } + secs, err := strconv.ParseUint(value, 10, 16) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err) + } - if dummy { - continue - } + old := peer.persistentKeepaliveInterval.Swap(uint32(secs)) - ones, _ := network.Mask.Size() - device.allowedips.Insert(network.IP, uint(ones), peer) + // Send immediate keepalive if we're turning it on and before it wasn't on. + peer.pkaOn = old == 0 && secs != 0 - case "protocol_version": + case "replace_allowed_ips": + device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer) + if value != "true" { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value) + } + if peer.dummy { + return nil + } + device.allowedips.RemoveByPeer(peer.Peer) - if value != "1" { - logError.Println("Invalid protocol version:", value) - return &IPCError{ipc.IpcErrorInvalid} - } + case "allowed_ip": + device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer) + prefix, err := netip.ParsePrefix(value) + if err != nil { + return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err) + } + if peer.dummy { + return nil + } + device.allowedips.Insert(prefix, peer.Peer) - default: - logError.Println("Invalid UAPI peer key:", key) - return &IPCError{ipc.IpcErrorInvalid} - } + case "protocol_version": + if value != "1" { + return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value) } + + default: + return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key) } return nil } -func (device *Device) IpcHandle(socket net.Conn) { +func (device *Device) IpcGet() (string, error) { + buf := new(strings.Builder) + if err := device.IpcGetOperation(buf); err != nil { + return "", err + } + return buf.String(), nil +} - // create buffered read/writer +func (device *Device) IpcSet(uapiConf string) error { + return device.IpcSetOperation(strings.NewReader(uapiConf)) +} +func (device *Device) IpcHandle(socket net.Conn) { defer socket.Close() buffered := func(s io.ReadWriter) *bufio.ReadWriter { @@ -409,45 +414,44 @@ func (device *Device) IpcHandle(socket net.Conn) { return bufio.NewReadWriter(reader, writer) }(socket) - defer buffered.Flush() - - op, err := buffered.ReadString('\n') - if err != nil { - return - } - - // handle operation - - var status *IPCError - - switch op { - case "set=1\n": - err = device.IpcSetOperation(buffered.Reader) - if !errors.As(err, &status) { - // should never happen - device.log.Error.Println("Invalid UAPI error:", err) - status = &IPCError{1} + for { + op, err := buffered.ReadString('\n') + if err != nil { + return } - case "get=1\n": - err = device.IpcGetOperation(buffered.Writer) - if !errors.As(err, &status) { - // should never happen - device.log.Error.Println("Invalid UAPI error:", err) - status = &IPCError{1} + // handle operation + switch op { + case "set=1\n": + err = device.IpcSetOperation(buffered.Reader) + case "get=1\n": + var nextByte byte + nextByte, err = buffered.ReadByte() + if err != nil { + return + } + if nextByte != '\n' { + err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte) + break + } + err = device.IpcGetOperation(buffered.Writer) + default: + device.log.Errorf("invalid UAPI operation: %v", op) + return } - default: - device.log.Error.Println("Invalid UAPI operation:", op) - return - } - - // write status - - if status != nil { - device.log.Error.Println(status) - fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) - } else { - fmt.Fprintf(buffered, "errno=0\n\n") + // write status + var status *IPCError + if err != nil && !errors.As(err, &status) { + // shouldn't happen + status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err) + } + if status != nil { + device.log.Errorf("%v", status) + fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode()) + } else { + fmt.Fprintf(buffered, "errno=0\n\n") + } + buffered.Flush() } } diff --git a/device/version.go b/device/version.go deleted file mode 100644 index 0877595..0000000 --- a/device/version.go +++ /dev/null @@ -1,3 +0,0 @@ -package device - -const WireGuardGoVersion = "0.0.20200320" diff --git a/format_test.go b/format_test.go new file mode 100644 index 0000000..6f6cab7 --- /dev/null +++ b/format_test.go @@ -0,0 +1,51 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ +package main + +import ( + "bytes" + "go/format" + "io/fs" + "os" + "path/filepath" + "runtime" + "sync" + "testing" +) + +func TestFormatting(t *testing.T) { + var wg sync.WaitGroup + filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + t.Errorf("unable to walk %s: %v", path, err) + return nil + } + if d.IsDir() || filepath.Ext(path) != ".go" { + return nil + } + wg.Add(1) + go func(path string) { + defer wg.Done() + src, err := os.ReadFile(path) + if err != nil { + t.Errorf("unable to read %s: %v", path, err) + return + } + if runtime.GOOS == "windows" { + src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'}) + } + formatted, err := format.Source(src) + if err != nil { + t.Errorf("unable to format %s: %v", path, err) + return + } + if !bytes.Equal(src, formatted) { + t.Errorf("unformatted code: %s", path) + } + }(path) + return nil + }) + wg.Wait() +} @@ -1,10 +1,16 @@ module golang.zx2c4.com/wireguard -go 1.13 +go 1.20 require ( - golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc - golang.org/x/net v0.0.0-20191003171128-d98b1b443823 - golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 - golang.org/x/text v0.3.2 + golang.org/x/crypto v0.13.0 + golang.org/x/net v0.15.0 + golang.org/x/sys v0.12.0 + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 + gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 +) + +require ( + github.com/google/btree v1.0.1 // indirect + golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 // indirect ) @@ -1,14 +1,14 @@ -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4= -golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4= -golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So= -golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= -golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +github.com/google/btree v1.0.1 h1:gK4Kx5IaGY9CD5sPJ36FHiBJ6ZXl0kilRiiCj+jdYp4= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= +golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= +golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8 h1:vVKdlvoWBphwdxWKrFZEuM0kGgGLxUOYcY4U/2Vjg44= +golang.org/x/time v0.0.0-20220210224613-90d013bbcef8/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259 h1:TbRPT0HtzFP3Cno1zZo7yPzEEnfu8EjLfl6IU9VfqkQ= +gvisor.dev/gvisor v0.0.0-20230927004350-cbd86285d259/go.mod h1:AVgIgHMwK63XvmAzWG9vLQ41YnVHN0du0tEC46fI7yY= diff --git a/ipc/winpipe/file.go b/ipc/namedpipe/file.go index 29d02a7..ab1e13d 100644 --- a/ipc/winpipe/file.go +++ b/ipc/namedpipe/file.go @@ -1,63 +1,31 @@ -// +build windows +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ -package winpipe +//go:build windows + +package namedpipe import ( - "errors" "io" + "os" "runtime" "sync" "sync/atomic" "time" + "unsafe" "golang.org/x/sys/windows" ) -//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx -//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort -//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus -//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes -//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult - -type atomicBool int32 - -func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 } -func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) } -func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) } -func (b *atomicBool) swap(new bool) bool { - var newInt int32 - if new { - newInt = 1 - } - return atomic.SwapInt32((*int32)(b), newInt) == 1 -} - -const ( - cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1 - cFILE_SKIP_SET_EVENT_ON_HANDLE = 2 -) +type timeoutChan chan struct{} var ( - ErrFileClosed = errors.New("file has already been closed") - ErrTimeout = &timeoutError{} + ioInitOnce sync.Once + ioCompletionPort windows.Handle ) -type timeoutError struct{} - -func (e *timeoutError) Error() string { return "i/o timeout" } -func (e *timeoutError) Timeout() bool { return true } -func (e *timeoutError) Temporary() bool { return true } - -type timeoutChan chan struct{} - -var ioInitOnce sync.Once -var ioCompletionPort windows.Handle - // ioResult contains the result of an asynchronous IO operation type ioResult struct { bytes uint32 @@ -71,7 +39,7 @@ type ioOperation struct { } func initIo() { - h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff) + h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0) if err != nil { panic(err) } @@ -79,13 +47,13 @@ func initIo() { go ioCompletionProcessor(h) } -// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. +// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. // It takes ownership of this handle and will close it if it is garbage collected. -type win32File struct { +type file struct { handle windows.Handle wg sync.WaitGroup wgLock sync.RWMutex - closing atomicBool + closing atomic.Bool socket bool readDeadline deadlineHandler writeDeadline deadlineHandler @@ -96,18 +64,18 @@ type deadlineHandler struct { channel timeoutChan channelLock sync.RWMutex timer *time.Timer - timedout atomicBool + timedout atomic.Bool } -// makeWin32File makes a new win32File from an existing file handle -func makeWin32File(h windows.Handle) (*win32File, error) { - f := &win32File{handle: h} +// makeFile makes a new file from an existing file handle +func makeFile(h windows.Handle) (*file, error) { + f := &file{handle: h} ioInitOnce.Do(initIo) - _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) + _, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0) if err != nil { return nil, err } - err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE) + err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE) if err != nil { return nil, err } @@ -116,18 +84,14 @@ func makeWin32File(h windows.Handle) (*win32File, error) { return f, nil } -func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) { - return makeWin32File(h) -} - // closeHandle closes the resources associated with a Win32 handle -func (f *win32File) closeHandle() { +func (f *file) closeHandle() { f.wgLock.Lock() // Atomically set that we are closing, releasing the resources only once. - if !f.closing.swap(true) { + if f.closing.Swap(true) == false { f.wgLock.Unlock() // cancel all IO and wait for it to complete - cancelIoEx(f.handle, nil) + windows.CancelIoEx(f.handle, nil) f.wg.Wait() // at this point, no new IO can start windows.Close(f.handle) @@ -137,19 +101,19 @@ func (f *win32File) closeHandle() { } } -// Close closes a win32File. -func (f *win32File) Close() error { +// Close closes a file. +func (f *file) Close() error { f.closeHandle() return nil } // prepareIo prepares for a new IO operation. // The caller must call f.wg.Done() when the IO is finished, prior to Close() returning. -func (f *win32File) prepareIo() (*ioOperation, error) { +func (f *file) prepareIo() (*ioOperation, error) { f.wgLock.RLock() - if f.closing.isSet() { + if f.closing.Load() { f.wgLock.RUnlock() - return nil, ErrFileClosed + return nil, os.ErrClosed } f.wg.Add(1) f.wgLock.RUnlock() @@ -164,7 +128,7 @@ func ioCompletionProcessor(h windows.Handle) { var bytes uint32 var key uintptr var op *ioOperation - err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE) + err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE) if op == nil { panic(err) } @@ -174,13 +138,13 @@ func ioCompletionProcessor(h windows.Handle) { // asyncIo processes the return value from ReadFile or WriteFile, blocking until // the operation has actually completed. -func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { +func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { if err != windows.ERROR_IO_PENDING { return int(bytes), err } - if f.closing.isSet() { - cancelIoEx(f.handle, &c.o) + if f.closing.Load() { + windows.CancelIoEx(f.handle, &c.o) } var timeout timeoutChan @@ -195,20 +159,20 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er case r = <-c.ch: err = r.err if err == windows.ERROR_OPERATION_ABORTED { - if f.closing.isSet() { - err = ErrFileClosed + if f.closing.Load() { + err = os.ErrClosed } } else if err != nil && f.socket { // err is from Win32. Query the overlapped structure to get the winsock error. var bytes, flags uint32 - err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) + err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags) } case <-timeout: - cancelIoEx(f.handle, &c.o) + windows.CancelIoEx(f.handle, &c.o) r = <-c.ch err = r.err if err == windows.ERROR_OPERATION_ABORTED { - err = ErrTimeout + err = os.ErrDeadlineExceeded } } @@ -220,15 +184,15 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er } // Read reads from a file handle. -func (f *win32File) Read(b []byte) (int, error) { +func (f *file) Read(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() - if f.readDeadline.timedout.isSet() { - return 0, ErrTimeout + if f.readDeadline.timedout.Load() { + return 0, os.ErrDeadlineExceeded } var bytes uint32 @@ -247,15 +211,15 @@ func (f *win32File) Read(b []byte) (int, error) { } // Write writes to a file handle. -func (f *win32File) Write(b []byte) (int, error) { +func (f *file) Write(b []byte) (int, error) { c, err := f.prepareIo() if err != nil { return 0, err } defer f.wg.Done() - if f.writeDeadline.timedout.isSet() { - return 0, ErrTimeout + if f.writeDeadline.timedout.Load() { + return 0, os.ErrDeadlineExceeded } var bytes uint32 @@ -265,19 +229,19 @@ func (f *win32File) Write(b []byte) (int, error) { return n, err } -func (f *win32File) SetReadDeadline(deadline time.Time) error { +func (f *file) SetReadDeadline(deadline time.Time) error { return f.readDeadline.set(deadline) } -func (f *win32File) SetWriteDeadline(deadline time.Time) error { +func (f *file) SetWriteDeadline(deadline time.Time) error { return f.writeDeadline.set(deadline) } -func (f *win32File) Flush() error { +func (f *file) Flush() error { return windows.FlushFileBuffers(f.handle) } -func (f *win32File) Fd() uintptr { +func (f *file) Fd() uintptr { return uintptr(f.handle) } @@ -291,7 +255,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } d.timer = nil } - d.timedout.setFalse() + d.timedout.Store(false) select { case <-d.channel: @@ -306,7 +270,7 @@ func (d *deadlineHandler) set(deadline time.Time) error { } timeoutIO := func() { - d.timedout.setTrue() + d.timedout.Store(true) close(d.channel) } diff --git a/ipc/namedpipe/namedpipe.go b/ipc/namedpipe/namedpipe.go new file mode 100644 index 0000000..ef3dea1 --- /dev/null +++ b/ipc/namedpipe/namedpipe.go @@ -0,0 +1,485 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes. +package namedpipe + +import ( + "context" + "io" + "net" + "os" + "runtime" + "sync/atomic" + "time" + "unsafe" + + "golang.org/x/sys/windows" +) + +type pipe struct { + *file + path string +} + +type messageBytePipe struct { + pipe + writeClosed atomic.Bool + readEOF bool +} + +type pipeAddress string + +func (f *pipe) LocalAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) RemoteAddr() net.Addr { + return pipeAddress(f.path) +} + +func (f *pipe) SetDeadline(t time.Time) error { + f.SetReadDeadline(t) + f.SetWriteDeadline(t) + return nil +} + +// CloseWrite closes the write side of a message pipe in byte mode. +func (f *messageBytePipe) CloseWrite() error { + if !f.writeClosed.CompareAndSwap(false, true) { + return io.ErrClosedPipe + } + err := f.file.Flush() + if err != nil { + f.writeClosed.Store(false) + return err + } + _, err = f.file.Write(nil) + if err != nil { + f.writeClosed.Store(false) + return err + } + return nil +} + +// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since +// they are used to implement CloseWrite. +func (f *messageBytePipe) Write(b []byte) (int, error) { + if f.writeClosed.Load() { + return 0, io.ErrClosedPipe + } + if len(b) == 0 { + return 0, nil + } + return f.file.Write(b) +} + +// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message +// mode pipe will return io.EOF, as will all subsequent reads. +func (f *messageBytePipe) Read(b []byte) (int, error) { + if f.readEOF { + return 0, io.EOF + } + n, err := f.file.Read(b) + if err == io.EOF { + // If this was the result of a zero-byte read, then + // it is possible that the read was due to a zero-size + // message. Since we are simulating CloseWrite with a + // zero-byte message, ensure that all future Read calls + // also return EOF. + f.readEOF = true + } else if err == windows.ERROR_MORE_DATA { + // ERROR_MORE_DATA indicates that the pipe's read mode is message mode + // and the message still has more bytes. Treat this as a success, since + // this package presents all named pipes as byte streams. + err = nil + } + return n, err +} + +func (f *pipe) Handle() windows.Handle { + return f.handle +} + +func (s pipeAddress) Network() string { + return "pipe" +} + +func (s pipeAddress) String() string { + return string(s) +} + +// tryDialPipe attempts to dial the specified pipe until cancellation or timeout. +func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { + for { + select { + case <-ctx.Done(): + return 0, ctx.Err() + default: + path16, err := windows.UTF16PtrFromString(*path) + if err != nil { + return 0, err + } + h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + return h, nil + } + if err != windows.ERROR_PIPE_BUSY { + return h, &os.PathError{Err: err, Op: "open", Path: *path} + } + // Wait 10 msec and try again. This is a rather simplistic + // view, as we always try each 10 milliseconds. + time.Sleep(10 * time.Millisecond) + } + } +} + +// DialConfig exposes various options for use in Dial and DialContext. +type DialConfig struct { + ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID. +} + +// DialTimeout connects to the specified named pipe by path, timing out if the +// connection takes longer than the specified duration. If timeout is zero, then +// we use a default timeout of 2 seconds. +func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + if timeout == 0 { + timeout = time.Second * 2 + } + absTimeout := time.Now().Add(timeout) + ctx, _ := context.WithDeadline(context.Background(), absTimeout) + conn, err := config.DialContext(ctx, path) + if err == context.DeadlineExceeded { + return nil, os.ErrDeadlineExceeded + } + return conn, err +} + +// DialContext attempts to connect to the specified named pipe by path. +func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) { + var err error + var h windows.Handle + h, err = tryDialPipe(ctx, &path) + if err != nil { + return nil, err + } + + if config.ExpectedOwner != nil { + sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) + if err != nil { + windows.Close(h) + return nil, err + } + realOwner, _, err := sd.Owner() + if err != nil { + windows.Close(h) + return nil, err + } + if !realOwner.Equals(config.ExpectedOwner) { + windows.Close(h) + return nil, windows.ERROR_ACCESS_DENIED + } + } + + var flags uint32 + err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil) + if err != nil { + windows.Close(h) + return nil, err + } + + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + + // If the pipe is in message mode, return a message byte pipe, which + // supports CloseWrite. + if flags&windows.PIPE_TYPE_MESSAGE != 0 { + return &messageBytePipe{ + pipe: pipe{file: f, path: path}, + }, nil + } + return &pipe{file: f, path: path}, nil +} + +var defaultDialer DialConfig + +// DialTimeout calls DialConfig.DialTimeout using an empty configuration. +func DialTimeout(path string, timeout time.Duration) (net.Conn, error) { + return defaultDialer.DialTimeout(path, timeout) +} + +// DialContext calls DialConfig.DialContext using an empty configuration. +func DialContext(ctx context.Context, path string) (net.Conn, error) { + return defaultDialer.DialContext(ctx, path) +} + +type acceptResponse struct { + f *file + err error +} + +type pipeListener struct { + firstHandle windows.Handle + path string + config ListenConfig + acceptCh chan chan acceptResponse + closeCh chan int + doneCh chan int +} + +func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) { + path16, err := windows.UTF16PtrFromString(path) + if err != nil { + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + var oa windows.OBJECT_ATTRIBUTES + oa.Length = uint32(unsafe.Sizeof(oa)) + + var ntPath windows.NTUnicodeString + if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer))) + oa.ObjectName = &ntPath + + // The security descriptor is only needed for the first pipe. + if isFirstPipe { + if sd != nil { + oa.SecurityDescriptor = sd + } else { + // Construct the default named pipe security descriptor. + var acl *windows.ACL + if err := windows.RtlDefaultNpAcl(&acl); err != nil { + return 0, err + } + defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl))) + sd, err = windows.NewSecurityDescriptor() + if err != nil { + return 0, err + } + if err = sd.SetDACL(acl, true, false); err != nil { + return 0, err + } + oa.SecurityDescriptor = sd + } + } + + typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS) + if c.MessageMode { + typ |= windows.FILE_PIPE_MESSAGE_TYPE + } + + disposition := uint32(windows.FILE_OPEN) + access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) + if isFirstPipe { + disposition = windows.FILE_CREATE + // By not asking for read or write access, the named pipe file system + // will put this pipe into an initially disconnected state, blocking + // client connections until the next call with isFirstPipe == false. + access = windows.SYNCHRONIZE + } + + timeout := int64(-50 * 10000) // 50ms + + var ( + h windows.Handle + iosb windows.IO_STATUS_BLOCK + ) + err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout) + if err != nil { + if ntstatus, ok := err.(windows.NTStatus); ok { + err = ntstatus.Errno() + } + return 0, &os.PathError{Op: "open", Path: path, Err: err} + } + + runtime.KeepAlive(ntPath) + return h, nil +} + +func (l *pipeListener) makeServerPipe() (*file, error) { + h, err := makeServerPipeHandle(l.path, nil, &l.config, false) + if err != nil { + return nil, err + } + f, err := makeFile(h) + if err != nil { + windows.Close(h) + return nil, err + } + return f, nil +} + +func (l *pipeListener) makeConnectedServerPipe() (*file, error) { + p, err := l.makeServerPipe() + if err != nil { + return nil, err + } + + // Wait for the client to connect. + ch := make(chan error) + go func(p *file) { + ch <- connectPipe(p) + }(p) + + select { + case err = <-ch: + if err != nil { + p.Close() + p = nil + } + case <-l.closeCh: + // Abort the connect request by closing the handle. + p.Close() + p = nil + err = <-ch + if err == nil || err == os.ErrClosed { + err = net.ErrClosed + } + } + return p, err +} + +func (l *pipeListener) listenerRoutine() { + closed := false + for !closed { + select { + case <-l.closeCh: + closed = true + case responseCh := <-l.acceptCh: + var ( + p *file + err error + ) + for { + p, err = l.makeConnectedServerPipe() + // If the connection was immediately closed by the client, try + // again. + if err != windows.ERROR_NO_DATA { + break + } + } + responseCh <- acceptResponse{p, err} + closed = err == net.ErrClosed + } + } + windows.Close(l.firstHandle) + l.firstHandle = 0 + // Notify Close and Accept callers that the handle has been closed. + close(l.doneCh) +} + +// ListenConfig contains configuration for the pipe listener. +type ListenConfig struct { + // SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used. + SecurityDescriptor *windows.SECURITY_DESCRIPTOR + + // MessageMode determines whether the pipe is in byte or message mode. In either + // case the pipe is read in byte mode by default. The only practical difference in + // this implementation is that CloseWrite is only supported for message mode pipes; + // CloseWrite is implemented as a zero-byte write, but zero-byte writes are only + // transferred to the reader (and returned as io.EOF in this implementation) + // when the pipe is in message mode. + MessageMode bool + + // InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed. + InputBufferSize int32 + + // OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed. + OutputBufferSize int32 +} + +// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe. +// The pipe must not already exist. +func (c *ListenConfig) Listen(path string) (net.Listener, error) { + h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) + if err != nil { + return nil, err + } + l := &pipeListener{ + firstHandle: h, + path: path, + config: *c, + acceptCh: make(chan chan acceptResponse), + closeCh: make(chan int), + doneCh: make(chan int), + } + // The first connection is swallowed on Windows 7 & 8, so synthesize it. + if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) { + path16, err := windows.UTF16PtrFromString(path) + if err == nil { + h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0) + if err == nil { + windows.CloseHandle(h) + } + } + } + go l.listenerRoutine() + return l, nil +} + +var defaultListener ListenConfig + +// Listen calls ListenConfig.Listen using an empty configuration. +func Listen(path string) (net.Listener, error) { + return defaultListener.Listen(path) +} + +func connectPipe(p *file) error { + c, err := p.prepareIo() + if err != nil { + return err + } + defer p.wg.Done() + + err = windows.ConnectNamedPipe(p.handle, &c.o) + _, err = p.asyncIo(c, nil, 0, err) + if err != nil && err != windows.ERROR_PIPE_CONNECTED { + return err + } + return nil +} + +func (l *pipeListener) Accept() (net.Conn, error) { + ch := make(chan acceptResponse) + select { + case l.acceptCh <- ch: + response := <-ch + err := response.err + if err != nil { + return nil, err + } + if l.config.MessageMode { + return &messageBytePipe{ + pipe: pipe{file: response.f, path: l.path}, + }, nil + } + return &pipe{file: response.f, path: l.path}, nil + case <-l.doneCh: + return nil, net.ErrClosed + } +} + +func (l *pipeListener) Close() error { + select { + case l.closeCh <- 1: + <-l.doneCh + case <-l.doneCh: + } + return nil +} + +func (l *pipeListener) Addr() net.Addr { + return pipeAddress(l.path) +} diff --git a/ipc/namedpipe/namedpipe_test.go b/ipc/namedpipe/namedpipe_test.go new file mode 100644 index 0000000..998453b --- /dev/null +++ b/ipc/namedpipe/namedpipe_test.go @@ -0,0 +1,674 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Copyright 2015 Microsoft +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build windows + +package namedpipe_test + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net" + "os" + "sync" + "syscall" + "testing" + "time" + + "golang.org/x/sys/windows" + "golang.zx2c4.com/wireguard/ipc/namedpipe" +) + +func randomPipePath() string { + guid, err := windows.GenerateGUID() + if err != nil { + panic(err) + } + return `\\.\PIPE\go-namedpipe-test-` + guid.String() +} + +func TestPingPong(t *testing.T) { + const ( + ping = 42 + pong = 24 + ) + pipePath := randomPipePath() + listener, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatalf("unable to listen on pipe: %v", err) + } + defer listener.Close() + go func() { + incoming, err := listener.Accept() + if err != nil { + t.Fatalf("unable to accept pipe connection: %v", err) + } + defer incoming.Close() + var data [1]byte + _, err = incoming.Read(data[:]) + if err != nil { + t.Fatalf("unable to read ping from pipe: %v", err) + } + if data[0] != ping { + t.Fatalf("expected ping, got %d", data[0]) + } + data[0] = pong + _, err = incoming.Write(data[:]) + if err != nil { + t.Fatalf("unable to write pong to pipe: %v", err) + } + }() + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatalf("unable to dial pipe: %v", err) + } + defer client.Close() + client.SetDeadline(time.Now().Add(time.Second * 5)) + var data [1]byte + data[0] = ping + _, err = client.Write(data[:]) + if err != nil { + t.Fatalf("unable to write ping to pipe: %v", err) + } + _, err = client.Read(data[:]) + if err != nil { + t.Fatalf("unable to read pong from pipe: %v", err) + } + if data[0] != pong { + t.Fatalf("expected pong, got %d", data[0]) + } +} + +func TestDialUnknownFailsImmediately(t *testing.T) { + _, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0)) + if !errors.Is(err, syscall.ENOENT) { + t.Fatalf("expected ENOENT got %v", err) + } +} + +func TestDialListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond) + if err == nil { + pipe.Close() + } + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestDialContextListenerTimesOut(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + d := 10 * time.Millisecond + ctx, _ := context.WithTimeout(context.Background(), d) + pipe, err := namedpipe.DialContext(ctx, pipePath) + if err == nil { + pipe.Close() + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestDialListenerGetsCancelled(t *testing.T) { + pipePath := randomPipePath() + ctx, cancel := context.WithCancel(context.Background()) + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + ch := make(chan error) + go func(ctx context.Context, ch chan error) { + _, err := namedpipe.DialContext(ctx, pipePath) + ch <- err + }(ctx, ch) + time.Sleep(time.Millisecond * 30) + cancel() + err = <-ch + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestDialAccessDeniedWithRestrictedSD(t *testing.T) { + if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil { + t.Skip("dacls on named pipes are broken on wine") + } + pipePath := randomPipePath() + sd, _ := windows.SecurityDescriptorFromString("D:") + l, err := (&namedpipe.ListenConfig{ + SecurityDescriptor: sd, + }).Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err == nil { + pipe.Close() + } + if !errors.Is(err, windows.ERROR_ACCESS_DENIED) { + t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err) + } +} + +func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) { + pipePath := randomPipePath() + if cfg == nil { + cfg = &namedpipe.ListenConfig{} + } + l, err := cfg.Listen(pipePath) + if err != nil { + return + } + defer l.Close() + + type response struct { + c net.Conn + err error + } + ch := make(chan response) + go func() { + c, err := l.Accept() + ch <- response{c, err} + }() + + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + return + } + + r := <-ch + if err = r.err; err != nil { + c.Close() + return + } + + client = c + server = r.c + return +} + +func TestReadTimeout(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + c.SetReadDeadline(time.Now().Add(10 * time.Millisecond)) + + buf := make([]byte, 10) + _, err = c.Read(buf) + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func server(l net.Listener, ch chan int) { + c, err := l.Accept() + if err != nil { + panic(err) + } + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + s, err := rw.ReadString('\n') + if err != nil { + panic(err) + } + _, err = rw.WriteString("got " + s) + if err != nil { + panic(err) + } + err = rw.Flush() + if err != nil { + panic(err) + } + c.Close() + ch <- 1 +} + +func TestFullListenDialReadWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + ch := make(chan int) + go server(l, ch) + + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c)) + _, err = rw.WriteString("hello world\n") + if err != nil { + t.Fatal(err) + } + err = rw.Flush() + if err != nil { + t.Fatal(err) + } + + s, err := rw.ReadString('\n') + if err != nil { + t.Fatal(err) + } + ms := "got hello world\n" + if s != ms { + t.Errorf("expected '%s', got '%s'", ms, s) + } + + <-ch +} + +func TestCloseAbortsListen(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + + ch := make(chan error) + go func() { + _, err := l.Accept() + ch <- err + }() + + time.Sleep(30 * time.Millisecond) + l.Close() + + err = <-ch + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) { + b := make([]byte, 10) + w.Close() + n, err := r.Read(b) + if n > 0 { + t.Errorf("unexpected byte count %d", n) + } + if err != io.EOF { + t.Errorf("expected EOF: %v", err) + } +} + +func TestCloseClientEOFServer(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, c, s) +} + +func TestCloseServerEOFClient(t *testing.T) { + c, s, err := getConnection(nil) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + ensureEOFOnClose(t, s, c) +} + +func TestCloseWriteEOF(t *testing.T) { + cfg := &namedpipe.ListenConfig{ + MessageMode: true, + } + c, s, err := getConnection(cfg) + if err != nil { + t.Fatal(err) + } + defer c.Close() + defer s.Close() + + type closeWriter interface { + CloseWrite() error + } + + err = c.(closeWriter).CloseWrite() + if err != nil { + t.Fatal(err) + } + + b := make([]byte, 10) + _, err = s.Read(b) + if err != io.EOF { + t.Fatal(err) + } +} + +func TestAcceptAfterCloseFails(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + l.Close() + _, err = l.Accept() + if err != net.ErrClosed { + t.Fatalf("expected net.ErrClosed, got %v", err) + } +} + +func TestDialTimesOutByDefault(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds. + if err == nil { + pipe.Close() + } + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } +} + +func TestTimeoutPendingRead(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + buf := make([]byte, 10) + _, err = client.Read(buf) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline + client.SetReadDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for read to cancel") + <-clientErr + } + <-serverDone +} + +func TestTimeoutPendingWrite(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + serverDone := make(chan struct{}) + + go func() { + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + time.Sleep(1 * time.Second) + s.Close() + close(serverDone) + }() + + client, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + clientErr := make(chan error) + go func() { + _, err = client.Write([]byte("this should timeout")) + clientErr <- err + }() + + time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline + client.SetWriteDeadline(time.Unix(1, 0)) + + select { + case err = <-clientErr: + if err != os.ErrDeadlineExceeded { + t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("timed out while waiting for write to cancel") + <-clientErr + } + <-serverDone +} + +type CloseWriter interface { + CloseWrite() error +} + +func TestEchoWithMessaging(t *testing.T) { + pipePath := randomPipePath() + l, err := (&namedpipe.ListenConfig{ + MessageMode: true, // Use message mode so that CloseWrite() is supported + InputBufferSize: 65536, // Use 64KB buffers to improve performance + OutputBufferSize: 65536, + }).Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + listenerDone := make(chan bool) + clientDone := make(chan bool) + go func() { + // server echo + conn, err := l.Accept() + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent + _, err = io.Copy(conn, conn) + if err != nil { + t.Fatal(err) + } + conn.(CloseWriter).CloseWrite() + close(listenerDone) + }() + client, err := namedpipe.DialTimeout(pipePath, time.Second) + if err != nil { + t.Fatal(err) + } + defer client.Close() + + go func() { + // client read back + bytes := make([]byte, 2) + n, e := client.Read(bytes) + if e != nil { + t.Fatal(e) + } + if n != 2 || bytes[0] != 0 || bytes[1] != 1 { + t.Fatalf("expected 2 bytes, got %v", n) + } + close(clientDone) + }() + + payload := make([]byte, 2) + payload[0] = 0 + payload[1] = 1 + + n, err := client.Write(payload) + if err != nil { + t.Fatal(err) + } + if n != 2 { + t.Fatalf("expected 2 bytes, got %v", n) + } + client.(CloseWriter).CloseWrite() + <-listenerDone + <-clientDone +} + +func TestConnectRace(t *testing.T) { + pipePath := randomPipePath() + l, err := namedpipe.Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + go func() { + for { + s, err := l.Accept() + if err == net.ErrClosed { + return + } + + if err != nil { + t.Fatal(err) + } + s.Close() + } + }() + + for i := 0; i < 1000; i++ { + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + c.Close() + } +} + +func TestMessageReadMode(t *testing.T) { + if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 { + t.Skipf("Skipping on Windows %d", maj) + } + var wg sync.WaitGroup + defer wg.Wait() + pipePath := randomPipePath() + l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath) + if err != nil { + t.Fatal(err) + } + defer l.Close() + + msg := ([]byte)("hello world") + + wg.Add(1) + go func() { + defer wg.Done() + s, err := l.Accept() + if err != nil { + t.Fatal(err) + } + _, err = s.Write(msg) + if err != nil { + t.Fatal(err) + } + s.Close() + }() + + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + mode := uint32(windows.PIPE_READMODE_MESSAGE) + err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil) + if err != nil { + t.Fatal(err) + } + + ch := make([]byte, 1) + var vmsg []byte + for { + n, err := c.Read(ch) + if err == io.EOF { + break + } + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatalf("expected 1, got %d", n) + } + vmsg = append(vmsg, ch[0]) + } + if !bytes.Equal(msg, vmsg) { + t.Fatalf("expected %s, got %s", msg, vmsg) + } +} + +func TestListenConnectRace(t *testing.T) { + if testing.Short() { + t.Skip("Skipping long race test") + } + pipePath := randomPipePath() + for i := 0; i < 50 && !t.Failed(); i++ { + var wg sync.WaitGroup + wg.Add(1) + go func() { + c, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) + if err == nil { + c.Close() + } + wg.Done() + }() + s, err := namedpipe.Listen(pipePath) + if err != nil { + t.Error(i, err) + } else { + s.Close() + } + wg.Wait() + } +} diff --git a/ipc/uapi_bsd.go b/ipc/uapi_bsd.go index 75cc0e3..ddcaf27 100644 --- a/ipc/uapi_bsd.go +++ b/ipc/uapi_bsd.go @@ -1,33 +1,21 @@ -// +build darwin freebsd openbsd +//go:build darwin || freebsd || openbsd /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ipc import ( "errors" - "fmt" "net" "os" - "path" "unsafe" "golang.org/x/sys/unix" ) -var socketDirectory = "/var/run/wireguard" - -const ( - IpcErrorIO = -int64(unix.EIO) - IpcErrorProtocol = -int64(unix.EPROTO) - IpcErrorInvalid = -int64(unix.EINVAL) - IpcErrorPortInUse = -int64(unix.EADDRINUSE) - socketName = "%s.sock" -) - type UAPIListener struct { listener net.Listener // unix socket listener connNew chan net.Conn @@ -66,7 +54,6 @@ func (l *UAPIListener) Addr() net.Addr { } func UAPIListen(name string, file *os.File) (net.Listener, error) { - // wrap file in listener listener, err := net.FileListener(file) @@ -84,10 +71,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) { unixListener.SetUnlinkOnClose(true) } - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) + socketPath := sockPath(name) // watch for deletion of socket @@ -119,7 +103,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) { l.connErr <- err return } - if kerr != nil || n != 1 { + if (kerr != nil || n != 1) && kerr != unix.EINTR { if kerr != nil { l.connErr <- kerr } else { @@ -146,58 +130,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) { return uapi, nil } - -func UAPIOpen(name string) (*os.File, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 0755) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - addr, err := net.ResolveUnixAddr("unix", socketPath) - if err != nil { - return nil, err - } - - oldUmask := unix.Umask(0077) - listener, err := func() (*net.UnixListener, error) { - - // initial connection attempt - - listener, err := net.ListenUnix("unix", addr) - if err == nil { - return listener, nil - } - - // check if socket already active - - _, err = net.Dial("unix", socketPath) - if err == nil { - return nil, errors.New("unix socket in use") - } - - // cleanup & attempt again - - err = os.Remove(socketPath) - if err != nil { - return nil, err - } - return net.ListenUnix("unix", addr) - }() - unix.Umask(oldUmask) - - if err != nil { - return nil, err - } - - return listener.File() -} diff --git a/ipc/uapi_linux.go b/ipc/uapi_linux.go index a3c95ca..1562a18 100644 --- a/ipc/uapi_linux.go +++ b/ipc/uapi_linux.go @@ -1,31 +1,18 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ipc import ( - "errors" - "fmt" "net" "os" - "path" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/rwcancel" ) -var socketDirectory = "/var/run/wireguard" - -const ( - IpcErrorIO = -int64(unix.EIO) - IpcErrorProtocol = -int64(unix.EPROTO) - IpcErrorInvalid = -int64(unix.EINVAL) - IpcErrorPortInUse = -int64(unix.EADDRINUSE) - socketName = "%s.sock" -) - type UAPIListener struct { listener net.Listener // unix socket listener connNew chan net.Conn @@ -64,7 +51,6 @@ func (l *UAPIListener) Addr() net.Addr { } func UAPIListen(name string, file *os.File) (net.Listener, error) { - // wrap file in listener listener, err := net.FileListener(file) @@ -84,10 +70,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) { // watch for deletion of socket - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) + socketPath := sockPath(name) uapi.inotifyFd, err = unix.InotifyInit() if err != nil { @@ -113,14 +96,15 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) { } go func(l *UAPIListener) { - var buff [0]byte + var buf [0]byte for { + defer uapi.inotifyRWCancel.Close() // start with lstat to avoid race condition if _, err := os.Lstat(socketPath); os.IsNotExist(err) { l.connErr <- err return } - _, err := uapi.inotifyRWCancel.Read(buff[:]) + _, err := uapi.inotifyRWCancel.Read(buf[:]) if err != nil { l.connErr <- err return @@ -143,58 +127,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) { return uapi, nil } - -func UAPIOpen(name string) (*os.File, error) { - - // check if path exist - - err := os.MkdirAll(socketDirectory, 0755) - if err != nil && !os.IsExist(err) { - return nil, err - } - - // open UNIX socket - - socketPath := path.Join( - socketDirectory, - fmt.Sprintf(socketName, name), - ) - - addr, err := net.ResolveUnixAddr("unix", socketPath) - if err != nil { - return nil, err - } - - oldUmask := unix.Umask(0077) - listener, err := func() (*net.UnixListener, error) { - - // initial connection attempt - - listener, err := net.ListenUnix("unix", addr) - if err == nil { - return listener, nil - } - - // check if socket already active - - _, err = net.Dial("unix", socketPath) - if err == nil { - return nil, errors.New("unix socket in use") - } - - // cleanup & attempt again - - err = os.Remove(socketPath) - if err != nil { - return nil, err - } - return net.ListenUnix("unix", addr) - }() - unix.Umask(oldUmask) - - if err != nil { - return nil, err - } - - return listener.File() -} diff --git a/ipc/uapi_unix.go b/ipc/uapi_unix.go new file mode 100644 index 0000000..e67be26 --- /dev/null +++ b/ipc/uapi_unix.go @@ -0,0 +1,66 @@ +//go:build linux || darwin || freebsd || openbsd + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package ipc + +import ( + "errors" + "fmt" + "net" + "os" + + "golang.org/x/sys/unix" +) + +const ( + IpcErrorIO = -int64(unix.EIO) + IpcErrorProtocol = -int64(unix.EPROTO) + IpcErrorInvalid = -int64(unix.EINVAL) + IpcErrorPortInUse = -int64(unix.EADDRINUSE) + IpcErrorUnknown = -55 // ENOANO +) + +// socketDirectory is variable because it is modified by a linker +// flag in wireguard-android. +var socketDirectory = "/var/run/wireguard" + +func sockPath(iface string) string { + return fmt.Sprintf("%s/%s.sock", socketDirectory, iface) +} + +func UAPIOpen(name string) (*os.File, error) { + if err := os.MkdirAll(socketDirectory, 0o755); err != nil { + return nil, err + } + + socketPath := sockPath(name) + addr, err := net.ResolveUnixAddr("unix", socketPath) + if err != nil { + return nil, err + } + + oldUmask := unix.Umask(0o077) + defer unix.Umask(oldUmask) + + listener, err := net.ListenUnix("unix", addr) + if err == nil { + return listener.File() + } + + // Test socket, if not in use cleanup and try again. + if _, err := net.Dial("unix", socketPath); err == nil { + return nil, errors.New("unix socket in use") + } + if err := os.Remove(socketPath); err != nil { + return nil, err + } + listener, err = net.ListenUnix("unix", addr) + if err != nil { + return nil, err + } + return listener.File() +} diff --git a/ipc/uapi_wasm.go b/ipc/uapi_wasm.go new file mode 100644 index 0000000..fa84684 --- /dev/null +++ b/ipc/uapi_wasm.go @@ -0,0 +1,15 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package ipc + +// Made up sentinel error codes for {js,wasip1}/wasm. +const ( + IpcErrorIO = 1 + IpcErrorInvalid = 2 + IpcErrorPortInUse = 3 + IpcErrorUnknown = 4 + IpcErrorProtocol = 5 +) diff --git a/ipc/uapi_windows.go b/ipc/uapi_windows.go index ead0dc5..aa023c9 100644 --- a/ipc/uapi_windows.go +++ b/ipc/uapi_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ipc @@ -9,8 +9,7 @@ import ( "net" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/ipc/winpipe" + "golang.zx2c4.com/wireguard/ipc/namedpipe" ) // TODO: replace these with actual standard windows error numbers from the win package @@ -19,6 +18,7 @@ const ( IpcErrorProtocol = -int64(71) IpcErrorInvalid = -int64(22) IpcErrorPortInUse = -int64(98) + IpcErrorUnknown = -int64(55) ) type UAPIListener struct { @@ -53,18 +53,16 @@ var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR func init() { var err error - /* SDDL_DEVOBJ_SYS_ALL from the WDK */ - UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)") + UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)") if err != nil { panic(err) } } func UAPIListen(name string) (net.Listener, error) { - config := winpipe.PipeConfig{ + listener, err := (&namedpipe.ListenConfig{ SecurityDescriptor: UAPISecurityDescriptor, - } - listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config) + }).Listen(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\` + name) if err != nil { return nil, err } diff --git a/ipc/winpipe/mksyscall.go b/ipc/winpipe/mksyscall.go deleted file mode 100644 index 19ac03a..0000000 --- a/ipc/winpipe/mksyscall.go +++ /dev/null @@ -1,9 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package winpipe - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go diff --git a/ipc/winpipe/pipe.go b/ipc/winpipe/pipe.go deleted file mode 100644 index 06b3037..0000000 --- a/ipc/winpipe/pipe.go +++ /dev/null @@ -1,509 +0,0 @@ -// +build windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2005 Microsoft - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package winpipe - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "os" - "runtime" - "time" - "unsafe" - - "golang.org/x/sys/windows" -) - -//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe -//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW -//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW -//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo -//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW -//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc -//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile -//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb -//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U -//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl - -type ioStatusBlock struct { - Status, Information uintptr -} - -type objectAttributes struct { - Length uintptr - RootDirectory uintptr - ObjectName *unicodeString - Attributes uintptr - SecurityDescriptor *windows.SECURITY_DESCRIPTOR - SecurityQoS uintptr -} - -type unicodeString struct { - Length uint16 - MaximumLength uint16 - Buffer uintptr -} - -type ntstatus int32 - -func (status ntstatus) Err() error { - if status >= 0 { - return nil - } - return rtlNtStatusToDosError(status) -} - -const ( - cSECURITY_SQOS_PRESENT = 0x100000 - cSECURITY_ANONYMOUS = 0 - - cPIPE_TYPE_MESSAGE = 4 - - cPIPE_READMODE_MESSAGE = 2 - - cFILE_OPEN = 1 - cFILE_CREATE = 2 - - cFILE_PIPE_MESSAGE_TYPE = 1 - cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2 -) - -var ( - // ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed. - // This error should match net.errClosing since docker takes a dependency on its text. - ErrPipeListenerClosed = errors.New("use of closed network connection") - - errPipeWriteClosed = errors.New("pipe has been closed for write") -) - -type win32Pipe struct { - *win32File - path string -} - -type win32MessageBytePipe struct { - win32Pipe - writeClosed bool - readEOF bool -} - -type pipeAddress string - -func (f *win32Pipe) LocalAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *win32Pipe) RemoteAddr() net.Addr { - return pipeAddress(f.path) -} - -func (f *win32Pipe) SetDeadline(t time.Time) error { - f.SetReadDeadline(t) - f.SetWriteDeadline(t) - return nil -} - -// CloseWrite closes the write side of a message pipe in byte mode. -func (f *win32MessageBytePipe) CloseWrite() error { - if f.writeClosed { - return errPipeWriteClosed - } - err := f.win32File.Flush() - if err != nil { - return err - } - _, err = f.win32File.Write(nil) - if err != nil { - return err - } - f.writeClosed = true - return nil -} - -// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since -// they are used to implement CloseWrite(). -func (f *win32MessageBytePipe) Write(b []byte) (int, error) { - if f.writeClosed { - return 0, errPipeWriteClosed - } - if len(b) == 0 { - return 0, nil - } - return f.win32File.Write(b) -} - -// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message -// mode pipe will return io.EOF, as will all subsequent reads. -func (f *win32MessageBytePipe) Read(b []byte) (int, error) { - if f.readEOF { - return 0, io.EOF - } - n, err := f.win32File.Read(b) - if err == io.EOF { - // If this was the result of a zero-byte read, then - // it is possible that the read was due to a zero-size - // message. Since we are simulating CloseWrite with a - // zero-byte message, ensure that all future Read() calls - // also return EOF. - f.readEOF = true - } else if err == windows.ERROR_MORE_DATA { - // ERROR_MORE_DATA indicates that the pipe's read mode is message mode - // and the message still has more bytes. Treat this as a success, since - // this package presents all named pipes as byte streams. - err = nil - } - return n, err -} - -func (s pipeAddress) Network() string { - return "pipe" -} - -func (s pipeAddress) String() string { - return string(s) -} - -// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. -func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) { - for { - select { - case <-ctx.Done(): - return windows.Handle(0), ctx.Err() - default: - h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0) - if err == nil { - return h, nil - } - if err != windows.ERROR_PIPE_BUSY { - return h, &os.PathError{Err: err, Op: "open", Path: *path} - } - // Wait 10 msec and try again. This is a rather simplistic - // view, as we always try each 10 milliseconds. - time.Sleep(time.Millisecond * 10) - } - } -} - -// DialPipe connects to a named pipe by path, timing out if the connection -// takes longer than the specified duration. If timeout is nil, then we use -// a default timeout of 2 seconds. (We do not use WaitNamedPipe.) -func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) { - var absTimeout time.Time - if timeout != nil { - absTimeout = time.Now().Add(*timeout) - } else { - absTimeout = time.Now().Add(time.Second * 2) - } - ctx, _ := context.WithDeadline(context.Background(), absTimeout) - conn, err := DialPipeContext(ctx, path, expectedOwner) - if err == context.DeadlineExceeded { - return nil, ErrTimeout - } - return conn, err -} - -// DialPipeContext attempts to connect to a named pipe by `path` until `ctx` -// cancellation or timeout. -func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) { - var err error - var h windows.Handle - h, err = tryDialPipe(ctx, &path) - if err != nil { - return nil, err - } - - if expectedOwner != nil { - sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION) - if err != nil { - windows.Close(h) - return nil, err - } - realOwner, _, err := sd.Owner() - if err != nil { - windows.Close(h) - return nil, err - } - if !realOwner.Equals(expectedOwner) { - windows.Close(h) - return nil, windows.ERROR_ACCESS_DENIED - } - } - - var flags uint32 - err = getNamedPipeInfo(h, &flags, nil, nil, nil) - if err != nil { - windows.Close(h) - return nil, err - } - - f, err := makeWin32File(h) - if err != nil { - windows.Close(h) - return nil, err - } - - // If the pipe is in message mode, return a message byte pipe, which - // supports CloseWrite(). - if flags&cPIPE_TYPE_MESSAGE != 0 { - return &win32MessageBytePipe{ - win32Pipe: win32Pipe{win32File: f, path: path}, - }, nil - } - return &win32Pipe{win32File: f, path: path}, nil -} - -type acceptResponse struct { - f *win32File - err error -} - -type win32PipeListener struct { - firstHandle windows.Handle - path string - config PipeConfig - acceptCh chan (chan acceptResponse) - closeCh chan int - doneCh chan int -} - -func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) { - path16, err := windows.UTF16FromString(path) - if err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - var oa objectAttributes - oa.Length = unsafe.Sizeof(oa) - - var ntPath unicodeString - if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - defer windows.LocalFree(windows.Handle(ntPath.Buffer)) - oa.ObjectName = &ntPath - - // The security descriptor is only needed for the first pipe. - if first { - if sd != nil { - oa.SecurityDescriptor = sd - } else { - // Construct the default named pipe security descriptor. - var dacl uintptr - if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { - return 0, fmt.Errorf("getting default named pipe ACL: %s", err) - } - defer windows.LocalFree(windows.Handle(dacl)) - sd, err := windows.NewSecurityDescriptor() - if err != nil { - return 0, fmt.Errorf("creating new security descriptor: %s", err) - } - if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil { - return 0, fmt.Errorf("assigning dacl: %s", err) - } - sd, err = sd.ToSelfRelative() - if err != nil { - return 0, fmt.Errorf("converting to self-relative: %s", err) - } - oa.SecurityDescriptor = sd - } - } - - typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS) - if c.MessageMode { - typ |= cFILE_PIPE_MESSAGE_TYPE - } - - disposition := uint32(cFILE_OPEN) - access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE) - if first { - disposition = cFILE_CREATE - // By not asking for read or write access, the named pipe file system - // will put this pipe into an initially disconnected state, blocking - // client connections until the next call with first == false. - access = windows.SYNCHRONIZE - } - - timeout := int64(-50 * 10000) // 50ms - - var ( - h windows.Handle - iosb ioStatusBlock - ) - err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err() - if err != nil { - return 0, &os.PathError{Op: "open", Path: path, Err: err} - } - - runtime.KeepAlive(ntPath) - return h, nil -} - -func (l *win32PipeListener) makeServerPipe() (*win32File, error) { - h, err := makeServerPipeHandle(l.path, nil, &l.config, false) - if err != nil { - return nil, err - } - f, err := makeWin32File(h) - if err != nil { - windows.Close(h) - return nil, err - } - return f, nil -} - -func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) { - p, err := l.makeServerPipe() - if err != nil { - return nil, err - } - - // Wait for the client to connect. - ch := make(chan error) - go func(p *win32File) { - ch <- connectPipe(p) - }(p) - - select { - case err = <-ch: - if err != nil { - p.Close() - p = nil - } - case <-l.closeCh: - // Abort the connect request by closing the handle. - p.Close() - p = nil - err = <-ch - if err == nil || err == ErrFileClosed { - err = ErrPipeListenerClosed - } - } - return p, err -} - -func (l *win32PipeListener) listenerRoutine() { - closed := false - for !closed { - select { - case <-l.closeCh: - closed = true - case responseCh := <-l.acceptCh: - var ( - p *win32File - err error - ) - for { - p, err = l.makeConnectedServerPipe() - // If the connection was immediately closed by the client, try - // again. - if err != windows.ERROR_NO_DATA { - break - } - } - responseCh <- acceptResponse{p, err} - closed = err == ErrPipeListenerClosed - } - } - windows.Close(l.firstHandle) - l.firstHandle = 0 - // Notify Close() and Accept() callers that the handle has been closed. - close(l.doneCh) -} - -// PipeConfig contain configuration for the pipe listener. -type PipeConfig struct { - // SecurityDescriptor contains a Windows security descriptor. - SecurityDescriptor *windows.SECURITY_DESCRIPTOR - - // MessageMode determines whether the pipe is in byte or message mode. In either - // case the pipe is read in byte mode by default. The only practical difference in - // this implementation is that CloseWrite() is only supported for message mode pipes; - // CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only - // transferred to the reader (and returned as io.EOF in this implementation) - // when the pipe is in message mode. - MessageMode bool - - // InputBufferSize specifies the size the input buffer, in bytes. - InputBufferSize int32 - - // OutputBufferSize specifies the size the input buffer, in bytes. - OutputBufferSize int32 -} - -// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe. -// The pipe must not already exist. -func ListenPipe(path string, c *PipeConfig) (net.Listener, error) { - if c == nil { - c = &PipeConfig{} - } - h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true) - if err != nil { - return nil, err - } - l := &win32PipeListener{ - firstHandle: h, - path: path, - config: *c, - acceptCh: make(chan (chan acceptResponse)), - closeCh: make(chan int), - doneCh: make(chan int), - } - go l.listenerRoutine() - return l, nil -} - -func connectPipe(p *win32File) error { - c, err := p.prepareIo() - if err != nil { - return err - } - defer p.wg.Done() - - err = connectNamedPipe(p.handle, &c.o) - _, err = p.asyncIo(c, nil, 0, err) - if err != nil && err != windows.ERROR_PIPE_CONNECTED { - return err - } - return nil -} - -func (l *win32PipeListener) Accept() (net.Conn, error) { - ch := make(chan acceptResponse) - select { - case l.acceptCh <- ch: - response := <-ch - err := response.err - if err != nil { - return nil, err - } - if l.config.MessageMode { - return &win32MessageBytePipe{ - win32Pipe: win32Pipe{win32File: response.f, path: l.path}, - }, nil - } - return &win32Pipe{win32File: response.f, path: l.path}, nil - case <-l.doneCh: - return nil, ErrPipeListenerClosed - } -} - -func (l *win32PipeListener) Close() error { - select { - case l.closeCh <- 1: - <-l.doneCh - case <-l.doneCh: - } - return nil -} - -func (l *win32PipeListener) Addr() net.Addr { - return pipeAddress(l.path) -} diff --git a/ipc/winpipe/zsyscall_windows.go b/ipc/winpipe/zsyscall_windows.go deleted file mode 100644 index 9954329..0000000 --- a/ipc/winpipe/zsyscall_windows.go +++ /dev/null @@ -1,238 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package winpipe - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modntdll = windows.NewLazySystemDLL("ntdll.dll") - modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") - - procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") - procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") - procCreateFileW = modkernel32.NewProc("CreateFileW") - procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo") - procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW") - procLocalAlloc = modkernel32.NewProc("LocalAlloc") - procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile") - procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") - procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") - procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl") - procCancelIoEx = modkernel32.NewProc("CancelIoEx") - procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort") - procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus") - procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes") - procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult") -) - -func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) { - r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { - var _p0 *uint16 - _p0, err = syscall.UTF16PtrFromString(name) - if err != nil { - return - } - return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa) -} - -func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) { - var _p0 *uint16 - _p0, err = syscall.UTF16PtrFromString(name) - if err != nil { - return - } - return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile) -} - -func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0) - handle = windows.Handle(r0) - if handle == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) { - r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func localAlloc(uFlags uint32, length uint32) (ptr uintptr) { - r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0) - ptr = uintptr(r0) - return -} - -func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) { - r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0) - status = ntstatus(r0) - return -} - -func rtlNtStatusToDosError(status ntstatus) (winerr error) { - r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0) - if r0 != 0 { - winerr = syscall.Errno(r0) - } - return -} - -func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) { - r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0) - status = ntstatus(r0) - return -} - -func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) { - r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0) - status = ntstatus(r0) - return -} - -func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) { - r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) { - r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0) - newport = windows.Handle(r0) - if newport == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) { - r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) { - var _p0 uint32 - if wait { - _p0 = 1 - } else { - _p0 = 0 - } - r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} @@ -1,8 +1,8 @@ -// +build !windows +//go:build !windows /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package main @@ -13,8 +13,9 @@ import ( "os/signal" "runtime" "strconv" - "syscall" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" @@ -32,32 +33,33 @@ const ( ) func printUsage() { - fmt.Printf("usage:\n") - fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) + fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0]) } func warning() { - if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { + switch runtime.GOOS { + case "linux", "freebsd", "openbsd": + if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { + return + } + default: return } - fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") - fmt.Fprintln(os.Stderr, "W G") - fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G") - fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G") - fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G") - fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G") - fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G") - fmt.Fprintln(os.Stderr, "W implementation. For more information on G") - fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G") - fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G") - fmt.Fprintln(os.Stderr, "W G") - fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") + fmt.Fprintln(os.Stderr, "āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā") + fmt.Fprintln(os.Stderr, "ā ā") + fmt.Fprintln(os.Stderr, "ā Running wireguard-go is not required because this ā") + fmt.Fprintln(os.Stderr, "ā kernel has first class support for WireGuard. For ā") + fmt.Fprintln(os.Stderr, "ā information on installing the kernel module, ā") + fmt.Fprintln(os.Stderr, "ā please visit: ā") + fmt.Fprintln(os.Stderr, "ā https://www.wireguard.com/install/ ā") + fmt.Fprintln(os.Stderr, "ā ā") + fmt.Fprintln(os.Stderr, "āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā") } func main() { if len(os.Args) == 2 && os.Args[1] == "--version" { - fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", device.WireGuardGoVersion, runtime.GOOS, runtime.GOARCH) + fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", Version, runtime.GOOS, runtime.GOARCH) return } @@ -97,21 +99,19 @@ func main() { logLevel := func() int { switch os.Getenv("LOG_LEVEL") { - case "debug": - return device.LogLevelDebug - case "info": - return device.LogLevelInfo + case "verbose", "debug": + return device.LogLevelVerbose case "error": return device.LogLevelError case "silent": return device.LogLevelSilent } - return device.LogLevelInfo + return device.LogLevelError }() // open TUN device (or use supplied fd) - tun, err := func() (tun.Device, error) { + tdev, err := func() (tun.Device, error) { tunFdStr := os.Getenv(ENV_WG_TUN_FD) if tunFdStr == "" { return tun.CreateTUN(interfaceName, device.DefaultMTU) @@ -124,7 +124,7 @@ func main() { return nil, err } - err = syscall.SetNonblock(int(fd), true) + err = unix.SetNonblock(int(fd), true) if err != nil { return nil, err } @@ -134,7 +134,7 @@ func main() { }() if err == nil { - realInterfaceName, err2 := tun.Name() + realInterfaceName, err2 := tdev.Name() if err2 == nil { interfaceName = realInterfaceName } @@ -145,12 +145,10 @@ func main() { fmt.Sprintf("(%s) ", interfaceName), ) - logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion) - - logger.Debug.Println("Debug log enabled") + logger.Verbosef("Starting wireguard-go version %s", Version) if err != nil { - logger.Error.Println("Failed to create TUN device:", err) + logger.Errorf("Failed to create TUN device: %v", err) os.Exit(ExitSetupFailed) } @@ -171,9 +169,8 @@ func main() { return os.NewFile(uintptr(fd), ""), nil }() - if err != nil { - logger.Error.Println("UAPI listen error:", err) + logger.Errorf("UAPI listen error: %v", err) os.Exit(ExitSetupFailed) return } @@ -199,7 +196,7 @@ func main() { files[0], // stdin files[1], // stdout files[2], // stderr - tun.File(), + tdev.File(), fileUAPI, }, Dir: ".", @@ -208,7 +205,7 @@ func main() { path, err := os.Executable() if err != nil { - logger.Error.Println("Failed to determine executable:", err) + logger.Errorf("Failed to determine executable: %v", err) os.Exit(ExitSetupFailed) } @@ -218,23 +215,23 @@ func main() { attr, ) if err != nil { - logger.Error.Println("Failed to daemonize:", err) + logger.Errorf("Failed to daemonize: %v", err) os.Exit(ExitSetupFailed) } process.Release() return } - device := device.NewDevice(tun, logger) + device := device.NewDevice(tdev, conn.NewDefaultBind(), logger) - logger.Info.Println("Device started") + logger.Verbosef("Device started") errs := make(chan error) term := make(chan os.Signal, 1) uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) if err != nil { - logger.Error.Println("Failed to listen on uapi socket:", err) + logger.Errorf("Failed to listen on uapi socket: %v", err) os.Exit(ExitSetupFailed) } @@ -249,11 +246,11 @@ func main() { } }() - logger.Info.Println("UAPI listener started") + logger.Verbosef("UAPI listener started") // wait for program to terminate - signal.Notify(term, syscall.SIGTERM) + signal.Notify(term, unix.SIGTERM) signal.Notify(term, os.Interrupt) select { @@ -267,5 +264,5 @@ func main() { uapi.Close() device.Close() - logger.Info.Println("Shutting down") + logger.Verbosef("Shutting down") } diff --git a/main_windows.go b/main_windows.go index f57bc8d..a4dc46f 100644 --- a/main_windows.go +++ b/main_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package main @@ -9,8 +9,10 @@ import ( "fmt" "os" "os/signal" - "syscall" + "golang.org/x/sys/windows" + + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" @@ -31,11 +33,10 @@ func main() { fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.") logger := device.NewLogger( - device.LogLevelDebug, + device.LogLevelVerbose, fmt.Sprintf("(%s) ", interfaceName), ) - logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion) - logger.Debug.Println("Debug log enabled") + logger.Verbosef("Starting wireguard-go version %s", Version) tun, err := tun.CreateTUN(interfaceName, 0) if err == nil { @@ -44,17 +45,21 @@ func main() { interfaceName = realInterfaceName } } else { - logger.Error.Println("Failed to create TUN device:", err) + logger.Errorf("Failed to create TUN device: %v", err) os.Exit(ExitSetupFailed) } - device := device.NewDevice(tun, logger) - device.Up() - logger.Info.Println("Device started") + device := device.NewDevice(tun, conn.NewDefaultBind(), logger) + err = device.Up() + if err != nil { + logger.Errorf("Failed to bring up device: %v", err) + os.Exit(ExitSetupFailed) + } + logger.Verbosef("Device started") uapi, err := ipc.UAPIListen(interfaceName) if err != nil { - logger.Error.Println("Failed to listen on uapi socket:", err) + logger.Errorf("Failed to listen on uapi socket: %v", err) os.Exit(ExitSetupFailed) } @@ -71,13 +76,13 @@ func main() { go device.IpcHandle(conn) } }() - logger.Info.Println("UAPI listener started") + logger.Verbosef("UAPI listener started") // wait for program to terminate signal.Notify(term, os.Interrupt) signal.Notify(term, os.Kill) - signal.Notify(term, syscall.SIGTERM) + signal.Notify(term, windows.SIGTERM) select { case <-term: @@ -90,5 +95,5 @@ func main() { uapi.Close() device.Close() - logger.Info.Println("Shutting down") + logger.Verbosef("Shutting down") } diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go index a6d0ea2..f7d05ef 100644 --- a/ratelimiter/ratelimiter.go +++ b/ratelimiter/ratelimiter.go @@ -1,12 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ratelimiter import ( - "net" + "net/netip" "sync" "time" ) @@ -30,8 +30,7 @@ type Ratelimiter struct { timeNow func() time.Time stopReset chan struct{} // send to reset, close to stop - tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry - tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry + table map[netip.Addr]*RatelimiterEntry } func (rate *Ratelimiter) Close() { @@ -57,8 +56,7 @@ func (rate *Ratelimiter) Init() { } rate.stopReset = make(chan struct{}) - rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) - rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) + rate.table = make(map[netip.Addr]*RatelimiterEntry) stopReset := rate.stopReset // store in case Init is called again. @@ -87,71 +85,39 @@ func (rate *Ratelimiter) cleanup() (empty bool) { rate.mu.Lock() defer rate.mu.Unlock() - for key, entry := range rate.tableIPv4 { + for key, entry := range rate.table { entry.mu.Lock() if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv4, key) + delete(rate.table, key) } entry.mu.Unlock() } - for key, entry := range rate.tableIPv6 { - entry.mu.Lock() - if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime { - delete(rate.tableIPv6, key) - } - entry.mu.Unlock() - } - - return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 + return len(rate.table) == 0 } -func (rate *Ratelimiter) Allow(ip net.IP) bool { +func (rate *Ratelimiter) Allow(ip netip.Addr) bool { var entry *RatelimiterEntry - var keyIPv4 [net.IPv4len]byte - var keyIPv6 [net.IPv6len]byte - // lookup entry - - IPv4 := ip.To4() - IPv6 := ip.To16() - rate.mu.RLock() - - if IPv4 != nil { - copy(keyIPv4[:], IPv4) - entry = rate.tableIPv4[keyIPv4] - } else { - copy(keyIPv6[:], IPv6) - entry = rate.tableIPv6[keyIPv6] - } - + entry = rate.table[ip] rate.mu.RUnlock() // make new entry if not found - if entry == nil { entry = new(RatelimiterEntry) entry.tokens = maxTokens - packetCost entry.lastTime = rate.timeNow() rate.mu.Lock() - if IPv4 != nil { - rate.tableIPv4[keyIPv4] = entry - if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 { - rate.stopReset <- struct{}{} - } - } else { - rate.tableIPv6[keyIPv6] = entry - if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 { - rate.stopReset <- struct{}{} - } + rate.table[ip] = entry + if len(rate.table) == 1 { + rate.stopReset <- struct{}{} } rate.mu.Unlock() return true } // add tokens to entry - entry.mu.Lock() now := rate.timeNow() entry.tokens += now.Sub(entry.lastTime).Nanoseconds() @@ -161,7 +127,6 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool { } // subtract cost of packet - if entry.tokens > packetCost { entry.tokens -= packetCost entry.mu.Unlock() diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go index 25d5d63..0bfa3af 100644 --- a/ratelimiter/ratelimiter_test.go +++ b/ratelimiter/ratelimiter_test.go @@ -1,12 +1,12 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package ratelimiter import ( - "net" + "net/netip" "testing" "time" ) @@ -71,21 +71,21 @@ func TestRatelimiter(t *testing.T) { text: "packet following 2 packet burst", }) - ips := []net.IP{ - net.ParseIP("127.0.0.1"), - net.ParseIP("192.168.1.1"), - net.ParseIP("172.167.2.3"), - net.ParseIP("97.231.252.215"), - net.ParseIP("248.97.91.167"), - net.ParseIP("188.208.233.47"), - net.ParseIP("104.2.183.179"), - net.ParseIP("72.129.46.120"), - net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), - net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"), - net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), - net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), - net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), - net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), + ips := []netip.Addr{ + netip.MustParseAddr("127.0.0.1"), + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("172.167.2.3"), + netip.MustParseAddr("97.231.252.215"), + netip.MustParseAddr("248.97.91.167"), + netip.MustParseAddr("188.208.233.47"), + netip.MustParseAddr("104.2.183.179"), + netip.MustParseAddr("72.129.46.120"), + netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"), + netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"), + netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"), + netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"), + netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"), + netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"), } now := time.Now() diff --git a/replay/replay.go b/replay/replay.go index 0f6b6c9..8b99e23 100644 --- a/replay/replay.go +++ b/replay/replay.go @@ -1,83 +1,62 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ +// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479. package replay -/* Implementation of RFC6479 - * https://tools.ietf.org/html/rfc6479 - * - * The implementation is not safe for concurrent use! - */ - -const ( - // See: https://golang.org/src/math/big/arith.go - _Wordm = ^uintptr(0) - _WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1 - _WordSize = 1 << _WordLogSize -) +type block uint64 const ( - CounterRedundantBitsLog = _WordLogSize + 3 - CounterRedundantBits = _WordSize * 8 - CounterBitsTotal = 2048 - CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits) + blockBitLog = 6 // 1<<6 == 64 bits + blockBits = 1 << blockBitLog // must be power of 2 + ringBlocks = 1 << 7 // must be power of 2 + windowSize = (ringBlocks - 1) * blockBits + blockMask = ringBlocks - 1 + bitMask = blockBits - 1 ) -const ( - BacktrackWords = CounterBitsTotal / _WordSize -) - -func minUint64(a uint64, b uint64) uint64 { - if a > b { - return b - } - return a -} - -type ReplayFilter struct { - counter uint64 - backtrack [BacktrackWords]uintptr +// A Filter rejects replayed messages by checking if message counter value is +// within a sliding window of previously received messages. +// The zero value for Filter is an empty filter ready to use. +// Filters are unsafe for concurrent use. +type Filter struct { + last uint64 + ring [ringBlocks]block } -func (filter *ReplayFilter) Init() { - filter.counter = 0 - filter.backtrack[0] = 0 +// Reset resets the filter to empty state. +func (f *Filter) Reset() { + f.last = 0 + f.ring[0] = 0 } -func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool { +// ValidateCounter checks if the counter should be accepted. +// Overlimit counters (>= limit) are always rejected. +func (f *Filter) ValidateCounter(counter, limit uint64) bool { if counter >= limit { return false } - - indexWord := counter >> CounterRedundantBitsLog - - if counter > filter.counter { - - // move window forward - - current := filter.counter >> CounterRedundantBitsLog - diff := minUint64(indexWord-current, BacktrackWords) - for i := uint64(1); i <= diff; i++ { - filter.backtrack[(current+i)%BacktrackWords] = 0 + indexBlock := counter >> blockBitLog + if counter > f.last { // move window forward + current := f.last >> blockBitLog + diff := indexBlock - current + if diff > ringBlocks { + diff = ringBlocks // cap diff to clear the whole ring } - filter.counter = counter - - } else if filter.counter-counter > CounterWindowSize { - - // behind current window - + for i := current + 1; i <= current+diff; i++ { + f.ring[i&blockMask] = 0 + } + f.last = counter + } else if f.last-counter > windowSize { // behind current window return false } - - indexWord %= BacktrackWords - indexBit := counter & uint64(CounterRedundantBits-1) - // check and set bit - - oldValue := filter.backtrack[indexWord] - newValue := oldValue | (1 << indexBit) - filter.backtrack[indexWord] = newValue - return oldValue != newValue + indexBlock &= blockMask + indexBit := counter & bitMask + old := f.ring[indexBlock] + new := old | 1<<indexBit + f.ring[indexBlock] = new + return old != new } diff --git a/replay/replay_test.go b/replay/replay_test.go index 5365f10..9a9e4a8 100644 --- a/replay/replay_test.go +++ b/replay/replay_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package replay @@ -14,22 +14,22 @@ import ( * */ -const RejectAfterMessages = (1 << 64) - (1 << 4) - 1 +const RejectAfterMessages = 1<<64 - 1<<13 - 1 func TestReplay(t *testing.T) { - var filter ReplayFilter + var filter Filter - T_LIM := CounterWindowSize + 1 + const T_LIM = windowSize + 1 testNumber := 0 - T := func(n uint64, v bool) { + T := func(n uint64, expected bool) { testNumber++ - if filter.ValidateCounter(n, RejectAfterMessages) != v { - t.Fatal("Test", testNumber, "failed", n, v) + if filter.ValidateCounter(n, RejectAfterMessages) != expected { + t.Fatal("Test", testNumber, "failed", n, expected) } } - filter.Init() + filter.Reset() T(0, true) /* 1 */ T(1, true) /* 2 */ @@ -67,53 +67,53 @@ func TestReplay(t *testing.T) { T(0, false) /* 34 */ t.Log("Bulk test 1") - filter.Init() + filter.Reset() testNumber = 0 - for i := uint64(1); i <= CounterWindowSize; i++ { + for i := uint64(1); i <= windowSize; i++ { T(i, true) } T(0, true) T(0, false) t.Log("Bulk test 2") - filter.Init() + filter.Reset() testNumber = 0 - for i := uint64(2); i <= CounterWindowSize+1; i++ { + for i := uint64(2); i <= windowSize+1; i++ { T(i, true) } T(1, true) T(0, false) t.Log("Bulk test 3") - filter.Init() + filter.Reset() testNumber = 0 - for i := CounterWindowSize + 1; i > 0; i-- { + for i := uint64(windowSize + 1); i > 0; i-- { T(i, true) } t.Log("Bulk test 4") - filter.Init() + filter.Reset() testNumber = 0 - for i := CounterWindowSize + 2; i > 1; i-- { + for i := uint64(windowSize + 2); i > 1; i-- { T(i, true) } T(0, false) t.Log("Bulk test 5") - filter.Init() + filter.Reset() testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { + for i := uint64(windowSize); i > 0; i-- { T(i, true) } - T(CounterWindowSize+1, true) + T(windowSize+1, true) T(0, false) t.Log("Bulk test 6") - filter.Init() + filter.Reset() testNumber = 0 - for i := CounterWindowSize; i > 0; i-- { + for i := uint64(windowSize); i > 0; i-- { T(i, true) } T(0, true) - T(CounterWindowSize+1, true) + T(windowSize+1, true) } diff --git a/rwcancel/fdset.go b/rwcancel/fdset.go deleted file mode 100644 index 91bd757..0000000 --- a/rwcancel/fdset.go +++ /dev/null @@ -1,24 +0,0 @@ -// +build !windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package rwcancel - -import "golang.org/x/sys/unix" - -type fdSet struct { - unix.FdSet -} - -func (fdset *fdSet) set(i int) { - bits := 32 << (^uint(0) >> 63) - fdset.Bits[i/bits] |= 1 << uint(i%bits) -} - -func (fdset *fdSet) check(i int) bool { - bits := 32 << (^uint(0) >> 63) - return (fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0 -} diff --git a/rwcancel/rwcancel.go b/rwcancel/rwcancel.go index 3c4a9ed..e397c0e 100644 --- a/rwcancel/rwcancel.go +++ b/rwcancel/rwcancel.go @@ -1,8 +1,8 @@ -// +build !windows +//go:build !windows && !wasm /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ // Package rwcancel implements cancelable read/write operations on @@ -17,13 +17,6 @@ import ( "golang.org/x/sys/unix" ) -func max(a, b int) int { - if a > b { - return a - } - return b -} - type RWCancel struct { fd int closingReader *os.File @@ -46,27 +39,16 @@ func NewRWCancel(fd int) (*RWCancel, error) { } func RetryAfterError(err error) bool { - if pe, ok := err.(*os.PathError); ok { - err = pe.Err - } - if errno, ok := err.(syscall.Errno); ok { - switch errno { - case syscall.EAGAIN, syscall.EINTR: - return true - } - - } - return false + return errors.Is(err, syscall.EAGAIN) || errors.Is(err, syscall.EINTR) } func (rw *RWCancel) ReadyRead() bool { - closeFd := int(rw.closingReader.Fd()) - fdset := fdSet{} - fdset.set(rw.fd) - fdset.set(closeFd) + closeFd := int32(rw.closingReader.Fd()) + + pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLIN}, {Fd: closeFd, Events: unix.POLLIN}} var err error for { - err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil) + _, err = unix.Poll(pollFds, -1) if err == nil || !RetryAfterError(err) { break } @@ -74,20 +56,18 @@ func (rw *RWCancel) ReadyRead() bool { if err != nil { return false } - if fdset.check(closeFd) { + if pollFds[1].Revents != 0 { return false } - return fdset.check(rw.fd) + return pollFds[0].Revents != 0 } func (rw *RWCancel) ReadyWrite() bool { - closeFd := int(rw.closingReader.Fd()) - fdset := fdSet{} - fdset.set(rw.fd) - fdset.set(closeFd) + closeFd := int32(rw.closingReader.Fd()) + pollFds := []unix.PollFd{{Fd: int32(rw.fd), Events: unix.POLLOUT}, {Fd: closeFd, Events: unix.POLLOUT}} var err error for { - err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil) + _, err = unix.Poll(pollFds, -1) if err == nil || !RetryAfterError(err) { break } @@ -95,10 +75,11 @@ func (rw *RWCancel) ReadyWrite() bool { if err != nil { return false } - if fdset.check(closeFd) { + + if pollFds[1].Revents != 0 { return false } - return fdset.check(rw.fd) + return pollFds[0].Revents != 0 } func (rw *RWCancel) Read(p []byte) (n int, err error) { @@ -108,7 +89,7 @@ func (rw *RWCancel) Read(p []byte) (n int, err error) { return n, err } if !rw.ReadyRead() { - return 0, errors.New("fd closed") + return 0, os.ErrClosed } } } @@ -120,7 +101,7 @@ func (rw *RWCancel) Write(p []byte) (n int, err error) { return n, err } if !rw.ReadyWrite() { - return 0, errors.New("fd closed") + return 0, os.ErrClosed } } } @@ -129,3 +110,8 @@ func (rw *RWCancel) Cancel() (err error) { _, err = rw.closingWriter.Write([]byte{0}) return } + +func (rw *RWCancel) Close() { + rw.closingReader.Close() + rw.closingWriter.Close() +} diff --git a/rwcancel/rwcancel_windows.go b/rwcancel/rwcancel_stub.go index 0316911..2a98b2b 100644 --- a/rwcancel/rwcancel_windows.go +++ b/rwcancel/rwcancel_stub.go @@ -1,8 +1,9 @@ +//go:build windows || wasm + // SPDX-License-Identifier: MIT package rwcancel -type RWCancel struct { -} +type RWCancel struct{} func (*RWCancel) Cancel() {} diff --git a/rwcancel/select_default.go b/rwcancel/select_default.go deleted file mode 100644 index 82e07dc..0000000 --- a/rwcancel/select_default.go +++ /dev/null @@ -1,15 +0,0 @@ -// +build !linux,!windows - -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package rwcancel - -import "golang.org/x/sys/unix" - -func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error { - _, err := unix.Select(nfd, r, w, e, timeout) - return err -} diff --git a/rwcancel/select_linux.go b/rwcancel/select_linux.go deleted file mode 100644 index 1a72e0a..0000000 --- a/rwcancel/select_linux.go +++ /dev/null @@ -1,13 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. - */ - -package rwcancel - -import "golang.org/x/sys/unix" - -func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) (err error) { - _, err = unix.Select(nfd, r, w, e, timeout) - return -} diff --git a/tai64n/tai64n.go b/tai64n/tai64n.go index 565aaa4..8f10b39 100644 --- a/tai64n/tai64n.go +++ b/tai64n/tai64n.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tai64n @@ -11,22 +11,31 @@ import ( "time" ) -const TimestampSize = 12 -const base = uint64(0x400000000000000a) -const whitenerMask = uint32(0x1000000 - 1) +const ( + TimestampSize = 12 + base = uint64(0x400000000000000a) + whitenerMask = uint32(0x1000000 - 1) +) type Timestamp [TimestampSize]byte -func Now() Timestamp { +func stamp(t time.Time) Timestamp { var tai64n Timestamp - now := time.Now() - secs := base + uint64(now.Unix()) - nano := uint32(now.Nanosecond()) &^ whitenerMask + secs := base + uint64(t.Unix()) + nano := uint32(t.Nanosecond()) &^ whitenerMask binary.BigEndian.PutUint64(tai64n[:], secs) binary.BigEndian.PutUint32(tai64n[8:], nano) return tai64n } +func Now() Timestamp { + return stamp(time.Now()) +} + func (t1 Timestamp) After(t2 Timestamp) bool { return bytes.Compare(t1[:], t2[:]) > 0 } + +func (t Timestamp) String() string { + return time.Unix(int64(binary.BigEndian.Uint64(t[:8])-base), int64(binary.BigEndian.Uint32(t[8:12]))).String() +} diff --git a/tai64n/tai64n_test.go b/tai64n/tai64n_test.go index 859660f..c70fc1a 100644 --- a/tai64n/tai64n_test.go +++ b/tai64n/tai64n_test.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tai64n @@ -10,21 +10,31 @@ import ( "time" ) -/* Testing the essential property of the timestamp - * as used by WireGuard. - */ +// Test that timestamps are monotonic as required by Wireguard and that +// nanosecond-level information is whitened to prevent side channel attacks. func TestMonotonic(t *testing.T) { - old := Now() - for i := 0; i < 50; i++ { - next := Now() - if next.After(old) { - t.Error("Whitening insufficient") - } - time.Sleep(time.Duration(whitenerMask)/time.Nanosecond + 1) - next = Now() - if !next.After(old) { - t.Error("Not monotonically increasing on whitened nano-second scale") - } - old = next + startTime := time.Unix(0, 123456789) // a nontrivial bit pattern + // Whitening should reduce timestamp granularity + // to more than 10 but fewer than 20 milliseconds. + tests := []struct { + name string + t1, t2 time.Time + wantAfter bool + }{ + {"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false}, + {"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false}, + {"after_1_ms", startTime, startTime.Add(time.Millisecond), false}, + {"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false}, + {"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ts1, ts2 := stamp(tt.t1), stamp(tt.t2) + got := ts2.After(ts1) + if got != tt.wantAfter { + t.Errorf("after = %v; want %v", got, tt.wantAfter) + } + }) } } diff --git a/tests/netns.sh b/tests/netns.sh index 02d428b..2f2a2cd 100755 --- a/tests/netns.sh +++ b/tests/netns.sh @@ -36,7 +36,7 @@ netns0="wg-test-$$-0" netns1="wg-test-$$-1" netns2="wg-test-$$-2" program=$1 -export LOG_LEVEL="info" +export LOG_LEVEL="verbose" pretty() { echo -e "\x1b[32m\x1b[1m[+] ${1:+NS$1: }${2}\x1b[0m" >&3; } pp() { pretty "" "$*"; "$@"; } diff --git a/tun/alignment_windows_test.go b/tun/alignment_windows_test.go new file mode 100644 index 0000000..67a785e --- /dev/null +++ b/tun/alignment_windows_test.go @@ -0,0 +1,67 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "reflect" + "testing" + "unsafe" +) + +func checkAlignment(t *testing.T, name string, offset uintptr) { + t.Helper() + if offset%8 != 0 { + t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8)) + } +} + +// TestRateJugglerAlignment checks that atomically-accessed fields are +// aligned to 64-bit boundaries, as required by the atomic package. +// +// Unfortunately, violating this rule on 32-bit platforms results in a +// hard segfault at runtime. +func TestRateJugglerAlignment(t *testing.T) { + var r rateJuggler + + typ := reflect.TypeOf(&r).Elem() + t.Logf("Peer type size: %d, with fields:", typ.Size()) + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", + field.Name, + field.Offset, + field.Type.Size(), + field.Type.Align(), + ) + } + + checkAlignment(t, "rateJuggler.current", unsafe.Offsetof(r.current)) + checkAlignment(t, "rateJuggler.nextByteCount", unsafe.Offsetof(r.nextByteCount)) + checkAlignment(t, "rateJuggler.nextStartTime", unsafe.Offsetof(r.nextStartTime)) +} + +// TestNativeTunAlignment checks that atomically-accessed fields are +// aligned to 64-bit boundaries, as required by the atomic package. +// +// Unfortunately, violating this rule on 32-bit platforms results in a +// hard segfault at runtime. +func TestNativeTunAlignment(t *testing.T) { + var tun NativeTun + + typ := reflect.TypeOf(&tun).Elem() + t.Logf("Peer type size: %d, with fields:", typ.Size()) + for i := 0; i < typ.NumField(); i++ { + field := typ.Field(i) + t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)", + field.Name, + field.Offset, + field.Type.Size(), + field.Type.Align(), + ) + } + + checkAlignment(t, "NativeTun.rate", unsafe.Offsetof(tun.rate)) +} diff --git a/tun/checksum.go b/tun/checksum.go new file mode 100644 index 0000000..29a8fc8 --- /dev/null +++ b/tun/checksum.go @@ -0,0 +1,118 @@ +package tun + +import "encoding/binary" + +// TODO: Explore SIMD and/or other assembly optimizations. +// TODO: Test native endian loads. See RFC 1071 section 2 part B. +func checksumNoFold(b []byte, initial uint64) uint64 { + ac := initial + + for len(b) >= 128 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + ac += uint64(binary.BigEndian.Uint32(b[64:68])) + ac += uint64(binary.BigEndian.Uint32(b[68:72])) + ac += uint64(binary.BigEndian.Uint32(b[72:76])) + ac += uint64(binary.BigEndian.Uint32(b[76:80])) + ac += uint64(binary.BigEndian.Uint32(b[80:84])) + ac += uint64(binary.BigEndian.Uint32(b[84:88])) + ac += uint64(binary.BigEndian.Uint32(b[88:92])) + ac += uint64(binary.BigEndian.Uint32(b[92:96])) + ac += uint64(binary.BigEndian.Uint32(b[96:100])) + ac += uint64(binary.BigEndian.Uint32(b[100:104])) + ac += uint64(binary.BigEndian.Uint32(b[104:108])) + ac += uint64(binary.BigEndian.Uint32(b[108:112])) + ac += uint64(binary.BigEndian.Uint32(b[112:116])) + ac += uint64(binary.BigEndian.Uint32(b[116:120])) + ac += uint64(binary.BigEndian.Uint32(b[120:124])) + ac += uint64(binary.BigEndian.Uint32(b[124:128])) + b = b[128:] + } + if len(b) >= 64 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + ac += uint64(binary.BigEndian.Uint32(b[32:36])) + ac += uint64(binary.BigEndian.Uint32(b[36:40])) + ac += uint64(binary.BigEndian.Uint32(b[40:44])) + ac += uint64(binary.BigEndian.Uint32(b[44:48])) + ac += uint64(binary.BigEndian.Uint32(b[48:52])) + ac += uint64(binary.BigEndian.Uint32(b[52:56])) + ac += uint64(binary.BigEndian.Uint32(b[56:60])) + ac += uint64(binary.BigEndian.Uint32(b[60:64])) + b = b[64:] + } + if len(b) >= 32 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + ac += uint64(binary.BigEndian.Uint32(b[16:20])) + ac += uint64(binary.BigEndian.Uint32(b[20:24])) + ac += uint64(binary.BigEndian.Uint32(b[24:28])) + ac += uint64(binary.BigEndian.Uint32(b[28:32])) + b = b[32:] + } + if len(b) >= 16 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + ac += uint64(binary.BigEndian.Uint32(b[8:12])) + ac += uint64(binary.BigEndian.Uint32(b[12:16])) + b = b[16:] + } + if len(b) >= 8 { + ac += uint64(binary.BigEndian.Uint32(b[:4])) + ac += uint64(binary.BigEndian.Uint32(b[4:8])) + b = b[8:] + } + if len(b) >= 4 { + ac += uint64(binary.BigEndian.Uint32(b)) + b = b[4:] + } + if len(b) >= 2 { + ac += uint64(binary.BigEndian.Uint16(b)) + b = b[2:] + } + if len(b) == 1 { + ac += uint64(b[0]) << 8 + } + + return ac +} + +func checksum(b []byte, initial uint64) uint16 { + ac := checksumNoFold(b, initial) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + ac = (ac >> 16) + (ac & 0xffff) + return uint16(ac) +} + +func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 { + sum := checksumNoFold(srcAddr, 0) + sum = checksumNoFold(dstAddr, sum) + sum = checksumNoFold([]byte{0, protocol}, sum) + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + return checksumNoFold(tmp, sum) +} diff --git a/tun/checksum_test.go b/tun/checksum_test.go new file mode 100644 index 0000000..c1ccff5 --- /dev/null +++ b/tun/checksum_test.go @@ -0,0 +1,35 @@ +package tun + +import ( + "fmt" + "math/rand" + "testing" +) + +func BenchmarkChecksum(b *testing.B) { + lengths := []int{ + 64, + 128, + 256, + 512, + 1024, + 1500, + 2048, + 4096, + 8192, + 9000, + 9001, + } + + for _, length := range lengths { + b.Run(fmt.Sprintf("%d", length), func(b *testing.B) { + buf := make([]byte, length) + rng := rand.New(rand.NewSource(1)) + rng.Read(buf) + b.ResetTimer() + for i := 0; i < b.N; i++ { + checksum(buf, 0) + } + }) + } +} diff --git a/tun/errors.go b/tun/errors.go new file mode 100644 index 0000000..75ae3a4 --- /dev/null +++ b/tun/errors.go @@ -0,0 +1,12 @@ +package tun + +import ( + "errors" +) + +var ( + // ErrTooManySegments is returned by Device.Read() when segmentation + // overflows the length of supplied buffers. This error should not cause + // reads to cease. + ErrTooManySegments = errors.New("too many segments") +) diff --git a/tun/netstack/examples/http_client.go b/tun/netstack/examples/http_client.go new file mode 100644 index 0000000..ccd32ed --- /dev/null +++ b/tun/netstack/examples/http_client.go @@ -0,0 +1,54 @@ +//go:build ignore + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package main + +import ( + "io" + "log" + "net/http" + "net/netip" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func main() { + tun, tnet, err := netstack.CreateNetTUN( + []netip.Addr{netip.MustParseAddr("192.168.4.28")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8")}, + 1420) + if err != nil { + log.Panic(err) + } + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + err = dev.IpcSet(`private_key=087ec6e14bbed210e7215cdc73468dfa23f080a1bfb8665b2fd809bd99d28379 +public_key=c4c8e984c5322c8184c72265b92b250fdb63688705f504ba003c88f03393cf28 +allowed_ip=0.0.0.0/0 +endpoint=127.0.0.1:58120 +`) + err = dev.Up() + if err != nil { + log.Panic(err) + } + + client := http.Client{ + Transport: &http.Transport{ + DialContext: tnet.DialContext, + }, + } + resp, err := client.Get("http://192.168.4.29/") + if err != nil { + log.Panic(err) + } + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Panic(err) + } + log.Println(string(body)) +} diff --git a/tun/netstack/examples/http_server.go b/tun/netstack/examples/http_server.go new file mode 100644 index 0000000..f5b7a8f --- /dev/null +++ b/tun/netstack/examples/http_server.go @@ -0,0 +1,51 @@ +//go:build ignore + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package main + +import ( + "io" + "log" + "net" + "net/http" + "net/netip" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func main() { + tun, tnet, err := netstack.CreateNetTUN( + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")}, + 1420, + ) + if err != nil { + log.Panic(err) + } + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + dev.IpcSet(`private_key=003ed5d73b55806c30de3f8a7bdab38af13539220533055e635690b8b87ad641 +listen_port=58120 +public_key=f928d4f6c1b86c12f2562c10b07c555c5c57fd00f59e90c8d8d88767271cbf7c +allowed_ip=192.168.4.28/32 +persistent_keepalive_interval=25 +`) + dev.Up() + listener, err := tnet.ListenTCP(&net.TCPAddr{Port: 80}) + if err != nil { + log.Panicln(err) + } + http.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { + log.Printf("> %s - %s - %s", request.RemoteAddr, request.URL.String(), request.UserAgent()) + io.WriteString(writer, "Hello from userspace TCP!") + }) + err = http.Serve(listener, nil) + if err != nil { + log.Panicln(err) + } +} diff --git a/tun/netstack/examples/ping_client.go b/tun/netstack/examples/ping_client.go new file mode 100644 index 0000000..2eef0fb --- /dev/null +++ b/tun/netstack/examples/ping_client.go @@ -0,0 +1,75 @@ +//go:build ignore + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package main + +import ( + "bytes" + "log" + "math/rand" + "net/netip" + "time" + + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func main() { + tun, tnet, err := netstack.CreateNetTUN( + []netip.Addr{netip.MustParseAddr("192.168.4.29")}, + []netip.Addr{netip.MustParseAddr("8.8.8.8")}, + 1420) + if err != nil { + log.Panic(err) + } + dev := device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(device.LogLevelVerbose, "")) + dev.IpcSet(`private_key=a8dac1d8a70a751f0f699fb14ba1cff7b79cf4fbd8f09f44c6e6a90d0369604f +public_key=25123c5dcd3328ff645e4f2a3fce0d754400d3887a0cb7c56f0267e20fbf3c5b +endpoint=163.172.161.0:12912 +allowed_ip=0.0.0.0/0 +`) + err = dev.Up() + if err != nil { + log.Panic(err) + } + + socket, err := tnet.Dial("ping4", "zx2c4.com") + if err != nil { + log.Panic(err) + } + requestPing := icmp.Echo{ + Seq: rand.Intn(1 << 16), + Data: []byte("gopher burrow"), + } + icmpBytes, _ := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + socket.SetReadDeadline(time.Now().Add(time.Second * 10)) + start := time.Now() + _, err = socket.Write(icmpBytes) + if err != nil { + log.Panic(err) + } + n, err := socket.Read(icmpBytes[:]) + if err != nil { + log.Panic(err) + } + replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) + if err != nil { + log.Panic(err) + } + replyPing, ok := replyPacket.Body.(*icmp.Echo) + if !ok { + log.Panicf("invalid reply type: %v", replyPacket) + } + if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { + log.Panicf("invalid ping reply: %v", replyPing) + } + log.Printf("Ping latency: %v", time.Since(start)) +} diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go new file mode 100644 index 0000000..2b73054 --- /dev/null +++ b/tun/netstack/tun.go @@ -0,0 +1,1055 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package netstack + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "net/netip" + "os" + "regexp" + "strconv" + "strings" + "syscall" + "time" + + "golang.zx2c4.com/wireguard/tun" + + "golang.org/x/net/dns/dnsmessage" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +type netTun struct { + ep *channel.Endpoint + stack *stack.Stack + events chan tun.Event + incomingPacket chan *buffer.View + mtu int + dnsServers []netip.Addr + hasV4, hasV6 bool +} + +type Net netTun + +func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { + opts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4}, + HandleLocal: true, + } + dev := &netTun{ + ep: channel.New(1024, uint32(mtu), ""), + stack: stack.New(opts), + events: make(chan tun.Event, 10), + incomingPacket: make(chan *buffer.View), + dnsServers: dnsServers, + mtu: mtu, + } + sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default + tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) + } + dev.ep.AddNotify(dev) + tcpipErr = dev.stack.CreateNIC(1, dev.ep) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr) + } + for _, ip := range localAddresses { + var protoNumber tcpip.NetworkProtocolNumber + if ip.Is4() { + protoNumber = ipv4.ProtocolNumber + } else if ip.Is6() { + protoNumber = ipv6.ProtocolNumber + } + protoAddr := tcpip.ProtocolAddress{ + Protocol: protoNumber, + AddressWithPrefix: tcpip.AddrFromSlice(ip.AsSlice()).WithPrefix(), + } + tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}) + if tcpipErr != nil { + return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr) + } + if ip.Is4() { + dev.hasV4 = true + } else if ip.Is6() { + dev.hasV6 = true + } + } + if dev.hasV4 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv4EmptySubnet, NIC: 1}) + } + if dev.hasV6 { + dev.stack.AddRoute(tcpip.Route{Destination: header.IPv6EmptySubnet, NIC: 1}) + } + + dev.events <- tun.EventUp + return dev, (*Net)(dev), nil +} + +func (tun *netTun) Name() (string, error) { + return "go", nil +} + +func (tun *netTun) File() *os.File { + return nil +} + +func (tun *netTun) Events() <-chan tun.Event { + return tun.events +} + +func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { + view, ok := <-tun.incomingPacket + if !ok { + return 0, os.ErrClosed + } + + n, err := view.Read(buf[0][offset:]) + if err != nil { + return 0, err + } + sizes[0] = n + return 1, nil +} + +func (tun *netTun) Write(buf [][]byte, offset int) (int, error) { + for _, buf := range buf { + packet := buf[offset:] + if len(packet) == 0 { + continue + } + + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{Payload: buffer.MakeWithData(packet)}) + switch packet[0] >> 4 { + case 4: + tun.ep.InjectInbound(header.IPv4ProtocolNumber, pkb) + case 6: + tun.ep.InjectInbound(header.IPv6ProtocolNumber, pkb) + default: + return 0, syscall.EAFNOSUPPORT + } + } + return len(buf), nil +} + +func (tun *netTun) WriteNotify() { + pkt := tun.ep.Read() + if pkt.IsNil() { + return + } + + view := pkt.ToView() + pkt.DecRef() + + tun.incomingPacket <- view +} + +func (tun *netTun) Close() error { + tun.stack.RemoveNIC(1) + + if tun.events != nil { + close(tun.events) + } + + tun.ep.Close() + + if tun.incomingPacket != nil { + close(tun.incomingPacket) + } + + return nil +} + +func (tun *netTun) MTU() (int, error) { + return tun.mtu, nil +} + +func (tun *netTun) BatchSize() int { + return 1 +} + +func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) { + var protoNumber tcpip.NetworkProtocolNumber + if endpoint.Addr().Is4() { + protoNumber = ipv4.ProtocolNumber + } else { + protoNumber = ipv6.ProtocolNumber + } + return tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFromSlice(endpoint.Addr().AsSlice()), + Port: endpoint.Port(), + }, protoNumber +} + +func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialContextTCP(ctx, net.stack, fa, pn) +} + +func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialContextTCPAddrPort(ctx, netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) { + fa, pn := convertToFullAddr(addr) + return gonet.DialTCP(net.stack, fa, pn) +} + +func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) { + if addr == nil { + return net.DialTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.DialTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) { + fa, pn := convertToFullAddr(addr) + return gonet.ListenTCP(net.stack, fa, pn) +} + +func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) { + if addr == nil { + return net.ListenTCPAddrPort(netip.AddrPort{}) + } + ip, _ := netip.AddrFromSlice(addr.IP) + return net.ListenTCPAddrPort(netip.AddrPortFrom(ip, uint16(addr.Port))) +} + +func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) { + var lfa, rfa *tcpip.FullAddress + var pn tcpip.NetworkProtocolNumber + if laddr.IsValid() || laddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(laddr) + lfa = &addr + } + if raddr.IsValid() || raddr.Port() > 0 { + var addr tcpip.FullAddress + addr, pn = convertToFullAddr(raddr) + rfa = &addr + } + return gonet.DialUDP(net.stack, lfa, rfa, pn) +} + +func (net *Net) ListenUDPAddrPort(laddr netip.AddrPort) (*gonet.UDPConn, error) { + return net.DialUDPAddrPort(laddr, netip.AddrPort{}) +} + +func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) { + var la, ra netip.AddrPort + if laddr != nil { + ip, _ := netip.AddrFromSlice(laddr.IP) + la = netip.AddrPortFrom(ip, uint16(laddr.Port)) + } + if raddr != nil { + ip, _ := netip.AddrFromSlice(raddr.IP) + ra = netip.AddrPortFrom(ip, uint16(raddr.Port)) + } + return net.DialUDPAddrPort(la, ra) +} + +func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { + return net.DialUDP(laddr, nil) +} + +type PingConn struct { + laddr PingAddr + raddr PingAddr + wq waiter.Queue + ep tcpip.Endpoint + deadline *time.Timer +} + +type PingAddr struct{ addr netip.Addr } + +func (ia PingAddr) String() string { + return ia.addr.String() +} + +func (ia PingAddr) Network() string { + if ia.addr.Is4() { + return "ping4" + } else if ia.addr.Is6() { + return "ping6" + } + return "ping" +} + +func (ia PingAddr) Addr() netip.Addr { + return ia.addr +} + +func PingAddrFromAddr(addr netip.Addr) *PingAddr { + return &PingAddr{addr} +} + +func (net *Net) DialPingAddr(laddr, raddr netip.Addr) (*PingConn, error) { + if !laddr.IsValid() && !raddr.IsValid() { + return nil, errors.New("ping dial: invalid address") + } + v6 := laddr.Is6() || raddr.Is6() + bind := laddr.IsValid() + if !bind { + if v6 { + laddr = netip.IPv6Unspecified() + } else { + laddr = netip.IPv4Unspecified() + } + } + + tn := icmp.ProtocolNumber4 + pn := ipv4.ProtocolNumber + if v6 { + tn = icmp.ProtocolNumber6 + pn = ipv6.ProtocolNumber + } + + pc := &PingConn{ + laddr: PingAddr{laddr}, + deadline: time.NewTimer(time.Hour << 10), + } + pc.deadline.Stop() + + ep, tcpipErr := net.stack.NewEndpoint(tn, pn, &pc.wq) + if tcpipErr != nil { + return nil, fmt.Errorf("ping socket: endpoint: %s", tcpipErr) + } + pc.ep = ep + + if bind { + fa, _ := convertToFullAddr(netip.AddrPortFrom(laddr, 0)) + if tcpipErr = pc.ep.Bind(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping bind: %s", tcpipErr) + } + } + + if raddr.IsValid() { + pc.raddr = PingAddr{raddr} + fa, _ := convertToFullAddr(netip.AddrPortFrom(raddr, 0)) + if tcpipErr = pc.ep.Connect(fa); tcpipErr != nil { + return nil, fmt.Errorf("ping connect: %s", tcpipErr) + } + } + + return pc, nil +} + +func (net *Net) ListenPingAddr(laddr netip.Addr) (*PingConn, error) { + return net.DialPingAddr(laddr, netip.Addr{}) +} + +func (net *Net) DialPing(laddr, raddr *PingAddr) (*PingConn, error) { + var la, ra netip.Addr + if laddr != nil { + la = laddr.addr + } + if raddr != nil { + ra = raddr.addr + } + return net.DialPingAddr(la, ra) +} + +func (net *Net) ListenPing(laddr *PingAddr) (*PingConn, error) { + var la netip.Addr + if laddr != nil { + la = laddr.addr + } + return net.ListenPingAddr(la) +} + +func (pc *PingConn) LocalAddr() net.Addr { + return pc.laddr +} + +func (pc *PingConn) RemoteAddr() net.Addr { + return pc.raddr +} + +func (pc *PingConn) Close() error { + pc.deadline.Reset(0) + pc.ep.Close() + return nil +} + +func (pc *PingConn) SetWriteDeadline(t time.Time) error { + return errors.New("not implemented") +} + +func (pc *PingConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + var na netip.Addr + switch v := addr.(type) { + case *PingAddr: + na = v.addr + case *net.IPAddr: + na, _ = netip.AddrFromSlice(v.IP) + default: + return 0, fmt.Errorf("ping write: wrong net.Addr type") + } + if !((na.Is4() && pc.laddr.addr.Is4()) || (na.Is6() && pc.laddr.addr.Is6())) { + return 0, fmt.Errorf("ping write: mismatched protocols") + } + + buf := bytes.NewReader(p) + rfa, _ := convertToFullAddr(netip.AddrPortFrom(na, 0)) + // won't block, no deadlines + n64, tcpipErr := pc.ep.Write(buf, tcpip.WriteOptions{ + To: &rfa, + }) + if tcpipErr != nil { + return int(n64), fmt.Errorf("ping write: %s", tcpipErr) + } + + return int(n64), nil +} + +func (pc *PingConn) Write(p []byte) (n int, err error) { + return pc.WriteTo(p, &pc.raddr) +} + +func (pc *PingConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + e, notifyCh := waiter.NewChannelEntry(waiter.EventIn) + pc.wq.EventRegister(&e) + defer pc.wq.EventUnregister(&e) + + select { + case <-pc.deadline.C: + return 0, nil, os.ErrDeadlineExceeded + case <-notifyCh: + } + + w := tcpip.SliceWriter(p) + + res, tcpipErr := pc.ep.Read(&w, tcpip.ReadOptions{ + NeedRemoteAddr: true, + }) + if tcpipErr != nil { + return 0, nil, fmt.Errorf("ping read: %s", tcpipErr) + } + + remoteAddr, _ := netip.AddrFromSlice(res.RemoteAddr.Addr.AsSlice()) + return res.Count, &PingAddr{remoteAddr}, nil +} + +func (pc *PingConn) Read(p []byte) (n int, err error) { + n, _, err = pc.ReadFrom(p) + return +} + +func (pc *PingConn) SetDeadline(t time.Time) error { + // pc.SetWriteDeadline is unimplemented + + return pc.SetReadDeadline(t) +} + +func (pc *PingConn) SetReadDeadline(t time.Time) error { + pc.deadline.Reset(time.Until(t)) + return nil +} + +var ( + errNoSuchHost = errors.New("no such host") + errLameReferral = errors.New("lame referral") + errCannotUnmarshalDNSMessage = errors.New("cannot unmarshal DNS message") + errCannotMarshalDNSMessage = errors.New("cannot marshal DNS message") + errServerMisbehaving = errors.New("server misbehaving") + errInvalidDNSResponse = errors.New("invalid DNS response") + errNoAnswerFromDNSServer = errors.New("no answer from DNS server") + errServerTemporarilyMisbehaving = errors.New("server misbehaving") + errCanceled = errors.New("operation was canceled") + errTimeout = errors.New("i/o timeout") + errNumericPort = errors.New("port must be numeric") + errNoSuitableAddress = errors.New("no suitable address found") + errMissingAddress = errors.New("missing address") +) + +func (net *Net) LookupHost(host string) (addrs []string, err error) { + return net.LookupContextHost(context.Background(), host) +} + +func isDomainName(s string) bool { + l := len(s) + if l == 0 || l > 254 || l == 254 && s[l-1] != '.' { + return false + } + last := byte('.') + nonNumeric := false + partlen := 0 + for i := 0; i < len(s); i++ { + c := s[i] + switch { + default: + return false + case 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || c == '_': + nonNumeric = true + partlen++ + case '0' <= c && c <= '9': + partlen++ + case c == '-': + if last == '.' { + return false + } + partlen++ + nonNumeric = true + case c == '.': + if last == '.' || last == '-' { + return false + } + if partlen > 63 || partlen == 0 { + return false + } + partlen = 0 + } + last = c + } + if last == '-' || partlen > 63 { + return false + } + return nonNumeric +} + +func randU16() uint16 { + var b [2]byte + _, err := rand.Read(b[:]) + if err != nil { + panic(err) + } + return binary.LittleEndian.Uint16(b[:]) +} + +func newRequest(q dnsmessage.Question) (id uint16, udpReq, tcpReq []byte, err error) { + id = randU16() + b := dnsmessage.NewBuilder(make([]byte, 2, 514), dnsmessage.Header{ID: id, RecursionDesired: true}) + b.EnableCompression() + if err := b.StartQuestions(); err != nil { + return 0, nil, nil, err + } + if err := b.Question(q); err != nil { + return 0, nil, nil, err + } + tcpReq, err = b.Finish() + udpReq = tcpReq[2:] + l := len(tcpReq) - 2 + tcpReq[0] = byte(l >> 8) + tcpReq[1] = byte(l) + return id, udpReq, tcpReq, err +} + +func equalASCIIName(x, y dnsmessage.Name) bool { + if x.Length != y.Length { + return false + } + for i := 0; i < int(x.Length); i++ { + a := x.Data[i] + b := y.Data[i] + if 'A' <= a && a <= 'Z' { + a += 0x20 + } + if 'A' <= b && b <= 'Z' { + b += 0x20 + } + if a != b { + return false + } + } + return true +} + +func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage.Header, respQues dnsmessage.Question) bool { + if !respHdr.Response { + return false + } + if reqID != respHdr.ID { + return false + } + if reqQues.Type != respQues.Type || reqQues.Class != respQues.Class || !equalASCIIName(reqQues.Name, respQues.Name) { + return false + } + return true +} + +func dnsPacketRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 512) + for { + n, err := c.Read(b) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + continue + } + q, err := p.Question() + if err != nil || !checkResponse(id, query, h, q) { + continue + } + return p, h, nil + } +} + +func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []byte) (dnsmessage.Parser, dnsmessage.Header, error) { + if _, err := c.Write(b); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + b = make([]byte, 1280) + if _, err := io.ReadFull(c, b[:2]); err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + l := int(b[0])<<8 | int(b[1]) + if l > len(b) { + b = make([]byte, l) + } + n, err := io.ReadFull(c, b[:l]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + var p dnsmessage.Parser + h, err := p.Start(b[:n]) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + q, err := p.Question() + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotUnmarshalDNSMessage + } + if !checkResponse(id, query, h, q) { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + return p, h, nil +} + +func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) { + q.Class = dnsmessage.ClassINET + id, udpReq, tcpReq, err := newRequest(q) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, errCannotMarshalDNSMessage + } + + for _, useUDP := range []bool{true, false} { + ctx, cancel := context.WithDeadline(ctx, time.Now().Add(timeout)) + defer cancel() + + var c net.Conn + var err error + if useUDP { + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53)) + } else { + c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53)) + } + + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if d, ok := ctx.Deadline(); ok && !d.IsZero() { + err := c.SetDeadline(d) + if err != nil { + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + } + var p dnsmessage.Parser + var h dnsmessage.Header + if useUDP { + p, h, err = dnsPacketRoundTrip(c, id, q, udpReq) + } else { + p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq) + } + c.Close() + if err != nil { + if err == context.Canceled { + err = errCanceled + } else if err == context.DeadlineExceeded { + err = errTimeout + } + return dnsmessage.Parser{}, dnsmessage.Header{}, err + } + if err := p.SkipQuestion(); err != dnsmessage.ErrSectionDone { + return dnsmessage.Parser{}, dnsmessage.Header{}, errInvalidDNSResponse + } + if h.Truncated { + continue + } + return p, h, nil + } + return dnsmessage.Parser{}, dnsmessage.Header{}, errNoAnswerFromDNSServer +} + +func checkHeader(p *dnsmessage.Parser, h dnsmessage.Header) error { + if h.RCode == dnsmessage.RCodeNameError { + return errNoSuchHost + } + _, err := p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + return errCannotUnmarshalDNSMessage + } + if h.RCode == dnsmessage.RCodeSuccess && !h.Authoritative && !h.RecursionAvailable && err == dnsmessage.ErrSectionDone { + return errLameReferral + } + if h.RCode != dnsmessage.RCodeSuccess && h.RCode != dnsmessage.RCodeNameError { + if h.RCode == dnsmessage.RCodeServerFailure { + return errServerTemporarilyMisbehaving + } + return errServerMisbehaving + } + return nil +} + +func skipToAnswer(p *dnsmessage.Parser, qtype dnsmessage.Type) error { + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + return errNoSuchHost + } + if err != nil { + return errCannotUnmarshalDNSMessage + } + if h.Type == qtype { + return nil + } + if err := p.SkipAnswer(); err != nil { + return errCannotUnmarshalDNSMessage + } + } +} + +func (tnet *Net) tryOneName(ctx context.Context, name string, qtype dnsmessage.Type) (dnsmessage.Parser, string, error) { + var lastErr error + + n, err := dnsmessage.NewName(name) + if err != nil { + return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage + } + q := dnsmessage.Question{ + Name: n, + Type: qtype, + Class: dnsmessage.ClassINET, + } + + for i := 0; i < 2; i++ { + for _, server := range tnet.dnsServers { + p, h, err := tnet.exchange(ctx, server, q, time.Second*5) + if err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if nerr, ok := err.(net.Error); ok && nerr.Timeout() { + dnsErr.IsTimeout = true + } + if _, ok := err.(*net.OpError); ok { + dnsErr.IsTemporary = true + } + lastErr = dnsErr + continue + } + + if err := checkHeader(&p, h); err != nil { + dnsErr := &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errServerTemporarilyMisbehaving { + dnsErr.IsTemporary = true + } + if err == errNoSuchHost { + dnsErr.IsNotFound = true + return p, server.String(), dnsErr + } + lastErr = dnsErr + continue + } + + err = skipToAnswer(&p, qtype) + if err == nil { + return p, server.String(), nil + } + lastErr = &net.DNSError{ + Err: err.Error(), + Name: name, + Server: server.String(), + } + if err == errNoSuchHost { + lastErr.(*net.DNSError).IsNotFound = true + return p, server.String(), lastErr + } + } + } + return dnsmessage.Parser{}, "", lastErr +} + +func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string, error) { + if host == "" || (!tnet.hasV6 && !tnet.hasV4) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + zlen := len(host) + if strings.IndexByte(host, ':') != -1 { + if zidx := strings.LastIndexByte(host, '%'); zidx != -1 { + zlen = zidx + } + } + if ip, err := netip.ParseAddr(host[:zlen]); err == nil { + return []string{ip.String()}, nil + } + + if !isDomainName(host) { + return nil, &net.DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + } + type result struct { + p dnsmessage.Parser + server string + error + } + var addrsV4, addrsV6 []netip.Addr + lanes := 0 + if tnet.hasV4 { + lanes++ + } + if tnet.hasV6 { + lanes++ + } + lane := make(chan result, lanes) + var lastErr error + if tnet.hasV4 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeA) + lane <- result{p, server, err} + }() + } + if tnet.hasV6 { + go func() { + p, server, err := tnet.tryOneName(ctx, host+".", dnsmessage.TypeAAAA) + lane <- result{p, server, err} + }() + } + for l := 0; l < lanes; l++ { + result := <-lane + if result.error != nil { + if lastErr == nil { + lastErr = result.error + } + continue + } + + loop: + for { + h, err := result.p.AnswerHeader() + if err != nil && err != dnsmessage.ErrSectionDone { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + } + if err != nil { + break + } + switch h.Type { + case dnsmessage.TypeA: + a, err := result.p.AResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV4 = append(addrsV4, netip.AddrFrom4(a.A)) + + case dnsmessage.TypeAAAA: + aaaa, err := result.p.AAAAResource() + if err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA)) + + default: + if err := result.p.SkipAnswer(); err != nil { + lastErr = &net.DNSError{ + Err: errCannotMarshalDNSMessage.Error(), + Name: host, + Server: result.server, + } + break loop + } + continue + } + } + } + // We don't do RFC6724. Instead just put V6 addresses first if an IPv6 address is enabled + var addrs []netip.Addr + if tnet.hasV6 { + addrs = append(addrsV6, addrsV4...) + } else { + addrs = append(addrsV4, addrsV6...) + } + + if len(addrs) == 0 && lastErr != nil { + return nil, lastErr + } + saddrs := make([]string, 0, len(addrs)) + for _, ip := range addrs { + saddrs = append(saddrs, ip.String()) + } + return saddrs, nil +} + +func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) { + if deadline.IsZero() { + return deadline, nil + } + timeRemaining := deadline.Sub(now) + if timeRemaining <= 0 { + return time.Time{}, errTimeout + } + timeout := timeRemaining / time.Duration(addrsRemaining) + const saneMinimum = 2 * time.Second + if timeout < saneMinimum { + if timeRemaining < saneMinimum { + timeout = timeRemaining + } else { + timeout = saneMinimum + } + } + return now.Add(timeout), nil +} + +var protoSplitter = regexp.MustCompile(`^(tcp|udp|ping)(4|6)?$`) + +func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if ctx == nil { + panic("nil context") + } + var acceptV4, acceptV6 bool + matches := protoSplitter.FindStringSubmatch(network) + if matches == nil { + return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)} + } else if len(matches[2]) == 0 { + acceptV4 = true + acceptV6 = true + } else { + acceptV4 = matches[2][0] == '4' + acceptV6 = !acceptV4 + } + var host string + var port int + if matches[1] == "ping" { + host = address + } else { + var sport string + var err error + host, sport, err = net.SplitHostPort(address) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + port, err = strconv.Atoi(sport) + if err != nil || port < 0 || port > 65535 { + return nil, &net.OpError{Op: "dial", Err: errNumericPort} + } + } + allAddr, err := tnet.LookupContextHost(ctx, host) + if err != nil { + return nil, &net.OpError{Op: "dial", Err: err} + } + var addrs []netip.AddrPort + for _, addr := range allAddr { + ip, err := netip.ParseAddr(addr) + if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) { + addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port))) + } + } + if len(addrs) == 0 && len(allAddr) != 0 { + return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress} + } + + var firstErr error + for i, addr := range addrs { + select { + case <-ctx.Done(): + err := ctx.Err() + if err == context.Canceled { + err = errCanceled + } else if err == context.DeadlineExceeded { + err = errTimeout + } + return nil, &net.OpError{Op: "dial", Err: err} + default: + } + + dialCtx := ctx + if deadline, hasDeadline := ctx.Deadline(); hasDeadline { + partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i) + if err != nil { + if firstErr == nil { + firstErr = &net.OpError{Op: "dial", Err: err} + } + break + } + if partialDeadline.Before(deadline) { + var cancel context.CancelFunc + dialCtx, cancel = context.WithDeadline(ctx, partialDeadline) + defer cancel() + } + } + + var c net.Conn + switch matches[1] { + case "tcp": + c, err = tnet.DialContextTCPAddrPort(dialCtx, addr) + case "udp": + c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr) + case "ping": + c, err = tnet.DialPingAddr(netip.Addr{}, addr.Addr()) + } + if err == nil { + return c, nil + } + if firstErr == nil { + firstErr = err + } + } + if firstErr == nil { + firstErr = &net.OpError{Op: "dial", Err: errMissingAddress} + } + return nil, firstErr +} + +func (tnet *Net) Dial(network, address string) (net.Conn, error) { + return tnet.DialContext(context.Background(), network, address) +} diff --git a/tun/offload_linux.go b/tun/offload_linux.go new file mode 100644 index 0000000..9ff7fea --- /dev/null +++ b/tun/offload_linux.go @@ -0,0 +1,993 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "unsafe" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" +) + +const tcpFlagsOffset = 13 + +const ( + tcpFlagFIN uint8 = 0x01 + tcpFlagPSH uint8 = 0x08 + tcpFlagACK uint8 = 0x10 +) + +// virtioNetHdr is defined in the kernel in include/uapi/linux/virtio_net.h. The +// kernel symbol is virtio_net_hdr. +type virtioNetHdr struct { + flags uint8 + gsoType uint8 + hdrLen uint16 + gsoSize uint16 + csumStart uint16 + csumOffset uint16 +} + +func (v *virtioNetHdr) decode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen), b[:virtioNetHdrLen]) + return nil +} + +func (v *virtioNetHdr) encode(b []byte) error { + if len(b) < virtioNetHdrLen { + return io.ErrShortBuffer + } + copy(b[:virtioNetHdrLen], unsafe.Slice((*byte)(unsafe.Pointer(v)), virtioNetHdrLen)) + return nil +} + +const ( + // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the + // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). + virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) +) + +// tcpFlowKey represents the key for a TCP flow. +type tcpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + rxAck uint32 // varying ack values should not be coalesced. Treat them as separate flows. + isV6 bool +} + +// tcpGROTable holds flow and coalescing information for the purposes of TCP GRO. +type tcpGROTable struct { + itemsByFlow map[tcpFlowKey][]tcpGROItem + itemsPool [][]tcpGROItem +} + +func newTCPGROTable() *tcpGROTable { + t := &tcpGROTable{ + itemsByFlow: make(map[tcpFlowKey][]tcpGROItem, conn.IdealBatchSize), + itemsPool: make([][]tcpGROItem, conn.IdealBatchSize), + } + for i := range t.itemsPool { + t.itemsPool[i] = make([]tcpGROItem, 0, conn.IdealBatchSize) + } + return t +} + +func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcpFlowKey { + key := tcpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[tcphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[tcphOffset+2:]) + key.rxAck = binary.BigEndian.Uint32(pkt[tcphOffset+8:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + items, ok := t.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { + key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + item := tcpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + } + items, ok := t.itemsByFlow[key] + if !ok { + items = t.newItems() + } + items = append(items, item) + t.itemsByFlow[key] = items +} + +func (t *tcpGROTable) updateAt(item tcpGROItem, i int) { + items, _ := t.itemsByFlow[item.key] + items[i] = item +} + +func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { + items, _ := t.itemsByFlow[key] + items = append(items[:i], items[i+1:]...) + t.itemsByFlow[key] = items +} + +// tcpGROItem represents bookkeeping data for a TCP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type tcpGROItem struct { + key tcpFlowKey + sentSeq uint32 // the sequence number + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set +} + +func (t *tcpGROTable) newItems() []tcpGROItem { + var items []tcpGROItem + items, t.itemsPool = t.itemsPool[len(t.itemsPool)-1], t.itemsPool[:len(t.itemsPool)-1] + return items +} + +func (t *tcpGROTable) reset() { + for k, items := range t.itemsByFlow { + items = items[:0] + t.itemsPool = append(t.itemsPool, items) + delete(t.itemsByFlow, k) + } +} + +// udpFlowKey represents the key for a UDP flow. +type udpFlowKey struct { + srcAddr, dstAddr [16]byte + srcPort, dstPort uint16 + isV6 bool +} + +// udpGROTable holds flow and coalescing information for the purposes of UDP GRO. +type udpGROTable struct { + itemsByFlow map[udpFlowKey][]udpGROItem + itemsPool [][]udpGROItem +} + +func newUDPGROTable() *udpGROTable { + u := &udpGROTable{ + itemsByFlow: make(map[udpFlowKey][]udpGROItem, conn.IdealBatchSize), + itemsPool: make([][]udpGROItem, conn.IdealBatchSize), + } + for i := range u.itemsPool { + u.itemsPool[i] = make([]udpGROItem, 0, conn.IdealBatchSize) + } + return u +} + +func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udpFlowKey { + key := udpFlowKey{} + addrSize := dstAddrOffset - srcAddrOffset + copy(key.srcAddr[:], pkt[srcAddrOffset:dstAddrOffset]) + copy(key.dstAddr[:], pkt[dstAddrOffset:dstAddrOffset+addrSize]) + key.srcPort = binary.BigEndian.Uint16(pkt[udphOffset:]) + key.dstPort = binary.BigEndian.Uint16(pkt[udphOffset+2:]) + key.isV6 = addrSize == 16 + return key +} + +// lookupOrInsert looks up a flow for the provided packet and metadata, +// returning the packets found for the flow, or inserting a new one if none +// is found. +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + items, ok := u.itemsByFlow[key] + if ok { + return items, ok + } + // TODO: insert() performs another map lookup. This could be rearranged to avoid. + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + return nil, false +} + +// insert an item in the table for the provided packet and packet metadata. +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { + key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + item := udpGROItem{ + key: key, + bufsIndex: uint16(bufsIndex), + gsoSize: uint16(len(pkt[udphOffset+udphLen:])), + iphLen: uint8(udphOffset), + cSumKnownInvalid: cSumKnownInvalid, + } + items, ok := u.itemsByFlow[key] + if !ok { + items = u.newItems() + } + items = append(items, item) + u.itemsByFlow[key] = items +} + +func (u *udpGROTable) updateAt(item udpGROItem, i int) { + items, _ := u.itemsByFlow[item.key] + items[i] = item +} + +// udpGROItem represents bookkeeping data for a UDP packet during the lifetime +// of a GRO evaluation across a vector of packets. +type udpGROItem struct { + key udpFlowKey + bufsIndex uint16 // the index into the original bufs slice + numMerged uint16 // the number of packets merged into this item + gsoSize uint16 // payload size + iphLen uint8 // ip header len + cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. +} + +func (u *udpGROTable) newItems() []udpGROItem { + var items []udpGROItem + items, u.itemsPool = u.itemsPool[len(u.itemsPool)-1], u.itemsPool[:len(u.itemsPool)-1] + return items +} + +func (u *udpGROTable) reset() { + for k, items := range u.itemsByFlow { + items = items[:0] + u.itemsPool = append(u.itemsPool, items) + delete(u.itemsByFlow, k) + } +} + +// canCoalesce represents the outcome of checking if two TCP packets are +// candidates for coalescing. +type canCoalesce int + +const ( + coalescePrepend canCoalesce = -1 + coalesceUnavailable canCoalesce = 0 + coalesceAppend canCoalesce = 1 +) + +// ipHeadersCanCoalesce returns true if the IP headers found in pktA and pktB +// meet all requirements to be merged as part of a GRO operation, otherwise it +// returns false. +func ipHeadersCanCoalesce(pktA, pktB []byte) bool { + if len(pktA) < 9 || len(pktB) < 9 { + return false + } + if pktA[0]>>4 == 6 { + if pktA[0] != pktB[0] || pktA[1]>>4 != pktB[1]>>4 { + // cannot coalesce with unequal Traffic class values + return false + } + if pktA[7] != pktB[7] { + // cannot coalesce with unequal Hop limit values + return false + } + } else { + if pktA[1] != pktB[1] { + // cannot coalesce with unequal ToS values + return false + } + if pktA[6]>>5 != pktB[6]>>5 { + // cannot coalesce with unequal DF or reserved bits. MF is checked + // further up the stack. + return false + } + if pktA[8] != pktB[8] { + // cannot coalesce with unequal TTL values + return false + } + } + return true +} + +// udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. iphLen and gsoSize describe pkt. bufs is the vector of +// packets involved in the current GRO evaluation. bufsOffset is the offset at +// which packet data begins within bufs. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend +} + +// tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet +// described by item. This function makes considerations that match the kernel's +// GRO self tests, which can be found in tools/testing/selftests/net/gro.c. +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { + pktTarget := bufs[item.bufsIndex][bufsOffset:] + if tcphLen != item.tcphLen { + // cannot coalesce with unequal tcp options len + return coalesceUnavailable + } + if tcphLen > 20 { + if !bytes.Equal(pkt[iphLen+20:iphLen+tcphLen], pktTarget[item.iphLen+20:iphLen+tcphLen]) { + // cannot coalesce with unequal tcp options + return coalesceUnavailable + } + } + if !ipHeadersCanCoalesce(pkt, pktTarget) { + return coalesceUnavailable + } + // seq adjacency + lhsLen := item.gsoSize + lhsLen += item.numMerged * item.gsoSize + if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if item.pshSet { + // We cannot append to a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + // A smaller than gsoSize packet has been appended previously. + // Nothing can come after a smaller packet on the end. + return coalesceUnavailable + } + if gsoSize > item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + return coalesceAppend + } else if seq+uint32(gsoSize) == item.sentSeq { // pkt aligns in front of item from a seq num perspective + if pshSet { + // We cannot prepend with a segment that has the PSH flag set, PSH + // can only be set on the final segment in a reassembled group. + return coalesceUnavailable + } + if gsoSize < item.gsoSize { + // We cannot have a larger packet following a smaller one. + return coalesceUnavailable + } + if gsoSize > item.gsoSize && item.numMerged > 0 { + // There's at least one previous merge, and we're larger than all + // previous. This would put multiple smaller packets on the end. + return coalesceUnavailable + } + return coalescePrepend + } + return coalesceUnavailable +} + +func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { + srcAddrAt := ipv4SrcAddrOffset + addrSize := 4 + if isV6 { + srcAddrAt = ipv6SrcAddrOffset + addrSize = 16 + } + lenForPseudo := uint16(len(pkt) - int(iphLen)) + cSum := pseudoHeaderChecksumNoFold(proto, pkt[srcAddrAt:srcAddrAt+addrSize], pkt[srcAddrAt+addrSize:srcAddrAt+addrSize*2], lenForPseudo) + return ^checksum(pkt[iphLen:], cSum) == 0 +} + +// coalesceResult represents the result of attempting to coalesce two TCP +// packets. +type coalesceResult int + +const ( + coalesceInsufficientCap coalesceResult = iota + coalescePSHEnding + coalesceItemInvalidCSum + coalescePktInvalidCSum + coalesceSuccess +) + +// coalesceUDPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. +func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front + headersLen := item.iphLen + udphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { + return coalescePktInvalidCSum + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + + item.numMerged++ + return coalesceSuccess +} + +// coalesceTCPPackets attempts to coalesce pkt with the packet described by +// item, and returns the outcome. This function may swap bufs elements in the +// event of a prepend as item's bufs index is already being tracked for writing +// to a Device. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { + var pktHead []byte // the packet that will end up at the front + headersLen := item.iphLen + item.tcphLen + coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) + + // Copy data + if mode == coalescePrepend { + pktHead = pkt + if cap(pkt)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if pshSet { + return coalescePSHEnding + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + item.sentSeq = seq + extendBy := coalescedLen - len(pktHead) + bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) + copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) + // Flip the slice headers in bufs as part of prepend. The index of item + // is already being tracked for writing. + bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + } else { + pktHead = bufs[item.bufsIndex][bufsOffset:] + if cap(pktHead)-bufsOffset < coalescedLen { + // We don't want to allocate a new underlying array if capacity is + // too small. + return coalesceInsufficientCap + } + if item.numMerged == 0 { + if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalesceItemInvalidCSum + } + } + if !checksumValid(pkt, item.iphLen, unix.IPPROTO_TCP, isV6) { + return coalescePktInvalidCSum + } + if pshSet { + // We are appending a segment with PSH set. + item.pshSet = pshSet + pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + } + extendBy := len(pkt) - int(headersLen) + bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) + copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + } + + if gsoSize > item.gsoSize { + item.gsoSize = gsoSize + } + + item.numMerged++ + return coalesceSuccess +} + +const ( + ipv4FlagMoreFragments uint8 = 0x20 +) + +const ( + ipv4SrcAddrOffset = 12 + ipv6SrcAddrOffset = 8 + maxUint16 = 1<<16 - 1 +) + +type groResult int + +const ( + groResultNoop groResult = iota + groResultTableInsert + groResultCoalesced +) + +// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + tcphLen := int((pkt[iphLen+12] >> 4) * 4) + if tcphLen < 20 || tcphLen > 60 { + return groResultNoop + } + if len(pkt) < iphLen+tcphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + tcpFlags := pkt[iphLen+tcpFlagsOffset] + var pshSet bool + // not a candidate if any non-ACK flags (except PSH+ACK) are set + if tcpFlags != tcpFlagACK { + if pkt[iphLen+tcpFlagsOffset] != tcpFlagACK|tcpFlagPSH { + return groResultNoop + } + pshSet = true + } + gsoSize := uint16(len(pkt) - tcphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + seq := binary.BigEndian.Uint32(pkt[iphLen+4:]) + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + if !existing { + return groResultTableInsert + } + for i := len(items) - 1; i >= 0; i-- { + // In the best case of packets arriving in order iterating in reverse is + // more efficient if there are multiple items for a given flow. This + // also enables a natural table.deleteAt() in the + // coalesceItemInvalidCSum case without the need for index tracking. + // This algorithm makes a best effort to coalesce in the event of + // unordered packets, where pkt may land anywhere in items from a + // sequence number perspective, however once an item is inserted into + // the table it is never compared across other items later. + item := items[i] + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + if can != coalesceUnavailable { + result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, i) + return groResultCoalesced + case coalesceItemInvalidCSum: + // delete the item with an invalid csum + table.deleteAt(item.key, i) + case coalescePktInvalidCSum: + // no point in inserting an item that we can't coalesce + return groResultNoop + default: + } + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + return groResultTableInsert +} + +// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + item.tcphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 16, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + if item.key.isV6 { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Calculate the pseudo header checksum and place it at the TCP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the tcp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// metadata found in table. +func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { + for _, items := range table.itemsByFlow { + for _, item := range items { + if item.numMerged > 0 { + hdr := virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb + hdrLen: uint16(item.iphLen + udphLen), + gsoSize: item.gsoSize, + csumStart: uint16(item.iphLen), + csumOffset: 6, + } + pkt := bufs[item.bufsIndex][offset:] + + // Recalculate the total len (IPv4) or payload len (IPv6). + // Recalculate the (IPv4) header checksum. + hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 + if item.key.isV6 { + binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + } else { + pkt[10], pkt[11] = 0, 0 + binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length + iphCSum := ^checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + } + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + + // Recalculate the UDP len field value + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + + // Calculate the pseudo header checksum and place it at the UDP + // checksum offset. Downstream checksum offloading will combine + // this with computation of the udp header and payload checksum. + addrLen := 4 + addrOffset := ipv4SrcAddrOffset + if item.key.isV6 { + addrLen = 16 + addrOffset = ipv6SrcAddrOffset + } + srcAddrAt := offset + addrOffset + srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] + dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] + psum := pseudoHeaderChecksumNoFold(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], checksum([]byte{}, psum)) + } else { + hdr := virtioNetHdr{} + err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + } + } + } + return nil +} + +type groCandidateType uint8 + +const ( + notGROCandidate groCandidateType = iota + tcp4GROCandidate + tcp6GROCandidate + udp4GROCandidate + udp6GROCandidate +) + +func packetIsGROCandidate(b []byte, canUDPGRO bool) groCandidateType { + if len(b) < 28 { + return notGROCandidate + } + if b[0]>>4 == 4 { + if b[0]&0x0F != 5 { + // IPv4 packets w/IP options do not coalesce + return notGROCandidate + } + if b[9] == unix.IPPROTO_TCP && len(b) >= 40 { + return tcp4GROCandidate + } + if b[9] == unix.IPPROTO_UDP && canUDPGRO { + return udp4GROCandidate + } + } else if b[0]>>4 == 6 { + if b[6] == unix.IPPROTO_TCP && len(b) >= 60 { + return tcp6GROCandidate + } + if b[6] == unix.IPPROTO_UDP && len(b) >= 48 && canUDPGRO { + return udp6GROCandidate + } + } + return notGROCandidate +} + +const ( + udphLen = 8 +) + +// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// existing packets tracked in table. It returns a groResultNoop when no +// action was taken, groResultTableInsert when the evaluated packet was +// inserted into table, and groResultCoalesced when the evaluated packet was +// coalesced with another packet in table. +func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { + pkt := bufs[pktI][offset:] + if len(pkt) > maxUint16 { + // A valid IPv4 or IPv6 packet will never exceed this. + return groResultNoop + } + iphLen := int((pkt[0] & 0x0F) * 4) + if isV6 { + iphLen = 40 + ipv6HPayloadLen := int(binary.BigEndian.Uint16(pkt[4:])) + if ipv6HPayloadLen != len(pkt)-iphLen { + return groResultNoop + } + } else { + totalLen := int(binary.BigEndian.Uint16(pkt[2:])) + if totalLen != len(pkt) { + return groResultNoop + } + } + if len(pkt) < iphLen { + return groResultNoop + } + if len(pkt) < iphLen+udphLen { + return groResultNoop + } + if !isV6 { + if pkt[6]&ipv4FlagMoreFragments != 0 || pkt[6]<<3 != 0 || pkt[7] != 0 { + // no GRO support for fragmented segments for now + return groResultNoop + } + } + gsoSize := uint16(len(pkt) - udphLen - iphLen) + // not a candidate if payload len is 0 + if gsoSize < 1 { + return groResultNoop + } + srcAddrOffset := ipv4SrcAddrOffset + addrLen := 4 + if isV6 { + srcAddrOffset = ipv6SrcAddrOffset + addrLen = 16 + } + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + if !existing { + return groResultTableInsert + } + // With UDP we only check the last item, otherwise we could reorder packets + // for a given flow. We must also always insert a new item, or successfully + // coalesce with an existing item, for the same reason. + item := items[len(items)-1] + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + var pktCSumKnownInvalid bool + if can == coalesceAppend { + result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + switch result { + case coalesceSuccess: + table.updateAt(item, len(items)-1) + return groResultCoalesced + case coalesceItemInvalidCSum: + // If the existing item has an invalid csum we take no action. A new + // item will be stored after it, and the existing item will never be + // revisited as part of future coalescing candidacy checks. + case coalescePktInvalidCSum: + // We must insert a new item, but we also mark it as invalid csum + // to prevent a repeat checksum validation. + pktCSumKnownInvalid = true + default: + } + } + // failed to coalesce with any other packets; store the item in the flow + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + return groResultTableInsert +} + +// handleGRO evaluates bufs for GRO, and writes the indices of the resulting +// packets into toWrite. toWrite, tcpTable, and udpTable should initially be +// empty (but non-nil), and are passed in to save allocs as the caller may reset +// and recycle them across vectors of packets. canUDPGRO indicates if UDP GRO is +// supported. +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, canUDPGRO bool, toWrite *[]int) error { + for i := range bufs { + if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { + return errors.New("invalid offset") + } + var result groResult + switch packetIsGROCandidate(bufs[i][offset:], canUDPGRO) { + case tcp4GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, false) + case tcp6GROCandidate: + result = tcpGRO(bufs, offset, i, tcpTable, true) + case udp4GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, false) + case udp6GROCandidate: + result = udpGRO(bufs, offset, i, udpTable, true) + } + switch result { + case groResultNoop: + hdr := virtioNetHdr{} + err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) + if err != nil { + return err + } + fallthrough + case groResultTableInsert: + *toWrite = append(*toWrite, i) + } + } + errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) + errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + return errors.Join(errTCP, errUDP) +} + +// gsoSplit splits packets from in into outBuffs, writing the size of each +// element into sizes. It returns the number of buffers populated, and/or an +// error. +func gsoSplit(in []byte, hdr virtioNetHdr, outBuffs [][]byte, sizes []int, outOffset int, isV6 bool) (int, error) { + iphLen := int(hdr.csumStart) + srcAddrOffset := ipv6SrcAddrOffset + addrLen := 16 + if !isV6 { + in[10], in[11] = 0, 0 // clear ipv4 header checksum + srcAddrOffset = ipv4SrcAddrOffset + addrLen = 4 + } + transportCsumAt := int(hdr.csumStart + hdr.csumOffset) + in[transportCsumAt], in[transportCsumAt+1] = 0, 0 // clear tcp/udp checksum + var firstTCPSeqNum uint32 + var protocol uint8 + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV4 || hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_TCPV6 { + protocol = unix.IPPROTO_TCP + firstTCPSeqNum = binary.BigEndian.Uint32(in[hdr.csumStart+4:]) + } else { + protocol = unix.IPPROTO_UDP + } + nextSegmentDataAt := int(hdr.hdrLen) + i := 0 + for ; nextSegmentDataAt < len(in); i++ { + if i == len(outBuffs) { + return i - 1, ErrTooManySegments + } + nextSegmentEnd := nextSegmentDataAt + int(hdr.gsoSize) + if nextSegmentEnd > len(in) { + nextSegmentEnd = len(in) + } + segmentDataLen := nextSegmentEnd - nextSegmentDataAt + totalLen := int(hdr.hdrLen) + segmentDataLen + sizes[i] = totalLen + out := outBuffs[i][outOffset:] + + copy(out, in[:iphLen]) + if !isV6 { + // For IPv4 we are responsible for incrementing the ID field, + // updating the total len field, and recalculating the header + // checksum. + if i > 0 { + id := binary.BigEndian.Uint16(out[4:]) + id += uint16(i) + binary.BigEndian.PutUint16(out[4:], id) + } + binary.BigEndian.PutUint16(out[2:], uint16(totalLen)) + ipv4CSum := ^checksum(out[:iphLen], 0) + binary.BigEndian.PutUint16(out[10:], ipv4CSum) + } else { + // For IPv6 we are responsible for updating the payload length field. + binary.BigEndian.PutUint16(out[4:], uint16(totalLen-iphLen)) + } + + // copy transport header + copy(out[hdr.csumStart:hdr.hdrLen], in[hdr.csumStart:hdr.hdrLen]) + + if protocol == unix.IPPROTO_TCP { + // set TCP seq and adjust TCP flags + tcpSeq := firstTCPSeqNum + uint32(hdr.gsoSize*uint16(i)) + binary.BigEndian.PutUint32(out[hdr.csumStart+4:], tcpSeq) + if nextSegmentEnd != len(in) { + // FIN and PSH should only be set on last segment + clearFlags := tcpFlagFIN | tcpFlagPSH + out[hdr.csumStart+tcpFlagsOffset] &^= clearFlags + } + } else { + // set UDP header len + binary.BigEndian.PutUint16(out[hdr.csumStart+4:], uint16(segmentDataLen)+(hdr.hdrLen-hdr.csumStart)) + } + + // payload + copy(out[hdr.hdrLen:], in[nextSegmentDataAt:nextSegmentEnd]) + + // transport checksum + transportHeaderLen := int(hdr.hdrLen - hdr.csumStart) + lenForPseudo := uint16(transportHeaderLen + segmentDataLen) + transportCSumNoFold := pseudoHeaderChecksumNoFold(protocol, in[srcAddrOffset:srcAddrOffset+addrLen], in[srcAddrOffset+addrLen:srcAddrOffset+addrLen*2], lenForPseudo) + transportCSum := ^checksum(out[hdr.csumStart:totalLen], transportCSumNoFold) + binary.BigEndian.PutUint16(out[hdr.csumStart+hdr.csumOffset:], transportCSum) + + nextSegmentDataAt += int(hdr.gsoSize) + } + return i, nil +} + +func gsoNoneChecksum(in []byte, cSumStart, cSumOffset uint16) error { + cSumAt := cSumStart + cSumOffset + // The initial value at the checksum offset should be summed with the + // checksum we compute. This is typically the pseudo-header checksum. + initial := binary.BigEndian.Uint16(in[cSumAt:]) + in[cSumAt], in[cSumAt+1] = 0, 0 + binary.BigEndian.PutUint16(in[cSumAt:], ^checksum(in[cSumStart:], uint64(initial))) + return nil +} diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go new file mode 100644 index 0000000..ae55c8c --- /dev/null +++ b/tun/offload_linux_test.go @@ -0,0 +1,752 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package tun + +import ( + "net/netip" + "testing" + + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" +) + +const ( + offset = virtioNetHdrLen +) + +var ( + ip4PortA = netip.MustParseAddrPort("192.0.2.1:1") + ip4PortB = netip.MustParseAddrPort("192.0.2.2:1") + ip4PortC = netip.MustParseAddrPort("192.0.2.3:1") + ip6PortA = netip.MustParseAddrPort("[2001:db8::1]:1") + ip6PortB = netip.MustParseAddrPort("[2001:db8::2]:1") + ip6PortC = netip.MustParseAddrPort("[2001:db8::3]:1") +) + +func udp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 28 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_UDP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + udpH := header.UDP(b[offset+20:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp6Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp6PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func udp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, payloadLen int, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 48 + payloadLen + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_UDP, + HopLimit: 64, + PayloadLength: uint16(payloadLen + udphLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + udpH := header.UDP(b[offset+40:]) + udpH.Encode(&header.UDPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + Length: uint16(payloadLen + udphLen), + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_UDP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(udphLen+payloadLen)) + udpH.SetChecksum(^udpH.CalculateChecksum(pseudoCsum)) + return b +} + +func udp4Packet(srcIPPort, dstIPPort netip.AddrPort, payloadLen int) []byte { + return udp4PacketMutateIPFields(srcIPPort, dstIPPort, payloadLen, nil) +} + +func tcp4PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv4Fields)) []byte { + totalLen := 40 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv4H := header.IPv4(b[offset:]) + srcAs4 := srcIPPort.Addr().As4() + dstAs4 := dstIPPort.Addr().As4() + ipFields := &header.IPv4Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs4[:]), + DstAddr: tcpip.AddrFromSlice(dstAs4[:]), + Protocol: unix.IPPROTO_TCP, + TTL: 64, + TotalLength: uint16(totalLen), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv4H.Encode(ipFields) + tcpH := header.TCP(b[offset+20:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + ipv4H.SetChecksum(^ipv4H.CalculateChecksum()) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv4H.SourceAddress(), ipv4H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp4Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp4PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func tcp6PacketMutateIPFields(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32, ipFn func(*header.IPv6Fields)) []byte { + totalLen := 60 + segmentSize + b := make([]byte, offset+int(totalLen), 65535) + ipv6H := header.IPv6(b[offset:]) + srcAs16 := srcIPPort.Addr().As16() + dstAs16 := dstIPPort.Addr().As16() + ipFields := &header.IPv6Fields{ + SrcAddr: tcpip.AddrFromSlice(srcAs16[:]), + DstAddr: tcpip.AddrFromSlice(dstAs16[:]), + TransportProtocol: unix.IPPROTO_TCP, + HopLimit: 64, + PayloadLength: uint16(segmentSize + 20), + } + if ipFn != nil { + ipFn(ipFields) + } + ipv6H.Encode(ipFields) + tcpH := header.TCP(b[offset+40:]) + tcpH.Encode(&header.TCPFields{ + SrcPort: srcIPPort.Port(), + DstPort: dstIPPort.Port(), + SeqNum: seq, + AckNum: 1, + DataOffset: 20, + Flags: flags, + WindowSize: 3000, + }) + pseudoCsum := header.PseudoHeaderChecksum(unix.IPPROTO_TCP, ipv6H.SourceAddress(), ipv6H.DestinationAddress(), uint16(20+segmentSize)) + tcpH.SetChecksum(^tcpH.CalculateChecksum(pseudoCsum)) + return b +} + +func tcp6Packet(srcIPPort, dstIPPort netip.AddrPort, flags header.TCPFlags, segmentSize, seq uint32) []byte { + return tcp6PacketMutateIPFields(srcIPPort, dstIPPort, flags, segmentSize, seq, nil) +} + +func Test_handleVirtioRead(t *testing.T) { + tests := []struct { + name string + hdr virtioNetHdr + pktIn []byte + wantLens []int + wantErr bool + }{ + { + "tcp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV4, + gsoSize: 100, + hdrLen: 40, + csumStart: 20, + csumOffset: 16, + }, + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{140, 140}, + false, + }, + { + "tcp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_TCPV6, + gsoSize: 100, + hdrLen: 60, + csumStart: 40, + csumOffset: 16, + }, + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 200, 1), + []int{160, 160}, + false, + }, + { + "udp4", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 28, + csumStart: 20, + csumOffset: 6, + }, + udp4Packet(ip4PortA, ip4PortB, 200), + []int{128, 128}, + false, + }, + { + "udp6", + virtioNetHdr{ + flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, + gsoType: unix.VIRTIO_NET_HDR_GSO_UDP_L4, + gsoSize: 100, + hdrLen: 48, + csumStart: 40, + csumOffset: 6, + }, + udp6Packet(ip6PortA, ip6PortB, 200), + []int{148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := make([][]byte, conn.IdealBatchSize) + sizes := make([]int, conn.IdealBatchSize) + for i := range out { + out[i] = make([]byte, 65535) + } + tt.hdr.encode(tt.pktIn) + n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if n != len(tt.wantLens) { + t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) + } + for i := range tt.wantLens { + if tt.wantLens[i] != sizes[i] { + t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + } + } + }) + } +} + +func flipTCP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 16 // 20 byte ipv4 header; tcp csum offset is 16 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func flipUDP4Checksum(b []byte) []byte { + at := virtioNetHdrLen + 20 + 6 // 20 byte ipv4 header; udp csum offset is 6 + b[at] ^= 0xFF + b[at+1] ^= 0xFF + return b +} + +func Fuzz_handleGRO(f *testing.F) { + pkt0 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1) + pkt1 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101) + pkt2 := tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201) + pkt3 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1) + pkt4 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101) + pkt5 := tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201) + pkt6 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt7 := udp4Packet(ip4PortA, ip4PortB, 100) + pkt8 := udp4Packet(ip4PortA, ip4PortC, 100) + pkt9 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt10 := udp6Packet(ip6PortA, ip6PortB, 100) + pkt11 := udp6Packet(ip6PortA, ip6PortC, 100) + f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, true, offset) + f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, canUDPGRO bool, offset int) { + pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} + toWrite := make([]int, 0, len(pkts)) + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), canUDPGRO, &toWrite) + if len(toWrite) > len(pkts) { + t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) + } + seenWriteI := make(map[int]bool) + for _, writeI := range toWrite { + if writeI < 0 || writeI > len(pkts)-1 { + t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) + } + if seenWriteI[writeI] { + t.Errorf("duplicate toWrite value: %d", writeI) + } + seenWriteI[writeI] = true + } + }) +} + +func Test_handleGRO(t *testing.T) { + tests := []struct { + name string + pktsIn [][]byte + canUDPGRO bool + wantToWrite []int + wantLens []int + wantErr bool + }{ + { + "multiple protocols and flows", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + true, + []int{0, 1, 2, 4, 5, 7, 9}, + []int{240, 228, 128, 140, 260, 160, 248}, + false, + }, + { + "multiple protocols and flows no UDP GRO", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // tcp4 flow 1 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp4Packet(ip4PortA, ip4PortC, 100), // udp4 flow 2 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // tcp4 flow 1 + tcp4Packet(ip4PortA, ip4PortC, header.TCPFlagAck, 100, 201), // tcp4 flow 2 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101), // tcp6 flow 1 + tcp6Packet(ip6PortA, ip6PortC, header.TCPFlagAck, 100, 201), // tcp6 flow 2 + udp4Packet(ip4PortA, ip4PortB, 100), // udp4 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 + }, + false, + []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, + []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, + false, + }, + { + "PSH interleaved", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 301), // v4 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck|header.TCPFlagPsh, 100, 101), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 201), // v6 flow 1 + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 + }, + true, + []int{0, 2, 4, 6}, + []int{240, 240, 260, 260}, + false, + }, + { + "coalesceItemInvalidCSum", + [][]byte{ + flipTCP4Checksum(tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + flipUDP4Checksum(udp4Packet(ip4PortA, ip4PortB, 100)), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4Packet(ip4PortA, ip4PortB, 100), + }, + true, + []int{0, 1, 3, 4}, + []int{140, 240, 128, 228}, + false, + }, + { + "out of order", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101), // v4 flow 1 seq 101 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), // v4 flow 1 seq 1 len 100 + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 + }, + true, + []int{0}, + []int{340}, + false, + }, + { + "unequal TTL", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TTL++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal ToS", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.TOS++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags more fragments set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 1 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "unequal flags DF set", + [][]byte{ + tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1), + tcp4PacketMutateIPFields(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + udp4Packet(ip4PortA, ip4PortB, 100), + udp4PacketMutateIPFields(ip4PortA, ip4PortB, 100, func(fields *header.IPv4Fields) { + fields.Flags = 2 + }), + }, + true, + []int{0, 1, 2, 3}, + []int{140, 140, 128, 128}, + false, + }, + { + "ipv6 unequal hop limit", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.HopLimit++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + { + "ipv6 unequal traffic class", + [][]byte{ + tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1), + tcp6PacketMutateIPFields(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 101, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + udp6Packet(ip6PortA, ip6PortB, 100), + udp6PacketMutateIPFields(ip6PortA, ip6PortB, 100, func(fields *header.IPv6Fields) { + fields.TrafficClass++ + }), + }, + true, + []int{0, 1, 2, 3}, + []int{160, 160, 148, 148}, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + toWrite := make([]int, 0, len(tt.pktsIn)) + err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.canUDPGRO, &toWrite) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if len(toWrite) != len(tt.wantToWrite) { + t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) + } + for i, pktI := range tt.wantToWrite { + if tt.wantToWrite[i] != toWrite[i] { + t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + } + if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { + t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + } + } + }) + } +} + +func Test_packetIsGROCandidate(t *testing.T) { + tcp4 := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp4TooShort := tcp4[:39] + ip4InvalidHeaderLen := make([]byte, len(tcp4)) + copy(ip4InvalidHeaderLen, tcp4) + ip4InvalidHeaderLen[0] = 0x46 + ip4InvalidProtocol := make([]byte, len(tcp4)) + copy(ip4InvalidProtocol, tcp4) + ip4InvalidProtocol[9] = unix.IPPROTO_GRE + + tcp6 := tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 1)[virtioNetHdrLen:] + tcp6TooShort := tcp6[:59] + ip6InvalidProtocol := make([]byte, len(tcp6)) + copy(ip6InvalidProtocol, tcp6) + ip6InvalidProtocol[6] = unix.IPPROTO_GRE + + udp4 := udp4Packet(ip4PortA, ip4PortB, 100)[virtioNetHdrLen:] + udp4TooShort := udp4[:27] + + udp6 := udp6Packet(ip6PortA, ip6PortB, 100)[virtioNetHdrLen:] + udp6TooShort := udp6[:47] + + tests := []struct { + name string + b []byte + canUDPGRO bool + want groCandidateType + }{ + { + "tcp4", + tcp4, + true, + tcp4GROCandidate, + }, + { + "tcp6", + tcp6, + true, + tcp6GROCandidate, + }, + { + "udp4", + udp4, + true, + udp4GROCandidate, + }, + { + "udp4 no support", + udp4, + false, + notGROCandidate, + }, + { + "udp6", + udp6, + true, + udp6GROCandidate, + }, + { + "udp6 no support", + udp6, + false, + notGROCandidate, + }, + { + "udp4 too short", + udp4TooShort, + true, + notGROCandidate, + }, + { + "udp6 too short", + udp6TooShort, + true, + notGROCandidate, + }, + { + "tcp4 too short", + tcp4TooShort, + true, + notGROCandidate, + }, + { + "tcp6 too short", + tcp6TooShort, + true, + notGROCandidate, + }, + { + "invalid IP version", + []byte{0x00}, + true, + notGROCandidate, + }, + { + "invalid IP header len", + ip4InvalidHeaderLen, + true, + notGROCandidate, + }, + { + "ip4 invalid protocol", + ip4InvalidProtocol, + true, + notGROCandidate, + }, + { + "ip6 invalid protocol", + ip6InvalidProtocol, + true, + notGROCandidate, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := packetIsGROCandidate(tt.b, tt.canUDPGRO); got != tt.want { + t.Errorf("packetIsGROCandidate() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_udpPacketsCanCoalesce(t *testing.T) { + udp4a := udp4Packet(ip4PortA, ip4PortB, 100) + udp4b := udp4Packet(ip4PortA, ip4PortB, 100) + udp4c := udp4Packet(ip4PortA, ip4PortB, 110) + + type args struct { + pkt []byte + iphLen uint8 + gsoSize uint16 + item udpGROItem + bufs [][]byte + bufsOffset int + } + tests := []struct { + name string + args args + want canCoalesce + }{ + { + "coalesceAppend equal gso", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceAppend smaller gso", + args{ + pkt: udp4a[offset : len(udp4a)-90], + iphLen: 20, + gsoSize: 10, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4b, + }, + bufsOffset: offset, + }, + coalesceAppend, + }, + { + "coalesceUnavailable smaller gso previously appended", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4c, + udp4b, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable larger following smaller", + args{ + pkt: udp4c[offset:], + iphLen: 20, + gsoSize: 110, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + }, + bufs: [][]byte{ + udp4a, + udp4c, + }, + bufsOffset: offset, + }, + coalesceUnavailable, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { + t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tun/operateonfd.go b/tun/operateonfd.go index 31747a2..f1beb6d 100644 --- a/tun/operateonfd.go +++ b/tun/operateonfd.go @@ -1,8 +1,8 @@ -// +build !windows +//go:build darwin || freebsd /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -18,12 +18,36 @@ const ( ) type Device interface { - File() *os.File // returns the file descriptor of the device - Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) - Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) - Flush() error // flush all previous writes to the device - MTU() (int, error) // returns the MTU of the device - Name() (string, error) // fetches and returns the current name - Events() chan Event // returns a constant channel of events related to the device - Close() error // stops the device and closes the event channel + // File returns the file descriptor of the device. + File() *os.File + + // Read one or more packets from the Device (without any additional headers). + // On a successful read it returns the number of packets read, and sets + // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). + // A nonzero offset can be used to instruct the Device on where to begin + // reading into each element of the bufs slice. + Read(bufs [][]byte, sizes []int, offset int) (n int, err error) + + // Write one or more packets to the device (without any additional headers). + // On a successful write it returns the number of packets written. A nonzero + // offset can be used to instruct the Device on where to begin writing from + // each packet contained within the bufs slice. + Write(bufs [][]byte, offset int) (int, error) + + // MTU returns the MTU of the Device. + MTU() (int, error) + + // Name returns the current name of the Device. + Name() (string, error) + + // Events returns a channel of type Event, which is fed Device events. + Events() <-chan Event + + // Close stops the Device and closes the Event channel. + Close() error + + // BatchSize returns the preferred/max number of packets that can be read or + // written in a single read/write call. BatchSize must not change over the + // lifetime of a Device. + BatchSize() int } diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index f19a7df..c9a6c0b 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -1,58 +1,41 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( + "errors" "fmt" - "io/ioutil" + "io" "net" "os" + "sync" "syscall" "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) const utunControlName = "com.apple.net.utun_control" -// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h -const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3 - -// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h -type sockaddrCtl struct { - scLen uint8 - scFamily uint8 - ssSysaddr uint16 - scID uint32 - scUnit uint32 - scReserved [5]uint32 -} - type NativeTun struct { name string tunFile *os.File events chan Event errors chan error routeSocket int + closeOnce sync.Once } -var sockaddrCtlSize uintptr = 32 - func retryInterfaceByIndex(index int) (iface *net.Interface, err error) { for i := 0; i < 20; i++ { iface, err = net.InterfaceByIndex(index) - if err != nil { - if opErr, ok := err.(*net.OpError); ok { - if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM { - time.Sleep(time.Duration(i) * time.Second / 3) - continue - } - } + if err != nil && errors.Is(err, unix.ENOMEM) { + time.Sleep(time.Duration(i) * time.Second / 3) + continue } return iface, err } @@ -72,7 +55,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { retry: n, err := unix.Read(tun.routeSocket, data) if err != nil { - if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { + if errno, ok := err.(unix.Errno); ok && errno == unix.EINTR { goto retry } tun.errors <- err @@ -124,53 +107,33 @@ func CreateTUN(name string, mtu int) (Device, error) { } } - fd, err := unix.Socket(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) - + fd, err := socketCloexec(unix.AF_SYSTEM, unix.SOCK_DGRAM, 2) if err != nil { return nil, err } - var ctlInfo = &struct { - ctlID uint32 - ctlName [96]byte - }{} - - copy(ctlInfo.ctlName[:], []byte(utunControlName)) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(_CTLIOCGINFO), - uintptr(unsafe.Pointer(ctlInfo)), - ) - - if errno != 0 { - return nil, fmt.Errorf("_CTLIOCGINFO: %v", errno) + ctlInfo := &unix.CtlInfo{} + copy(ctlInfo.Name[:], []byte(utunControlName)) + err = unix.IoctlCtlInfo(fd, ctlInfo) + if err != nil { + unix.Close(fd) + return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err) } - sc := sockaddrCtl{ - scLen: uint8(sockaddrCtlSize), - scFamily: unix.AF_SYSTEM, - ssSysaddr: 2, - scID: ctlInfo.ctlID, - scUnit: uint32(ifIndex) + 1, + sc := &unix.SockaddrCtl{ + ID: ctlInfo.Id, + Unit: uint32(ifIndex) + 1, } - scPointer := unsafe.Pointer(&sc) - - _, _, errno = unix.RawSyscall( - unix.SYS_CONNECT, - uintptr(fd), - uintptr(scPointer), - uintptr(sockaddrCtlSize), - ) - - if errno != 0 { - return nil, fmt.Errorf("SYS_CONNECT: %v", errno) + err = unix.Connect(fd, sc) + if err != nil { + unix.Close(fd) + return nil, err } - err = syscall.SetNonblock(fd, true) + err = unix.SetNonblock(fd, true) if err != nil { + unix.Close(fd) return nil, err } tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu) @@ -178,7 +141,7 @@ func CreateTUN(name string, mtu int) (Device, error) { if err == nil && name == "utun" { fname := os.Getenv("WG_TUN_NAME_FILE") if fname != "" { - ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400) + os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) } } @@ -210,7 +173,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + tun.routeSocket, err = socketCloexec(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) if err != nil { tun.tunFile.Close() return nil, err @@ -230,27 +193,19 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { } func (tun *NativeTun) Name() (string, error) { - var ifName struct { - name [16]byte - } - ifNameSize := uintptr(16) - - var errno syscall.Errno + var err error tun.operateOnFd(func(fd uintptr) { - _, _, errno = unix.Syscall6( - unix.SYS_GETSOCKOPT, - fd, + tun.name, err = unix.GetsockoptString( + int(fd), 2, /* #define SYSPROTO_CONTROL 2 */ 2, /* #define UTUN_OPT_IFNAME 2 */ - uintptr(unsafe.Pointer(&ifName)), - uintptr(unsafe.Pointer(&ifNameSize)), 0) + ) }) - if errno != 0 { - return "", fmt.Errorf("SYS_GETSOCKOPT: %v", errno) + if err != nil { + return "", fmt.Errorf("GetSockoptString: %w", err) } - tun.name = string(ifName.name[:ifNameSize-1]) return tun.name, nil } @@ -258,61 +213,63 @@ func (tun *NativeTun) File() *os.File { return tun.tunFile } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + // TODO: the BSDs look very similar in Read() and Write(). They should be + // collapsed, with platform-specific files containing the varying parts of + // their implementations. select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) + buf := bufs[0][offset-4:] + n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range bufs { + buf = buf[offset-4:] + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(bufs), nil } func (tun *NativeTun) Close() error { - var err2 error - err1 := tun.tunFile.Close() - if tun.routeSocket != -1 { - unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) - err2 = unix.Close(tun.routeSocket) - } else if tun.events != nil { - close(tun.events) - } + var err1, err2 error + tun.closeOnce.Do(func() { + err1 = tun.tunFile.Close() + if tun.routeSocket != -1 { + unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) + err2 = unix.Close(tun.routeSocket) + } else if tun.events != nil { + close(tun.events) + } + }) if err1 != nil { return err1 } @@ -320,71 +277,60 @@ func (tun *NativeTun) Close() error { } func (tun *NativeTun) setMTU(n int) error { - - // open datagram socket - - var fd int - - fd, err := unix.Socket( + fd, err := socketCloexec( unix.AF_INET, unix.SOCK_DGRAM, 0, ) - if err != nil { return err } defer unix.Close(fd) - // do ioctl call - - var ifr [32]byte - copy(ifr[:], tun.name) - *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - - if errno != 0 { - return fmt.Errorf("failed to set MTU on %s", tun.name) + var ifr unix.IfreqMTU + copy(ifr.Name[:], tun.name) + ifr.MTU = int32(n) + err = unix.IoctlSetIfreqMTU(fd, &ifr) + if err != nil { + return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err) } return nil } func (tun *NativeTun) MTU() (int, error) { - - // open datagram socket - - fd, err := unix.Socket( + fd, err := socketCloexec( unix.AF_INET, unix.SOCK_DGRAM, 0, ) - if err != nil { return 0, err } defer unix.Close(fd) - // do ioctl call - - var ifr [64]byte - copy(ifr[:], tun.name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFMTU), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return 0, fmt.Errorf("failed to get MTU on %s", tun.name) + ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name) + if err != nil { + return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err) } - return int(*(*int32)(unsafe.Pointer(&ifr[16]))), nil + return int(ifr.MTU), nil +} + +func (tun *NativeTun) BatchSize() int { + return 1 +} + +func socketCloexec(family, sotype, proto int) (fd int, err error) { + // See go/src/net/sys_cloexec.go for background. + syscall.ForkLock.RLock() + defer syscall.ForkLock.RUnlock() + + fd, err = unix.Socket(family, sotype, proto) + if err == nil { + unix.CloseOnExec(fd) + } + return } diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 6cf9313..7c65fd9 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -1,66 +1,57 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( - "bytes" "errors" "fmt" + "io" "net" "os" + "sync" "syscall" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) -// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h -// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96)) const ( _TUNSIFHEAD = 0x80047460 _TUNSIFMODE = 0x8004745e + _TUNGIFNAME = 0x4020745d _TUNSIFPID = 0x2000745f -) -// TODO: move into x/sys/unix -const ( - SIOCGIFINFO_IN6 = 0xc048696c - SIOCSIFINFO_IN6 = 0xc048696d - ND6_IFF_AUTO_LINKLOCAL = 0x20 - ND6_IFF_NO_DAD = 0x100 + _SIOCGIFINFO_IN6 = 0xc048696c + _SIOCSIFINFO_IN6 = 0xc048696d + _ND6_IFF_AUTO_LINKLOCAL = 0x20 + _ND6_IFF_NO_DAD = 0x100 ) -// Iface status string max len -const _IFSTATMAX = 800 - -const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1) +// Iface requests with just the name +type ifreqName struct { + Name [unix.IFNAMSIZ]byte + _ [16]byte +} -// structure for iface requests with a pointer -type ifreq_ptr struct { +// Iface requests with a pointer +type ifreqPtr struct { Name [unix.IFNAMSIZ]byte Data uintptr - Pad0 [16 - SIZEOF_UINTPTR]byte + _ [16 - unsafe.Sizeof(uintptr(0))]byte } -// Structure for iface mtu get/set ioctls -type ifreq_mtu struct { +// Iface requests with MTU +type ifreqMtu struct { Name [unix.IFNAMSIZ]byte MTU uint32 - Pad0 [12]byte -} - -// Structure for interface status request ioctl -type ifstat struct { - IfsName [unix.IFNAMSIZ]byte - Ascii [_IFSTATMAX]byte + _ [12]byte } -// Structures for nd6 flag manipulation -type in6_ndireq struct { +// ND6 flag manipulation +type nd6Req struct { Name [unix.IFNAMSIZ]byte Linkmtu uint32 Maxmtu uint32 @@ -82,6 +73,7 @@ type NativeTun struct { events chan Event errors chan error routeSocket int + closeOnce sync.Once } func (tun *NativeTun) routineRouteListener(tunIfindex int) { @@ -97,7 +89,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { retry: n, err := unix.Read(tun.routeSocket, data) if err != nil { - if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { + if errors.Is(err, syscall.EINTR) { goto retry } tun.errors <- err @@ -141,91 +133,17 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { } func tunName(fd uintptr) (string, error) { - //Terrible hack to make up for freebsd not having a TUNGIFNAME - - //First, make sure the tun pid matches this proc's pid - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(_TUNSIFPID), - uintptr(0), - ) - - if errno != 0 { - return "", fmt.Errorf("failed to set tun device PID: %s", errno.Error()) - } - - // Open iface control socket - - confd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - - if err != nil { + var ifreq ifreqName + _, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq))) + if err != 0 { return "", err } - - defer unix.Close(confd) - - procPid := os.Getpid() - - //Try to find interface with matching PID - for i := 1; ; i++ { - iface, _ := net.InterfaceByIndex(i) - if err != nil || iface == nil { - break - } - - // Structs for getting data in and out of SIOCGIFSTATUS ioctl - var ifstatus ifstat - copy(ifstatus.IfsName[:], iface.Name) - - // Make the syscall to get the status string - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(confd), - uintptr(unix.SIOCGIFSTATUS), - uintptr(unsafe.Pointer(&ifstatus)), - ) - - if errno != 0 { - continue - } - - nullStr := ifstatus.Ascii[:] - i := bytes.IndexByte(nullStr, 0) - if i < 1 { - continue - } - statStr := string(nullStr[:i]) - var pidNum int = 0 - - // Finally get the owning PID - // Format string taken from sys/net/if_tun.c - _, err := fmt.Sscanf(statStr, "\tOpened by PID %d\n", &pidNum) - if err != nil { - continue - } - - if pidNum == procPid { - return iface.Name, nil - } - } - - return "", nil + return unix.ByteSliceToString(ifreq.Name[:]), nil } // Destroy a named system interface func tunDestroy(name string) error { - // Open control socket. - var fd int - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { return err } @@ -233,14 +151,9 @@ func tunDestroy(name string) error { var ifr [32]byte copy(ifr[:], name) - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCIFDESTROY), - uintptr(unsafe.Pointer(&ifr[0])), - ) + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0]))) if errno != 0 { - return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error()) + return fmt.Errorf("failed to destroy interface %s: %w", name, errno) } return nil @@ -257,7 +170,7 @@ func CreateTUN(name string, mtu int) (Device, error) { return nil, fmt.Errorf("interface %s already exists", name) } - tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR, 0) + tunFile, err := os.OpenFile("/dev/tun", unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { return nil, err } @@ -276,103 +189,94 @@ func CreateTUN(name string, mtu int) (Device, error) { ifheadmode := 1 var errno syscall.Errno tun.operateOnFd(func(fd uintptr) { - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - fd, - uintptr(_TUNSIFHEAD), - uintptr(unsafe.Pointer(&ifheadmode)), - ) + _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode))) }) if errno != 0 { tunFile.Close() tunDestroy(assignedName) - return nil, fmt.Errorf("Unable to put into IFHEAD mode: %v", errno) + return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno) } - // Open control sockets - confd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - if err != nil { + // Get out of PTP mode. + ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST + tun.operateOnFd(func(fd uintptr) { + _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags))) + }) + + if errno != 0 { tunFile.Close() tunDestroy(assignedName) - return nil, err + return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno) } - defer unix.Close(confd) - confd6, err := unix.Socket( - unix.AF_INET6, - unix.SOCK_DGRAM, - 0, - ) + + // Disable link-local v6, not just because WireGuard doesn't do that anyway, but + // also because there are serious races with attaching and detaching LLv6 addresses + // in relation to interface lifetime within the FreeBSD kernel. + confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { tunFile.Close() tunDestroy(assignedName) return nil, err } defer unix.Close(confd6) - - // Disable link-local v6, not just because WireGuard doesn't do that anyway, but - // also because there are serious races with attaching and detaching LLv6 addresses - // in relation to interface lifetime within the FreeBSD kernel. - var ndireq in6_ndireq + var ndireq nd6Req copy(ndireq.Name[:], assignedName) - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - uintptr(confd6), - uintptr(SIOCGIFINFO_IN6), - uintptr(unsafe.Pointer(&ndireq)), - ) + _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq))) if errno != 0 { tunFile.Close() tunDestroy(assignedName) - return nil, fmt.Errorf("Unable to get nd6 flags for %s: %v", assignedName, errno) + return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno) } - ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL - ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - uintptr(confd6), - uintptr(SIOCSIFINFO_IN6), - uintptr(unsafe.Pointer(&ndireq)), - ) + ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL + ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD + _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq))) if errno != 0 { tunFile.Close() tunDestroy(assignedName) - return nil, fmt.Errorf("Unable to set nd6 flags for %s: %v", assignedName, errno) + return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno) } - // Rename the interface - var newnp [unix.IFNAMSIZ]byte - copy(newnp[:], name) - var ifr ifreq_ptr - copy(ifr.Name[:], assignedName) - ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) - _, _, errno = unix.Syscall( - unix.SYS_IOCTL, - uintptr(confd), - uintptr(unix.SIOCSIFNAME), - uintptr(unsafe.Pointer(&ifr)), - ) - if errno != 0 { - tunFile.Close() - tunDestroy(assignedName) - return nil, fmt.Errorf("Failed to rename %s to %s: %v", assignedName, name, errno) + if name != "" { + confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) + if err != nil { + tunFile.Close() + tunDestroy(assignedName) + return nil, err + } + defer unix.Close(confd) + var newnp [unix.IFNAMSIZ]byte + copy(newnp[:], name) + var ifr ifreqPtr + copy(ifr.Name[:], assignedName) + ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) + _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr))) + if errno != 0 { + tunFile.Close() + tunDestroy(assignedName) + return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno) + } } return CreateTUNFromFile(tunFile, mtu) } func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { - tun := &NativeTun{ tunFile: file, events: make(chan Event, 10), errors: make(chan error, 1), } + var errno syscall.Errno + tun.operateOnFd(func(fd uintptr) { + _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0)) + }) + if errno != 0 { + tun.tunFile.Close() + return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno) + } + name, err := tun.Name() if err != nil { tun.tunFile.Close() @@ -391,7 +295,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) if err != nil { tun.tunFile.Close() return nil, err @@ -425,63 +329,65 @@ func (tun *NativeTun) File() *os.File { return tun.tunFile } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) + buf := bufs[0][offset-4:] + n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range bufs { + buf = buf[offset-4:] + if len(buf) < 5 { + return i, io.ErrShortBuffer + } + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(bufs), nil } func (tun *NativeTun) Close() error { - var err3 error - err1 := tun.tunFile.Close() - err2 := tunDestroy(tun.name) - if tun.routeSocket != -1 { - unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) - err3 = unix.Close(tun.routeSocket) - tun.routeSocket = -1 - } else if tun.events != nil { - close(tun.events) - } + var err1, err2, err3 error + tun.closeOnce.Do(func() { + err1 = tun.tunFile.Close() + err2 = tunDestroy(tun.name) + if tun.routeSocket != -1 { + unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) + err3 = unix.Close(tun.routeSocket) + tun.routeSocket = -1 + } else if tun.events != nil { + close(tun.events) + } + }) if err1 != nil { return err1 } @@ -492,70 +398,38 @@ func (tun *NativeTun) Close() error { } func (tun *NativeTun) setMTU(n int) error { - // open datagram socket - - var fd int - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { return err } - defer unix.Close(fd) - // do ioctl call - - var ifr ifreq_mtu + var ifr ifreqMtu copy(ifr.Name[:], tun.name) ifr.MTU = uint32(n) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCSIFMTU), - uintptr(unsafe.Pointer(&ifr)), - ) - + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr))) if errno != 0 { - return fmt.Errorf("failed to set MTU on %s", tun.name) + return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno) } - return nil } func (tun *NativeTun) MTU() (int, error) { - // open datagram socket - - fd, err := unix.Socket( - unix.AF_INET, - unix.SOCK_DGRAM, - 0, - ) - + fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0) if err != nil { return 0, err } - defer unix.Close(fd) - // do ioctl call - var ifr ifreq_mtu + var ifr ifreqMtu copy(ifr.Name[:], tun.name) - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(fd), - uintptr(unix.SIOCGIFMTU), - uintptr(unsafe.Pointer(&ifr)), - ) + _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr))) if errno != 0 { - return 0, fmt.Errorf("failed to get MTU on %s", tun.name) + return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno) } - return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } + +func (tun *NativeTun) BatchSize() int { + return 1 +} diff --git a/tun/tun_linux.go b/tun/tun_linux.go index cc90c45..bd69cb5 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -9,7 +9,6 @@ package tun */ import ( - "bytes" "errors" "fmt" "os" @@ -18,8 +17,8 @@ import ( "time" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/rwcancel" ) @@ -33,15 +32,27 @@ type NativeTun struct { index int32 // if index errors chan error // async error handling events chan Event // device related events - nopi bool // the device was passed IFF_NO_PI netlinkSock int netlinkCancel *rwcancel.RWCancel hackListenerClosed sync.Mutex statusListenersShutdown chan struct{} + batchSize int + vnetHdr bool + udpGSO bool + + closeOnce sync.Once nameOnce sync.Once // guards calling initNameCache, which sets following fields nameCache string // name of interface nameErr error + + readOpMu sync.Mutex // readOpMu guards readBuff + readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr + + writeOpMu sync.Mutex // writeOpMu guards toWrite, tcpGROTable + toWrite []int + tcpGROTable *tcpGROTable + udpGROTable *udpGROTable } func (tun *NativeTun) File() *os.File { @@ -53,6 +64,11 @@ func (tun *NativeTun) routineHackListener() { /* This is needed for the detection to work across network namespaces * If you are reading this and know a better method, please get in touch. */ + last := 0 + const ( + up = 1 + down = 2 + ) for { sysconn, err := tun.tunFile.SyscallConn() if err != nil { @@ -66,13 +82,19 @@ func (tun *NativeTun) routineHackListener() { } switch err { case unix.EINVAL: - // If the tunnel is up, it reports that write() is - // allowed but we provided invalid data. - tun.events <- EventUp + if last != up { + // If the tunnel is up, it reports that write() is + // allowed but we provided invalid data. + tun.events <- EventUp + last = up + } case unix.EIO: - // If the tunnel is down, it reports that no I/O - // is possible, without checking our provided data. - tun.events <- EventDown + if last != down { + // If the tunnel is down, it reports that no I/O + // is possible, without checking our provided data. + tun.events <- EventDown + last = down + } default: return } @@ -86,7 +108,7 @@ func (tun *NativeTun) routineHackListener() { } func createNetlinkSocket() (int, error) { - sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE) + sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE) if err != nil { return -1, err } @@ -106,10 +128,10 @@ func (tun *NativeTun) routineNetlinkListener() { unix.Close(tun.netlinkSock) tun.hackListenerClosed.Lock() close(tun.events) + tun.netlinkCancel.Close() }() for msg := make([]byte, 1<<16); ; { - var err error var msgn int for { @@ -118,12 +140,12 @@ func (tun *NativeTun) routineNetlinkListener() { break } if !tun.netlinkCancel.ReadyRead() { - tun.errors <- fmt.Errorf("netlink socket closed: %s", err.Error()) + tun.errors <- fmt.Errorf("netlink socket closed: %w", err) return } } if err != nil { - tun.errors <- fmt.Errorf("failed to receive netlink message: %s", err.Error()) + tun.errors <- fmt.Errorf("failed to receive netlink message: %w", err) return } @@ -181,7 +203,7 @@ func (tun *NativeTun) routineNetlinkListener() { func getIFIndex(name string) (int32, error) { fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) if err != nil { @@ -215,10 +237,9 @@ func (tun *NativeTun) setMTU(n int) error { // open datagram socket fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return err } @@ -237,7 +258,7 @@ func (tun *NativeTun) setMTU(n int) error { ) if errno != 0 { - return errors.New("failed to set MTU of TUN device") + return fmt.Errorf("failed to set MTU of TUN device: %w", errno) } return nil @@ -252,10 +273,9 @@ func (tun *NativeTun) MTU() (int, error) { // open datagram socket fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return 0, err } @@ -273,7 +293,7 @@ func (tun *NativeTun) MTU() (int, error) { uintptr(unsafe.Pointer(&ifr[0])), ) if errno != 0 { - return 0, errors.New("failed to get MTU of TUN device: " + errno.Error()) + return 0, fmt.Errorf("failed to get MTU of TUN device: %w", errno) } return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil @@ -304,137 +324,273 @@ func (tun *NativeTun) nameSlow() (string, error) { ) }) if err != nil { - return "", errors.New("failed to get name of TUN device: " + err.Error()) + return "", fmt.Errorf("failed to get name of TUN device: %w", err) } if errno != 0 { - return "", errors.New("failed to get name of TUN device: " + errno.Error()) + return "", fmt.Errorf("failed to get name of TUN device: %w", errno) } - name := ifr[:] - if i := bytes.IndexByte(name, 0); i != -1 { - name = name[:i] - } - return string(name), nil + return unix.ByteSliceToString(ifr[:]), nil } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - if tun.nopi { - buff = buff[offset:] +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + tun.writeOpMu.Lock() + defer func() { + tun.tcpGROTable.reset() + tun.udpGROTable.reset() + tun.writeOpMu.Unlock() + }() + var ( + errs error + total int + ) + tun.toWrite = tun.toWrite[:0] + if tun.vnetHdr { + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.udpGSO, &tun.toWrite) + if err != nil { + return 0, err + } + offset -= virtioNetHdrLen } else { - // reserve space for header + for i := range bufs { + tun.toWrite = append(tun.toWrite, i) + } + } + for _, bufsI := range tun.toWrite { + n, err := tun.tunFile.Write(bufs[bufsI][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = errors.Join(errs, err) + } else { + total += n + } + } + return total, errs +} - buff = buff[offset-4:] +// handleVirtioRead splits in into bufs, leaving offset bytes at the front of +// each buffer. It mutates sizes to reflect the size of each element of bufs, +// and returns the number of packets read. +func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { + var hdr virtioNetHdr + err := hdr.decode(in) + if err != nil { + return 0, err + } + in = in[virtioNetHdrLen:] + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE { + if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 { + // This means CHECKSUM_PARTIAL in skb context. We are responsible + // for computing the checksum starting at hdr.csumStart and placing + // at hdr.csumOffset. + err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset) + if err != nil { + return 0, err + } + } + if len(in) > len(bufs[0][offset:]) { + return 0, fmt.Errorf("read len %d overflows bufs element len %d", len(in), len(bufs[0][offset:])) + } + n := copy(bufs[0][offset:], in) + sizes[0] = n + return 1, nil + } + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType) + } - // add packet information header + ipVersion := in[0] >> 4 + switch ipVersion { + case 4: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + case 6: + if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType) + } + default: + return 0, fmt.Errorf("invalid ip header version: %d", ipVersion) + } - buff[0] = 0x00 - buff[1] = 0x00 + // Don't trust hdr.hdrLen from the kernel as it can be equal to the length + // of the entire first packet when the kernel is handling it as part of a + // FORWARD path. Instead, parse the transport header length and add it onto + // csumStart, which is synonymous for IP header length. + if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_UDP_L4 { + hdr.hdrLen = hdr.csumStart + 8 + } else { + if len(in) <= int(hdr.csumStart+12) { + return 0, errors.New("packet is too short") + } - if buff[4]>>4 == ipv6.Version { - buff[2] = 0x86 - buff[3] = 0xdd - } else { - buff[2] = 0x08 - buff[3] = 0x00 + tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4) + if tcpHLen < 20 || tcpHLen > 60 { + // A TCP header must be between 20 and 60 bytes in length. + return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen) } + hdr.hdrLen = hdr.csumStart + tcpHLen } - // write + if len(in) < int(hdr.hdrLen) { + return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen) + } - return tun.tunFile.Write(buff) -} + if hdr.hdrLen < hdr.csumStart { + return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart) + } + cSumAt := int(hdr.csumStart + hdr.csumOffset) + if cSumAt+1 >= len(in) { + return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) + } -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + return gsoSplit(in, hdr, bufs, sizes, offset, ipVersion == 6) } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + tun.readOpMu.Lock() + defer tun.readOpMu.Unlock() select { case err := <-tun.errors: return 0, err default: - if tun.nopi { - return tun.tunFile.Read(buff[offset:]) + readInto := bufs[0][offset:] + if tun.vnetHdr { + readInto = tun.readBuff[:] + } + n, err := tun.tunFile.Read(readInto) + if errors.Is(err, syscall.EBADFD) { + err = os.ErrClosed + } + if err != nil { + return 0, err + } + if tun.vnetHdr { + return handleVirtioRead(readInto[:n], bufs, sizes, offset) } else { - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) - if n < 4 { - return 0, err - } - return n - 4, err + sizes[0] = n + return 1, nil } } } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Close() error { - var err1 error - if tun.statusListenersShutdown != nil { - close(tun.statusListenersShutdown) - if tun.netlinkCancel != nil { - err1 = tun.netlinkCancel.Cancel() + var err1, err2 error + tun.closeOnce.Do(func() { + if tun.statusListenersShutdown != nil { + close(tun.statusListenersShutdown) + if tun.netlinkCancel != nil { + err1 = tun.netlinkCancel.Cancel() + } + } else if tun.events != nil { + close(tun.events) } - } else if tun.events != nil { - close(tun.events) - } - err2 := tun.tunFile.Close() - + err2 = tun.tunFile.Close() + }) if err1 != nil { return err1 } return err2 } +func (tun *NativeTun) BatchSize() int { + return tun.batchSize +} + +const ( + // TODO: support TSO with ECN bits + tunTCPOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6 + tunUDPOffloads = unix.TUN_F_USO4 | unix.TUN_F_USO6 +) + +func (tun *NativeTun) initFromFlags(name string) error { + sc, err := tun.tunFile.SyscallConn() + if err != nil { + return err + } + if e := sc.Control(func(fd uintptr) { + var ( + ifr *unix.Ifreq + ) + ifr, err = unix.NewIfreq(name) + if err != nil { + return + } + err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr) + if err != nil { + return + } + got := ifr.Uint16() + if got&unix.IFF_VNET_HDR != 0 { + // tunTCPOffloads were added in Linux v2.6. We require their support + // if IFF_VNET_HDR is set. + err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads) + if err != nil { + return + } + tun.vnetHdr = true + tun.batchSize = conn.IdealBatchSize + // tunUDPOffloads were added in Linux v6.2. We do not return an + // error if they are unsupported at runtime. + tun.udpGSO = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunTCPOffloads|tunUDPOffloads) == nil + } else { + tun.batchSize = 1 + } + }); e != nil { + return e + } + return err +} + +// CreateTUN creates a Device with the provided name and MTU. func CreateTUN(name string, mtu int) (Device, error) { - nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0) + nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0) if err != nil { if os.IsNotExist(err) { - return nil, fmt.Errorf("CreateTUN(%q) failed; tuntap %s does not exist", name, cloneDevicePath) + return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath) } return nil, err } - var ifr [ifReqSize]byte - var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack) - nameBytes := []byte(name) - if len(nameBytes) >= unix.IFNAMSIZ { - return nil, errors.New("interface name too long") + ifr, err := unix.NewIfreq(name) + if err != nil { + return nil, err } - copy(ifr[:], nameBytes) - *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags - - _, _, errno := unix.Syscall( - unix.SYS_IOCTL, - uintptr(nfd), - uintptr(unix.TUNSETIFF), - uintptr(unsafe.Pointer(&ifr[0])), - ) - if errno != 0 { - return nil, errno + // IFF_VNET_HDR enables the "tun status hack" via routineHackListener() + // where a null write will return EINVAL indicating the TUN is up. + ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR) + err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr) + if err != nil { + return nil, err } - err = unix.SetNonblock(nfd, true) - - // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. - fd := os.NewFile(uintptr(nfd), cloneDevicePath) + err = unix.SetNonblock(nfd, true) if err != nil { + unix.Close(nfd) return nil, err } + // Note that the above -- open,ioctl,nonblock -- must happen prior to handing it to netpoll as below this line. + + fd := os.NewFile(uintptr(nfd), cloneDevicePath) return CreateTUNFromFile(fd, mtu) } +// CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), errors: make(chan error, 5), statusListenersShutdown: make(chan struct{}), - nopi: false, + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() @@ -442,8 +598,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - // start event listener + err = tun.initFromFlags(name) + if err != nil { + return nil, err + } + // start event listener tun.index, err = getIFIndex(name) if err != nil { return nil, err @@ -472,6 +632,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return tun, nil } +// CreateUnmonitoredTUNFromFD creates a Device from the provided file +// descriptor. func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { err := unix.SetNonblock(fd, true) if err != nil { @@ -479,14 +641,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { } file := os.NewFile(uintptr(fd), "/dev/tun") tun := &NativeTun{ - tunFile: file, - events: make(chan Event, 5), - errors: make(chan error, 5), - nopi: true, + tunFile: file, + events: make(chan Event, 5), + errors: make(chan error, 5), + tcpGROTable: newTCPGROTable(), + udpGROTable: newUDPGROTable(), + toWrite: make([]int, 0, conn.IdealBatchSize), } name, err := tun.Name() if err != nil { return nil, "", err } - return tun, name, nil + err = tun.initFromFlags(name) + if err != nil { + return nil, "", err + } + return tun, name, err } diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index 44cedaa..ae571b9 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -1,19 +1,20 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun import ( + "errors" "fmt" - "io/ioutil" + "io" "net" "os" + "sync" "syscall" "unsafe" - "golang.org/x/net/ipv6" "golang.org/x/sys/unix" ) @@ -32,6 +33,7 @@ type NativeTun struct { events chan Event errors chan error routeSocket int + closeOnce sync.Once } func (tun *NativeTun) routineRouteListener(tunIfindex int) { @@ -99,16 +101,6 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) { } } -func errorIsEBUSY(err error) bool { - if pe, ok := err.(*os.PathError); ok { - err = pe.Err - } - if errno, ok := err.(syscall.Errno); ok && errno == syscall.EBUSY { - return true - } - return false -} - func CreateTUN(name string, mtu int) (Device, error) { ifIndex := -1 if name != "tun" { @@ -122,11 +114,11 @@ func CreateTUN(name string, mtu int) (Device, error) { var err error if ifIndex != -1 { - tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0) + tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) } else { - for ifIndex = 0; ifIndex < 256; ifIndex += 1 { - tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0) - if err == nil || !errorIsEBUSY(err) { + for ifIndex = 0; ifIndex < 256; ifIndex++ { + tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR|unix.O_CLOEXEC, 0) + if err == nil || !errors.Is(err, syscall.EBUSY) { break } } @@ -141,7 +133,7 @@ func CreateTUN(name string, mtu int) (Device, error) { if err == nil && name == "tun" { fname := os.Getenv("WG_TUN_NAME_FILE") if fname != "" { - ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400) + os.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0o400) } } @@ -173,7 +165,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { return nil, err } - tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.AF_UNSPEC) if err != nil { tun.tunFile.Close() return nil, err @@ -208,62 +200,61 @@ func (tun *NativeTun) File() *os.File { return tun.tunFile } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { select { case err := <-tun.errors: return 0, err default: - buff := buff[offset-4:] - n, err := tun.tunFile.Read(buff[:]) + buf := bufs[0][offset-4:] + n, err := tun.tunFile.Read(buf[:]) if n < 4 { return 0, err } - return n - 4, err + sizes[0] = n - 4 + return 1, err } } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - - // reserve space for header - - buff = buff[offset-4:] - - // add packet information header - - buff[0] = 0x00 - buff[1] = 0x00 - buff[2] = 0x00 - - if buff[4]>>4 == ipv6.Version { - buff[3] = unix.AF_INET6 - } else { - buff[3] = unix.AF_INET +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + if offset < 4 { + return 0, io.ErrShortBuffer } - - // write - - return tun.tunFile.Write(buff) -} - -func (tun *NativeTun) Flush() error { - // TODO: can flushing be implemented by buffering and using sendmmsg? - return nil + for i, buf := range bufs { + buf = buf[offset-4:] + buf[0] = 0x00 + buf[1] = 0x00 + buf[2] = 0x00 + switch buf[4] >> 4 { + case 4: + buf[3] = unix.AF_INET + case 6: + buf[3] = unix.AF_INET6 + default: + return i, unix.EAFNOSUPPORT + } + if _, err := tun.tunFile.Write(buf); err != nil { + return i, err + } + } + return len(bufs), nil } func (tun *NativeTun) Close() error { - var err2 error - err1 := tun.tunFile.Close() - if tun.routeSocket != -1 { - unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) - err2 = unix.Close(tun.routeSocket) - tun.routeSocket = -1 - } else if tun.events != nil { - close(tun.events) - } + var err1, err2 error + tun.closeOnce.Do(func() { + err1 = tun.tunFile.Close() + if tun.routeSocket != -1 { + unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) + err2 = unix.Close(tun.routeSocket) + tun.routeSocket = -1 + } else if tun.events != nil { + close(tun.events) + } + }) if err1 != nil { return err1 } @@ -277,10 +268,9 @@ func (tun *NativeTun) setMTU(n int) error { fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return err } @@ -312,10 +302,9 @@ func (tun *NativeTun) MTU() (int, error) { fd, err := unix.Socket( unix.AF_INET, - unix.SOCK_DGRAM, + unix.SOCK_DGRAM|unix.SOCK_CLOEXEC, 0, ) - if err != nil { return 0, err } @@ -338,3 +327,7 @@ func (tun *NativeTun) MTU() (int, error) { return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil } + +func (tun *NativeTun) BatchSize() int { + return 1 +} diff --git a/tun/tun_windows.go b/tun/tun_windows.go index b16dbb7..2af8e3e 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -12,11 +12,10 @@ import ( "sync" "sync/atomic" "time" - "unsafe" + _ "unsafe" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/tun/wintun" + "golang.zx2c4.com/wintun" ) const ( @@ -26,25 +25,31 @@ const ( ) type rateJuggler struct { - current uint64 - nextByteCount uint64 - nextStartTime int64 - changing int32 + current atomic.Uint64 + nextByteCount atomic.Uint64 + nextStartTime atomic.Int64 + changing atomic.Bool } type NativeTun struct { - wt *wintun.Interface + wt *wintun.Adapter + name string handle windows.Handle - close bool + rate rateJuggler + session wintun.Session + readWait windows.Handle events chan Event - errors chan error + running sync.WaitGroup + closeOnce sync.Once + close atomic.Bool forcedMTU int - rate rateJuggler - rings *wintun.RingDescriptor - writeLock sync.Mutex + outSizes []int } -const WintunPool = wintun.Pool("WireGuard") +var ( + WintunTunnelType = "WireGuard" + WintunStaticRequestedGUID *windows.GUID +) //go:linkname procyield runtime.procyield func procyield(cycles uint32) @@ -52,34 +57,18 @@ func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 -// // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. -// func CreateTUN(ifname string, mtu int) (Device, error) { - return CreateTUNWithRequestedGUID(ifname, nil, mtu) + return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) } -// // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. -// func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { - var err error - var wt *wintun.Interface - - // Does an interface with this name already exist? - wt, err = WintunPool.GetInterface(ifname) - if err == nil { - // If so, we delete it, in case it has weird residual configuration. - _, err = wt.DeleteInterface() - if err != nil { - return nil, fmt.Errorf("Error deleting already existing interface: %v", err) - } - } - wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID) + wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { - return nil, fmt.Errorf("Error creating interface: %v", err) + return nil, fmt.Errorf("Error creating interface: %w", err) } forcedMTU := 1420 @@ -89,52 +78,46 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu tun := &NativeTun{ wt: wt, + name: ifname, handle: windows.InvalidHandle, events: make(chan Event, 10), - errors: make(chan error, 1), forcedMTU: forcedMTU, } - tun.rings, err = wintun.NewRingDescriptor() - if err != nil { - tun.Close() - return nil, fmt.Errorf("Error creating events: %v", err) - } - - tun.handle, err = tun.wt.Register(tun.rings) + tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB if err != nil { - tun.Close() - return nil, fmt.Errorf("Error registering rings: %v", err) + tun.wt.Close() + close(tun.events) + return nil, fmt.Errorf("Error starting session: %w", err) } + tun.readWait = tun.session.ReadWaitEvent() return tun, nil } func (tun *NativeTun) Name() (string, error) { - return tun.wt.Name() + return tun.name, nil } func (tun *NativeTun) File() *os.File { return nil } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Close() error { - tun.close = true - if tun.rings.Send.TailMoved != 0 { - windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping - } - if tun.handle != windows.InvalidHandle { - windows.CloseHandle(tun.handle) - } - tun.rings.Close() var err error - if tun.wt != nil { - _, err = tun.wt.DeleteInterface() - } - close(tun.events) + tun.closeOnce.Do(func() { + tun.close.Store(true) + windows.SetEvent(tun.readWait) + tun.running.Wait() + tun.session.End() + if tun.wt != nil { + tun.wt.Close() + } + close(tun.events) + }) return err } @@ -144,132 +127,115 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } + update := tun.forcedMTU != mtu tun.forcedMTU = mtu + if update { + tun.events <- EventMTUUpdate + } +} + +func (tun *NativeTun) BatchSize() int { + // TODO: implement batching with wintun + return 1 } // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + tun.running.Add(1) + defer tun.running.Done() retry: - select { - case err := <-tun.errors: - return 0, err - default: - } - if tun.close { - return 0, os.ErrClosed - } - - buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head) - if buffHead >= wintun.PacketCapacity { + if tun.close.Load() { return 0, os.ErrClosed } - start := nanotime() - shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 - var buffTail uint32 + shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { - buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail) - if buffHead != buffTail { - break - } - if tun.close { + if tun.close.Load() { return 0, os.ErrClosed } - if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { - windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE) - goto retry + packet, err := tun.session.ReceivePacket() + switch err { + case nil: + n := copy(bufs[0][offset:], packet) + sizes[0] = n + tun.session.ReleaseReceivePacket(packet) + tun.rate.update(uint64(n)) + return 1, nil + case windows.ERROR_NO_MORE_ITEMS: + if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { + windows.WaitForSingleObject(tun.readWait, windows.INFINITE) + goto retry + } + procyield(1) + continue + case windows.ERROR_HANDLE_EOF: + return 0, os.ErrClosed + case windows.ERROR_INVALID_DATA: + return 0, errors.New("Send ring corrupt") } - procyield(1) - } - if buffTail >= wintun.PacketCapacity { - return 0, os.ErrClosed - } - - buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead) - if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) { - return 0, errors.New("incomplete packet header in send ring") - } - - packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead])) - if packet.Size > wintun.PacketSizeMax { - return 0, errors.New("packet too big in send ring") + return 0, fmt.Errorf("Read failed: %w", err) } - - alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size) - if alignedPacketSize > buffContent { - return 0, errors.New("incomplete packet in send ring") - } - - copy(buff[offset:], packet.Data[:packet.Size]) - buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize) - atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead) - tun.rate.update(uint64(packet.Size)) - return int(packet.Size), nil -} - -func (tun *NativeTun) Flush() error { - return nil } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - if tun.close { +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + tun.running.Add(1) + defer tun.running.Done() + if tun.close.Load() { return 0, os.ErrClosed } - packetSize := uint32(len(buff) - offset) - tun.rate.update(uint64(packetSize)) - alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize) - - tun.writeLock.Lock() - defer tun.writeLock.Unlock() - - buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head) - if buffHead >= wintun.PacketCapacity { - return 0, os.ErrClosed - } - - buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail) - if buffTail >= wintun.PacketCapacity { - return 0, os.ErrClosed - } - - buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment) - if alignedPacketSize > buffSpace { - return 0, nil // Dropping when ring is full. - } - - packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail])) - packet.Size = packetSize - copy(packet.Data[:packetSize], buff[offset:]) - atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize)) - if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 { - windows.SetEvent(tun.rings.Receive.TailMoved) + for i, buf := range bufs { + packetSize := len(buf) - offset + tun.rate.update(uint64(packetSize)) + + packet, err := tun.session.AllocateSendPacket(packetSize) + switch err { + case nil: + // TODO: Explore options to eliminate this copy. + copy(packet, buf[offset:]) + tun.session.SendPacket(packet) + continue + case windows.ERROR_HANDLE_EOF: + return i, os.ErrClosed + case windows.ERROR_BUFFER_OVERFLOW: + continue // Dropping when ring is full. + default: + return i, fmt.Errorf("Write failed: %w", err) + } } - return int(packetSize), nil + return len(bufs), nil } // LUID returns Windows interface instance ID. func (tun *NativeTun) LUID() uint64 { + tun.running.Add(1) + defer tun.running.Done() + if tun.close.Load() { + return 0 + } return tun.wt.LUID() } -// Version returns the version of the Wintun driver and NDIS system currently loaded. -func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) { - return tun.wt.Version() +// RunningVersion returns the running version of the Wintun driver. +func (tun *NativeTun) RunningVersion() (version uint32, err error) { + return wintun.RunningVersion() } func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() - total := atomic.AddUint64(&rate.nextByteCount, packetLen) - period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + total := rate.nextByteCount.Add(packetLen) + period := uint64(now - rate.nextStartTime.Load()) if period >= rateMeasurementGranularity { - if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + if !rate.changing.CompareAndSwap(false, true) { return } - atomic.StoreInt64(&rate.nextStartTime, now) - atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) - atomic.StoreUint64(&rate.nextByteCount, 0) - atomic.StoreInt32(&rate.changing, 0) + rate.nextStartTime.Store(now) + rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) + rate.nextByteCount.Store(0) + rate.changing.Store(false) } } diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index bdd96ac..d07e860 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tuntest @@ -8,13 +8,13 @@ package tuntest import ( "encoding/binary" "io" - "net" + "net/netip" "os" "golang.zx2c4.com/wireguard/tun" ) -func Ping(dst, src net.IP) []byte { +func Ping(dst, src netip.Addr) []byte { localPort := uint16(1337) seq := uint16(0) @@ -40,7 +40,7 @@ func checksum(buf []byte, initial uint16) uint16 { return ^uint16(v) } -func genICMPv4(payload []byte, dst, src net.IP) []byte { +func genICMPv4(payload []byte, dst, src netip.Addr) []byte { const ( icmpv4ProtocolNumber = 1 icmpv4Echo = 8 @@ -50,12 +50,13 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte { ipv4TotalLenOffset = 2 ipv4ChecksumOffset = 10 ttl = 65 + headerSize = ipv4Size + icmpv4Size ) - hdr := make([]byte, ipv4Size+icmpv4Size) + pkt := make([]byte, headerSize+len(payload)) - ip := hdr[0:ipv4Size] - icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size] + ip := pkt[0:ipv4Size] + icmpv4 := pkt[ipv4Size : ipv4Size+icmpv4Size] // https://tools.ietf.org/html/rfc792 icmpv4[0] = icmpv4Echo // type @@ -64,23 +65,20 @@ func genICMPv4(payload []byte, dst, src net.IP) []byte { binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum) // https://tools.ietf.org/html/rfc760 section 3.1 - length := uint16(len(hdr) + len(payload)) + length := uint16(len(pkt)) ip[0] = (4 << 4) | (ipv4Size / 4) binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length) ip[8] = ttl ip[9] = icmpv4ProtocolNumber - copy(ip[12:], src.To4()) - copy(ip[16:], dst.To4()) + copy(ip[12:], src.AsSlice()) + copy(ip[16:], dst.AsSlice()) chksum = ^checksum(ip[:], 0) binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum) - var v []byte - v = append(v, hdr...) - v = append(v, payload...) - return []byte(v) + copy(pkt[headerSize:], payload) + return pkt } -// TODO(crawshaw): find a reusable home for this. package devicetest? type ChannelTUN struct { Inbound chan []byte // incoming packets, closed on TUN close Outbound chan []byte // outbound packets, blocks forever on TUN close @@ -112,38 +110,45 @@ type chTun struct { func (t *chTun) File() *os.File { return nil } -func (t *chTun) Read(data []byte, offset int) (int, error) { +func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { select { case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? + return 0, os.ErrClosed case msg := <-t.c.Outbound: - return copy(data[offset:], msg), nil + n := copy(packets[0][offset:], msg) + sizes[0] = n + return 1, nil } } // Write is called by the wireguard device to deliver a packet for routing. -func (t *chTun) Write(data []byte, offset int) (int, error) { +func (t *chTun) Write(packets [][]byte, offset int) (int, error) { if offset == -1 { close(t.c.closed) close(t.c.events) return 0, io.EOF } - msg := make([]byte, len(data)-offset) - copy(msg, data[offset:]) - select { - case <-t.c.closed: - return 0, io.EOF // TODO(crawshaw): what is the correct error value? - case t.c.Inbound <- msg: - return len(data) - offset, nil + for i, data := range packets { + msg := make([]byte, len(data)-offset) + copy(msg, data[offset:]) + select { + case <-t.c.closed: + return i, os.ErrClosed + case t.c.Inbound <- msg: + } } + return len(packets), nil +} + +func (t *chTun) BatchSize() int { + return 1 } const DefaultMTU = 1420 -func (t *chTun) Flush() error { return nil } -func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } -func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } -func (t *chTun) Events() chan tun.Event { return t.c.events } +func (t *chTun) MTU() (int, error) { return DefaultMTU, nil } +func (t *chTun) Name() (string, error) { return "loopbackTun1", nil } +func (t *chTun) Events() <-chan tun.Event { return t.c.events } func (t *chTun) Close() error { t.Write(nil, -1) return nil diff --git a/tun/wintun/iphlpapi/conversion_windows.go b/tun/wintun/iphlpapi/conversion_windows.go deleted file mode 100644 index a19e961..0000000 --- a/tun/wintun/iphlpapi/conversion_windows.go +++ /dev/null @@ -1,25 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package iphlpapi - -import "golang.org/x/sys/windows" - -//sys convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid -//sys convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) = iphlpapi.ConvertInterfaceAliasToLuid - -func InterfaceGUIDFromAlias(alias string) (*windows.GUID, error) { - var luid uint64 - var guid windows.GUID - err := convertInterfaceAliasToLUID(windows.StringToUTF16Ptr(alias), &luid) - if err != nil { - return nil, err - } - err = convertInterfaceLUIDToGUID(&luid, &guid) - if err != nil { - return nil, err - } - return &guid, nil -} diff --git a/tun/wintun/iphlpapi/mksyscall.go b/tun/wintun/iphlpapi/mksyscall.go deleted file mode 100644 index fc7dba4..0000000 --- a/tun/wintun/iphlpapi/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package iphlpapi - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go diff --git a/tun/wintun/iphlpapi/zsyscall_windows.go b/tun/wintun/iphlpapi/zsyscall_windows.go deleted file mode 100644 index dc14294..0000000 --- a/tun/wintun/iphlpapi/zsyscall_windows.go +++ /dev/null @@ -1,60 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package iphlpapi - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") - - procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid") - procConvertInterfaceAliasToLuid = modiphlpapi.NewProc("ConvertInterfaceAliasToLuid") -) - -func convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) { - r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0) - if r0 != 0 { - ret = syscall.Errno(r0) - } - return -} - -func convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) { - r0, _, _ := syscall.Syscall(procConvertInterfaceAliasToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceAlias)), uintptr(unsafe.Pointer(interfaceLUID)), 0) - if r0 != 0 { - ret = syscall.Errno(r0) - } - return -} diff --git a/tun/wintun/namespace_windows.go b/tun/wintun/namespace_windows.go deleted file mode 100644 index 5f8a041..0000000 --- a/tun/wintun/namespace_windows.go +++ /dev/null @@ -1,101 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "encoding/hex" - "errors" - "fmt" - "sync" - "unsafe" - - "golang.org/x/crypto/blake2s" - "golang.org/x/sys/windows" - "golang.org/x/text/unicode/norm" - - "golang.zx2c4.com/wireguard/tun/wintun/namespaceapi" -) - -var ( - wintunObjectSecurityAttributes *windows.SecurityAttributes - hasInitializedNamespace bool - initializingNamespace sync.Mutex -) - -func initializeNamespace() error { - initializingNamespace.Lock() - defer initializingNamespace.Unlock() - if hasInitializedNamespace { - return nil - } - sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)") - if err != nil { - return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err) - } - wintunObjectSecurityAttributes = &windows.SecurityAttributes{ - Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})), - SecurityDescriptor: sd, - } - sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid) - if err != nil { - return fmt.Errorf("CreateWellKnownSid(LOCAL_SYSTEM) failed: %v", err) - } - - boundary, err := namespaceapi.CreateBoundaryDescriptor("Wintun") - if err != nil { - return fmt.Errorf("CreateBoundaryDescriptor failed: %v", err) - } - err = boundary.AddSid(sid) - if err != nil { - return fmt.Errorf("AddSIDToBoundaryDescriptor failed: %v", err) - } - for { - _, err = namespaceapi.CreatePrivateNamespace(wintunObjectSecurityAttributes, boundary, "Wintun") - if err == windows.ERROR_ALREADY_EXISTS { - _, err = namespaceapi.OpenPrivateNamespace(boundary, "Wintun") - if err == windows.ERROR_PATH_NOT_FOUND { - continue - } - if err != nil { - return fmt.Errorf("OpenPrivateNamespace failed: %v", err) - } - } - if err != nil { - return fmt.Errorf("CreatePrivateNamespace failed: %v", err) - } - break - } - hasInitializedNamespace = true - return nil -} - -func (pool Pool) takeNameMutex() (windows.Handle, error) { - err := initializeNamespace() - if err != nil { - return 0, err - } - - const mutexLabel = "WireGuard Adapter Name Mutex Stable Suffix v1 jason@zx2c4.com" - b2, _ := blake2s.New256(nil) - b2.Write([]byte(mutexLabel)) - b2.Write(norm.NFC.Bytes([]byte(string(pool)))) - mutexName := `Wintun\Wintun-Name-Mutex-` + hex.EncodeToString(b2.Sum(nil)) - mutex, err := windows.CreateMutex(wintunObjectSecurityAttributes, false, windows.StringToUTF16Ptr(mutexName)) - if err != nil { - err = fmt.Errorf("Error creating name mutex: %v", err) - return 0, err - } - event, err := windows.WaitForSingleObject(mutex, windows.INFINITE) - if err != nil { - windows.CloseHandle(mutex) - return 0, fmt.Errorf("Error waiting on name mutex: %v", err) - } - if event != windows.WAIT_OBJECT_0 && event != windows.WAIT_ABANDONED { - windows.CloseHandle(mutex) - return 0, errors.New("Error with event trigger of name mutex") - } - return mutex, nil -} diff --git a/tun/wintun/namespaceapi/mksyscall.go b/tun/wintun/namespaceapi/mksyscall.go deleted file mode 100644 index 93d43b0..0000000 --- a/tun/wintun/namespaceapi/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package namespaceapi - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go diff --git a/tun/wintun/namespaceapi/namespaceapi_windows.go b/tun/wintun/namespaceapi/namespaceapi_windows.go deleted file mode 100644 index a3a6274..0000000 --- a/tun/wintun/namespaceapi/namespaceapi_windows.go +++ /dev/null @@ -1,83 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package namespaceapi - -import "golang.org/x/sys/windows" - -//sys createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) = kernel32.CreateBoundaryDescriptorW -//sys deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) = kernel32.DeleteBoundaryDescriptor -//sys addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) = kernel32.AddSIDToBoundaryDescriptor -//sys createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.CreatePrivateNamespaceW -//sys openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.OpenPrivateNamespaceW -//sys closePrivateNamespace(handle windows.Handle, flags uint32) (err error) = kernel32.ClosePrivateNamespace - -// BoundaryDescriptor represents a boundary that defines how the objects in the namespace are to be isolated. -type BoundaryDescriptor windows.Handle - -// CreateBoundaryDescriptor creates a boundary descriptor. -func CreateBoundaryDescriptor(name string) (BoundaryDescriptor, error) { - name16, err := windows.UTF16PtrFromString(name) - if err != nil { - return 0, err - } - handle, err := createBoundaryDescriptor(name16, 0) - if err != nil { - return 0, err - } - return BoundaryDescriptor(handle), nil -} - -// Delete deletes the specified boundary descriptor. -func (bd BoundaryDescriptor) Delete() { - deleteBoundaryDescriptor(windows.Handle(bd)) -} - -// AddSid adds a security identifier (SID) to the specified boundary descriptor. -func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error { - return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid) -} - -// PrivateNamespace represents a private namespace. -type PrivateNamespace windows.Handle - -// CreatePrivateNamespace creates a private namespace. -func CreatePrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) { - aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix) - if err != nil { - return 0, err - } - handle, err := createPrivateNamespace(privateNamespaceAttributes, windows.Handle(boundaryDescriptor), aliasPrefix16) - if err != nil { - return 0, err - } - return PrivateNamespace(handle), nil -} - -// OpenPrivateNamespace opens a private namespace. -func OpenPrivateNamespace(boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) { - aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix) - if err != nil { - return 0, err - } - handle, err := openPrivateNamespace(windows.Handle(boundaryDescriptor), aliasPrefix16) - if err != nil { - return 0, err - } - return PrivateNamespace(handle), nil -} - -// ClosePrivateNamespaceFlags describes flags that are used by PrivateNamespace's Close() method. -type ClosePrivateNamespaceFlags uint32 - -const ( - // PrivateNamespaceFlagDestroy makes the close to destroy the namespace. - PrivateNamespaceFlagDestroy = ClosePrivateNamespaceFlags(0x1) -) - -// Close closes an open namespace handle. -func (pns PrivateNamespace) Close(flags ClosePrivateNamespaceFlags) error { - return closePrivateNamespace(windows.Handle(pns), uint32(flags)) -} diff --git a/tun/wintun/namespaceapi/zsyscall_windows.go b/tun/wintun/namespaceapi/zsyscall_windows.go deleted file mode 100644 index 508c223..0000000 --- a/tun/wintun/namespaceapi/zsyscall_windows.go +++ /dev/null @@ -1,116 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package namespaceapi - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - - procCreateBoundaryDescriptorW = modkernel32.NewProc("CreateBoundaryDescriptorW") - procDeleteBoundaryDescriptor = modkernel32.NewProc("DeleteBoundaryDescriptor") - procAddSIDToBoundaryDescriptor = modkernel32.NewProc("AddSIDToBoundaryDescriptor") - procCreatePrivateNamespaceW = modkernel32.NewProc("CreatePrivateNamespaceW") - procOpenPrivateNamespaceW = modkernel32.NewProc("OpenPrivateNamespaceW") - procClosePrivateNamespace = modkernel32.NewProc("ClosePrivateNamespace") -) - -func createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procCreateBoundaryDescriptorW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(flags), 0) - handle = windows.Handle(r0) - if handle == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) { - syscall.Syscall(procDeleteBoundaryDescriptor.Addr(), 1, uintptr(boundaryDescriptor), 0, 0) - return -} - -func addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) { - r1, _, e1 := syscall.Syscall(procAddSIDToBoundaryDescriptor.Addr(), 2, uintptr(unsafe.Pointer(boundaryDescriptor)), uintptr(unsafe.Pointer(requiredSid)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procCreatePrivateNamespaceW.Addr(), 3, uintptr(unsafe.Pointer(privateNamespaceAttributes)), uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix))) - handle = windows.Handle(r0) - if handle == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) { - r0, _, e1 := syscall.Syscall(procOpenPrivateNamespaceW.Addr(), 2, uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)), 0) - handle = windows.Handle(r0) - if handle == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func closePrivateNamespace(handle windows.Handle, flags uint32) (err error) { - r1, _, e1 := syscall.Syscall(procClosePrivateNamespace.Addr(), 2, uintptr(handle), uintptr(flags), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} diff --git a/tun/wintun/nci/mksyscall.go b/tun/wintun/nci/mksyscall.go deleted file mode 100644 index 019da93..0000000 --- a/tun/wintun/nci/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package nci - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go diff --git a/tun/wintun/nci/nci_windows.go b/tun/wintun/nci/nci_windows.go deleted file mode 100644 index 9dc6699..0000000 --- a/tun/wintun/nci/nci_windows.go +++ /dev/null @@ -1,28 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package nci - -import "golang.org/x/sys/windows" - -//sys nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) = nci.NciSetConnectionName -//sys nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) = nci.NciGetConnectionName - -func SetConnectionName(guid *windows.GUID, newName string) error { - newName16, err := windows.UTF16PtrFromString(newName) - if err != nil { - return err - } - return nciSetConnectionName(guid, newName16) -} - -func ConnectionName(guid *windows.GUID) (string, error) { - var name [0x400]uint16 - err := nciGetConnectionName(guid, &name[0], uint32(len(name)*2), nil) - if err != nil { - return "", err - } - return windows.UTF16ToString(name[:]), nil -} diff --git a/tun/wintun/nci/zsyscall_windows.go b/tun/wintun/nci/zsyscall_windows.go deleted file mode 100644 index 2a7b79e..0000000 --- a/tun/wintun/nci/zsyscall_windows.go +++ /dev/null @@ -1,60 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package nci - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modnci = windows.NewLazySystemDLL("nci.dll") - - procNciSetConnectionName = modnci.NewProc("NciSetConnectionName") - procNciGetConnectionName = modnci.NewProc("NciGetConnectionName") -) - -func nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) { - r0, _, _ := syscall.Syscall(procNciSetConnectionName.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0) - if r0 != 0 { - ret = syscall.Errno(r0) - } - return -} - -func nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) { - r0, _, _ := syscall.Syscall6(procNciGetConnectionName.Addr(), 4, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(destName)), uintptr(inDestNameBytes), uintptr(unsafe.Pointer(outDestNameBytes)), 0, 0) - if r0 != 0 { - ret = syscall.Errno(r0) - } - return -} diff --git a/tun/wintun/registry/mksyscall.go b/tun/wintun/registry/mksyscall.go deleted file mode 100644 index 6ad82d2..0000000 --- a/tun/wintun/registry/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package registry - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zregistry_windows.go registry_windows.go diff --git a/tun/wintun/registry/registry_windows.go b/tun/wintun/registry/registry_windows.go deleted file mode 100644 index 12a0336..0000000 --- a/tun/wintun/registry/registry_windows.go +++ /dev/null @@ -1,272 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package registry - -import ( - "errors" - "fmt" - "runtime" - "strings" - "time" - "unsafe" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" -) - -const ( - // REG_NOTIFY_CHANGE_NAME notifies the caller if a subkey is added or deleted. - REG_NOTIFY_CHANGE_NAME uint32 = 0x00000001 - - // REG_NOTIFY_CHANGE_ATTRIBUTES notifies the caller of changes to the attributes of the key, such as the security descriptor information. - REG_NOTIFY_CHANGE_ATTRIBUTES uint32 = 0x00000002 - - // REG_NOTIFY_CHANGE_LAST_SET notifies the caller of changes to a value of the key. This can include adding or deleting a value, or changing an existing value. - REG_NOTIFY_CHANGE_LAST_SET uint32 = 0x00000004 - - // REG_NOTIFY_CHANGE_SECURITY notifies the caller of changes to the security descriptor of the key. - REG_NOTIFY_CHANGE_SECURITY uint32 = 0x00000008 - - // REG_NOTIFY_THREAD_AGNOSTIC indicates that the lifetime of the registration must not be tied to the lifetime of the thread issuing the RegNotifyChangeKeyValue call. Note: This flag value is only supported in Windows 8 and later. - REG_NOTIFY_THREAD_AGNOSTIC uint32 = 0x10000000 -) - -//sys regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) = advapi32.RegNotifyChangeKeyValue - -func OpenKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - deadline := time.Now().Add(timeout) - pathSpl := strings.Split(path, "\\") - for i := 0; ; i++ { - keyName := pathSpl[i] - isLast := i+1 == len(pathSpl) - - event, err := windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return 0, fmt.Errorf("Error creating event: %v", err) - } - defer windows.CloseHandle(event) - - var key registry.Key - for { - err = regNotifyChangeKeyValue(windows.Handle(k), false, REG_NOTIFY_CHANGE_NAME, windows.Handle(event), true) - if err != nil { - return 0, fmt.Errorf("Setting up change notification on registry key failed: %v", err) - } - - var accessFlags uint32 - if isLast { - accessFlags = access - } else { - accessFlags = registry.NOTIFY - } - key, err = registry.OpenKey(k, keyName, accessFlags) - if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { - timeout := time.Until(deadline) / time.Millisecond - if timeout < 0 { - timeout = 0 - } - s, err := windows.WaitForSingleObject(event, uint32(timeout)) - if err != nil { - return 0, fmt.Errorf("Unable to wait on registry key: %v", err) - } - if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows - return 0, errors.New("Timeout waiting for registry key") - } - } else if err != nil { - return 0, fmt.Errorf("Error opening registry key %v: %v", path, err) - } else { - if isLast { - return key, nil - } - defer key.Close() - break - } - } - - k = key - } -} - -func WaitForKey(k registry.Key, path string, timeout time.Duration) error { - key, err := OpenKeyWait(k, path, registry.NOTIFY, timeout) - if err != nil { - return err - } - key.Close() - return nil -} - -// -// getValue is more or less the same as windows/registry's getValue. -// -func getValue(k registry.Key, name string, buf []byte) (value []byte, valueType uint32, err error) { - var name16 *uint16 - name16, err = windows.UTF16PtrFromString(name) - if err != nil { - return - } - n := uint32(len(buf)) - for { - err = windows.RegQueryValueEx(windows.Handle(k), name16, nil, &valueType, (*byte)(unsafe.Pointer(&buf[0])), &n) - if err == nil { - value = buf[:n] - return - } - if err != windows.ERROR_MORE_DATA { - return - } - if n <= uint32(len(buf)) { - return - } - buf = make([]byte, n) - } -} - -// -// getValueRetry function reads any value from registry. It waits for -// the registry value to become available or returns error on timeout. -// -// Key must be opened with at least QUERY_VALUE|NOTIFY access. -// -func getValueRetry(key registry.Key, name string, buf []byte, timeout time.Duration) ([]byte, uint32, error) { - runtime.LockOSThread() - defer runtime.UnlockOSThread() - - event, err := windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return nil, 0, fmt.Errorf("Error creating event: %v", err) - } - defer windows.CloseHandle(event) - - deadline := time.Now().Add(timeout) - for { - err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true) - if err != nil { - return nil, 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err) - } - - buf, valueType, err := getValue(key, name, buf) - if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND { - timeout := time.Until(deadline) / time.Millisecond - if timeout < 0 { - timeout = 0 - } - s, err := windows.WaitForSingleObject(event, uint32(timeout)) - if err != nil { - return nil, 0, fmt.Errorf("Unable to wait on registry value: %v", err) - } - if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows - return nil, 0, errors.New("Timeout waiting for registry value") - } - } else if err != nil { - return nil, 0, fmt.Errorf("Error reading registry value %v: %v", name, err) - } else { - return buf, valueType, nil - } - } -} - -func toString(buf []byte, valueType uint32, err error) (string, error) { - if err != nil { - return "", err - } - - var value string - switch valueType { - case registry.SZ, registry.EXPAND_SZ, registry.MULTI_SZ: - if len(buf) == 0 { - return "", nil - } - value = windows.UTF16ToString((*[(1 << 30) - 1]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2]) - - default: - return "", registry.ErrUnexpectedType - } - - if valueType != registry.EXPAND_SZ { - // Value does not require expansion. - return value, nil - } - - valueExp, err := registry.ExpandString(value) - if err != nil { - // Expanding failed: return original sting value. - return value, nil - } - - // Return expanded value. - return valueExp, nil -} - -func toInteger(buf []byte, valueType uint32, err error) (uint64, error) { - if err != nil { - return 0, err - } - - switch valueType { - case registry.DWORD: - if len(buf) != 4 { - return 0, errors.New("DWORD value is not 4 bytes long") - } - var val uint32 - copy((*[4]byte)(unsafe.Pointer(&val))[:], buf) - return uint64(val), nil - - case registry.QWORD: - if len(buf) != 8 { - return 0, errors.New("QWORD value is not 8 bytes long") - } - var val uint64 - copy((*[8]byte)(unsafe.Pointer(&val))[:], buf) - return val, nil - - default: - return 0, registry.ErrUnexpectedType - } -} - -// -// GetStringValueWait function reads a string value from registry. It waits -// for the registry value to become available or returns error on timeout. -// -// Key must be opened with at least QUERY_VALUE|NOTIFY access. -// -// If the value type is REG_EXPAND_SZ the environment variables are expanded. -// Should expanding fail, original string value and nil error are returned. -// -// If the value type is REG_MULTI_SZ only the first string is returned. -// -func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) { - return toString(getValueRetry(key, name, make([]byte, 256), timeout)) -} - -// -// GetStringValue function reads a string value from registry. -// -// Key must be opened with at least QUERY_VALUE access. -// -// If the value type is REG_EXPAND_SZ the environment variables are expanded. -// Should expanding fail, original string value and nil error are returned. -// -// If the value type is REG_MULTI_SZ only the first string is returned. -// -func GetStringValue(key registry.Key, name string) (string, error) { - return toString(getValue(key, name, make([]byte, 256))) -} - -// -// GetIntegerValueWait function reads a DWORD32 or QWORD value from registry. -// It waits for the registry value to become available or returns error on -// timeout. -// -// Key must be opened with at least QUERY_VALUE|NOTIFY access. -// -func GetIntegerValueWait(key registry.Key, name string, timeout time.Duration) (uint64, error) { - return toInteger(getValueRetry(key, name, make([]byte, 8), timeout)) -} diff --git a/tun/wintun/registry/registry_windows_test.go b/tun/wintun/registry/registry_windows_test.go deleted file mode 100644 index c56b51b..0000000 --- a/tun/wintun/registry/registry_windows_test.go +++ /dev/null @@ -1,103 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package registry - -import ( - "testing" - "time" - - "golang.org/x/sys/windows/registry" -) - -const keyRoot = registry.CURRENT_USER -const pathRoot = "Software\\WireGuardRegistryTest" -const path = pathRoot + "\\foobar" -const pathFake = pathRoot + "\\raboof" - -func Test_WaitForKey(t *testing.T) { - registry.DeleteKey(keyRoot, path) - registry.DeleteKey(keyRoot, pathRoot) - go func() { - time.Sleep(time.Second * 1) - key, _, err := registry.CreateKey(keyRoot, pathFake, registry.QUERY_VALUE) - if err != nil { - t.Errorf("Error creating registry key: %v", err) - } - key.Close() - registry.DeleteKey(keyRoot, pathFake) - - key, _, err = registry.CreateKey(keyRoot, path, registry.QUERY_VALUE) - if err != nil { - t.Errorf("Error creating registry key: %v", err) - } - key.Close() - }() - err := WaitForKey(keyRoot, path, time.Second*2) - if err != nil { - t.Errorf("Error waiting for registry key: %v", err) - } - registry.DeleteKey(keyRoot, path) - registry.DeleteKey(keyRoot, pathRoot) - - err = WaitForKey(keyRoot, path, time.Second*1) - if err == nil { - t.Error("Registry key notification expected to timeout but it succeeded.") - } -} - -func Test_GetValueWait(t *testing.T) { - registry.DeleteKey(keyRoot, path) - registry.DeleteKey(keyRoot, pathRoot) - go func() { - time.Sleep(time.Second * 1) - key, _, err := registry.CreateKey(keyRoot, path, registry.SET_VALUE) - if err != nil { - t.Errorf("Error creating registry key: %v", err) - } - time.Sleep(time.Second * 1) - key.SetStringValue("name1", "eulav") - key.SetExpandStringValue("name2", "value") - time.Sleep(time.Second * 1) - key.SetDWordValue("name3", ^uint32(123)) - key.SetDWordValue("name4", 123) - key.Close() - }() - - key, err := OpenKeyWait(keyRoot, path, registry.QUERY_VALUE|registry.NOTIFY, time.Second*2) - if err != nil { - t.Errorf("Error waiting for registry key: %v", err) - } - - valueStr, err := GetStringValueWait(key, "name2", time.Second*2) - if err != nil { - t.Errorf("Error waiting for registry value: %v", err) - } - if valueStr != "value" { - t.Errorf("Wrong value read: %v", valueStr) - } - - _, err = GetStringValueWait(key, "nonexisting", time.Second*1) - if err == nil { - t.Error("Registry value notification expected to timeout but it succeeded.") - } - - valueInt, err := GetIntegerValueWait(key, "name4", time.Second*2) - if err != nil { - t.Errorf("Error waiting for registry value: %v", err) - } - if valueInt != 123 { - t.Errorf("Wrong value read: %v", valueInt) - } - - _, err = GetIntegerValueWait(key, "nonexisting", time.Second*1) - if err == nil { - t.Error("Registry value notification expected to timeout but it succeeded.") - } - - key.Close() - registry.DeleteKey(keyRoot, path) - registry.DeleteKey(keyRoot, pathRoot) -} diff --git a/tun/wintun/registry/zregistry_windows.go b/tun/wintun/registry/zregistry_windows.go deleted file mode 100644 index f7ac33b..0000000 --- a/tun/wintun/registry/zregistry_windows.go +++ /dev/null @@ -1,63 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package registry - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") - - procRegNotifyChangeKeyValue = modadvapi32.NewProc("RegNotifyChangeKeyValue") -) - -func regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) { - var _p0 uint32 - if watchSubtree { - _p0 = 1 - } else { - _p0 = 0 - } - var _p1 uint32 - if asynchronous { - _p1 = 1 - } else { - _p1 = 0 - } - r0, _, _ := syscall.Syscall6(procRegNotifyChangeKeyValue.Addr(), 5, uintptr(key), uintptr(_p0), uintptr(notifyFilter), uintptr(event), uintptr(_p1), 0) - if r0 != 0 { - regerrno = syscall.Errno(r0) - } - return -} diff --git a/tun/wintun/ring_windows.go b/tun/wintun/ring_windows.go deleted file mode 100644 index 8e6b375..0000000 --- a/tun/wintun/ring_windows.go +++ /dev/null @@ -1,117 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "runtime" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - PacketAlignment = 4 // Number of bytes packets are aligned to in rings - PacketSizeMax = 0xffff // Maximum packet size - PacketCapacity = 0x800000 // Ring capacity, 8MiB - PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment - ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14) -) - -type PacketHeader struct { - Size uint32 -} - -type Packet struct { - PacketHeader - Data [PacketSizeMax]byte -} - -type Ring struct { - Head uint32 - Tail uint32 - Alertable int32 - Data [PacketCapacity + PacketTrailingSize]byte -} - -type RingDescriptor struct { - Send, Receive struct { - Size uint32 - Ring *Ring - TailMoved windows.Handle - } -} - -// Wrap returns value modulo ring capacity -func (rb *Ring) Wrap(value uint32) uint32 { - return value & (PacketCapacity - 1) -} - -// Aligns a packet size to PacketAlignment -func PacketAlign(size uint32) uint32 { - return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1) -} - -func NewRingDescriptor() (descriptor *RingDescriptor, err error) { - descriptor = new(RingDescriptor) - allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE) - if err != nil { - return - } - defer func() { - if err != nil { - descriptor.free() - descriptor = nil - } - }() - descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{})) - descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion)) - descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - return - } - - descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{})) - descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{}))) - descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) - if err != nil { - windows.CloseHandle(descriptor.Send.TailMoved) - return - } - runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() }) - return -} - -func (descriptor *RingDescriptor) free() { - if descriptor.Send.Ring != nil { - windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE) - descriptor.Send.Ring = nil - descriptor.Receive.Ring = nil - } -} - -func (descriptor *RingDescriptor) Close() { - if descriptor.Send.TailMoved != 0 { - windows.CloseHandle(descriptor.Send.TailMoved) - descriptor.Send.TailMoved = 0 - } - if descriptor.Send.TailMoved != 0 { - windows.CloseHandle(descriptor.Receive.TailMoved) - descriptor.Receive.TailMoved = 0 - } -} - -func (wintun *Interface) Register(descriptor *RingDescriptor) (windows.Handle, error) { - handle, err := wintun.handle() - if err != nil { - return 0, err - } - var bytesReturned uint32 - err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil) - if err != nil { - return 0, err - } - return handle, nil -} diff --git a/tun/wintun/setupapi/mksyscall.go b/tun/wintun/setupapi/mksyscall.go deleted file mode 100644 index ac103a1..0000000 --- a/tun/wintun/setupapi/mksyscall.go +++ /dev/null @@ -1,8 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsetupapi_windows.go setupapi_windows.go diff --git a/tun/wintun/setupapi/setupapi_windows.go b/tun/wintun/setupapi/setupapi_windows.go deleted file mode 100644 index 60a8eb7..0000000 --- a/tun/wintun/setupapi/setupapi_windows.go +++ /dev/null @@ -1,506 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -import ( - "encoding/binary" - "fmt" - "runtime" - "unsafe" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" -) - -//sys setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiCreateDeviceInfoListExW - -// SetupDiCreateDeviceInfoListEx function creates an empty device information set on a remote or a local computer and optionally associates the set with a device setup class. -func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName string) (deviceInfoSet DevInfo, err error) { - var machineNameUTF16 *uint16 - if machineName != "" { - machineNameUTF16, err = windows.UTF16PtrFromString(machineName) - if err != nil { - return - } - } - return setupDiCreateDeviceInfoListEx(classGUID, hwndParent, machineNameUTF16, 0) -} - -//sys setupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo, deviceInfoSetDetailData *DevInfoListDetailData) (err error) = setupapi.SetupDiGetDeviceInfoListDetailW - -// SetupDiGetDeviceInfoListDetail function retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name. -func SetupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo) (deviceInfoSetDetailData *DevInfoListDetailData, err error) { - data := &DevInfoListDetailData{} - data.size = sizeofDevInfoListDetailData - - return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data) -} - -// DeviceInfoListDetail method retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name. -func (deviceInfoSet DevInfo) DeviceInfoListDetail() (*DevInfoListDetailData, error) { - return SetupDiGetDeviceInfoListDetail(deviceInfoSet) -} - -//sys setupDiCreateDeviceInfo(deviceInfoSet DevInfo, DeviceName *uint16, classGUID *windows.GUID, DeviceDescription *uint16, hwndParent uintptr, CreationFlags DICD, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiCreateDeviceInfoW - -// SetupDiCreateDeviceInfo function creates a new device information element and adds it as a new member to the specified device information set. -func SetupDiCreateDeviceInfo(deviceInfoSet DevInfo, deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (deviceInfoData *DevInfoData, err error) { - deviceNameUTF16, err := windows.UTF16PtrFromString(deviceName) - if err != nil { - return - } - - var deviceDescriptionUTF16 *uint16 - if deviceDescription != "" { - deviceDescriptionUTF16, err = windows.UTF16PtrFromString(deviceDescription) - if err != nil { - return - } - } - - data := &DevInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiCreateDeviceInfo(deviceInfoSet, deviceNameUTF16, classGUID, deviceDescriptionUTF16, hwndParent, creationFlags, data) -} - -// CreateDeviceInfo method creates a new device information element and adds it as a new member to the specified device information set. -func (deviceInfoSet DevInfo) CreateDeviceInfo(deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (*DevInfoData, error) { - return SetupDiCreateDeviceInfo(deviceInfoSet, deviceName, classGUID, deviceDescription, hwndParent, creationFlags) -} - -//sys setupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex uint32, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiEnumDeviceInfo - -// SetupDiEnumDeviceInfo function returns a DevInfoData structure that specifies a device information element in a device information set. -func SetupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex int) (*DevInfoData, error) { - data := &DevInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiEnumDeviceInfo(deviceInfoSet, uint32(memberIndex), data) -} - -// EnumDeviceInfo method returns a DevInfoData structure that specifies a device information element in a device information set. -func (deviceInfoSet DevInfo) EnumDeviceInfo(memberIndex int) (*DevInfoData, error) { - return SetupDiEnumDeviceInfo(deviceInfoSet, memberIndex) -} - -// SetupDiDestroyDeviceInfoList function deletes a device information set and frees all associated memory. -//sys SetupDiDestroyDeviceInfoList(deviceInfoSet DevInfo) (err error) = setupapi.SetupDiDestroyDeviceInfoList - -// Close method deletes a device information set and frees all associated memory. -func (deviceInfoSet DevInfo) Close() error { - return SetupDiDestroyDeviceInfoList(deviceInfoSet) -} - -//sys SetupDiBuildDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) = setupapi.SetupDiBuildDriverInfoList - -// BuildDriverInfoList method builds a list of drivers that is associated with a specific device or with the global class driver list for a device information set. -func (deviceInfoSet DevInfo) BuildDriverInfoList(deviceInfoData *DevInfoData, driverType SPDIT) error { - return SetupDiBuildDriverInfoList(deviceInfoSet, deviceInfoData, driverType) -} - -//sys SetupDiCancelDriverInfoSearch(deviceInfoSet DevInfo) (err error) = setupapi.SetupDiCancelDriverInfoSearch - -// CancelDriverInfoSearch method cancels a driver list search that is currently in progress in a different thread. -func (deviceInfoSet DevInfo) CancelDriverInfoSearch() error { - return SetupDiCancelDriverInfoSearch(deviceInfoSet) -} - -//sys setupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex uint32, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiEnumDriverInfoW - -// SetupDiEnumDriverInfo function enumerates the members of a driver list. -func SetupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex int) (*DrvInfoData, error) { - data := &DrvInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiEnumDriverInfo(deviceInfoSet, deviceInfoData, driverType, uint32(memberIndex), data) -} - -// EnumDriverInfo method enumerates the members of a driver list. -func (deviceInfoSet DevInfo) EnumDriverInfo(deviceInfoData *DevInfoData, driverType SPDIT, memberIndex int) (*DrvInfoData, error) { - return SetupDiEnumDriverInfo(deviceInfoSet, deviceInfoData, driverType, memberIndex) -} - -//sys setupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiGetSelectedDriverW - -// SetupDiGetSelectedDriver function retrieves the selected driver for a device information set or a particular device information element. -func SetupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (*DrvInfoData, error) { - data := &DrvInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiGetSelectedDriver(deviceInfoSet, deviceInfoData, data) -} - -// SelectedDriver method retrieves the selected driver for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) SelectedDriver(deviceInfoData *DevInfoData) (*DrvInfoData, error) { - return SetupDiGetSelectedDriver(deviceInfoSet, deviceInfoData) -} - -//sys SetupDiSetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) = setupapi.SetupDiSetSelectedDriverW - -// SetSelectedDriver method sets, or resets, the selected driver for a device information element or the selected class driver for a device information set. -func (deviceInfoSet DevInfo) SetSelectedDriver(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) error { - return SetupDiSetSelectedDriver(deviceInfoSet, deviceInfoData, driverInfoData) -} - -//sys setupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData, driverInfoDetailData *DrvInfoDetailData, driverInfoDetailDataSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDriverInfoDetailW - -// SetupDiGetDriverInfoDetail function retrieves driver information detail for a device information set or a particular device information element in the device information set. -func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) { - reqSize := uint32(2048) - for { - buf := make([]byte, reqSize) - data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0])) - data.size = sizeofDrvInfoDetailData - err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, uint32(len(buf)), &reqSize) - if err == windows.ERROR_INSUFFICIENT_BUFFER { - continue - } - if err != nil { - return nil, err - } - data.size = reqSize - return data, nil - } -} - -// DriverInfoDetail method retrieves driver information detail for a device information set or a particular device information element in the device information set. -func (deviceInfoSet DevInfo) DriverInfoDetail(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) { - return SetupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData) -} - -//sys SetupDiDestroyDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) = setupapi.SetupDiDestroyDriverInfoList - -// DestroyDriverInfoList method deletes a driver list. -func (deviceInfoSet DevInfo) DestroyDriverInfoList(deviceInfoData *DevInfoData, driverType SPDIT) error { - return SetupDiDestroyDriverInfoList(deviceInfoSet, deviceInfoData, driverType) -} - -//sys setupDiGetClassDevsEx(classGUID *windows.GUID, Enumerator *uint16, hwndParent uintptr, Flags DIGCF, deviceInfoSet DevInfo, machineName *uint16, reserved uintptr) (handle DevInfo, err error) [failretval==DevInfo(windows.InvalidHandle)] = setupapi.SetupDiGetClassDevsExW - -// SetupDiGetClassDevsEx function returns a handle to a device information set that contains requested device information elements for a local or a remote computer. -func SetupDiGetClassDevsEx(classGUID *windows.GUID, enumerator string, hwndParent uintptr, flags DIGCF, deviceInfoSet DevInfo, machineName string) (handle DevInfo, err error) { - var enumeratorUTF16 *uint16 - if enumerator != "" { - enumeratorUTF16, err = windows.UTF16PtrFromString(enumerator) - if err != nil { - return - } - } - var machineNameUTF16 *uint16 - if machineName != "" { - machineNameUTF16, err = windows.UTF16PtrFromString(machineName) - if err != nil { - return - } - } - return setupDiGetClassDevsEx(classGUID, enumeratorUTF16, hwndParent, flags, deviceInfoSet, machineNameUTF16, 0) -} - -// SetupDiCallClassInstaller function calls the appropriate class installer, and any registered co-installers, with the specified installation request (DIF code). -//sys SetupDiCallClassInstaller(installFunction DI_FUNCTION, deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiCallClassInstaller - -// CallClassInstaller member calls the appropriate class installer, and any registered co-installers, with the specified installation request (DIF code). -func (deviceInfoSet DevInfo) CallClassInstaller(installFunction DI_FUNCTION, deviceInfoData *DevInfoData) error { - return SetupDiCallClassInstaller(installFunction, deviceInfoSet, deviceInfoData) -} - -//sys setupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (key windows.Handle, err error) [failretval==windows.InvalidHandle] = setupapi.SetupDiOpenDevRegKey - -// SetupDiOpenDevRegKey function opens a registry key for device-specific configuration information. -func SetupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, scope DICS_FLAG, hwProfile uint32, keyType DIREG, samDesired uint32) (registry.Key, error) { - handle, err := setupDiOpenDevRegKey(deviceInfoSet, deviceInfoData, scope, hwProfile, keyType, samDesired) - return registry.Key(handle), err -} - -// OpenDevRegKey method opens a registry key for device-specific configuration information. -func (deviceInfoSet DevInfo) OpenDevRegKey(DeviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (registry.Key, error) { - return SetupDiOpenDevRegKey(deviceInfoSet, DeviceInfoData, Scope, HwProfile, KeyType, samDesired) -} - -//sys setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceRegistryPropertyW - -// SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property. -func SetupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP) (value interface{}, err error) { - reqSize := uint32(256) - for { - var dataType uint32 - buf := make([]byte, reqSize) - err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(len(buf)), &reqSize) - if err == windows.ERROR_INSUFFICIENT_BUFFER { - continue - } - if err != nil { - return - } - return getRegistryValue(buf[:reqSize], dataType) - } -} - -func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) { - switch dataType { - case windows.REG_SZ: - ret := windows.UTF16ToString(bufToUTF16(buf)) - runtime.KeepAlive(buf) - return ret, nil - case windows.REG_EXPAND_SZ: - ret, err := registry.ExpandString(windows.UTF16ToString(bufToUTF16(buf))) - runtime.KeepAlive(buf) - return ret, err - case windows.REG_BINARY: - return buf, nil - case windows.REG_DWORD_LITTLE_ENDIAN: - return binary.LittleEndian.Uint32(buf), nil - case windows.REG_DWORD_BIG_ENDIAN: - return binary.BigEndian.Uint32(buf), nil - case windows.REG_MULTI_SZ: - bufW := bufToUTF16(buf) - a := []string{} - for i := 0; i < len(bufW); { - j := i + wcslen(bufW[i:]) - if i < j { - a = append(a, windows.UTF16ToString(bufW[i:j])) - } - i = j + 1 - } - runtime.KeepAlive(buf) - return a, nil - case windows.REG_QWORD_LITTLE_ENDIAN: - return binary.LittleEndian.Uint64(buf), nil - default: - return nil, fmt.Errorf("Unsupported registry value type: %v", dataType) - } -} - -// bufToUTF16 function reinterprets []byte buffer as []uint16 -func bufToUTF16(buf []byte) []uint16 { - sl := struct { - addr *uint16 - len int - cap int - }{(*uint16)(unsafe.Pointer(&buf[0])), len(buf) / 2, cap(buf) / 2} - return *(*[]uint16)(unsafe.Pointer(&sl)) -} - -// utf16ToBuf function reinterprets []uint16 as []byte -func utf16ToBuf(buf []uint16) []byte { - sl := struct { - addr *byte - len int - cap int - }{(*byte)(unsafe.Pointer(&buf[0])), len(buf) * 2, cap(buf) * 2} - return *(*[]byte)(unsafe.Pointer(&sl)) -} - -func wcslen(str []uint16) int { - for i := 0; i < len(str); i++ { - if str[i] == 0 { - return i - } - } - return len(str) -} - -// DeviceRegistryProperty method retrieves a specified Plug and Play device property. -func (deviceInfoSet DevInfo) DeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP) (interface{}, error) { - return SetupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property) -} - -//sys setupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffer *byte, propertyBufferSize uint32) (err error) = setupapi.SetupDiSetDeviceRegistryPropertyW - -// SetupDiSetDeviceRegistryProperty function sets a Plug and Play device property for a device. -func SetupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffers []byte) error { - return setupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &propertyBuffers[0], uint32(len(propertyBuffers))) -} - -// SetDeviceRegistryProperty function sets a Plug and Play device property for a device. -func (deviceInfoSet DevInfo) SetDeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP, propertyBuffers []byte) error { - return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers) -} - -// SetDeviceRegistryPropertyString method sets a Plug and Play device property string for a device. -func (deviceInfoSet DevInfo) SetDeviceRegistryPropertyString(deviceInfoData *DevInfoData, property SPDRP, str string) error { - str16, err := windows.UTF16FromString(str) - if err != nil { - return err - } - err = SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, utf16ToBuf(append(str16, 0))) - runtime.KeepAlive(str16) - return err -} - -//sys setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiGetDeviceInstallParamsW - -// SetupDiGetDeviceInstallParams function retrieves device installation parameters for a device information set or a particular device information element. -func SetupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (*DevInstallParams, error) { - params := &DevInstallParams{} - params.size = uint32(unsafe.Sizeof(*params)) - - return params, setupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData, params) -} - -// DeviceInstallParams method retrieves device installation parameters for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) DeviceInstallParams(deviceInfoData *DevInfoData) (*DevInstallParams, error) { - return SetupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData) -} - -//sys setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceInstanceIdW - -// SetupDiGetDeviceInstanceId function retrieves the instance ID of the device. -func SetupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (string, error) { - reqSize := uint32(1024) - for { - buf := make([]uint16, reqSize) - err := setupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData, &buf[0], uint32(len(buf)), &reqSize) - if err == windows.ERROR_INSUFFICIENT_BUFFER { - continue - } - if err != nil { - return "", err - } - return windows.UTF16ToString(buf), nil - } -} - -// DeviceInstanceID method retrieves the instance ID of the device. -func (deviceInfoSet DevInfo) DeviceInstanceID(deviceInfoData *DevInfoData) (string, error) { - return SetupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData) -} - -// SetupDiGetClassInstallParams function retrieves class installation parameters for a device information set or a particular device information element. -//sys SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetClassInstallParamsW - -// ClassInstallParams method retrieves class installation parameters for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) ClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) error { - return SetupDiGetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize, requiredSize) -} - -//sys SetupDiSetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiSetDeviceInstallParamsW - -// SetDeviceInstallParams member sets device installation parameters for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) SetDeviceInstallParams(deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) error { - return SetupDiSetDeviceInstallParams(deviceInfoSet, deviceInfoData, deviceInstallParams) -} - -// SetupDiSetClassInstallParams function sets or clears class install parameters for a device information set or a particular device information element. -//sys SetupDiSetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) (err error) = setupapi.SetupDiSetClassInstallParamsW - -// SetClassInstallParams method sets or clears class install parameters for a device information set or a particular device information element. -func (deviceInfoSet DevInfo) SetClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) error { - return SetupDiSetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize) -} - -//sys setupDiClassNameFromGuidEx(classGUID *windows.GUID, className *uint16, classNameSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) = setupapi.SetupDiClassNameFromGuidExW - -// SetupDiClassNameFromGuidEx function retrieves the class name associated with a class GUID. The class can be installed on a local or remote computer. -func SetupDiClassNameFromGuidEx(classGUID *windows.GUID, machineName string) (className string, err error) { - var classNameUTF16 [MAX_CLASS_NAME_LEN]uint16 - - var machineNameUTF16 *uint16 - if machineName != "" { - machineNameUTF16, err = windows.UTF16PtrFromString(machineName) - if err != nil { - return - } - } - - err = setupDiClassNameFromGuidEx(classGUID, &classNameUTF16[0], MAX_CLASS_NAME_LEN, nil, machineNameUTF16, 0) - if err != nil { - return - } - - className = windows.UTF16ToString(classNameUTF16[:]) - return -} - -//sys setupDiClassGuidsFromNameEx(className *uint16, classGuidList *windows.GUID, classGuidListSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) = setupapi.SetupDiClassGuidsFromNameExW - -// SetupDiClassGuidsFromNameEx function retrieves the GUIDs associated with the specified class name. This resulting list contains the classes currently installed on a local or remote computer. -func SetupDiClassGuidsFromNameEx(className string, machineName string) ([]windows.GUID, error) { - classNameUTF16, err := windows.UTF16PtrFromString(className) - if err != nil { - return nil, err - } - - var machineNameUTF16 *uint16 - if machineName != "" { - machineNameUTF16, err = windows.UTF16PtrFromString(machineName) - if err != nil { - return nil, err - } - } - - reqSize := uint32(4) - for { - buf := make([]windows.GUID, reqSize) - err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], uint32(len(buf)), &reqSize, machineNameUTF16, 0) - if err == windows.ERROR_INSUFFICIENT_BUFFER { - continue - } - if err != nil { - return nil, err - } - return buf[:reqSize], nil - } -} - -//sys setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiGetSelectedDevice - -// SetupDiGetSelectedDevice function retrieves the selected device information element in a device information set. -func SetupDiGetSelectedDevice(deviceInfoSet DevInfo) (*DevInfoData, error) { - data := &DevInfoData{} - data.size = uint32(unsafe.Sizeof(*data)) - - return data, setupDiGetSelectedDevice(deviceInfoSet, data) -} - -// SelectedDevice method retrieves the selected device information element in a device information set. -func (deviceInfoSet DevInfo) SelectedDevice() (*DevInfoData, error) { - return SetupDiGetSelectedDevice(deviceInfoSet) -} - -// SetupDiSetSelectedDevice function sets a device information element as the selected member of a device information set. This function is typically used by an installation wizard. -//sys SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiSetSelectedDevice - -// SetSelectedDevice method sets a device information element as the selected member of a device information set. This function is typically used by an installation wizard. -func (deviceInfoSet DevInfo) SetSelectedDevice(deviceInfoData *DevInfoData) error { - return SetupDiSetSelectedDevice(deviceInfoSet, deviceInfoData) -} - -//sys cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_List_SizeW -//sys cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_ListW - -func CM_Get_Device_Interface_List(deviceID string, interfaceClass *windows.GUID, flags uint32) ([]string, error) { - deviceID16, err := windows.UTF16PtrFromString(deviceID) - if err != nil { - return nil, err - } - var buf []uint16 - var buflen uint32 - for { - if ret := cm_Get_Device_Interface_List_Size(&buflen, interfaceClass, deviceID16, flags); ret != CR_SUCCESS { - return nil, fmt.Errorf("CfgMgr error: 0x%x", ret) - } - buf = make([]uint16, buflen) - if ret := cm_Get_Device_Interface_List(interfaceClass, deviceID16, &buf[0], buflen, flags); ret == CR_SUCCESS { - break - } else if ret != CR_BUFFER_SMALL { - return nil, fmt.Errorf("CfgMgr error: 0x%x", ret) - } - } - var interfaces []string - for i := 0; i < len(buf); { - j := i + wcslen(buf[i:]) - if i < j { - interfaces = append(interfaces, windows.UTF16ToString(buf[i:j])) - } - i = j + 1 - } - if interfaces == nil { - return nil, fmt.Errorf("no interfaces found") - } - return interfaces, nil -} diff --git a/tun/wintun/setupapi/setupapi_windows_test.go b/tun/wintun/setupapi/setupapi_windows_test.go deleted file mode 100644 index a9e6b89..0000000 --- a/tun/wintun/setupapi/setupapi_windows_test.go +++ /dev/null @@ -1,488 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -import ( - "runtime" - "strings" - "testing" - - "golang.org/x/sys/windows" -) - -var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}} -var computerName string - -func init() { - computerName, _ = windows.ComputerName() -} - -func TestSetupDiCreateDeviceInfoListEx(t *testing.T) { - devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "") - if err != nil { - t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error()) - } else { - devInfoList.Close() - } - - devInfoList, err = SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName) - if err != nil { - t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error()) - } else { - devInfoList.Close() - } - - devInfoList, err = SetupDiCreateDeviceInfoListEx(nil, 0, "") - if err != nil { - t.Errorf("Error calling SetupDiCreateDeviceInfoListEx(nil): %s", err.Error()) - } else { - devInfoList.Close() - } -} - -func TestSetupDiGetDeviceInfoListDetail(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - data, err := devInfoList.DeviceInfoListDetail() - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error()) - } else { - if data.ClassGUID != deviceClassNetGUID { - t.Error("SetupDiGetDeviceInfoListDetail returned different class GUID") - } - - if data.RemoteMachineHandle != windows.Handle(0) { - t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine handle") - } - - if data.RemoteMachineName() != "" { - t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine name") - } - } - - devInfoList, err = SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), computerName) - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - data, err = devInfoList.DeviceInfoListDetail() - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error()) - } else { - if data.ClassGUID != deviceClassNetGUID { - t.Error("SetupDiGetDeviceInfoListDetail returned different class GUID") - } - - if data.RemoteMachineHandle == windows.Handle(0) { - t.Error("SetupDiGetDeviceInfoListDetail returned NULL remote machine handle") - } - - if data.RemoteMachineName() != computerName { - t.Error("SetupDiGetDeviceInfoListDetail returned different remote machine name") - } - } - - data = &DevInfoListDetailData{} - data.SetRemoteMachineName("foobar") - if data.RemoteMachineName() != "foobar" { - t.Error("DevInfoListDetailData.(Get|Set)RemoteMachineName() differ") - } -} - -func TestSetupDiCreateDeviceInfo(t *testing.T) { - devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName) - if err != nil { - t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error()) - } - defer devInfoList.Close() - - deviceClassNetName, err := SetupDiClassNameFromGuidEx(&deviceClassNetGUID, computerName) - if err != nil { - t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error()) - } - - devInfoData, err := devInfoList.CreateDeviceInfo(deviceClassNetName, &deviceClassNetGUID, "This is a test device", 0, DICD_GENERATE_ID) - if err != nil { - // Access denied is expected, as the SetupDiCreateDeviceInfo() require elevation to succeed. - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_ACCESS_DENIED { - t.Errorf("Error calling SetupDiCreateDeviceInfo: %s", err.Error()) - } - } else if devInfoData.ClassGUID != deviceClassNetGUID { - t.Error("SetupDiCreateDeviceInfo returned different class GUID") - } -} - -func TestSetupDiEnumDeviceInfo(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - if data.ClassGUID != deviceClassNetGUID { - t.Error("SetupDiEnumDeviceInfo returned different class GUID") - } - - _, err = devInfoList.DeviceInstanceID(data) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceInstanceId: %s", err.Error()) - } - } -} - -func TestDevInfo_BuildDriverInfoList(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - deviceData, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - const driverType SPDIT = SPDIT_COMPATDRIVER - err = devInfoList.BuildDriverInfoList(deviceData, driverType) - if err != nil { - t.Errorf("Error calling SetupDiBuildDriverInfoList: %s", err.Error()) - } - defer devInfoList.DestroyDriverInfoList(deviceData, driverType) - - var selectedDriverData *DrvInfoData - for j := 0; true; j++ { - driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - if driverData.DriverType == 0 { - continue - } - - if !driverData.IsNewer(windows.Filetime{}, 0) { - t.Error("Driver should have non-zero date and version") - } - if !driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime}, 0) { - t.Error("Driver should have non-zero date and version") - } - if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime + 1}, 0) { - t.Error("Driver should report newer version on high-date-time") - } - if !driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, 0) { - t.Error("Driver should have non-zero version") - } - if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime + 1}, 0) { - t.Error("Driver should report newer version on low-date-time") - } - if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, driverData.DriverVersion) { - t.Error("Driver should not be newer than itself") - } - if driverData.IsNewer(windows.Filetime{HighDateTime: driverData.DriverDate.HighDateTime, LowDateTime: driverData.DriverDate.LowDateTime}, driverData.DriverVersion+1) { - t.Error("Driver should report newer version on version") - } - - err = devInfoList.SetSelectedDriver(deviceData, driverData) - if err != nil { - t.Errorf("Error calling SetupDiSetSelectedDriver: %s", err.Error()) - } else { - selectedDriverData = driverData - } - - driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) - if err != nil { - t.Errorf("Error calling SetupDiGetDriverInfoDetail: %s", err.Error()) - } - - if driverDetailData.IsCompatible("foobar-aab6e3a4-144e-4786-88d3-6cec361e1edd") { - t.Error("Invalid HWID compatibitlity reported") - } - if !driverDetailData.IsCompatible(strings.ToUpper(driverDetailData.HardwareID())) { - t.Error("HWID compatibitlity missed") - } - a := driverDetailData.CompatIDs() - for k := range a { - if !driverDetailData.IsCompatible(strings.ToUpper(a[k])) { - t.Error("HWID compatibitlity missed") - } - } - } - - selectedDriverData2, err := devInfoList.SelectedDriver(deviceData) - if err != nil { - t.Errorf("Error calling SetupDiGetSelectedDriver: %s", err.Error()) - } else if *selectedDriverData != *selectedDriverData2 { - t.Error("SetupDiGetSelectedDriver should return driver selected with SetupDiSetSelectedDriver") - } - } - - data := &DrvInfoData{} - data.SetDescription("foobar") - if data.Description() != "foobar" { - t.Error("DrvInfoData.(Get|Set)Description() differ") - } - data.SetMfgName("foobar") - if data.MfgName() != "foobar" { - t.Error("DrvInfoData.(Get|Set)MfgName() differ") - } - data.SetProviderName("foobar") - if data.ProviderName() != "foobar" { - t.Error("DrvInfoData.(Get|Set)ProviderName() differ") - } -} - -func TestSetupDiGetClassDevsEx(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "PCI", 0, DIGCF_PRESENT, DevInfo(0), computerName) - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } else { - devInfoList.Close() - } - - devInfoList, err = SetupDiGetClassDevsEx(nil, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER { - t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER") - } - } else { - devInfoList.Close() - t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail") - } -} - -func TestSetupDiOpenDevRegKey(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - key, err := devInfoList.OpenDevRegKey(data, DICS_FLAG_GLOBAL, 0, DIREG_DRV, windows.KEY_READ) - if err != nil { - t.Errorf("Error calling SetupDiOpenDevRegKey: %s", err.Error()) - } - defer key.Close() - } -} - -func TestSetupDiGetDeviceRegistryProperty(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - val, err := devInfoList.DeviceRegistryProperty(data, SPDRP_CLASS) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASS): %s", err.Error()) - } else if class, ok := val.(string); !ok || strings.ToLower(class) != "net" { - t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASS) should return \"Net\"") - } - - val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CLASSGUID) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error()) - } else if valStr, ok := val.(string); !ok { - t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return string") - } else { - classGUID, err := windows.GUIDFromString(valStr) - if err != nil { - t.Errorf("Error parsing GUID returned by SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error()) - } else if classGUID != deviceClassNetGUID { - t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return %x", deviceClassNetGUID) - } - } - - val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS) - if err != nil { - // Some devices have no SPDRP_COMPATIBLEIDS. - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_DATA { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error()) - } - } - - val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CONFIGFLAGS) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CONFIGFLAGS): %s", err.Error()) - } - - val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_DEVICE_POWER_DATA) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_DEVICE_POWER_DATA): %s", err.Error()) - } - } -} - -func TestSetupDiGetDeviceInstallParams(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - _, err = devInfoList.DeviceInstallParams(data) - if err != nil { - t.Errorf("Error calling SetupDiGetDeviceInstallParams: %s", err.Error()) - } - } - - params := &DevInstallParams{} - params.SetDriverPath("foobar") - if params.DriverPath() != "foobar" { - t.Error("DevInstallParams.(Get|Set)DriverPath() differ") - } -} - -func TestSetupDiClassNameFromGuidEx(t *testing.T) { - deviceClassNetName, err := SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "") - if err != nil { - t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error()) - } else if strings.ToLower(deviceClassNetName) != "net" { - t.Errorf("SetupDiClassNameFromGuidEx(%x) should return \"Net\"", deviceClassNetGUID) - } - - deviceClassNetName, err = SetupDiClassNameFromGuidEx(&deviceClassNetGUID, computerName) - if err != nil { - t.Errorf("Error calling SetupDiClassNameFromGuidEx: %s", err.Error()) - } else if strings.ToLower(deviceClassNetName) != "net" { - t.Errorf("SetupDiClassNameFromGuidEx(%x) should return \"Net\"", deviceClassNetGUID) - } - - _, err = SetupDiClassNameFromGuidEx(nil, "") - if err != nil { - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_USER_BUFFER { - t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER") - } - } else { - t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail") - } -} - -func TestSetupDiClassGuidsFromNameEx(t *testing.T) { - ClassGUIDs, err := SetupDiClassGuidsFromNameEx("Net", "") - if err != nil { - t.Errorf("Error calling SetupDiClassGuidsFromNameEx: %s", err.Error()) - } else { - found := false - for i := range ClassGUIDs { - if ClassGUIDs[i] == deviceClassNetGUID { - found = true - break - } - } - if !found { - t.Errorf("SetupDiClassGuidsFromNameEx(\"Net\") should return %x", deviceClassNetGUID) - } - } - - ClassGUIDs, err = SetupDiClassGuidsFromNameEx("foobar-34274a51-a6e6-45f0-80d6-c62be96dd5fe", computerName) - if err != nil { - t.Errorf("Error calling SetupDiClassGuidsFromNameEx: %s", err.Error()) - } else if len(ClassGUIDs) != 0 { - t.Errorf("SetupDiClassGuidsFromNameEx(\"foobar-34274a51-a6e6-45f0-80d6-c62be96dd5fe\") should return an empty GUID set") - } -} - -func TestSetupDiGetSelectedDevice(t *testing.T) { - devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, DIGCF_PRESENT, DevInfo(0), "") - if err != nil { - t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) - } - defer devInfoList.Close() - - for i := 0; true; i++ { - data, err := devInfoList.EnumDeviceInfo(i) - if err != nil { - if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - err = devInfoList.SetSelectedDevice(data) - if err != nil { - t.Errorf("Error calling SetupDiSetSelectedDevice: %s", err.Error()) - } - - data2, err := devInfoList.SelectedDevice() - if err != nil { - t.Errorf("Error calling SetupDiGetSelectedDevice: %s", err.Error()) - } else if *data != *data2 { - t.Error("SetupDiGetSelectedDevice returned different data than was set by SetupDiSetSelectedDevice") - } - } - - err = devInfoList.SetSelectedDevice(nil) - if err != nil { - if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER { - t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER") - } - } else { - t.Errorf("SetupDiSetSelectedDevice(nil) should fail") - } -} - -func TestUTF16ToBuf(t *testing.T) { - buf := []uint16{0x0123, 0x4567, 0x89ab, 0xcdef} - buf2 := utf16ToBuf(buf) - if len(buf)*2 != len(buf2) || - cap(buf)*2 != cap(buf2) || - buf2[0] != 0x23 || buf2[1] != 0x01 || - buf2[2] != 0x67 || buf2[3] != 0x45 || - buf2[4] != 0xab || buf2[5] != 0x89 || - buf2[6] != 0xef || buf2[7] != 0xcd { - t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER") - } - runtime.KeepAlive(buf) -} diff --git a/tun/wintun/setupapi/types_windows.go b/tun/wintun/setupapi/types_windows.go deleted file mode 100644 index 239d80b..0000000 --- a/tun/wintun/setupapi/types_windows.go +++ /dev/null @@ -1,568 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -import ( - "strings" - "unsafe" - - "golang.org/x/sys/windows" -) - -const ( - MAX_DEVICE_ID_LEN = 200 - MAX_DEVNODE_ID_LEN = MAX_DEVICE_ID_LEN - MAX_GUID_STRING_LEN = 39 // 38 chars + terminator null - MAX_CLASS_NAME_LEN = 32 - MAX_PROFILE_LEN = 80 - MAX_CONFIG_VALUE = 9999 - MAX_INSTANCE_VALUE = 9999 - CONFIGMG_VERSION = 0x0400 -) - -// -// Define maximum string length constants -// -const ( - ANYSIZE_ARRAY = 1 - LINE_LEN = 256 // Windows 9x-compatible maximum for displayable strings coming from a device INF. - MAX_INF_STRING_LENGTH = 4096 // Actual maximum size of an INF string (including string substitutions). - MAX_INF_SECTION_NAME_LENGTH = 255 // For Windows 9x compatibility, INF section names should be constrained to 32 characters. - MAX_TITLE_LEN = 60 - MAX_INSTRUCTION_LEN = 256 - MAX_LABEL_LEN = 30 - MAX_SERVICE_NAME_LEN = 256 - MAX_SUBTITLE_LEN = 256 -) - -const ( - // SP_MAX_MACHINENAME_LENGTH defines maximum length of a machine name in the format expected by ConfigMgr32 CM_Connect_Machine (i.e., "\\\\MachineName\0"). - SP_MAX_MACHINENAME_LENGTH = windows.MAX_PATH + 3 -) - -// HSPFILEQ is type for setup file queue -type HSPFILEQ uintptr - -// DevInfo holds reference to device information set -type DevInfo windows.Handle - -// DevInfoData is a device information structure (references a device instance that is a member of a device information set) -type DevInfoData struct { - size uint32 - ClassGUID windows.GUID - DevInst uint32 // DEVINST handle - _ uintptr -} - -// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supersedes the functionality of SetupDiGetDeviceInfoListClass). -type DevInfoListDetailData struct { - size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const. - ClassGUID windows.GUID - RemoteMachineHandle windows.Handle - remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16 -} - -func (data *DevInfoListDetailData) RemoteMachineName() string { - return windows.UTF16ToString(data.remoteMachineName[:]) -} - -func (data *DevInfoListDetailData) SetRemoteMachineName(remoteMachineName string) error { - str, err := windows.UTF16FromString(remoteMachineName) - if err != nil { - return err - } - copy(data.remoteMachineName[:], str) - return nil -} - -// DI_FUNCTION is function type for device installer -type DI_FUNCTION uint32 - -const ( - DIF_SELECTDEVICE DI_FUNCTION = 0x00000001 - DIF_INSTALLDEVICE DI_FUNCTION = 0x00000002 - DIF_ASSIGNRESOURCES DI_FUNCTION = 0x00000003 - DIF_PROPERTIES DI_FUNCTION = 0x00000004 - DIF_REMOVE DI_FUNCTION = 0x00000005 - DIF_FIRSTTIMESETUP DI_FUNCTION = 0x00000006 - DIF_FOUNDDEVICE DI_FUNCTION = 0x00000007 - DIF_SELECTCLASSDRIVERS DI_FUNCTION = 0x00000008 - DIF_VALIDATECLASSDRIVERS DI_FUNCTION = 0x00000009 - DIF_INSTALLCLASSDRIVERS DI_FUNCTION = 0x0000000A - DIF_CALCDISKSPACE DI_FUNCTION = 0x0000000B - DIF_DESTROYPRIVATEDATA DI_FUNCTION = 0x0000000C - DIF_VALIDATEDRIVER DI_FUNCTION = 0x0000000D - DIF_DETECT DI_FUNCTION = 0x0000000F - DIF_INSTALLWIZARD DI_FUNCTION = 0x00000010 - DIF_DESTROYWIZARDDATA DI_FUNCTION = 0x00000011 - DIF_PROPERTYCHANGE DI_FUNCTION = 0x00000012 - DIF_ENABLECLASS DI_FUNCTION = 0x00000013 - DIF_DETECTVERIFY DI_FUNCTION = 0x00000014 - DIF_INSTALLDEVICEFILES DI_FUNCTION = 0x00000015 - DIF_UNREMOVE DI_FUNCTION = 0x00000016 - DIF_SELECTBESTCOMPATDRV DI_FUNCTION = 0x00000017 - DIF_ALLOW_INSTALL DI_FUNCTION = 0x00000018 - DIF_REGISTERDEVICE DI_FUNCTION = 0x00000019 - DIF_NEWDEVICEWIZARD_PRESELECT DI_FUNCTION = 0x0000001A - DIF_NEWDEVICEWIZARD_SELECT DI_FUNCTION = 0x0000001B - DIF_NEWDEVICEWIZARD_PREANALYZE DI_FUNCTION = 0x0000001C - DIF_NEWDEVICEWIZARD_POSTANALYZE DI_FUNCTION = 0x0000001D - DIF_NEWDEVICEWIZARD_FINISHINSTALL DI_FUNCTION = 0x0000001E - DIF_INSTALLINTERFACES DI_FUNCTION = 0x00000020 - DIF_DETECTCANCEL DI_FUNCTION = 0x00000021 - DIF_REGISTER_COINSTALLERS DI_FUNCTION = 0x00000022 - DIF_ADDPROPERTYPAGE_ADVANCED DI_FUNCTION = 0x00000023 - DIF_ADDPROPERTYPAGE_BASIC DI_FUNCTION = 0x00000024 - DIF_TROUBLESHOOTER DI_FUNCTION = 0x00000026 - DIF_POWERMESSAGEWAKE DI_FUNCTION = 0x00000027 - DIF_ADDREMOTEPROPERTYPAGE_ADVANCED DI_FUNCTION = 0x00000028 - DIF_UPDATEDRIVER_UI DI_FUNCTION = 0x00000029 - DIF_FINISHINSTALL_ACTION DI_FUNCTION = 0x0000002A -) - -// DevInstallParams is device installation parameters structure (associated with a particular device information element, or globally with a device information set) -type DevInstallParams struct { - size uint32 - Flags DI_FLAGS - FlagsEx DI_FLAGSEX - hwndParent uintptr - InstallMsgHandler uintptr - InstallMsgHandlerContext uintptr - FileQueue HSPFILEQ - _ uintptr - _ uint32 - driverPath [windows.MAX_PATH]uint16 -} - -func (params *DevInstallParams) DriverPath() string { - return windows.UTF16ToString(params.driverPath[:]) -} - -func (params *DevInstallParams) SetDriverPath(driverPath string) error { - str, err := windows.UTF16FromString(driverPath) - if err != nil { - return err - } - copy(params.driverPath[:], str) - return nil -} - -// DI_FLAGS is SP_DEVINSTALL_PARAMS.Flags values -type DI_FLAGS uint32 - -const ( - // Flags for choosing a device - DI_SHOWOEM DI_FLAGS = 0x00000001 // support Other... button - DI_SHOWCOMPAT DI_FLAGS = 0x00000002 // show compatibility list - DI_SHOWCLASS DI_FLAGS = 0x00000004 // show class list - DI_SHOWALL DI_FLAGS = 0x00000007 // both class & compat list shown - DI_NOVCP DI_FLAGS = 0x00000008 // don't create a new copy queue--use caller-supplied FileQueue - DI_DIDCOMPAT DI_FLAGS = 0x00000010 // Searched for compatible devices - DI_DIDCLASS DI_FLAGS = 0x00000020 // Searched for class devices - DI_AUTOASSIGNRES DI_FLAGS = 0x00000040 // No UI for resources if possible - - // Flags returned by DiInstallDevice to indicate need to reboot/restart - DI_NEEDRESTART DI_FLAGS = 0x00000080 // Reboot required to take effect - DI_NEEDREBOOT DI_FLAGS = 0x00000100 // "" - - // Flags for device installation - DI_NOBROWSE DI_FLAGS = 0x00000200 // no Browse... in InsertDisk - - // Flags set by DiBuildDriverInfoList - DI_MULTMFGS DI_FLAGS = 0x00000400 // Set if multiple manufacturers in class driver list - - // Flag indicates that device is disabled - DI_DISABLED DI_FLAGS = 0x00000800 // Set if device disabled - - // Flags for Device/Class Properties - DI_GENERALPAGE_ADDED DI_FLAGS = 0x00001000 - DI_RESOURCEPAGE_ADDED DI_FLAGS = 0x00002000 - - // Flag to indicate the setting properties for this Device (or class) caused a change so the Dev Mgr UI probably needs to be updated. - DI_PROPERTIES_CHANGE DI_FLAGS = 0x00004000 - - // Flag to indicate that the sorting from the INF file should be used. - DI_INF_IS_SORTED DI_FLAGS = 0x00008000 - - // Flag to indicate that only the the INF specified by SP_DEVINSTALL_PARAMS.DriverPath should be searched. - DI_ENUMSINGLEINF DI_FLAGS = 0x00010000 - - // Flag that prevents ConfigMgr from removing/re-enumerating devices during device - // registration, installation, and deletion. - DI_DONOTCALLCONFIGMG DI_FLAGS = 0x00020000 - - // The following flag can be used to install a device disabled - DI_INSTALLDISABLED DI_FLAGS = 0x00040000 - - // Flag that causes SetupDiBuildDriverInfoList to build a device's compatible driver - // list from its existing class driver list, instead of the normal INF search. - DI_COMPAT_FROM_CLASS DI_FLAGS = 0x00080000 - - // This flag is set if the Class Install params should be used. - DI_CLASSINSTALLPARAMS DI_FLAGS = 0x00100000 - - // This flag is set if the caller of DiCallClassInstaller does NOT want the internal default action performed if the Class installer returns ERROR_DI_DO_DEFAULT. - DI_NODI_DEFAULTACTION DI_FLAGS = 0x00200000 - - // Flags for device installation - DI_QUIETINSTALL DI_FLAGS = 0x00800000 // don't confuse the user with questions or excess info - DI_NOFILECOPY DI_FLAGS = 0x01000000 // No file Copy necessary - DI_FORCECOPY DI_FLAGS = 0x02000000 // Force files to be copied from install path - DI_DRIVERPAGE_ADDED DI_FLAGS = 0x04000000 // Prop provider added Driver page. - DI_USECI_SELECTSTRINGS DI_FLAGS = 0x08000000 // Use Class Installer Provided strings in the Select Device Dlg - DI_OVERRIDE_INFFLAGS DI_FLAGS = 0x10000000 // Override INF flags - DI_PROPS_NOCHANGEUSAGE DI_FLAGS = 0x20000000 // No Enable/Disable in General Props - - DI_NOSELECTICONS DI_FLAGS = 0x40000000 // No small icons in select device dialogs - - DI_NOWRITE_IDS DI_FLAGS = 0x80000000 // Don't write HW & Compat IDs on install -) - -// DI_FLAGSEX is SP_DEVINSTALL_PARAMS.FlagsEx values -type DI_FLAGSEX uint32 - -const ( - DI_FLAGSEX_CI_FAILED DI_FLAGSEX = 0x00000004 // Failed to Load/Call class installer - DI_FLAGSEX_FINISHINSTALL_ACTION DI_FLAGSEX = 0x00000008 // Class/co-installer wants to get a DIF_FINISH_INSTALL action in client context. - DI_FLAGSEX_DIDINFOLIST DI_FLAGSEX = 0x00000010 // Did the Class Info List - DI_FLAGSEX_DIDCOMPATINFO DI_FLAGSEX = 0x00000020 // Did the Compat Info List - DI_FLAGSEX_FILTERCLASSES DI_FLAGSEX = 0x00000040 - DI_FLAGSEX_SETFAILEDINSTALL DI_FLAGSEX = 0x00000080 - DI_FLAGSEX_DEVICECHANGE DI_FLAGSEX = 0x00000100 - DI_FLAGSEX_ALWAYSWRITEIDS DI_FLAGSEX = 0x00000200 - DI_FLAGSEX_PROPCHANGE_PENDING DI_FLAGSEX = 0x00000400 // One or more device property sheets have had changes made to them, and need to have a DIF_PROPERTYCHANGE occur. - DI_FLAGSEX_ALLOWEXCLUDEDDRVS DI_FLAGSEX = 0x00000800 - DI_FLAGSEX_NOUIONQUERYREMOVE DI_FLAGSEX = 0x00001000 - DI_FLAGSEX_USECLASSFORCOMPAT DI_FLAGSEX = 0x00002000 // Use the device's class when building compat drv list. (Ignored if DI_COMPAT_FROM_CLASS flag is specified.) - DI_FLAGSEX_NO_DRVREG_MODIFY DI_FLAGSEX = 0x00008000 // Don't run AddReg and DelReg for device's software (driver) key. - DI_FLAGSEX_IN_SYSTEM_SETUP DI_FLAGSEX = 0x00010000 // Installation is occurring during initial system setup. - DI_FLAGSEX_INET_DRIVER DI_FLAGSEX = 0x00020000 // Driver came from Windows Update - DI_FLAGSEX_APPENDDRIVERLIST DI_FLAGSEX = 0x00040000 // Cause SetupDiBuildDriverInfoList to append a new driver list to an existing list. - DI_FLAGSEX_PREINSTALLBACKUP DI_FLAGSEX = 0x00080000 // not used - DI_FLAGSEX_BACKUPONREPLACE DI_FLAGSEX = 0x00100000 // not used - DI_FLAGSEX_DRIVERLIST_FROM_URL DI_FLAGSEX = 0x00200000 // build driver list from INF(s) retrieved from URL specified in SP_DEVINSTALL_PARAMS.DriverPath (empty string means Windows Update website) - DI_FLAGSEX_EXCLUDE_OLD_INET_DRIVERS DI_FLAGSEX = 0x00800000 // Don't include old Internet drivers when building a driver list. Ignored on Windows Vista and later. - DI_FLAGSEX_POWERPAGE_ADDED DI_FLAGSEX = 0x01000000 // class installer added their own power page - DI_FLAGSEX_FILTERSIMILARDRIVERS DI_FLAGSEX = 0x02000000 // only include similar drivers in class list - DI_FLAGSEX_INSTALLEDDRIVER DI_FLAGSEX = 0x04000000 // only add the installed driver to the class or compat driver list. Used in calls to SetupDiBuildDriverInfoList - DI_FLAGSEX_NO_CLASSLIST_NODE_MERGE DI_FLAGSEX = 0x08000000 // Don't remove identical driver nodes from the class list - DI_FLAGSEX_ALTPLATFORM_DRVSEARCH DI_FLAGSEX = 0x10000000 // Build driver list based on alternate platform information specified in associated file queue - DI_FLAGSEX_RESTART_DEVICE_ONLY DI_FLAGSEX = 0x20000000 // only restart the device drivers are being installed on as opposed to restarting all devices using those drivers. - DI_FLAGSEX_RECURSIVESEARCH DI_FLAGSEX = 0x40000000 // Tell SetupDiBuildDriverInfoList to do a recursive search - DI_FLAGSEX_SEARCH_PUBLISHED_INFS DI_FLAGSEX = 0x80000000 // Tell SetupDiBuildDriverInfoList to do a "published INF" search -) - -// ClassInstallHeader is the first member of any class install parameters structure. It contains the device installation request code that defines the format of the rest of the install parameters structure. -type ClassInstallHeader struct { - size uint32 - InstallFunction DI_FUNCTION -} - -func MakeClassInstallHeader(installFunction DI_FUNCTION) *ClassInstallHeader { - hdr := &ClassInstallHeader{InstallFunction: installFunction} - hdr.size = uint32(unsafe.Sizeof(*hdr)) - return hdr -} - -// DICS_STATE specifies values indicating a change in a device's state -type DICS_STATE uint32 - -const ( - DICS_ENABLE DICS_STATE = 0x00000001 // The device is being enabled. - DICS_DISABLE DICS_STATE = 0x00000002 // The device is being disabled. - DICS_PROPCHANGE DICS_STATE = 0x00000003 // The properties of the device have changed. - DICS_START DICS_STATE = 0x00000004 // The device is being started (if the request is for the currently active hardware profile). - DICS_STOP DICS_STATE = 0x00000005 // The device is being stopped. The driver stack will be unloaded and the CSCONFIGFLAG_DO_NOT_START flag will be set for the device. -) - -// DICS_FLAG specifies the scope of a device property change -type DICS_FLAG uint32 - -const ( - DICS_FLAG_GLOBAL DICS_FLAG = 0x00000001 // make change in all hardware profiles - DICS_FLAG_CONFIGSPECIFIC DICS_FLAG = 0x00000002 // make change in specified profile only - DICS_FLAG_CONFIGGENERAL DICS_FLAG = 0x00000004 // 1 or more hardware profile-specific changes to follow (obsolete) -) - -// PropChangeParams is a structure corresponding to a DIF_PROPERTYCHANGE install function. -type PropChangeParams struct { - ClassInstallHeader ClassInstallHeader - StateChange DICS_STATE - Scope DICS_FLAG - HwProfile uint32 -} - -// DI_REMOVEDEVICE specifies the scope of the device removal -type DI_REMOVEDEVICE uint32 - -const ( - DI_REMOVEDEVICE_GLOBAL DI_REMOVEDEVICE = 0x00000001 // Make this change in all hardware profiles. Remove information about the device from the registry. - DI_REMOVEDEVICE_CONFIGSPECIFIC DI_REMOVEDEVICE = 0x00000002 // Make this change to only the hardware profile specified by HwProfile. this flag only applies to root-enumerated devices. When Windows removes the device from the last hardware profile in which it was configured, Windows performs a global removal. -) - -// RemoveDeviceParams is a structure corresponding to a DIF_REMOVE install function. -type RemoveDeviceParams struct { - ClassInstallHeader ClassInstallHeader - Scope DI_REMOVEDEVICE - HwProfile uint32 -} - -// DrvInfoData is driver information structure (member of a driver info list that may be associated with a particular device instance, or (globally) with a device information set) -type DrvInfoData struct { - size uint32 - DriverType uint32 - _ uintptr - description [LINE_LEN]uint16 - mfgName [LINE_LEN]uint16 - providerName [LINE_LEN]uint16 - DriverDate windows.Filetime - DriverVersion uint64 -} - -func (data *DrvInfoData) Description() string { - return windows.UTF16ToString(data.description[:]) -} - -func (data *DrvInfoData) SetDescription(description string) error { - str, err := windows.UTF16FromString(description) - if err != nil { - return err - } - copy(data.description[:], str) - return nil -} - -func (data *DrvInfoData) MfgName() string { - return windows.UTF16ToString(data.mfgName[:]) -} - -func (data *DrvInfoData) SetMfgName(mfgName string) error { - str, err := windows.UTF16FromString(mfgName) - if err != nil { - return err - } - copy(data.mfgName[:], str) - return nil -} - -func (data *DrvInfoData) ProviderName() string { - return windows.UTF16ToString(data.providerName[:]) -} - -func (data *DrvInfoData) SetProviderName(providerName string) error { - str, err := windows.UTF16FromString(providerName) - if err != nil { - return err - } - copy(data.providerName[:], str) - return nil -} - -// IsNewer method returns true if DrvInfoData date and version is newer than supplied parameters. -func (data *DrvInfoData) IsNewer(driverDate windows.Filetime, driverVersion uint64) bool { - if data.DriverDate.HighDateTime > driverDate.HighDateTime { - return true - } - if data.DriverDate.HighDateTime < driverDate.HighDateTime { - return false - } - - if data.DriverDate.LowDateTime > driverDate.LowDateTime { - return true - } - if data.DriverDate.LowDateTime < driverDate.LowDateTime { - return false - } - - if data.DriverVersion > driverVersion { - return true - } - if data.DriverVersion < driverVersion { - return false - } - - return false -} - -// DrvInfoDetailData is driver information details structure (provides detailed information about a particular driver information structure) -type DrvInfoDetailData struct { - size uint32 // Warning: unsafe.Sizeof(DrvInfoDetailData) > sizeof(SP_DRVINFO_DETAIL_DATA) when GOARCH == 386 => use sizeofDrvInfoDetailData const. - InfDate windows.Filetime - compatIDsOffset uint32 - compatIDsLength uint32 - _ uintptr - sectionName [LINE_LEN]uint16 - infFileName [windows.MAX_PATH]uint16 - drvDescription [LINE_LEN]uint16 - hardwareID [ANYSIZE_ARRAY]uint16 -} - -func (data *DrvInfoDetailData) SectionName() string { - return windows.UTF16ToString(data.sectionName[:]) -} - -func (data *DrvInfoDetailData) InfFileName() string { - return windows.UTF16ToString(data.infFileName[:]) -} - -func (data *DrvInfoDetailData) DrvDescription() string { - return windows.UTF16ToString(data.drvDescription[:]) -} - -func (data *DrvInfoDetailData) HardwareID() string { - if data.compatIDsOffset > 1 { - bufW := data.getBuf() - return windows.UTF16ToString(bufW[:wcslen(bufW)]) - } - - return "" -} - -func (data *DrvInfoDetailData) CompatIDs() []string { - a := make([]string, 0) - - if data.compatIDsLength > 0 { - bufW := data.getBuf() - bufW = bufW[data.compatIDsOffset : data.compatIDsOffset+data.compatIDsLength] - for i := 0; i < len(bufW); { - j := i + wcslen(bufW[i:]) - if i < j { - a = append(a, windows.UTF16ToString(bufW[i:j])) - } - i = j + 1 - } - } - - return a -} - -func (data *DrvInfoDetailData) getBuf() []uint16 { - len := (data.size - uint32(unsafe.Offsetof(data.hardwareID))) / 2 - sl := struct { - addr *uint16 - len int - cap int - }{&data.hardwareID[0], int(len), int(len)} - return *(*[]uint16)(unsafe.Pointer(&sl)) -} - -// IsCompatible method tests if given hardware ID matches the driver or is listed on the compatible ID list. -func (data *DrvInfoDetailData) IsCompatible(hwid string) bool { - hwidLC := strings.ToLower(hwid) - if strings.ToLower(data.HardwareID()) == hwidLC { - return true - } - a := data.CompatIDs() - for i := range a { - if strings.ToLower(a[i]) == hwidLC { - return true - } - } - - return false -} - -// DICD flags control SetupDiCreateDeviceInfo -type DICD uint32 - -const ( - DICD_GENERATE_ID DICD = 0x00000001 - DICD_INHERIT_CLASSDRVS DICD = 0x00000002 -) - -// -// SPDIT flags to distinguish between class drivers and -// device drivers. -// (Passed in 'DriverType' parameter of driver information list APIs) -// -type SPDIT uint32 - -const ( - SPDIT_NODRIVER SPDIT = 0x00000000 - SPDIT_CLASSDRIVER SPDIT = 0x00000001 - SPDIT_COMPATDRIVER SPDIT = 0x00000002 -) - -// DIGCF flags control what is included in the device information set built by SetupDiGetClassDevs -type DIGCF uint32 - -const ( - DIGCF_DEFAULT DIGCF = 0x00000001 // only valid with DIGCF_DEVICEINTERFACE - DIGCF_PRESENT DIGCF = 0x00000002 - DIGCF_ALLCLASSES DIGCF = 0x00000004 - DIGCF_PROFILE DIGCF = 0x00000008 - DIGCF_DEVICEINTERFACE DIGCF = 0x00000010 -) - -// DIREG specifies values for SetupDiCreateDevRegKey, SetupDiOpenDevRegKey, and SetupDiDeleteDevRegKey. -type DIREG uint32 - -const ( - DIREG_DEV DIREG = 0x00000001 // Open/Create/Delete device key - DIREG_DRV DIREG = 0x00000002 // Open/Create/Delete driver key - DIREG_BOTH DIREG = 0x00000004 // Delete both driver and Device key -) - -// -// SPDRP specifies device registry property codes -// (Codes marked as read-only (R) may only be used for -// SetupDiGetDeviceRegistryProperty) -// -// These values should cover the same set of registry properties -// as defined by the CM_DRP codes in cfgmgr32.h. -// -// Note that SPDRP codes are zero based while CM_DRP codes are one based! -// -type SPDRP uint32 - -const ( - SPDRP_DEVICEDESC SPDRP = 0x00000000 // DeviceDesc (R/W) - SPDRP_HARDWAREID SPDRP = 0x00000001 // HardwareID (R/W) - SPDRP_COMPATIBLEIDS SPDRP = 0x00000002 // CompatibleIDs (R/W) - SPDRP_SERVICE SPDRP = 0x00000004 // Service (R/W) - SPDRP_CLASS SPDRP = 0x00000007 // Class (R--tied to ClassGUID) - SPDRP_CLASSGUID SPDRP = 0x00000008 // ClassGUID (R/W) - SPDRP_DRIVER SPDRP = 0x00000009 // Driver (R/W) - SPDRP_CONFIGFLAGS SPDRP = 0x0000000A // ConfigFlags (R/W) - SPDRP_MFG SPDRP = 0x0000000B // Mfg (R/W) - SPDRP_FRIENDLYNAME SPDRP = 0x0000000C // FriendlyName (R/W) - SPDRP_LOCATION_INFORMATION SPDRP = 0x0000000D // LocationInformation (R/W) - SPDRP_PHYSICAL_DEVICE_OBJECT_NAME SPDRP = 0x0000000E // PhysicalDeviceObjectName (R) - SPDRP_CAPABILITIES SPDRP = 0x0000000F // Capabilities (R) - SPDRP_UI_NUMBER SPDRP = 0x00000010 // UiNumber (R) - SPDRP_UPPERFILTERS SPDRP = 0x00000011 // UpperFilters (R/W) - SPDRP_LOWERFILTERS SPDRP = 0x00000012 // LowerFilters (R/W) - SPDRP_BUSTYPEGUID SPDRP = 0x00000013 // BusTypeGUID (R) - SPDRP_LEGACYBUSTYPE SPDRP = 0x00000014 // LegacyBusType (R) - SPDRP_BUSNUMBER SPDRP = 0x00000015 // BusNumber (R) - SPDRP_ENUMERATOR_NAME SPDRP = 0x00000016 // Enumerator Name (R) - SPDRP_SECURITY SPDRP = 0x00000017 // Security (R/W, binary form) - SPDRP_SECURITY_SDS SPDRP = 0x00000018 // Security (W, SDS form) - SPDRP_DEVTYPE SPDRP = 0x00000019 // Device Type (R/W) - SPDRP_EXCLUSIVE SPDRP = 0x0000001A // Device is exclusive-access (R/W) - SPDRP_CHARACTERISTICS SPDRP = 0x0000001B // Device Characteristics (R/W) - SPDRP_ADDRESS SPDRP = 0x0000001C // Device Address (R) - SPDRP_UI_NUMBER_DESC_FORMAT SPDRP = 0x0000001D // UiNumberDescFormat (R/W) - SPDRP_DEVICE_POWER_DATA SPDRP = 0x0000001E // Device Power Data (R) - SPDRP_REMOVAL_POLICY SPDRP = 0x0000001F // Removal Policy (R) - SPDRP_REMOVAL_POLICY_HW_DEFAULT SPDRP = 0x00000020 // Hardware Removal Policy (R) - SPDRP_REMOVAL_POLICY_OVERRIDE SPDRP = 0x00000021 // Removal Policy Override (RW) - SPDRP_INSTALL_STATE SPDRP = 0x00000022 // Device Install State (R) - SPDRP_LOCATION_PATHS SPDRP = 0x00000023 // Device Location Paths (R) - SPDRP_BASE_CONTAINERID SPDRP = 0x00000024 // Base ContainerID (R) - - SPDRP_MAXIMUM_PROPERTY SPDRP = 0x00000025 // Upper bound on ordinals -) - -const ( - CR_SUCCESS = 0x0 - CR_BUFFER_SMALL = 0x1a -) - -const ( - CM_GET_DEVICE_INTERFACE_LIST_PRESENT = 0 // only currently 'live' device interfaces - CM_GET_DEVICE_INTERFACE_LIST_ALL_DEVICES = 1 // all registered device interfaces, live or not -) diff --git a/tun/wintun/setupapi/types_windows_386.go b/tun/wintun/setupapi/types_windows_386.go deleted file mode 100644 index 132f921..0000000 --- a/tun/wintun/setupapi/types_windows_386.go +++ /dev/null @@ -1,11 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -const ( - sizeofDevInfoListDetailData uint32 = 550 - sizeofDrvInfoDetailData uint32 = 1570 -) diff --git a/tun/wintun/setupapi/types_windows_amd64.go b/tun/wintun/setupapi/types_windows_amd64.go deleted file mode 100644 index d4dd65c..0000000 --- a/tun/wintun/setupapi/types_windows_amd64.go +++ /dev/null @@ -1,11 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -const ( - sizeofDevInfoListDetailData uint32 = 560 - sizeofDrvInfoDetailData uint32 = 1584 -) diff --git a/tun/wintun/setupapi/zsetupapi_windows.go b/tun/wintun/setupapi/zsetupapi_windows.go deleted file mode 100644 index 375862d..0000000 --- a/tun/wintun/setupapi/zsetupapi_windows.go +++ /dev/null @@ -1,398 +0,0 @@ -// Code generated by 'go generate'; DO NOT EDIT. - -package setupapi - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -var _ unsafe.Pointer - -// Do the interface allocations only once for common -// Errno values. -const ( - errnoERROR_IO_PENDING = 997 -) - -var ( - errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case errnoERROR_IO_PENDING: - return errERROR_IO_PENDING - } - // TODO: add more here, after collecting data on the common - // error values see on Windows. (perhaps when running - // all.bat?) - return e -} - -var ( - modsetupapi = windows.NewLazySystemDLL("setupapi.dll") - modCfgMgr32 = windows.NewLazySystemDLL("CfgMgr32.dll") - - procSetupDiCreateDeviceInfoListExW = modsetupapi.NewProc("SetupDiCreateDeviceInfoListExW") - procSetupDiGetDeviceInfoListDetailW = modsetupapi.NewProc("SetupDiGetDeviceInfoListDetailW") - procSetupDiCreateDeviceInfoW = modsetupapi.NewProc("SetupDiCreateDeviceInfoW") - procSetupDiEnumDeviceInfo = modsetupapi.NewProc("SetupDiEnumDeviceInfo") - procSetupDiDestroyDeviceInfoList = modsetupapi.NewProc("SetupDiDestroyDeviceInfoList") - procSetupDiBuildDriverInfoList = modsetupapi.NewProc("SetupDiBuildDriverInfoList") - procSetupDiCancelDriverInfoSearch = modsetupapi.NewProc("SetupDiCancelDriverInfoSearch") - procSetupDiEnumDriverInfoW = modsetupapi.NewProc("SetupDiEnumDriverInfoW") - procSetupDiGetSelectedDriverW = modsetupapi.NewProc("SetupDiGetSelectedDriverW") - procSetupDiSetSelectedDriverW = modsetupapi.NewProc("SetupDiSetSelectedDriverW") - procSetupDiGetDriverInfoDetailW = modsetupapi.NewProc("SetupDiGetDriverInfoDetailW") - procSetupDiDestroyDriverInfoList = modsetupapi.NewProc("SetupDiDestroyDriverInfoList") - procSetupDiGetClassDevsExW = modsetupapi.NewProc("SetupDiGetClassDevsExW") - procSetupDiCallClassInstaller = modsetupapi.NewProc("SetupDiCallClassInstaller") - procSetupDiOpenDevRegKey = modsetupapi.NewProc("SetupDiOpenDevRegKey") - procSetupDiGetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiGetDeviceRegistryPropertyW") - procSetupDiSetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiSetDeviceRegistryPropertyW") - procSetupDiGetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiGetDeviceInstallParamsW") - procSetupDiGetDeviceInstanceIdW = modsetupapi.NewProc("SetupDiGetDeviceInstanceIdW") - procSetupDiGetClassInstallParamsW = modsetupapi.NewProc("SetupDiGetClassInstallParamsW") - procSetupDiSetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiSetDeviceInstallParamsW") - procSetupDiSetClassInstallParamsW = modsetupapi.NewProc("SetupDiSetClassInstallParamsW") - procSetupDiClassNameFromGuidExW = modsetupapi.NewProc("SetupDiClassNameFromGuidExW") - procSetupDiClassGuidsFromNameExW = modsetupapi.NewProc("SetupDiClassGuidsFromNameExW") - procSetupDiGetSelectedDevice = modsetupapi.NewProc("SetupDiGetSelectedDevice") - procSetupDiSetSelectedDevice = modsetupapi.NewProc("SetupDiSetSelectedDevice") - procCM_Get_Device_Interface_List_SizeW = modCfgMgr32.NewProc("CM_Get_Device_Interface_List_SizeW") - procCM_Get_Device_Interface_ListW = modCfgMgr32.NewProc("CM_Get_Device_Interface_ListW") -) - -func setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) { - r0, _, e1 := syscall.Syscall6(procSetupDiCreateDeviceInfoListExW.Addr(), 4, uintptr(unsafe.Pointer(classGUID)), uintptr(hwndParent), uintptr(unsafe.Pointer(machineName)), uintptr(reserved), 0, 0) - handle = DevInfo(r0) - if handle == DevInfo(windows.InvalidHandle) { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo, deviceInfoSetDetailData *DevInfoListDetailData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiGetDeviceInfoListDetailW.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoSetDetailData)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiCreateDeviceInfo(deviceInfoSet DevInfo, DeviceName *uint16, classGUID *windows.GUID, DeviceDescription *uint16, hwndParent uintptr, CreationFlags DICD, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall9(procSetupDiCreateDeviceInfoW.Addr(), 7, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(DeviceName)), uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(DeviceDescription)), uintptr(hwndParent), uintptr(CreationFlags), uintptr(unsafe.Pointer(deviceInfoData)), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiEnumDeviceInfo(deviceInfoSet DevInfo, memberIndex uint32, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiEnumDeviceInfo.Addr(), 3, uintptr(deviceInfoSet), uintptr(memberIndex), uintptr(unsafe.Pointer(deviceInfoData))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiDestroyDeviceInfoList(deviceInfoSet DevInfo) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiDestroyDeviceInfoList.Addr(), 1, uintptr(deviceInfoSet), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiBuildDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiBuildDriverInfoList.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType)) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiCancelDriverInfoSearch(deviceInfoSet DevInfo) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiCancelDriverInfoSearch.Addr(), 1, uintptr(deviceInfoSet), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiEnumDriverInfo(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT, memberIndex uint32, driverInfoData *DrvInfoData) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiEnumDriverInfoW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType), uintptr(memberIndex), uintptr(unsafe.Pointer(driverInfoData)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiGetSelectedDriverW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiSetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiSetSelectedDriverW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData, driverInfoDetailData *DrvInfoDetailData, driverInfoDetailDataSize uint32, requiredSize *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiGetDriverInfoDetailW.Addr(), 6, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(driverInfoData)), uintptr(unsafe.Pointer(driverInfoDetailData)), uintptr(driverInfoDetailDataSize), uintptr(unsafe.Pointer(requiredSize))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiDestroyDriverInfoList(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverType SPDIT) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiDestroyDriverInfoList.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(driverType)) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetClassDevsEx(classGUID *windows.GUID, Enumerator *uint16, hwndParent uintptr, Flags DIGCF, deviceInfoSet DevInfo, machineName *uint16, reserved uintptr) (handle DevInfo, err error) { - r0, _, e1 := syscall.Syscall9(procSetupDiGetClassDevsExW.Addr(), 7, uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(Enumerator)), uintptr(hwndParent), uintptr(Flags), uintptr(deviceInfoSet), uintptr(unsafe.Pointer(machineName)), uintptr(reserved), 0, 0) - handle = DevInfo(r0) - if handle == DevInfo(windows.InvalidHandle) { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiCallClassInstaller(installFunction DI_FUNCTION, deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiCallClassInstaller.Addr(), 3, uintptr(installFunction), uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiOpenDevRegKey(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, Scope DICS_FLAG, HwProfile uint32, KeyType DIREG, samDesired uint32) (key windows.Handle, err error) { - r0, _, e1 := syscall.Syscall6(procSetupDiOpenDevRegKey.Addr(), 6, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(Scope), uintptr(HwProfile), uintptr(KeyType), uintptr(samDesired)) - key = windows.Handle(r0) - if key == windows.InvalidHandle { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyRegDataType *uint32, propertyBuffer *byte, propertyBufferSize uint32, requiredSize *uint32) (err error) { - r1, _, e1 := syscall.Syscall9(procSetupDiGetDeviceRegistryPropertyW.Addr(), 7, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(property), uintptr(unsafe.Pointer(propertyRegDataType)), uintptr(unsafe.Pointer(propertyBuffer)), uintptr(propertyBufferSize), uintptr(unsafe.Pointer(requiredSize)), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiSetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP, propertyBuffer *byte, propertyBufferSize uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiSetDeviceRegistryPropertyW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(property), uintptr(unsafe.Pointer(propertyBuffer)), uintptr(propertyBufferSize), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiGetDeviceInstallParamsW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(deviceInstallParams))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiGetDeviceInstanceIdW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(instanceId)), uintptr(instanceIdSize), uintptr(unsafe.Pointer(instanceIdRequiredSize)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiGetClassInstallParamsW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), uintptr(unsafe.Pointer(requiredSize)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiSetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiSetDeviceInstallParamsW.Addr(), 3, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(deviceInstallParams))) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiSetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiSetClassInstallParamsW.Addr(), 4, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), 0, 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiClassNameFromGuidEx(classGUID *windows.GUID, className *uint16, classNameSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiClassNameFromGuidExW.Addr(), 6, uintptr(unsafe.Pointer(classGUID)), uintptr(unsafe.Pointer(className)), uintptr(classNameSize), uintptr(unsafe.Pointer(requiredSize)), uintptr(unsafe.Pointer(machineName)), uintptr(reserved)) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiClassGuidsFromNameEx(className *uint16, classGuidList *windows.GUID, classGuidListSize uint32, requiredSize *uint32, machineName *uint16, reserved uintptr) (err error) { - r1, _, e1 := syscall.Syscall6(procSetupDiClassGuidsFromNameExW.Addr(), 6, uintptr(unsafe.Pointer(className)), uintptr(unsafe.Pointer(classGuidList)), uintptr(classGuidListSize), uintptr(unsafe.Pointer(requiredSize)), uintptr(unsafe.Pointer(machineName)), uintptr(reserved)) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiGetSelectedDevice.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) { - r1, _, e1 := syscall.Syscall(procSetupDiSetSelectedDevice.Addr(), 2, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), 0) - if r1 == 0 { - if e1 != 0 { - err = errnoErr(e1) - } else { - err = syscall.EINVAL - } - } - return -} - -func cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) { - r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_List_SizeW.Addr(), 4, uintptr(unsafe.Pointer(len)), uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(flags), 0, 0) - ret = uint32(r0) - return -} - -func cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) { - r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_ListW.Addr(), 5, uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(unsafe.Pointer(buffer)), uintptr(bufferLen), uintptr(flags), 0) - ret = uint32(r0) - return -} diff --git a/tun/wintun/setupapi/zsetupapi_windows_test.go b/tun/wintun/setupapi/zsetupapi_windows_test.go deleted file mode 100644 index 915b427..0000000 --- a/tun/wintun/setupapi/zsetupapi_windows_test.go +++ /dev/null @@ -1,20 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package setupapi - -import ( - "syscall" - "testing" - - "golang.org/x/sys/windows" -) - -func TestSetupDiDestroyDeviceInfoList(t *testing.T) { - err := SetupDiDestroyDeviceInfoList(DevInfo(windows.InvalidHandle)) - if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_HANDLE { - t.Errorf("SetupDiDestroyDeviceInfoList(nil, ...) should fail with ERROR_INVALID_HANDLE") - } -} diff --git a/tun/wintun/wintun_windows.go b/tun/wintun/wintun_windows.go deleted file mode 100644 index 4c12d97..0000000 --- a/tun/wintun/wintun_windows.go +++ /dev/null @@ -1,803 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wintun - -import ( - "errors" - "fmt" - "strings" - "time" - "unsafe" - - "golang.org/x/sys/windows" - "golang.org/x/sys/windows/registry" - - "golang.zx2c4.com/wireguard/tun/wintun/iphlpapi" - "golang.zx2c4.com/wireguard/tun/wintun/nci" - registryEx "golang.zx2c4.com/wireguard/tun/wintun/registry" - "golang.zx2c4.com/wireguard/tun/wintun/setupapi" -) - -type Pool string - -type Interface struct { - cfgInstanceID windows.GUID - devInstanceID string - luidIndex uint32 - ifType uint32 - pool Pool -} - -var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}} -var deviceInterfaceNetGUID = windows.GUID{Data1: 0xcac88484, Data2: 0x7515, Data3: 0x4c03, Data4: [8]byte{0x82, 0xe6, 0x71, 0xa8, 0x7a, 0xba, 0xc3, 0x61}} - -const ( - hardwareID = "Wintun" - waitForRegistryTimeout = time.Second * 10 -) - -// makeWintun creates a Wintun interface handle and populates it from the device's registry key. -func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) { - // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. - key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE) - if err != nil { - return nil, fmt.Errorf("Device-specific registry key open failed: %v", err) - } - defer key.Close() - - // Read the NetCfgInstanceId value. - valueStr, err := registryEx.GetStringValue(key, "NetCfgInstanceId") - if err != nil { - return nil, fmt.Errorf("RegQueryStringValue(\"NetCfgInstanceId\") failed: %v", err) - } - - // Convert to GUID. - ifid, err := windows.GUIDFromString(valueStr) - if err != nil { - return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: %q)", valueStr) - } - - // Read the NetLuidIndex value. - luidIdx, _, err := key.GetIntegerValue("NetLuidIndex") - if err != nil { - return nil, fmt.Errorf("RegQueryValue(\"NetLuidIndex\") failed: %v", err) - } - - // Read the NetLuidIndex value. - ifType, _, err := key.GetIntegerValue("*IfType") - if err != nil { - return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err) - } - - instanceID, err := devInfo.DeviceInstanceID(devInfoData) - if err != nil { - return nil, fmt.Errorf("DeviceInstanceID failed: %v", err) - } - - return &Interface{ - cfgInstanceID: ifid, - devInstanceID: instanceID, - luidIndex: uint32(luidIdx), - ifType: uint32(ifType), - pool: pool, - }, nil -} - -func removeNumberedSuffix(ifname string) string { - removed := strings.TrimRight(ifname, "0123456789") - if removed != ifname && len(removed) > 1 && removed[len(removed)-1] == ' ' { - return removed[:len(removed)-1] - } - return ifname -} - -// GetInterface finds a Wintun interface by its name. This function returns -// the interface if found, or windows.ERROR_OBJECT_NOT_FOUND otherwise. If -// the interface is found but not a Wintun-class or a member of the pool, -// this function returns windows.ERROR_ALREADY_EXISTS. -func (pool Pool) GetInterface(ifname string) (*Interface, error) { - mutex, err := pool.takeNameMutex() - if err != nil { - return nil, err - } - defer func() { - windows.ReleaseMutex(mutex) - windows.CloseHandle(mutex) - }() - - // Create a list of network devices. - devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") - if err != nil { - return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err) - } - defer devInfo.Close() - - // Windows requires each interface to have a different name. When - // enforcing this, Windows treats interface names case-insensitive. If an - // interface "FooBar" exists and this function reports there is no - // interface "foobar", an attempt to create a new interface and name it - // "foobar" would cause conflict with "FooBar". - ifname = strings.ToLower(ifname) - - for index := 0; ; index++ { - devInfoData, err := devInfo.EnumDeviceInfo(index) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices. - property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID) - if err != nil { - continue - } - if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID { - continue - } - - wintun, err := makeWintun(devInfo, devInfoData, pool) - if err != nil { - continue - } - - // TODO: is there a better way than comparing ifnames? - ifname2, err := wintun.Name() - if err != nil { - continue - } - ifname2 = strings.ToLower(ifname2) - ifname3 := removeNumberedSuffix(ifname2) - - if ifname == ifname2 || ifname == ifname3 { - err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - if err != nil { - return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) - } - defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - - for index := 0; ; index++ { - driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Get driver info details. - driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData) - if err != nil { - continue - } - - if driverDetailData.IsCompatible(hardwareID) { - isMember, err := pool.isMember(devInfo, devInfoData) - if err != nil { - return nil, err - } - if !isMember { - return nil, windows.ERROR_ALREADY_EXISTS - } - - return wintun, nil - } - } - - // This interface is not using Wintun driver. - return nil, windows.ERROR_ALREADY_EXISTS - } - } - - return nil, windows.ERROR_OBJECT_NOT_FOUND -} - -// CreateInterface creates a Wintun interface. ifname is the requested name of -// the interface, while requestedGUID is the GUID of the created network -// interface, which then influences NLA generation deterministically. If it is -// set to nil, the GUID is chosen by the system at random, and hence a new NLA -// entry is created for each new interface. It is called "requested" GUID -// because the API it uses is completely undocumented, and so there could be minor -// interesting complications with its usage. This function returns the network -// interface ID and a flag if reboot is required. -func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wintun *Interface, rebootRequired bool, err error) { - mutex, err := pool.takeNameMutex() - if err != nil { - return - } - defer func() { - windows.ReleaseMutex(mutex) - windows.CloseHandle(mutex) - }() - - // Create an empty device info set for network adapter device class. - devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "") - if err != nil { - err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err) - return - } - defer devInfo.Close() - - // Get the device class name from GUID. - className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "") - if err != nil { - err = fmt.Errorf("SetupDiClassNameFromGuidEx(%v) failed: %v", deviceClassNetGUID, err) - return - } - - // Create a new device info element and add it to the device info set. - deviceTypeName := pool.deviceTypeName() - devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID) - if err != nil { - err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err) - return - } - - err = setQuietInstall(devInfo, devInfoData) - if err != nil { - err = fmt.Errorf("Setting quiet installation failed: %v", err) - return - } - - // Set a device information element as the selected member of a device information set. - err = devInfo.SetSelectedDevice(devInfoData) - if err != nil { - err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err) - return - } - - // Set Plug&Play device hardware ID property. - err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID) - if err != nil { - err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err) - return - } - - err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms - if err != nil { - err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) - return - } - defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - - driverDate := windows.Filetime{} - driverVersion := uint64(0) - for index := 0; ; index++ { // TODO: This loop takes ~600ms - driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match. - if driverData.IsNewer(driverDate, driverVersion) { - driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData) - if err != nil { - continue - } - - if driverDetailData.IsCompatible(hardwareID) { - err := devInfo.SetSelectedDriver(devInfoData, driverData) - if err != nil { - continue - } - - driverDate = driverData.DriverDate - driverVersion = driverData.DriverVersion - } - } - } - - if driverVersion == 0 { - err = fmt.Errorf("No driver for device %q installed", hardwareID) - return - } - - defer func() { - if err != nil { - // The interface failed to install, or the interface ID was unobtainable. Clean-up. - removeDeviceParams := setupapi.RemoveDeviceParams{ - ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), - Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, - } - - // Set class installer parameters for DIF_REMOVE. - if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil { - // Call appropriate class installer. - if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil { - rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData) - } - } - - wintun = nil - } - }() - - // Call appropriate class installer. - err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData) - if err != nil { - err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err) - return - } - - // Register device co-installers if any. (Ignore errors) - devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData) - - var netDevRegKey registry.Key - const pollTimeout = time.Millisecond * 50 - for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ { - if i != 0 { - time.Sleep(pollTimeout) - } - netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY) - if err == nil { - break - } - } - if err != nil { - err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err) - return - } - defer netDevRegKey.Close() - if requestedGUID != nil { - err = netDevRegKey.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String()) - if err != nil { - err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err) - return - } - } - - // Install interfaces if any. (Ignore errors) - devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData) - - // Install the device. - err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData) - if err != nil { - err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err) - return - } - rebootRequired = checkReboot(devInfo, devInfoData) - - err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName) - if err != nil { - err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err) - return - } - - // DIF_INSTALLDEVICE returns almost immediately, while the device installation - // continues in the background. It might take a while, before all registry - // keys and values are populated. - _, err = registryEx.GetStringValueWait(netDevRegKey, "NetCfgInstanceId", waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("GetStringValueWait(NetCfgInstanceId) failed: %v", err) - return - } - _, err = registryEx.GetIntegerValueWait(netDevRegKey, "NetLuidIndex", waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("GetIntegerValueWait(NetLuidIndex) failed: %v", err) - return - } - _, err = registryEx.GetIntegerValueWait(netDevRegKey, "*IfType", waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("GetIntegerValueWait(*IfType) failed: %v", err) - return - } - - // Get network interface. - wintun, err = makeWintun(devInfo, devInfoData, pool) - if err != nil { - err = fmt.Errorf("makeWintun failed: %v", err) - return - } - - // Wait for TCP/IP adapter registry key to emerge and populate. - tcpipAdapterRegKey, err := registryEx.OpenKeyWait( - registry.LOCAL_MACHINE, - wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY, - waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", wintun.tcpipAdapterRegKeyName(), err) - return - } - defer tcpipAdapterRegKey.Close() - _, err = registryEx.GetStringValueWait(tcpipAdapterRegKey, "IpConfig", waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("GetStringValueWait(IpConfig) failed: %v", err) - return - } - - tcpipInterfaceRegKeyName, err := wintun.tcpipInterfaceRegKeyName() - if err != nil { - err = fmt.Errorf("tcpipInterfaceRegKeyName failed: %v", err) - return - } - - // Wait for TCP/IP interface registry key to emerge. - tcpipInterfaceRegKey, err := registryEx.OpenKeyWait( - registry.LOCAL_MACHINE, - tcpipInterfaceRegKeyName, registry.QUERY_VALUE|registry.SET_VALUE, - waitForRegistryTimeout) - if err != nil { - err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", tcpipInterfaceRegKeyName, err) - return - } - defer tcpipInterfaceRegKey.Close() - // Disable dead gateway detection on our interface. - tcpipInterfaceRegKey.SetDWordValue("EnableDeadGWDetect", 0) - - err = wintun.SetName(ifname) - if err != nil { - err = fmt.Errorf("Unable to set name of Wintun interface: %v", err) - return - } - - return -} - -// DeleteInterface deletes a Wintun interface. This function succeeds -// if the interface was not found. It returns a bool indicating whether -// a reboot is required. -func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) { - devInfo, devInfoData, err := wintun.devInfoData() - if err == windows.ERROR_OBJECT_NOT_FOUND { - return false, nil - } - if err != nil { - return false, err - } - defer devInfo.Close() - - // Remove the device. - removeDeviceParams := setupapi.RemoveDeviceParams{ - ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), - Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, - } - - // Set class installer parameters for DIF_REMOVE. - err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) - if err != nil { - return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err) - } - - // Call appropriate class installer. - err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) - if err != nil { - return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err) - } - - return checkReboot(devInfo, devInfoData), nil -} - -// DeleteMatchingInterfaces deletes all Wintun interfaces, which match -// given criteria, and returns which ones it deleted, whether a reboot -// is required after, and which errors occurred during the process. -func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool) (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) { - mutex, err := pool.takeNameMutex() - if err != nil { - errors = append(errors, err) - return - } - defer func() { - windows.ReleaseMutex(mutex) - windows.CloseHandle(mutex) - }() - - devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") - if err != nil { - return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())} - } - defer devInfo.Close() - - for i := 0; ; i++ { - devInfoData, err := devInfo.EnumDeviceInfo(i) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices. - property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID) - if err != nil { - continue - } - if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID { - continue - } - - err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - if err != nil { - continue - } - defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) - - isWintun := false - for j := 0; ; j++ { - driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData) - if err != nil { - continue - } - if driverDetailData.IsCompatible(hardwareID) { - isWintun = true - break - } - } - if !isWintun { - continue - } - - isMember, err := pool.isMember(devInfo, devInfoData) - if err != nil { - errors = append(errors, err) - continue - } - if !isMember { - continue - } - - wintun, err := makeWintun(devInfo, devInfoData, pool) - if err != nil { - errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err)) - continue - } - if !matches(wintun) { - continue - } - - err = setQuietInstall(devInfo, devInfoData) - if err != nil { - errors = append(errors, err) - continue - } - - inst := devInfoData.DevInst - removeDeviceParams := setupapi.RemoveDeviceParams{ - ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), - Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, - } - err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) - if err != nil { - errors = append(errors, err) - continue - } - err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) - if err != nil { - errors = append(errors, err) - continue - } - rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData) - deviceInstancesDeleted = append(deviceInstancesDeleted, inst) - } - return -} - -// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name. -func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) { - deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC) - if err != nil { - return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err) - } - deviceDesc, _ := deviceDescVal.(string) - friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME) - if err != nil { - return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err) - } - friendlyName, _ := friendlyNameVal.(string) - deviceTypeName := pool.deviceTypeName() - return friendlyName == deviceTypeName || deviceDesc == deviceTypeName || - removeNumberedSuffix(friendlyName) == deviceTypeName || removeNumberedSuffix(deviceDesc) == deviceTypeName, nil -} - -// checkReboot checks device install parameters if a system reboot is required. -func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool { - devInstallParams, err := devInfo.DeviceInstallParams(devInfoData) - if err != nil { - return false - } - - return (devInstallParams.Flags & (setupapi.DI_NEEDREBOOT | setupapi.DI_NEEDRESTART)) != 0 -} - -// setQuietInstall sets device install parameters for a quiet installation -func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error { - devInstallParams, err := devInfo.DeviceInstallParams(devInfoData) - if err != nil { - return err - } - - devInstallParams.Flags |= setupapi.DI_QUIETINSTALL - return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams) -} - -// deviceTypeName returns pool-specific device type name. -func (pool Pool) deviceTypeName() string { - return fmt.Sprintf("%s Tunnel", pool) -} - -// Name returns the name of the Wintun interface. -func (wintun *Interface) Name() (string, error) { - return nci.ConnectionName(&wintun.cfgInstanceID) -} - -// SetName sets name of the Wintun interface. -func (wintun *Interface) SetName(ifname string) error { - const maxSuffix = 1000 - availableIfname := ifname - for i := 0; ; i++ { - err := nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname) - if err == windows.ERROR_DUP_NAME { - duplicateGuid, err2 := iphlpapi.InterfaceGUIDFromAlias(availableIfname) - if err2 == nil { - for j := 0; j < maxSuffix; j++ { - proposal := fmt.Sprintf("%s %d", ifname, j+1) - if proposal == availableIfname { - continue - } - err2 = nci.SetConnectionName(duplicateGuid, proposal) - if err2 == windows.ERROR_DUP_NAME { - continue - } - if err2 == nil { - err = nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname) - if err == nil { - break - } - } - break - } - } - } - if err == nil { - break - } - - if i > maxSuffix || err != windows.ERROR_DUP_NAME { - return fmt.Errorf("NciSetConnectionName failed: %v", err) - } - availableIfname = fmt.Sprintf("%s %d", ifname, i+1) - } - - // TODO: This should use NetSetup2 so that it doesn't get unset. - deviceRegKey, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.deviceRegKeyName(), registry.SET_VALUE) - if err != nil { - return fmt.Errorf("Device-level registry key open failed: %v", err) - } - defer deviceRegKey.Close() - err = deviceRegKey.SetStringValue("FriendlyName", wintun.pool.deviceTypeName()) - if err != nil { - return fmt.Errorf("SetStringValue(FriendlyName) failed: %v", err) - } - return nil -} - -// tcpipAdapterRegKeyName returns the adapter-specific TCP/IP network registry key name. -func (wintun *Interface) tcpipAdapterRegKeyName() string { - return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID) -} - -// deviceRegKeyName returns the device-level registry key name. -func (wintun *Interface) deviceRegKeyName() string { - return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID) -} - -// Version returns the version of the Wintun driver and NDIS system currently loaded. -func (wintun *Interface) Version() (driverVersion string, ndisVersion string, err error) { - key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SYSTEM\\CurrentControlSet\\Services\\Wintun", registry.QUERY_VALUE) - if err != nil { - return - } - defer key.Close() - driverMajor, _, err := key.GetIntegerValue("DriverMajorVersion") - if err != nil { - return - } - driverMinor, _, err := key.GetIntegerValue("DriverMinorVersion") - if err != nil { - return - } - ndisMajor, _, err := key.GetIntegerValue("NdisMajorVersion") - if err != nil { - return - } - ndisMinor, _, err := key.GetIntegerValue("NdisMinorVersion") - if err != nil { - return - } - driverVersion = fmt.Sprintf("%d.%d", driverMajor, driverMinor) - ndisVersion = fmt.Sprintf("%d.%d", ndisMajor, ndisMinor) - return -} - -// tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name. -func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) { - key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE) - if err != nil { - return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err) - } - paths, _, err := key.GetStringsValue("IpConfig") - key.Close() - if err != nil { - return "", fmt.Errorf("Error reading IpConfig registry key: %v", err) - } - if len(paths) == 0 { - return "", errors.New("No TCP/IP interfaces found on adapter") - } - return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil -} - -// devInfoData returns TUN device info list handle and interface device info -// data. The device info list handle must be closed after use. In case the -// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned. -func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) { - // Create a list of network devices. - devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") - if err != nil { - return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error()) - } - - for index := 0; ; index++ { - devInfoData, err := devInfo.EnumDeviceInfo(index) - if err != nil { - if err == windows.ERROR_NO_MORE_ITEMS { - break - } - continue - } - - // Get interface ID. - // TODO: Store some ID in the Wintun object such that this call isn't required. - wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool) - if err != nil { - continue - } - - if wintun.cfgInstanceID == wintun2.cfgInstanceID { - err = setQuietInstall(devInfo, devInfoData) - if err != nil { - devInfo.Close() - return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err) - } - return devInfo, devInfoData, nil - } - } - - devInfo.Close() - return 0, nil, windows.ERROR_OBJECT_NOT_FOUND -} - -// handle returns a handle to the interface device object. -func (wintun *Interface) handle() (windows.Handle, error) { - interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT) - if err != nil { - return windows.InvalidHandle, fmt.Errorf("Error listing NDIS interfaces: %v", err) - } - handle, err := windows.CreateFile(windows.StringToUTF16Ptr(interfaces[0]), windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0) - if err != nil { - return windows.InvalidHandle, fmt.Errorf("Error opening NDIS device: %v", err) - } - return handle, nil -} - -// GUID returns the GUID of the interface. -func (wintun *Interface) GUID() windows.GUID { - return wintun.cfgInstanceID -} - -// LUID returns the LUID of the interface. -func (wintun *Interface) LUID() uint64 { - return ((uint64(wintun.luidIndex) & ((1 << 24) - 1)) << 24) | ((uint64(wintun.ifType) & ((1 << 16) - 1)) << 48) -} diff --git a/version.go b/version.go new file mode 100644 index 0000000..db75bb9 --- /dev/null +++ b/version.go @@ -0,0 +1,3 @@ +package main + +const Version = "0.0.20230223" diff --git a/wgcfg/config.go b/wgcfg/config.go deleted file mode 100644 index 2b5e714..0000000 --- a/wgcfg/config.go +++ /dev/null @@ -1,78 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -// Package wgcfg has types and a parser for representing WireGuard config. -package wgcfg - -import ( - "fmt" - "strings" -) - -// Config is a wireguard configuration. -type Config struct { - Name string - PrivateKey PrivateKey - Addresses []CIDR - ListenPort uint16 - MTU uint16 - DNS []IP - Peers []Peer -} - -type Peer struct { - PublicKey Key - PresharedKey SymmetricKey - AllowedIPs []CIDR - Endpoints []Endpoint - PersistentKeepalive uint16 -} - -type Endpoint struct { - Host string - Port uint16 -} - -func (e *Endpoint) String() string { - if strings.IndexByte(e.Host, ':') > 0 { - return fmt.Sprintf("[%s]:%d", e.Host, e.Port) - } - return fmt.Sprintf("%s:%d", e.Host, e.Port) -} - -func (e *Endpoint) IsEmpty() bool { - return len(e.Host) == 0 -} - -// Copy makes a deep copy of Config. -// The result aliases no memory with the original. -func (cfg Config) Copy() Config { - res := cfg - if res.Addresses != nil { - res.Addresses = append([]CIDR{}, res.Addresses...) - } - if res.DNS != nil { - res.DNS = append([]IP{}, res.DNS...) - } - peers := make([]Peer, 0, len(res.Peers)) - for _, peer := range res.Peers { - peers = append(peers, peer.Copy()) - } - res.Peers = peers - return res -} - -// Copy makes a deep copy of Peer. -// The result aliases no memory with the original. -func (peer Peer) Copy() Peer { - res := peer - if res.AllowedIPs != nil { - res.AllowedIPs = append([]CIDR{}, res.AllowedIPs...) - } - if res.Endpoints != nil { - res.Endpoints = append([]Endpoint{}, res.Endpoints...) - } - return res -} diff --git a/wgcfg/ip.go b/wgcfg/ip.go deleted file mode 100644 index 47fa91c..0000000 --- a/wgcfg/ip.go +++ /dev/null @@ -1,152 +0,0 @@ -package wgcfg - -import ( - "fmt" - "math" - "net" -) - -// IP is an IPv4 or an IPv6 address. -// -// Internally the address is always represented in its IPv6 form. -// IPv4 addresses use the IPv4-in-IPv6 syntax. -type IP struct { - Addr [16]byte -} - -func (ip IP) String() string { return net.IP(ip.Addr[:]).String() } - -// IP converts ip into a standard library net.IP. -func (ip IP) IP() net.IP { return net.IP(ip.Addr[:]) } - -// Is6 reports whether ip is an IPv6 address. -func (ip IP) Is6() bool { return !ip.Is4() } - -// Is4 reports whether ip is an IPv4 address. -func (ip IP) Is4() bool { - return ip.Addr[0] == 0 && ip.Addr[1] == 0 && - ip.Addr[2] == 0 && ip.Addr[3] == 0 && - ip.Addr[4] == 0 && ip.Addr[5] == 0 && - ip.Addr[6] == 0 && ip.Addr[7] == 0 && - ip.Addr[8] == 0 && ip.Addr[9] == 0 && - ip.Addr[10] == 0xff && ip.Addr[11] == 0xff -} - -// To4 returns either a 4 byte slice for an IPv4 address, or nil if -// it's not IPv4. -func (ip IP) To4() []byte { - if ip.Is4() { - return ip.Addr[12:16] - } else { - return nil - } -} - -// Equal reports whether ip == x. -func (ip IP) Equal(x IP) bool { - return ip == x -} - -func (ip IP) MarshalText() ([]byte, error) { - return []byte(ip.String()), nil -} - -func (ip *IP) UnmarshalText(text []byte) error { - parsedIP, ok := ParseIP(string(text)) - if !ok { - return fmt.Errorf("wgcfg.IP: UnmarshalText: bad IP address %q", text) - } - *ip = parsedIP - return nil -} - -func IPv4(b0, b1, b2, b3 byte) (ip IP) { - ip.Addr[10], ip.Addr[11] = 0xff, 0xff // IPv4-in-IPv6 prefix - ip.Addr[12] = b0 - ip.Addr[13] = b1 - ip.Addr[14] = b2 - ip.Addr[15] = b3 - return ip -} - -// ParseIP parses the string representation of an address into an IP. -// -// It accepts IPv4 notation such as "1.2.3.4" and IPv6 notation like ""::0". -// The ok result reports whether s was a valid IP and ip is valid. -func ParseIP(s string) (ip IP, ok bool) { - netIP := net.ParseIP(s) - if netIP == nil { - return IP{}, false - } - copy(ip.Addr[:], netIP.To16()) - return ip, true -} - -// CIDR is a compact IP address and subnet mask. -type CIDR struct { - IP IP - Mask uint8 // 0-32 for IsIPv4, 4-128 for IsIPv6 -} - -// ParseCIDR parses CIDR notation into a CIDR type. -// Typical CIDR strings look like "192.168.1.0/24". -func ParseCIDR(s string) (CIDR, error) { - netIP, netAddr, err := net.ParseCIDR(s) - if err != nil { - return CIDR{}, err - } - var cidr CIDR - copy(cidr.IP.Addr[:], netIP.To16()) - ones, _ := netAddr.Mask.Size() - cidr.Mask = uint8(ones) - - return cidr, nil -} - -func (r CIDR) String() string { return r.IPNet().String() } - -func (r CIDR) IPNet() *net.IPNet { - bits := 128 - if r.IP.Is4() { - bits = 32 - } - return &net.IPNet{IP: r.IP.IP(), Mask: net.CIDRMask(int(r.Mask), bits)} -} - -func (r CIDR) Contains(ip IP) bool { - c := int8(r.Mask) - i := 0 - if r.IP.Is4() { - i = 12 - if ip.Is6() { - return false - } - } - for ; i < 16 && c > 0; i++ { - var x uint8 - if c < 8 { - x = 8 - uint8(c) - } - m := uint8(math.MaxUint8) >> x << x - a := r.IP.Addr[i] & m - b := ip.Addr[i] & m - if a != b { - return false - } - c -= 8 - } - return true -} - -func (r CIDR) MarshalText() ([]byte, error) { - return []byte(r.String()), nil -} - -func (r *CIDR) UnmarshalText(text []byte) error { - cidr, err := ParseCIDR(string(text)) - if err != nil { - return fmt.Errorf("wgcfg.CIDR: UnmarshalText: %v", err) - } - *r = cidr - return nil -} diff --git a/wgcfg/ip_test.go b/wgcfg/ip_test.go deleted file mode 100644 index 6cd41d3..0000000 --- a/wgcfg/ip_test.go +++ /dev/null @@ -1,106 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wgcfg_test - -import ( - "testing" - - "golang.zx2c4.com/wireguard/wgcfg" -) - -func parseIP(t testing.TB, ipStr string) wgcfg.IP { - t.Helper() - ip, ok := wgcfg.ParseIP(ipStr) - if !ok { - t.Fatalf("failed to parse IP: %q", ipStr) - } - return ip -} - -func TestCIDRContains(t *testing.T) { - t.Run("home router test", func(t *testing.T) { - r, err := wgcfg.ParseCIDR("192.168.0.0/24") - if err != nil { - t.Fatal(err) - } - ip := parseIP(t, "192.168.0.1") - if !r.Contains(ip) { - t.Fatalf("%q should contain %q", r, ip) - } - }) - - t.Run("IPv4 outside network", func(t *testing.T) { - r, err := wgcfg.ParseCIDR("192.168.0.0/30") - if err != nil { - t.Fatal(err) - } - ip := parseIP(t, "192.168.0.4") - if r.Contains(ip) { - t.Fatalf("%q should not contain %q", r, ip) - } - }) - - t.Run("IPv4 does not contain IPv6", func(t *testing.T) { - r, err := wgcfg.ParseCIDR("192.168.0.0/24") - if err != nil { - t.Fatal(err) - } - ip := parseIP(t, "2001:db8:85a3:0:0:8a2e:370:7334") - if r.Contains(ip) { - t.Fatalf("%q should not contain %q", r, ip) - } - }) - - t.Run("IPv6 inside network", func(t *testing.T) { - r, err := wgcfg.ParseCIDR("2001:db8:1234::/48") - if err != nil { - t.Fatal(err) - } - ip := parseIP(t, "2001:db8:1234:0000:0000:0000:0000:0001") - if !r.Contains(ip) { - t.Fatalf("%q should not contain %q", r, ip) - } - }) - - t.Run("IPv6 outside network", func(t *testing.T) { - r, err := wgcfg.ParseCIDR("2001:db8:1234:0:190b:0:1982::/126") - if err != nil { - t.Fatal(err) - } - ip := parseIP(t, "2001:db8:1234:0:190b:0:1982:4") - if r.Contains(ip) { - t.Fatalf("%q should not contain %q", r, ip) - } - }) -} - -func BenchmarkCIDRContainsIPv4(b *testing.B) { - b.Run("IPv4", func(b *testing.B) { - r, err := wgcfg.ParseCIDR("192.168.1.0/24") - if err != nil { - b.Fatal(err) - } - ip := parseIP(b, "1.2.3.4") - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.Contains(ip) - } - }) - - b.Run("IPv6", func(b *testing.B) { - r, err := wgcfg.ParseCIDR("2001:db8:1234::/48") - if err != nil { - b.Fatal(err) - } - ip := parseIP(b, "2001:db8:1234:0000:0000:0000:0000:0001") - b.ResetTimer() - - for i := 0; i < b.N; i++ { - r.Contains(ip) - } - }) -} diff --git a/wgcfg/key.go b/wgcfg/key.go deleted file mode 100644 index cdbbeea..0000000 --- a/wgcfg/key.go +++ /dev/null @@ -1,239 +0,0 @@ -package wgcfg - -import ( - "bytes" - "crypto/rand" - "crypto/subtle" - "encoding/base64" - "encoding/hex" - "errors" - "fmt" - "strings" - - "golang.org/x/crypto/chacha20poly1305" - "golang.org/x/crypto/curve25519" -) - -const KeySize = 32 - -// Key is curve25519 key. -// It is used by WireGuard to represent public and preshared keys. -type Key [KeySize]byte - -// NewPresharedKey generates a new random key. -func NewPresharedKey() (*Key, error) { - var k [KeySize]byte - _, err := rand.Read(k[:]) - if err != nil { - return nil, err - } - return (*Key)(&k), nil -} - -func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) } - -func ParseHexKey(s string) (Key, error) { - b, err := hex.DecodeString(s) - if err != nil { - return Key{}, &ParseError{"invalid hex key: " + err.Error(), s} - } - if len(b) != KeySize { - return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s} - } - - var key Key - copy(key[:], b) - return key, nil -} - -func ParsePrivateHexKey(v string) (PrivateKey, error) { - k, err := ParseHexKey(v) - if err != nil { - return PrivateKey{}, err - } - pk := PrivateKey(k) - if pk.IsZero() { - // Do not clamp a zero key, pass the zero through - // (much like NaN propagation) so that IsZero reports - // a useful result. - return pk, nil - } - pk.clamp() - return pk, nil -} - -func (k Key) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } -func (k Key) String() string { return "pub:" + k.Base64()[:8] } -func (k Key) HexString() string { return hex.EncodeToString(k[:]) } -func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } - -func (k *Key) ShortString() string { - if k.IsZero() { - return "[empty]" - } - long := k.String() - if len(long) < 10 { - return "invalid" - } - return "[" + long[0:4] + "ā¦" + long[len(long)-5:len(long)-1] + "]" -} - -func (k *Key) IsZero() bool { - if k == nil { - return true - } - var zeros Key - return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 -} - -func (k *Key) MarshalJSON() ([]byte, error) { - if k == nil { - return []byte("null"), nil - } - buf := new(bytes.Buffer) - fmt.Fprintf(buf, `"%x"`, k[:]) - return buf.Bytes(), nil -} - -func (k *Key) UnmarshalJSON(b []byte) error { - if k == nil { - return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer") - } - if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { - return errors.New("wgcfg.Key: UnmarshalJSON not given a string") - } - b = b[1 : len(b)-1] - key, err := ParseHexKey(string(b)) - if err != nil { - return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err) - } - copy(k[:], key[:]) - return nil -} - -func (a *Key) LessThan(b *Key) bool { - for i := range a { - if a[i] < b[i] { - return true - } else if a[i] > b[i] { - return false - } - } - return false -} - -// PrivateKey is curve25519 key. -// It is used by WireGuard to represent private keys. -type PrivateKey [KeySize]byte - -// NewPrivateKey generates a new curve25519 secret key. -// It conforms to the format described on https://cr.yp.to/ecdh.html. -func NewPrivateKey() (PrivateKey, error) { - k, err := NewPresharedKey() - if err != nil { - return PrivateKey{}, err - } - k[0] &= 248 - k[31] = (k[31] & 127) | 64 - return (PrivateKey)(*k), nil -} - -func ParsePrivateKey(b64 string) (*PrivateKey, error) { - k, err := parseKeyBase64(base64.StdEncoding, b64) - return (*PrivateKey)(k), err -} - -func (k *PrivateKey) String() string { return base64.StdEncoding.EncodeToString(k[:]) } -func (k *PrivateKey) HexString() string { return hex.EncodeToString(k[:]) } -func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } - -func (k *PrivateKey) IsZero() bool { - pk := Key(*k) - return pk.IsZero() -} - -func (k *PrivateKey) clamp() { - k[0] &= 248 - k[31] = (k[31] & 127) | 64 -} - -// Public computes the public key matching this curve25519 secret key. -func (k *PrivateKey) Public() Key { - pk := Key(*k) - if pk.IsZero() { - panic("Tried to generate emptyPrivateKey.Public()") - } - var p [KeySize]byte - curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k)) - return (Key)(p) -} - -func (k PrivateKey) MarshalText() ([]byte, error) { - buf := new(bytes.Buffer) - fmt.Fprintf(buf, `privkey:%x`, k[:]) - return buf.Bytes(), nil -} - -func (k *PrivateKey) UnmarshalText(b []byte) error { - s := string(b) - if !strings.HasPrefix(s, `privkey:`) { - return errors.New("wgcfg.PrivateKey: UnmarshalText not given a private-key string") - } - s = strings.TrimPrefix(s, `privkey:`) - key, err := ParseHexKey(s) - if err != nil { - return fmt.Errorf("wgcfg.PrivateKey: UnmarshalText: %v", err) - } - copy(k[:], key[:]) - return nil -} - -func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) { - apk := (*[KeySize]byte)(&pub) - ask := (*[KeySize]byte)(&k) - curve25519.ScalarMult(&ss, ask, apk) - return ss -} - -func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) { - k, err := enc.DecodeString(s) - if err != nil { - return nil, &ParseError{"Invalid key: " + err.Error(), s} - } - if len(k) != KeySize { - return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} - } - var key Key - copy(key[:], k) - return &key, nil -} - -func ParseSymmetricKey(b64 string) (SymmetricKey, error) { - k, err := parseKeyBase64(base64.StdEncoding, b64) - if err != nil { - return SymmetricKey{}, err - } - return SymmetricKey(*k), nil -} - -func ParseSymmetricHexKey(s string) (SymmetricKey, error) { - b, err := hex.DecodeString(s) - if err != nil { - return SymmetricKey{}, &ParseError{"invalid symmetric hex key: " + err.Error(), s} - } - if len(b) != chacha20poly1305.KeySize { - return SymmetricKey{}, &ParseError{fmt.Sprintf("invalid symmetric hex key length: %d", len(b)), s} - } - var key SymmetricKey - copy(key[:], b) - return key, nil -} - -// SymmetricKey is a 32-byte value used as a pre-shared key. -type SymmetricKey [chacha20poly1305.KeySize]byte - -func (k SymmetricKey) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } -func (k SymmetricKey) String() string { return "sym:" + k.Base64()[:8] } -func (k SymmetricKey) HexString() string { return hex.EncodeToString(k[:]) } -func (k SymmetricKey) IsZero() bool { return k.Equal(SymmetricKey{}) } -func (k SymmetricKey) Equal(k2 SymmetricKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } diff --git a/wgcfg/key_test.go b/wgcfg/key_test.go deleted file mode 100644 index 0b82d5f..0000000 --- a/wgcfg/key_test.go +++ /dev/null @@ -1,107 +0,0 @@ -package wgcfg - -import ( - "bytes" - "testing" -) - -func TestKeyBasics(t *testing.T) { - k1, err := NewPresharedKey() - if err != nil { - t.Fatal(err) - } - - b, err := k1.MarshalJSON() - if err != nil { - t.Fatal(err) - } - - t.Run("JSON round-trip", func(t *testing.T) { - // should preserve the keys - k2 := new(Key) - if err := k2.UnmarshalJSON(b); err != nil { - t.Fatal(err) - } - if !bytes.Equal(k1[:], k2[:]) { - t.Fatalf("k1 %v != k2 %v", k1[:], k2[:]) - } - if b1, b2 := k1.String(), k2.String(); b1 != b2 { - t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) - } - }) - - t.Run("JSON incompatible with PrivateKey", func(t *testing.T) { - k2 := new(PrivateKey) - if err := k2.UnmarshalText(b); err == nil { - t.Fatalf("successfully decoded key as private key") - } - }) - - t.Run("second key", func(t *testing.T) { - // A second call to NewPresharedKey should make a new key. - k3, err := NewPresharedKey() - if err != nil { - t.Fatal(err) - } - if bytes.Equal(k1[:], k3[:]) { - t.Fatalf("k1 %v == k3 %v", k1[:], k3[:]) - } - // Check for obvious comparables to make sure we are not generating bad strings somewhere. - if b1, b2 := k1.String(), k3.String(); b1 == b2 { - t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) - } - }) -} -func TestPrivateKeyBasics(t *testing.T) { - pri, err := NewPrivateKey() - if err != nil { - t.Fatal(err) - } - - b, err := pri.MarshalText() - if err != nil { - t.Fatal(err) - } - - t.Run("JSON round-trip", func(t *testing.T) { - // should preserve the keys - pri2 := new(PrivateKey) - if err := pri2.UnmarshalText(b); err != nil { - t.Fatal(err) - } - if !bytes.Equal(pri[:], pri2[:]) { - t.Fatalf("pri %v != pri2 %v", pri[:], pri2[:]) - } - if b1, b2 := pri.String(), pri2.String(); b1 != b2 { - t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) - } - if pub1, pub2 := pri.Public().String(), pri2.Public().String(); pub1 != pub2 { - t.Fatalf("base64-encoded public keys do not match: %s, %s", pub1, pub2) - } - }) - - t.Run("JSON incompatible with Key", func(t *testing.T) { - k2 := new(Key) - if err := k2.UnmarshalJSON(b); err == nil { - t.Fatalf("successfully decoded private key as key") - } - }) - - t.Run("second key", func(t *testing.T) { - // A second call to New should make a new key. - pri3, err := NewPrivateKey() - if err != nil { - t.Fatal(err) - } - if bytes.Equal(pri[:], pri3[:]) { - t.Fatalf("pri %v == pri3 %v", pri[:], pri3[:]) - } - // Check for obvious comparables to make sure we are not generating bad strings somewhere. - if b1, b2 := pri.String(), pri3.String(); b1 == b2 { - t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) - } - if pub1, pub2 := pri.Public().String(), pri3.Public().String(); pub1 == pub2 { - t.Fatalf("base64-encoded public keys match: %s, %s", pub1, pub2) - } - }) -} diff --git a/wgcfg/name.go b/wgcfg/name.go deleted file mode 100644 index 28bc0f0..0000000 --- a/wgcfg/name.go +++ /dev/null @@ -1,49 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wgcfg - -import ( - "regexp" - "strings" -) - -var reservedNames = []string{ - "CON", "PRN", "AUX", "NUL", - "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", - "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", -} - -const specialChars = "/\\<>:\"|?*\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x00" - -var allowedNameFormat *regexp.Regexp - -func init() { - allowedNameFormat = regexp.MustCompile("^[a-zA-Z0-9_=+.-]{1,32}$") -} - -func isReserved(name string) bool { - if len(name) == 0 { - return false - } - for _, reserved := range reservedNames { - if strings.EqualFold(name, reserved) { - return true - } - } - return false -} - -func hasSpecialChars(name string) bool { - return strings.ContainsAny(name, specialChars) -} - -func TunnelNameIsValid(name string) bool { - // Aside from our own restrictions, let's impose the Windows restrictions first - if isReserved(name) || hasSpecialChars(name) { - return false - } - return allowedNameFormat.MatchString(name) -} diff --git a/wgcfg/parser.go b/wgcfg/parser.go deleted file mode 100644 index e71d32b..0000000 --- a/wgcfg/parser.go +++ /dev/null @@ -1,397 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wgcfg - -import ( - "encoding/hex" - "fmt" - "net" - "strconv" - "strings" -) - -type ParseError struct { - why string - offender string -} - -func (e *ParseError) Error() string { - return fmt.Sprintf("%s: ā%sā", e.why, e.offender) -} - -func parseEndpoints(s string) ([]Endpoint, error) { - var eps []Endpoint - vals := strings.Split(s, ",") - for _, val := range vals { - e, err := parseEndpoint(val) - if err != nil { - return nil, err - } - eps = append(eps, *e) - } - return eps, nil -} - -func parseEndpoint(s string) (*Endpoint, error) { - i := strings.LastIndexByte(s, ':') - if i < 0 { - return nil, &ParseError{"Missing port from endpoint", s} - } - host, portStr := s[:i], s[i+1:] - if len(host) < 1 { - return nil, &ParseError{"Invalid endpoint host", host} - } - port, err := parsePort(portStr) - if err != nil { - return nil, err - } - hostColon := strings.IndexByte(host, ':') - if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { - err := &ParseError{"Brackets must contain an IPv6 address", host} - if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { - maybeV6 := net.ParseIP(host[1 : len(host)-1]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { - return nil, err - } - } else { - return nil, err - } - host = host[1 : len(host)-1] - } - return &Endpoint{host, uint16(port)}, nil -} - -func parseMTU(s string) (uint16, error) { - m, err := strconv.Atoi(s) - if err != nil { - return 0, err - } - if m < 576 || m > 65535 { - return 0, &ParseError{"Invalid MTU", s} - } - return uint16(m), nil -} - -func parsePort(s string) (uint16, error) { - m, err := strconv.Atoi(s) - if err != nil { - return 0, err - } - if m < 0 || m > 65535 { - return 0, &ParseError{"Invalid port", s} - } - return uint16(m), nil -} - -func parsePersistentKeepalive(s string) (uint16, error) { - if s == "off" { - return 0, nil - } - m, err := strconv.Atoi(s) - if err != nil { - return 0, err - } - if m < 0 || m > 65535 { - return 0, &ParseError{"Invalid persistent keepalive", s} - } - return uint16(m), nil -} - -func parseKeyHex(s string) (*Key, error) { - k, err := hex.DecodeString(s) - if err != nil { - return nil, &ParseError{"Invalid key: " + err.Error(), s} - } - if len(k) != KeySize { - return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} - } - var key Key - copy(key[:], k) - return &key, nil -} - -func parseBytesOrStamp(s string) (uint64, error) { - b, err := strconv.ParseUint(s, 10, 64) - if err != nil { - return 0, &ParseError{"Number must be a number between 0 and 2^64-1: " + err.Error(), s} - } - return b, nil -} - -func splitList(s string) ([]string, error) { - var out []string - for _, split := range strings.Split(s, ",") { - trim := strings.TrimSpace(split) - if len(trim) == 0 { - return nil, &ParseError{"Two commas in a row", s} - } - out = append(out, trim) - } - return out, nil -} - -type parserState int - -const ( - inInterfaceSection parserState = iota - inPeerSection - notInASection -) - -func (c *Config) maybeAddPeer(p *Peer) { - if p != nil { - c.Peers = append(c.Peers, *p) - } -} - -func FromWgQuick(s string, name string) (*Config, error) { - if !TunnelNameIsValid(name) { - return nil, &ParseError{"Tunnel name is not valid", name} - } - lines := strings.Split(s, "\n") - parserState := notInASection - conf := Config{Name: name} - sawPrivateKey := false - var peer *Peer - for _, line := range lines { - pound := strings.IndexByte(line, '#') - if pound >= 0 { - line = line[:pound] - } - line = strings.TrimSpace(line) - lineLower := strings.ToLower(line) - if len(line) == 0 { - continue - } - if lineLower == "[interface]" { - conf.maybeAddPeer(peer) - parserState = inInterfaceSection - continue - } - if lineLower == "[peer]" { - conf.maybeAddPeer(peer) - peer = &Peer{} - parserState = inPeerSection - continue - } - if parserState == notInASection { - return nil, &ParseError{"Line must occur in a section", line} - } - equals := strings.IndexByte(line, '=') - if equals < 0 { - return nil, &ParseError{"Invalid config key is missing an equals separator", line} - } - key, val := strings.TrimSpace(lineLower[:equals]), strings.TrimSpace(line[equals+1:]) - if len(val) == 0 { - return nil, &ParseError{"Key must have a value", line} - } - if parserState == inInterfaceSection { - switch key { - case "privatekey": - k, err := ParseKey(val) - if err != nil { - return nil, err - } - conf.PrivateKey = PrivateKey(*k) - sawPrivateKey = true - case "listenport": - p, err := parsePort(val) - if err != nil { - return nil, err - } - conf.ListenPort = p - case "mtu": - m, err := parseMTU(val) - if err != nil { - return nil, err - } - conf.MTU = m - case "address": - addresses, err := splitList(val) - if err != nil { - return nil, err - } - for _, address := range addresses { - a, err := ParseCIDR(address) - if err != nil { - return nil, err - } - conf.Addresses = append(conf.Addresses, a) - } - case "dns": - addresses, err := splitList(val) - if err != nil { - return nil, err - } - for _, address := range addresses { - a, ok := ParseIP(address) - if !ok { - return nil, &ParseError{"Invalid IP address", address} - } - conf.DNS = append(conf.DNS, a) - } - default: - return nil, &ParseError{"Invalid key for [Interface] section", key} - } - } else if parserState == inPeerSection { - switch key { - case "publickey": - k, err := ParseKey(val) - if err != nil { - return nil, err - } - peer.PublicKey = *k - case "presharedkey": - k, err := ParseKey(val) - if err != nil { - return nil, err - } - peer.PresharedKey = SymmetricKey(*k) - case "allowedips": - addresses, err := splitList(val) - if err != nil { - return nil, err - } - for _, address := range addresses { - a, err := ParseCIDR(address) - if err != nil { - return nil, err - } - peer.AllowedIPs = append(peer.AllowedIPs, a) - } - case "persistentkeepalive": - p, err := parsePersistentKeepalive(val) - if err != nil { - return nil, err - } - peer.PersistentKeepalive = p - case "endpoint": - eps, err := parseEndpoints(val) - if err != nil { - return nil, err - } - peer.Endpoints = eps - default: - return nil, &ParseError{"Invalid key for [Peer] section", key} - } - } - } - conf.maybeAddPeer(peer) - - if !sawPrivateKey { - return nil, &ParseError{"An interface must have a private key", "[none specified]"} - } - for _, p := range conf.Peers { - if p.PublicKey.IsZero() { - return nil, &ParseError{"All peers must have public keys", "[none specified]"} - } - } - - return &conf, nil -} - -// TODO(apenwarr): This is incompatibe with current Device.IpcSetOperation. -// It duplicates all the parser stuff in there, but is missing some -// keywords. Nothing useful seems to need it anymore. -func Broken_FromUAPI(s string, existingConfig *Config) (*Config, error) { - lines := strings.Split(s, "\n") - parserState := inInterfaceSection - conf := Config{ - Name: existingConfig.Name, - Addresses: existingConfig.Addresses, - DNS: existingConfig.DNS, - MTU: existingConfig.MTU, - } - var peer *Peer - for _, line := range lines { - if len(line) == 0 { - continue - } - equals := strings.IndexByte(line, '=') - if equals < 0 { - return nil, &ParseError{"Invalid config key is missing an equals separator", line} - } - key, val := line[:equals], line[equals+1:] - if len(val) == 0 { - return nil, &ParseError{"Key must have a value", line} - } - switch key { - case "public_key": - conf.maybeAddPeer(peer) - peer = &Peer{} - parserState = inPeerSection - case "errno": - if val == "0" { - continue - } else { - return nil, &ParseError{"Error in getting configuration", val} - } - } - if parserState == inInterfaceSection { - switch key { - case "private_key": - k, err := parseKeyHex(val) - if err != nil { - return nil, err - } - conf.PrivateKey = PrivateKey(*k) - case "listen_port": - p, err := parsePort(val) - if err != nil { - return nil, err - } - conf.ListenPort = p - case "fwmark": - // Ignored for now. - - default: - return nil, &ParseError{"Invalid key for interface section", key} - } - } else if parserState == inPeerSection { - switch key { - case "public_key": - k, err := parseKeyHex(val) - if err != nil { - return nil, err - } - peer.PublicKey = *k - case "preshared_key": - k, err := parseKeyHex(val) - if err != nil { - return nil, err - } - peer.PresharedKey = SymmetricKey(*k) - case "protocol_version": - if val != "1" { - return nil, &ParseError{"Protocol version must be 1", val} - } - case "allowed_ip": - a, err := ParseCIDR(val) - if err != nil { - return nil, err - } - peer.AllowedIPs = append(peer.AllowedIPs, a) - case "persistent_keepalive_interval": - p, err := parsePersistentKeepalive(val) - if err != nil { - return nil, err - } - peer.PersistentKeepalive = p - case "endpoint": - eps, err := parseEndpoints(val) - if err != nil { - return nil, err - } - peer.Endpoints = eps - default: - return nil, &ParseError{"Invalid key for peer section", key} - } - } - } - conf.maybeAddPeer(peer) - - return &conf, nil -} diff --git a/wgcfg/parser_test.go b/wgcfg/parser_test.go deleted file mode 100644 index d0df537..0000000 --- a/wgcfg/parser_test.go +++ /dev/null @@ -1,127 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wgcfg - -import ( - "reflect" - "runtime" - "testing" -) - -const testInput = ` -[Interface] -Address = 10.192.122.1/24 -Address = 10.10.0.1/16 -PrivateKey = yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk= -ListenPort = 51820 #comments don't matter - -[Peer] -PublicKey = xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg= -Endpoint = 192.95.5.67:1234 -AllowedIPs = 10.192.122.3/32, 10.192.124.1/24 - -[Peer] -PublicKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= -Endpoint = [2607:5300:60:6b0::c05f:543]:2468 -AllowedIPs = 10.192.122.4/32, 192.168.0.0/16 -PersistentKeepalive = 100 - -[Peer] -PublicKey = gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA= -PresharedKey = TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0= -Endpoint = test.wireguard.com:18981 -AllowedIPs = 10.10.10.230/32` - -func noError(t *testing.T, err error) bool { - if err == nil { - return true - } - _, fn, line, _ := runtime.Caller(1) - t.Errorf("Error at %s:%d: %#v", fn, line, err) - return false -} - -func equal(t *testing.T, expected, actual interface{}) bool { - if reflect.DeepEqual(expected, actual) { - return true - } - _, fn, line, _ := runtime.Caller(1) - t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) - return false -} -func lenTest(t *testing.T, actualO interface{}, expected int) bool { - actual := reflect.ValueOf(actualO).Len() - if reflect.DeepEqual(expected, actual) { - return true - } - _, fn, line, _ := runtime.Caller(1) - t.Errorf("Wrong length at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) - return false -} -func contains(t *testing.T, list, element interface{}) bool { - listValue := reflect.ValueOf(list) - for i := 0; i < listValue.Len(); i++ { - if reflect.DeepEqual(listValue.Index(i).Interface(), element) { - return true - } - } - _, fn, line, _ := runtime.Caller(1) - t.Errorf("Error %s:%d\nelement not found: %#v", fn, line, element) - return false -} - -func TestFromWgQuick(t *testing.T) { - conf, err := FromWgQuick(testInput, "test") - if noError(t, err) { - - lenTest(t, conf.Addresses, 2) - contains(t, conf.Addresses, CIDR{IPv4(10, 10, 0, 1), 16}) - contains(t, conf.Addresses, CIDR{IPv4(10, 192, 122, 1), 24}) - equal(t, "yAnz5TF+lXXJte14tji3zlMNq+hd2rYUIgJBgB3fBmk=", conf.PrivateKey.String()) - equal(t, uint16(51820), conf.ListenPort) - - lenTest(t, conf.Peers, 3) - lenTest(t, conf.Peers[0].AllowedIPs, 2) - equal(t, Endpoint{Host: "192.95.5.67", Port: 1234}, conf.Peers[0].Endpoints[0]) - equal(t, "xTIBA5rboUvnH4htodjb6e697QjLERt1NAB4mZqp8Dg=", conf.Peers[0].PublicKey.Base64()) - - lenTest(t, conf.Peers[1].AllowedIPs, 2) - equal(t, Endpoint{Host: "2607:5300:60:6b0::c05f:543", Port: 2468}, conf.Peers[1].Endpoints[0]) - equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[1].PublicKey.Base64()) - equal(t, uint16(100), conf.Peers[1].PersistentKeepalive) - - lenTest(t, conf.Peers[2].AllowedIPs, 1) - equal(t, Endpoint{Host: "test.wireguard.com", Port: 18981}, conf.Peers[2].Endpoints[0]) - equal(t, "gN65BkIKy1eCE9pP1wdc8ROUtkHLF2PfAqYdyYBz6EA=", conf.Peers[2].PublicKey.Base64()) - equal(t, "TrMvSoP4jYQlY6RIzBgbssQqY3vxI2Pi+y71lOWWXX0=", conf.Peers[2].PresharedKey.Base64()) - } -} - -func TestParseEndpoint(t *testing.T) { - _, err := parseEndpoint("[192.168.42.0:]:51880") - if err == nil { - t.Error("Error was expected") - } - e, err := parseEndpoint("192.168.42.0:51880") - if noError(t, err) { - equal(t, "192.168.42.0", e.Host) - equal(t, uint16(51880), e.Port) - } - e, err = parseEndpoint("test.wireguard.com:18981") - if noError(t, err) { - equal(t, "test.wireguard.com", e.Host) - equal(t, uint16(18981), e.Port) - } - e, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468") - if noError(t, err) { - equal(t, "2607:5300:60:6b0::c05f:543", e.Host) - equal(t, uint16(2468), e.Port) - } - _, err = parseEndpoint("[::::::invalid:18981") - if err == nil { - t.Error("Error was expected") - } -} diff --git a/wgcfg/writer.go b/wgcfg/writer.go deleted file mode 100644 index 246a57d..0000000 --- a/wgcfg/writer.go +++ /dev/null @@ -1,73 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package wgcfg - -import ( - "fmt" - "net" - "strings" -) - -func (conf *Config) ToUAPI() (string, error) { - output := new(strings.Builder) - fmt.Fprintf(output, "private_key=%s\n", conf.PrivateKey.HexString()) - - if conf.ListenPort > 0 { - fmt.Fprintf(output, "listen_port=%d\n", conf.ListenPort) - } - - output.WriteString("replace_peers=true\n") - - for _, peer := range conf.Peers { - fmt.Fprintf(output, "public_key=%s\n", peer.PublicKey.HexString()) - fmt.Fprintf(output, "protocol_version=1\n") - fmt.Fprintf(output, "replace_allowed_ips=true\n") - - if !peer.PresharedKey.IsZero() { - fmt.Fprintf(output, "preshared_key = %s\n", peer.PresharedKey.String()) - } - - if len(peer.AllowedIPs) > 0 { - for _, address := range peer.AllowedIPs { - fmt.Fprintf(output, "allowed_ip=%s\n", address.String()) - } - } - - if len(peer.Endpoints) > 0 { - var reps []string - for _, ep := range peer.Endpoints { - ips, err := net.LookupIP(ep.Host) - if err != nil { - return "", err - } - var ip net.IP - for _, iterip := range ips { - if ip4 := iterip.To4(); ip4 != nil { - ip = ip4 - break - } - if ip == nil { - ip = iterip - } - } - if ip == nil { - return "", fmt.Errorf("unable to resolve IP address of endpoint %q (%v)", ep.Host, ips) - } - resolvedEndpoint := Endpoint{ip.String(), ep.Port} - reps = append(reps, resolvedEndpoint.String()) - } - fmt.Fprintf(output, "endpoint=%s\n", strings.Join(reps, ",")) - } else { - fmt.Fprint(output, "endpoint=\n") - } - - // Note: this needs to come *after* endpoint definitions, - // because setting it will trigger a handshake to all - // already-defined endpoints. - fmt.Fprintf(output, "persistent_keepalive_interval=%d\n", peer.PersistentKeepalive) - } - return output.String(), nil -} |