From d82ea502b971fdc75cd500973546a164d3302828 Mon Sep 17 00:00:00 2001 From: Julian Orth Date: Thu, 6 Sep 2018 20:48:03 +0200 Subject: device: store a copy of the device net This eliminates the need for have_transit_net_ref because have_transit_net_ref == true if and only if dev_net != transit_net. --- src/device.c | 27 +++++++++++++++++---------- src/device.h | 4 +++- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/device.c b/src/device.c index ae4b9ad..0c0c17b 100644 --- a/src/device.c +++ b/src/device.c @@ -251,7 +251,7 @@ static void wg_destruct(struct net_device *dev) skb_queue_purge(&wg->incoming_handshakes); free_percpu(dev->tstats); free_percpu(wg->incoming_handshakes_worker); - if (wg->have_transit_net_ref) + if (wg->transit_net != wg->dev_net) put_net(wg->transit_net); mutex_unlock(&wg->device_update_lock); @@ -304,7 +304,9 @@ static int wg_newlink(struct net *src_net, struct net_device *dev, struct wg_device *wg = netdev_priv(dev); int ret = -ENOMEM; - wg->transit_net = src_net; + wg->dev_net = NULL; + wg->transit_net = NULL; + wg_device_set_nets(wg, dev_net(dev), src_net); init_rwsem(&wg->static_identity.lock); mutex_init(&wg->socket_update_lock); mutex_init(&wg->device_update_lock); @@ -406,14 +408,8 @@ static int wg_netdevice_notification(struct notifier_block *nb, if (action != NETDEV_REGISTER || dev->netdev_ops != &netdev_ops) return 0; - if (dev_net(dev) == wg->transit_net && wg->have_transit_net_ref) { - put_net(wg->transit_net); - wg->have_transit_net_ref = false; - } else if (dev_net(dev) != wg->transit_net && - !wg->have_transit_net_ref) { - wg->have_transit_net_ref = true; - get_net(wg->transit_net); - } + wg_device_set_nets(wg, dev_net(dev), wg->transit_net); + return 0; } @@ -459,3 +455,14 @@ void wg_device_uninit(void) #endif rcu_barrier_bh(); } + +void wg_device_set_nets(struct wg_device *wg, struct net *dev_net, + struct net *transit_net) +{ + if (wg->transit_net != wg->dev_net) + put_net(wg->transit_net); + wg->dev_net = dev_net; + wg->transit_net = transit_net; + if (wg->transit_net != wg->dev_net) + get_net(wg->transit_net); +} diff --git a/src/device.h b/src/device.h index be54d4a..e1d4e84 100644 --- a/src/device.h +++ b/src/device.h @@ -41,6 +41,7 @@ struct wg_device { struct crypt_queue encrypt_queue, decrypt_queue; struct sock __rcu *sock4, *sock6; struct net *transit_net; + struct net *dev_net; struct noise_static_identity static_identity; struct workqueue_struct *handshake_receive_wq, *handshake_send_wq; struct workqueue_struct *packet_crypt_wq; @@ -56,10 +57,11 @@ struct wg_device { unsigned int num_peers, device_update_gen; u32 fwmark; u16 incoming_port; - bool have_transit_net_ref; }; int wg_device_init(void); void wg_device_uninit(void); +void wg_device_set_nets(struct wg_device *wg, struct net *dev_net, + struct net *transit_net); #endif /* _WG_DEVICE_H */ -- cgit v1.2.3-59-g8ed1b