From 4e47e4a46ce9cc06a3496113f331b46ecffbc255 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Thu, 9 Sep 2021 09:06:59 +0000 Subject: driver: socket: remember to copy cmsghack when copying endpoint Otherwise, we can't reply to incoming endpoints. Reported-by: Peter Whisker Signed-off-by: Jason A. Donenfeld --- driver/peer.h | 13 ++++++++++--- driver/socket.c | 39 ++++++++++++++++----------------------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/driver/peer.h b/driver/peer.h index 1cf1b9a..d5d14d7 100644 --- a/driver/peer.h +++ b/driver/peer.h @@ -20,10 +20,17 @@ typedef struct _ENDPOINT WSACMSGHDR Cmsg; union { - IN_PKTINFO Src4; - IN6_PKTINFO Src6; + struct + { + IN_PKTINFO Src4; + WSACMSGHDR CmsgHack4; + }; + struct + { + IN6_PKTINFO Src6; + WSACMSGHDR CmsgHack6; + }; }; - UCHAR CmsgHackBuf[WSA_CMSGHDR_ALIGN(sizeof(WSACMSGHDR))]; }; UINT32 RoutingGeneration; UINT32 UpdateGeneration; diff --git a/driver/socket.c b/driver/socket.c index b7f42ac..e3a021c 100644 --- a/driver/socket.c +++ b/driver/socket.c @@ -37,7 +37,7 @@ static BOOLEAN WskHasIpv4Transport, WskHasIpv6Transport; static NTSTATUS WskInitStatus = STATUS_RETRY; static EX_PUSH_LOCK WskIsIniting; static LOOKASIDE_ALIGN LOOKASIDE_LIST_EX SocketSendCtxCache; -static ULONG CmsgHackAdditionalLength = WSA_CMSG_LEN(0); +static ULONG CmsgHackAdditionalLength = WSA_CMSG_SPACE(0); #define NET_BUFFER_WSK_BUF(Nb) ((WSK_BUF_LIST *)&NET_BUFFER_MINIPORT_RESERVED(Nb)[0]) static_assert( @@ -255,14 +255,11 @@ retryWhileHoldingSharedLock: Peer->Endpoint.Cmsg.cmsg_len = WSA_CMSG_LEN(sizeof(Peer->Endpoint.Src4)); Peer->Endpoint.Cmsg.cmsg_level = IPPROTO_IP; Peer->Endpoint.Cmsg.cmsg_type = IP_PKTINFO; - WSACMSGHDR *CmsgHack = - (WSACMSGHDR - *)((UCHAR *)&Peer->Endpoint.Cmsg + WSA_CMSGHDR_ALIGN(WSA_CMSG_LEN(sizeof(Peer->Endpoint.Src4)))); - CmsgHack->cmsg_len = WSA_CMSG_LEN(0); - CmsgHack->cmsg_level = IPPROTO_IP; - CmsgHack->cmsg_type = IP_WFP_REDIRECT_RECORDS; Peer->Endpoint.Src4.ipi_addr = SrcAddr.Ipv4.sin_addr; Peer->Endpoint.Src4.ipi_ifindex = BestIndex; + Peer->Endpoint.CmsgHack4.cmsg_len = WSA_CMSG_LEN(0); + Peer->Endpoint.CmsgHack4.cmsg_level = IPPROTO_IP; + Peer->Endpoint.CmsgHack4.cmsg_type = IP_WFP_REDIRECT_RECORDS; Peer->Endpoint.RoutingGeneration = ReadNoFence(&RoutingGenerationV4); } else if (Peer->Endpoint.Addr.si_family == AF_INET6) @@ -270,14 +267,11 @@ retryWhileHoldingSharedLock: Peer->Endpoint.Cmsg.cmsg_len = WSA_CMSG_LEN(sizeof(Peer->Endpoint.Src6)); Peer->Endpoint.Cmsg.cmsg_level = IPPROTO_IPV6; Peer->Endpoint.Cmsg.cmsg_type = IPV6_PKTINFO; - WSACMSGHDR *CmsgHack = - (WSACMSGHDR - *)((UCHAR *)&Peer->Endpoint.Cmsg + WSA_CMSGHDR_ALIGN(WSA_CMSG_LEN(sizeof(Peer->Endpoint.Src6)))); - CmsgHack->cmsg_len = WSA_CMSG_LEN(0); - CmsgHack->cmsg_level = IPPROTO_IPV6; - CmsgHack->cmsg_type = IPV6_WFP_REDIRECT_RECORDS; Peer->Endpoint.Src6.ipi6_addr = SrcAddr.Ipv6.sin6_addr; Peer->Endpoint.Src6.ipi6_ifindex = BestIndex; + Peer->Endpoint.CmsgHack6.cmsg_len = WSA_CMSG_LEN(0); + Peer->Endpoint.CmsgHack6.cmsg_level = IPPROTO_IPV6; + Peer->Endpoint.CmsgHack6.cmsg_type = IPV6_WFP_REDIRECT_RECORDS; Peer->Endpoint.RoutingGeneration = ReadNoFence(&RoutingGenerationV6); } ++Peer->Endpoint.UpdateGeneration, ++UpdateGeneration; @@ -512,6 +506,7 @@ static_assert( static_assert( FIELD_OFFSET(ENDPOINT, Cmsg) + WSA_CMSG_SPACE(RTL_FIELD_SIZE(ENDPOINT, Src6)) <= sizeof(ENDPOINT), "cmsg calculation mismatch"); +static_assert(WSA_CMSG_SPACE(0) == sizeof(WSACMSGHDR), "cmsg calculation mismatch"); _Post_maybenull_ static VOID * @@ -544,12 +539,10 @@ SocketEndpointFromNbl(ENDPOINT *Endpoint, CONST NET_BUFFER_LIST *Nbl) Endpoint->Cmsg.cmsg_len = WSA_CMSG_LEN(sizeof(Endpoint->Src4)); Endpoint->Cmsg.cmsg_level = IPPROTO_IP; Endpoint->Cmsg.cmsg_type = IP_PKTINFO; - WSACMSGHDR *CmsgHack = - (WSACMSGHDR *)((UCHAR *)&Endpoint->Cmsg + WSA_CMSGHDR_ALIGN(WSA_CMSG_LEN(sizeof(Endpoint->Src4)))); - CmsgHack->cmsg_len = WSA_CMSG_LEN(0); - CmsgHack->cmsg_level = IPPROTO_IP; - CmsgHack->cmsg_type = IP_WFP_REDIRECT_RECORDS; Endpoint->Src4 = *(IN_PKTINFO *)Pktinfo; + Endpoint->CmsgHack4.cmsg_len = WSA_CMSG_LEN(0); + Endpoint->CmsgHack4.cmsg_level = IPPROTO_IP; + Endpoint->CmsgHack4.cmsg_type = IP_WFP_REDIRECT_RECORDS; Endpoint->RoutingGeneration = ReadNoFence(&RoutingGenerationV4); } else if (Addr->sa_family == AF_INET6 && (Pktinfo = FindInCmsgHdr(Data, IPPROTO_IPV6, IPV6_PKTINFO)) != NULL) @@ -558,12 +551,10 @@ SocketEndpointFromNbl(ENDPOINT *Endpoint, CONST NET_BUFFER_LIST *Nbl) Endpoint->Cmsg.cmsg_len = WSA_CMSG_LEN(sizeof(Endpoint->Src6)); Endpoint->Cmsg.cmsg_level = IPPROTO_IPV6; Endpoint->Cmsg.cmsg_type = IPV6_PKTINFO; - WSACMSGHDR *CmsgHack = - (WSACMSGHDR *)((UCHAR *)&Endpoint->Cmsg + WSA_CMSGHDR_ALIGN(WSA_CMSG_LEN(sizeof(Endpoint->Src6)))); - CmsgHack->cmsg_len = WSA_CMSG_LEN(0); - CmsgHack->cmsg_level = IPPROTO_IPV6; - CmsgHack->cmsg_type = IPV6_WFP_REDIRECT_RECORDS; Endpoint->Src6 = *(IN6_PKTINFO *)Pktinfo; + Endpoint->CmsgHack6.cmsg_len = WSA_CMSG_LEN(0); + Endpoint->CmsgHack6.cmsg_level = IPPROTO_IPV6; + Endpoint->CmsgHack6.cmsg_type = IPV6_WFP_REDIRECT_RECORDS; Endpoint->RoutingGeneration = ReadNoFence(&RoutingGenerationV6); } else @@ -614,6 +605,7 @@ SocketSetPeerEndpoint(WG_PEER *Peer, CONST ENDPOINT *Endpoint) { Peer->Endpoint.Cmsg = Endpoint->Cmsg; Peer->Endpoint.Src4 = Endpoint->Src4; + Peer->Endpoint.CmsgHack4 = Endpoint->CmsgHack4; } } else if (Endpoint->Addr.si_family == AF_INET6) @@ -623,6 +615,7 @@ SocketSetPeerEndpoint(WG_PEER *Peer, CONST ENDPOINT *Endpoint) { Peer->Endpoint.Cmsg = Endpoint->Cmsg; Peer->Endpoint.Src6 = Endpoint->Src6; + Peer->Endpoint.CmsgHack6 = Endpoint->CmsgHack6; } } else -- cgit v1.2.3-59-g8ed1b