aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/net/packet/af_packet.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/packet/af_packet.c')
-rw-r--r--net/packet/af_packet.c36
1 files changed, 24 insertions, 12 deletions
diff --git a/net/packet/af_packet.c b/net/packet/af_packet.c
index d288f52c53f7..2986941164b1 100644
--- a/net/packet/af_packet.c
+++ b/net/packet/af_packet.c
@@ -1769,7 +1769,7 @@ static int fanout_add(struct sock *sk, u16 id, u16 type_flags)
out:
if (err && rollover) {
- kfree(rollover);
+ kfree_rcu(rollover, rcu);
po->rollover = NULL;
}
mutex_unlock(&fanout_mutex);
@@ -1796,8 +1796,10 @@ static struct packet_fanout *fanout_release(struct sock *sk)
else
f = NULL;
- if (po->rollover)
+ if (po->rollover) {
kfree_rcu(po->rollover, rcu);
+ po->rollover = NULL;
+ }
}
mutex_unlock(&fanout_mutex);
@@ -2840,6 +2842,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
struct virtio_net_hdr vnet_hdr = { 0 };
int offset = 0;
struct packet_sock *po = pkt_sk(sk);
+ bool has_vnet_hdr = false;
int hlen, tlen, linear;
int extra_len = 0;
@@ -2883,6 +2886,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
err = packet_snd_vnet_parse(msg, &len, &vnet_hdr);
if (err)
goto out_unlock;
+ has_vnet_hdr = true;
}
if (unlikely(sock_flag(sk, SOCK_NOFCS))) {
@@ -2941,7 +2945,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
skb->priority = sk->sk_priority;
skb->mark = sockc.mark;
- if (po->has_vnet_hdr) {
+ if (has_vnet_hdr) {
err = virtio_net_hdr_to_skb(skb, &vnet_hdr, vio_le());
if (err)
goto out_free;
@@ -3069,13 +3073,15 @@ static int packet_do_bind(struct sock *sk, const char *name, int ifindex,
int ret = 0;
bool unlisted = false;
- if (po->fanout)
- return -EINVAL;
-
lock_sock(sk);
spin_lock(&po->bind_lock);
rcu_read_lock();
+ if (po->fanout) {
+ ret = -EINVAL;
+ goto out_unlock;
+ }
+
if (name) {
dev = dev_get_by_name_rcu(sock_net(sk), name);
if (!dev) {
@@ -3847,6 +3853,7 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
void *data = &val;
union tpacket_stats_u st;
struct tpacket_rollover_stats rstats;
+ struct packet_rollover *rollover;
if (level != SOL_PACKET)
return -ENOPROTOOPT;
@@ -3925,13 +3932,18 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
0);
break;
case PACKET_ROLLOVER_STATS:
- if (!po->rollover)
+ rcu_read_lock();
+ rollover = rcu_dereference(po->rollover);
+ if (rollover) {
+ rstats.tp_all = atomic_long_read(&rollover->num);
+ rstats.tp_huge = atomic_long_read(&rollover->num_huge);
+ rstats.tp_failed = atomic_long_read(&rollover->num_failed);
+ data = &rstats;
+ lv = sizeof(rstats);
+ }
+ rcu_read_unlock();
+ if (!rollover)
return -EINVAL;
- rstats.tp_all = atomic_long_read(&po->rollover->num);
- rstats.tp_huge = atomic_long_read(&po->rollover->num_huge);
- rstats.tp_failed = atomic_long_read(&po->rollover->num_failed);
- data = &rstats;
- lv = sizeof(rstats);
break;
case PACKET_TX_HAS_OFF:
val = po->tp_tx_has_off;