aboutsummaryrefslogtreecommitdiffstats
path: root/tun/tun_windows.go
diff options
context:
space:
mode:
Diffstat (limited to 'tun/tun_windows.go')
-rw-r--r--tun/tun_windows.go141
1 files changed, 64 insertions, 77 deletions
diff --git a/tun/tun_windows.go b/tun/tun_windows.go
index ff16e2f..2af8e3e 100644
--- a/tun/tun_windows.go
+++ b/tun/tun_windows.go
@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: MIT
*
- * Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved.
+ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/
package tun
@@ -8,7 +8,6 @@ package tun
import (
"errors"
"fmt"
- "log"
"os"
"sync"
"sync/atomic"
@@ -16,8 +15,7 @@ import (
_ "unsafe"
"golang.org/x/sys/windows"
-
- "golang.zx2c4.com/wireguard/tun/wintun"
+ "golang.zx2c4.com/wintun"
)
const (
@@ -27,14 +25,15 @@ const (
)
type rateJuggler struct {
- current uint64
- nextByteCount uint64
- nextStartTime int64
- changing int32
+ current atomic.Uint64
+ nextByteCount atomic.Uint64
+ nextStartTime atomic.Int64
+ changing atomic.Bool
}
type NativeTun struct {
wt *wintun.Adapter
+ name string
handle windows.Handle
rate rateJuggler
session wintun.Session
@@ -42,12 +41,15 @@ type NativeTun struct {
events chan Event
running sync.WaitGroup
closeOnce sync.Once
- close int32
+ close atomic.Bool
forcedMTU int
+ outSizes []int
}
-var WintunPool, _ = wintun.MakePool("WireGuard")
-var WintunStaticRequestedGUID *windows.GUID
+var (
+ WintunTunnelType = "WireGuard"
+ WintunStaticRequestedGUID *windows.GUID
+)
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
@@ -55,38 +57,19 @@ func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
-//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
-//
func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
}
-//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
-//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
- var err error
- var wt *wintun.Adapter
-
- // Does an interface with this name already exist?
- wt, err = WintunPool.OpenAdapter(ifname)
- if err == nil {
- // If so, we delete it, in case it has weird residual configuration.
- _, err = wt.Delete(true)
- if err != nil {
- return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
- }
- }
- wt, rebootRequired, err := WintunPool.CreateAdapter(ifname, requestedGUID)
+ wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
if err != nil {
return nil, fmt.Errorf("Error creating interface: %w", err)
}
- if rebootRequired {
- log.Println("Windows indicated a reboot is required.")
- }
forcedMTU := 1420
if mtu > 0 {
@@ -95,6 +78,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun := &NativeTun{
wt: wt,
+ name: ifname,
handle: windows.InvalidHandle,
events: make(chan Event, 10),
forcedMTU: forcedMTU,
@@ -102,7 +86,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
if err != nil {
- tun.wt.Delete(false)
+ tun.wt.Close()
close(tun.events)
return nil, fmt.Errorf("Error starting session: %w", err)
}
@@ -111,31 +95,26 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
}
func (tun *NativeTun) Name() (string, error) {
- tun.running.Add(1)
- defer tun.running.Done()
- if atomic.LoadInt32(&tun.close) == 1 {
- return "", os.ErrClosed
- }
- return tun.wt.Name()
+ return tun.name, nil
}
func (tun *NativeTun) File() *os.File {
return nil
}
-func (tun *NativeTun) Events() chan Event {
+func (tun *NativeTun) Events() <-chan Event {
return tun.events
}
func (tun *NativeTun) Close() error {
var err error
tun.closeOnce.Do(func() {
- atomic.StoreInt32(&tun.close, 1)
+ tun.close.Store(true)
windows.SetEvent(tun.readWait)
tun.running.Wait()
tun.session.End()
if tun.wt != nil {
- _, err = tun.wt.Delete(false)
+ tun.wt.Close()
}
close(tun.events)
})
@@ -148,6 +127,9 @@ func (tun *NativeTun) MTU() (int, error) {
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMTU(mtu int) {
+ if tun.close.Load() {
+ return
+ }
update := tun.forcedMTU != mtu
tun.forcedMTU = mtu
if update {
@@ -155,29 +137,34 @@ func (tun *NativeTun) ForceMTU(mtu int) {
}
}
+func (tun *NativeTun) BatchSize() int {
+ // TODO: implement batching with wintun
+ return 1
+}
+
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
-func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
retry:
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
start := nanotime()
- shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
+ shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
for {
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
packet, err := tun.session.ReceivePacket()
switch err {
case nil:
- packetSize := len(packet)
- copy(buff[offset:], packet)
+ n := copy(bufs[0][offset:], packet)
+ sizes[0] = n
tun.session.ReleaseReceivePacket(packet)
- tun.rate.update(uint64(packetSize))
- return packetSize, nil
+ tun.rate.update(uint64(n))
+ return 1, nil
case windows.ERROR_NO_MORE_ITEMS:
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
@@ -194,40 +181,40 @@ retry:
}
}
-func (tun *NativeTun) Flush() error {
- return nil
-}
-
-func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
+func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
tun.running.Add(1)
defer tun.running.Done()
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0, os.ErrClosed
}
- packetSize := len(buff) - offset
- tun.rate.update(uint64(packetSize))
+ for i, buf := range bufs {
+ packetSize := len(buf) - offset
+ tun.rate.update(uint64(packetSize))
- packet, err := tun.session.AllocateSendPacket(packetSize)
- if err == nil {
- copy(packet, buff[offset:])
- tun.session.SendPacket(packet)
- return packetSize, nil
- }
- switch err {
- case windows.ERROR_HANDLE_EOF:
- return 0, os.ErrClosed
- case windows.ERROR_BUFFER_OVERFLOW:
- return 0, nil // Dropping when ring is full.
+ packet, err := tun.session.AllocateSendPacket(packetSize)
+ switch err {
+ case nil:
+ // TODO: Explore options to eliminate this copy.
+ copy(packet, buf[offset:])
+ tun.session.SendPacket(packet)
+ continue
+ case windows.ERROR_HANDLE_EOF:
+ return i, os.ErrClosed
+ case windows.ERROR_BUFFER_OVERFLOW:
+ continue // Dropping when ring is full.
+ default:
+ return i, fmt.Errorf("Write failed: %w", err)
+ }
}
- return 0, fmt.Errorf("Write failed: %w", err)
+ return len(bufs), nil
}
// LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 {
tun.running.Add(1)
defer tun.running.Done()
- if atomic.LoadInt32(&tun.close) == 1 {
+ if tun.close.Load() {
return 0
}
return tun.wt.LUID()
@@ -240,15 +227,15 @@ func (tun *NativeTun) RunningVersion() (version uint32, err error) {
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
- total := atomic.AddUint64(&rate.nextByteCount, packetLen)
- period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
+ total := rate.nextByteCount.Add(packetLen)
+ period := uint64(now - rate.nextStartTime.Load())
if period >= rateMeasurementGranularity {
- if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
+ if !rate.changing.CompareAndSwap(false, true) {
return
}
- atomic.StoreInt64(&rate.nextStartTime, now)
- atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
- atomic.StoreUint64(&rate.nextByteCount, 0)
- atomic.StoreInt32(&rate.changing, 0)
+ rate.nextStartTime.Store(now)
+ rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
+ rate.nextByteCount.Store(0)
+ rate.changing.Store(false)
}
}