aboutsummaryrefslogtreecommitdiffstats
path: root/net/tls
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls')
-rw-r--r--net/tls/tls_main.c14
-rw-r--r--net/tls/tls_sw.c54
2 files changed, 51 insertions, 17 deletions
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 28887cf628b8..78cb4a584080 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -55,6 +55,8 @@ enum {
static struct proto *saved_tcpv6_prot;
static DEFINE_MUTEX(tcpv6_prot_mutex);
+static struct proto *saved_tcpv4_prot;
+static DEFINE_MUTEX(tcpv4_prot_mutex);
static LIST_HEAD(device_list);
static DEFINE_SPINLOCK(device_spinlock);
static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
@@ -700,6 +702,16 @@ static int tls_init(struct sock *sk)
mutex_unlock(&tcpv6_prot_mutex);
}
+ if (ip_ver == TLSV4 &&
+ unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) {
+ mutex_lock(&tcpv4_prot_mutex);
+ if (likely(sk->sk_prot != saved_tcpv4_prot)) {
+ build_protos(tls_prots[TLSV4], sk->sk_prot);
+ smp_store_release(&saved_tcpv4_prot, sk->sk_prot);
+ }
+ mutex_unlock(&tcpv4_prot_mutex);
+ }
+
ctx->tx_conf = TLS_BASE;
ctx->rx_conf = TLS_BASE;
update_sk_prot(sk, ctx);
@@ -731,8 +743,6 @@ static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
static int __init tls_register(void)
{
- build_protos(tls_prots[TLSV4], &tcp_prot);
-
tls_sw_proto_ops = inet_stream_ops;
tls_sw_proto_ops.splice_read = tls_sw_splice_read;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 29b27858fff1..11cdc8f7db63 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -686,16 +686,24 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
struct sk_psock *psock;
struct sock *sk_redir;
struct tls_rec *rec;
+ bool enospc, policy;
int err = 0, send;
- bool enospc;
+ u32 delta = 0;
+ policy = !(flags & MSG_SENDPAGE_NOPOLICY);
psock = sk_psock_get(sk);
- if (!psock)
+ if (!psock || !policy)
return tls_push_record(sk, flags, record_type);
more_data:
enospc = sk_msg_full(msg);
- if (psock->eval == __SK_NONE)
+ if (psock->eval == __SK_NONE) {
+ delta = msg->sg.size;
psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+ if (delta < msg->sg.size)
+ delta -= msg->sg.size;
+ else
+ delta = 0;
+ }
if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
!enospc && !full_record) {
err = -ENOSPC;
@@ -743,7 +751,7 @@ more_data:
msg->apply_bytes -= send;
if (msg->sg.size == 0)
tls_free_open_rec(sk);
- *copied -= send;
+ *copied -= (send + delta);
err = -EACCES;
}
@@ -1012,8 +1020,8 @@ send_end:
return copied ? copied : ret;
}
-int tls_sw_sendpage(struct sock *sk, struct page *page,
- int offset, size_t size, int flags)
+int tls_sw_do_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags)
{
long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
struct tls_context *tls_ctx = tls_get_ctx(sk);
@@ -1028,15 +1036,7 @@ int tls_sw_sendpage(struct sock *sk, struct page *page,
int ret = 0;
bool eor;
- if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
- MSG_SENDPAGE_NOTLAST))
- return -ENOTSUPP;
-
- /* No MSG_EOR from splice, only look at MSG_MORE */
eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
-
- lock_sock(sk);
-
sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
/* Wait till there is any pending write on socket */
@@ -1140,10 +1140,34 @@ wait_for_memory:
}
sendpage_end:
ret = sk_stream_error(sk, flags, ret);
- release_sock(sk);
return copied ? copied : ret;
}
+int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags)
+{
+ if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
+ MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
+ return -ENOTSUPP;
+
+ return tls_sw_do_sendpage(sk, page, offset, size, flags);
+}
+
+int tls_sw_sendpage(struct sock *sk, struct page *page,
+ int offset, size_t size, int flags)
+{
+ int ret;
+
+ if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
+ MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
+ return -ENOTSUPP;
+
+ lock_sock(sk);
+ ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
+ release_sock(sk);
+ return ret;
+}
+
static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
int flags, long timeo, int *err)
{