aboutsummaryrefslogtreecommitdiffstats
path: root/tun/tun_windows.go
diff options
context:
space:
mode:
Diffstat (limited to 'tun/tun_windows.go')
-rw-r--r--tun/tun_windows.go265
1 files changed, 118 insertions, 147 deletions
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index daad4aa..2af8e3e 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -9,13 +9,13 @@ import (
"errors"
"fmt"
"os"
+ "sync"
"sync/atomic"
"time"
- "unsafe"
+ _ "unsafe"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/tun/wintun"
+ "golang.zx2c4.com/wintun"
)
const (
@@ -25,24 +25,31 @@ const (
)
type rateJuggler struct {
- current uint64
- nextByteCount uint64
- nextStartTime int64
- changing int32
+ current atomic.Uint64
+ nextByteCount atomic.Uint64
+ nextStartTime atomic.Int64
+ changing atomic.Bool
}
type NativeTun struct {
- wt *wintun.Interface
+ wt *wintun.Adapter
+ name string
handle windows.Handle
- close bool
- rings wintun.RingDescriptor
+ rate rateJuggler
+ session wintun.Session
+ readWait windows.Handle
events chan Event
- errors chan error
+ running sync.WaitGroup
+ closeOnce sync.Once
+ close atomic.Bool
forcedMTU int
- rate rateJuggler
+ outSizes []int
}
-const WintunPool = wintun.Pool("WireGuard")
+var (
+ WintunTunnelType = "WireGuard"
+ WintunStaticRequestedGUID *windows.GUID
+)
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
@@ -50,34 +57,18 @@ func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
-//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
-//
func CreateTUN(ifname string, mtu int) (Device, error) {
- return CreateTUNWithRequestedGUID(ifname, nil, mtu)
+ return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
}
-//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
-//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
- var err error
- var wt *wintun.Interface
-
- // Does an interface with this name already exist?
- wt, err = WintunPool.GetInterface(ifname)
- if err == nil {
- // If so, we delete it, in case it has weird residual configuration.
- _, err = wt.DeleteInterface()
- if err != nil {
- return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
- }
- }
- wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
+ wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil {
- return nil, fmt.Errorf("Error creating interface: %v", err)
+ return nil, fmt.Errorf("Error creating interface: %w", err)
}
forcedMTU := 1420
@@ -87,52 +78,46 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun := &NativeTun{
wt: wt,
+ name: ifname,
handle: windows.InvalidHandle,
events: make(chan Event, 10),
- errors: make(chan error, 1),
forcedMTU: forcedMTU,
}
- err = tun.rings.Init()
- if err != nil {
- tun.Close()
- return nil, fmt.Errorf("Error creating events: %v", err)
- }
-
- tun.handle, err = tun.wt.Register(&tun.rings)
+ tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
- tun.Close()
- return nil, fmt.Errorf("Error registering rings: %v", err)
+ tun.wt.Close()
+ close(tun.events)
+ return nil, fmt.Errorf("Error starting session: %w", err)
}
+ tun.readWait = tun.session.ReadWaitEvent()
return tun, nil
}
func (tun *NativeTun) Name() (string, error) {
- return tun.wt.Name()
+ return tun.name, nil
}
func (tun *NativeTun) File() *os.File {
return nil
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Close() error {
- tun.close = true
- if tun.rings.Send.TailMoved != 0 {
- windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
- }
- if tun.handle != windows.InvalidHandle {
- windows.CloseHandle(tun.handle)
- }
- tun.rings.Close()
var err error
- if tun.wt != nil {
- _, err = tun.wt.DeleteInterface()
- }
- close(tun.events)
+ tun.closeOnce.Do(func() {
+ tun.close.Store(true)
+ windows.SetEvent(tun.readWait)
+ tun.running.Wait()
+ tun.session.End()
+ if tun.wt != nil {
+ tun.wt.Close()
+ }
+ close(tun.events)
+ })
return err
}
@@ -142,129 +127,115 @@ func (tun *NativeTun) MTU() (int, error) {
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) {
+ if tun.close.Load() {
+ return
+ }
+ update := tun.forcedMTU != mtu
tun.forcedMTU = mtu
+ if update {
+ tun.events <- EventMTUUpdate
+ }
+}
+
+func (tun *NativeTun) BatchSize() int {
+ // TODO: implement batching with wintun
+ return 1
}
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
+ tun.running.Add(1)
+ defer tun.running.Done()
retry:
- select {
- case err := <-tun.errors:
- return 0, err
- default:
- }
- if tun.close {
- return 0, os.ErrClosed
- }
-
- buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
- if buffHead >= wintun.PacketCapacity {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
-
start := nanotime()
- shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
- var buffTail uint32
+ shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
- buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
- if buffHead != buffTail {
- break
- }
- if tun.close {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
- if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
- windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
- goto retry
+ packet, err := tun.session.ReceivePacket()
+ switch err {
+ case nil:
+ n := copy(bufs[0][offset:], packet)
+ sizes[0] = n
+ tun.session.ReleaseReceivePacket(packet)
+ tun.rate.update(uint64(n))
+ return 1, nil
+ case windows.ERROR_NO_MORE_ITEMS:
+ if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
+ windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
+ goto retry
+ }
+ procyield(1)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return 0, os.ErrClosed
+ case windows.ERROR_INVALID_DATA:
+ return 0, errors.New("Send ring corrupt")
}
- procyield(1)
- }
- if buffTail >= wintun.PacketCapacity {
- return 0, os.ErrClosed
- }
-
- buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
- if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
- return 0, errors.New("incomplete packet header in send ring")
- }
-
- packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
- if packet.Size > wintun.PacketSizeMax {
- return 0, errors.New("packet too big in send ring")
- }
-
- alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
- if alignedPacketSize > buffContent {
- return 0, errors.New("incomplete packet in send ring")
+ return 0, fmt.Errorf("Read failed: %w", err)
}
-
- copy(buff[offset:], packet.Data[:packet.Size])
- buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
- atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
- tun.rate.update(uint64(packet.Size))
- return int(packet.Size), nil
-}
-
-func (tun *NativeTun) Flush() error {
- return nil
}
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
- if tun.close {
- return 0, os.ErrClosed
- }
-
- packetSize := uint32(len(buff) - offset)
- tun.rate.update(uint64(packetSize))
- alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
-
- buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
- if buffHead >= wintun.PacketCapacity {
- return 0, os.ErrClosed
- }
-
- buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
- if buffTail >= wintun.PacketCapacity {
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
+ tun.running.Add(1)
+ defer tun.running.Done()
+ if tun.close.Load() {
return 0, os.ErrClosed
}
- buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
- if alignedPacketSize > buffSpace {
- return 0, nil // Dropping when ring is full.
- }
-
- packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
- packet.Size = packetSize
- copy(packet.Data[:packetSize], buff[offset:])
- atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
- if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
- windows.SetEvent(tun.rings.Receive.TailMoved)
+ for i, buf := range bufs {
+ packetSize := len(buf) - offset
+ tun.rate.update(uint64(packetSize))
+
+ packet, err := tun.session.AllocateSendPacket(packetSize)
+ switch err {
+ case nil:
+ // TODO: Explore options to eliminate this copy.
+ copy(packet, buf[offset:])
+ tun.session.SendPacket(packet)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return i, os.ErrClosed
+ case windows.ERROR_BUFFER_OVERFLOW:
+ continue // Dropping when ring is full.
+ default:
+ return i, fmt.Errorf("Write failed: %w", err)
+ }
}
- return int(packetSize), nil
+ return len(bufs), nil
}
// LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 {
+ tun.running.Add(1)
+ defer tun.running.Done()
+ if tun.close.Load() {
+ return 0
+ }
return tun.wt.LUID()
}
-// Version returns the version of the Wintun driver and NDIS system currently loaded.
-func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
- return tun.wt.Version()
+// RunningVersion returns the running version of the Wintun driver.
+func (tun *NativeTun) RunningVersion() (version uint32, err error) {
+ return wintun.RunningVersion()
}
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
- total := atomic.AddUint64(&rate.nextByteCount, packetLen)
- period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
+ total := rate.nextByteCount.Add(packetLen)
+ period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity {
- if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
+ if !rate.changing.CompareAndSwap(false, true) {
return
}
- atomic.StoreInt64(&rate.nextStartTime, now)
- atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
- atomic.StoreUint64(&rate.nextByteCount, 0)
- atomic.StoreInt32(&rate.changing, 0)
+ rate.nextStartTime.Store(now)
+ rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
+ rate.nextByteCount.Store(0)
+ rate.changing.Store(false)
}
}