aboutsummaryrefslogtreecommitdiffstats
path: root/net/tls/tls_main.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_main.c')
-rw-r--r--net/tls/tls_main.c481
1 files changed, 430 insertions, 51 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 94774c0e5ff3..3735cb00905d 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -45,6 +45,8 @@
#include <net/tls.h>
#include <net/tls_toe.h>
+#include "tls.h"
+
MODULE_AUTHOR("Mellanox Technologies");
MODULE_DESCRIPTION("Transport Layer Security Support");
MODULE_LICENSE("Dual BSD/GPL");
@@ -56,20 +58,40 @@ enum {
TLS_NUM_PROTS,
};
-static struct proto *saved_tcpv6_prot;
+#define CIPHER_SIZE_DESC(cipher) [cipher] = { \
+ .iv = cipher ## _IV_SIZE, \
+ .key = cipher ## _KEY_SIZE, \
+ .salt = cipher ## _SALT_SIZE, \
+ .tag = cipher ## _TAG_SIZE, \
+ .rec_seq = cipher ## _REC_SEQ_SIZE, \
+}
+
+const struct tls_cipher_size_desc tls_cipher_size_desc[] = {
+ CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_128),
+ CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_256),
+ CIPHER_SIZE_DESC(TLS_CIPHER_AES_CCM_128),
+ CIPHER_SIZE_DESC(TLS_CIPHER_CHACHA20_POLY1305),
+ CIPHER_SIZE_DESC(TLS_CIPHER_SM4_GCM),
+ CIPHER_SIZE_DESC(TLS_CIPHER_SM4_CCM),
+};
+
+static const struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
-static struct proto *saved_tcpv4_prot;
+static const struct proto *saved_tcpv4_prot;
static DEFINE_MUTEX(tcpv4_prot_mutex);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
-static struct proto_ops tls_sw_proto_ops;
+static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
- struct proto *base);
+ const struct proto *base);
void update_sk_prot(struct sock *sk, struct tls_context *ctx)
{
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
- sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];
+ WRITE_ONCE(sk->sk_prot,
+ &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
+ WRITE_ONCE(sk->sk_socket->ops,
+ &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
}
int wait_on_pending_writer(struct sock *sk, long *timeo)
@@ -161,8 +183,8 @@ static int tls_handle_open_record(struct sock *sk, int flags)
return 0;
}
-int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
- unsigned char *record_type)
+int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
+ unsigned char *record_type)
{
struct cmsghdr *cmsg;
int rc = -EINVAL;
@@ -312,7 +334,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
write_lock_bh(&sk->sk_callback_lock);
if (free_ctx)
rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
- sk->sk_prot = ctx->sk_proto;
+ WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
if (sk->sk_write_space == tls_write_space)
sk->sk_write_space = ctx->sk_write_space;
write_unlock_bh(&sk->sk_callback_lock);
@@ -329,12 +351,13 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
tls_ctx_free(sk, ctx);
}
-static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
- int __user *optlen)
+static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
+ int __user *optlen, int tx)
{
int rc = 0;
struct tls_context *ctx = tls_get_ctx(sk);
struct tls_crypto_info *crypto_info;
+ struct cipher_context *cctx;
int len;
if (get_user(len, optlen))
@@ -351,7 +374,13 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
}
/* get user crypto info */
- crypto_info = &ctx->crypto_send.info;
+ if (tx) {
+ crypto_info = &ctx->crypto_send.info;
+ cctx = &ctx->tx;
+ } else {
+ crypto_info = &ctx->crypto_recv.info;
+ cctx = &ctx->rx;
+ }
if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
rc = -EBUSY;
@@ -378,9 +407,9 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
}
lock_sock(sk);
memcpy(crypto_info_aes_gcm_128->iv,
- ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+ cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
TLS_CIPHER_AES_GCM_128_IV_SIZE);
- memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->tx.rec_seq,
+ memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
release_sock(sk);
if (copy_to_user(optval,
@@ -402,9 +431,9 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
}
lock_sock(sk);
memcpy(crypto_info_aes_gcm_256->iv,
- ctx->tx.iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
+ cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
TLS_CIPHER_AES_GCM_256_IV_SIZE);
- memcpy(crypto_info_aes_gcm_256->rec_seq, ctx->tx.rec_seq,
+ memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
release_sock(sk);
if (copy_to_user(optval,
@@ -413,6 +442,136 @@ static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
rc = -EFAULT;
break;
}
+ case TLS_CIPHER_AES_CCM_128: {
+ struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 =
+ container_of(crypto_info,
+ struct tls12_crypto_info_aes_ccm_128, info);
+
+ if (len != sizeof(*aes_ccm_128)) {
+ rc = -EINVAL;
+ goto out;
+ }
+ lock_sock(sk);
+ memcpy(aes_ccm_128->iv,
+ cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE,
+ TLS_CIPHER_AES_CCM_128_IV_SIZE);
+ memcpy(aes_ccm_128->rec_seq, cctx->rec_seq,
+ TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE);
+ release_sock(sk);
+ if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128)))
+ rc = -EFAULT;
+ break;
+ }
+ case TLS_CIPHER_CHACHA20_POLY1305: {
+ struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 =
+ container_of(crypto_info,
+ struct tls12_crypto_info_chacha20_poly1305,
+ info);
+
+ if (len != sizeof(*chacha20_poly1305)) {
+ rc = -EINVAL;
+ goto out;
+ }
+ lock_sock(sk);
+ memcpy(chacha20_poly1305->iv,
+ cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE,
+ TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE);
+ memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq,
+ TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE);
+ release_sock(sk);
+ if (copy_to_user(optval, chacha20_poly1305,
+ sizeof(*chacha20_poly1305)))
+ rc = -EFAULT;
+ break;
+ }
+ case TLS_CIPHER_SM4_GCM: {
+ struct tls12_crypto_info_sm4_gcm *sm4_gcm_info =
+ container_of(crypto_info,
+ struct tls12_crypto_info_sm4_gcm, info);
+
+ if (len != sizeof(*sm4_gcm_info)) {
+ rc = -EINVAL;
+ goto out;
+ }
+ lock_sock(sk);
+ memcpy(sm4_gcm_info->iv,
+ cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE,
+ TLS_CIPHER_SM4_GCM_IV_SIZE);
+ memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq,
+ TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE);
+ release_sock(sk);
+ if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info)))
+ rc = -EFAULT;
+ break;
+ }
+ case TLS_CIPHER_SM4_CCM: {
+ struct tls12_crypto_info_sm4_ccm *sm4_ccm_info =
+ container_of(crypto_info,
+ struct tls12_crypto_info_sm4_ccm, info);
+
+ if (len != sizeof(*sm4_ccm_info)) {
+ rc = -EINVAL;
+ goto out;
+ }
+ lock_sock(sk);
+ memcpy(sm4_ccm_info->iv,
+ cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE,
+ TLS_CIPHER_SM4_CCM_IV_SIZE);
+ memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq,
+ TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE);
+ release_sock(sk);
+ if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info)))
+ rc = -EFAULT;
+ break;
+ }
+ case TLS_CIPHER_ARIA_GCM_128: {
+ struct tls12_crypto_info_aria_gcm_128 *
+ crypto_info_aria_gcm_128 =
+ container_of(crypto_info,
+ struct tls12_crypto_info_aria_gcm_128,
+ info);
+
+ if (len != sizeof(*crypto_info_aria_gcm_128)) {
+ rc = -EINVAL;
+ goto out;
+ }
+ lock_sock(sk);
+ memcpy(crypto_info_aria_gcm_128->iv,
+ cctx->iv + TLS_CIPHER_ARIA_GCM_128_SALT_SIZE,
+ TLS_CIPHER_ARIA_GCM_128_IV_SIZE);
+ memcpy(crypto_info_aria_gcm_128->rec_seq, cctx->rec_seq,
+ TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE);
+ release_sock(sk);
+ if (copy_to_user(optval,
+ crypto_info_aria_gcm_128,
+ sizeof(*crypto_info_aria_gcm_128)))
+ rc = -EFAULT;
+ break;
+ }
+ case TLS_CIPHER_ARIA_GCM_256: {
+ struct tls12_crypto_info_aria_gcm_256 *
+ crypto_info_aria_gcm_256 =
+ container_of(crypto_info,
+ struct tls12_crypto_info_aria_gcm_256,
+ info);
+
+ if (len != sizeof(*crypto_info_aria_gcm_256)) {
+ rc = -EINVAL;
+ goto out;
+ }
+ lock_sock(sk);
+ memcpy(crypto_info_aria_gcm_256->iv,
+ cctx->iv + TLS_CIPHER_ARIA_GCM_256_SALT_SIZE,
+ TLS_CIPHER_ARIA_GCM_256_IV_SIZE);
+ memcpy(crypto_info_aria_gcm_256->rec_seq, cctx->rec_seq,
+ TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE);
+ release_sock(sk);
+ if (copy_to_user(optval,
+ crypto_info_aria_gcm_256,
+ sizeof(*crypto_info_aria_gcm_256)))
+ rc = -EFAULT;
+ break;
+ }
default:
rc = -EINVAL;
}
@@ -421,6 +580,56 @@ out:
return rc;
}
+static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
+ int __user *optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ unsigned int value;
+ int len;
+
+ if (get_user(len, optlen))
+ return -EFAULT;
+
+ if (len != sizeof(value))
+ return -EINVAL;
+
+ value = ctx->zerocopy_sendfile;
+ if (copy_to_user(optval, &value, sizeof(value)))
+ return -EFAULT;
+
+ return 0;
+}
+
+static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
+ int __user *optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ int value, len;
+
+ if (ctx->prot_info.version != TLS_1_3_VERSION)
+ return -EINVAL;
+
+ if (get_user(len, optlen))
+ return -EFAULT;
+ if (len < sizeof(value))
+ return -EINVAL;
+
+ lock_sock(sk);
+ value = -EINVAL;
+ if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
+ value = ctx->rx_no_pad;
+ release_sock(sk);
+ if (value < 0)
+ return value;
+
+ if (put_user(sizeof(value), optlen))
+ return -EFAULT;
+ if (copy_to_user(optval, &value, sizeof(value)))
+ return -EFAULT;
+
+ return 0;
+}
+
static int do_tls_getsockopt(struct sock *sk, int optname,
char __user *optval, int __user *optlen)
{
@@ -428,7 +637,15 @@ static int do_tls_getsockopt(struct sock *sk, int optname,
switch (optname) {
case TLS_TX:
- rc = do_tls_getsockopt_tx(sk, optval, optlen);
+ case TLS_RX:
+ rc = do_tls_getsockopt_conf(sk, optval, optlen,
+ optname == TLS_TX);
+ break;
+ case TLS_TX_ZEROCOPY_RO:
+ rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
+ break;
+ case TLS_RX_EXPECT_NO_PAD:
+ rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
break;
default:
rc = -ENOPROTOOPT;
@@ -449,7 +666,7 @@ static int tls_getsockopt(struct sock *sk, int level, int optname,
return do_tls_getsockopt(sk, optname, optval, optlen);
}
-static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
+static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
unsigned int optlen, int tx)
{
struct tls_crypto_info *crypto_info;
@@ -459,10 +676,8 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
int rc = 0;
int conf;
- if (!optval || (optlen < sizeof(*crypto_info))) {
- rc = -EINVAL;
- goto out;
- }
+ if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
+ return -EINVAL;
if (tx) {
crypto_info = &ctx->crypto_send.info;
@@ -473,12 +688,10 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
}
/* Currently we don't support set crypto info more than one time */
- if (TLS_CRYPTO_INFO_READY(crypto_info)) {
- rc = -EBUSY;
- goto out;
- }
+ if (TLS_CRYPTO_INFO_READY(crypto_info))
+ return -EBUSY;
- rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info));
+ rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
if (rc) {
rc = -EFAULT;
goto err_crypto_info;
@@ -511,6 +724,29 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
case TLS_CIPHER_AES_CCM_128:
optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
break;
+ case TLS_CIPHER_CHACHA20_POLY1305:
+ optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305);
+ break;
+ case TLS_CIPHER_SM4_GCM:
+ optsize = sizeof(struct tls12_crypto_info_sm4_gcm);
+ break;
+ case TLS_CIPHER_SM4_CCM:
+ optsize = sizeof(struct tls12_crypto_info_sm4_ccm);
+ break;
+ case TLS_CIPHER_ARIA_GCM_128:
+ if (crypto_info->version != TLS_1_2_VERSION) {
+ rc = -EINVAL;
+ goto err_crypto_info;
+ }
+ optsize = sizeof(struct tls12_crypto_info_aria_gcm_128);
+ break;
+ case TLS_CIPHER_ARIA_GCM_256:
+ if (crypto_info->version != TLS_1_2_VERSION) {
+ rc = -EINVAL;
+ goto err_crypto_info;
+ }
+ optsize = sizeof(struct tls12_crypto_info_aria_gcm_256);
+ break;
default:
rc = -EINVAL;
goto err_crypto_info;
@@ -521,8 +757,9 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
goto err_crypto_info;
}
- rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info),
- optlen - sizeof(*crypto_info));
+ rc = copy_from_sockptr_offset(crypto_info + 1, optval,
+ sizeof(*crypto_info),
+ optlen - sizeof(*crypto_info));
if (rc) {
rc = -EFAULT;
goto err_crypto_info;
@@ -568,18 +805,71 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
ctx->sk_write_space = sk->sk_write_space;
sk->sk_write_space = tls_write_space;
} else {
- sk->sk_socket->ops = &tls_sw_proto_ops;
+ struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx);
+
+ tls_strp_check_rcv(&rx_ctx->strp);
}
- goto out;
+ return 0;
err_crypto_info:
memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
-out:
return rc;
}
-static int do_tls_setsockopt(struct sock *sk, int optname,
- char __user *optval, unsigned int optlen)
+static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
+ unsigned int optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ unsigned int value;
+
+ if (sockptr_is_null(optval) || optlen != sizeof(value))
+ return -EINVAL;
+
+ if (copy_from_sockptr(&value, optval, sizeof(value)))
+ return -EFAULT;
+
+ if (value > 1)
+ return -EINVAL;
+
+ ctx->zerocopy_sendfile = value;
+
+ return 0;
+}
+
+static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
+ unsigned int optlen)
+{
+ struct tls_context *ctx = tls_get_ctx(sk);
+ u32 val;
+ int rc;
+
+ if (ctx->prot_info.version != TLS_1_3_VERSION ||
+ sockptr_is_null(optval) || optlen < sizeof(val))
+ return -EINVAL;
+
+ rc = copy_from_sockptr(&val, optval, sizeof(val));
+ if (rc)
+ return -EFAULT;
+ if (val > 1)
+ return -EINVAL;
+ rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
+ if (rc < 1)
+ return rc == 0 ? -EINVAL : rc;
+
+ lock_sock(sk);
+ rc = -EINVAL;
+ if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
+ ctx->rx_no_pad = val;
+ tls_update_rx_zc_capable(ctx);
+ rc = 0;
+ }
+ release_sock(sk);
+
+ return rc;
+}
+
+static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
+ unsigned int optlen)
{
int rc = 0;
@@ -591,6 +881,14 @@ static int do_tls_setsockopt(struct sock *sk, int optname,
optname == TLS_TX);
release_sock(sk);
break;
+ case TLS_TX_ZEROCOPY_RO:
+ lock_sock(sk);
+ rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
+ release_sock(sk);
+ break;
+ case TLS_RX_EXPECT_NO_PAD:
+ rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
+ break;
default:
rc = -ENOPROTOOPT;
break;
@@ -599,7 +897,7 @@ static int do_tls_setsockopt(struct sock *sk, int optname,
}
static int tls_setsockopt(struct sock *sk, int level, int optname,
- char __user *optval, unsigned int optlen)
+ sockptr_t optval, unsigned int optlen)
{
struct tls_context *ctx = tls_get_ctx(sk);
@@ -621,38 +919,77 @@ struct tls_context *tls_ctx_create(struct sock *sk)
mutex_init(&ctx->tx_lock);
rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
- ctx->sk_proto = sk->sk_prot;
+ ctx->sk_proto = READ_ONCE(sk->sk_prot);
+ ctx->sk = sk;
return ctx;
}
+static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
+ const struct proto_ops *base)
+{
+ ops[TLS_BASE][TLS_BASE] = *base;
+
+ ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
+ ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked;
+
+ ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE];
+ ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read;
+
+ ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE];
+ ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read;
+
+#ifdef CONFIG_TLS_DEVICE
+ ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
+ ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL;
+
+ ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ];
+ ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL;
+
+ ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ];
+
+ ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ];
+
+ ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ];
+ ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL;
+#endif
+#ifdef CONFIG_TLS_TOE
+ ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
+#endif
+}
+
static void tls_build_proto(struct sock *sk)
{
int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
+ struct proto *prot = READ_ONCE(sk->sk_prot);
/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
if (ip_ver == TLSV6 &&
- unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
+ unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
mutex_lock(&tcpv6_prot_mutex);
- if (likely(sk->sk_prot != saved_tcpv6_prot)) {
- build_protos(tls_prots[TLSV6], sk->sk_prot);
- smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
+ if (likely(prot != saved_tcpv6_prot)) {
+ build_protos(tls_prots[TLSV6], prot);
+ build_proto_ops(tls_proto_ops[TLSV6],
+ sk->sk_socket->ops);
+ smp_store_release(&saved_tcpv6_prot, prot);
}
mutex_unlock(&tcpv6_prot_mutex);
}
if (ip_ver == TLSV4 &&
- unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) {
+ unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
mutex_lock(&tcpv4_prot_mutex);
- if (likely(sk->sk_prot != saved_tcpv4_prot)) {
- build_protos(tls_prots[TLSV4], sk->sk_prot);
- smp_store_release(&saved_tcpv4_prot, sk->sk_prot);
+ if (likely(prot != saved_tcpv4_prot)) {
+ build_protos(tls_prots[TLSV4], prot);
+ build_proto_ops(tls_proto_ops[TLSV4],
+ sk->sk_socket->ops);
+ smp_store_release(&saved_tcpv4_prot, prot);
}
mutex_unlock(&tcpv4_prot_mutex);
}
}
static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
- struct proto *base)
+ const struct proto *base)
{
prot[TLS_BASE][TLS_BASE] = *base;
prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt;
@@ -665,12 +1002,12 @@ static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg;
- prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read;
+ prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close;
prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg;
- prot[TLS_SW][TLS_SW].stream_memory_read = tls_sw_stream_read;
+ prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable;
prot[TLS_SW][TLS_SW].close = tls_sk_proto_close;
#ifdef CONFIG_TLS_DEVICE
@@ -737,16 +1074,36 @@ static void tls_update(struct sock *sk, struct proto *p,
{
struct tls_context *ctx;
+ WARN_ON_ONCE(sk->sk_prot == p);
+
ctx = tls_get_ctx(sk);
if (likely(ctx)) {
ctx->sk_write_space = write_space;
ctx->sk_proto = p;
} else {
- sk->sk_prot = p;
+ /* Pairs with lockless read in sk_clone_lock(). */
+ WRITE_ONCE(sk->sk_prot, p);
sk->sk_write_space = write_space;
}
}
+static u16 tls_user_config(struct tls_context *ctx, bool tx)
+{
+ u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
+
+ switch (config) {
+ case TLS_BASE:
+ return TLS_CONF_BASE;
+ case TLS_SW:
+ return TLS_CONF_SW;
+ case TLS_HW:
+ return TLS_CONF_HW;
+ case TLS_HW_RECORD:
+ return TLS_CONF_HW_RECORD;
+ }
+ return 0;
+}
+
static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
{
u16 version, cipher_type;
@@ -784,6 +1141,17 @@ static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
if (err)
goto nla_failure;
+ if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
+ err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
+ if (err)
+ goto nla_failure;
+ }
+ if (ctx->rx_no_pad) {
+ err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
+ if (err)
+ goto nla_failure;
+ }
+
rcu_read_unlock();
nla_nest_end(skb, start);
return 0;
@@ -803,6 +1171,8 @@ static size_t tls_get_info_size(const struct sock *sk)
nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */
nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */
nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */
+ nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */
+ nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */
0;
return size;
@@ -854,19 +1224,28 @@ static int __init tls_register(void)
if (err)
return err;
- tls_sw_proto_ops = inet_stream_ops;
- tls_sw_proto_ops.splice_read = tls_sw_splice_read;
- tls_sw_proto_ops.sendpage_locked = tls_sw_sendpage_locked,
+ err = tls_strp_dev_init();
+ if (err)
+ goto err_pernet;
+
+ err = tls_device_init();
+ if (err)
+ goto err_strp;
- tls_device_init();
tcp_register_ulp(&tcp_tls_ulp_ops);
return 0;
+err_strp:
+ tls_strp_dev_exit();
+err_pernet:
+ unregister_pernet_subsys(&tls_proc_ops);
+ return err;
}
static void __exit tls_unregister(void)
{
tcp_unregister_ulp(&tcp_tls_ulp_ops);
+ tls_strp_dev_exit();
tls_device_cleanup();
unregister_pernet_subsys(&tls_proc_ops);
}