aboutsummaryrefslogtreecommitdiffstats
path: root/net/mctp
diff options
context:
space:
mode:
Diffstat (limited to 'net/mctp')
-rw-r--r--net/mctp/af_mctp.c252
-rw-r--r--net/mctp/device.c36
-rw-r--r--net/mctp/neigh.c2
-rw-r--r--net/mctp/route.c180
-rw-r--r--net/mctp/test/route-test.c169
-rw-r--r--net/mctp/test/utils.c1
6 files changed, 521 insertions, 119 deletions
diff --git a/net/mctp/af_mctp.c b/net/mctp/af_mctp.c
index c921de63b494..fc9e728b6333 100644
--- a/net/mctp/af_mctp.c
+++ b/net/mctp/af_mctp.c
@@ -6,6 +6,7 @@
* Copyright (c) 2021 Google
*/
+#include <linux/compat.h>
#include <linux/if_arp.h>
#include <linux/net.h>
#include <linux/mctp.h>
@@ -21,6 +22,8 @@
/* socket implementation */
+static void mctp_sk_expire_keys(struct timer_list *timer);
+
static int mctp_release(struct socket *sock)
{
struct sock *sk = sock->sk;
@@ -90,22 +93,29 @@ out_release:
static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
{
DECLARE_SOCKADDR(struct sockaddr_mctp *, addr, msg->msg_name);
- const int hlen = MCTP_HEADER_MAXLEN + sizeof(struct mctp_hdr);
int rc, addrlen = msg->msg_namelen;
struct sock *sk = sock->sk;
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct mctp_skb_cb *cb;
struct mctp_route *rt;
- struct sk_buff *skb;
+ struct sk_buff *skb = NULL;
+ int hlen;
if (addr) {
+ const u8 tagbits = MCTP_TAG_MASK | MCTP_TAG_OWNER |
+ MCTP_TAG_PREALLOC;
+
if (addrlen < sizeof(struct sockaddr_mctp))
return -EINVAL;
if (addr->smctp_family != AF_MCTP)
return -EINVAL;
if (!mctp_sockaddr_is_ok(addr))
return -EINVAL;
- if (addr->smctp_tag & ~(MCTP_TAG_MASK | MCTP_TAG_OWNER))
+ if (addr->smctp_tag & ~tagbits)
+ return -EINVAL;
+ /* can't preallocate a non-owned tag */
+ if (addr->smctp_tag & MCTP_TAG_PREALLOC &&
+ !(addr->smctp_tag & MCTP_TAG_OWNER))
return -EINVAL;
} else {
@@ -119,6 +129,34 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
if (addr->smctp_network == MCTP_NET_ANY)
addr->smctp_network = mctp_default_net(sock_net(sk));
+ /* direct addressing */
+ if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
+ DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
+ extaddr, msg->msg_name);
+ struct net_device *dev;
+
+ rc = -EINVAL;
+ rcu_read_lock();
+ dev = dev_get_by_index_rcu(sock_net(sk), extaddr->smctp_ifindex);
+ /* check for correct halen */
+ if (dev && extaddr->smctp_halen == dev->addr_len) {
+ hlen = LL_RESERVED_SPACE(dev) + sizeof(struct mctp_hdr);
+ rc = 0;
+ }
+ rcu_read_unlock();
+ if (rc)
+ goto err_free;
+ rt = NULL;
+ } else {
+ rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
+ addr->smctp_addr.s_addr);
+ if (!rt) {
+ rc = -EHOSTUNREACH;
+ goto err_free;
+ }
+ hlen = LL_RESERVED_SPACE(rt->dev->dev) + sizeof(struct mctp_hdr);
+ }
+
skb = sock_alloc_send_skb(sk, hlen + 1 + len,
msg->msg_flags & MSG_DONTWAIT, &rc);
if (!skb)
@@ -137,8 +175,8 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
cb = __mctp_cb(skb);
cb->net = addr->smctp_network;
- /* direct addressing */
- if (msk->addr_ext && addrlen >= sizeof(struct sockaddr_mctp_ext)) {
+ if (!rt) {
+ /* fill extended address in cb */
DECLARE_SOCKADDR(struct sockaddr_mctp_ext *,
extaddr, msg->msg_name);
@@ -149,17 +187,9 @@ static int mctp_sendmsg(struct socket *sock, struct msghdr *msg, size_t len)
}
cb->ifindex = extaddr->smctp_ifindex;
+ /* smctp_halen is checked above */
cb->halen = extaddr->smctp_halen;
memcpy(cb->haddr, extaddr->smctp_haddr, cb->halen);
-
- rt = NULL;
- } else {
- rt = mctp_route_lookup(sock_net(sk), addr->smctp_network,
- addr->smctp_addr.s_addr);
- if (!rt) {
- rc = -EHOSTUNREACH;
- goto err_free;
- }
}
rc = mctp_local_output(sk, rt, skb, addr->smctp_addr.s_addr,
@@ -186,7 +216,7 @@ static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
if (flags & ~(MSG_DONTWAIT | MSG_TRUNC | MSG_PEEK))
return -EOPNOTSUPP;
- skb = skb_recv_datagram(sk, flags, flags & MSG_DONTWAIT, &rc);
+ skb = skb_recv_datagram(sk, flags, &rc);
if (!skb)
return rc;
@@ -208,7 +238,7 @@ static int mctp_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
if (rc < 0)
goto out_free;
- sock_recv_ts_and_drops(msg, sk, skb);
+ sock_recv_cmsgs(msg, sk, skb);
if (addr) {
struct mctp_skb_cb *cb = mctp_cb(skb);
@@ -248,6 +278,33 @@ out_free:
return rc;
}
+/* We're done with the key; invalidate, stop reassembly, and remove from lists.
+ */
+static void __mctp_key_remove(struct mctp_sk_key *key, struct net *net,
+ unsigned long flags, unsigned long reason)
+__releases(&key->lock)
+__must_hold(&net->mctp.keys_lock)
+{
+ struct sk_buff *skb;
+
+ trace_mctp_key_release(key, reason);
+ skb = key->reasm_head;
+ key->reasm_head = NULL;
+ key->reasm_dead = true;
+ key->valid = false;
+ mctp_dev_release_key(key->dev, key);
+ spin_unlock_irqrestore(&key->lock, flags);
+
+ if (!hlist_unhashed(&key->hlist)) {
+ hlist_del_init(&key->hlist);
+ hlist_del_init(&key->sklist);
+ /* unref for the lists */
+ mctp_key_unref(key);
+ }
+
+ kfree_skb(skb);
+}
+
static int mctp_setsockopt(struct socket *sock, int level, int optname,
sockptr_t optval, unsigned int optlen)
{
@@ -293,6 +350,123 @@ static int mctp_getsockopt(struct socket *sock, int level, int optname,
return -EINVAL;
}
+static int mctp_ioctl_alloctag(struct mctp_sock *msk, unsigned long arg)
+{
+ struct net *net = sock_net(&msk->sk);
+ struct mctp_sk_key *key = NULL;
+ struct mctp_ioc_tag_ctl ctl;
+ unsigned long flags;
+ u8 tag;
+
+ if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
+ return -EFAULT;
+
+ if (ctl.tag)
+ return -EINVAL;
+
+ if (ctl.flags)
+ return -EINVAL;
+
+ key = mctp_alloc_local_tag(msk, ctl.peer_addr, MCTP_ADDR_ANY,
+ true, &tag);
+ if (IS_ERR(key))
+ return PTR_ERR(key);
+
+ ctl.tag = tag | MCTP_TAG_OWNER | MCTP_TAG_PREALLOC;
+ if (copy_to_user((void __user *)arg, &ctl, sizeof(ctl))) {
+ unsigned long fl2;
+ /* Unwind our key allocation: the keys list lock needs to be
+ * taken before the individual key locks, and we need a valid
+ * flags value (fl2) to pass to __mctp_key_remove, hence the
+ * second spin_lock_irqsave() rather than a plain spin_lock().
+ */
+ spin_lock_irqsave(&net->mctp.keys_lock, flags);
+ spin_lock_irqsave(&key->lock, fl2);
+ __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_DROPPED);
+ mctp_key_unref(key);
+ spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+ return -EFAULT;
+ }
+
+ mctp_key_unref(key);
+ return 0;
+}
+
+static int mctp_ioctl_droptag(struct mctp_sock *msk, unsigned long arg)
+{
+ struct net *net = sock_net(&msk->sk);
+ struct mctp_ioc_tag_ctl ctl;
+ unsigned long flags, fl2;
+ struct mctp_sk_key *key;
+ struct hlist_node *tmp;
+ int rc;
+ u8 tag;
+
+ if (copy_from_user(&ctl, (void __user *)arg, sizeof(ctl)))
+ return -EFAULT;
+
+ if (ctl.flags)
+ return -EINVAL;
+
+ /* Must be a local tag, TO set, preallocated */
+ if ((ctl.tag & ~MCTP_TAG_MASK) != (MCTP_TAG_OWNER | MCTP_TAG_PREALLOC))
+ return -EINVAL;
+
+ tag = ctl.tag & MCTP_TAG_MASK;
+ rc = -EINVAL;
+
+ spin_lock_irqsave(&net->mctp.keys_lock, flags);
+ hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
+ /* we do an irqsave here, even though we know the irq state,
+ * so we have the flags to pass to __mctp_key_remove
+ */
+ spin_lock_irqsave(&key->lock, fl2);
+ if (key->manual_alloc &&
+ ctl.peer_addr == key->peer_addr &&
+ tag == key->tag) {
+ __mctp_key_remove(key, net, fl2,
+ MCTP_TRACE_KEY_DROPPED);
+ rc = 0;
+ } else {
+ spin_unlock_irqrestore(&key->lock, fl2);
+ }
+ }
+ spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+
+ return rc;
+}
+
+static int mctp_ioctl(struct socket *sock, unsigned int cmd, unsigned long arg)
+{
+ struct mctp_sock *msk = container_of(sock->sk, struct mctp_sock, sk);
+
+ switch (cmd) {
+ case SIOCMCTPALLOCTAG:
+ return mctp_ioctl_alloctag(msk, arg);
+ case SIOCMCTPDROPTAG:
+ return mctp_ioctl_droptag(msk, arg);
+ }
+
+ return -EINVAL;
+}
+
+#ifdef CONFIG_COMPAT
+static int mctp_compat_ioctl(struct socket *sock, unsigned int cmd,
+ unsigned long arg)
+{
+ void __user *argp = compat_ptr(arg);
+
+ switch (cmd) {
+ /* These have compatible ptr layouts */
+ case SIOCMCTPALLOCTAG:
+ case SIOCMCTPDROPTAG:
+ return mctp_ioctl(sock, cmd, (unsigned long)argp);
+ }
+
+ return -ENOIOCTLCMD;
+}
+#endif
+
static const struct proto_ops mctp_dgram_ops = {
.family = PF_MCTP,
.release = mctp_release,
@@ -302,7 +476,7 @@ static const struct proto_ops mctp_dgram_ops = {
.accept = sock_no_accept,
.getname = sock_no_getname,
.poll = datagram_poll,
- .ioctl = sock_no_ioctl,
+ .ioctl = mctp_ioctl,
.gettstamp = sock_gettstamp,
.listen = sock_no_listen,
.shutdown = sock_no_shutdown,
@@ -312,6 +486,9 @@ static const struct proto_ops mctp_dgram_ops = {
.recvmsg = mctp_recvmsg,
.mmap = sock_no_mmap,
.sendpage = sock_no_sendpage,
+#ifdef CONFIG_COMPAT
+ .compat_ioctl = mctp_compat_ioctl,
+#endif
};
static void mctp_sk_expire_keys(struct timer_list *timer)
@@ -319,7 +496,7 @@ static void mctp_sk_expire_keys(struct timer_list *timer)
struct mctp_sock *msk = container_of(timer, struct mctp_sock,
key_expiry);
struct net *net = sock_net(&msk->sk);
- unsigned long next_expiry, flags;
+ unsigned long next_expiry, flags, fl2;
struct mctp_sk_key *key;
struct hlist_node *tmp;
bool next_expiry_valid = false;
@@ -327,15 +504,16 @@ static void mctp_sk_expire_keys(struct timer_list *timer)
spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
- spin_lock(&key->lock);
+ /* don't expire. manual_alloc is immutable, no locking
+ * required.
+ */
+ if (key->manual_alloc)
+ continue;
+ spin_lock_irqsave(&key->lock, fl2);
if (!time_after_eq(key->expiry, jiffies)) {
- trace_mctp_key_release(key, MCTP_TRACE_KEY_TIMEOUT);
- key->valid = false;
- hlist_del_rcu(&key->hlist);
- hlist_del_rcu(&key->sklist);
- spin_unlock(&key->lock);
- mctp_key_unref(key);
+ __mctp_key_remove(key, net, fl2,
+ MCTP_TRACE_KEY_TIMEOUT);
continue;
}
@@ -346,7 +524,7 @@ static void mctp_sk_expire_keys(struct timer_list *timer)
next_expiry = key->expiry;
next_expiry_valid = true;
}
- spin_unlock(&key->lock);
+ spin_unlock_irqrestore(&key->lock, fl2);
}
spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
@@ -387,9 +565,9 @@ static void mctp_sk_unhash(struct sock *sk)
{
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct net *net = sock_net(sk);
+ unsigned long flags, fl2;
struct mctp_sk_key *key;
struct hlist_node *tmp;
- unsigned long flags;
/* remove from any type-based binds */
mutex_lock(&net->mctp.bind_lock);
@@ -399,20 +577,8 @@ static void mctp_sk_unhash(struct sock *sk)
/* remove tag allocations */
spin_lock_irqsave(&net->mctp.keys_lock, flags);
hlist_for_each_entry_safe(key, tmp, &msk->keys, sklist) {
- hlist_del(&key->sklist);
- hlist_del(&key->hlist);
-
- trace_mctp_key_release(key, MCTP_TRACE_KEY_CLOSED);
-
- spin_lock(&key->lock);
- kfree_skb(key->reasm_head);
- key->reasm_head = NULL;
- key->reasm_dead = true;
- key->valid = false;
- spin_unlock(&key->lock);
-
- /* key is no longer on the lookup lists, unref */
- mctp_key_unref(key);
+ spin_lock_irqsave(&key->lock, fl2);
+ __mctp_key_remove(key, net, fl2, MCTP_TRACE_KEY_CLOSED);
}
spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
}
@@ -499,12 +665,14 @@ static __init int mctp_init(void)
rc = mctp_neigh_init();
if (rc)
- goto err_unreg_proto;
+ goto err_unreg_routes;
mctp_device_init();
return 0;
+err_unreg_routes:
+ mctp_routes_exit();
err_unreg_proto:
proto_unregister(&mctp_proto);
err_unreg_sock:
diff --git a/net/mctp/device.c b/net/mctp/device.c
index ef2755f82f87..99a3bda8852f 100644
--- a/net/mctp/device.c
+++ b/net/mctp/device.c
@@ -6,6 +6,7 @@
* Copyright (c) 2021 Google
*/
+#include <linux/if_arp.h>
#include <linux/if_link.h>
#include <linux/mctp.h>
#include <linux/netdevice.h>
@@ -24,12 +25,25 @@ struct mctp_dump_cb {
size_t a_idx;
};
-/* unlocked: caller must hold rcu_read_lock */
+/* unlocked: caller must hold rcu_read_lock.
+ * Returned mctp_dev has its refcount incremented, or NULL if unset.
+ */
struct mctp_dev *__mctp_dev_get(const struct net_device *dev)
{
- return rcu_dereference(dev->mctp_ptr);
+ struct mctp_dev *mdev = rcu_dereference(dev->mctp_ptr);
+
+ /* RCU guarantees that any mdev is still live.
+ * Zero refcount implies a pending free, return NULL.
+ */
+ if (mdev)
+ if (!refcount_inc_not_zero(&mdev->refs))
+ return NULL;
+ return mdev;
}
+/* Returned mctp_dev does not have refcount incremented. The returned pointer
+ * remains live while rtnl_lock is held, as that prevents mctp_unregister()
+ */
struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev)
{
return rtnl_dereference(dev->mctp_ptr);
@@ -106,7 +120,7 @@ static int mctp_dump_addrinfo(struct sk_buff *skb, struct netlink_callback *cb)
struct ifaddrmsg *hdr;
struct mctp_dev *mdev;
int ifindex;
- int idx, rc;
+ int idx = 0, rc;
hdr = nlmsg_data(cb->nlh);
// filter by ifindex if requested
@@ -123,6 +137,7 @@ static int mctp_dump_addrinfo(struct sk_buff *skb, struct netlink_callback *cb)
if (mdev) {
rc = mctp_dump_dev_addrinfo(mdev,
skb, cb);
+ mctp_dev_put(mdev);
// Error indicates full buffer, this
// callback will get retried.
if (rc < 0)
@@ -208,7 +223,7 @@ static int mctp_rtm_newaddr(struct sk_buff *skb, struct nlmsghdr *nlh,
if (!mdev)
return -ENODEV;
- if (!mctp_address_ok(addr->s_addr))
+ if (!mctp_address_unicast(addr->s_addr))
return -EINVAL;
/* Prevent duplicates. Under RTNL so don't need to lock for reading */
@@ -297,7 +312,8 @@ void mctp_dev_hold(struct mctp_dev *mdev)
void mctp_dev_put(struct mctp_dev *mdev)
{
- if (refcount_dec_and_test(&mdev->refs)) {
+ if (mdev && refcount_dec_and_test(&mdev->refs)) {
+ kfree(mdev->addrs);
dev_put(mdev->dev);
kfree_rcu(mdev, rcu);
}
@@ -369,6 +385,7 @@ static size_t mctp_get_link_af_size(const struct net_device *dev,
if (!mdev)
return 0;
ret = nla_total_size(4); /* IFLA_MCTP_NET */
+ mctp_dev_put(mdev);
return ret;
}
@@ -412,10 +429,10 @@ static void mctp_unregister(struct net_device *dev)
struct mctp_dev *mdev;
mdev = mctp_dev_get_rtnl(dev);
- if (mctp_known(dev) != (bool)mdev) {
+ if (mdev && !mctp_known(dev)) {
// Sanity check, should match what was set in mctp_register
- netdev_warn(dev, "%s: mdev pointer %d but type (%d) match is %d",
- __func__, (bool)mdev, mctp_known(dev), dev->type);
+ netdev_warn(dev, "%s: BUG mctp_ptr set for unknown type %d",
+ __func__, dev->type);
return;
}
if (!mdev)
@@ -425,7 +442,6 @@ static void mctp_unregister(struct net_device *dev)
mctp_route_remove_dev(mdev);
mctp_neigh_remove_dev(mdev);
- kfree(mdev->addrs);
mctp_dev_put(mdev);
}
@@ -439,7 +455,7 @@ static int mctp_register(struct net_device *dev)
if (mdev) {
if (!mctp_known(dev))
- netdev_warn(dev, "%s: mctp_dev set for unknown type %d",
+ netdev_warn(dev, "%s: BUG mctp_ptr set for unknown type %d",
__func__, dev->type);
return 0;
}
diff --git a/net/mctp/neigh.c b/net/mctp/neigh.c
index 6ad3e33bd4d4..ffa0f9e0983f 100644
--- a/net/mctp/neigh.c
+++ b/net/mctp/neigh.c
@@ -143,7 +143,7 @@ static int mctp_rtm_newneigh(struct sk_buff *skb, struct nlmsghdr *nlh,
}
eid = nla_get_u8(tb[NDA_DST]);
- if (!mctp_address_ok(eid)) {
+ if (!mctp_address_unicast(eid)) {
NL_SET_ERR_MSG(extack, "Invalid neighbour EID");
return -EINVAL;
}
diff --git a/net/mctp/route.c b/net/mctp/route.c
index 8d9f4ff3e285..f9a80b82dc51 100644
--- a/net/mctp/route.c
+++ b/net/mctp/route.c
@@ -64,8 +64,7 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
if (msk->bind_type != type)
continue;
- if (msk->bind_addr != MCTP_ADDR_ANY &&
- msk->bind_addr != mh->dest)
+ if (!mctp_address_matches(msk->bind_addr, mh->dest))
continue;
return msk;
@@ -77,7 +76,7 @@ static struct mctp_sock *mctp_lookup_bind(struct net *net, struct sk_buff *skb)
static bool mctp_key_match(struct mctp_sk_key *key, mctp_eid_t local,
mctp_eid_t peer, u8 tag)
{
- if (key->local_addr != local)
+ if (!mctp_address_matches(key->local_addr, local))
return false;
if (key->peer_addr != peer)
@@ -204,29 +203,38 @@ static int mctp_key_add(struct mctp_sk_key *key, struct mctp_sock *msk)
return rc;
}
-/* We're done with the key; unset valid and remove from lists. There may still
- * be outstanding refs on the key though...
+/* Helper for mctp_route_input().
+ * We're done with the key; unlock and unref the key.
+ * For the usual case of automatic expiry we remove the key from lists.
+ * In the case that manual allocation is set on a key we release the lock
+ * and local ref, reset reassembly, but don't remove from lists.
*/
-static void __mctp_key_unlock_drop(struct mctp_sk_key *key, struct net *net,
- unsigned long flags)
- __releases(&key->lock)
+static void __mctp_key_done_in(struct mctp_sk_key *key, struct net *net,
+ unsigned long flags, unsigned long reason)
+__releases(&key->lock)
{
struct sk_buff *skb;
+ trace_mctp_key_release(key, reason);
skb = key->reasm_head;
key->reasm_head = NULL;
- key->reasm_dead = true;
- key->valid = false;
- mctp_dev_release_key(key->dev, key);
- spin_unlock_irqrestore(&key->lock, flags);
- spin_lock_irqsave(&net->mctp.keys_lock, flags);
- hlist_del(&key->hlist);
- hlist_del(&key->sklist);
- spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+ if (!key->manual_alloc) {
+ key->reasm_dead = true;
+ key->valid = false;
+ mctp_dev_release_key(key->dev, key);
+ }
+ spin_unlock_irqrestore(&key->lock, flags);
- /* one unref for the lists */
- mctp_key_unref(key);
+ if (!key->manual_alloc) {
+ spin_lock_irqsave(&net->mctp.keys_lock, flags);
+ if (!hlist_unhashed(&key->hlist)) {
+ hlist_del_init(&key->hlist);
+ hlist_del_init(&key->sklist);
+ mctp_key_unref(key);
+ }
+ spin_unlock_irqrestore(&net->mctp.keys_lock, flags);
+ }
/* and one for the local reference */
mctp_key_unref(key);
@@ -380,9 +388,8 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
/* we've hit a pending reassembly; not much we
* can do but drop it
*/
- trace_mctp_key_release(key,
- MCTP_TRACE_KEY_REPLIED);
- __mctp_key_unlock_drop(key, net, f);
+ __mctp_key_done_in(key, net, f,
+ MCTP_TRACE_KEY_REPLIED);
key = NULL;
}
rc = 0;
@@ -412,21 +419,21 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
* this function.
*/
rc = mctp_key_add(key, msk);
- if (rc)
+ if (rc) {
kfree(key);
+ } else {
+ trace_mctp_key_acquire(key);
- trace_mctp_key_acquire(key);
-
- /* we don't need to release key->lock on exit */
- mctp_key_unref(key);
+ /* we don't need to release key->lock on exit */
+ mctp_key_unref(key);
+ }
key = NULL;
} else {
if (key->reasm_head || key->reasm_dead) {
/* duplicate start? drop everything */
- trace_mctp_key_release(key,
- MCTP_TRACE_KEY_INVALIDATED);
- __mctp_key_unlock_drop(key, net, f);
+ __mctp_key_done_in(key, net, f,
+ MCTP_TRACE_KEY_INVALIDATED);
rc = -EEXIST;
key = NULL;
} else {
@@ -451,8 +458,7 @@ static int mctp_route_input(struct mctp_route *route, struct sk_buff *skb)
if (!rc && flags & MCTP_HDR_FLAG_EOM) {
sock_queue_rcv_skb(key->sk, key->reasm_head);
key->reasm_head = NULL;
- trace_mctp_key_release(key, MCTP_TRACE_KEY_REPLIED);
- __mctp_key_unlock_drop(key, net, f);
+ __mctp_key_done_in(key, net, f, MCTP_TRACE_KEY_REPLIED);
key = NULL;
}
@@ -497,6 +503,11 @@ static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
if (cb->ifindex) {
/* direct route; use the hwaddr we stashed in sendmsg */
+ if (cb->halen != skb->dev->addr_len) {
+ /* sanity check, sendmsg should have already caught this */
+ kfree_skb(skb);
+ return -EMSGSIZE;
+ }
daddr = cb->haddr;
} else {
/* If lookup fails let the device handle daddr==NULL */
@@ -506,7 +517,7 @@ static int mctp_route_output(struct mctp_route *route, struct sk_buff *skb)
rc = dev_hard_header(skb, skb->dev, ntohs(skb->protocol),
daddr, skb->dev->dev_addr, skb->len);
- if (rc) {
+ if (rc < 0) {
kfree_skb(skb);
return -EHOSTUNREACH;
}
@@ -580,9 +591,9 @@ static void mctp_reserve_tag(struct net *net, struct mctp_sk_key *key,
/* Allocate a locally-owned tag value for (saddr, daddr), and reserve
* it for the socket msk
*/
-static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
- mctp_eid_t saddr,
- mctp_eid_t daddr, u8 *tagp)
+struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
+ mctp_eid_t daddr, mctp_eid_t saddr,
+ bool manual, u8 *tagp)
{
struct net *net = sock_net(&msk->sk);
struct netns_mctp *mns = &net->mctp;
@@ -616,9 +627,8 @@ static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
if (tmp->tag & MCTP_HDR_FLAG_TO)
continue;
- if (!((tmp->peer_addr == daddr ||
- tmp->peer_addr == MCTP_ADDR_ANY) &&
- tmp->local_addr == saddr))
+ if (!(mctp_address_matches(tmp->peer_addr, daddr) &&
+ mctp_address_matches(tmp->local_addr, saddr)))
continue;
spin_lock(&tmp->lock);
@@ -638,6 +648,7 @@ static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
mctp_reserve_tag(net, key, msk);
trace_mctp_key_acquire(key);
+ key->manual_alloc = manual;
*tagp = key->tag;
}
@@ -651,6 +662,50 @@ static struct mctp_sk_key *mctp_alloc_local_tag(struct mctp_sock *msk,
return key;
}
+static struct mctp_sk_key *mctp_lookup_prealloc_tag(struct mctp_sock *msk,
+ mctp_eid_t daddr,
+ u8 req_tag, u8 *tagp)
+{
+ struct net *net = sock_net(&msk->sk);
+ struct netns_mctp *mns = &net->mctp;
+ struct mctp_sk_key *key, *tmp;
+ unsigned long flags;
+
+ req_tag &= ~(MCTP_TAG_PREALLOC | MCTP_TAG_OWNER);
+ key = NULL;
+
+ spin_lock_irqsave(&mns->keys_lock, flags);
+
+ hlist_for_each_entry(tmp, &mns->keys, hlist) {
+ if (tmp->tag != req_tag)
+ continue;
+
+ if (!mctp_address_matches(tmp->peer_addr, daddr))
+ continue;
+
+ if (!tmp->manual_alloc)
+ continue;
+
+ spin_lock(&tmp->lock);
+ if (tmp->valid) {
+ key = tmp;
+ refcount_inc(&key->refs);
+ spin_unlock(&tmp->lock);
+ break;
+ }
+ spin_unlock(&tmp->lock);
+ }
+ spin_unlock_irqrestore(&mns->keys_lock, flags);
+
+ if (!key)
+ return ERR_PTR(-ENOENT);
+
+ if (tagp)
+ *tagp = key->tag;
+
+ return key;
+}
+
/* routing lookups */
static bool mctp_rt_match_eid(struct mctp_route *rt,
unsigned int net, mctp_eid_t eid)
@@ -706,7 +761,7 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
{
const unsigned int hlen = sizeof(struct mctp_hdr);
struct mctp_hdr *hdr, *hdr2;
- unsigned int pos, size;
+ unsigned int pos, size, headroom;
struct sk_buff *skb2;
int rc;
u8 seq;
@@ -720,6 +775,9 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
return -EMSGSIZE;
}
+ /* keep same headroom as the original skb */
+ headroom = skb_headroom(skb);
+
/* we've got the header */
skb_pull(skb, hlen);
@@ -727,7 +785,7 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
/* size of message payload */
size = min(mtu - hlen, skb->len - pos);
- skb2 = alloc_skb(MCTP_HEADER_MAXLEN + hlen + size, GFP_KERNEL);
+ skb2 = alloc_skb(headroom + hlen + size, GFP_KERNEL);
if (!skb2) {
rc = -ENOMEM;
break;
@@ -743,7 +801,7 @@ static int mctp_do_fragment_route(struct mctp_route *rt, struct sk_buff *skb,
skb_set_owner_w(skb2, skb->sk);
/* establish packet */
- skb_reserve(skb2, MCTP_HEADER_MAXLEN);
+ skb_reserve(skb2, headroom);
skb_reset_network_header(skb2);
skb_put(skb2, hlen + size);
skb2->transport_header = skb2->network_header + hlen;
@@ -785,9 +843,8 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
{
struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk);
struct mctp_skb_cb *cb = mctp_cb(skb);
- struct mctp_route tmp_rt;
+ struct mctp_route tmp_rt = {0};
struct mctp_sk_key *key;
- struct net_device *dev;
struct mctp_hdr *hdr;
unsigned long flags;
unsigned int mtu;
@@ -800,12 +857,12 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
if (rt) {
ext_rt = false;
- dev = NULL;
-
if (WARN_ON(!rt->dev))
goto out_release;
} else if (cb->ifindex) {
+ struct net_device *dev;
+
ext_rt = true;
rt = &tmp_rt;
@@ -815,7 +872,6 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
rcu_read_unlock();
return rc;
}
-
rt->dev = __mctp_dev_get(dev);
rcu_read_unlock();
@@ -845,8 +901,14 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
if (rc)
goto out_release;
- if (req_tag & MCTP_HDR_FLAG_TO) {
- key = mctp_alloc_local_tag(msk, saddr, daddr, &tag);
+ if (req_tag & MCTP_TAG_OWNER) {
+ if (req_tag & MCTP_TAG_PREALLOC)
+ key = mctp_lookup_prealloc_tag(msk, daddr,
+ req_tag, &tag);
+ else
+ key = mctp_alloc_local_tag(msk, daddr, saddr,
+ false, &tag);
+
if (IS_ERR(key)) {
rc = PTR_ERR(key);
goto out_release;
@@ -857,7 +919,7 @@ int mctp_local_output(struct sock *sk, struct mctp_route *rt,
tag |= MCTP_HDR_FLAG_TO;
} else {
key = NULL;
- tag = req_tag;
+ tag = req_tag & MCTP_TAG_MASK;
}
skb->protocol = htons(ETH_P_MCTP);
@@ -890,10 +952,9 @@ out_release:
if (!ext_rt)
mctp_route_release(rt);
- dev_put(dev);
+ mctp_dev_put(tmp_rt.dev);
return rc;
-
}
/* route management */
@@ -905,7 +966,7 @@ static int mctp_route_add(struct mctp_dev *mdev, mctp_eid_t daddr_start,
struct net *net = dev_net(mdev->dev);
struct mctp_route *rt, *ert;
- if (!mctp_address_ok(daddr_start))
+ if (!mctp_address_unicast(daddr_start))
return -EINVAL;
if (daddr_extent > 0xff || daddr_start + daddr_extent >= 255)
@@ -1035,6 +1096,17 @@ static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev,
if (mh->ver < MCTP_VER_MIN || mh->ver > MCTP_VER_MAX)
goto err_drop;
+ /* source must be valid unicast or null; drop reserved ranges and
+ * broadcast
+ */
+ if (!(mctp_address_unicast(mh->src) || mctp_address_null(mh->src)))
+ goto err_drop;
+
+ /* dest address: as above, but allow broadcast */
+ if (!(mctp_address_unicast(mh->dest) || mctp_address_null(mh->dest) ||
+ mctp_address_broadcast(mh->dest)))
+ goto err_drop;
+
/* MCTP drivers must populate halen/haddr */
if (dev->type == ARPHRD_MCTP) {
cb = mctp_cb(skb);
@@ -1056,11 +1128,13 @@ static int mctp_pkttype_receive(struct sk_buff *skb, struct net_device *dev,
rt->output(rt, skb);
mctp_route_release(rt);
+ mctp_dev_put(mdev);
return NET_RX_SUCCESS;
err_drop:
kfree_skb(skb);
+ mctp_dev_put(mdev);
return NET_RX_DROP;
}
@@ -1326,7 +1400,7 @@ int __init mctp_routes_init(void)
return register_pernet_subsys(&mctp_net_ops);
}
-void __exit mctp_routes_exit(void)
+void mctp_routes_exit(void)
{
unregister_pernet_subsys(&mctp_net_ops);
rtnl_unregister(PF_MCTP, RTM_DELROUTE);
diff --git a/net/mctp/test/route-test.c b/net/mctp/test/route-test.c
index 86ad15abf897..92ea4158f7fc 100644
--- a/net/mctp/test/route-test.c
+++ b/net/mctp/test/route-test.c
@@ -285,7 +285,7 @@ static void __mctp_route_test_init(struct kunit *test,
struct mctp_test_route **rtp,
struct socket **sockp)
{
- struct sockaddr_mctp addr;
+ struct sockaddr_mctp addr = {0};
struct mctp_test_route *rt;
struct mctp_test_dev *dev;
struct socket *sock;
@@ -352,7 +352,7 @@ static void mctp_test_route_input_sk(struct kunit *test)
if (params->deliver) {
KUNIT_EXPECT_EQ(test, rc, 0);
- skb2 = skb_recv_datagram(sock->sk, 0, 1, &rc);
+ skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
KUNIT_EXPECT_EQ(test, skb->len, 1);
@@ -360,8 +360,8 @@ static void mctp_test_route_input_sk(struct kunit *test)
} else {
KUNIT_EXPECT_NE(test, rc, 0);
- skb2 = skb_recv_datagram(sock->sk, 0, 1, &rc);
- KUNIT_EXPECT_PTR_EQ(test, skb2, NULL);
+ skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
+ KUNIT_EXPECT_NULL(test, skb2);
}
__mctp_route_test_fini(test, dev, rt, sock);
@@ -369,14 +369,15 @@ static void mctp_test_route_input_sk(struct kunit *test)
#define FL_S (MCTP_HDR_FLAG_SOM)
#define FL_E (MCTP_HDR_FLAG_EOM)
-#define FL_T (MCTP_HDR_FLAG_TO)
+#define FL_TO (MCTP_HDR_FLAG_TO)
+#define FL_T(t) ((t) & MCTP_HDR_TAG_MASK)
static const struct mctp_route_input_sk_test mctp_route_input_sk_tests[] = {
- { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T), .type = 0, .deliver = true },
- { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T), .type = 1, .deliver = false },
+ { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 0, .deliver = true },
+ { .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_TO), .type = 1, .deliver = false },
{ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E), .type = 0, .deliver = false },
- { .hdr = RX_HDR(1, 10, 8, FL_E | FL_T), .type = 0, .deliver = false },
- { .hdr = RX_HDR(1, 10, 8, FL_T), .type = 0, .deliver = false },
+ { .hdr = RX_HDR(1, 10, 8, FL_E | FL_TO), .type = 0, .deliver = false },
+ { .hdr = RX_HDR(1, 10, 8, FL_TO), .type = 0, .deliver = false },
{ .hdr = RX_HDR(1, 10, 8, 0), .type = 0, .deliver = false },
};
@@ -422,7 +423,7 @@ static void mctp_test_route_input_sk_reasm(struct kunit *test)
rc = mctp_route_input(&rt->rt, skb);
}
- skb2 = skb_recv_datagram(sock->sk, 0, 1, &rc);
+ skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
if (params->rx_len) {
KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
@@ -430,13 +431,13 @@ static void mctp_test_route_input_sk_reasm(struct kunit *test)
skb_free_datagram(sock->sk, skb2);
} else {
- KUNIT_EXPECT_PTR_EQ(test, skb2, NULL);
+ KUNIT_EXPECT_NULL(test, skb2);
}
__mctp_route_test_fini(test, dev, rt, sock);
}
-#define RX_FRAG(f, s) RX_HDR(1, 10, 8, FL_T | (f) | ((s) << MCTP_HDR_SEQ_SHIFT))
+#define RX_FRAG(f, s) RX_HDR(1, 10, 8, FL_TO | (f) | ((s) << MCTP_HDR_SEQ_SHIFT))
static const struct mctp_route_input_sk_reasm_test mctp_route_input_sk_reasm_tests[] = {
{
@@ -522,12 +523,156 @@ static void mctp_route_input_sk_reasm_to_desc(
KUNIT_ARRAY_PARAM(mctp_route_input_sk_reasm, mctp_route_input_sk_reasm_tests,
mctp_route_input_sk_reasm_to_desc);
+struct mctp_route_input_sk_keys_test {
+ const char *name;
+ mctp_eid_t key_peer_addr;
+ mctp_eid_t key_local_addr;
+ u8 key_tag;
+ struct mctp_hdr hdr;
+ bool deliver;
+};
+
+/* test packet rx in the presence of various key configurations */
+static void mctp_test_route_input_sk_keys(struct kunit *test)
+{
+ const struct mctp_route_input_sk_keys_test *params;
+ struct mctp_test_route *rt;
+ struct sk_buff *skb, *skb2;
+ struct mctp_test_dev *dev;
+ struct mctp_sk_key *key;
+ struct netns_mctp *mns;
+ struct mctp_sock *msk;
+ struct socket *sock;
+ unsigned long flags;
+ int rc;
+ u8 c;
+
+ params = test->param_value;
+
+ dev = mctp_test_create_dev();
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, dev);
+
+ rt = mctp_test_create_route(&init_net, dev->mdev, 8, 68);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, rt);
+
+ rc = sock_create_kern(&init_net, AF_MCTP, SOCK_DGRAM, 0, &sock);
+ KUNIT_ASSERT_EQ(test, rc, 0);
+
+ msk = container_of(sock->sk, struct mctp_sock, sk);
+ mns = &sock_net(sock->sk)->mctp;
+
+ /* set the incoming tag according to test params */
+ key = mctp_key_alloc(msk, params->key_local_addr, params->key_peer_addr,
+ params->key_tag, GFP_KERNEL);
+
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, key);
+
+ spin_lock_irqsave(&mns->keys_lock, flags);
+ mctp_reserve_tag(&init_net, key, msk);
+ spin_unlock_irqrestore(&mns->keys_lock, flags);
+
+ /* create packet and route */
+ c = 0;
+ skb = mctp_test_create_skb_data(&params->hdr, &c);
+ KUNIT_ASSERT_NOT_ERR_OR_NULL(test, skb);
+
+ skb->dev = dev->ndev;
+ __mctp_cb(skb);
+
+ rc = mctp_route_input(&rt->rt, skb);
+
+ /* (potentially) receive message */
+ skb2 = skb_recv_datagram(sock->sk, MSG_DONTWAIT, &rc);
+
+ if (params->deliver)
+ KUNIT_EXPECT_NOT_ERR_OR_NULL(test, skb2);
+ else
+ KUNIT_EXPECT_PTR_EQ(test, skb2, NULL);
+
+ if (skb2)
+ skb_free_datagram(sock->sk, skb2);
+
+ mctp_key_unref(key);
+ __mctp_route_test_fini(test, dev, rt, sock);
+}
+
+static const struct mctp_route_input_sk_keys_test mctp_route_input_sk_keys_tests[] = {
+ {
+ .name = "direct match",
+ .key_peer_addr = 9,
+ .key_local_addr = 8,
+ .key_tag = 1,
+ .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)),
+ .deliver = true,
+ },
+ {
+ .name = "flipped src/dest",
+ .key_peer_addr = 8,
+ .key_local_addr = 9,
+ .key_tag = 1,
+ .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1)),
+ .deliver = false,
+ },
+ {
+ .name = "peer addr mismatch",
+ .key_peer_addr = 9,
+ .key_local_addr = 8,
+ .key_tag = 1,
+ .hdr = RX_HDR(1, 10, 8, FL_S | FL_E | FL_T(1)),
+ .deliver = false,
+ },
+ {
+ .name = "tag value mismatch",
+ .key_peer_addr = 9,
+ .key_local_addr = 8,
+ .key_tag = 1,
+ .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(2)),
+ .deliver = false,
+ },
+ {
+ .name = "TO mismatch",
+ .key_peer_addr = 9,
+ .key_local_addr = 8,
+ .key_tag = 1,
+ .hdr = RX_HDR(1, 9, 8, FL_S | FL_E | FL_T(1) | FL_TO),
+ .deliver = false,
+ },
+ {
+ .name = "broadcast response",
+ .key_peer_addr = MCTP_ADDR_ANY,
+ .key_local_addr = 8,
+ .key_tag = 1,
+ .hdr = RX_HDR(1, 11, 8, FL_S | FL_E | FL_T(1)),
+ .deliver = true,
+ },
+ {
+ .name = "any local match",
+ .key_peer_addr = 12,
+ .key_local_addr = MCTP_ADDR_ANY,
+ .key_tag = 1,
+ .hdr = RX_HDR(1, 12, 8, FL_S | FL_E | FL_T(1)),
+ .deliver = true,
+ },
+};
+
+static void mctp_route_input_sk_keys_to_desc(
+ const struct mctp_route_input_sk_keys_test *t,
+ char *desc)
+{
+ sprintf(desc, "%s", t->name);
+}
+
+KUNIT_ARRAY_PARAM(mctp_route_input_sk_keys, mctp_route_input_sk_keys_tests,
+ mctp_route_input_sk_keys_to_desc);
+
static struct kunit_case mctp_test_cases[] = {
KUNIT_CASE_PARAM(mctp_test_fragment, mctp_frag_gen_params),
KUNIT_CASE_PARAM(mctp_test_rx_input, mctp_rx_input_gen_params),
KUNIT_CASE_PARAM(mctp_test_route_input_sk, mctp_route_input_sk_gen_params),
KUNIT_CASE_PARAM(mctp_test_route_input_sk_reasm,
mctp_route_input_sk_reasm_gen_params),
+ KUNIT_CASE_PARAM(mctp_test_route_input_sk_keys,
+ mctp_route_input_sk_keys_gen_params),
{}
};
diff --git a/net/mctp/test/utils.c b/net/mctp/test/utils.c
index 7b7918702592..e03ba66bbe18 100644
--- a/net/mctp/test/utils.c
+++ b/net/mctp/test/utils.c
@@ -54,7 +54,6 @@ struct mctp_test_dev *mctp_test_create_dev(void)
rcu_read_lock();
dev->mdev = __mctp_dev_get(ndev);
- mctp_dev_hold(dev->mdev);
rcu_read_unlock();
return dev;