diff options
Diffstat (limited to 'conn/bind_windows.go')
-rw-r--r-- | conn/bind_windows.go | 152 |
1 files changed, 76 insertions, 76 deletions
diff --git a/conn/bind_windows.go b/conn/bind_windows.go index d744987..d5095e0 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -1,6 +1,6 @@ /* SPDX-License-Identifier: MIT * - * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved. + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ package conn @@ -9,6 +9,7 @@ import ( "encoding/binary" "io" "net" + "net/netip" "strconv" "sync" "sync/atomic" @@ -73,7 +74,7 @@ type afWinRingBind struct { type WinRingBind struct { v4, v6 afWinRingBind mu sync.RWMutex - isOpen uint32 + isOpen atomic.Uint32 // 0, 1, or 2 } func NewDefaultBind() Bind { return NewWinRingBind() } @@ -90,8 +91,10 @@ type WinRingEndpoint struct { data [30]byte } -var _ Bind = (*WinRingBind)(nil) -var _ Endpoint = (*WinRingEndpoint)(nil) +var ( + _ Bind = (*WinRingBind)(nil) + _ Endpoint = (*WinRingEndpoint)(nil) +) func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { host, port, err := net.SplitHostPort(s) @@ -121,27 +124,25 @@ func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) { if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) { return nil, windows.ERROR_INVALID_ADDRESS } - var src []byte var dst [unsafe.Sizeof(WinRingEndpoint{})]byte - unsafeSlice(unsafe.Pointer(&src), unsafe.Pointer(addrinfo.Addr), int(addrinfo.Addrlen)) - copy(dst[:], src) + copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen)) return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil } func (*WinRingEndpoint) ClearSrc() {} -func (e *WinRingEndpoint) DstIP() net.IP { +func (e *WinRingEndpoint) DstIP() netip.Addr { switch e.family { case windows.AF_INET: - return append([]byte{}, e.data[2:6]...) + return netip.AddrFrom4(*(*[4]byte)(e.data[2:6])) case windows.AF_INET6: - return append([]byte{}, e.data[6:22]...) + return netip.AddrFrom16(*(*[16]byte)(e.data[6:22])) } - return nil + return netip.Addr{} } -func (e *WinRingEndpoint) SrcIP() net.IP { - return nil // not supported +func (e *WinRingEndpoint) SrcIP() netip.Addr { + return netip.Addr{} // not supported } func (e *WinRingEndpoint) DstToBytes() []byte { @@ -163,15 +164,13 @@ func (e *WinRingEndpoint) DstToBytes() []byte { func (e *WinRingEndpoint) DstToString() string { switch e.family { case windows.AF_INET: - addr := net.UDPAddr{IP: e.data[2:6], Port: int(binary.BigEndian.Uint16(e.data[0:2]))} - return addr.String() + return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String() case windows.AF_INET6: var zone string if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 { zone = strconv.FormatUint(uint64(scope), 10) } - addr := net.UDPAddr{IP: e.data[6:22], Zone: zone, Port: int(binary.BigEndian.Uint16(e.data[0:2]))} - return addr.String() + return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String() } return "" } @@ -213,7 +212,7 @@ func (bind *afWinRingBind) CloseAndZero() { } func (bind *WinRingBind) closeAndZero() { - atomic.StoreUint32(&bind.isOpen, 0) + bind.isOpen.Store(0) bind.v4.CloseAndZero() bind.v6.CloseAndZero() } @@ -277,7 +276,7 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort bind.closeAndZero() } }() - if atomic.LoadUint32(&bind.isOpen) != 0 { + if bind.isOpen.Load() != 0 { return nil, 0, ErrBindAlreadyOpen } var sa windows.Sockaddr @@ -300,17 +299,17 @@ func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort return nil, 0, err } } - atomic.StoreUint32(&bind.isOpen, 1) + bind.isOpen.Store(1) return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err } func (bind *WinRingBind) Close() error { bind.mu.RLock() - if atomic.LoadUint32(&bind.isOpen) != 1 { + if bind.isOpen.Load() != 1 { bind.mu.RUnlock() return nil } - atomic.StoreUint32(&bind.isOpen, 2) + bind.isOpen.Store(2) windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil) @@ -322,6 +321,13 @@ func (bind *WinRingBind) Close() error { return nil } +// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and +// rename the IdealBatchSize constant to BatchSize. +func (bind *WinRingBind) BatchSize() int { + // TODO: implement batching in and out of the ring + return 1 +} + func (bind *WinRingBind) SetMark(mark uint32) error { return nil } @@ -346,8 +352,8 @@ func (bind *afWinRingBind) InsertReceiveRequest() error { //go:linkname procyield runtime.procyield func procyield(cycles uint32) -func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, error) { - if atomic.LoadUint32(isOpen) != 1 { +func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } bind.rx.mu.Lock() @@ -360,7 +366,7 @@ retry: count = 0 for tries := 0; count == 0 && tries < receiveSpins; tries++ { if tries > 0 { - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } procyield(1) @@ -379,13 +385,12 @@ retry: if err != nil { return 0, nil, err } - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } count = winrio.DequeueCompletion(bind.rx.cq, results[:]) if count == 0 { return 0, nil, io.ErrNoProgress - } } bind.rx.Return(1) @@ -397,7 +402,7 @@ retry: // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to // attacker bandwidth, just like the rest of the receive path. if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE { - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return 0, nil, net.ErrClosed } goto retry @@ -411,20 +416,26 @@ retry: return n, &ep, nil } -func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - return bind.v4.Receive(buf, &bind.isOpen) + n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err } -func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) { +func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - return bind.v6.Receive(buf, &bind.isOpen) + n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) + sizes[0] = n + eps[0] = ep + return 1, err } -func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint32) error { - if atomic.LoadUint32(isOpen) != 1 { +func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error { + if isOpen.Load() != 1 { return net.ErrClosed } if len(buf) > bytesPerPacket { @@ -446,7 +457,7 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3 if err != nil { return err } - if atomic.LoadUint32(isOpen) != 1 { + if isOpen.Load() != 1 { return net.ErrClosed } count = winrio.DequeueCompletion(bind.tx.cq, results[:]) @@ -475,32 +486,38 @@ func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *uint3 return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0) } -func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error { +func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error { nend, ok := endpoint.(*WinRingEndpoint) if !ok { return ErrWrongEndpointType } bind.mu.RLock() defer bind.mu.RUnlock() - switch nend.family { - case windows.AF_INET: - if bind.v4.blackhole { - return nil - } - return bind.v4.Send(buf, nend, &bind.isOpen) - case windows.AF_INET6: - if bind.v6.blackhole { - return nil + for _, buf := range bufs { + switch nend.family { + case windows.AF_INET: + if bind.v4.blackhole { + continue + } + if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil { + return err + } + case windows.AF_INET6: + if bind.v6.blackhole { + continue + } + if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil { + return err + } } - return bind.v6.Send(buf, nend, &bind.isOpen) } return nil } -func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { - bind.mu.Lock() - defer bind.mu.Unlock() - sysconn, err := bind.ipv4.SyscallConn() +func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv4.SyscallConn() if err != nil { return err } @@ -513,14 +530,14 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole if err != nil { return err } - bind.blackhole4 = blackhole + s.blackhole4 = blackhole return nil } -func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { - bind.mu.Lock() - defer bind.mu.Unlock() - sysconn, err := bind.ipv6.SyscallConn() +func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { + s.mu.Lock() + defer s.mu.Unlock() + sysconn, err := s.ipv6.SyscallConn() if err != nil { return err } @@ -533,13 +550,14 @@ func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole if err != nil { return err } - bind.blackhole6 = blackhole + s.blackhole6 = blackhole return nil } + func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { bind.mu.RLock() defer bind.mu.RUnlock() - if atomic.LoadUint32(&bind.isOpen) != 1 { + if bind.isOpen.Load() != 1 { return net.ErrClosed } err := bindSocketToInterface4(bind.v4.sock, interfaceIndex) @@ -553,7 +571,7 @@ func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { bind.mu.RLock() defer bind.mu.RUnlock() - if atomic.LoadUint32(&bind.isOpen) != 1 { + if bind.isOpen.Load() != 1 { return net.ErrClosed } err := bindSocketToInterface6(bind.v6.sock, interfaceIndex) @@ -581,21 +599,3 @@ func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error const IPV6_UNICAST_IF = 31 return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex)) } - -// unsafeSlice updates the slice slicePtr to be a slice -// referencing the provided data with its length & capacity set to -// lenCap. -// -// TODO: when Go 1.16 or Go 1.17 is the minimum supported version, -// update callers to use unsafe.Slice instead of this. -func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) { - type sliceHeader struct { - Data unsafe.Pointer - Len int - Cap int - } - h := (*sliceHeader)(slicePtr) - h.Data = data - h.Len = lenCap - h.Cap = lenCap -} |