diff options
Diffstat (limited to 'tun/tun_windows.go')
-rw-r--r-- | tun/tun_windows.go | 265 |
1 files changed, 118 insertions, 147 deletions
diff --git a/tun/tun_windows.go b/tun/tun_windows.go index daad4aa..2af8e3e 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -9,13 +9,13 @@ import ( "errors" "fmt" "os" + "sync" "sync/atomic" "time" - "unsafe" + _ "unsafe" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/tun/wintun" + "golang.zx2c4.com/wintun" ) const ( @@ -25,24 +25,31 @@ const ( ) type rateJuggler struct { - current uint64 - nextByteCount uint64 - nextStartTime int64 - changing int32 + current atomic.Uint64 + nextByteCount atomic.Uint64 + nextStartTime atomic.Int64 + changing atomic.Bool } type NativeTun struct { - wt *wintun.Interface + wt *wintun.Adapter + name string handle windows.Handle - close bool - rings wintun.RingDescriptor + rate rateJuggler + session wintun.Session + readWait windows.Handle events chan Event - errors chan error + running sync.WaitGroup + closeOnce sync.Once + close atomic.Bool forcedMTU int - rate rateJuggler + outSizes []int } -const WintunPool = wintun.Pool("WireGuard") +var ( + WintunTunnelType = "WireGuard" + WintunStaticRequestedGUID *windows.GUID +) //go:linkname procyield runtime.procyield func procyield(cycles uint32) @@ -50,34 +57,18 @@ func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 -// // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. -// func CreateTUN(ifname string, mtu int) (Device, error) { - return CreateTUNWithRequestedGUID(ifname, nil, mtu) + return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) } -// // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. -// func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { - var err error - var wt *wintun.Interface - - // Does an interface with this name already exist? - wt, err = WintunPool.GetInterface(ifname) - if err == nil { - // If so, we delete it, in case it has weird residual configuration. - _, err = wt.DeleteInterface() - if err != nil { - return nil, fmt.Errorf("Error deleting already existing interface: %v", err) - } - } - wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID) + wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { - return nil, fmt.Errorf("Error creating interface: %v", err) + return nil, fmt.Errorf("Error creating interface: %w", err) } forcedMTU := 1420 @@ -87,52 +78,46 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu tun := &NativeTun{ wt: wt, + name: ifname, handle: windows.InvalidHandle, events: make(chan Event, 10), - errors: make(chan error, 1), forcedMTU: forcedMTU, } - err = tun.rings.Init() - if err != nil { - tun.Close() - return nil, fmt.Errorf("Error creating events: %v", err) - } - - tun.handle, err = tun.wt.Register(&tun.rings) + tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB if err != nil { - tun.Close() - return nil, fmt.Errorf("Error registering rings: %v", err) + tun.wt.Close() + close(tun.events) + return nil, fmt.Errorf("Error starting session: %w", err) } + tun.readWait = tun.session.ReadWaitEvent() return tun, nil } func (tun *NativeTun) Name() (string, error) { - return tun.wt.Name() + return tun.name, nil } func (tun *NativeTun) File() *os.File { return nil } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Close() error { - tun.close = true - if tun.rings.Send.TailMoved != 0 { - windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping - } - if tun.handle != windows.InvalidHandle { - windows.CloseHandle(tun.handle) - } - tun.rings.Close() var err error - if tun.wt != nil { - _, err = tun.wt.DeleteInterface() - } - close(tun.events) + tun.closeOnce.Do(func() { + tun.close.Store(true) + windows.SetEvent(tun.readWait) + tun.running.Wait() + tun.session.End() + if tun.wt != nil { + tun.wt.Close() + } + close(tun.events) + }) return err } @@ -142,129 +127,115 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } + update := tun.forcedMTU != mtu tun.forcedMTU = mtu + if update { + tun.events <- EventMTUUpdate + } +} + +func (tun *NativeTun) BatchSize() int { + // TODO: implement batching with wintun + return 1 } // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { + tun.running.Add(1) + defer tun.running.Done() retry: - select { - case err := <-tun.errors: - return 0, err - default: - } - if tun.close { - return 0, os.ErrClosed - } - - buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head) - if buffHead >= wintun.PacketCapacity { + if tun.close.Load() { return 0, os.ErrClosed } - start := nanotime() - shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 - var buffTail uint32 + shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { - buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail) - if buffHead != buffTail { - break - } - if tun.close { + if tun.close.Load() { return 0, os.ErrClosed } - if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { - windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE) - goto retry + packet, err := tun.session.ReceivePacket() + switch err { + case nil: + n := copy(bufs[0][offset:], packet) + sizes[0] = n + tun.session.ReleaseReceivePacket(packet) + tun.rate.update(uint64(n)) + return 1, nil + case windows.ERROR_NO_MORE_ITEMS: + if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { + windows.WaitForSingleObject(tun.readWait, windows.INFINITE) + goto retry + } + procyield(1) + continue + case windows.ERROR_HANDLE_EOF: + return 0, os.ErrClosed + case windows.ERROR_INVALID_DATA: + return 0, errors.New("Send ring corrupt") } - procyield(1) - } - if buffTail >= wintun.PacketCapacity { - return 0, os.ErrClosed - } - - buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead) - if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) { - return 0, errors.New("incomplete packet header in send ring") - } - - packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead])) - if packet.Size > wintun.PacketSizeMax { - return 0, errors.New("packet too big in send ring") - } - - alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size) - if alignedPacketSize > buffContent { - return 0, errors.New("incomplete packet in send ring") + return 0, fmt.Errorf("Read failed: %w", err) } - - copy(buff[offset:], packet.Data[:packet.Size]) - buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize) - atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead) - tun.rate.update(uint64(packet.Size)) - return int(packet.Size), nil -} - -func (tun *NativeTun) Flush() error { - return nil } -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { - if tun.close { - return 0, os.ErrClosed - } - - packetSize := uint32(len(buff) - offset) - tun.rate.update(uint64(packetSize)) - alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize) - - buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head) - if buffHead >= wintun.PacketCapacity { - return 0, os.ErrClosed - } - - buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail) - if buffTail >= wintun.PacketCapacity { +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { + tun.running.Add(1) + defer tun.running.Done() + if tun.close.Load() { return 0, os.ErrClosed } - buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment) - if alignedPacketSize > buffSpace { - return 0, nil // Dropping when ring is full. - } - - packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail])) - packet.Size = packetSize - copy(packet.Data[:packetSize], buff[offset:]) - atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize)) - if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 { - windows.SetEvent(tun.rings.Receive.TailMoved) + for i, buf := range bufs { + packetSize := len(buf) - offset + tun.rate.update(uint64(packetSize)) + + packet, err := tun.session.AllocateSendPacket(packetSize) + switch err { + case nil: + // TODO: Explore options to eliminate this copy. + copy(packet, buf[offset:]) + tun.session.SendPacket(packet) + continue + case windows.ERROR_HANDLE_EOF: + return i, os.ErrClosed + case windows.ERROR_BUFFER_OVERFLOW: + continue // Dropping when ring is full. + default: + return i, fmt.Errorf("Write failed: %w", err) + } } - return int(packetSize), nil + return len(bufs), nil } // LUID returns Windows interface instance ID. func (tun *NativeTun) LUID() uint64 { + tun.running.Add(1) + defer tun.running.Done() + if tun.close.Load() { + return 0 + } return tun.wt.LUID() } -// Version returns the version of the Wintun driver and NDIS system currently loaded. -func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) { - return tun.wt.Version() +// RunningVersion returns the running version of the Wintun driver. +func (tun *NativeTun) RunningVersion() (version uint32, err error) { + return wintun.RunningVersion() } func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() - total := atomic.AddUint64(&rate.nextByteCount, packetLen) - period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + total := rate.nextByteCount.Add(packetLen) + period := uint64(now - rate.nextStartTime.Load()) if period >= rateMeasurementGranularity { - if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + if !rate.changing.CompareAndSwap(false, true) { return } - atomic.StoreInt64(&rate.nextStartTime, now) - atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) - atomic.StoreUint64(&rate.nextByteCount, 0) - atomic.StoreInt32(&rate.changing, 0) + rate.nextStartTime.Store(now) + rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) + rate.nextByteCount.Store(0) + rate.changing.Store(false) } } |