aboutsummaryrefslogtreecommitdiffstats
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--net/tls/Makefile2
-rw-r--r--net/tls/tls.h321
-rw-r--r--net/tls/tls_device.c345
-rw-r--r--net/tls/tls_device_fallback.c90
-rw-r--r--net/tls/tls_main.c267
-rw-r--r--net/tls/tls_proc.c4
-rw-r--r--net/tls/tls_strp.c518
-rw-r--r--net/tls/tls_sw.c1182
-rw-r--r--net/tls/tls_toe.c2
9 files changed, 2051 insertions, 680 deletions
diff --git a/net/tls/Makefile b/net/tls/Makefile
index f1ffbfe8968d..e41c800489ac 100644
--- a/net/tls/Makefile
+++ b/net/tls/Makefile
@@ -7,7 +7,7 @@ CFLAGS_trace.o := -I$(src)
obj-$(CONFIG_TLS) += tls.o
-tls-y := tls_main.o tls_sw.o tls_proc.o trace.o
+tls-y := tls_main.o tls_sw.o tls_proc.o trace.o tls_strp.o
tls-$(CONFIG_TLS_TOE) += tls_toe.o
tls-$(CONFIG_TLS_DEVICE) += tls_device.o tls_device_fallback.o
diff --git a/net/tls/tls.h b/net/tls/tls.h
new file mode 100644
index 000000000000..0e840a0c3437
--- /dev/null
+++ b/net/tls/tls.h
@@ -0,0 +1,321 @@
+/*
+ * Copyright (c) 2016 Tom Herbert <tom@herbertland.com>
+ * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
+ * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the
+ * OpenIB.org BSD license below:
+ *
+ * Redistribution and use in source and binary forms, with or
+ * without modification, are permitted provided that the following
+ * conditions are met:
+ *
+ * - Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * - Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+ * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef _TLS_INT_H
+#define _TLS_INT_H
+
+#include <asm/byteorder.h>
+#include <linux/types.h>
+#include <linux/skmsg.h>
+#include <net/tls.h>
+
+#define TLS_PAGE_ORDER (min_t(unsigned int, PAGE_ALLOC_COSTLY_ORDER, \
+ TLS_MAX_PAYLOAD_SIZE >> PAGE_SHIFT))
+
+#define __TLS_INC_STATS(net, field) \
+ __SNMP_INC_STATS((net)->mib.tls_statistics, field)
+#define TLS_INC_STATS(net, field) \
+ SNMP_INC_STATS((net)->mib.tls_statistics, field)
+#define TLS_DEC_STATS(net, field) \
+ SNMP_DEC_STATS((net)->mib.tls_statistics, field)
+
+/* TLS records are maintained in 'struct tls_rec'. It stores the memory pages
+ * allocated or mapped for each TLS record. After encryption, the records are
+ * stores in a linked list.
+ */
+struct tls_rec {
+ struct list_head list;
+ int tx_ready;
+ int tx_flags;
+
+ struct sk_msg msg_plaintext;
+ struct sk_msg msg_encrypted;
+
+ /* AAD | msg_plaintext.sg.data | sg_tag */
+ struct scatterlist sg_aead_in[2];
+ /* AAD | msg_encrypted.sg.data (data contains overhead for hdr & iv & tag) */
+ struct scatterlist sg_aead_out[2];
+
+ char content_type;
+ struct scatterlist sg_content_type;
+
+ char aad_space[TLS_AAD_SPACE_SIZE];
+ u8 iv_data[MAX_IV_SIZE];
+ struct aead_request aead_req;
+ u8 aead_req_ctx[];
+};
+
+int __net_init tls_proc_init(struct net *net);
+void __net_exit tls_proc_fini(struct net *net);
+
+struct tls_context *tls_ctx_create(struct sock *sk);
+void tls_ctx_free(struct sock *sk, struct tls_context *ctx);
+void update_sk_prot(struct sock *sk, struct tls_context *ctx);
+
+int wait_on_pending_writer(struct sock *sk, long *timeo);
+int tls_sk_query(struct sock *sk, int optname, char __user *optval,
+ int __user *optlen);
+int tls_sk_attach(struct sock *sk, int optname, char __user *optval,
+ unsigned int optlen);
+void tls_err_abort(struct sock *sk, int err);
+
+int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
+void tls_update_rx_zc_capable(struct tls_context *tls_ctx);
+void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
+void tls_sw_strparser_done(struct tls_context *tls_ctx);
+int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags);
+int tls_sw_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags);
+void tls_sw_cancel_work_tx(struct tls_context *tls_ctx);
+void tls_sw_release_resources_tx(struct sock *sk);
+void tls_sw_free_ctx_tx(struct tls_context *tls_ctx);
+void tls_sw_free_resources_rx(struct sock *sk);
+void tls_sw_release_resources_rx(struct sock *sk);
+void tls_sw_free_ctx_rx(struct tls_context *tls_ctx);
+int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+ int flags, int *addr_len);
+bool tls_sw_sock_is_readable(struct sock *sk);
+ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
+ struct pipe_inode_info *pipe,
+ size_t len, unsigned int flags);
+
+int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+int tls_device_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags);
+int tls_tx_records(struct sock *sk, int flags);
+
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx);
+void tls_device_write_space(struct sock *sk, struct tls_context *ctx);
+
+int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
+ unsigned char *record_type);
+int decrypt_skb(struct sock *sk, struct scatterlist *sgout);
+
+int tls_sw_fallback_init(struct sock *sk,
+ struct tls_offload_context_tx *offload_ctx,
+ struct tls_crypto_info *crypto_info);
+
+int tls_strp_dev_init(void);
+void tls_strp_dev_exit(void);
+
+void tls_strp_done(struct tls_strparser *strp);
+void tls_strp_stop(struct tls_strparser *strp);
+int tls_strp_init(struct tls_strparser *strp, struct sock *sk);
+void tls_strp_data_ready(struct tls_strparser *strp);
+
+void tls_strp_check_rcv(struct tls_strparser *strp);
+void tls_strp_msg_done(struct tls_strparser *strp);
+
+int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb);
+void tls_rx_msg_ready(struct tls_strparser *strp);
+
+void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh);
+int tls_strp_msg_cow(struct tls_sw_context_rx *ctx);
+struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx);
+int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst);
+
+static inline struct tls_msg *tls_msg(struct sk_buff *skb)
+{
+ struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb;
+
+ return &scb->tls;
+}
+
+static inline struct sk_buff *tls_strp_msg(struct tls_sw_context_rx *ctx)
+{
+ DEBUG_NET_WARN_ON_ONCE(!ctx->strp.msg_ready || !ctx->strp.anchor->len);
+ return ctx->strp.anchor;
+}
+
+static inline bool tls_strp_msg_ready(struct tls_sw_context_rx *ctx)
+{
+ return ctx->strp.msg_ready;
+}
+
+#ifdef CONFIG_TLS_DEVICE
+int tls_device_init(void);
+void tls_device_cleanup(void);
+int tls_set_device_offload(struct sock *sk, struct tls_context *ctx);
+void tls_device_free_resources_tx(struct sock *sk);
+int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx);
+void tls_device_offload_cleanup_rx(struct sock *sk);
+void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq);
+int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx);
+#else
+static inline int tls_device_init(void) { return 0; }
+static inline void tls_device_cleanup(void) {}
+
+static inline int
+tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
+{
+ return -EOPNOTSUPP;
+}
+
+static inline void tls_device_free_resources_tx(struct sock *sk) {}
+
+static inline int
+tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
+{
+ return -EOPNOTSUPP;
+}
+
+static inline void tls_device_offload_cleanup_rx(struct sock *sk) {}
+static inline void
+tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {}
+
+static inline int
+tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
+{
+ return 0;
+}
+#endif
+
+int tls_push_sg(struct sock *sk, struct tls_context *ctx,
+ struct scatterlist *sg, u16 first_offset,
+ int flags);
+int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
+ int flags);
+void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
+
+static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
+{
+ return !!ctx->partially_sent_record;
+}
+
+static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
+{
+ return tls_ctx->pending_open_record_frags;
+}
+
+static inline bool tls_bigint_increment(unsigned char *seq, int len)
+{
+ int i;
+
+ for (i = len - 1; i >= 0; i--) {
+ ++seq[i];
+ if (seq[i] != 0)
+ break;
+ }
+
+ return (i == -1);
+}
+
+static inline void tls_bigint_subtract(unsigned char *seq, int n)
+{
+ u64 rcd_sn;
+ __be64 *p;
+
+ BUILD_BUG_ON(TLS_MAX_REC_SEQ_SIZE != 8);
+
+ p = (__be64 *)seq;
+ rcd_sn = be64_to_cpu(*p);
+ *p = cpu_to_be64(rcd_sn - n);
+}
+
+static inline void
+tls_advance_record_sn(struct sock *sk, struct tls_prot_info *prot,
+ struct cipher_context *ctx)
+{
+ if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size))
+ tls_err_abort(sk, -EBADMSG);
+
+ if (prot->version != TLS_1_3_VERSION &&
+ prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
+ tls_bigint_increment(ctx->iv + prot->salt_size,
+ prot->iv_size);
+}
+
+static inline void
+tls_xor_iv_with_seq(struct tls_prot_info *prot, char *iv, char *seq)
+{
+ int i;
+
+ if (prot->version == TLS_1_3_VERSION ||
+ prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
+ for (i = 0; i < 8; i++)
+ iv[i + 4] ^= seq[i];
+ }
+}
+
+static inline void
+tls_fill_prepend(struct tls_context *ctx, char *buf, size_t plaintext_len,
+ unsigned char record_type)
+{
+ struct tls_prot_info *prot = &ctx->prot_info;
+ size_t pkt_len, iv_size = prot->iv_size;
+
+ pkt_len = plaintext_len + prot->tag_size;
+ if (prot->version != TLS_1_3_VERSION &&
+ prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) {
+ pkt_len += iv_size;
+
+ memcpy(buf + TLS_NONCE_OFFSET,
+ ctx->tx.iv + prot->salt_size, iv_size);
+ }
+
+ /* we cover nonce explicit here as well, so buf should be of
+ * size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE
+ */
+ buf[0] = prot->version == TLS_1_3_VERSION ?
+ TLS_RECORD_TYPE_DATA : record_type;
+ /* Note that VERSION must be TLS_1_2 for both TLS1.2 and TLS1.3 */
+ buf[1] = TLS_1_2_VERSION_MINOR;
+ buf[2] = TLS_1_2_VERSION_MAJOR;
+ /* we can use IV for nonce explicit according to spec */
+ buf[3] = pkt_len >> 8;
+ buf[4] = pkt_len & 0xFF;
+}
+
+static inline
+void tls_make_aad(char *buf, size_t size, char *record_sequence,
+ unsigned char record_type, struct tls_prot_info *prot)
+{
+ if (prot->version != TLS_1_3_VERSION) {
+ memcpy(buf, record_sequence, prot->rec_seq_size);
+ buf += 8;
+ } else {
+ size += prot->tag_size;
+ }
+
+ buf[0] = prot->version == TLS_1_3_VERSION ?
+ TLS_RECORD_TYPE_DATA : record_type;
+ buf[1] = TLS_1_2_VERSION_MAJOR;
+ buf[2] = TLS_1_2_VERSION_MINOR;
+ buf[3] = size >> 8;
+ buf[4] = size & 0xFF;
+}
+
+#endif
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index b932469ee69c..a03d66046ca3 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -38,6 +38,7 @@
#include <net/tcp.h>
#include <net/tls.h>
+#include "tls.h"
#include "trace.h"
/* device_offload_lock is used to synchronize tls_dev_add
@@ -45,10 +46,8 @@
*/
static DECLARE_RWSEM(device_offload_lock);
-static void tls_device_gc_task(struct work_struct *work);
+static struct workqueue_struct *destruct_wq __read_mostly;
-static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task);
-static LIST_HEAD(tls_device_gc_list);
static LIST_HEAD(tls_device_list);
static LIST_HEAD(tls_device_down_list);
static DEFINE_SPINLOCK(tls_device_lock);
@@ -67,44 +66,58 @@ static void tls_device_free_ctx(struct tls_context *ctx)
tls_ctx_free(NULL, ctx);
}
-static void tls_device_gc_task(struct work_struct *work)
+static void tls_device_tx_del_task(struct work_struct *work)
{
- struct tls_context *ctx, *tmp;
- unsigned long flags;
- LIST_HEAD(gc_list);
-
- spin_lock_irqsave(&tls_device_lock, flags);
- list_splice_init(&tls_device_gc_list, &gc_list);
- spin_unlock_irqrestore(&tls_device_lock, flags);
-
- list_for_each_entry_safe(ctx, tmp, &gc_list, list) {
- struct net_device *netdev = ctx->netdev;
+ struct tls_offload_context_tx *offload_ctx =
+ container_of(work, struct tls_offload_context_tx, destruct_work);
+ struct tls_context *ctx = offload_ctx->ctx;
+ struct net_device *netdev;
- if (netdev && ctx->tx_conf == TLS_HW) {
- netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
- TLS_OFFLOAD_CTX_DIR_TX);
- dev_put(netdev);
- ctx->netdev = NULL;
- }
+ /* Safe, because this is the destroy flow, refcount is 0, so
+ * tls_device_down can't store this field in parallel.
+ */
+ netdev = rcu_dereference_protected(ctx->netdev,
+ !refcount_read(&ctx->refcount));
- list_del(&ctx->list);
- tls_device_free_ctx(ctx);
- }
+ netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX);
+ dev_put(netdev);
+ ctx->netdev = NULL;
+ tls_device_free_ctx(ctx);
}
static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
{
+ struct net_device *netdev;
unsigned long flags;
+ bool async_cleanup;
spin_lock_irqsave(&tls_device_lock, flags);
- list_move_tail(&ctx->list, &tls_device_gc_list);
+ if (unlikely(!refcount_dec_and_test(&ctx->refcount))) {
+ spin_unlock_irqrestore(&tls_device_lock, flags);
+ return;
+ }
+
+ list_del(&ctx->list); /* Remove from tls_device_list / tls_device_down_list */
- /* schedule_work inside the spinlock
- * to make sure tls_device_down waits for that work.
+ /* Safe, because this is the destroy flow, refcount is 0, so
+ * tls_device_down can't store this field in parallel.
*/
- schedule_work(&tls_device_gc_work);
+ netdev = rcu_dereference_protected(ctx->netdev,
+ !refcount_read(&ctx->refcount));
+ async_cleanup = netdev && ctx->tx_conf == TLS_HW;
+ if (async_cleanup) {
+ struct tls_offload_context_tx *offload_ctx = tls_offload_ctx_tx(ctx);
+
+ /* queue_work inside the spinlock
+ * to make sure tls_device_down waits for that work.
+ */
+ queue_work(destruct_wq, &offload_ctx->destruct_work);
+ }
spin_unlock_irqrestore(&tls_device_lock, flags);
+
+ if (!async_cleanup)
+ tls_device_free_ctx(ctx);
}
/* We assume that the socket is already connected */
@@ -194,8 +207,7 @@ void tls_device_sk_destruct(struct sock *sk)
clean_acked_data_disable(inet_csk(sk));
}
- if (refcount_dec_and_test(&tls_ctx->refcount))
- tls_device_queue_ctx_destruction(tls_ctx);
+ tls_device_queue_ctx_destruction(tls_ctx);
}
EXPORT_SYMBOL_GPL(tls_device_sk_destruct);
@@ -231,7 +243,8 @@ static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
trace_tls_device_tx_resync_send(sk, seq, rcd_sn);
down_read(&device_offload_lock);
- netdev = tls_ctx->netdev;
+ netdev = rcu_dereference_protected(tls_ctx->netdev,
+ lockdep_is_held(&device_offload_lock));
if (netdev)
err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
rcd_sn,
@@ -411,10 +424,16 @@ static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i)
return 0;
}
+union tls_iter_offset {
+ struct iov_iter *msg_iter;
+ int offset;
+};
+
static int tls_push_data(struct sock *sk,
- struct iov_iter *msg_iter,
+ union tls_iter_offset iter_offset,
size_t size, int flags,
- unsigned char record_type)
+ unsigned char record_type,
+ struct page *zc_page)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
@@ -480,14 +499,25 @@ handle_error:
}
record = ctx->open_record;
- copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
- copy = min_t(size_t, copy, (max_open_record_len - record->len));
- rc = tls_device_copy_data(page_address(pfrag->page) +
- pfrag->offset, copy, msg_iter);
- if (rc)
- goto handle_error;
- tls_append_frag(record, pfrag, copy);
+ copy = min_t(size_t, size, max_open_record_len - record->len);
+ if (copy && zc_page) {
+ struct page_frag zc_pfrag;
+
+ zc_pfrag.page = zc_page;
+ zc_pfrag.offset = iter_offset.offset;
+ zc_pfrag.size = copy;
+ tls_append_frag(record, &zc_pfrag, copy);
+ } else if (copy) {
+ copy = min_t(size_t, copy, pfrag->size - pfrag->offset);
+
+ rc = tls_device_copy_data(page_address(pfrag->page) +
+ pfrag->offset, copy,
+ iter_offset.msg_iter);
+ if (rc)
+ goto handle_error;
+ tls_append_frag(record, pfrag, copy);
+ }
size -= copy;
if (!size) {
@@ -538,19 +568,20 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
{
unsigned char record_type = TLS_RECORD_TYPE_DATA;
struct tls_context *tls_ctx = tls_get_ctx(sk);
+ union tls_iter_offset iter;
int rc;
mutex_lock(&tls_ctx->tx_lock);
lock_sock(sk);
if (unlikely(msg->msg_controllen)) {
- rc = tls_proccess_cmsg(sk, msg, &record_type);
+ rc = tls_process_cmsg(sk, msg, &record_type);
if (rc)
goto out;
}
- rc = tls_push_data(sk, &msg->msg_iter, size,
- msg->msg_flags, record_type);
+ iter.msg_iter = &msg->msg_iter;
+ rc = tls_push_data(sk, iter, size, msg->msg_flags, record_type, NULL);
out:
release_sock(sk);
@@ -562,7 +593,8 @@ int tls_device_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
- struct iov_iter msg_iter;
+ union tls_iter_offset iter_offset;
+ struct iov_iter msg_iter;
char *kaddr;
struct kvec iov;
int rc;
@@ -578,12 +610,20 @@ int tls_device_sendpage(struct sock *sk, struct page *page,
goto out;
}
+ if (tls_ctx->zerocopy_sendfile) {
+ iter_offset.offset = offset;
+ rc = tls_push_data(sk, iter_offset, size,
+ flags, TLS_RECORD_TYPE_DATA, page);
+ goto out;
+ }
+
kaddr = kmap(page);
iov.iov_base = kaddr + offset;
iov.iov_len = size;
iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size);
- rc = tls_push_data(sk, &msg_iter, size,
- flags, TLS_RECORD_TYPE_DATA);
+ iter_offset.msg_iter = &msg_iter;
+ rc = tls_push_data(sk, iter_offset, size, flags, TLS_RECORD_TYPE_DATA,
+ NULL);
kunmap(page);
out:
@@ -654,10 +694,12 @@ EXPORT_SYMBOL(tls_get_record);
static int tls_device_push_pending_record(struct sock *sk, int flags)
{
- struct iov_iter msg_iter;
+ union tls_iter_offset iter;
+ struct iov_iter msg_iter;
iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0);
- return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
+ iter.msg_iter = &msg_iter;
+ return tls_push_data(sk, iter, 0, flags, TLS_RECORD_TYPE_DATA, NULL);
}
void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
@@ -683,7 +725,7 @@ static void tls_device_resync_rx(struct tls_context *tls_ctx,
trace_tls_device_rx_resync_send(sk, seq, rcd_sn, rx_ctx->resync_type);
rcu_read_lock();
- netdev = READ_ONCE(tls_ctx->netdev);
+ netdev = rcu_dereference(tls_ctx->netdev);
if (netdev)
netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
TLS_OFFLOAD_CTX_DIR_RX);
@@ -859,43 +901,56 @@ static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
}
}
-static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
+static int
+tls_device_reencrypt(struct sock *sk, struct tls_context *tls_ctx)
{
- struct strp_msg *rxm = strp_msg(skb);
- int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
- struct sk_buff *skb_iter, *unused;
+ struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
+ const struct tls_cipher_size_desc *cipher_sz;
+ int err, offset, copy, data_len, pos;
+ struct sk_buff *skb, *skb_iter;
struct scatterlist sg[1];
+ struct strp_msg *rxm;
char *orig_buf, *buf;
- orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation);
+ switch (tls_ctx->crypto_recv.info.cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ case TLS_CIPHER_AES_GCM_256:
+ break;
+ default:
+ return -EINVAL;
+ }
+ cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_recv.info.cipher_type];
+
+ rxm = strp_msg(tls_strp_msg(sw_ctx));
+ orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv,
+ sk->sk_allocation);
if (!orig_buf)
return -ENOMEM;
buf = orig_buf;
- nsg = skb_cow_data(skb, 0, &unused);
- if (unlikely(nsg < 0)) {
- err = nsg;
+ err = tls_strp_msg_cow(sw_ctx);
+ if (unlikely(err))
goto free_buf;
- }
+
+ skb = tls_strp_msg(sw_ctx);
+ rxm = strp_msg(skb);
+ offset = rxm->offset;
sg_init_table(sg, 1);
sg_set_buf(&sg[0], buf,
- rxm->full_len + TLS_HEADER_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE);
- err = skb_copy_bits(skb, offset, buf,
- TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+ rxm->full_len + TLS_HEADER_SIZE + cipher_sz->iv);
+ err = skb_copy_bits(skb, offset, buf, TLS_HEADER_SIZE + cipher_sz->iv);
if (err)
goto free_buf;
/* We are interested only in the decrypted data not the auth */
- err = decrypt_skb(sk, skb, sg);
+ err = decrypt_skb(sk, sg);
if (err != -EBADMSG)
goto free_buf;
else
err = 0;
- data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ data_len = rxm->full_len - cipher_sz->tag;
if (skb_pagelen(skb) > offset) {
copy = min_t(int, skb_pagelen(skb) - offset, data_len);
@@ -944,35 +999,41 @@ free_buf:
return err;
}
-int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
- struct sk_buff *skb, struct strp_msg *rxm)
+int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx)
{
struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx);
+ struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx);
+ struct sk_buff *skb = tls_strp_msg(sw_ctx);
+ struct strp_msg *rxm = strp_msg(skb);
int is_decrypted = skb->decrypted;
int is_encrypted = !is_decrypted;
struct sk_buff *skb_iter;
+ int left;
+ left = rxm->full_len - skb->len;
/* Check if all the data is decrypted already */
- skb_walk_frags(skb, skb_iter) {
+ skb_iter = skb_shinfo(skb)->frag_list;
+ while (skb_iter && left > 0) {
is_decrypted &= skb_iter->decrypted;
is_encrypted &= !skb_iter->decrypted;
+
+ left -= skb_iter->len;
+ skb_iter = skb_iter->next;
}
trace_tls_device_decrypted(sk, tcp_sk(sk)->copied_seq - rxm->full_len,
tls_ctx->rx.rec_seq, rxm->full_len,
is_encrypted, is_decrypted);
- ctx->sw.decrypted |= is_decrypted;
-
if (unlikely(test_bit(TLS_RX_DEV_DEGRADED, &tls_ctx->flags))) {
if (likely(is_encrypted || is_decrypted))
- return 0;
+ return is_decrypted;
/* After tls_device_down disables the offload, the next SKB will
* likely have initial fragments decrypted, and final ones not
* decrypted. We need to reencrypt that single SKB.
*/
- return tls_device_reencrypt(sk, skb);
+ return tls_device_reencrypt(sk, tls_ctx);
}
/* Return immediately if the record is either entirely plaintext or
@@ -981,7 +1042,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
*/
if (is_decrypted) {
ctx->resync_nh_reset = 1;
- return 0;
+ return is_decrypted;
}
if (is_encrypted) {
tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
@@ -989,7 +1050,7 @@ int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
}
ctx->resync_nh_reset = 1;
- return tls_device_reencrypt(sk, skb);
+ return tls_device_reencrypt(sk, tls_ctx);
}
static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
@@ -998,7 +1059,7 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
if (sk->sk_destruct != tls_device_sk_destruct) {
refcount_set(&ctx->refcount, 1);
dev_hold(netdev);
- ctx->netdev = netdev;
+ RCU_INIT_POINTER(ctx->netdev, netdev);
spin_lock_irq(&tls_device_lock);
list_add_tail(&ctx->list, &tls_device_list);
spin_unlock_irq(&tls_device_lock);
@@ -1010,9 +1071,9 @@ static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
{
- u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
+ const struct tls_cipher_size_desc *cipher_sz;
struct tls_record_info *start_marker_record;
struct tls_offload_context_tx *offload_ctx;
struct tls_crypto_info *crypto_info;
@@ -1028,70 +1089,83 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
if (ctx->priv_ctx_tx)
return -EEXIST;
- start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
- if (!start_marker_record)
- return -ENOMEM;
+ netdev = get_netdev_for_sock(sk);
+ if (!netdev) {
+ pr_err_ratelimited("%s: netdev not found\n", __func__);
+ return -EINVAL;
+ }
- offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
- if (!offload_ctx) {
- rc = -ENOMEM;
- goto free_marker_record;
+ if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
+ rc = -EOPNOTSUPP;
+ goto release_netdev;
}
crypto_info = &ctx->crypto_send.info;
if (crypto_info->version != TLS_1_2_VERSION) {
rc = -EOPNOTSUPP;
- goto free_offload_ctx;
+ goto release_netdev;
}
switch (crypto_info->cipher_type) {
case TLS_CIPHER_AES_GCM_128:
- nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
- tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
- iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
- rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
- salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
rec_seq =
((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
break;
+ case TLS_CIPHER_AES_GCM_256:
+ iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
+ rec_seq =
+ ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
+ break;
default:
rc = -EINVAL;
- goto free_offload_ctx;
+ goto release_netdev;
}
+ cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type];
/* Sanity-check the rec_seq_size for stack allocations */
- if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
+ if (cipher_sz->rec_seq > TLS_MAX_REC_SEQ_SIZE) {
rc = -EINVAL;
- goto free_offload_ctx;
+ goto release_netdev;
}
prot->version = crypto_info->version;
prot->cipher_type = crypto_info->cipher_type;
- prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
- prot->tag_size = tag_size;
+ prot->prepend_size = TLS_HEADER_SIZE + cipher_sz->iv;
+ prot->tag_size = cipher_sz->tag;
prot->overhead_size = prot->prepend_size + prot->tag_size;
- prot->iv_size = iv_size;
- prot->salt_size = salt_size;
- ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
- GFP_KERNEL);
+ prot->iv_size = cipher_sz->iv;
+ prot->salt_size = cipher_sz->salt;
+ ctx->tx.iv = kmalloc(cipher_sz->iv + cipher_sz->salt, GFP_KERNEL);
if (!ctx->tx.iv) {
rc = -ENOMEM;
- goto free_offload_ctx;
+ goto release_netdev;
}
- memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+ memcpy(ctx->tx.iv + cipher_sz->salt, iv, cipher_sz->iv);
- prot->rec_seq_size = rec_seq_size;
- ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
+ prot->rec_seq_size = cipher_sz->rec_seq;
+ ctx->tx.rec_seq = kmemdup(rec_seq, cipher_sz->rec_seq, GFP_KERNEL);
if (!ctx->tx.rec_seq) {
rc = -ENOMEM;
goto free_iv;
}
+ start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
+ if (!start_marker_record) {
+ rc = -ENOMEM;
+ goto free_rec_seq;
+ }
+
+ offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
+ if (!offload_ctx) {
+ rc = -ENOMEM;
+ goto free_marker_record;
+ }
+
rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info);
if (rc)
- goto free_rec_seq;
+ goto free_offload_ctx;
/* start at rec_seq - 1 to account for the start marker record */
memcpy(&rcd_sn, ctx->tx.rec_seq, sizeof(rcd_sn));
@@ -1101,6 +1175,9 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
start_marker_record->len = 0;
start_marker_record->num_frags = 0;
+ INIT_WORK(&offload_ctx->destruct_work, tls_device_tx_del_task);
+ offload_ctx->ctx = ctx;
+
INIT_LIST_HEAD(&offload_ctx->records_list);
list_add_tail(&start_marker_record->list, &offload_ctx->records_list);
spin_lock_init(&offload_ctx->lock);
@@ -1118,18 +1195,6 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
if (skb)
TCP_SKB_CB(skb)->eor = 1;
- netdev = get_netdev_for_sock(sk);
- if (!netdev) {
- pr_err_ratelimited("%s: netdev not found\n", __func__);
- rc = -EINVAL;
- goto disable_cad;
- }
-
- if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
- rc = -EOPNOTSUPP;
- goto release_netdev;
- }
-
/* Avoid offloading if the device is down
* We don't want to offload new flows after
* the NETDEV_DOWN event
@@ -1167,20 +1232,19 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
release_lock:
up_read(&device_offload_lock);
-release_netdev:
- dev_put(netdev);
-disable_cad:
clean_acked_data_disable(inet_csk(sk));
crypto_free_aead(offload_ctx->aead_send);
-free_rec_seq:
- kfree(ctx->tx.rec_seq);
-free_iv:
- kfree(ctx->tx.iv);
free_offload_ctx:
kfree(offload_ctx);
ctx->priv_ctx_tx = NULL;
free_marker_record:
kfree(start_marker_record);
+free_rec_seq:
+ kfree(ctx->tx.rec_seq);
+free_iv:
+ kfree(ctx->tx.iv);
+release_netdev:
+ dev_put(netdev);
return rc;
}
@@ -1266,7 +1330,8 @@ void tls_device_offload_cleanup_rx(struct sock *sk)
struct net_device *netdev;
down_read(&device_offload_lock);
- netdev = tls_ctx->netdev;
+ netdev = rcu_dereference_protected(tls_ctx->netdev,
+ lockdep_is_held(&device_offload_lock));
if (!netdev)
goto out;
@@ -1275,7 +1340,7 @@ void tls_device_offload_cleanup_rx(struct sock *sk)
if (tls_ctx->tx_conf != TLS_HW) {
dev_put(netdev);
- tls_ctx->netdev = NULL;
+ rcu_assign_pointer(tls_ctx->netdev, NULL);
} else {
set_bit(TLS_RX_DEV_CLOSED, &tls_ctx->flags);
}
@@ -1295,7 +1360,11 @@ static int tls_device_down(struct net_device *netdev)
spin_lock_irqsave(&tls_device_lock, flags);
list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) {
- if (ctx->netdev != netdev ||
+ struct net_device *ctx_netdev =
+ rcu_dereference_protected(ctx->netdev,
+ lockdep_is_held(&device_offload_lock));
+
+ if (ctx_netdev != netdev ||
!refcount_inc_not_zero(&ctx->refcount))
continue;
@@ -1312,7 +1381,7 @@ static int tls_device_down(struct net_device *netdev)
/* Stop the RX and TX resync.
* tls_dev_resync must not be called after tls_dev_del.
*/
- WRITE_ONCE(ctx->netdev, NULL);
+ rcu_assign_pointer(ctx->netdev, NULL);
/* Start skipping the RX resync logic completely. */
set_bit(TLS_RX_DEV_DEGRADED, &ctx->flags);
@@ -1345,12 +1414,20 @@ static int tls_device_down(struct net_device *netdev)
/* Device contexts for RX and TX will be freed in on sk_destruct
* by tls_device_free_ctx. rx_conf and tx_conf stay in TLS_HW.
+ * Now release the ref taken above.
*/
+ if (refcount_dec_and_test(&ctx->refcount)) {
+ /* sk_destruct ran after tls_device_down took a ref, and
+ * it returned early. Complete the destruction here.
+ */
+ list_del(&ctx->list);
+ tls_device_free_ctx(ctx);
+ }
}
up_write(&device_offload_lock);
- flush_work(&tls_device_gc_work);
+ flush_workqueue(destruct_wq);
return NOTIFY_DONE;
}
@@ -1389,14 +1466,24 @@ static struct notifier_block tls_dev_notifier = {
.notifier_call = tls_dev_event,
};
-void __init tls_device_init(void)
+int __init tls_device_init(void)
{
- register_netdevice_notifier(&tls_dev_notifier);
+ int err;
+
+ destruct_wq = alloc_workqueue("ktls_device_destruct", 0, 0);
+ if (!destruct_wq)
+ return -ENOMEM;
+
+ err = register_netdevice_notifier(&tls_dev_notifier);
+ if (err)
+ destroy_workqueue(destruct_wq);
+
+ return err;
}
void __exit tls_device_cleanup(void)
{
unregister_netdevice_notifier(&tls_dev_notifier);
- flush_work(&tls_device_gc_work);
+ destroy_workqueue(destruct_wq);
clean_acked_data_flush();
}
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
index e40bedd112b6..cdb391a8754b 100644
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -34,6 +34,8 @@
#include <crypto/scatterwalk.h>
#include <net/ip6_checksum.h>
+#include "tls.h"
+
static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
{
struct scatterlist *src = walk->sg;
@@ -52,13 +54,25 @@ static int tls_enc_record(struct aead_request *aead_req,
struct scatter_walk *out, int *in_len,
struct tls_prot_info *prot)
{
- unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE];
+ unsigned char buf[TLS_HEADER_SIZE + MAX_IV_SIZE];
+ const struct tls_cipher_size_desc *cipher_sz;
struct scatterlist sg_in[3];
struct scatterlist sg_out[3];
+ unsigned int buf_size;
u16 len;
int rc;
- len = min_t(int, *in_len, ARRAY_SIZE(buf));
+ switch (prot->cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ case TLS_CIPHER_AES_GCM_256:
+ break;
+ default:
+ return -EINVAL;
+ }
+ cipher_sz = &tls_cipher_size_desc[prot->cipher_type];
+
+ buf_size = TLS_HEADER_SIZE + cipher_sz->iv;
+ len = min_t(int, *in_len, buf_size);
scatterwalk_copychunks(buf, in, len, 0);
scatterwalk_copychunks(buf, out, len, 1);
@@ -71,13 +85,11 @@ static int tls_enc_record(struct aead_request *aead_req,
scatterwalk_pagedone(out, 1, 1);
len = buf[4] | (buf[3] << 8);
- len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
+ len -= cipher_sz->iv;
- tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
- (char *)&rcd_sn, buf[0], prot);
+ tls_make_aad(aad, len - cipher_sz->tag, (char *)&rcd_sn, buf[0], prot);
- memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
- TLS_CIPHER_AES_GCM_128_IV_SIZE);
+ memcpy(iv + cipher_sz->salt, buf + TLS_HEADER_SIZE, cipher_sz->iv);
sg_init_table(sg_in, ARRAY_SIZE(sg_in));
sg_init_table(sg_out, ARRAY_SIZE(sg_out));
@@ -88,7 +100,7 @@ static int tls_enc_record(struct aead_request *aead_req,
*in_len -= len;
if (*in_len < 0) {
- *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ *in_len += cipher_sz->tag;
/* the input buffer doesn't contain the entire record.
* trim len accordingly. The resulting authentication tag
* will contain garbage, but we don't care, so we won't
@@ -109,7 +121,7 @@ static int tls_enc_record(struct aead_request *aead_req,
scatterwalk_pagedone(out, 1, 1);
}
- len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ len -= cipher_sz->tag;
aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv);
rc = crypto_aead_encrypt(aead_req);
@@ -232,7 +244,7 @@ static int fill_sg_in(struct scatterlist *sg_in,
s32 *sync_size,
int *resync_sgs)
{
- int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
+ int tcp_payload_offset = skb_tcp_all_headers(skb);
int payload_len = skb->len - tcp_payload_offset;
u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
struct tls_record_info *record;
@@ -297,11 +309,14 @@ static void fill_sg_out(struct scatterlist sg_out[3], void *buf,
int sync_size,
void *dummy_buf)
{
+ const struct tls_cipher_size_desc *cipher_sz =
+ &tls_cipher_size_desc[tls_ctx->crypto_send.info.cipher_type];
+
sg_set_buf(&sg_out[0], dummy_buf, sync_size);
sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, payload_len);
/* Add room for authentication tag produced by crypto */
dummy_buf += sync_size;
- sg_set_buf(&sg_out[2], dummy_buf, TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+ sg_set_buf(&sg_out[2], dummy_buf, cipher_sz->tag);
}
static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
@@ -310,10 +325,11 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
struct sk_buff *skb,
s32 sync_size, u64 rcd_sn)
{
- int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
+ int tcp_payload_offset = skb_tcp_all_headers(skb);
int payload_len = skb->len - tcp_payload_offset;
- void *buf, *iv, *aad, *dummy_buf;
+ const struct tls_cipher_size_desc *cipher_sz;
+ void *buf, *iv, *aad, *dummy_buf, *salt;
struct aead_request *aead_req;
struct sk_buff *nskb = NULL;
int buf_len;
@@ -322,20 +338,26 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
if (!aead_req)
return NULL;
- buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE +
- TLS_AAD_SPACE_SIZE +
- sync_size +
- TLS_CIPHER_AES_GCM_128_TAG_SIZE;
+ switch (tls_ctx->crypto_send.info.cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ salt = tls_ctx->crypto_send.aes_gcm_128.salt;
+ break;
+ case TLS_CIPHER_AES_GCM_256:
+ salt = tls_ctx->crypto_send.aes_gcm_256.salt;
+ break;
+ default:
+ return NULL;
+ }
+ cipher_sz = &tls_cipher_size_desc[tls_ctx->crypto_send.info.cipher_type];
+ buf_len = cipher_sz->salt + cipher_sz->iv + TLS_AAD_SPACE_SIZE +
+ sync_size + cipher_sz->tag;
buf = kmalloc(buf_len, GFP_ATOMIC);
if (!buf)
goto free_req;
iv = buf;
- memcpy(iv, tls_ctx->crypto_send.aes_gcm_128.salt,
- TLS_CIPHER_AES_GCM_128_SALT_SIZE);
- aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE +
- TLS_CIPHER_AES_GCM_128_IV_SIZE;
+ memcpy(iv, salt, cipher_sz->salt);
+ aad = buf + cipher_sz->salt + cipher_sz->iv;
dummy_buf = aad + TLS_AAD_SPACE_SIZE;
nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC);
@@ -372,7 +394,7 @@ free_nskb:
static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
{
- int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
+ int tcp_payload_offset = skb_tcp_all_headers(skb);
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
int payload_len = skb->len - tcp_payload_offset;
@@ -424,7 +446,8 @@ struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
struct net_device *dev,
struct sk_buff *skb)
{
- if (dev == tls_get_ctx(sk)->netdev || netif_is_bond_master(dev))
+ if (dev == rcu_dereference_bh(tls_get_ctx(sk)->netdev) ||
+ netif_is_bond_master(dev))
return skb;
return tls_sw_fallback(sk, skb);
@@ -448,6 +471,7 @@ int tls_sw_fallback_init(struct sock *sk,
struct tls_offload_context_tx *offload_ctx,
struct tls_crypto_info *crypto_info)
{
+ const struct tls_cipher_size_desc *cipher_sz;
const u8 *key;
int rc;
@@ -460,15 +484,23 @@ int tls_sw_fallback_init(struct sock *sk,
goto err_out;
}
- key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
+ switch (crypto_info->cipher_type) {
+ case TLS_CIPHER_AES_GCM_128:
+ key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key;
+ break;
+ case TLS_CIPHER_AES_GCM_256:
+ key = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->key;
+ break;
+ default:
+ return -EINVAL;
+ }
+ cipher_sz = &tls_cipher_size_desc[crypto_info->cipher_type];
- rc = crypto_aead_setkey(offload_ctx->aead_send, key,
- TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+ rc = crypto_aead_setkey(offload_ctx->aead_send, key, cipher_sz->key);
if (rc)
goto free_aead;
- rc = crypto_aead_setauthsize(offload_ctx->aead_send,
- TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+ rc = crypto_aead_setauthsize(offload_ctx->aead_send, cipher_sz->tag);
if (rc)
goto free_aead;
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 6bc2879ba637..3735cb00905d 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -45,6 +45,8 @@
#include <net/tls.h>
#include <net/tls_toe.h>
+#include "tls.h"
+
MODULE_AUTHOR("Mellanox Technologies");
MODULE_DESCRIPTION("Transport Layer Security Support");
MODULE_LICENSE("Dual BSD/GPL");
@@ -56,6 +58,23 @@ enum {
TLS_NUM_PROTS,
};
+#define CIPHER_SIZE_DESC(cipher) [cipher] = { \
+ .iv = cipher ## _IV_SIZE, \
+ .key = cipher ## _KEY_SIZE, \
+ .salt = cipher ## _SALT_SIZE, \
+ .tag = cipher ## _TAG_SIZE, \
+ .rec_seq = cipher ## _REC_SEQ_SIZE, \
+}
+
+const struct tls_cipher_size_desc tls_cipher_size_desc[] = {
+ CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_128),
+ CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_256),
+ CIPHER_SIZE_DESC(TLS_CIPHER_AES_CCM_128),
+ CIPHER_SIZE_DESC(TLS_CIPHER_CHACHA20_POLY1305),
+ CIPHER_SIZE_DESC(TLS_CIPHER_SM4_GCM),
+ CIPHER_SIZE_DESC(TLS_CIPHER_SM4_CCM),
+};
+
static const struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
static const struct proto *saved_tcpv4_prot;
@@ -164,8 +183,8 @@ static int tls_handle_open_record(struct sock *sk, int flags)
return 0;
}
-int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
- unsigned char *record_type)
+int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
+ unsigned char *record_type)
{
struct cmsghdr *cmsg;
int rc = -EINVAL;
@@ -505,6 +524,54 @@ static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
rc = -EFAULT;
break;
}
+ case TLS_CIPHER_ARIA_GCM_128: {
+ struct tls12_crypto_info_aria_gcm_128 *
+ crypto_info_aria_gcm_128 =
+ container_of(crypto_info,
+ struct tls12_crypto_info_aria_gcm_128,
+ info);
+
+ if (len != sizeof(*crypto_info_aria_gcm_128)) {
+ rc = -EINVAL;
+ goto out;
+ }
+ lock_sock(sk);
+ memcpy(crypto_info_aria_gcm_128->iv,
+ cctx->iv + TLS_CIPHER_ARIA_GCM_128_SALT_SIZE,
+ TLS_CIPHER_ARIA_GCM_128_IV_SIZE);
+ memcpy(crypto_info_aria_gcm_128->rec_seq, cctx->rec_seq,
+ TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE);
+ release_sock(sk);
+ if (copy_to_user(optval,
+ crypto_info_aria_gcm_128,
+ sizeof(*crypto_info_aria_gcm_128)))
+ rc = -EFAULT;
+ break;
+ }
+ case TLS_CIPHER_ARIA_GCM_256: {
+ struct tls12_crypto_info_aria_gcm_256 *
+ crypto_info_aria_gcm_256 =
+ container_of(crypto_info,
+ struct tls12_crypto_info_aria_gcm_256,
+ info);
+
+ if (len != sizeof(*crypto_info_aria_gcm_256)) {
+ rc = -EINVAL;
+ goto out;
+ }
+ lock_sock(sk);
+ memcpy(crypto_info_aria_gcm_256->iv,
+ cctx->iv + TLS_CIPHER_ARIA_GCM_256_SALT_SIZE,
+ TLS_CIPHER_ARIA_GCM_256_IV_SIZE);
+ memcpy(crypto_info_aria_gcm_256->rec_seq, cctx->rec_seq,
+ TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE);
+ release_sock(sk);
+ if (copy_to_user(optval,
+ crypto_info_aria_gcm_256,
+ sizeof(*crypto_info_aria_gcm_256)))
+ rc = -EFAULT;
+ break;
+ }
default:
rc = -EINVAL;
}
@@ -513,6 +580,56 @@ out:
return rc;
}
+static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
+ int __user *optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ unsigned int value;
+ int len;
+
+ if (get_user(len, optlen))
+ return -EFAULT;
+
+ if (len != sizeof(value))
+ return -EINVAL;
+
+ value = ctx->zerocopy_sendfile;
+ if (copy_to_user(optval, &value, sizeof(value)))
+ return -EFAULT;
+
+ return 0;
+}
+
+static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
+ int __user *optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ int value, len;
+
+ if (ctx->prot_info.version != TLS_1_3_VERSION)
+ return -EINVAL;
+
+ if (get_user(len, optlen))
+ return -EFAULT;
+ if (len < sizeof(value))
+ return -EINVAL;
+
+ lock_sock(sk);
+ value = -EINVAL;
+ if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
+ value = ctx->rx_no_pad;
+ release_sock(sk);
+ if (value < 0)
+ return value;
+
+ if (put_user(sizeof(value), optlen))
+ return -EFAULT;
+ if (copy_to_user(optval, &value, sizeof(value)))
+ return -EFAULT;
+
+ return 0;
+}
+
static int do_tls_getsockopt(struct sock *sk, int optname,
char __user *optval, int __user *optlen)
{
@@ -524,6 +641,12 @@ static int do_tls_getsockopt(struct sock *sk, int optname,
rc = do_tls_getsockopt_conf(sk, optval, optlen,
optname == TLS_TX);
break;
+ case TLS_TX_ZEROCOPY_RO:
+ rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
+ break;
+ case TLS_RX_EXPECT_NO_PAD:
+ rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
+ break;
default:
rc = -ENOPROTOOPT;
break;
@@ -553,10 +676,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
int rc = 0;
int conf;
- if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) {
- rc = -EINVAL;
- goto out;
- }
+ if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
+ return -EINVAL;
if (tx) {
crypto_info = &ctx->crypto_send.info;
@@ -567,10 +688,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
}
/* Currently we don't support set crypto info more than one time */
- if (TLS_CRYPTO_INFO_READY(crypto_info)) {
- rc = -EBUSY;
- goto out;
- }
+ if (TLS_CRYPTO_INFO_READY(crypto_info))
+ return -EBUSY;
rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
if (rc) {
@@ -614,6 +733,20 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
case TLS_CIPHER_SM4_CCM:
optsize = sizeof(struct tls12_crypto_info_sm4_ccm);
break;
+ case TLS_CIPHER_ARIA_GCM_128:
+ if (crypto_info->version != TLS_1_2_VERSION) {
+ rc = -EINVAL;
+ goto err_crypto_info;
+ }
+ optsize = sizeof(struct tls12_crypto_info_aria_gcm_128);
+ break;
+ case TLS_CIPHER_ARIA_GCM_256:
+ if (crypto_info->version != TLS_1_2_VERSION) {
+ rc = -EINVAL;
+ goto err_crypto_info;
+ }
+ optsize = sizeof(struct tls12_crypto_info_aria_gcm_256);
+ break;
default:
rc = -EINVAL;
goto err_crypto_info;
@@ -671,12 +804,67 @@ static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
if (tx) {
ctx->sk_write_space = sk->sk_write_space;
sk->sk_write_space = tls_write_space;
+ } else {
+ struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx);
+
+ tls_strp_check_rcv(&rx_ctx->strp);
}
- goto out;
+ return 0;
err_crypto_info:
memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
-out:
+ return rc;
+}
+
+static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
+ unsigned int optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ unsigned int value;
+
+ if (sockptr_is_null(optval) || optlen != sizeof(value))
+ return -EINVAL;
+
+ if (copy_from_sockptr(&value, optval, sizeof(value)))
+ return -EFAULT;
+
+ if (value > 1)
+ return -EINVAL;
+
+ ctx->zerocopy_sendfile = value;
+
+ return 0;
+}
+
+static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
+ unsigned int optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ u32 val;
+ int rc;
+
+ if (ctx->prot_info.version != TLS_1_3_VERSION ||
+ sockptr_is_null(optval) || optlen < sizeof(val))
+ return -EINVAL;
+
+ rc = copy_from_sockptr(&val, optval, sizeof(val));
+ if (rc)
+ return -EFAULT;
+ if (val > 1)
+ return -EINVAL;
+ rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
+ if (rc < 1)
+ return rc == 0 ? -EINVAL : rc;
+
+ lock_sock(sk);
+ rc = -EINVAL;
+ if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
+ ctx->rx_no_pad = val;
+ tls_update_rx_zc_capable(ctx);
+ rc = 0;
+ }
+ release_sock(sk);
+
return rc;
}
@@ -693,6 +881,14 @@ static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
optname == TLS_TX);
release_sock(sk);
break;
+ case TLS_TX_ZEROCOPY_RO:
+ lock_sock(sk);
+ rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
+ release_sock(sk);
+ break;
+ case TLS_RX_EXPECT_NO_PAD:
+ rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
+ break;
default:
rc = -ENOPROTOOPT;
break;
@@ -878,6 +1074,8 @@ static void tls_update(struct sock *sk, struct proto *p,
{
struct tls_context *ctx;
+ WARN_ON_ONCE(sk->sk_prot == p);
+
ctx = tls_get_ctx(sk);
if (likely(ctx)) {
ctx->sk_write_space = write_space;
@@ -889,6 +1087,23 @@ static void tls_update(struct sock *sk, struct proto *p,
}
}
+static u16 tls_user_config(struct tls_context *ctx, bool tx)
+{
+ u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
+
+ switch (config) {
+ case TLS_BASE:
+ return TLS_CONF_BASE;
+ case TLS_SW:
+ return TLS_CONF_SW;
+ case TLS_HW:
+ return TLS_CONF_HW;
+ case TLS_HW_RECORD:
+ return TLS_CONF_HW_RECORD;
+ }
+ return 0;
+}
+
static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
{
u16 version, cipher_type;
@@ -926,6 +1141,17 @@ static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
if (err)
goto nla_failure;
+ if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
+ err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
+ if (err)
+ goto nla_failure;
+ }
+ if (ctx->rx_no_pad) {
+ err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
+ if (err)
+ goto nla_failure;
+ }
+
rcu_read_unlock();
nla_nest_end(skb, start);
return 0;
@@ -945,6 +1171,8 @@ static size_t tls_get_info_size(const struct sock *sk)
nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */
nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */
nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */
+ nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */
+ nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */
0;
return size;
@@ -996,15 +1224,28 @@ static int __init tls_register(void)
if (err)
return err;
- tls_device_init();
+ err = tls_strp_dev_init();
+ if (err)
+ goto err_pernet;
+
+ err = tls_device_init();
+ if (err)
+ goto err_strp;
+
tcp_register_ulp(&tcp_tls_ulp_ops);
return 0;
+err_strp:
+ tls_strp_dev_exit();
+err_pernet:
+ unregister_pernet_subsys(&tls_proc_ops);
+ return err;
}
static void __exit tls_unregister(void)
{
tcp_unregister_ulp(&tcp_tls_ulp_ops);
+ tls_strp_dev_exit();
tls_device_cleanup();
unregister_pernet_subsys(&tls_proc_ops);
}
diff --git a/net/tls/tls_proc.c b/net/tls/tls_proc.c
index feeceb0e4cb4..68982728f620 100644
--- a/net/tls/tls_proc.c
+++ b/net/tls/tls_proc.c
@@ -6,6 +6,8 @@
#include <net/snmp.h>
#include <net/tls.h>
+#include "tls.h"
+
#ifdef CONFIG_PROC_FS
static const struct snmp_mib tls_mib_list[] = {
SNMP_MIB_ITEM("TlsCurrTxSw", LINUX_MIB_TLSCURRTXSW),
@@ -18,6 +20,8 @@ static const struct snmp_mib tls_mib_list[] = {
SNMP_MIB_ITEM("TlsRxDevice", LINUX_MIB_TLSRXDEVICE),
SNMP_MIB_ITEM("TlsDecryptError", LINUX_MIB_TLSDECRYPTERROR),
SNMP_MIB_ITEM("TlsRxDeviceResync", LINUX_MIB_TLSRXDEVICERESYNC),
+ SNMP_MIB_ITEM("TlsDecryptRetry", LINUX_MIB_TLSDECRYPTRETRY),
+ SNMP_MIB_ITEM("TlsRxNoPadViolation", LINUX_MIB_TLSRXNOPADVIOL),
SNMP_MIB_SENTINEL
};
diff --git a/net/tls/tls_strp.c b/net/tls/tls_strp.c
new file mode 100644
index 000000000000..955ac3e0bf4d
--- /dev/null
+++ b/net/tls/tls_strp.c
@@ -0,0 +1,518 @@
+// SPDX-License-Identifier: GPL-2.0-only
+/* Copyright (c) 2016 Tom Herbert <tom@herbertland.com> */
+
+#include <linux/skbuff.h>
+#include <linux/workqueue.h>
+#include <net/strparser.h>
+#include <net/tcp.h>
+#include <net/sock.h>
+#include <net/tls.h>
+
+#include "tls.h"
+
+static struct workqueue_struct *tls_strp_wq;
+
+static void tls_strp_abort_strp(struct tls_strparser *strp, int err)
+{
+ if (strp->stopped)
+ return;
+
+ strp->stopped = 1;
+
+ /* Report an error on the lower socket */
+ strp->sk->sk_err = -err;
+ sk_error_report(strp->sk);
+}
+
+static void tls_strp_anchor_free(struct tls_strparser *strp)
+{
+ struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
+
+ DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
+ shinfo->frag_list = NULL;
+ consume_skb(strp->anchor);
+ strp->anchor = NULL;
+}
+
+/* Create a new skb with the contents of input copied to its page frags */
+static struct sk_buff *tls_strp_msg_make_copy(struct tls_strparser *strp)
+{
+ struct strp_msg *rxm;
+ struct sk_buff *skb;
+ int i, err, offset;
+
+ skb = alloc_skb_with_frags(0, strp->stm.full_len, TLS_PAGE_ORDER,
+ &err, strp->sk->sk_allocation);
+ if (!skb)
+ return NULL;
+
+ offset = strp->stm.offset;
+ for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
+ skb_frag_t *frag = &skb_shinfo(skb)->frags[i];
+
+ WARN_ON_ONCE(skb_copy_bits(strp->anchor, offset,
+ skb_frag_address(frag),
+ skb_frag_size(frag)));
+ offset += skb_frag_size(frag);
+ }
+
+ skb_copy_header(skb, strp->anchor);
+ rxm = strp_msg(skb);
+ rxm->offset = 0;
+ return skb;
+}
+
+/* Steal the input skb, input msg is invalid after calling this function */
+struct sk_buff *tls_strp_msg_detach(struct tls_sw_context_rx *ctx)
+{
+ struct tls_strparser *strp = &ctx->strp;
+
+#ifdef CONFIG_TLS_DEVICE
+ DEBUG_NET_WARN_ON_ONCE(!strp->anchor->decrypted);
+#else
+ /* This function turns an input into an output,
+ * that can only happen if we have offload.
+ */
+ WARN_ON(1);
+#endif
+
+ if (strp->copy_mode) {
+ struct sk_buff *skb;
+
+ /* Replace anchor with an empty skb, this is a little
+ * dangerous but __tls_cur_msg() warns on empty skbs
+ * so hopefully we'll catch abuses.
+ */
+ skb = alloc_skb(0, strp->sk->sk_allocation);
+ if (!skb)
+ return NULL;
+
+ swap(strp->anchor, skb);
+ return skb;
+ }
+
+ return tls_strp_msg_make_copy(strp);
+}
+
+/* Force the input skb to be in copy mode. The data ownership remains
+ * with the input skb itself (meaning unpause will wipe it) but it can
+ * be modified.
+ */
+int tls_strp_msg_cow(struct tls_sw_context_rx *ctx)
+{
+ struct tls_strparser *strp = &ctx->strp;
+ struct sk_buff *skb;
+
+ if (strp->copy_mode)
+ return 0;
+
+ skb = tls_strp_msg_make_copy(strp);
+ if (!skb)
+ return -ENOMEM;
+
+ tls_strp_anchor_free(strp);
+ strp->anchor = skb;
+
+ tcp_read_done(strp->sk, strp->stm.full_len);
+ strp->copy_mode = 1;
+
+ return 0;
+}
+
+/* Make a clone (in the skb sense) of the input msg to keep a reference
+ * to the underlying data. The reference-holding skbs get placed on
+ * @dst.
+ */
+int tls_strp_msg_hold(struct tls_strparser *strp, struct sk_buff_head *dst)
+{
+ struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
+
+ if (strp->copy_mode) {
+ struct sk_buff *skb;
+
+ WARN_ON_ONCE(!shinfo->nr_frags);
+
+ /* We can't skb_clone() the anchor, it gets wiped by unpause */
+ skb = alloc_skb(0, strp->sk->sk_allocation);
+ if (!skb)
+ return -ENOMEM;
+
+ __skb_queue_tail(dst, strp->anchor);
+ strp->anchor = skb;
+ } else {
+ struct sk_buff *iter, *clone;
+ int chunk, len, offset;
+
+ offset = strp->stm.offset;
+ len = strp->stm.full_len;
+ iter = shinfo->frag_list;
+
+ while (len > 0) {
+ if (iter->len <= offset) {
+ offset -= iter->len;
+ goto next;
+ }
+
+ chunk = iter->len - offset;
+ offset = 0;
+
+ clone = skb_clone(iter, strp->sk->sk_allocation);
+ if (!clone)
+ return -ENOMEM;
+ __skb_queue_tail(dst, clone);
+
+ len -= chunk;
+next:
+ iter = iter->next;
+ }
+ }
+
+ return 0;
+}
+
+static void tls_strp_flush_anchor_copy(struct tls_strparser *strp)
+{
+ struct skb_shared_info *shinfo = skb_shinfo(strp->anchor);
+ int i;
+
+ DEBUG_NET_WARN_ON_ONCE(atomic_read(&shinfo->dataref) != 1);
+
+ for (i = 0; i < shinfo->nr_frags; i++)
+ __skb_frag_unref(&shinfo->frags[i], false);
+ shinfo->nr_frags = 0;
+ strp->copy_mode = 0;
+}
+
+static int tls_strp_copyin(read_descriptor_t *desc, struct sk_buff *in_skb,
+ unsigned int offset, size_t in_len)
+{
+ struct tls_strparser *strp = (struct tls_strparser *)desc->arg.data;
+ struct sk_buff *skb;
+ skb_frag_t *frag;
+ size_t len, chunk;
+ int sz;
+
+ if (strp->msg_ready)
+ return 0;
+
+ skb = strp->anchor;
+ frag = &skb_shinfo(skb)->frags[skb->len / PAGE_SIZE];
+
+ len = in_len;
+ /* First make sure we got the header */
+ if (!strp->stm.full_len) {
+ /* Assume one page is more than enough for headers */
+ chunk = min_t(size_t, len, PAGE_SIZE - skb_frag_size(frag));
+ WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
+ skb_frag_address(frag) +
+ skb_frag_size(frag),
+ chunk));
+
+ sz = tls_rx_msg_size(strp, strp->anchor);
+ if (sz < 0) {
+ desc->error = sz;
+ return 0;
+ }
+
+ /* We may have over-read, sz == 0 is guaranteed under-read */
+ if (sz > 0)
+ chunk = min_t(size_t, chunk, sz - skb->len);
+
+ skb->len += chunk;
+ skb->data_len += chunk;
+ skb_frag_size_add(frag, chunk);
+ frag++;
+ len -= chunk;
+ offset += chunk;
+
+ strp->stm.full_len = sz;
+ if (!strp->stm.full_len)
+ goto read_done;
+ }
+
+ /* Load up more data */
+ while (len && strp->stm.full_len > skb->len) {
+ chunk = min_t(size_t, len, strp->stm.full_len - skb->len);
+ chunk = min_t(size_t, chunk, PAGE_SIZE - skb_frag_size(frag));
+ WARN_ON_ONCE(skb_copy_bits(in_skb, offset,
+ skb_frag_address(frag) +
+ skb_frag_size(frag),
+ chunk));
+
+ skb->len += chunk;
+ skb->data_len += chunk;
+ skb_frag_size_add(frag, chunk);
+ frag++;
+ len -= chunk;
+ offset += chunk;
+ }
+
+ if (strp->stm.full_len == skb->len) {
+ desc->count = 0;
+
+ strp->msg_ready = 1;
+ tls_rx_msg_ready(strp);
+ }
+
+read_done:
+ return in_len - len;
+}
+
+static int tls_strp_read_copyin(struct tls_strparser *strp)
+{
+ struct socket *sock = strp->sk->sk_socket;
+ read_descriptor_t desc;
+
+ desc.arg.data = strp;
+ desc.error = 0;
+ desc.count = 1; /* give more than one skb per call */
+
+ /* sk should be locked here, so okay to do read_sock */
+ sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin);
+
+ return desc.error;
+}
+
+static int tls_strp_read_copy(struct tls_strparser *strp, bool qshort)
+{
+ struct skb_shared_info *shinfo;
+ struct page *page;
+ int need_spc, len;
+
+ /* If the rbuf is small or rcv window has collapsed to 0 we need
+ * to read the data out. Otherwise the connection will stall.
+ * Without pressure threshold of INT_MAX will never be ready.
+ */
+ if (likely(qshort && !tcp_epollin_ready(strp->sk, INT_MAX)))
+ return 0;
+
+ shinfo = skb_shinfo(strp->anchor);
+ shinfo->frag_list = NULL;
+
+ /* If we don't know the length go max plus page for cipher overhead */
+ need_spc = strp->stm.full_len ?: TLS_MAX_PAYLOAD_SIZE + PAGE_SIZE;
+
+ for (len = need_spc; len > 0; len -= PAGE_SIZE) {
+ page = alloc_page(strp->sk->sk_allocation);
+ if (!page) {
+ tls_strp_flush_anchor_copy(strp);
+ return -ENOMEM;
+ }
+
+ skb_fill_page_desc(strp->anchor, shinfo->nr_frags++,
+ page, 0, 0);
+ }
+
+ strp->copy_mode = 1;
+ strp->stm.offset = 0;
+
+ strp->anchor->len = 0;
+ strp->anchor->data_len = 0;
+ strp->anchor->truesize = round_up(need_spc, PAGE_SIZE);
+
+ tls_strp_read_copyin(strp);
+
+ return 0;
+}
+
+static bool tls_strp_check_no_dup(struct tls_strparser *strp)
+{
+ unsigned int len = strp->stm.offset + strp->stm.full_len;
+ struct sk_buff *skb;
+ u32 seq;
+
+ skb = skb_shinfo(strp->anchor)->frag_list;
+ seq = TCP_SKB_CB(skb)->seq;
+
+ while (skb->len < len) {
+ seq += skb->len;
+ len -= skb->len;
+ skb = skb->next;
+
+ if (TCP_SKB_CB(skb)->seq != seq)
+ return false;
+ }
+
+ return true;
+}
+
+static void tls_strp_load_anchor_with_queue(struct tls_strparser *strp, int len)
+{
+ struct tcp_sock *tp = tcp_sk(strp->sk);
+ struct sk_buff *first;
+ u32 offset;
+
+ first = tcp_recv_skb(strp->sk, tp->copied_seq, &offset);
+ if (WARN_ON_ONCE(!first))
+ return;
+
+ /* Bestow the state onto the anchor */
+ strp->anchor->len = offset + len;
+ strp->anchor->data_len = offset + len;
+ strp->anchor->truesize = offset + len;
+
+ skb_shinfo(strp->anchor)->frag_list = first;
+
+ skb_copy_header(strp->anchor, first);
+ strp->anchor->destructor = NULL;
+
+ strp->stm.offset = offset;
+}
+
+void tls_strp_msg_load(struct tls_strparser *strp, bool force_refresh)
+{
+ struct strp_msg *rxm;
+ struct tls_msg *tlm;
+
+ DEBUG_NET_WARN_ON_ONCE(!strp->msg_ready);
+ DEBUG_NET_WARN_ON_ONCE(!strp->stm.full_len);
+
+ if (!strp->copy_mode && force_refresh) {
+ if (WARN_ON(tcp_inq(strp->sk) < strp->stm.full_len))
+ return;
+
+ tls_strp_load_anchor_with_queue(strp, strp->stm.full_len);
+ }
+
+ rxm = strp_msg(strp->anchor);
+ rxm->full_len = strp->stm.full_len;
+ rxm->offset = strp->stm.offset;
+ tlm = tls_msg(strp->anchor);
+ tlm->control = strp->mark;
+}
+
+/* Called with lock held on lower socket */
+static int tls_strp_read_sock(struct tls_strparser *strp)
+{
+ int sz, inq;
+
+ inq = tcp_inq(strp->sk);
+ if (inq < 1)
+ return 0;
+
+ if (unlikely(strp->copy_mode))
+ return tls_strp_read_copyin(strp);
+
+ if (inq < strp->stm.full_len)
+ return tls_strp_read_copy(strp, true);
+
+ if (!strp->stm.full_len) {
+ tls_strp_load_anchor_with_queue(strp, inq);
+
+ sz = tls_rx_msg_size(strp, strp->anchor);
+ if (sz < 0) {
+ tls_strp_abort_strp(strp, sz);
+ return sz;
+ }
+
+ strp->stm.full_len = sz;
+
+ if (!strp->stm.full_len || inq < strp->stm.full_len)
+ return tls_strp_read_copy(strp, true);
+ }
+
+ if (!tls_strp_check_no_dup(strp))
+ return tls_strp_read_copy(strp, false);
+
+ strp->msg_ready = 1;
+ tls_rx_msg_ready(strp);
+
+ return 0;
+}
+
+void tls_strp_check_rcv(struct tls_strparser *strp)
+{
+ if (unlikely(strp->stopped) || strp->msg_ready)
+ return;
+
+ if (tls_strp_read_sock(strp) == -ENOMEM)
+ queue_work(tls_strp_wq, &strp->work);
+}
+
+/* Lower sock lock held */
+void tls_strp_data_ready(struct tls_strparser *strp)
+{
+ /* This check is needed to synchronize with do_tls_strp_work.
+ * do_tls_strp_work acquires a process lock (lock_sock) whereas
+ * the lock held here is bh_lock_sock. The two locks can be
+ * held by different threads at the same time, but bh_lock_sock
+ * allows a thread in BH context to safely check if the process
+ * lock is held. In this case, if the lock is held, queue work.
+ */
+ if (sock_owned_by_user_nocheck(strp->sk)) {
+ queue_work(tls_strp_wq, &strp->work);
+ return;
+ }
+
+ tls_strp_check_rcv(strp);
+}
+
+static void tls_strp_work(struct work_struct *w)
+{
+ struct tls_strparser *strp =
+ container_of(w, struct tls_strparser, work);
+
+ lock_sock(strp->sk);
+ tls_strp_check_rcv(strp);
+ release_sock(strp->sk);
+}
+
+void tls_strp_msg_done(struct tls_strparser *strp)
+{
+ WARN_ON(!strp->stm.full_len);
+
+ if (likely(!strp->copy_mode))
+ tcp_read_done(strp->sk, strp->stm.full_len);
+ else
+ tls_strp_flush_anchor_copy(strp);
+
+ strp->msg_ready = 0;
+ memset(&strp->stm, 0, sizeof(strp->stm));
+
+ tls_strp_check_rcv(strp);
+}
+
+void tls_strp_stop(struct tls_strparser *strp)
+{
+ strp->stopped = 1;
+}
+
+int tls_strp_init(struct tls_strparser *strp, struct sock *sk)
+{
+ memset(strp, 0, sizeof(*strp));
+
+ strp->sk = sk;
+
+ strp->anchor = alloc_skb(0, GFP_KERNEL);
+ if (!strp->anchor)
+ return -ENOMEM;
+
+ INIT_WORK(&strp->work, tls_strp_work);
+
+ return 0;
+}
+
+/* strp must already be stopped so that tls_strp_recv will no longer be called.
+ * Note that tls_strp_done is not called with the lower socket held.
+ */
+void tls_strp_done(struct tls_strparser *strp)
+{
+ WARN_ON(!strp->stopped);
+
+ cancel_work_sync(&strp->work);
+ tls_strp_anchor_free(strp);
+}
+
+int __init tls_strp_dev_init(void)
+{
+ tls_strp_wq = create_workqueue("tls-strp");
+ if (unlikely(!tls_strp_wq))
+ return -ENOMEM;
+
+ return 0;
+}
+
+void tls_strp_dev_exit(void)
+{
+ destroy_workqueue(tls_strp_wq);
+}
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index dfe623a4e72f..264cf367e265 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -44,6 +44,25 @@
#include <net/strparser.h>
#include <net/tls.h>
+#include "tls.h"
+
+struct tls_decrypt_arg {
+ struct_group(inargs,
+ bool zc;
+ bool async;
+ u8 tail;
+ );
+
+ struct sk_buff *skb;
+};
+
+struct tls_decrypt_ctx {
+ u8 iv[MAX_IV_SIZE];
+ u8 aad[TLS_MAX_AAD_SIZE];
+ u8 tail;
+ struct scatterlist sg[];
+};
+
noinline void tls_err_abort(struct sock *sk, int err)
{
WARN_ON_ONCE(err >= 0);
@@ -128,32 +147,32 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len)
return __skb_nsg(skb, offset, len, 0);
}
-static int padding_length(struct tls_sw_context_rx *ctx,
- struct tls_prot_info *prot, struct sk_buff *skb)
+static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
+ struct tls_decrypt_arg *darg)
{
struct strp_msg *rxm = strp_msg(skb);
+ struct tls_msg *tlm = tls_msg(skb);
int sub = 0;
/* Determine zero-padding length */
if (prot->version == TLS_1_3_VERSION) {
- char content_type = 0;
+ int offset = rxm->full_len - TLS_TAG_SIZE - 1;
+ char content_type = darg->zc ? darg->tail : 0;
int err;
- int back = 17;
while (content_type == 0) {
- if (back > rxm->full_len - prot->prepend_size)
+ if (offset < prot->prepend_size)
return -EBADMSG;
- err = skb_copy_bits(skb,
- rxm->offset + rxm->full_len - back,
+ err = skb_copy_bits(skb, rxm->offset + offset,
&content_type, 1);
if (err)
return err;
if (content_type)
break;
sub++;
- back++;
+ offset--;
}
- ctx->control = content_type;
+ tlm->control = content_type;
}
return sub;
}
@@ -165,45 +184,22 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx;
struct tls_context *tls_ctx;
- struct tls_prot_info *prot;
struct scatterlist *sg;
- struct sk_buff *skb;
unsigned int pages;
- int pending;
+ struct sock *sk;
- skb = (struct sk_buff *)req->data;
- tls_ctx = tls_get_ctx(skb->sk);
+ sk = (struct sock *)req->data;
+ tls_ctx = tls_get_ctx(sk);
ctx = tls_sw_ctx_rx(tls_ctx);
- prot = &tls_ctx->prot_info;
/* Propagate if there was an err */
if (err) {
if (err == -EBADMSG)
- TLS_INC_STATS(sock_net(skb->sk),
- LINUX_MIB_TLSDECRYPTERROR);
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
ctx->async_wait.err = err;
- tls_err_abort(skb->sk, err);
- } else {
- struct strp_msg *rxm = strp_msg(skb);
- int pad;
-
- pad = padding_length(ctx, prot, skb);
- if (pad < 0) {
- ctx->async_wait.err = pad;
- tls_err_abort(skb->sk, pad);
- } else {
- rxm->full_len -= pad;
- rxm->offset += prot->prepend_size;
- rxm->full_len -= prot->overhead_size;
- }
+ tls_err_abort(sk, err);
}
- /* After using skb->sk to propagate sk through crypto async callback
- * we need to NULL it again.
- */
- skb->sk = NULL;
-
-
/* Free the destination pages if skb was not decrypted inplace */
if (sgout != sgin) {
/* Skip the first S/G entry as it points to AAD */
@@ -217,21 +213,18 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
kfree(aead_req);
spin_lock_bh(&ctx->decrypt_compl_lock);
- pending = atomic_dec_return(&ctx->decrypt_pending);
-
- if (!pending && ctx->async_notify)
+ if (!atomic_dec_return(&ctx->decrypt_pending))
complete(&ctx->async_wait.completion);
spin_unlock_bh(&ctx->decrypt_compl_lock);
}
static int tls_do_decryption(struct sock *sk,
- struct sk_buff *skb,
struct scatterlist *sgin,
struct scatterlist *sgout,
char *iv_recv,
size_t data_len,
struct aead_request *aead_req,
- bool async)
+ struct tls_decrypt_arg *darg)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
@@ -244,17 +237,10 @@ static int tls_do_decryption(struct sock *sk,
data_len + prot->tag_size,
(u8 *)iv_recv);
- if (async) {
- /* Using skb->sk to push sk through to crypto async callback
- * handler. This allows propagating errors up to the socket
- * if needed. It _must_ be cleared in the async handler
- * before consume_skb is called. We _know_ skb->sk is NULL
- * because it is a clone from strparser.
- */
- skb->sk = sk;
+ if (darg->async) {
aead_request_set_callback(aead_req,
CRYPTO_TFM_REQ_MAY_BACKLOG,
- tls_decrypt_done, skb);
+ tls_decrypt_done, sk);
atomic_inc(&ctx->decrypt_pending);
} else {
aead_request_set_callback(aead_req,
@@ -264,14 +250,12 @@ static int tls_do_decryption(struct sock *sk,
ret = crypto_aead_decrypt(aead_req);
if (ret == -EINPROGRESS) {
- if (async)
- return ret;
+ if (darg->async)
+ return 0;
ret = crypto_wait_req(ret, &ctx->async_wait);
}
-
- if (async)
- atomic_dec(&ctx->decrypt_pending);
+ darg->async = false;
return ret;
}
@@ -521,7 +505,8 @@ static int tls_do_encryption(struct sock *sk,
memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
prot->iv_size + prot->salt_size);
- xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
+ tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset,
+ tls_ctx->tx.rec_seq);
sge->offset += prot->prepend_size;
sge->length -= prot->prepend_size;
@@ -958,7 +943,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
lock_sock(sk);
if (unlikely(msg->msg_controllen)) {
- ret = tls_proccess_cmsg(sk, msg, &record_type);
+ ret = tls_process_cmsg(sk, msg, &record_type);
if (ret) {
if (ret == -EINPROGRESS)
num_async++;
@@ -1296,65 +1281,67 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
return ret;
}
-static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
- bool nonblock, long timeo, int *err)
+static int
+tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
+ bool released)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- struct sk_buff *skb;
DEFINE_WAIT_FUNC(wait, woken_wake_function);
+ long timeo;
- while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
- if (sk->sk_err) {
- *err = sock_error(sk);
- return NULL;
- }
+ timeo = sock_rcvtimeo(sk, nonblock);
+
+ while (!tls_strp_msg_ready(ctx)) {
+ if (!sk_psock_queue_empty(psock))
+ return 0;
+
+ if (sk->sk_err)
+ return sock_error(sk);
if (!skb_queue_empty(&sk->sk_receive_queue)) {
- __strp_unpause(&ctx->strp);
- if (ctx->recv_pkt)
- return ctx->recv_pkt;
+ tls_strp_check_rcv(&ctx->strp);
+ if (tls_strp_msg_ready(ctx))
+ break;
}
if (sk->sk_shutdown & RCV_SHUTDOWN)
- return NULL;
+ return 0;
if (sock_flag(sk, SOCK_DONE))
- return NULL;
+ return 0;
- if (nonblock || !timeo) {
- *err = -EAGAIN;
- return NULL;
- }
+ if (!timeo)
+ return -EAGAIN;
+ released = true;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
sk_wait_event(sk, &timeo,
- ctx->recv_pkt != skb ||
+ tls_strp_msg_ready(ctx) ||
!sk_psock_queue_empty(psock),
&wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
/* Handle signals */
- if (signal_pending(current)) {
- *err = sock_intr_errno(timeo);
- return NULL;
- }
+ if (signal_pending(current))
+ return sock_intr_errno(timeo);
}
- return skb;
+ tls_strp_msg_load(&ctx->strp, released);
+
+ return 1;
}
-static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
+static int tls_setup_from_iter(struct iov_iter *from,
int length, int *pages_used,
- unsigned int *size_used,
struct scatterlist *to,
int to_max_pages)
{
int rc = 0, i = 0, num_elem = *pages_used, maxpages;
struct page *pages[MAX_SKB_FRAGS];
- unsigned int size = *size_used;
+ unsigned int size = 0;
ssize_t copied, use;
size_t offset;
@@ -1365,7 +1352,7 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
rc = -EFAULT;
goto out;
}
- copied = iov_iter_get_pages(from, pages,
+ copied = iov_iter_get_pages2(from, pages,
length,
maxpages, &offset);
if (copied <= 0) {
@@ -1373,8 +1360,6 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
goto out;
}
- iov_iter_advance(from, copied);
-
length -= copied;
size += copied;
while (copied) {
@@ -1397,246 +1382,363 @@ static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
sg_mark_end(&to[num_elem - 1]);
out:
if (rc)
- iov_iter_revert(from, size - *size_used);
- *size_used = size;
+ iov_iter_revert(from, size);
*pages_used = num_elem;
return rc;
}
+static struct sk_buff *
+tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb,
+ unsigned int full_len)
+{
+ struct strp_msg *clr_rxm;
+ struct sk_buff *clr_skb;
+ int err;
+
+ clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER,
+ &err, sk->sk_allocation);
+ if (!clr_skb)
+ return NULL;
+
+ skb_copy_header(clr_skb, skb);
+ clr_skb->len = full_len;
+ clr_skb->data_len = full_len;
+
+ clr_rxm = strp_msg(clr_skb);
+ clr_rxm->offset = 0;
+
+ return clr_skb;
+}
+
+/* Decrypt handlers
+ *
+ * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers.
+ * They must transform the darg in/out argument are as follows:
+ * | Input | Output
+ * -------------------------------------------------------------------
+ * zc | Zero-copy decrypt allowed | Zero-copy performed
+ * async | Async decrypt allowed | Async crypto used / in progress
+ * skb | * | Output skb
+ *
+ * If ZC decryption was performed darg.skb will point to the input skb.
+ */
+
/* This function decrypts the input skb into either out_iov or in out_sg
- * or in skb buffers itself. The input parameter 'zc' indicates if
+ * or in skb buffers itself. The input parameter 'darg->zc' indicates if
* zero-copy mode needs to be tried or not. With zero-copy mode, either
* out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
* NULL, then the decryption happens inside skb buffers itself, i.e.
- * zero-copy gets disabled and 'zc' is updated.
+ * zero-copy gets disabled and 'darg->zc' is updated.
*/
-
-static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
- struct iov_iter *out_iov,
- struct scatterlist *out_sg,
- int *chunk, bool *zc, bool async)
+static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
+ struct scatterlist *out_sg,
+ struct tls_decrypt_arg *darg)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
- struct strp_msg *rxm = strp_msg(skb);
- int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
+ int n_sgin, n_sgout, aead_size, err, pages = 0;
+ struct sk_buff *skb = tls_strp_msg(ctx);
+ const struct strp_msg *rxm = strp_msg(skb);
+ const struct tls_msg *tlm = tls_msg(skb);
struct aead_request *aead_req;
- struct sk_buff *unused;
- u8 *aad, *iv, *mem = NULL;
struct scatterlist *sgin = NULL;
struct scatterlist *sgout = NULL;
- const int data_len = rxm->full_len - prot->overhead_size +
- prot->tail_size;
+ const int data_len = rxm->full_len - prot->overhead_size;
+ int tail_pages = !!prot->tail_size;
+ struct tls_decrypt_ctx *dctx;
+ struct sk_buff *clear_skb;
int iv_offset = 0;
+ u8 *mem;
+
+ n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
+ rxm->full_len - prot->prepend_size);
+ if (n_sgin < 1)
+ return n_sgin ?: -EBADMSG;
+
+ if (darg->zc && (out_iov || out_sg)) {
+ clear_skb = NULL;
- if (*zc && (out_iov || out_sg)) {
if (out_iov)
- n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
+ n_sgout = 1 + tail_pages +
+ iov_iter_npages_cap(out_iov, INT_MAX, data_len);
else
n_sgout = sg_nents(out_sg);
- n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
- rxm->full_len - prot->prepend_size);
} else {
- n_sgout = 0;
- *zc = false;
- n_sgin = skb_cow_data(skb, 0, &unused);
- }
+ darg->zc = false;
- if (n_sgin < 1)
- return -EBADMSG;
+ clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len);
+ if (!clear_skb)
+ return -ENOMEM;
+
+ n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags;
+ }
/* Increment to accommodate AAD */
n_sgin = n_sgin + 1;
- nsg = n_sgin + n_sgout;
-
- aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
- mem_size = aead_size + (nsg * sizeof(struct scatterlist));
- mem_size = mem_size + prot->aad_size;
- mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
-
/* Allocate a single block of memory which contains
- * aead_req || sgin[] || sgout[] || aad || iv.
- * This order achieves correct alignment for aead_req, sgin, sgout.
+ * aead_req || tls_decrypt_ctx.
+ * Both structs are variable length.
*/
- mem = kmalloc(mem_size, sk->sk_allocation);
- if (!mem)
- return -ENOMEM;
+ aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
+ mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
+ sk->sk_allocation);
+ if (!mem) {
+ err = -ENOMEM;
+ goto exit_free_skb;
+ }
/* Segment the allocated memory */
aead_req = (struct aead_request *)mem;
- sgin = (struct scatterlist *)(mem + aead_size);
- sgout = sgin + n_sgin;
- aad = (u8 *)(sgout + n_sgout);
- iv = aad + prot->aad_size;
+ dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
+ sgin = &dctx->sg[0];
+ sgout = &dctx->sg[n_sgin];
/* For CCM based ciphers, first byte of nonce+iv is a constant */
switch (prot->cipher_type) {
case TLS_CIPHER_AES_CCM_128:
- iv[0] = TLS_AES_CCM_IV_B0_BYTE;
+ dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
iv_offset = 1;
break;
case TLS_CIPHER_SM4_CCM:
- iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
+ dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
iv_offset = 1;
break;
}
/* Prepare IV */
- err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
- iv + iv_offset + prot->salt_size,
- prot->iv_size);
- if (err < 0) {
- kfree(mem);
- return err;
- }
if (prot->version == TLS_1_3_VERSION ||
- prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305)
- memcpy(iv + iv_offset, tls_ctx->rx.iv,
- crypto_aead_ivsize(ctx->aead_recv));
- else
- memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
-
- xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
+ prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
+ memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
+ prot->iv_size + prot->salt_size);
+ } else {
+ err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
+ &dctx->iv[iv_offset] + prot->salt_size,
+ prot->iv_size);
+ if (err < 0)
+ goto exit_free;
+ memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
+ }
+ tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
/* Prepare AAD */
- tls_make_aad(aad, rxm->full_len - prot->overhead_size +
+ tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
prot->tail_size,
- tls_ctx->rx.rec_seq, ctx->control, prot);
+ tls_ctx->rx.rec_seq, tlm->control, prot);
/* Prepare sgin */
sg_init_table(sgin, n_sgin);
- sg_set_buf(&sgin[0], aad, prot->aad_size);
+ sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
err = skb_to_sgvec(skb, &sgin[1],
rxm->offset + prot->prepend_size,
rxm->full_len - prot->prepend_size);
- if (err < 0) {
- kfree(mem);
- return err;
- }
-
- if (n_sgout) {
- if (out_iov) {
- sg_init_table(sgout, n_sgout);
- sg_set_buf(&sgout[0], aad, prot->aad_size);
-
- *chunk = 0;
- err = tls_setup_from_iter(sk, out_iov, data_len,
- &pages, chunk, &sgout[1],
- (n_sgout - 1));
- if (err < 0)
- goto fallback_to_reg_recv;
- } else if (out_sg) {
- memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
- } else {
- goto fallback_to_reg_recv;
+ if (err < 0)
+ goto exit_free;
+
+ if (clear_skb) {
+ sg_init_table(sgout, n_sgout);
+ sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
+
+ err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size,
+ data_len + prot->tail_size);
+ if (err < 0)
+ goto exit_free;
+ } else if (out_iov) {
+ sg_init_table(sgout, n_sgout);
+ sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
+
+ err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1],
+ (n_sgout - 1 - tail_pages));
+ if (err < 0)
+ goto exit_free_pages;
+
+ if (prot->tail_size) {
+ sg_unmark_end(&sgout[pages]);
+ sg_set_buf(&sgout[pages + 1], &dctx->tail,
+ prot->tail_size);
+ sg_mark_end(&sgout[pages + 1]);
}
- } else {
-fallback_to_reg_recv:
- sgout = sgin;
- pages = 0;
- *chunk = data_len;
- *zc = false;
+ } else if (out_sg) {
+ memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
}
/* Prepare and submit AEAD request */
- err = tls_do_decryption(sk, skb, sgin, sgout, iv,
- data_len, aead_req, async);
- if (err == -EINPROGRESS)
+ err = tls_do_decryption(sk, sgin, sgout, dctx->iv,
+ data_len + prot->tail_size, aead_req, darg);
+ if (err)
+ goto exit_free_pages;
+
+ darg->skb = clear_skb ?: tls_strp_msg(ctx);
+ clear_skb = NULL;
+
+ if (unlikely(darg->async)) {
+ err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold);
+ if (err)
+ __skb_queue_tail(&ctx->async_hold, darg->skb);
return err;
+ }
+
+ if (prot->tail_size)
+ darg->tail = dctx->tail;
+exit_free_pages:
/* Release the pages in case iov was mapped to pages */
for (; pages > 0; pages--)
put_page(sg_page(&sgout[pages]));
-
+exit_free:
kfree(mem);
+exit_free_skb:
+ consume_skb(clear_skb);
return err;
}
-static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
- struct iov_iter *dest, int *chunk, bool *zc,
- bool async)
+static int
+tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx,
+ struct msghdr *msg, struct tls_decrypt_arg *darg)
{
- struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
- struct strp_msg *rxm = strp_msg(skb);
- int pad, err = 0;
-
- if (!ctx->decrypted) {
- if (tls_ctx->rx_conf == TLS_HW) {
- err = tls_device_decrypted(sk, tls_ctx, skb, rxm);
- if (err < 0)
- return err;
- }
+ struct strp_msg *rxm;
+ int pad, err;
- /* Still not decrypted after tls_device */
- if (!ctx->decrypted) {
- err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
- async);
- if (err < 0) {
- if (err == -EINPROGRESS)
- tls_advance_record_sn(sk, prot,
- &tls_ctx->rx);
- else if (err == -EBADMSG)
- TLS_INC_STATS(sock_net(sk),
- LINUX_MIB_TLSDECRYPTERROR);
- return err;
- }
- } else {
- *zc = false;
- }
+ err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg);
+ if (err < 0) {
+ if (err == -EBADMSG)
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR);
+ return err;
+ }
+ /* keep going even for ->async, the code below is TLS 1.3 */
- pad = padding_length(ctx, prot, skb);
- if (pad < 0)
- return pad;
+ /* If opportunistic TLS 1.3 ZC failed retry without ZC */
+ if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
+ darg->tail != TLS_RECORD_TYPE_DATA)) {
+ darg->zc = false;
+ if (!darg->tail)
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
+ return tls_decrypt_sw(sk, tls_ctx, msg, darg);
+ }
- rxm->full_len -= pad;
- rxm->offset += prot->prepend_size;
- rxm->full_len -= prot->overhead_size;
- tls_advance_record_sn(sk, prot, &tls_ctx->rx);
- ctx->decrypted = 1;
- ctx->saved_data_ready(sk);
- } else {
- *zc = false;
+ pad = tls_padding_length(prot, darg->skb, darg);
+ if (pad < 0) {
+ if (darg->skb != tls_strp_msg(ctx))
+ consume_skb(darg->skb);
+ return pad;
}
- return err;
+ rxm = strp_msg(darg->skb);
+ rxm->full_len -= pad;
+
+ return 0;
}
-int decrypt_skb(struct sock *sk, struct sk_buff *skb,
- struct scatterlist *sgout)
+static int
+tls_decrypt_device(struct sock *sk, struct msghdr *msg,
+ struct tls_context *tls_ctx, struct tls_decrypt_arg *darg)
{
- bool zc = true;
- int chunk;
+ struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
+ struct strp_msg *rxm;
+ int pad, err;
- return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
+ if (tls_ctx->rx_conf != TLS_HW)
+ return 0;
+
+ err = tls_device_decrypted(sk, tls_ctx);
+ if (err <= 0)
+ return err;
+
+ pad = tls_padding_length(prot, tls_strp_msg(ctx), darg);
+ if (pad < 0)
+ return pad;
+
+ darg->async = false;
+ darg->skb = tls_strp_msg(ctx);
+ /* ->zc downgrade check, in case TLS 1.3 gets here */
+ darg->zc &= !(prot->version == TLS_1_3_VERSION &&
+ tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA);
+
+ rxm = strp_msg(darg->skb);
+ rxm->full_len -= pad;
+
+ if (!darg->zc) {
+ /* Non-ZC case needs a real skb */
+ darg->skb = tls_strp_msg_detach(ctx);
+ if (!darg->skb)
+ return -ENOMEM;
+ } else {
+ unsigned int off, len;
+
+ /* In ZC case nobody cares about the output skb.
+ * Just copy the data here. Note the skb is not fully trimmed.
+ */
+ off = rxm->offset + prot->prepend_size;
+ len = rxm->full_len - prot->overhead_size;
+
+ err = skb_copy_datagram_msg(darg->skb, off, msg, len);
+ if (err)
+ return err;
+ }
+ return 1;
}
-static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
- unsigned int len)
+static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
+ struct tls_decrypt_arg *darg)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
- struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+ struct tls_prot_info *prot = &tls_ctx->prot_info;
+ struct strp_msg *rxm;
+ int err;
- if (skb) {
- struct strp_msg *rxm = strp_msg(skb);
+ err = tls_decrypt_device(sk, msg, tls_ctx, darg);
+ if (!err)
+ err = tls_decrypt_sw(sk, tls_ctx, msg, darg);
+ if (err < 0)
+ return err;
- if (len < rxm->full_len) {
- rxm->offset += len;
- rxm->full_len -= len;
- return false;
+ rxm = strp_msg(darg->skb);
+ rxm->offset += prot->prepend_size;
+ rxm->full_len -= prot->overhead_size;
+ tls_advance_record_sn(sk, prot, &tls_ctx->rx);
+
+ return 0;
+}
+
+int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
+{
+ struct tls_decrypt_arg darg = { .zc = true, };
+
+ return tls_decrypt_sg(sk, NULL, sgout, &darg);
+}
+
+static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm,
+ u8 *control)
+{
+ int err;
+
+ if (!*control) {
+ *control = tlm->control;
+ if (!*control)
+ return -EBADMSG;
+
+ err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
+ sizeof(*control), control);
+ if (*control != TLS_RECORD_TYPE_DATA) {
+ if (err || msg->msg_flags & MSG_CTRUNC)
+ return -EIO;
}
- consume_skb(skb);
+ } else if (*control != tlm->control) {
+ return 0;
}
- /* Finished with message */
- ctx->recv_pkt = NULL;
- __strp_unpause(&ctx->strp);
+ return 1;
+}
- return true;
+static void tls_rx_rec_done(struct tls_sw_context_rx *ctx)
+{
+ tls_strp_msg_done(&ctx->strp);
}
/* This function traverses the rx_list in tls receive context to copies the
@@ -1647,31 +1749,22 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
static int process_rx_list(struct tls_sw_context_rx *ctx,
struct msghdr *msg,
u8 *control,
- bool *cmsg,
size_t skip,
size_t len,
- bool zc,
bool is_peek)
{
struct sk_buff *skb = skb_peek(&ctx->rx_list);
- u8 ctrl = *control;
- u8 msgc = *cmsg;
struct tls_msg *tlm;
ssize_t copied = 0;
-
- /* Set the record type in 'control' if caller didn't pass it */
- if (!ctrl && skb) {
- tlm = tls_msg(skb);
- ctrl = tlm->control;
- }
+ int err;
while (skip && skb) {
struct strp_msg *rxm = strp_msg(skb);
tlm = tls_msg(skb);
- /* Cannot process a record of different type */
- if (ctrl != tlm->control)
- return 0;
+ err = tls_record_content_type(msg, tlm, control);
+ if (err <= 0)
+ goto out;
if (skip < rxm->full_len)
break;
@@ -1687,31 +1780,14 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
tlm = tls_msg(skb);
- /* Cannot process a record of different type */
- if (ctrl != tlm->control)
- return 0;
-
- /* Set record type if not already done. For a non-data record,
- * do not proceed if record type could not be copied.
- */
- if (!msgc) {
- int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
- sizeof(ctrl), &ctrl);
- msgc = true;
- if (ctrl != TLS_RECORD_TYPE_DATA) {
- if (cerr || msg->msg_flags & MSG_CTRUNC)
- return -EIO;
-
- *cmsg = msgc;
- }
- }
+ err = tls_record_content_type(msg, tlm, control);
+ if (err <= 0)
+ goto out;
- if (!zc || (rxm->full_len - skip) > len) {
- int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
- msg, chunk);
- if (err < 0)
- return err;
- }
+ err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+ msg, chunk);
+ if (err < 0)
+ goto out;
len = len - chunk;
copied = copied + chunk;
@@ -1737,127 +1813,186 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
next_skb = skb_peek_next(skb, &ctx->rx_list);
if (!is_peek) {
- skb_unlink(skb, &ctx->rx_list);
+ __skb_unlink(skb, &ctx->rx_list);
consume_skb(skb);
}
skb = next_skb;
}
+ err = 0;
+
+out:
+ return copied ? : err;
+}
+
+static bool
+tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
+ size_t len_left, size_t decrypted, ssize_t done,
+ size_t *flushed_at)
+{
+ size_t max_rec;
+
+ if (len_left <= decrypted)
+ return false;
+
+ max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
+ if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
+ return false;
+
+ *flushed_at = done;
+ return sk_flush_backlog(sk);
+}
+
+static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx,
+ bool nonblock)
+{
+ long timeo;
+ int err;
+
+ lock_sock(sk);
+
+ timeo = sock_rcvtimeo(sk, nonblock);
+
+ while (unlikely(ctx->reader_present)) {
+ DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+ ctx->reader_contended = 1;
+
+ add_wait_queue(&ctx->wq, &wait);
+ sk_wait_event(sk, &timeo,
+ !READ_ONCE(ctx->reader_present), &wait);
+ remove_wait_queue(&ctx->wq, &wait);
+
+ if (timeo <= 0) {
+ err = -EAGAIN;
+ goto err_unlock;
+ }
+ if (signal_pending(current)) {
+ err = sock_intr_errno(timeo);
+ goto err_unlock;
+ }
+ }
- *control = ctrl;
- return copied;
+ WRITE_ONCE(ctx->reader_present, 1);
+
+ return 0;
+
+err_unlock:
+ release_sock(sk);
+ return err;
+}
+
+static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx)
+{
+ if (unlikely(ctx->reader_contended)) {
+ if (wq_has_sleeper(&ctx->wq))
+ wake_up(&ctx->wq);
+ else
+ ctx->reader_contended = 0;
+
+ WARN_ON_ONCE(!ctx->reader_present);
+ }
+
+ WRITE_ONCE(ctx->reader_present, 0);
+ release_sock(sk);
}
int tls_sw_recvmsg(struct sock *sk,
struct msghdr *msg,
size_t len,
- int nonblock,
int flags,
int *addr_len)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
+ ssize_t decrypted = 0, async_copy_bytes = 0;
struct sk_psock *psock;
unsigned char control = 0;
- ssize_t decrypted = 0;
+ size_t flushed_at = 0;
struct strp_msg *rxm;
struct tls_msg *tlm;
- struct sk_buff *skb;
ssize_t copied = 0;
- bool cmsg = false;
- int target, err = 0;
- long timeo;
+ bool async = false;
+ int target, err;
bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
bool is_peek = flags & MSG_PEEK;
+ bool released = true;
bool bpf_strp_enabled;
- int num_async = 0;
- int pending;
-
- flags |= nonblock;
+ bool zc_capable;
if (unlikely(flags & MSG_ERRQUEUE))
return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
psock = sk_psock_get(sk);
- lock_sock(sk);
+ err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT);
+ if (err < 0)
+ return err;
bpf_strp_enabled = sk_psock_strp_enabled(psock);
+ /* If crypto failed the connection is broken */
+ err = ctx->async_wait.err;
+ if (err)
+ goto end;
+
/* Process pending decrypted records. It must be non-zero-copy */
- err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
- is_peek);
- if (err < 0) {
- tls_err_abort(sk, err);
+ err = process_rx_list(ctx, msg, &control, 0, len, is_peek);
+ if (err < 0)
goto end;
- } else {
- copied = err;
- }
+ copied = err;
if (len <= copied)
- goto recv_end;
+ goto end;
target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
len = len - copied;
- timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
-
- while (len && (decrypted + copied < target || ctx->recv_pkt)) {
- bool retain_skb = false;
- bool zc = false;
- int to_decrypt;
- int chunk = 0;
- bool async_capable;
- bool async = false;
-
- skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
- if (!skb) {
- if (psock) {
- int ret = sk_msg_recvmsg(sk, psock, msg, len,
- flags);
- if (ret > 0) {
- decrypted += ret;
- len -= ret;
+ zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
+ ctx->zc_capable;
+ decrypted = 0;
+ while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) {
+ struct tls_decrypt_arg darg;
+ int to_decrypt, chunk;
+
+ err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT,
+ released);
+ if (err <= 0) {
+ if (psock) {
+ chunk = sk_msg_recvmsg(sk, psock, msg, len,
+ flags);
+ if (chunk > 0) {
+ decrypted += chunk;
+ len -= chunk;
continue;
}
}
goto recv_end;
- } else {
- tlm = tls_msg(skb);
- if (prot->version == TLS_1_3_VERSION)
- tlm->control = 0;
- else
- tlm->control = ctx->control;
}
- rxm = strp_msg(skb);
+ memset(&darg.inargs, 0, sizeof(darg.inargs));
+
+ rxm = strp_msg(tls_strp_msg(ctx));
+ tlm = tls_msg(tls_strp_msg(ctx));
to_decrypt = rxm->full_len - prot->overhead_size;
- if (to_decrypt <= len && !is_kvec && !is_peek &&
- ctx->control == TLS_RECORD_TYPE_DATA &&
- prot->version != TLS_1_3_VERSION &&
- !bpf_strp_enabled)
- zc = true;
+ if (zc_capable && to_decrypt <= len &&
+ tlm->control == TLS_RECORD_TYPE_DATA)
+ darg.zc = true;
/* Do not use async mode if record is non-data */
- if (ctx->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
- async_capable = ctx->async_capable;
+ if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled)
+ darg.async = ctx->async_capable;
else
- async_capable = false;
+ darg.async = false;
- err = decrypt_skb_update(sk, skb, &msg->msg_iter,
- &chunk, &zc, async_capable);
- if (err < 0 && err != -EINPROGRESS) {
+ err = tls_rx_one_record(sk, msg, &darg);
+ if (err < 0) {
tls_err_abort(sk, -EBADMSG);
goto recv_end;
}
- if (err == -EINPROGRESS) {
- async = true;
- num_async++;
- } else if (prot->version == TLS_1_3_VERSION) {
- tlm->control = ctx->control;
- }
+ async |= darg.async;
/* If the type of records being processed is not known yet,
* set it to record type just dequeued. If it is already known,
@@ -1866,130 +2001,120 @@ int tls_sw_recvmsg(struct sock *sk,
* is known just after record is dequeued from stream parser.
* For tls1.3, we disable async.
*/
-
- if (!control)
- control = tlm->control;
- else if (control != tlm->control)
+ err = tls_record_content_type(msg, tls_msg(darg.skb), &control);
+ if (err <= 0) {
+ DEBUG_NET_WARN_ON_ONCE(darg.zc);
+ tls_rx_rec_done(ctx);
+put_on_rx_list_err:
+ __skb_queue_tail(&ctx->rx_list, darg.skb);
goto recv_end;
-
- if (!cmsg) {
- int cerr;
-
- cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
- sizeof(control), &control);
- cmsg = true;
- if (control != TLS_RECORD_TYPE_DATA) {
- if (cerr || msg->msg_flags & MSG_CTRUNC) {
- err = -EIO;
- goto recv_end;
- }
- }
}
- if (async)
- goto pick_next_record;
+ /* periodically flush backlog, and feed strparser */
+ released = tls_read_flush_backlog(sk, prot, len, to_decrypt,
+ decrypted + copied,
+ &flushed_at);
+
+ /* TLS 1.3 may have updated the length by more than overhead */
+ rxm = strp_msg(darg.skb);
+ chunk = rxm->full_len;
+ tls_rx_rec_done(ctx);
+
+ if (!darg.zc) {
+ bool partially_consumed = chunk > len;
+ struct sk_buff *skb = darg.skb;
+
+ DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor);
+
+ if (async) {
+ /* TLS 1.2-only, to_decrypt must be text len */
+ chunk = min_t(int, to_decrypt, len);
+ async_copy_bytes += chunk;
+put_on_rx_list:
+ decrypted += chunk;
+ len -= chunk;
+ __skb_queue_tail(&ctx->rx_list, skb);
+ continue;
+ }
- if (!zc) {
if (bpf_strp_enabled) {
+ released = true;
err = sk_psock_tls_strp_read(psock, skb);
if (err != __SK_PASS) {
rxm->offset = rxm->offset + rxm->full_len;
rxm->full_len = 0;
if (err == __SK_DROP)
consume_skb(skb);
- ctx->recv_pkt = NULL;
- __strp_unpause(&ctx->strp);
continue;
}
}
- if (rxm->full_len > len) {
- retain_skb = true;
+ if (partially_consumed)
chunk = len;
- } else {
- chunk = rxm->full_len;
- }
err = skb_copy_datagram_msg(skb, rxm->offset,
msg, chunk);
if (err < 0)
- goto recv_end;
+ goto put_on_rx_list_err;
- if (!is_peek) {
- rxm->offset = rxm->offset + chunk;
- rxm->full_len = rxm->full_len - chunk;
+ if (is_peek)
+ goto put_on_rx_list;
+
+ if (partially_consumed) {
+ rxm->offset += chunk;
+ rxm->full_len -= chunk;
+ goto put_on_rx_list;
}
- }
-pick_next_record:
- if (chunk > len)
- chunk = len;
+ consume_skb(skb);
+ }
decrypted += chunk;
len -= chunk;
- /* For async or peek case, queue the current skb */
- if (async || is_peek || retain_skb) {
- skb_queue_tail(&ctx->rx_list, skb);
- skb = NULL;
- }
-
- if (tls_sw_advance_skb(sk, skb, chunk)) {
- /* Return full control message to
- * userspace before trying to parse
- * another message type
- */
- msg->msg_flags |= MSG_EOR;
- if (control != TLS_RECORD_TYPE_DATA)
- goto recv_end;
- } else {
+ /* Return full control message to userspace before trying
+ * to parse another message type
+ */
+ msg->msg_flags |= MSG_EOR;
+ if (control != TLS_RECORD_TYPE_DATA)
break;
- }
}
recv_end:
- if (num_async) {
+ if (async) {
+ int ret, pending;
+
/* Wait for all previously submitted records to be decrypted */
spin_lock_bh(&ctx->decrypt_compl_lock);
- ctx->async_notify = true;
+ reinit_completion(&ctx->async_wait.completion);
pending = atomic_read(&ctx->decrypt_pending);
spin_unlock_bh(&ctx->decrypt_compl_lock);
- if (pending) {
- err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
- if (err) {
- /* one of async decrypt failed */
- tls_err_abort(sk, err);
- copied = 0;
- decrypted = 0;
- goto end;
- }
- } else {
- reinit_completion(&ctx->async_wait.completion);
- }
+ ret = 0;
+ if (pending)
+ ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+ __skb_queue_purge(&ctx->async_hold);
- /* There can be no concurrent accesses, since we have no
- * pending decrypt operations
- */
- WRITE_ONCE(ctx->async_notify, false);
+ if (ret) {
+ if (err >= 0 || err == -EINPROGRESS)
+ err = ret;
+ decrypted = 0;
+ goto end;
+ }
/* Drain records from the rx_list & copy if required */
if (is_peek || is_kvec)
- err = process_rx_list(ctx, msg, &control, &cmsg, copied,
- decrypted, false, is_peek);
+ err = process_rx_list(ctx, msg, &control, copied,
+ decrypted, is_peek);
else
- err = process_rx_list(ctx, msg, &control, &cmsg, 0,
- decrypted, true, is_peek);
- if (err < 0) {
- tls_err_abort(sk, err);
- copied = 0;
- goto end;
- }
+ err = process_rx_list(ctx, msg, &control, 0,
+ async_copy_bytes, is_peek);
+ decrypted = max(err, 0);
}
copied += decrypted;
end:
- release_sock(sk);
+ tls_rx_reader_unlock(sk, ctx);
if (psock)
sk_psock_put(sk, psock);
return copied ? : err;
@@ -2003,62 +2128,67 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct strp_msg *rxm = NULL;
struct sock *sk = sock->sk;
+ struct tls_msg *tlm;
struct sk_buff *skb;
ssize_t copied = 0;
- bool from_queue;
- int err = 0;
- long timeo;
int chunk;
- bool zc = false;
-
- lock_sock(sk);
+ int err;
- timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK);
+ err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK);
+ if (err < 0)
+ return err;
- from_queue = !skb_queue_empty(&ctx->rx_list);
- if (from_queue) {
+ if (!skb_queue_empty(&ctx->rx_list)) {
skb = __skb_dequeue(&ctx->rx_list);
} else {
- skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
- &err);
- if (!skb)
+ struct tls_decrypt_arg darg;
+
+ err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
+ true);
+ if (err <= 0)
goto splice_read_end;
- err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
+ memset(&darg.inargs, 0, sizeof(darg.inargs));
+
+ err = tls_rx_one_record(sk, NULL, &darg);
if (err < 0) {
tls_err_abort(sk, -EBADMSG);
goto splice_read_end;
}
+
+ tls_rx_rec_done(ctx);
+ skb = darg.skb;
}
+ rxm = strp_msg(skb);
+ tlm = tls_msg(skb);
+
/* splice does not support reading control messages */
- if (ctx->control != TLS_RECORD_TYPE_DATA) {
+ if (tlm->control != TLS_RECORD_TYPE_DATA) {
err = -EINVAL;
- goto splice_read_end;
+ goto splice_requeue;
}
- rxm = strp_msg(skb);
-
chunk = min_t(unsigned int, rxm->full_len, len);
copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
if (copied < 0)
- goto splice_read_end;
+ goto splice_requeue;
- if (!from_queue) {
- ctx->recv_pkt = NULL;
- __strp_unpause(&ctx->strp);
- }
if (chunk < rxm->full_len) {
- __skb_queue_head(&ctx->rx_list, skb);
rxm->offset += len;
rxm->full_len -= len;
- } else {
- consume_skb(skb);
+ goto splice_requeue;
}
+ consume_skb(skb);
+
splice_read_end:
- release_sock(sk);
+ tls_rx_reader_unlock(sk, ctx);
return copied ? : err;
+
+splice_requeue:
+ __skb_queue_head(&ctx->rx_list, skb);
+ goto splice_read_end;
}
bool tls_sw_sock_is_readable(struct sock *sk)
@@ -2074,23 +2204,21 @@ bool tls_sw_sock_is_readable(struct sock *sk)
ingress_empty = list_empty(&psock->ingress_msg);
rcu_read_unlock();
- return !ingress_empty || ctx->recv_pkt ||
+ return !ingress_empty || tls_strp_msg_ready(ctx) ||
!skb_queue_empty(&ctx->rx_list);
}
-static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
+int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb)
{
struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
- struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
- struct strp_msg *rxm = strp_msg(skb);
size_t cipher_overhead;
size_t data_len = 0;
int ret;
/* Verify that we have a full TLS header, or wait for more data */
- if (rxm->offset + prot->prepend_size > skb->len)
+ if (strp->stm.offset + prot->prepend_size > skb->len)
return 0;
/* Sanity-check size of on-stack buffer. */
@@ -2100,12 +2228,11 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
}
/* Linearize header to local buffer */
- ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
-
+ ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size);
if (ret < 0)
goto read_failure;
- ctx->control = header[0];
+ strp->mark = header[0];
data_len = ((header[4] & 0xFF) | (header[3] << 8));
@@ -2132,7 +2259,7 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
}
tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
- TCP_SKB_CB(skb)->seq + rxm->offset);
+ TCP_SKB_CB(skb)->seq + strp->stm.offset);
return data_len + TLS_HEADER_SIZE;
read_failure:
@@ -2141,16 +2268,11 @@ read_failure:
return ret;
}
-static void tls_queue(struct strparser *strp, struct sk_buff *skb)
+void tls_rx_msg_ready(struct tls_strparser *strp)
{
- struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
- struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-
- ctx->decrypted = 0;
-
- ctx->recv_pkt = skb;
- strp_pause(strp);
+ struct tls_sw_context_rx *ctx;
+ ctx = container_of(strp, struct tls_sw_context_rx, strp);
ctx->saved_data_ready(strp->sk);
}
@@ -2160,7 +2282,7 @@ static void tls_data_ready(struct sock *sk)
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct sk_psock *psock;
- strp_data_ready(&ctx->strp);
+ tls_strp_data_ready(&ctx->strp);
psock = sk_psock_get(sk);
if (psock) {
@@ -2236,13 +2358,11 @@ void tls_sw_release_resources_rx(struct sock *sk)
kfree(tls_ctx->rx.iv);
if (ctx->aead_recv) {
- kfree_skb(ctx->recv_pkt);
- ctx->recv_pkt = NULL;
- skb_queue_purge(&ctx->rx_list);
+ __skb_queue_purge(&ctx->rx_list);
crypto_free_aead(ctx->aead_recv);
- strp_stop(&ctx->strp);
+ tls_strp_stop(&ctx->strp);
/* If tls_sw_strparser_arm() was not called (cleanup paths)
- * we still want to strp_stop(), but sk->sk_data_ready was
+ * we still want to tls_strp_stop(), but sk->sk_data_ready was
* never swapped.
*/
if (ctx->saved_data_ready) {
@@ -2257,7 +2377,7 @@ void tls_sw_strparser_done(struct tls_context *tls_ctx)
{
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- strp_done(&ctx->strp);
+ tls_strp_done(&ctx->strp);
}
void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
@@ -2301,12 +2421,23 @@ static void tx_work_handler(struct work_struct *work)
mutex_unlock(&tls_ctx->tx_lock);
}
+static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx)
+{
+ struct tls_rec *rec;
+
+ rec = list_first_entry(&ctx->tx_list, struct tls_rec, list);
+ if (!rec)
+ return false;
+
+ return READ_ONCE(rec->tx_ready);
+}
+
void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
{
struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
/* Schedule the transmission if tx list is ready */
- if (is_tx_ready(tx_ctx) &&
+ if (tls_is_tx_ready(tx_ctx) &&
!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
schedule_delayed_work(&tx_ctx->tx_work.work, 0);
}
@@ -2319,8 +2450,14 @@ void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
rx_ctx->saved_data_ready = sk->sk_data_ready;
sk->sk_data_ready = tls_data_ready;
write_unlock_bh(&sk->sk_callback_lock);
+}
- strp_check_rcv(&rx_ctx->strp);
+void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
+{
+ struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
+
+ rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
+ tls_ctx->prot_info.version != TLS_1_3_VERSION;
}
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
@@ -2328,15 +2465,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_prot_info *prot = &tls_ctx->prot_info;
struct tls_crypto_info *crypto_info;
- struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
- struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
- struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
- struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info;
struct tls_sw_context_tx *sw_ctx_tx = NULL;
struct tls_sw_context_rx *sw_ctx_rx = NULL;
struct cipher_context *cctx;
struct crypto_aead **aead;
- struct strp_callbacks cb;
u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
struct crypto_tfm *tfm;
char *iv, *rec_seq, *key, *salt, *cipher_name;
@@ -2386,23 +2518,25 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
} else {
crypto_init_wait(&sw_ctx_rx->async_wait);
spin_lock_init(&sw_ctx_rx->decrypt_compl_lock);
+ init_waitqueue_head(&sw_ctx_rx->wq);
crypto_info = &ctx->crypto_recv.info;
cctx = &ctx->rx;
skb_queue_head_init(&sw_ctx_rx->rx_list);
+ skb_queue_head_init(&sw_ctx_rx->async_hold);
aead = &sw_ctx_rx->aead_recv;
}
switch (crypto_info->cipher_type) {
case TLS_CIPHER_AES_GCM_128: {
+ struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
+
+ gcm_128_info = (void *)crypto_info;
nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
- iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
+ iv = gcm_128_info->iv;
rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
- rec_seq =
- ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
- gcm_128_info =
- (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
+ rec_seq = gcm_128_info->rec_seq;
keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
key = gcm_128_info->key;
salt = gcm_128_info->salt;
@@ -2411,15 +2545,15 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
break;
}
case TLS_CIPHER_AES_GCM_256: {
+ struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
+
+ gcm_256_info = (void *)crypto_info;
nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
- iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
+ iv = gcm_256_info->iv;
rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
- rec_seq =
- ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
- gcm_256_info =
- (struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
+ rec_seq = gcm_256_info->rec_seq;
keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
key = gcm_256_info->key;
salt = gcm_256_info->salt;
@@ -2428,15 +2562,15 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
break;
}
case TLS_CIPHER_AES_CCM_128: {
+ struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
+
+ ccm_128_info = (void *)crypto_info;
nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
- iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv;
+ iv = ccm_128_info->iv;
rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
- rec_seq =
- ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq;
- ccm_128_info =
- (struct tls12_crypto_info_aes_ccm_128 *)crypto_info;
+ rec_seq = ccm_128_info->rec_seq;
keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
key = ccm_128_info->key;
salt = ccm_128_info->salt;
@@ -2445,6 +2579,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
break;
}
case TLS_CIPHER_CHACHA20_POLY1305: {
+ struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info;
+
chacha20_poly1305_info = (void *)crypto_info;
nonce_size = 0;
tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE;
@@ -2493,14 +2629,41 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
cipher_name = "ccm(sm4)";
break;
}
- default:
- rc = -EINVAL;
- goto free_priv;
+ case TLS_CIPHER_ARIA_GCM_128: {
+ struct tls12_crypto_info_aria_gcm_128 *aria_gcm_128_info;
+
+ aria_gcm_128_info = (void *)crypto_info;
+ nonce_size = TLS_CIPHER_ARIA_GCM_128_IV_SIZE;
+ tag_size = TLS_CIPHER_ARIA_GCM_128_TAG_SIZE;
+ iv_size = TLS_CIPHER_ARIA_GCM_128_IV_SIZE;
+ iv = aria_gcm_128_info->iv;
+ rec_seq_size = TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE;
+ rec_seq = aria_gcm_128_info->rec_seq;
+ keysize = TLS_CIPHER_ARIA_GCM_128_KEY_SIZE;
+ key = aria_gcm_128_info->key;
+ salt = aria_gcm_128_info->salt;
+ salt_size = TLS_CIPHER_ARIA_GCM_128_SALT_SIZE;
+ cipher_name = "gcm(aria)";
+ break;
}
-
- /* Sanity-check the sizes for stack allocations. */
- if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
- rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
+ case TLS_CIPHER_ARIA_GCM_256: {
+ struct tls12_crypto_info_aria_gcm_256 *gcm_256_info;
+
+ gcm_256_info = (void *)crypto_info;
+ nonce_size = TLS_CIPHER_ARIA_GCM_256_IV_SIZE;
+ tag_size = TLS_CIPHER_ARIA_GCM_256_TAG_SIZE;
+ iv_size = TLS_CIPHER_ARIA_GCM_256_IV_SIZE;
+ iv = gcm_256_info->iv;
+ rec_seq_size = TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE;
+ rec_seq = gcm_256_info->rec_seq;
+ keysize = TLS_CIPHER_ARIA_GCM_256_KEY_SIZE;
+ key = gcm_256_info->key;
+ salt = gcm_256_info->salt;
+ salt_size = TLS_CIPHER_ARIA_GCM_256_SALT_SIZE;
+ cipher_name = "gcm(aria)";
+ break;
+ }
+ default:
rc = -EINVAL;
goto free_priv;
}
@@ -2514,6 +2677,14 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
prot->tail_size = 0;
}
+ /* Sanity-check the sizes for stack allocations. */
+ if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
+ rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE ||
+ prot->aad_size > TLS_MAX_AAD_SIZE) {
+ rc = -EINVAL;
+ goto free_priv;
+ }
+
prot->version = crypto_info->version;
prot->cipher_type = crypto_info->cipher_type;
prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
@@ -2560,19 +2731,14 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
if (sw_ctx_rx) {
tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
- if (crypto_info->version == TLS_1_3_VERSION)
- sw_ctx_rx->async_capable = 0;
- else
- sw_ctx_rx->async_capable =
- !!(tfm->__crt_alg->cra_flags &
- CRYPTO_ALG_ASYNC);
-
- /* Set up strparser */
- memset(&cb, 0, sizeof(cb));
- cb.rcv_msg = tls_queue;
- cb.parse_msg = tls_read_size;
+ tls_update_rx_zc_capable(ctx);
+ sw_ctx_rx->async_capable =
+ crypto_info->version != TLS_1_3_VERSION &&
+ !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
- strp_init(&sw_ctx_rx->strp, sk, &cb);
+ rc = tls_strp_init(&sw_ctx_rx->strp, sk);
+ if (rc)
+ goto free_aead;
}
goto out;
diff --git a/net/tls/tls_toe.c b/net/tls/tls_toe.c
index 7e1330f19165..825669e1ab47 100644
--- a/net/tls/tls_toe.c
+++ b/net/tls/tls_toe.c
@@ -38,6 +38,8 @@
#include <net/tls.h>
#include <net/tls_toe.h>
+#include "tls.h"
+
static LIST_HEAD(device_list);
static DEFINE_SPINLOCK(device_spinlock);