diff options
-rw-r--r-- | src/device.c | 23 | ||||
-rw-r--r-- | src/device.h | 3 | ||||
-rw-r--r-- | src/peer.c | 11 | ||||
-rw-r--r-- | src/peer.h | 2 | ||||
-rw-r--r-- | src/queueing.h | 4 | ||||
-rw-r--r-- | src/receive.c | 37 | ||||
-rw-r--r-- | src/send.c | 6 |
7 files changed, 48 insertions, 38 deletions
diff --git a/src/device.c b/src/device.c index 54a94fb..9af2533 100644 --- a/src/device.c +++ b/src/device.c @@ -227,9 +227,13 @@ static void destruct(struct net_device *dev) destroy_workqueue(wg->handshake_receive_wq); destroy_workqueue(wg->handshake_send_wq); destroy_workqueue(wg->packet_crypt_wq); + napi_disable(&wg->napi); + packet_queue_free(&wg->rx_queue, false); + packet_queue_free(&wg->tx_queue, false); packet_queue_free(&wg->decrypt_queue, true); packet_queue_free(&wg->encrypt_queue, true); rcu_barrier_bh(); /* Wait for all the peers to be actually freed. */ + netif_napi_del(&wg->napi); ratelimiter_uninit(); memzero_explicit(&wg->static_identity, sizeof(struct noise_static_identity)); skb_queue_purge(&wg->incoming_handshakes); @@ -322,13 +326,21 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t if (packet_queue_init(&wg->decrypt_queue, packet_decrypt_worker, true, MAX_QUEUED_PACKETS) < 0) goto error_7; + if (packet_queue_init(&wg->tx_queue, packet_tx_worker, false, MAX_QUEUED_PACKETS) < 0) + goto error_8; + + if (packet_queue_init(&wg->rx_queue, NULL, false, MAX_QUEUED_PACKETS) < 0) + goto error_9; + ret = ratelimiter_init(); if (ret < 0) - goto error_8; + goto error_10; + netif_napi_add(dev, &wg->napi, packet_rx_poll, NAPI_POLL_WEIGHT); + napi_enable(&wg->napi); ret = register_netdevice(dev); if (ret < 0) - goto error_9; + goto error_11; list_add(&wg->device_list, &device_list); @@ -340,8 +352,13 @@ static int newlink(struct net *src_net, struct net_device *dev, struct nlattr *t pr_debug("%s: Interface created\n", dev->name); return ret; -error_9: +error_11: + netif_napi_del(&wg->napi); ratelimiter_uninit(); +error_10: + packet_queue_free(&wg->rx_queue, false); +error_9: + packet_queue_free(&wg->tx_queue, false); error_8: packet_queue_free(&wg->decrypt_queue, true); error_7: diff --git a/src/device.h b/src/device.h index 2a0e2c7..45b742f 100644 --- a/src/device.h +++ b/src/device.h @@ -38,7 +38,8 @@ struct crypt_queue { struct wireguard_device { struct net_device *dev; - struct crypt_queue encrypt_queue, decrypt_queue; + struct crypt_queue encrypt_queue, decrypt_queue, tx_queue, rx_queue; + struct napi_struct napi; struct sock __rcu *sock4, *sock6; struct net *creating_net; struct noise_static_identity static_identity; @@ -50,14 +50,10 @@ struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_ INIT_WORK(&peer->transmit_handshake_work, packet_handshake_send_worker); rwlock_init(&peer->endpoint_lock); kref_init(&peer->refcount); - packet_queue_init(&peer->tx_queue, packet_tx_worker, false, MAX_QUEUED_PACKETS); - packet_queue_init(&peer->rx_queue, NULL, false, MAX_QUEUED_PACKETS); skb_queue_head_init(&peer->staged_packet_queue); list_add_tail(&peer->peer_list, &wg->peer_list); pubkey_hashtable_add(&wg->peer_hashtable, peer); peer->last_sent_handshake = ktime_get_boot_fast_ns() - (u64)(REKEY_TIMEOUT + 1) * NSEC_PER_SEC; - netif_napi_add(wg->dev, &peer->napi, packet_rx_poll, NAPI_POLL_WEIGHT); - napi_enable(&peer->napi); pr_debug("%s: Peer %llu created\n", wg->dev->name, peer->internal_id); return peer; } @@ -94,11 +90,6 @@ void peer_remove(struct wireguard_peer *peer) noise_keypairs_clear(&peer->keypairs); list_del_init(&peer->peer_list); timers_stop(peer); - flush_workqueue(peer->device->packet_crypt_wq); /* The first flush is for encrypt/decrypt. */ - flush_workqueue(peer->device->packet_crypt_wq); /* The second.1 flush is for send (but not receive, since that's napi). */ - napi_disable(&peer->napi); /* The second.2 flush is for receive (but not send, since that's wq). */ - flush_workqueue(peer->device->handshake_send_wq); - netif_napi_del(&peer->napi); --peer->device->num_peers; peer_put(peer); } @@ -109,8 +100,6 @@ static void rcu_release(struct rcu_head *rcu) pr_debug("%s: Peer %llu (%pISpfsc) destroyed\n", peer->device->dev->name, peer->internal_id, &peer->endpoint.addr); dst_cache_destroy(&peer->endpoint_cache); - packet_queue_free(&peer->rx_queue, false); - packet_queue_free(&peer->tx_queue, false); kzfree(peer); } @@ -35,7 +35,6 @@ struct endpoint { struct wireguard_peer { struct wireguard_device *device; - struct crypt_queue tx_queue, rx_queue; struct sk_buff_head staged_packet_queue; int serial_work_cpu; struct noise_keypairs keypairs; @@ -57,7 +56,6 @@ struct wireguard_peer { struct rcu_head rcu; struct list_head peer_list; u64 internal_id; - struct napi_struct napi; }; struct wireguard_peer *peer_create(struct wireguard_device *wg, const u8 public_key[NOISE_PUBLIC_KEY_LEN], const u8 preshared_key[NOISE_SYMMETRIC_KEY_LEN]); diff --git a/src/queueing.h b/src/queueing.h index c17b0d8..2869675 100644 --- a/src/queueing.h +++ b/src/queueing.h @@ -139,12 +139,12 @@ static inline void queue_enqueue_per_peer(struct crypt_queue *queue, struct sk_b peer_put(peer); } -static inline void queue_enqueue_per_peer_napi(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state) +static inline void queue_enqueue_per_device_napi(struct crypt_queue *queue, struct sk_buff *skb, enum packet_state state) { struct wireguard_peer *peer = peer_rcu_get(PACKET_PEER(skb)); atomic_set(&PACKET_CB(skb)->state, state); - napi_schedule(&peer->napi); + napi_schedule(&peer->device->napi); peer_put(peer); } diff --git a/src/receive.c b/src/receive.c index 144064e..64ea92e 100644 --- a/src/receive.c +++ b/src/receive.c @@ -342,7 +342,7 @@ static void packet_consume_data_done(struct sk_buff *skb, struct endpoint *endpo if (unlikely(routed_peer != peer)) goto dishonest_packet_peer; - if (unlikely(napi_gro_receive(&peer->napi, skb) == NET_RX_DROP)) { + if (unlikely(napi_gro_receive(&peer->device->napi, skb) == NET_RX_DROP)) { ++dev->stats.rx_dropped; net_dbg_ratelimited("%s: Failed to give packet to userspace from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); } else @@ -370,9 +370,10 @@ packet_processed: int packet_rx_poll(struct napi_struct *napi, int budget) { - struct wireguard_peer *peer = container_of(napi, struct wireguard_peer, napi); - struct crypt_queue *queue = &peer->rx_queue; + struct wireguard_device *wg = container_of(napi, struct wireguard_device, napi); + struct crypt_queue *queue = &wg->rx_queue; struct noise_keypair *keypair; + struct wireguard_peer *peer; struct sk_buff *skb; struct endpoint endpoint; enum packet_state state; @@ -392,7 +393,7 @@ int packet_rx_poll(struct napi_struct *napi, int budget) goto next; if (unlikely(!counter_validate(&keypair->receiving.counter, PACKET_CB(skb)->nonce))) { - net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", peer->device->dev->name, PACKET_CB(skb)->nonce, keypair->receiving.counter.receive.counter); + net_dbg_ratelimited("%s: Packet has invalid nonce %llu (max %llu)\n", wg->dev->name, PACKET_CB(skb)->nonce, keypair->receiving.counter.receive.counter); goto next; } @@ -427,7 +428,7 @@ void packet_decrypt_worker(struct work_struct *work) while ((skb = ptr_ring_consume_bh(&queue->ring)) != NULL) { enum packet_state state = likely(skb_decrypt(skb, &PACKET_CB(skb)->keypair->receiving, have_simd)) ? PACKET_STATE_CRYPTED : PACKET_STATE_DEAD; - queue_enqueue_per_peer_napi(&PACKET_PEER(skb)->rx_queue, skb, state); + queue_enqueue_per_device_napi(&PACKET_PEER(skb)->device->rx_queue, skb, state); have_simd = simd_relax(have_simd); } @@ -443,25 +444,29 @@ static void packet_consume_data(struct wireguard_device *wg, struct sk_buff *skb rcu_read_lock_bh(); PACKET_CB(skb)->keypair = noise_keypair_get((struct noise_keypair *)index_hashtable_lookup(&wg->index_hashtable, INDEX_HASHTABLE_KEYPAIR, idx)); rcu_read_unlock_bh(); - if (unlikely(!PACKET_CB(skb)->keypair)) { - dev_kfree_skb(skb); - return; - } + if (unlikely(!PACKET_CB(skb)->keypair)) + goto err_keypair; /* The call to index_hashtable_lookup gives us a reference to its underlying peer, so we don't need to call peer_rcu_get(). */ peer = PACKET_PEER(skb); + /* If elsewhere has called peer_remove, we should not queue up more packets, so that eventually the reference count goes to zero. */ + if (unlikely(list_empty(&peer->peer_list))) + goto err_peer; - ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &peer->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu); + ret = queue_enqueue_per_device_and_peer(&wg->decrypt_queue, &wg->rx_queue, skb, wg->packet_crypt_wq, &wg->decrypt_queue.last_cpu); if (likely(!ret)) return; /* Successful. No need to drop references below. */ - if (ret == -EPIPE) - queue_enqueue_per_peer(&peer->rx_queue, skb, PACKET_STATE_DEAD); - else { - peer_put(peer); - noise_keypair_put(PACKET_CB(skb)->keypair); - dev_kfree_skb(skb); + if (ret == -EPIPE) { + queue_enqueue_per_peer(&wg->rx_queue, skb, PACKET_STATE_DEAD); + return; } + +err_peer: + peer_put(peer); + noise_keypair_put(PACKET_CB(skb)->keypair); +err_keypair: + dev_kfree_skb(skb); } void packet_receive(struct wireguard_device *wg, struct sk_buff *skb) @@ -255,7 +255,7 @@ void packet_encrypt_worker(struct work_struct *work) break; } } - queue_enqueue_per_peer(&PACKET_PEER(first)->tx_queue, first, state); + queue_enqueue_per_peer(&PACKET_PEER(first)->device->tx_queue, first, state); have_simd = simd_relax(have_simd); } @@ -268,12 +268,12 @@ static void packet_create_data(struct sk_buff *first) struct wireguard_device *wg = peer->device; int ret; - ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &peer->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu); + ret = queue_enqueue_per_device_and_peer(&wg->encrypt_queue, &wg->tx_queue, first, wg->packet_crypt_wq, &wg->encrypt_queue.last_cpu); if (likely(!ret)) return; /* Successful. No need to fall through to drop references below. */ if (ret == -EPIPE) - queue_enqueue_per_peer(&peer->tx_queue, first, PACKET_STATE_DEAD); + queue_enqueue_per_peer(&wg->tx_queue, first, PACKET_STATE_DEAD); else { peer_put(peer); noise_keypair_put(PACKET_CB(first)->keypair); |