aboutsummaryrefslogtreecommitdiffstats
path: root/net/tls/tls_sw.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r--net/tls/tls_sw.c103
1 files changed, 76 insertions, 27 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 29d6af43dd24..d93f83f77864 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -42,8 +42,6 @@
#include <net/strparser.h>
#include <net/tls.h>
-#define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE
-
static int __skb_nsg(struct sk_buff *skb, int offset, int len,
unsigned int recursion_level)
{
@@ -121,23 +119,25 @@ static int skb_nsg(struct sk_buff *skb, int offset, int len)
}
static int padding_length(struct tls_sw_context_rx *ctx,
- struct tls_context *tls_ctx, struct sk_buff *skb)
+ struct tls_prot_info *prot, struct sk_buff *skb)
{
struct strp_msg *rxm = strp_msg(skb);
int sub = 0;
/* Determine zero-padding length */
- if (tls_ctx->prot_info.version == TLS_1_3_VERSION) {
+ if (prot->version == TLS_1_3_VERSION) {
char content_type = 0;
int err;
int back = 17;
while (content_type == 0) {
- if (back > rxm->full_len)
+ if (back > rxm->full_len - prot->prepend_size)
return -EBADMSG;
err = skb_copy_bits(skb,
rxm->offset + rxm->full_len - back,
&content_type, 1);
+ if (err)
+ return err;
if (content_type)
break;
sub++;
@@ -172,9 +172,17 @@ static void tls_decrypt_done(struct crypto_async_request *req, int err)
tls_err_abort(skb->sk, err);
} else {
struct strp_msg *rxm = strp_msg(skb);
- rxm->full_len -= padding_length(ctx, tls_ctx, skb);
- rxm->offset += prot->prepend_size;
- rxm->full_len -= prot->overhead_size;
+ int pad;
+
+ pad = padding_length(ctx, prot, skb);
+ if (pad < 0) {
+ ctx->async_wait.err = pad;
+ tls_err_abort(skb->sk, pad);
+ } else {
+ rxm->full_len -= pad;
+ rxm->offset += prot->prepend_size;
+ rxm->full_len -= prot->overhead_size;
+ }
}
/* After using skb->sk to propagate sk through crypto async callback
@@ -225,7 +233,7 @@ static int tls_do_decryption(struct sock *sk,
/* Using skb->sk to push sk through to crypto async callback
* handler. This allows propagating errors up to the socket
* if needed. It _must_ be cleared in the async handler
- * before kfree_skb is called. We _know_ skb->sk is NULL
+ * before consume_skb is called. We _know_ skb->sk is NULL
* because it is a clone from strparser.
*/
skb->sk = sk;
@@ -479,11 +487,18 @@ static int tls_do_encryption(struct sock *sk,
struct tls_rec *rec = ctx->open_rec;
struct sk_msg *msg_en = &rec->msg_encrypted;
struct scatterlist *sge = sk_msg_elem(msg_en, start);
- int rc;
+ int rc, iv_offset = 0;
- memcpy(rec->iv_data, tls_ctx->tx.iv, sizeof(rec->iv_data));
- xor_iv_with_seq(prot->version, rec->iv_data,
- tls_ctx->tx.rec_seq);
+ /* For CCM based ciphers, first byte of IV is a constant */
+ if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
+ rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
+ iv_offset = 1;
+ }
+
+ memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
+ prot->iv_size + prot->salt_size);
+
+ xor_iv_with_seq(prot->version, rec->iv_data, tls_ctx->tx.rec_seq);
sge->offset += prot->prepend_size;
sge->length -= prot->prepend_size;
@@ -1344,6 +1359,7 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
struct scatterlist *sgout = NULL;
const int data_len = rxm->full_len - prot->overhead_size +
prot->tail_size;
+ int iv_offset = 0;
if (*zc && (out_iov || out_sg)) {
if (out_iov)
@@ -1386,18 +1402,25 @@ static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
aad = (u8 *)(sgout + n_sgout);
iv = aad + prot->aad_size;
+ /* For CCM based ciphers, first byte of nonce+iv is always '2' */
+ if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
+ iv[0] = 2;
+ iv_offset = 1;
+ }
+
/* Prepare IV */
err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
- iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+ iv + iv_offset + prot->salt_size,
prot->iv_size);
if (err < 0) {
kfree(mem);
return err;
}
if (prot->version == TLS_1_3_VERSION)
- memcpy(iv, tls_ctx->rx.iv, crypto_aead_ivsize(ctx->aead_recv));
+ memcpy(iv + iv_offset, tls_ctx->rx.iv,
+ crypto_aead_ivsize(ctx->aead_recv));
else
- memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+ memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq);
@@ -1465,7 +1488,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
struct tls_prot_info *prot = &tls_ctx->prot_info;
int version = prot->version;
struct strp_msg *rxm = strp_msg(skb);
- int err = 0;
+ int pad, err = 0;
if (!ctx->decrypted) {
#ifdef CONFIG_TLS_DEVICE
@@ -1488,7 +1511,11 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
*zc = false;
}
- rxm->full_len -= padding_length(ctx, tls_ctx, skb);
+ pad = padding_length(ctx, prot, skb);
+ if (pad < 0)
+ return pad;
+
+ rxm->full_len -= pad;
rxm->offset += prot->prepend_size;
rxm->full_len -= prot->overhead_size;
tls_advance_record_sn(sk, &tls_ctx->rx, version);
@@ -1524,7 +1551,7 @@ static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
rxm->full_len -= len;
return false;
}
- kfree_skb(skb);
+ consume_skb(skb);
}
/* Finished with message */
@@ -1633,7 +1660,7 @@ static int process_rx_list(struct tls_sw_context_rx *ctx,
if (!is_peek) {
skb_unlink(skb, &ctx->rx_list);
- kfree_skb(skb);
+ consume_skb(skb);
}
skb = next_skb;
@@ -2144,14 +2171,15 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
struct tls_crypto_info *crypto_info;
struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
+ struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
struct tls_sw_context_tx *sw_ctx_tx = NULL;
struct tls_sw_context_rx *sw_ctx_rx = NULL;
struct cipher_context *cctx;
struct crypto_aead **aead;
struct strp_callbacks cb;
- u16 nonce_size, tag_size, iv_size, rec_seq_size;
+ u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
struct crypto_tfm *tfm;
- char *iv, *rec_seq, *key, *salt;
+ char *iv, *rec_seq, *key, *salt, *cipher_name;
size_t keysize;
int rc = 0;
@@ -2216,6 +2244,8 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
key = gcm_128_info->key;
salt = gcm_128_info->salt;
+ salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
+ cipher_name = "gcm(aes)";
break;
}
case TLS_CIPHER_AES_GCM_256: {
@@ -2231,6 +2261,25 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
key = gcm_256_info->key;
salt = gcm_256_info->salt;
+ salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
+ cipher_name = "gcm(aes)";
+ break;
+ }
+ case TLS_CIPHER_AES_CCM_128: {
+ nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
+ tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
+ iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
+ iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv;
+ rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
+ rec_seq =
+ ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq;
+ ccm_128_info =
+ (struct tls12_crypto_info_aes_ccm_128 *)crypto_info;
+ keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
+ key = ccm_128_info->key;
+ salt = ccm_128_info->salt;
+ salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
+ cipher_name = "ccm(aes)";
break;
}
default:
@@ -2260,16 +2309,16 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
prot->overhead_size = prot->prepend_size +
prot->tag_size + prot->tail_size;
prot->iv_size = iv_size;
- cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
- GFP_KERNEL);
+ prot->salt_size = salt_size;
+ cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
if (!cctx->iv) {
rc = -ENOMEM;
goto free_priv;
}
/* Note: 128 & 256 bit salt are the same size */
- memcpy(cctx->iv, salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
- memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
prot->rec_seq_size = rec_seq_size;
+ memcpy(cctx->iv, salt, salt_size);
+ memcpy(cctx->iv + salt_size, iv, iv_size);
cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
if (!cctx->rec_seq) {
rc = -ENOMEM;
@@ -2277,7 +2326,7 @@ int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
}
if (!*aead) {
- *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
+ *aead = crypto_alloc_aead(cipher_name, 0, 0);
if (IS_ERR(*aead)) {
rc = PTR_ERR(*aead);
*aead = NULL;