aboutsummaryrefslogtreecommitdiffstats
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls.h290
-rw-r--r--net/tls/tls_device.c3
-rw-r--r--net/tls/tls_device_fallback.c8
-rw-r--r--net/tls/tls_main.c97
-rw-r--r--net/tls/tls_proc.c4
-rw-r--r--net/tls/tls_sw.c236
-rw-r--r--net/tls/tls_toe.c2
7 files changed, 547 insertions, 93 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h
new file mode 100644
index 000000000000..e0ccc96a0850
--- /dev/null
+++ b/net/tls/tls.h
@@ -0,0 +1,290 @@
+/*
+ * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
+ * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
+ *
+ * This software is available to you under a choice of one of two
+ * licenses. You may choose to be licensed under the terms of the GNU
+ * General Public License (GPL) Version 2, available from the file
+ * COPYING in the main directory of this source tree, or the
+ * OpenIB.org BSD license below:
+ *
+ * Redistribution and use in source and binary forms, with or
+ * without modification, are permitted provided that the following
+ * conditions are met:
+ *
+ * - Redistributions of source code must retain the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer.
+ *
+ * - Redistributions in binary form must reproduce the above
+ * copyright notice, this list of conditions and the following
+ * disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+ * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+ * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+ * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
+ * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
+ * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
+ * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef _TLS_INT_H
+#define _TLS_INT_H
+
+#include <asm/byteorder.h>
+#include <linux/types.h>
+#include <linux/skmsg.h>
+#include <net/tls.h>
+
+#define __TLS_INC_STATS(net, field) \
+ __SNMP_INC_STATS((net)->mib.tls_statistics, field)
+#define TLS_INC_STATS(net, field) \
+ SNMP_INC_STATS((net)->mib.tls_statistics, field)
+#define TLS_DEC_STATS(net, field) \
+ SNMP_DEC_STATS((net)->mib.tls_statistics, field)
+
+/* TLS records are maintained in 'struct tls_rec'. It stores the memory pages
+ * allocated or mapped for each TLS record. After encryption, the records are
+ * stores in a linked list.
+ */
+struct tls_rec {
+ struct list_head list;
+ int tx_ready;
+ int tx_flags;
+
+ struct sk_msg msg_plaintext;
+ struct sk_msg msg_encrypted;
+
+ /* AAD | msg_plaintext.sg.data | sg_tag */
+ struct scatterlist sg_aead_in[2];
+ /* AAD | msg_encrypted.sg.data (data contains overhead for hdr & iv & tag) */
+ struct scatterlist sg_aead_out[2];
+
+ char content_type;
+ struct scatterlist sg_content_type;
+
+ char aad_space[TLS_AAD_SPACE_SIZE];
+ u8 iv_data[MAX_IV_SIZE];
+ struct aead_request aead_req;
+ u8 aead_req_ctx[];
+};
+
+int __net_init tls_proc_init(struct net *net);
+void __net_exit tls_proc_fini(struct net *net);
+
+struct tls_context *tls_ctx_create(struct sock *sk);
+void tls_ctx_free(struct sock *sk, struct tls_context *ctx);
+void update_sk_prot(struct sock *sk, struct tls_context *ctx);
+
+int wait_on_pending_writer(struct sock *sk, long *timeo);
+int tls_sk_query(struct sock *sk, int optname, char __user *optval,
+ int __user *optlen);
+int tls_sk_attach(struct sock *sk, int optname, char __user *optval,
+ unsigned int optlen);
+void tls_err_abort(struct sock *sk, int err);
+
+int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx);
+void tls_update_rx_zc_capable(struct tls_context *tls_ctx);
+void tls_sw_strparser_arm(struct sock *sk, struct tls_context *ctx);
+void tls_sw_strparser_done(struct tls_context *tls_ctx);
+int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags);
+int tls_sw_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags);
+void tls_sw_cancel_work_tx(struct tls_context *tls_ctx);
+void tls_sw_release_resources_tx(struct sock *sk);
+void tls_sw_free_ctx_tx(struct tls_context *tls_ctx);
+void tls_sw_free_resources_rx(struct sock *sk);
+void tls_sw_release_resources_rx(struct sock *sk);
+void tls_sw_free_ctx_rx(struct tls_context *tls_ctx);
+int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+ int flags, int *addr_len);
+bool tls_sw_sock_is_readable(struct sock *sk);
+ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
+ struct pipe_inode_info *pipe,
+ size_t len, unsigned int flags);
+
+int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
+int tls_device_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags);
+int tls_tx_records(struct sock *sk, int flags);
+
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx);
+void tls_device_write_space(struct sock *sk, struct tls_context *ctx);
+
+int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
+ unsigned char *record_type);
+int decrypt_skb(struct sock *sk, struct sk_buff *skb,
+ struct scatterlist *sgout);
+
+int tls_sw_fallback_init(struct sock *sk,
+ struct tls_offload_context_tx *offload_ctx,
+ struct tls_crypto_info *crypto_info);
+
+static inline struct tls_msg *tls_msg(struct sk_buff *skb)
+{
+ struct sk_skb_cb *scb = (struct sk_skb_cb *)skb->cb;
+
+ return &scb->tls;
+}
+
+#ifdef CONFIG_TLS_DEVICE
+int tls_device_init(void);
+void tls_device_cleanup(void);
+int tls_set_device_offload(struct sock *sk, struct tls_context *ctx);
+void tls_device_free_resources_tx(struct sock *sk);
+int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx);
+void tls_device_offload_cleanup_rx(struct sock *sk);
+void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq);
+int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
+ struct sk_buff *skb, struct strp_msg *rxm);
+#else
+static inline int tls_device_init(void) { return 0; }
+static inline void tls_device_cleanup(void) {}
+
+static inline int
+tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
+{
+ return -EOPNOTSUPP;
+}
+
+static inline void tls_device_free_resources_tx(struct sock *sk) {}
+
+static inline int
+tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
+{
+ return -EOPNOTSUPP;
+}
+
+static inline void tls_device_offload_cleanup_rx(struct sock *sk) {}
+static inline void
+tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {}
+
+static inline int
+tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx,
+ struct sk_buff *skb, struct strp_msg *rxm)
+{
+ return 0;
+}
+#endif
+
+int tls_push_sg(struct sock *sk, struct tls_context *ctx,
+ struct scatterlist *sg, u16 first_offset,
+ int flags);
+int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
+ int flags);
+void tls_free_partial_record(struct sock *sk, struct tls_context *ctx);
+
+static inline bool tls_is_partially_sent_record(struct tls_context *ctx)
+{
+ return !!ctx->partially_sent_record;
+}
+
+static inline bool tls_is_pending_open_record(struct tls_context *tls_ctx)
+{
+ return tls_ctx->pending_open_record_frags;
+}
+
+static inline bool tls_bigint_increment(unsigned char *seq, int len)
+{
+ int i;
+
+ for (i = len - 1; i >= 0; i--) {
+ ++seq[i];
+ if (seq[i] != 0)
+ break;
+ }
+
+ return (i == -1);
+}
+
+static inline void tls_bigint_subtract(unsigned char *seq, int n)
+{
+ u64 rcd_sn;
+ __be64 *p;
+
+ BUILD_BUG_ON(TLS_MAX_REC_SEQ_SIZE != 8);
+
+ p = (__be64 *)seq;
+ rcd_sn = be64_to_cpu(*p);
+ *p = cpu_to_be64(rcd_sn - n);
+}
+
+static inline void
+tls_advance_record_sn(struct sock *sk, struct tls_prot_info *prot,
+ struct cipher_context *ctx)
+{
+ if (tls_bigint_increment(ctx->rec_seq, prot->rec_seq_size))
+ tls_err_abort(sk, -EBADMSG);
+
+ if (prot->version != TLS_1_3_VERSION &&
+ prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305)
+ tls_bigint_increment(ctx->iv + prot->salt_size,
+ prot->iv_size);
+}
+
+static inline void
+tls_xor_iv_with_seq(struct tls_prot_info *prot, char *iv, char *seq)
+{
+ int i;
+
+ if (prot->version == TLS_1_3_VERSION ||
+ prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
+ for (i = 0; i < 8; i++)
+ iv[i + 4] ^= seq[i];
+ }
+}
+
+static inline void
+tls_fill_prepend(struct tls_context *ctx, char *buf, size_t plaintext_len,
+ unsigned char record_type)
+{
+ struct tls_prot_info *prot = &ctx->prot_info;
+ size_t pkt_len, iv_size = prot->iv_size;
+
+ pkt_len = plaintext_len + prot->tag_size;
+ if (prot->version != TLS_1_3_VERSION &&
+ prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) {
+ pkt_len += iv_size;
+
+ memcpy(buf + TLS_NONCE_OFFSET,
+ ctx->tx.iv + prot->salt_size, iv_size);
+ }
+
+ /* we cover nonce explicit here as well, so buf should be of
+ * size KTLS_DTLS_HEADER_SIZE + KTLS_DTLS_NONCE_EXPLICIT_SIZE
+ */
+ buf[0] = prot->version == TLS_1_3_VERSION ?
+ TLS_RECORD_TYPE_DATA : record_type;
+ /* Note that VERSION must be TLS_1_2 for both TLS1.2 and TLS1.3 */
+ buf[1] = TLS_1_2_VERSION_MINOR;
+ buf[2] = TLS_1_2_VERSION_MAJOR;
+ /* we can use IV for nonce explicit according to spec */
+ buf[3] = pkt_len >> 8;
+ buf[4] = pkt_len & 0xFF;
+}
+
+static inline
+void tls_make_aad(char *buf, size_t size, char *record_sequence,
+ unsigned char record_type, struct tls_prot_info *prot)
+{
+ if (prot->version != TLS_1_3_VERSION) {
+ memcpy(buf, record_sequence, prot->rec_seq_size);
+ buf += 8;
+ } else {
+ size += prot->tag_size;
+ }
+
+ buf[0] = prot->version == TLS_1_3_VERSION ?
+ TLS_RECORD_TYPE_DATA : record_type;
+ buf[1] = TLS_1_2_VERSION_MAJOR;
+ buf[2] = TLS_1_2_VERSION_MINOR;
+ buf[3] = size >> 8;
+ buf[4] = size & 0xFF;
+}
+
+#endif
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index ce827e79c66a..6f9dff1c11f6 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -38,6 +38,7 @@
#include <net/tcp.h>
#include <net/tls.h>
+#include "tls.h"
#include "trace.h"
/* device_offload_lock is used to synchronize tls_dev_add
@@ -562,7 +563,7 @@ int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
lock_sock(sk);
if (unlikely(msg->msg_controllen)) {
- rc = tls_proccess_cmsg(sk, msg, &record_type);
+ rc = tls_process_cmsg(sk, msg, &record_type);
if (rc)
goto out;
}
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
index e40bedd112b6..618cee704217 100644
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -34,6 +34,8 @@
#include <crypto/scatterwalk.h>
#include <net/ip6_checksum.h>
+#include "tls.h"
+
static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk)
{
struct scatterlist *src = walk->sg;
@@ -232,7 +234,7 @@ static int fill_sg_in(struct scatterlist *sg_in,
s32 *sync_size,
int *resync_sgs)
{
- int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
+ int tcp_payload_offset = skb_tcp_all_headers(skb);
int payload_len = skb->len - tcp_payload_offset;
u32 tcp_seq = ntohl(tcp_hdr(skb)->seq);
struct tls_record_info *record;
@@ -310,8 +312,8 @@ static struct sk_buff *tls_enc_skb(struct tls_context *tls_ctx,
struct sk_buff *skb,
s32 sync_size, u64 rcd_sn)
{
- int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
+ int tcp_payload_offset = skb_tcp_all_headers(skb);
int payload_len = skb->len - tcp_payload_offset;
void *buf, *iv, *aad, *dummy_buf;
struct aead_request *aead_req;
@@ -372,7 +374,7 @@ free_nskb:
static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb)
{
- int tcp_payload_offset = skb_transport_offset(skb) + tcp_hdrlen(skb);
+ int tcp_payload_offset = skb_tcp_all_headers(skb);
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
int payload_len = skb->len - tcp_payload_offset;
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index d80ab3d1764e..9703636cfc60 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -45,6 +45,8 @@
#include <net/tls.h>
#include <net/tls_toe.h>
+#include "tls.h"
+
MODULE_AUTHOR("Mellanox Technologies");
MODULE_DESCRIPTION("Transport Layer Security Support");
MODULE_LICENSE("Dual BSD/GPL");
@@ -164,8 +166,8 @@ static int tls_handle_open_record(struct sock *sk, int flags)
return 0;
}
-int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
- unsigned char *record_type)
+int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
+ unsigned char *record_type)
{
struct cmsghdr *cmsg;
int rc = -EINVAL;
@@ -533,6 +535,36 @@ static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
return 0;
}
+static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
+ int __user *optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ int value, len;
+
+ if (ctx->prot_info.version != TLS_1_3_VERSION)
+ return -EINVAL;
+
+ if (get_user(len, optlen))
+ return -EFAULT;
+ if (len < sizeof(value))
+ return -EINVAL;
+
+ lock_sock(sk);
+ value = -EINVAL;
+ if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
+ value = ctx->rx_no_pad;
+ release_sock(sk);
+ if (value < 0)
+ return value;
+
+ if (put_user(sizeof(value), optlen))
+ return -EFAULT;
+ if (copy_to_user(optval, &value, sizeof(value)))
+ return -EFAULT;
+
+ return 0;
+}
+
static int do_tls_getsockopt(struct sock *sk, int optname,
char __user *optval, int __user *optlen)
{
@@ -547,6 +579,9 @@ static int do_tls_getsockopt(struct sock *sk, int optname,
case TLS_TX_ZEROCOPY_RO:
rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
break;
+ case TLS_RX_EXPECT_NO_PAD:
+ rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
+ break;
default:
rc = -ENOPROTOOPT;
break;
@@ -718,6 +753,38 @@ static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
return 0;
}
+static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
+ unsigned int optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ u32 val;
+ int rc;
+
+ if (ctx->prot_info.version != TLS_1_3_VERSION ||
+ sockptr_is_null(optval) || optlen < sizeof(val))
+ return -EINVAL;
+
+ rc = copy_from_sockptr(&val, optval, sizeof(val));
+ if (rc)
+ return -EFAULT;
+ if (val > 1)
+ return -EINVAL;
+ rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
+ if (rc < 1)
+ return rc == 0 ? -EINVAL : rc;
+
+ lock_sock(sk);
+ rc = -EINVAL;
+ if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
+ ctx->rx_no_pad = val;
+ tls_update_rx_zc_capable(ctx);
+ rc = 0;
+ }
+ release_sock(sk);
+
+ return rc;
+}
+
static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
unsigned int optlen)
{
@@ -736,6 +803,9 @@ static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
release_sock(sk);
break;
+ case TLS_RX_EXPECT_NO_PAD:
+ rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
+ break;
default:
rc = -ENOPROTOOPT;
break;
@@ -934,6 +1004,23 @@ static void tls_update(struct sock *sk, struct proto *p,
}
}
+static u16 tls_user_config(struct tls_context *ctx, bool tx)
+{
+ u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
+
+ switch (config) {
+ case TLS_BASE:
+ return TLS_CONF_BASE;
+ case TLS_SW:
+ return TLS_CONF_SW;
+ case TLS_HW:
+ return TLS_CONF_HW;
+ case TLS_HW_RECORD:
+ return TLS_CONF_HW_RECORD;
+ }
+ return 0;
+}
+
static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
{
u16 version, cipher_type;
@@ -976,6 +1063,11 @@ static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
if (err)
goto nla_failure;
}
+ if (ctx->rx_no_pad) {
+ err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
+ if (err)
+ goto nla_failure;
+ }
rcu_read_unlock();
nla_nest_end(skb, start);
@@ -997,6 +1089,7 @@ static size_t tls_get_info_size(const struct sock *sk)
nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */
nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */
nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */
+ nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */
0;
return size;
diff --git a/net/tls/tls_proc.c b/net/tls/tls_proc.c
index feeceb0e4cb4..68982728f620 100644
--- a/net/tls/tls_proc.c
+++ b/net/tls/tls_proc.c
@@ -6,6 +6,8 @@
#include <net/snmp.h>
#include <net/tls.h>
+#include "tls.h"
+
#ifdef CONFIG_PROC_FS
static const struct snmp_mib tls_mib_list[] = {
SNMP_MIB_ITEM("TlsCurrTxSw", LINUX_MIB_TLSCURRTXSW),
@@ -18,6 +20,8 @@ static const struct snmp_mib tls_mib_list[] = {
SNMP_MIB_ITEM("TlsRxDevice", LINUX_MIB_TLSRXDEVICE),
SNMP_MIB_ITEM("TlsDecryptError", LINUX_MIB_TLSDECRYPTERROR),
SNMP_MIB_ITEM("TlsRxDeviceResync", LINUX_MIB_TLSRXDEVICERESYNC),
+ SNMP_MIB_ITEM("TlsDecryptRetry", LINUX_MIB_TLSDECRYPTRETRY),
+ SNMP_MIB_ITEM("TlsRxNoPadViolation", LINUX_MIB_TLSRXNOPADVIOL),
SNMP_MIB_SENTINEL
};
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index e30649f6dde5..68d79ee48a56 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -44,9 +44,19 @@
#include <net/strparser.h>
#include <net/tls.h>
+#include "tls.h"
+
struct tls_decrypt_arg {
bool zc;
bool async;
+ u8 tail;
+};
+
+struct tls_decrypt_ctx {
+ u8 iv[MAX_IV_SIZE];
+ u8 aad[TLS_MAX_AAD_SIZE];
+ u8 tail;
+ struct scatterlist sg[];
};
noinline void tls_err_abort(struct sock *sk, int err)
@@ -133,7 +143,8 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len)
return __skb_nsg(skb, offset, len, 0);
}
-static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb)
+static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
+ struct tls_decrypt_arg *darg)
{
struct strp_msg *rxm = strp_msg(skb);
struct tls_msg *tlm = tls_msg(skb);
@@ -142,7 +153,7 @@ static int padding_length(struct tls_prot_info *prot, struct sk_buff *skb)
/* Determine zero-padding length */
if (prot->version == TLS_1_3_VERSION) {
int offset = rxm->full_len - TLS_TAG_SIZE - 1;
- char content_type = 0;
+ char content_type = darg->zc ? darg->tail : 0;
int err;
while (content_type == 0) {
@@ -515,7 +526,8 @@ static int tls_do_encryption(struct sock *sk,
memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
prot->iv_size + prot->salt_size);
- xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq);
+ tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset,
+ tls_ctx->tx.rec_seq);
sge->offset += prot->prepend_size;
sge->length -= prot->prepend_size;
@@ -952,7 +964,7 @@ int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
lock_sock(sk);
if (unlikely(msg->msg_controllen)) {
- ret = tls_proccess_cmsg(sk, msg, &record_type);
+ ret = tls_process_cmsg(sk, msg, &record_type);
if (ret) {
if (ret == -EINPROGRESS)
num_async++;
@@ -1290,54 +1302,50 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
return ret;
}
-static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
- bool nonblock, long timeo, int *err)
+static int
+tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
+ long timeo)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
- struct sk_buff *skb;
DEFINE_WAIT_FUNC(wait, woken_wake_function);
- while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
- if (sk->sk_err) {
- *err = sock_error(sk);
- return NULL;
- }
+ while (!ctx->recv_pkt) {
+ if (!sk_psock_queue_empty(psock))
+ return 0;
+
+ if (sk->sk_err)
+ return sock_error(sk);
if (!skb_queue_empty(&sk->sk_receive_queue)) {
__strp_unpause(&ctx->strp);
if (ctx->recv_pkt)
- return ctx->recv_pkt;
+ break;
}
if (sk->sk_shutdown & RCV_SHUTDOWN)
- return NULL;
+ return 0;
if (sock_flag(sk, SOCK_DONE))
- return NULL;
+ return 0;
- if (nonblock || !timeo) {
- *err = -EAGAIN;
- return NULL;
- }
+ if (nonblock || !timeo)
+ return -EAGAIN;
add_wait_queue(sk_sleep(sk), &wait);
sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
sk_wait_event(sk, &timeo,
- ctx->recv_pkt != skb ||
- !sk_psock_queue_empty(psock),
+ ctx->recv_pkt || !sk_psock_queue_empty(psock),
&wait);
sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
remove_wait_queue(sk_sleep(sk), &wait);
/* Handle signals */
- if (signal_pending(current)) {
- *err = sock_intr_errno(timeo);
- return NULL;
- }
+ if (signal_pending(current))
+ return sock_intr_errno(timeo);
}
- return skb;
+ return 1;
}
static int tls_setup_from_iter(struct iov_iter *from,
@@ -1412,21 +1420,22 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
struct tls_prot_info *prot = &tls_ctx->prot_info;
+ int n_sgin, n_sgout, aead_size, err, pages = 0;
struct strp_msg *rxm = strp_msg(skb);
struct tls_msg *tlm = tls_msg(skb);
- int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
struct aead_request *aead_req;
struct sk_buff *unused;
- u8 *aad, *iv, *mem = NULL;
struct scatterlist *sgin = NULL;
struct scatterlist *sgout = NULL;
- const int data_len = rxm->full_len - prot->overhead_size +
- prot->tail_size;
+ const int data_len = rxm->full_len - prot->overhead_size;
+ int tail_pages = !!prot->tail_size;
+ struct tls_decrypt_ctx *dctx;
int iv_offset = 0;
+ u8 *mem;
if (darg->zc && (out_iov || out_sg)) {
if (out_iov)
- n_sgout = 1 +
+ n_sgout = 1 + tail_pages +
iov_iter_npages_cap(out_iov, INT_MAX, data_len);
else
n_sgout = sg_nents(out_sg);
@@ -1444,36 +1453,30 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
/* Increment to accommodate AAD */
n_sgin = n_sgin + 1;
- nsg = n_sgin + n_sgout;
-
- aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
- mem_size = aead_size + (nsg * sizeof(struct scatterlist));
- mem_size = mem_size + prot->aad_size;
- mem_size = mem_size + MAX_IV_SIZE;
-
/* Allocate a single block of memory which contains
- * aead_req || sgin[] || sgout[] || aad || iv.
- * This order achieves correct alignment for aead_req, sgin, sgout.
+ * aead_req || tls_decrypt_ctx.
+ * Both structs are variable length.
*/
- mem = kmalloc(mem_size, sk->sk_allocation);
+ aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
+ mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
+ sk->sk_allocation);
if (!mem)
return -ENOMEM;
/* Segment the allocated memory */
aead_req = (struct aead_request *)mem;
- sgin = (struct scatterlist *)(mem + aead_size);
- sgout = sgin + n_sgin;
- aad = (u8 *)(sgout + n_sgout);
- iv = aad + prot->aad_size;
+ dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
+ sgin = &dctx->sg[0];
+ sgout = &dctx->sg[n_sgin];
/* For CCM based ciphers, first byte of nonce+iv is a constant */
switch (prot->cipher_type) {
case TLS_CIPHER_AES_CCM_128:
- iv[0] = TLS_AES_CCM_IV_B0_BYTE;
+ dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE;
iv_offset = 1;
break;
case TLS_CIPHER_SM4_CCM:
- iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
+ dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE;
iv_offset = 1;
break;
}
@@ -1481,46 +1484,49 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
/* Prepare IV */
if (prot->version == TLS_1_3_VERSION ||
prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) {
- memcpy(iv + iv_offset, tls_ctx->rx.iv,
+ memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv,
prot->iv_size + prot->salt_size);
} else {
err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
- iv + iv_offset + prot->salt_size,
+ &dctx->iv[iv_offset] + prot->salt_size,
prot->iv_size);
- if (err < 0) {
- kfree(mem);
- return err;
- }
- memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
+ if (err < 0)
+ goto exit_free;
+ memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size);
}
- xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq);
+ tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq);
/* Prepare AAD */
- tls_make_aad(aad, rxm->full_len - prot->overhead_size +
+ tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size +
prot->tail_size,
tls_ctx->rx.rec_seq, tlm->control, prot);
/* Prepare sgin */
sg_init_table(sgin, n_sgin);
- sg_set_buf(&sgin[0], aad, prot->aad_size);
+ sg_set_buf(&sgin[0], dctx->aad, prot->aad_size);
err = skb_to_sgvec(skb, &sgin[1],
rxm->offset + prot->prepend_size,
rxm->full_len - prot->prepend_size);
- if (err < 0) {
- kfree(mem);
- return err;
- }
+ if (err < 0)
+ goto exit_free;
if (n_sgout) {
if (out_iov) {
sg_init_table(sgout, n_sgout);
- sg_set_buf(&sgout[0], aad, prot->aad_size);
+ sg_set_buf(&sgout[0], dctx->aad, prot->aad_size);
err = tls_setup_from_iter(out_iov, data_len,
&pages, &sgout[1],
- (n_sgout - 1));
+ (n_sgout - 1 - tail_pages));
if (err < 0)
goto fallback_to_reg_recv;
+
+ if (prot->tail_size) {
+ sg_unmark_end(&sgout[pages]);
+ sg_set_buf(&sgout[pages + 1], &dctx->tail,
+ prot->tail_size);
+ sg_mark_end(&sgout[pages + 1]);
+ }
} else if (out_sg) {
memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
} else {
@@ -1534,15 +1540,18 @@ fallback_to_reg_recv:
}
/* Prepare and submit AEAD request */
- err = tls_do_decryption(sk, skb, sgin, sgout, iv,
- data_len, aead_req, darg);
+ err = tls_do_decryption(sk, skb, sgin, sgout, dctx->iv,
+ data_len + prot->tail_size, aead_req, darg);
if (darg->async)
return 0;
+ if (prot->tail_size)
+ darg->tail = dctx->tail;
+
/* Release the pages in case iov was mapped to pages */
for (; pages > 0; pages--)
put_page(sg_page(&sgout[pages]));
-
+exit_free:
kfree(mem);
return err;
}
@@ -1583,9 +1592,18 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
}
if (darg->async)
goto decrypt_next;
+ /* If opportunistic TLS 1.3 ZC failed retry without ZC */
+ if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION &&
+ darg->tail != TLS_RECORD_TYPE_DATA)) {
+ darg->zc = false;
+ if (!darg->tail)
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL);
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY);
+ return decrypt_skb_update(sk, skb, dest, darg);
+ }
decrypt_done:
- pad = padding_length(prot, skb);
+ pad = tls_padding_length(prot, skb, darg);
if (pad < 0)
return pad;
@@ -1717,6 +1735,24 @@ out:
return copied ? : err;
}
+static void
+tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot,
+ size_t len_left, size_t decrypted, ssize_t done,
+ size_t *flushed_at)
+{
+ size_t max_rec;
+
+ if (len_left <= decrypted)
+ return;
+
+ max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE;
+ if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec)
+ return;
+
+ *flushed_at = done;
+ sk_flush_backlog(sk);
+}
+
int tls_sw_recvmsg(struct sock *sk,
struct msghdr *msg,
size_t len,
@@ -1729,6 +1765,7 @@ int tls_sw_recvmsg(struct sock *sk,
struct sk_psock *psock;
unsigned char control = 0;
ssize_t decrypted = 0;
+ size_t flushed_at = 0;
struct strp_msg *rxm;
struct tls_msg *tlm;
struct sk_buff *skb;
@@ -1767,14 +1804,14 @@ int tls_sw_recvmsg(struct sock *sk,
timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek &&
- prot->version != TLS_1_3_VERSION;
+ ctx->zc_capable;
decrypted = 0;
while (len && (decrypted + copied < target || ctx->recv_pkt)) {
struct tls_decrypt_arg darg = {};
int to_decrypt, chunk;
- skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err);
- if (!skb) {
+ err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, timeo);
+ if (err <= 0) {
if (psock) {
chunk = sk_msg_recvmsg(sk, psock, msg, len,
flags);
@@ -1784,6 +1821,7 @@ int tls_sw_recvmsg(struct sock *sk,
goto recv_end;
}
+ skb = ctx->recv_pkt;
rxm = strp_msg(skb);
tlm = tls_msg(skb);
@@ -1818,6 +1856,10 @@ int tls_sw_recvmsg(struct sock *sk,
if (err <= 0)
goto recv_end;
+ /* periodically flush backlog, and feed strparser */
+ tls_read_flush_backlog(sk, prot, len, to_decrypt,
+ decrypted + copied, &flushed_at);
+
ctx->recv_pkt = NULL;
__strp_unpause(&ctx->strp);
__skb_queue_tail(&ctx->rx_list, skb);
@@ -1946,11 +1988,13 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
} else {
struct tls_decrypt_arg darg = {};
- skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo,
- &err);
- if (!skb)
+ err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK,
+ timeo);
+ if (err <= 0)
goto splice_read_end;
+ skb = ctx->recv_pkt;
+
err = decrypt_skb_update(sk, skb, NULL, &darg);
if (err < 0) {
tls_err_abort(sk, -EBADMSG);
@@ -2227,12 +2271,23 @@ static void tx_work_handler(struct work_struct *work)
mutex_unlock(&tls_ctx->tx_lock);
}
+static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx)
+{
+ struct tls_rec *rec;
+
+ rec = list_first_entry(&ctx->tx_list, struct tls_rec, list);
+ if (!rec)
+ return false;
+
+ return READ_ONCE(rec->tx_ready);
+}
+
void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
{
struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
/* Schedule the transmission if tx list is ready */
- if (is_tx_ready(tx_ctx) &&
+ if (tls_is_tx_ready(tx_ctx) &&
!test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
schedule_delayed_work(&tx_ctx->tx_work.work, 0);
}
@@ -2249,6 +2304,14 @@ void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
strp_check_rcv(&rx_ctx->strp);
}
+void tls_update_rx_zc_capable(struct tls_context *tls_ctx)
+{
+ struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
+
+ rx_ctx->zc_capable = tls_ctx->rx_no_pad ||
+ tls_ctx->prot_info.version != TLS_1_3_VERSION;
+}
+
int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -2422,13 +2485,6 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
goto free_priv;
}
- /* Sanity-check the sizes for stack allocations. */
- if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
- rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) {
- rc = -EINVAL;
- goto free_priv;
- }
-
if (crypto_info->version == TLS_1_3_VERSION) {
nonce_size = 0;
prot->aad_size = TLS_HEADER_SIZE;
@@ -2438,6 +2494,14 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
prot->tail_size = 0;
}
+ /* Sanity-check the sizes for stack allocations. */
+ if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
+ rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE ||
+ prot->aad_size > TLS_MAX_AAD_SIZE) {
+ rc = -EINVAL;
+ goto free_priv;
+ }
+
prot->version = crypto_info->version;
prot->cipher_type = crypto_info->cipher_type;
prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
@@ -2484,12 +2548,10 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
if (sw_ctx_rx) {
tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
- if (crypto_info->version == TLS_1_3_VERSION)
- sw_ctx_rx->async_capable = 0;
- else
- sw_ctx_rx->async_capable =
- !!(tfm->__crt_alg->cra_flags &
- CRYPTO_ALG_ASYNC);
+ tls_update_rx_zc_capable(ctx);
+ sw_ctx_rx->async_capable =
+ crypto_info->version != TLS_1_3_VERSION &&
+ !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
/* Set up strparser */
memset(&cb, 0, sizeof(cb));
diff --git a/net/tls/tls_toe.c b/net/tls/tls_toe.c
index 7e1330f19165..825669e1ab47 100644
--- a/net/tls/tls_toe.c
+++ b/net/tls/tls_toe.c
@@ -38,6 +38,8 @@
#include <net/tls.h>
#include <net/tls_toe.h>
+#include "tls.h"
+
static LIST_HEAD(device_list);
static DEFINE_SPINLOCK(device_spinlock);