diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_main.c | 19 | ||||
-rw-r--r-- | net/tls/tls_sw.c | 9 |
2 files changed, 16 insertions, 12 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c index 0d379970960e..20cd93be6236 100644 --- a/net/tls/tls_main.c +++ b/net/tls/tls_main.c @@ -114,6 +114,7 @@ int tls_push_sg(struct sock *sk, size = sg->length - offset; offset += sg->offset; + ctx->in_tcp_sendpages = true; while (1) { if (sg_is_last(sg)) sendpage_flags = flags; @@ -134,6 +135,7 @@ retry: offset -= sg->offset; ctx->partially_sent_offset = offset; ctx->partially_sent_record = (void *)sg; + ctx->in_tcp_sendpages = false; return ret; } @@ -148,6 +150,8 @@ retry: } clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); + ctx->in_tcp_sendpages = false; + ctx->sk_write_space(sk); return 0; } @@ -217,6 +221,10 @@ static void tls_write_space(struct sock *sk) { struct tls_context *ctx = tls_get_ctx(sk); + /* We are already sending pages, ignore notification */ + if (ctx->in_tcp_sendpages) + return; + if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) { gfp_t sk_allocation = sk->sk_allocation; int rc; @@ -241,16 +249,13 @@ static void tls_sk_proto_close(struct sock *sk, long timeout) struct tls_context *ctx = tls_get_ctx(sk); long timeo = sock_sndtimeo(sk, 0); void (*sk_proto_close)(struct sock *sk, long timeout); + bool free_ctx = false; lock_sock(sk); sk_proto_close = ctx->sk_proto_close; - if (ctx->conf == TLS_HW_RECORD) - goto skip_tx_cleanup; - - if (ctx->conf == TLS_BASE) { - kfree(ctx); - ctx = NULL; + if (ctx->conf == TLS_BASE || ctx->conf == TLS_HW_RECORD) { + free_ctx = true; goto skip_tx_cleanup; } @@ -287,7 +292,7 @@ skip_tx_cleanup: /* free ctx for TLS_HW_RECORD, used by tcp_set_state * for sk->sk_prot->unhash [tls_hw_unhash] */ - if (ctx && ctx->conf == TLS_HW_RECORD) + if (free_ctx) kfree(ctx); } diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 71e79597f940..e1c93ce74e0f 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -680,7 +680,6 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb, struct scatterlist *sgin = &sgin_arr[0]; struct strp_msg *rxm = strp_msg(skb); int ret, nsg = ARRAY_SIZE(sgin_arr); - char aad_recv[TLS_AAD_SPACE_SIZE]; struct sk_buff *unused; ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, @@ -698,13 +697,13 @@ static int decrypt_skb(struct sock *sk, struct sk_buff *skb, } sg_init_table(sgin, nsg); - sg_set_buf(&sgin[0], aad_recv, sizeof(aad_recv)); + sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE); nsg = skb_to_sgvec(skb, &sgin[1], rxm->offset + tls_ctx->rx.prepend_size, rxm->full_len - tls_ctx->rx.prepend_size); - tls_make_aad(aad_recv, + tls_make_aad(ctx->rx_aad_ciphertext, rxm->full_len - tls_ctx->rx.overhead_size, tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size, @@ -803,12 +802,12 @@ int tls_sw_recvmsg(struct sock *sk, if (to_copy <= len && page_count < MAX_SKB_FRAGS && likely(!(flags & MSG_PEEK))) { struct scatterlist sgin[MAX_SKB_FRAGS + 1]; - char unused[21]; int pages = 0; zc = true; sg_init_table(sgin, MAX_SKB_FRAGS + 1); - sg_set_buf(&sgin[0], unused, 13); + sg_set_buf(&sgin[0], ctx->rx_aad_plaintext, + TLS_AAD_SPACE_SIZE); err = zerocopy_from_iter(sk, &msg->msg_iter, to_copy, &pages, |