diff options
Diffstat (limited to 'tun/tun_windows.go')
-rw-r--r-- | tun/tun_windows.go | 141 |
1 files changed, 64 insertions, 77 deletions
diff --git a/tun/tun_windows.go b/tun/tun_windows.go index ff16e2f..2af8e3e 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -8,7 +8,6 @@ package tun import ( "errors" "fmt" - "log" "os" "sync" "sync/atomic" @@ -16,8 +15,7 @@ import ( _ "unsafe" "golang.org/x/sys/windows" - - "golang.zx2c4.com/wireguard/tun/wintun" + "golang.zx2c4.com/wintun" ) const ( @@ -27,14 +25,15 @@ const ( ) type rateJuggler struct { - current uint64 - nextByteCount uint64 - nextStartTime int64 - changing int32 + current atomic.Uint64 + nextByteCount atomic.Uint64 + nextStartTime atomic.Int64 + changing atomic.Bool } type NativeTun struct { wt *wintun.Adapter + name string handle windows.Handle rate rateJuggler session wintun.Session @@ -42,12 +41,15 @@ type NativeTun struct { events chan Event running sync.WaitGroup closeOnce sync.Once - close int32 + close atomic.Bool forcedMTU int + outSizes []int } -var WintunPool, _ = wintun.MakePool("WireGuard") -var WintunStaticRequestedGUID *windows.GUID +var ( + WintunTunnelType = "WireGuard" + WintunStaticRequestedGUID *windows.GUID +) //go:linkname procyield runtime.procyield func procyield(cycles uint32) @@ -55,38 +57,19 @@ func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 -// // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. -// func CreateTUN(ifname string, mtu int) (Device, error) { return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) } -// // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. -// func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { - var err error - var wt *wintun.Adapter - - // Does an interface with this name already exist? - wt, err = WintunPool.OpenAdapter(ifname) - if err == nil { - // If so, we delete it, in case it has weird residual configuration. - _, err = wt.Delete(true) - if err != nil { - return nil, fmt.Errorf("Error deleting already existing interface: %w", err) - } - } - wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID) + wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { return nil, fmt.Errorf("Error creating interface: %w", err) } - if rebootRequired { - log.Println("Windows indicated a reboot is required.") - } forcedMTU := 1420 if mtu > 0 { @@ -95,6 +78,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu tun := &NativeTun{ wt: wt, + name: ifname, handle: windows.InvalidHandle, events: make(chan Event, 10), forcedMTU: forcedMTU, @@ -102,7 +86,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB if err != nil { - tun.wt.Delete(false) + tun.wt.Close() close(tun.events) return nil, fmt.Errorf("Error starting session: %w", err) } @@ -111,31 +95,26 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu } func (tun *NativeTun) Name() (string, error) { - tun.running.Add(1) - defer tun.running.Done() - if atomic.LoadInt32(&tun.close) == 1 { - return "", os.ErrClosed - } - return tun.wt.Name() + return tun.name, nil } func (tun *NativeTun) File() *os.File { return nil } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Close() error { var err error tun.closeOnce.Do(func() { - atomic.StoreInt32(&tun.close, 1) + tun.close.Store(true) windows.SetEvent(tun.readWait) tun.running.Wait() tun.session.End() if tun.wt != nil { - _, err = tun.wt.Delete(false) + tun.wt.Close() } close(tun.events) }) @@ -148,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } update := tun.forcedMTU != mtu tun.forcedMTU = mtu if update { @@ -155,29 +137,34 @@ func (tun *NativeTun) ForceMTU(mtu int) { } } +func (tun *NativeTun) BatchSize() int { + // TODO: implement batching with wintun + return 1 +} + // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() retry: - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0, os.ErrClosed } start := nanotime() - shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 + shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0, os.ErrClosed } packet, err := tun.session.ReceivePacket() switch err { case nil: - packetSize := len(packet) - copy(buff[offset:], packet) + n := copy(bufs[0][offset:], packet) + sizes[0] = n tun.session.ReleaseReceivePacket(packet) - tun.rate.update(uint64(packetSize)) - return packetSize, nil + tun.rate.update(uint64(n)) + return 1, nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { windows.WaitForSingleObject(tun.readWait, windows.INFINITE) @@ -194,40 +181,40 @@ retry: } } -func (tun *NativeTun) Flush() error { - return nil -} - -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0, os.ErrClosed } - packetSize := len(buff) - offset - tun.rate.update(uint64(packetSize)) + for i, buf := range bufs { + packetSize := len(buf) - offset + tun.rate.update(uint64(packetSize)) - packet, err := tun.session.AllocateSendPacket(packetSize) - if err == nil { - copy(packet, buff[offset:]) - tun.session.SendPacket(packet) - return packetSize, nil - } - switch err { - case windows.ERROR_HANDLE_EOF: - return 0, os.ErrClosed - case windows.ERROR_BUFFER_OVERFLOW: - return 0, nil // Dropping when ring is full. + packet, err := tun.session.AllocateSendPacket(packetSize) + switch err { + case nil: + // TODO: Explore options to eliminate this copy. + copy(packet, buf[offset:]) + tun.session.SendPacket(packet) + continue + case windows.ERROR_HANDLE_EOF: + return i, os.ErrClosed + case windows.ERROR_BUFFER_OVERFLOW: + continue // Dropping when ring is full. + default: + return i, fmt.Errorf("Write failed: %w", err) + } } - return 0, fmt.Errorf("Write failed: %w", err) + return len(bufs), nil } // LUID returns Windows interface instance ID. func (tun *NativeTun) LUID() uint64 { tun.running.Add(1) defer tun.running.Done() - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0 } return tun.wt.LUID() @@ -240,15 +227,15 @@ func (tun *NativeTun) RunningVersion() (version uint32, err error) { func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() - total := atomic.AddUint64(&rate.nextByteCount, packetLen) - period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + total := rate.nextByteCount.Add(packetLen) + period := uint64(now - rate.nextStartTime.Load()) if period >= rateMeasurementGranularity { - if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + if !rate.changing.CompareAndSwap(false, true) { return } - atomic.StoreInt64(&rate.nextStartTime, now) - atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) - atomic.StoreUint64(&rate.nextByteCount, 0) - atomic.StoreInt32(&rate.changing, 0) + rate.nextStartTime.Store(now) + rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) + rate.nextByteCount.Store(0) + rate.changing.Store(false) } } |