aboutsummaryrefslogtreecommitdiffstats
path: root/net/mptcp/pm_netlink.c
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--net/mptcp/pm_netlink.c826
1 files changed, 739 insertions, 87 deletions
diff --git a/net/mptcp/pm_netlink.c b/net/mptcp/pm_netlink.c
index a6d983d80576..8e8e35fa4002 100644
--- a/net/mptcp/pm_netlink.c
+++ b/net/mptcp/pm_netlink.c
@@ -26,6 +26,7 @@ struct mptcp_pm_addr_entry {
struct list_head list;
struct mptcp_addr_info addr;
struct rcu_head rcu;
+ struct socket *lsk;
};
struct mptcp_pm_add_entry {
@@ -36,6 +37,9 @@ struct mptcp_pm_add_entry {
u8 retrans_times;
};
+#define MAX_ADDR_ID 255
+#define BITMAP_SZ DIV_ROUND_UP(MAX_ADDR_ID + 1, BITS_PER_LONG)
+
struct pm_nl_pernet {
/* protects pernet updates */
spinlock_t lock;
@@ -46,25 +50,33 @@ struct pm_nl_pernet {
unsigned int local_addr_max;
unsigned int subflows_max;
unsigned int next_id;
+ unsigned long id_bitmap[BITMAP_SZ];
};
#define MPTCP_PM_ADDR_MAX 8
#define ADD_ADDR_RETRANS_MAX 3
+static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk);
+
static bool addresses_equal(const struct mptcp_addr_info *a,
struct mptcp_addr_info *b, bool use_port)
{
bool addr_equals = false;
- if (a->family != b->family)
- return false;
-
- if (a->family == AF_INET)
- addr_equals = a->addr.s_addr == b->addr.s_addr;
+ if (a->family == b->family) {
+ if (a->family == AF_INET)
+ addr_equals = a->addr.s_addr == b->addr.s_addr;
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
- else
- addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
+ else
+ addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
+ } else if (a->family == AF_INET) {
+ if (ipv6_addr_v4mapped(&b->addr6))
+ addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3];
+ } else if (b->family == AF_INET) {
+ if (ipv6_addr_v4mapped(&a->addr6))
+ addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr;
#endif
+ }
if (!addr_equals)
return false;
@@ -81,14 +93,14 @@ static bool address_zero(const struct mptcp_addr_info *addr)
memset(&zero, 0, sizeof(zero));
zero.family = addr->family;
- return addresses_equal(addr, &zero, false);
+ return addresses_equal(addr, &zero, true);
}
static void local_address(const struct sock_common *skc,
struct mptcp_addr_info *addr)
{
- addr->port = 0;
addr->family = skc->skc_family;
+ addr->port = htons(skc->skc_num);
if (addr->family == AF_INET)
addr->addr.s_addr = skc->skc_rcv_saddr;
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
@@ -121,7 +133,7 @@ static bool lookup_subflow_by_saddr(const struct list_head *list,
skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
local_address(skc, &cur);
- if (addresses_equal(&cur, saddr, false))
+ if (addresses_equal(&cur, saddr, saddr->port))
return true;
}
@@ -133,6 +145,9 @@ select_local_address(const struct pm_nl_pernet *pernet,
struct mptcp_sock *msk)
{
struct mptcp_pm_addr_entry *entry, *ret = NULL;
+ struct sock *sk = (struct sock *)msk;
+
+ msk_owned_by_me(msk);
rcu_read_lock();
__mptcp_flush_join_list(msk);
@@ -140,11 +155,20 @@ select_local_address(const struct pm_nl_pernet *pernet,
if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
continue;
+ if (entry->addr.family != sk->sk_family) {
+#if IS_ENABLED(CONFIG_MPTCP_IPV6)
+ if ((entry->addr.family == AF_INET &&
+ !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
+ (sk->sk_family == AF_INET &&
+ !ipv6_addr_v4mapped(&entry->addr.addr6)))
+#endif
+ continue;
+ }
+
/* avoid any address already in use by subflows and
* pending join
*/
- if (entry->addr.family == ((struct sock *)msk)->sk_family &&
- !lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) {
+ if (!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) {
ret = entry;
break;
}
@@ -177,11 +201,47 @@ select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
return ret;
}
+unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk)
+{
+ struct pm_nl_pernet *pernet;
+
+ pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
+ return READ_ONCE(pernet->add_addr_signal_max);
+}
+EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max);
+
+unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk)
+{
+ struct pm_nl_pernet *pernet;
+
+ pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
+ return READ_ONCE(pernet->add_addr_accept_max);
+}
+EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max);
+
+unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk)
+{
+ struct pm_nl_pernet *pernet;
+
+ pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
+ return READ_ONCE(pernet->subflows_max);
+}
+EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max);
+
+unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk)
+{
+ struct pm_nl_pernet *pernet;
+
+ pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
+ return READ_ONCE(pernet->local_addr_max);
+}
+EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max);
+
static void check_work_pending(struct mptcp_sock *msk)
{
- if (msk->pm.add_addr_signaled == msk->pm.add_addr_signal_max &&
- (msk->pm.local_addr_used == msk->pm.local_addr_max ||
- msk->pm.subflows == msk->pm.subflows_max))
+ if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) &&
+ (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) ||
+ msk->pm.subflows == mptcp_pm_get_subflows_max(msk)))
WRITE_ONCE(msk->pm.work_pending, false);
}
@@ -191,14 +251,37 @@ lookup_anno_list_by_saddr(struct mptcp_sock *msk,
{
struct mptcp_pm_add_entry *entry;
+ lockdep_assert_held(&msk->pm.lock);
+
list_for_each_entry(entry, &msk->pm.anno_list, list) {
- if (addresses_equal(&entry->addr, addr, false))
+ if (addresses_equal(&entry->addr, addr, true))
return entry;
}
return NULL;
}
+bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
+{
+ struct mptcp_pm_add_entry *entry;
+ struct mptcp_addr_info saddr;
+ bool ret = false;
+
+ local_address((struct sock_common *)sk, &saddr);
+
+ spin_lock_bh(&msk->pm.lock);
+ list_for_each_entry(entry, &msk->pm.anno_list, list) {
+ if (addresses_equal(&entry->addr, &saddr, true)) {
+ ret = true;
+ goto out;
+ }
+ }
+
+out:
+ spin_unlock_bh(&msk->pm.lock);
+ return ret;
+}
+
static void mptcp_pm_add_timer(struct timer_list *timer)
{
struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
@@ -266,6 +349,8 @@ static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
struct sock *sk = (struct sock *)msk;
struct net *net = sock_net(sk);
+ lockdep_assert_held(&msk->pm.lock);
+
if (lookup_anno_list_by_saddr(msk, &entry->addr))
return false;
@@ -306,20 +391,26 @@ void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
{
- struct mptcp_addr_info remote = { 0 };
struct sock *sk = (struct sock *)msk;
struct mptcp_pm_addr_entry *local;
+ unsigned int add_addr_signal_max;
+ unsigned int local_addr_max;
struct pm_nl_pernet *pernet;
+ unsigned int subflows_max;
pernet = net_generic(sock_net(sk), pm_nl_pernet_id);
+ add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk);
+ local_addr_max = mptcp_pm_get_local_addr_max(msk);
+ subflows_max = mptcp_pm_get_subflows_max(msk);
+
pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
- msk->pm.local_addr_used, msk->pm.local_addr_max,
- msk->pm.add_addr_signaled, msk->pm.add_addr_signal_max,
- msk->pm.subflows, msk->pm.subflows_max);
+ msk->pm.local_addr_used, local_addr_max,
+ msk->pm.add_addr_signaled, add_addr_signal_max,
+ msk->pm.subflows, subflows_max);
/* check first for announce */
- if (msk->pm.add_addr_signaled < msk->pm.add_addr_signal_max) {
+ if (msk->pm.add_addr_signaled < add_addr_signal_max) {
local = select_signal_address(pernet,
msk->pm.add_addr_signaled);
@@ -331,22 +422,23 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
}
} else {
/* pick failed, avoid fourther attempts later */
- msk->pm.local_addr_used = msk->pm.add_addr_signal_max;
+ msk->pm.local_addr_used = add_addr_signal_max;
}
check_work_pending(msk);
}
/* check if should create a new subflow */
- if (msk->pm.local_addr_used < msk->pm.local_addr_max &&
- msk->pm.subflows < msk->pm.subflows_max) {
- remote_address((struct sock_common *)sk, &remote);
-
+ if (msk->pm.local_addr_used < local_addr_max &&
+ msk->pm.subflows < subflows_max) {
local = select_local_address(pernet, msk);
if (local) {
+ struct mptcp_addr_info remote = { 0 };
+
msk->pm.local_addr_used++;
msk->pm.subflows++;
check_work_pending(msk);
+ remote_address((struct sock_common *)sk, &remote);
spin_unlock_bh(&msk->pm.lock);
__mptcp_subflow_connect(sk, &local->addr, &remote);
spin_lock_bh(&msk->pm.lock);
@@ -354,35 +446,40 @@ static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
}
/* lookup failed, avoid fourther attempts later */
- msk->pm.local_addr_used = msk->pm.local_addr_max;
+ msk->pm.local_addr_used = local_addr_max;
check_work_pending(msk);
}
}
-void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
+static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
{
mptcp_pm_create_subflow_or_signal_addr(msk);
}
-void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
+static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
{
mptcp_pm_create_subflow_or_signal_addr(msk);
}
-void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
+static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
{
struct sock *sk = (struct sock *)msk;
+ unsigned int add_addr_accept_max;
struct mptcp_addr_info remote;
struct mptcp_addr_info local;
+ unsigned int subflows_max;
bool use_port = false;
+ add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk);
+ subflows_max = mptcp_pm_get_subflows_max(msk);
+
pr_debug("accepted %d:%d remote family %d",
- msk->pm.add_addr_accepted, msk->pm.add_addr_accept_max,
+ msk->pm.add_addr_accepted, add_addr_accept_max,
msk->pm.remote.family);
msk->pm.add_addr_accepted++;
msk->pm.subflows++;
- if (msk->pm.add_addr_accepted >= msk->pm.add_addr_accept_max ||
- msk->pm.subflows >= msk->pm.subflows_max)
+ if (msk->pm.add_addr_accepted >= add_addr_accept_max ||
+ msk->pm.subflows >= subflows_max)
WRITE_ONCE(msk->pm.accept_addr, false);
/* connect to the specified remote address, using whatever
@@ -404,12 +501,14 @@ void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
mptcp_pm_nl_add_addr_send_ack(msk);
}
-void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk)
+static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk)
{
struct mptcp_subflow_context *subflow;
- if (!mptcp_pm_should_add_signal_ipv6(msk) &&
- !mptcp_pm_should_add_signal_port(msk))
+ msk_owned_by_me(msk);
+ lockdep_assert_held(&msk->pm.lock);
+
+ if (!mptcp_pm_should_add_signal(msk))
return;
__mptcp_flush_join_list(msk);
@@ -419,10 +518,9 @@ void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk)
u8 add_addr;
spin_unlock_bh(&msk->pm.lock);
- if (mptcp_pm_should_add_signal_ipv6(msk))
- pr_debug("send ack for add_addr6");
- if (mptcp_pm_should_add_signal_port(msk))
- pr_debug("send ack for add_addr_port");
+ pr_debug("send ack for add_addr%s%s",
+ mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "",
+ mptcp_pm_should_add_signal_port(msk) ? " [port]" : "");
lock_sock(ssk);
tcp_send_ack(ssk);
@@ -438,13 +536,50 @@ void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk)
}
}
-void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
+int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
+ struct mptcp_addr_info *addr,
+ u8 bkup)
+{
+ struct mptcp_subflow_context *subflow;
+
+ pr_debug("bkup=%d", bkup);
+
+ mptcp_for_each_subflow(msk, subflow) {
+ struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
+ struct sock *sk = (struct sock *)msk;
+ struct mptcp_addr_info local;
+
+ local_address((struct sock_common *)ssk, &local);
+ if (!addresses_equal(&local, addr, addr->port))
+ continue;
+
+ subflow->backup = bkup;
+ subflow->send_mp_prio = 1;
+ subflow->request_bkup = bkup;
+ __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX);
+
+ spin_unlock_bh(&msk->pm.lock);
+ pr_debug("send ack for mp_prio");
+ lock_sock(ssk);
+ tcp_send_ack(ssk);
+ release_sock(ssk);
+ spin_lock_bh(&msk->pm.lock);
+
+ return 0;
+ }
+
+ return -EINVAL;
+}
+
+static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
{
struct mptcp_subflow_context *subflow, *tmp;
struct sock *sk = (struct sock *)msk;
pr_debug("address rm_id %d", msk->pm.rm_id);
+ msk_owned_by_me(msk);
+
if (!msk->pm.rm_id)
return;
@@ -460,7 +595,7 @@ void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
spin_unlock_bh(&msk->pm.lock);
mptcp_subflow_shutdown(sk, ssk, how);
- __mptcp_close_ssk(sk, ssk, subflow);
+ mptcp_close_ssk(sk, ssk, subflow);
spin_lock_bh(&msk->pm.lock);
msk->pm.add_addr_accepted--;
@@ -473,6 +608,39 @@ void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
}
}
+void mptcp_pm_nl_work(struct mptcp_sock *msk)
+{
+ struct mptcp_pm_data *pm = &msk->pm;
+
+ msk_owned_by_me(msk);
+
+ spin_lock_bh(&msk->pm.lock);
+
+ pr_debug("msk=%p status=%x", msk, pm->status);
+ if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) {
+ pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED);
+ mptcp_pm_nl_add_addr_received(msk);
+ }
+ if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) {
+ pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK);
+ mptcp_pm_nl_add_addr_send_ack(msk);
+ }
+ if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) {
+ pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED);
+ mptcp_pm_nl_rm_addr_received(msk);
+ }
+ if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) {
+ pm->status &= ~BIT(MPTCP_PM_ESTABLISHED);
+ mptcp_pm_nl_fully_established(msk);
+ }
+ if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) {
+ pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED);
+ mptcp_pm_nl_subflow_established(msk);
+ }
+
+ spin_unlock_bh(&msk->pm.lock);
+}
+
void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id)
{
struct mptcp_subflow_context *subflow, *tmp;
@@ -480,6 +648,8 @@ void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id)
pr_debug("subflow rm_id %d", rm_id);
+ msk_owned_by_me(msk);
+
if (!rm_id)
return;
@@ -495,7 +665,7 @@ void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id)
spin_unlock_bh(&msk->pm.lock);
mptcp_subflow_shutdown(sk, ssk, how);
- __mptcp_close_ssk(sk, ssk, subflow);
+ mptcp_close_ssk(sk, ssk, subflow);
spin_lock_bh(&msk->pm.lock);
msk->pm.local_addr_used--;
@@ -518,16 +688,19 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
struct mptcp_pm_addr_entry *entry)
{
struct mptcp_pm_addr_entry *cur;
+ unsigned int addr_max;
int ret = -EINVAL;
spin_lock_bh(&pernet->lock);
/* to keep the code simple, don't do IDR-like allocation for address ID,
* just bail when we exceed limits
*/
- if (pernet->next_id > 255)
- goto out;
+ if (pernet->next_id == MAX_ADDR_ID)
+ pernet->next_id = 1;
if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
goto out;
+ if (test_bit(entry->addr.id, pernet->id_bitmap))
+ goto out;
/* do not insert duplicate address, differentiate on port only
* singled addresses
@@ -539,12 +712,34 @@ static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
goto out;
}
- if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)
- pernet->add_addr_signal_max++;
- if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)
- pernet->local_addr_max++;
+ if (!entry->addr.id) {
+find_next:
+ entry->addr.id = find_next_zero_bit(pernet->id_bitmap,
+ MAX_ADDR_ID + 1,
+ pernet->next_id);
+ if ((!entry->addr.id || entry->addr.id > MAX_ADDR_ID) &&
+ pernet->next_id != 1) {
+ pernet->next_id = 1;
+ goto find_next;
+ }
+ }
+
+ if (!entry->addr.id || entry->addr.id > MAX_ADDR_ID)
+ goto out;
+
+ __set_bit(entry->addr.id, pernet->id_bitmap);
+ if (entry->addr.id > pernet->next_id)
+ pernet->next_id = entry->addr.id;
+
+ if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
+ addr_max = pernet->add_addr_signal_max;
+ WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1);
+ }
+ if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
+ addr_max = pernet->local_addr_max;
+ WRITE_ONCE(pernet->local_addr_max, addr_max + 1);
+ }
- entry->addr.id = pernet->next_id++;
pernet->addrs++;
list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
ret = entry->addr.id;
@@ -554,6 +749,53 @@ out:
return ret;
}
+static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
+ struct mptcp_pm_addr_entry *entry)
+{
+ struct sockaddr_storage addr;
+ struct mptcp_sock *msk;
+ struct socket *ssock;
+ int backlog = 1024;
+ int err;
+
+ err = sock_create_kern(sock_net(sk), entry->addr.family,
+ SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk);
+ if (err)
+ return err;
+
+ msk = mptcp_sk(entry->lsk->sk);
+ if (!msk) {
+ err = -EINVAL;
+ goto out;
+ }
+
+ ssock = __mptcp_nmpc_socket(msk);
+ if (!ssock) {
+ err = -EINVAL;
+ goto out;
+ }
+
+ mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
+ err = kernel_bind(ssock, (struct sockaddr *)&addr,
+ sizeof(struct sockaddr_in));
+ if (err) {
+ pr_warn("kernel_bind error, err=%d", err);
+ goto out;
+ }
+
+ err = kernel_listen(ssock, backlog);
+ if (err) {
+ pr_warn("kernel_listen error, err=%d", err);
+ goto out;
+ }
+
+ return 0;
+
+out:
+ sock_release(entry->lsk);
+ return err;
+}
+
int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
{
struct mptcp_pm_addr_entry *entry;
@@ -580,7 +822,7 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
rcu_read_lock();
list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
- if (addresses_equal(&entry->addr, &skc_local, false)) {
+ if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
ret = entry->addr.id;
break;
}
@@ -597,6 +839,9 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
entry->addr = skc_local;
entry->addr.ifindex = 0;
entry->addr.flags = 0;
+ entry->addr.id = 0;
+ entry->addr.port = 0;
+ entry->lsk = NULL;
ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
if (ret < 0)
kfree(entry);
@@ -607,26 +852,23 @@ int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
void mptcp_pm_nl_data_init(struct mptcp_sock *msk)
{
struct mptcp_pm_data *pm = &msk->pm;
- struct pm_nl_pernet *pernet;
bool subflows;
- pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
-
- pm->add_addr_signal_max = READ_ONCE(pernet->add_addr_signal_max);
- pm->add_addr_accept_max = READ_ONCE(pernet->add_addr_accept_max);
- pm->local_addr_max = READ_ONCE(pernet->local_addr_max);
- pm->subflows_max = READ_ONCE(pernet->subflows_max);
- subflows = !!pm->subflows_max;
- WRITE_ONCE(pm->work_pending, (!!pm->local_addr_max && subflows) ||
- !!pm->add_addr_signal_max);
- WRITE_ONCE(pm->accept_addr, !!pm->add_addr_accept_max && subflows);
+ subflows = !!mptcp_pm_get_subflows_max(msk);
+ WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) ||
+ !!mptcp_pm_get_add_addr_signal_max(msk));
+ WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows);
WRITE_ONCE(pm->accept_subflow, subflows);
}
-#define MPTCP_PM_CMD_GRP_OFFSET 0
+#define MPTCP_PM_CMD_GRP_OFFSET 0
+#define MPTCP_PM_EV_GRP_OFFSET 1
static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
[MPTCP_PM_CMD_GRP_OFFSET] = { .name = MPTCP_PM_CMD_GRP_NAME, },
+ [MPTCP_PM_EV_GRP_OFFSET] = { .name = MPTCP_PM_EV_GRP_NAME,
+ .flags = GENL_UNS_ADMIN_PERM,
+ },
};
static const struct nla_policy
@@ -722,6 +964,9 @@ skip_family:
if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
+ if (tb[MPTCP_PM_ADDR_ATTR_PORT])
+ entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
+
return 0;
}
@@ -730,6 +975,31 @@ static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
return net_generic(genl_info_net(info), pm_nl_pernet_id);
}
+static int mptcp_nl_add_subflow_or_signal_addr(struct net *net)
+{
+ struct mptcp_sock *msk;
+ long s_slot = 0, s_num = 0;
+
+ while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
+ struct sock *sk = (struct sock *)msk;
+
+ if (!READ_ONCE(msk->fully_established))
+ goto next;
+
+ lock_sock(sk);
+ spin_lock_bh(&msk->pm.lock);
+ mptcp_pm_create_subflow_or_signal_addr(msk);
+ spin_unlock_bh(&msk->pm.lock);
+ release_sock(sk);
+
+next:
+ sock_put(sk);
+ cond_resched();
+ }
+
+ return 0;
+}
+
static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
{
struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
@@ -748,13 +1018,25 @@ static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
}
*entry = addr;
+ if (entry->addr.port) {
+ ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry);
+ if (ret) {
+ GENL_SET_ERR_MSG(info, "create listen socket error");
+ kfree(entry);
+ return ret;
+ }
+ }
ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
if (ret < 0) {
GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
+ if (entry->lsk)
+ sock_release(entry->lsk);
kfree(entry);
return ret;
}
+ mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk));
+
return 0;
}
@@ -832,11 +1114,44 @@ next:
return 0;
}
+struct addr_entry_release_work {
+ struct rcu_work rwork;
+ struct mptcp_pm_addr_entry *entry;
+};
+
+static void mptcp_pm_release_addr_entry(struct work_struct *work)
+{
+ struct addr_entry_release_work *w;
+ struct mptcp_pm_addr_entry *entry;
+
+ w = container_of(to_rcu_work(work), struct addr_entry_release_work, rwork);
+ entry = w->entry;
+ if (entry) {
+ if (entry->lsk)
+ sock_release(entry->lsk);
+ kfree(entry);
+ }
+ kfree(w);
+}
+
+static void mptcp_pm_free_addr_entry(struct mptcp_pm_addr_entry *entry)
+{
+ struct addr_entry_release_work *w;
+
+ w = kmalloc(sizeof(*w), GFP_ATOMIC);
+ if (w) {
+ INIT_RCU_WORK(&w->rwork, mptcp_pm_release_addr_entry);
+ w->entry = entry;
+ queue_rcu_work(system_wq, &w->rwork);
+ }
+}
+
static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
{
struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
struct mptcp_pm_addr_entry addr, *entry;
+ unsigned int addr_max;
int ret;
ret = mptcp_pm_parse_addr(attr, info, false, &addr);
@@ -850,17 +1165,22 @@ static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
spin_unlock_bh(&pernet->lock);
return -EINVAL;
}
- if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)
- pernet->add_addr_signal_max--;
- if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)
- pernet->local_addr_max--;
+ if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
+ addr_max = pernet->add_addr_signal_max;
+ WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1);
+ }
+ if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
+ addr_max = pernet->local_addr_max;
+ WRITE_ONCE(pernet->local_addr_max, addr_max - 1);
+ }
pernet->addrs--;
list_del_rcu(&entry->list);
+ __clear_bit(entry->addr.id, pernet->id_bitmap);
spin_unlock_bh(&pernet->lock);
mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr);
- kfree_rcu(entry, rcu);
+ mptcp_pm_free_addr_entry(entry);
return ret;
}
@@ -874,15 +1194,15 @@ static void __flush_addrs(struct net *net, struct list_head *list)
struct mptcp_pm_addr_entry, list);
mptcp_nl_remove_subflow_and_signal_addr(net, &cur->addr);
list_del_rcu(&cur->list);
- kfree_rcu(cur, rcu);
+ mptcp_pm_free_addr_entry(cur);
}
}
static void __reset_counters(struct pm_nl_pernet *pernet)
{
- pernet->add_addr_signal_max = 0;
- pernet->add_addr_accept_max = 0;
- pernet->local_addr_max = 0;
+ WRITE_ONCE(pernet->add_addr_signal_max, 0);
+ WRITE_ONCE(pernet->add_addr_accept_max, 0);
+ WRITE_ONCE(pernet->local_addr_max, 0);
pernet->addrs = 0;
}
@@ -894,6 +1214,8 @@ static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
spin_lock_bh(&pernet->lock);
list_splice_init(&pernet->local_addr_list, &free_list);
__reset_counters(pernet);
+ pernet->next_id = 1;
+ bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1);
spin_unlock_bh(&pernet->lock);
__flush_addrs(sock_net(skb->sk), &free_list);
return 0;
@@ -911,6 +1233,8 @@ static int mptcp_nl_fill_addr(struct sk_buff *skb,
if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
goto nla_put_failure;
+ if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port)))
+ goto nla_put_failure;
if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
goto nla_put_failure;
if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags))
@@ -994,27 +1318,34 @@ static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
struct pm_nl_pernet *pernet;
int id = cb->args[0];
void *hdr;
+ int i;
pernet = net_generic(net, pm_nl_pernet_id);
spin_lock_bh(&pernet->lock);
- list_for_each_entry(entry, &pernet->local_addr_list, list) {
- if (entry->addr.id <= id)
- continue;
-
- hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
- cb->nlh->nlmsg_seq, &mptcp_genl_family,
- NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
- if (!hdr)
- break;
+ for (i = id; i < MAX_ADDR_ID + 1; i++) {
+ if (test_bit(i, pernet->id_bitmap)) {
+ entry = __lookup_addr_by_id(pernet, i);
+ if (!entry)
+ break;
+
+ if (entry->addr.id <= id)
+ continue;
+
+ hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
+ cb->nlh->nlmsg_seq, &mptcp_genl_family,
+ NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
+ if (!hdr)
+ break;
+
+ if (mptcp_nl_fill_addr(msg, entry) < 0) {
+ genlmsg_cancel(msg, hdr);
+ break;
+ }
- if (mptcp_nl_fill_addr(msg, entry) < 0) {
- genlmsg_cancel(msg, hdr);
- break;
+ id = entry->addr.id;
+ genlmsg_end(msg, hdr);
}
-
- id = entry->addr.id;
- genlmsg_end(msg, hdr);
}
spin_unlock_bh(&pernet->lock);
@@ -1096,6 +1427,321 @@ fail:
return -EMSGSIZE;
}
+static int mptcp_nl_addr_backup(struct net *net,
+ struct mptcp_addr_info *addr,
+ u8 bkup)
+{
+ long s_slot = 0, s_num = 0;
+ struct mptcp_sock *msk;
+ int ret = -EINVAL;
+
+ while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
+ struct sock *sk = (struct sock *)msk;
+
+ if (list_empty(&msk->conn_list))
+ goto next;
+
+ lock_sock(sk);
+ spin_lock_bh(&msk->pm.lock);
+ ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup);
+ spin_unlock_bh(&msk->pm.lock);
+ release_sock(sk);
+
+next:
+ sock_put(sk);
+ cond_resched();
+ }
+
+ return ret;
+}
+
+static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
+{
+ struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
+ struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
+ struct mptcp_pm_addr_entry addr, *entry;
+ struct net *net = sock_net(skb->sk);
+ u8 bkup = 0;
+ int ret;
+
+ ret = mptcp_pm_parse_addr(attr, info, true, &addr);
+ if (ret < 0)
+ return ret;
+
+ if (addr.addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
+ bkup = 1;
+
+ list_for_each_entry(entry, &pernet->local_addr_list, list) {
+ if (addresses_equal(&entry->addr, &addr.addr, true)) {
+ ret = mptcp_nl_addr_backup(net, &entry->addr, bkup);
+ if (ret)
+ return ret;
+
+ if (bkup)
+ entry->addr.flags |= MPTCP_PM_ADDR_FLAG_BACKUP;
+ else
+ entry->addr.flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP;
+ }
+ }
+
+ return 0;
+}
+
+static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp)
+{
+ genlmsg_multicast_netns(&mptcp_genl_family, net,
+ nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp);
+}
+
+static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk)
+{
+ const struct inet_sock *issk = inet_sk(ssk);
+ const struct mptcp_subflow_context *sf;
+
+ if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family))
+ return -EMSGSIZE;
+
+ switch (ssk->sk_family) {
+ case AF_INET:
+ if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr))
+ return -EMSGSIZE;
+ if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr))
+ return -EMSGSIZE;
+ break;
+#if IS_ENABLED(CONFIG_MPTCP_IPV6)
+ case AF_INET6: {
+ const struct ipv6_pinfo *np = inet6_sk(ssk);
+
+ if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr))
+ return -EMSGSIZE;
+ if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr))
+ return -EMSGSIZE;
+ break;
+ }
+#endif
+ default:
+ WARN_ON_ONCE(1);
+ return -EMSGSIZE;
+ }
+
+ if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport))
+ return -EMSGSIZE;
+ if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport))
+ return -EMSGSIZE;
+
+ sf = mptcp_subflow_ctx(ssk);
+ if (WARN_ON_ONCE(!sf))
+ return -EINVAL;
+
+ if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id))
+ return -EMSGSIZE;
+
+ if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id))
+ return -EMSGSIZE;
+
+ return 0;
+}
+
+static int mptcp_event_put_token_and_ssk(struct sk_buff *skb,
+ const struct mptcp_sock *msk,
+ const struct sock *ssk)
+{
+ const struct sock *sk = (const struct sock *)msk;
+ const struct mptcp_subflow_context *sf;
+ u8 sk_err;
+
+ if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
+ return -EMSGSIZE;
+
+ if (mptcp_event_add_subflow(skb, ssk))
+ return -EMSGSIZE;
+
+ sf = mptcp_subflow_ctx(ssk);
+ if (WARN_ON_ONCE(!sf))
+ return -EINVAL;
+
+ if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup))
+ return -EMSGSIZE;
+
+ if (ssk->sk_bound_dev_if &&
+ nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if))
+ return -EMSGSIZE;
+
+ sk_err = ssk->sk_err;
+ if (sk_err && sk->sk_state == TCP_ESTABLISHED &&
+ nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err))
+ return -EMSGSIZE;
+
+ return 0;
+}
+
+static int mptcp_event_sub_established(struct sk_buff *skb,
+ const struct mptcp_sock *msk,
+ const struct sock *ssk)
+{
+ return mptcp_event_put_token_and_ssk(skb, msk, ssk);
+}
+
+static int mptcp_event_sub_closed(struct sk_buff *skb,
+ const struct mptcp_sock *msk,
+ const struct sock *ssk)
+{
+ if (mptcp_event_put_token_and_ssk(skb, msk, ssk))
+ return -EMSGSIZE;
+
+ return 0;
+}
+
+static int mptcp_event_created(struct sk_buff *skb,
+ const struct mptcp_sock *msk,
+ const struct sock *ssk)
+{
+ int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token);
+
+ if (err)
+ return err;
+
+ return mptcp_event_add_subflow(skb, ssk);
+}
+
+void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id)
+{
+ struct net *net = sock_net((const struct sock *)msk);
+ struct nlmsghdr *nlh;
+ struct sk_buff *skb;
+
+ if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
+ return;
+
+ skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
+ if (!skb)
+ return;
+
+ nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED);
+ if (!nlh)
+ goto nla_put_failure;
+
+ if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
+ goto nla_put_failure;
+
+ if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id))
+ goto nla_put_failure;
+
+ genlmsg_end(skb, nlh);
+ mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
+ return;
+
+nla_put_failure:
+ kfree_skb(skb);
+}
+
+void mptcp_event_addr_announced(const struct mptcp_sock *msk,
+ const struct mptcp_addr_info *info)
+{
+ struct net *net = sock_net((const struct sock *)msk);
+ struct nlmsghdr *nlh;
+ struct sk_buff *skb;
+
+ if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
+ return;
+
+ skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC);
+ if (!skb)
+ return;
+
+ nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0,
+ MPTCP_EVENT_ANNOUNCED);
+ if (!nlh)
+ goto nla_put_failure;
+
+ if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token))
+ goto nla_put_failure;
+
+ if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id))
+ goto nla_put_failure;
+
+ if (nla_put_be16(skb, MPTCP_ATTR_DPORT, info->port))
+ goto nla_put_failure;
+
+ switch (info->family) {
+ case AF_INET:
+ if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr))
+ goto nla_put_failure;
+ break;
+#if IS_ENABLED(CONFIG_MPTCP_IPV6)
+ case AF_INET6:
+ if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6))
+ goto nla_put_failure;
+ break;
+#endif
+ default:
+ WARN_ON_ONCE(1);
+ goto nla_put_failure;
+ }
+
+ genlmsg_end(skb, nlh);
+ mptcp_nl_mcast_send(net, skb, GFP_ATOMIC);
+ return;
+
+nla_put_failure:
+ kfree_skb(skb);
+}
+
+void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk,
+ const struct sock *ssk, gfp_t gfp)
+{
+ struct net *net = sock_net((const struct sock *)msk);
+ struct nlmsghdr *nlh;
+ struct sk_buff *skb;
+
+ if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET))
+ return;
+
+ skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp);
+ if (!skb)
+ return;
+
+ nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type);
+ if (!nlh)
+ goto nla_put_failure;
+
+ switch (type) {
+ case MPTCP_EVENT_UNSPEC:
+ WARN_ON_ONCE(1);
+ break;
+ case MPTCP_EVENT_CREATED:
+ case MPTCP_EVENT_ESTABLISHED:
+ if (mptcp_event_created(skb, msk, ssk) < 0)
+ goto nla_put_failure;
+ break;
+ case MPTCP_EVENT_CLOSED:
+ if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0)
+ goto nla_put_failure;
+ break;
+ case MPTCP_EVENT_ANNOUNCED:
+ case MPTCP_EVENT_REMOVED:
+ /* call mptcp_event_addr_announced()/removed instead */
+ WARN_ON_ONCE(1);
+ break;
+ case MPTCP_EVENT_SUB_ESTABLISHED:
+ case MPTCP_EVENT_SUB_PRIORITY:
+ if (mptcp_event_sub_established(skb, msk, ssk) < 0)
+ goto nla_put_failure;
+ break;
+ case MPTCP_EVENT_SUB_CLOSED:
+ if (mptcp_event_sub_closed(skb, msk, ssk) < 0)
+ goto nla_put_failure;
+ break;
+ }
+
+ genlmsg_end(skb, nlh);
+ mptcp_nl_mcast_send(net, skb, gfp);
+ return;
+
+nla_put_failure:
+ kfree_skb(skb);
+}
+
static const struct genl_small_ops mptcp_pm_ops[] = {
{
.cmd = MPTCP_PM_CMD_ADD_ADDR,
@@ -1126,6 +1772,11 @@ static const struct genl_small_ops mptcp_pm_ops[] = {
.cmd = MPTCP_PM_CMD_GET_LIMITS,
.doit = mptcp_nl_cmd_get_limits,
},
+ {
+ .cmd = MPTCP_PM_CMD_SET_FLAGS,
+ .doit = mptcp_nl_cmd_set_flags,
+ .flags = GENL_ADMIN_PERM,
+ },
};
static struct genl_family mptcp_genl_family __ro_after_init = {
@@ -1148,6 +1799,7 @@ static int __net_init pm_nl_init_net(struct net *net)
INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
__reset_counters(pernet);
pernet->next_id = 1;
+ bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1);
spin_lock_init(&pernet->lock);
return 0;
}