aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--drivers/crypto/chelsio/chtls/chtls_main.c6
-rw-r--r--include/net/tls.h48
-rw-r--r--net/tls/tls_device.c78
-rw-r--r--net/tls/tls_main.c46
-rw-r--r--net/tls/tls_sw.c6
5 files changed, 85 insertions, 99 deletions
diff --git a/drivers/crypto/chelsio/chtls/chtls_main.c b/drivers/crypto/chelsio/chtls/chtls_main.c
index 635bb4b447fb..e6df5b95ed47 100644
--- a/drivers/crypto/chelsio/chtls/chtls_main.c
+++ b/drivers/crypto/chelsio/chtls/chtls_main.c
@@ -474,7 +474,8 @@ static int chtls_getsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS)
- return ctx->getsockopt(sk, level, optname, optval, optlen);
+ return ctx->sk_proto->getsockopt(sk, level,
+ optname, optval, optlen);
return do_chtls_getsockopt(sk, optval, optlen);
}
@@ -541,7 +542,8 @@ static int chtls_setsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS)
- return ctx->setsockopt(sk, level, optname, optval, optlen);
+ return ctx->sk_proto->setsockopt(sk, level,
+ optname, optval, optlen);
return do_chtls_setsockopt(sk, optname, optval, optlen);
}
diff --git a/include/net/tls.h b/include/net/tls.h
index ec3c3ed2c6c3..c664e6dba0d1 100644
--- a/include/net/tls.h
+++ b/include/net/tls.h
@@ -275,16 +275,6 @@ struct tls_context {
struct proto *sk_proto;
void (*sk_destruct)(struct sock *sk);
- void (*sk_proto_close)(struct sock *sk, long timeout);
-
- int (*setsockopt)(struct sock *sk, int level,
- int optname, char __user *optval,
- unsigned int optlen);
- int (*getsockopt)(struct sock *sk, int level,
- int optname, char __user *optval,
- int __user *optlen);
- int (*hash)(struct sock *sk);
- void (*unhash)(struct sock *sk);
union tls_crypto_context crypto_send;
union tls_crypto_context crypto_recv;
@@ -376,13 +366,9 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
struct pipe_inode_info *pipe,
size_t len, unsigned int flags);
-int tls_set_device_offload(struct sock *sk, struct tls_context *ctx);
int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
int tls_device_sendpage(struct sock *sk, struct page *page,
int offset, size_t size, int flags);
-void tls_device_free_resources_tx(struct sock *sk);
-void tls_device_init(void);
-void tls_device_cleanup(void);
int tls_tx_records(struct sock *sk, int flags);
struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context,
@@ -659,7 +645,6 @@ int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
unsigned char *record_type);
void tls_register_device(struct tls_device *device);
void tls_unregister_device(struct tls_device *device);
-int tls_device_decrypted(struct sock *sk, struct sk_buff *skb);
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout);
struct sk_buff *tls_encrypt_skb(struct sk_buff *skb);
@@ -672,9 +657,40 @@ int tls_sw_fallback_init(struct sock *sk,
struct tls_offload_context_tx *offload_ctx,
struct tls_crypto_info *crypto_info);
+#ifdef CONFIG_TLS_DEVICE
+void tls_device_init(void);
+void tls_device_cleanup(void);
+int tls_set_device_offload(struct sock *sk, struct tls_context *ctx);
+void tls_device_free_resources_tx(struct sock *sk);
int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx);
-
void tls_device_offload_cleanup_rx(struct sock *sk);
void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq);
+int tls_device_decrypted(struct sock *sk, struct sk_buff *skb);
+#else
+static inline void tls_device_init(void) {}
+static inline void tls_device_cleanup(void) {}
+static inline int
+tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
+{
+ return -EOPNOTSUPP;
+}
+
+static inline void tls_device_free_resources_tx(struct sock *sk) {}
+
+static inline int
+tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
+{
+ return -EOPNOTSUPP;
+}
+
+static inline void tls_device_offload_cleanup_rx(struct sock *sk) {}
+static inline void
+tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {}
+
+static inline int tls_device_decrypted(struct sock *sk, struct sk_buff *skb)
+{
+ return 0;
+}
+#endif
#endif /* _TLS_OFFLOAD_H */
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index e188139f0464..41c106e45f01 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -159,12 +159,8 @@ static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq)
spin_lock_irqsave(&ctx->lock, flags);
info = ctx->retransmit_hint;
- if (info && !before(acked_seq, info->end_seq)) {
+ if (info && !before(acked_seq, info->end_seq))
ctx->retransmit_hint = NULL;
- list_del(&info->list);
- destroy_record(info);
- deleted_records++;
- }
list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
if (before(acked_seq, info->end_seq))
@@ -838,22 +834,18 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
struct net_device *netdev;
char *iv, *rec_seq;
struct sk_buff *skb;
- int rc = -EINVAL;
__be64 rcd_sn;
+ int rc;
if (!ctx)
- goto out;
+ return -EINVAL;
- if (ctx->priv_ctx_tx) {
- rc = -EEXIST;
- goto out;
- }
+ if (ctx->priv_ctx_tx)
+ return -EEXIST;
start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
- if (!start_marker_record) {
- rc = -ENOMEM;
- goto out;
- }
+ if (!start_marker_record)
+ return -ENOMEM;
offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
if (!offload_ctx) {
@@ -939,17 +931,11 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
if (skb)
TCP_SKB_CB(skb)->eor = 1;
- /* We support starting offload on multiple sockets
- * concurrently, so we only need a read lock here.
- * This lock must precede get_netdev_for_sock to prevent races between
- * NETDEV_DOWN and setsockopt.
- */
- down_read(&device_offload_lock);
netdev = get_netdev_for_sock(sk);
if (!netdev) {
pr_err_ratelimited("%s: netdev not found\n", __func__);
rc = -EINVAL;
- goto release_lock;
+ goto disable_cad;
}
if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
@@ -960,10 +946,15 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
/* Avoid offloading if the device is down
* We don't want to offload new flows after
* the NETDEV_DOWN event
+ *
+ * device_offload_lock is taken in tls_devices's NETDEV_DOWN
+ * handler thus protecting from the device going down before
+ * ctx was added to tls_device_list.
*/
+ down_read(&device_offload_lock);
if (!(netdev->flags & IFF_UP)) {
rc = -EINVAL;
- goto release_netdev;
+ goto release_lock;
}
ctx->priv_ctx_tx = offload_ctx;
@@ -971,9 +962,10 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
&ctx->crypto_send.info,
tcp_sk(sk)->write_seq);
if (rc)
- goto release_netdev;
+ goto release_lock;
tls_device_attach(ctx, sk, netdev);
+ up_read(&device_offload_lock);
/* following this assignment tls_is_sk_tx_device_offloaded
* will return true and the context might be accessed
@@ -981,13 +973,14 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
*/
smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
dev_put(netdev);
- up_read(&device_offload_lock);
- goto out;
-release_netdev:
- dev_put(netdev);
+ return 0;
+
release_lock:
up_read(&device_offload_lock);
+release_netdev:
+ dev_put(netdev);
+disable_cad:
clean_acked_data_disable(inet_csk(sk));
crypto_free_aead(offload_ctx->aead_send);
free_rec_seq:
@@ -999,7 +992,6 @@ free_offload_ctx:
ctx->priv_ctx_tx = NULL;
free_marker_record:
kfree(start_marker_record);
-out:
return rc;
}
@@ -1012,17 +1004,10 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
if (ctx->crypto_recv.info.version != TLS_1_2_VERSION)
return -EOPNOTSUPP;
- /* We support starting offload on multiple sockets
- * concurrently, so we only need a read lock here.
- * This lock must precede get_netdev_for_sock to prevent races between
- * NETDEV_DOWN and setsockopt.
- */
- down_read(&device_offload_lock);
netdev = get_netdev_for_sock(sk);
if (!netdev) {
pr_err_ratelimited("%s: netdev not found\n", __func__);
- rc = -EINVAL;
- goto release_lock;
+ return -EINVAL;
}
if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
@@ -1033,16 +1018,21 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
/* Avoid offloading if the device is down
* We don't want to offload new flows after
* the NETDEV_DOWN event
+ *
+ * device_offload_lock is taken in tls_devices's NETDEV_DOWN
+ * handler thus protecting from the device going down before
+ * ctx was added to tls_device_list.
*/
+ down_read(&device_offload_lock);
if (!(netdev->flags & IFF_UP)) {
rc = -EINVAL;
- goto release_netdev;
+ goto release_lock;
}
context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
if (!context) {
rc = -ENOMEM;
- goto release_netdev;
+ goto release_lock;
}
context->resync_nh_reset = 1;
@@ -1058,7 +1048,11 @@ int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx)
goto free_sw_resources;
tls_device_attach(ctx, sk, netdev);
- goto release_netdev;
+ up_read(&device_offload_lock);
+
+ dev_put(netdev);
+
+ return 0;
free_sw_resources:
up_read(&device_offload_lock);
@@ -1066,10 +1060,10 @@ free_sw_resources:
down_read(&device_offload_lock);
release_ctx:
ctx->priv_ctx_rx = NULL;
-release_netdev:
- dev_put(netdev);
release_lock:
up_read(&device_offload_lock);
+release_netdev:
+ dev_put(netdev);
return rc;
}
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 277f7c209fed..ac88877dcade 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -286,19 +286,14 @@ static void tls_sk_proto_cleanup(struct sock *sk,
kfree(ctx->tx.rec_seq);
kfree(ctx->tx.iv);
tls_sw_release_resources_tx(sk);
-#ifdef CONFIG_TLS_DEVICE
} else if (ctx->tx_conf == TLS_HW) {
tls_device_free_resources_tx(sk);
-#endif
}
if (ctx->rx_conf == TLS_SW)
tls_sw_release_resources_rx(sk);
-
-#ifdef CONFIG_TLS_DEVICE
- if (ctx->rx_conf == TLS_HW)
+ else if (ctx->rx_conf == TLS_HW)
tls_device_offload_cleanup_rx(sk);
-#endif
}
static void tls_sk_proto_close(struct sock *sk, long timeout)
@@ -331,7 +326,7 @@ static void tls_sk_proto_close(struct sock *sk, long timeout)
tls_sw_strparser_done(ctx);
if (ctx->rx_conf == TLS_SW)
tls_sw_free_ctx_rx(ctx);
- ctx->sk_proto_close(sk, timeout);
+ ctx->sk_proto->close(sk, timeout);
if (free_ctx)
tls_ctx_free(sk, ctx);
@@ -451,7 +446,8 @@ static int tls_getsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS)
- return ctx->getsockopt(sk, level, optname, optval, optlen);
+ return ctx->sk_proto->getsockopt(sk, level,
+ optname, optval, optlen);
return do_tls_getsockopt(sk, optname, optval, optlen);
}
@@ -536,26 +532,18 @@ static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
}
if (tx) {
-#ifdef CONFIG_TLS_DEVICE
rc = tls_set_device_offload(sk, ctx);
conf = TLS_HW;
if (rc) {
-#else
- {
-#endif
rc = tls_set_sw_offload(sk, ctx, 1);
if (rc)
goto err_crypto_info;
conf = TLS_SW;
}
} else {
-#ifdef CONFIG_TLS_DEVICE
rc = tls_set_device_offload_rx(sk, ctx);
conf = TLS_HW;
if (rc) {
-#else
- {
-#endif
rc = tls_set_sw_offload(sk, ctx, 0);
if (rc)
goto err_crypto_info;
@@ -609,7 +597,8 @@ static int tls_setsockopt(struct sock *sk, int level, int optname,
struct tls_context *ctx = tls_get_ctx(sk);
if (level != SOL_TLS)
- return ctx->setsockopt(sk, level, optname, optval, optlen);
+ return ctx->sk_proto->setsockopt(sk, level, optname, optval,
+ optlen);
return do_tls_setsockopt(sk, optname, optval, optlen);
}
@@ -624,10 +613,7 @@ static struct tls_context *create_ctx(struct sock *sk)
return NULL;
rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
- ctx->setsockopt = sk->sk_prot->setsockopt;
- ctx->getsockopt = sk->sk_prot->getsockopt;
- ctx->sk_proto_close = sk->sk_prot->close;
- ctx->unhash = sk->sk_prot->unhash;
+ ctx->sk_proto = sk->sk_prot;
return ctx;
}
@@ -683,9 +669,6 @@ static int tls_hw_prot(struct sock *sk)
spin_unlock_bh(&device_spinlock);
tls_build_proto(sk);
- ctx->hash = sk->sk_prot->hash;
- ctx->unhash = sk->sk_prot->unhash;
- ctx->sk_proto_close = sk->sk_prot->close;
ctx->sk_destruct = sk->sk_destruct;
sk->sk_destruct = tls_hw_sk_destruct;
ctx->rx_conf = TLS_HW_RECORD;
@@ -717,7 +700,7 @@ static void tls_hw_unhash(struct sock *sk)
}
}
spin_unlock_bh(&device_spinlock);
- ctx->unhash(sk);
+ ctx->sk_proto->unhash(sk);
}
static int tls_hw_hash(struct sock *sk)
@@ -726,7 +709,7 @@ static int tls_hw_hash(struct sock *sk)
struct tls_device *dev;
int err;
- err = ctx->hash(sk);
+ err = ctx->sk_proto->hash(sk);
spin_lock_bh(&device_spinlock);
list_for_each_entry(dev, &device_list, dev_list) {
if (dev->hash) {
@@ -816,7 +799,6 @@ static int tls_init(struct sock *sk)
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
- ctx->sk_proto = sk->sk_prot;
update_sk_prot(sk, ctx);
out:
write_unlock_bh(&sk->sk_callback_lock);
@@ -828,12 +810,10 @@ static void tls_update(struct sock *sk, struct proto *p)
struct tls_context *ctx;
ctx = tls_get_ctx(sk);
- if (likely(ctx)) {
- ctx->sk_proto_close = p->close;
+ if (likely(ctx))
ctx->sk_proto = p;
- } else {
+ else
sk->sk_prot = p;
- }
}
static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
@@ -927,9 +907,7 @@ static int __init tls_register(void)
tls_sw_proto_ops = inet_stream_ops;
tls_sw_proto_ops.splice_read = tls_sw_splice_read;
-#ifdef CONFIG_TLS_DEVICE
tls_device_init();
-#endif
tcp_register_ulp(&tcp_tls_ulp_ops);
return 0;
@@ -938,9 +916,7 @@ static int __init tls_register(void)
static void __exit tls_unregister(void)
{
tcp_unregister_ulp(&tcp_tls_ulp_ops);
-#ifdef CONFIG_TLS_DEVICE
tls_device_cleanup();
-#endif
}
module_init(tls_register);
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 91d21b048a9b..c2b5e0d2ba1a 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -1489,13 +1489,12 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
int pad, err = 0;
if (!ctx->decrypted) {
-#ifdef CONFIG_TLS_DEVICE
if (tls_ctx->rx_conf == TLS_HW) {
err = tls_device_decrypted(sk, skb);
if (err < 0)
return err;
}
-#endif
+
/* Still not decrypted after tls_device */
if (!ctx->decrypted) {
err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
@@ -2014,10 +2013,9 @@ static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
ret = -EINVAL;
goto read_failure;
}
-#ifdef CONFIG_TLS_DEVICE
+
tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
TCP_SKB_CB(skb)->seq + rxm->offset);
-#endif
return data_len + TLS_HEADER_SIZE;
read_failure: