diff options
Diffstat (limited to '')
-rw-r--r-- | tun/tun_windows.go | 102 |
1 files changed, 53 insertions, 49 deletions
diff --git a/tun/tun_windows.go b/tun/tun_windows.go index d057150..2af8e3e 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package tun @@ -15,7 +15,6 @@ import ( _ "unsafe" "golang.org/x/sys/windows" - "golang.zx2c4.com/wintun" ) @@ -26,10 +25,10 @@ const ( ) type rateJuggler struct { - current uint64 - nextByteCount uint64 - nextStartTime int64 - changing int32 + current atomic.Uint64 + nextByteCount atomic.Uint64 + nextStartTime atomic.Int64 + changing atomic.Bool } type NativeTun struct { @@ -42,8 +41,9 @@ type NativeTun struct { events chan Event running sync.WaitGroup closeOnce sync.Once - close int32 + close atomic.Bool forcedMTU int + outSizes []int } var ( @@ -57,18 +57,14 @@ func procyield(cycles uint32) //go:linkname nanotime runtime.nanotime func nanotime() int64 -// // CreateTUN creates a Wintun interface with the given name. Should a Wintun // interface with the same name exist, it is reused. -// func CreateTUN(ifname string, mtu int) (Device, error) { return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu) } -// // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // a requested GUID. Should a Wintun interface with the same name exist, it is reused. -// func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) { wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID) if err != nil { @@ -106,14 +102,14 @@ func (tun *NativeTun) File() *os.File { return nil } -func (tun *NativeTun) Events() chan Event { +func (tun *NativeTun) Events() <-chan Event { return tun.events } func (tun *NativeTun) Close() error { var err error tun.closeOnce.Do(func() { - atomic.StoreInt32(&tun.close, 1) + tun.close.Store(true) windows.SetEvent(tun.readWait) tun.running.Wait() tun.session.End() @@ -131,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) { // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. func (tun *NativeTun) ForceMTU(mtu int) { + if tun.close.Load() { + return + } update := tun.forcedMTU != mtu tun.forcedMTU = mtu if update { @@ -138,29 +137,34 @@ func (tun *NativeTun) ForceMTU(mtu int) { } } +func (tun *NativeTun) BatchSize() int { + // TODO: implement batching with wintun + return 1 +} + // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() retry: - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0, os.ErrClosed } start := nanotime() - shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2 + shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2 for { - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0, os.ErrClosed } packet, err := tun.session.ReceivePacket() switch err { case nil: - packetSize := len(packet) - copy(buff[offset:], packet) + n := copy(bufs[0][offset:], packet) + sizes[0] = n tun.session.ReleaseReceivePacket(packet) - tun.rate.update(uint64(packetSize)) - return packetSize, nil + tun.rate.update(uint64(n)) + return 1, nil case windows.ERROR_NO_MORE_ITEMS: if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration { windows.WaitForSingleObject(tun.readWait, windows.INFINITE) @@ -177,40 +181,40 @@ retry: } } -func (tun *NativeTun) Flush() error { - return nil -} - -func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { +func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { tun.running.Add(1) defer tun.running.Done() - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0, os.ErrClosed } - packetSize := len(buff) - offset - tun.rate.update(uint64(packetSize)) + for i, buf := range bufs { + packetSize := len(buf) - offset + tun.rate.update(uint64(packetSize)) - packet, err := tun.session.AllocateSendPacket(packetSize) - if err == nil { - copy(packet, buff[offset:]) - tun.session.SendPacket(packet) - return packetSize, nil - } - switch err { - case windows.ERROR_HANDLE_EOF: - return 0, os.ErrClosed - case windows.ERROR_BUFFER_OVERFLOW: - return 0, nil // Dropping when ring is full. + packet, err := tun.session.AllocateSendPacket(packetSize) + switch err { + case nil: + // TODO: Explore options to eliminate this copy. + copy(packet, buf[offset:]) + tun.session.SendPacket(packet) + continue + case windows.ERROR_HANDLE_EOF: + return i, os.ErrClosed + case windows.ERROR_BUFFER_OVERFLOW: + continue // Dropping when ring is full. + default: + return i, fmt.Errorf("Write failed: %w", err) + } } - return 0, fmt.Errorf("Write failed: %w", err) + return len(bufs), nil } // LUID returns Windows interface instance ID. func (tun *NativeTun) LUID() uint64 { tun.running.Add(1) defer tun.running.Done() - if atomic.LoadInt32(&tun.close) == 1 { + if tun.close.Load() { return 0 } return tun.wt.LUID() @@ -223,15 +227,15 @@ func (tun *NativeTun) RunningVersion() (version uint32, err error) { func (rate *rateJuggler) update(packetLen uint64) { now := nanotime() - total := atomic.AddUint64(&rate.nextByteCount, packetLen) - period := uint64(now - atomic.LoadInt64(&rate.nextStartTime)) + total := rate.nextByteCount.Add(packetLen) + period := uint64(now - rate.nextStartTime.Load()) if period >= rateMeasurementGranularity { - if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) { + if !rate.changing.CompareAndSwap(false, true) { return } - atomic.StoreInt64(&rate.nextStartTime, now) - atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period) - atomic.StoreUint64(&rate.nextByteCount, 0) - atomic.StoreInt32(&rate.changing, 0) + rate.nextStartTime.Store(now) + rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period) + rate.nextByteCount.Store(0) + rate.changing.Store(false) } } |