diff options
Diffstat (limited to '')
-rw-r--r-- | conn/bind_windows.go | 105 |
1 files changed, 62 insertions, 43 deletions
diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 9268bc1..d5095e0 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn @@ -74,7 +74,7 @@ type afWinRingBind struct { type WinRingBind struct { v4, v6 afWinRingBind mu sync.RWMutex - isOpen uint32 + isOpen atomic.Uint32 // 0, 1, or 2 } func NewDefaultBind() Bind { return NewWinRingBind() } @@ -164,7 +164,7 @@ func (e *WinRingEndpoint) DstToBytes() []byte { func (e *WinRingEndpoint) DstToString() string { switch e.family { case windows.AF_INET: - netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() + return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() case windows.AF_INET6: var zone string if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { @@ -212,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() { } func (bind *WinRingBind) closeAndZero() { - atomic.StoreUint32(&bind.isOpen, 0) + bind.isOpen.Store(0) bind.v4.CloseAndZero() bind.v6.CloseAndZero() } @@ -276,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort bind.closeAndZero() } }() - if atomic.LoadUint32(&bind.isOpen) != 0 { + if bind.isOpen.Load() != 0 { return nil, 0, ErrBindAlreadyOpen } var sa windows.Sockaddr @@ -299,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort return nil, 0, err } } - atomic.StoreUint32(&bind.isOpen, 1) + bind.isOpen.Store(1) return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err } func (bind *WinRingBind) Close() error { bind.mu.RLock() - if atomic.LoadUint32(&bind.isOpen) != 1 { + if bind.isOpen.Load() != 1 { bind.mu.RUnlock() return nil } - atomic.StoreUint32(&bind.isOpen, 2) + bind.isOpen.Store(2) windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) @@ -321,6 +321,13 @@ func (bind *WinRingBind) Close() error { return nil } +// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and +// rename the IdealBatchSize constant to BatchSize. +func (bind *WinRingBind) BatchSize() int { + // TODO: implement batching in and out of the ring + return 1 +} + func (bind *WinRingBind) SetMark(mark uint32) error { return nil } @@ -345,8 +352,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error { //go:linkname procyield runtime.procyield func procyield(cycles uint32) -func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) { - if atomic.LoadUint32(isOpen) != 1 { +func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } bind.rx.mu.Lock() @@ -359,7 +366,7 @@ retry: count = 0 for tries := 0; count == 0 && tries < receiveSpins; tries++ { if tries > 0 { - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } procyield(1) @@ -378,7 +385,7 @@ retry: if err != nil { return 0, nil, err } - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } count = winrio.DequeueCompletion(bind.rx.cq, results[:]) @@ -395,7 +402,7 @@ retry: // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to // attacker bandwidth, just like the rest of the receive path. if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } goto retry @@ -409,20 +416,26 @@ retry: return n, &ep, nil } -func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - return bind.v4.Receive(buf, &bind.isOpen) + n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err } -func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - return bind.v6.Receive(buf, &bind.isOpen) + n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err } -func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error { - if atomic.LoadUint32(isOpen) != 1 { +func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { + if isOpen.Load() != 1 { return net.ErrClosed } if len(buf) > bytesPerPacket { @@ -444,7 +457,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3 if err != nil { return err } - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return net.ErrClosed } count = winrio.DequeueCompletion(bind.tx.cq, results[:]) @@ -473,32 +486,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3 return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { +func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType } bind.mu.RLock() defer bind.mu.RUnlock() - switch nend.family { - case windows.AF_INET: - if bind.v4.blackhole { - return nil - } - return bind.v4.Send(buf, nend, &bind.isOpen) - case windows.AF_INET6: - if bind.v6.blackhole { - return nil + for _, buf := range bufs { + switch nend.family { + case windows.AF_INET: + if bind.v4.blackhole { + continue + } + if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + case windows.AF_INET6: + if bind.v6.blackhole { + continue + } + if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { + return err + } } - return bind.v6.Send(buf, nend, &bind.isOpen) } return nil } -func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - bind.mu.Lock() - defer bind.mu.Unlock() - sysconn, err := bind.ipv4.SyscallConn() +func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv4.SyscallConn() if err != nil { return err } @@ -511,14 +530,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole if err != nil { return err } - bind.blackhole4 = blackhole + s.blackhole4 = blackhole return nil } -func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - bind.mu.Lock() - defer bind.mu.Unlock() - sysconn, err := bind.ipv6.SyscallConn() +func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv6.SyscallConn() if err != nil { return err } @@ -531,14 +550,14 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole if err != nil { return err } - bind.blackhole6 = blackhole + s.blackhole6 = blackhole return nil } func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { bind.mu.RLock() defer bind.mu.RUnlock() - if atomic.LoadUint32(&bind.isOpen) != 1 { + if bind.isOpen.Load() != 1 { return net.ErrClosed } err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) @@ -552,7 +571,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { bind.mu.RLock() defer bind.mu.RUnlock() - if atomic.LoadUint32(&bind.isOpen) != 1 { + if bind.isOpen.Load() != 1 { return net.ErrClosed } err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) |