aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/src
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2016-11-10 16:28:48 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2016-11-10 16:28:48 +0100
commitcc3d7df096a88cdf96d016bdcb2f78fa03abb6f3 (patch)
tree88858692a465c1c351fa0634b1ba74828f15df7f /src
parentchacha20poly1305: don't forget version header (diff)
downloadwireguard-monolithic-historical-cc3d7df096a88cdf96d016bdcb2f78fa03abb6f3.tar.xz
wireguard-monolithic-historical-cc3d7df096a88cdf96d016bdcb2f78fa03abb6f3.zip
curve25519: use kmalloc in order to not overflow stackexperimental-0.0.20161110
On MIPS, the IRQ and SoftIRQ handlers share the stack with whatever kernel thread was interrupted. This means that Curve25519 can be interrupted by, say, an ethernet controller, that then gets handled by a SoftIRQ. If something like l2tp is being used, which uses quite a bit of stack, then by the time the SoftIRQ handler gets to WireGuard code and calls into the stack-heavy ChaPoly functions, our 8k stack is shot. In other words, since Curve25519 is such a big consumer of stack, if it's interrupted by anything else that uses a healthy amount of stack, then disaster strikes. The solution here is just to allocate using kmalloc. This is quite ugly, and if performance becomes an issue, we might consider moving to a kmem_cache allocator, or even having each peer keep its own preallocated space. But for now, we'll try this.
Diffstat (limited to 'src')
-rw-r--r--src/crypto/curve25519.c349
1 files changed, 260 insertions, 89 deletions
diff --git a/src/crypto/curve25519.c b/src/crypto/curve25519.c
index b96b69c..afc2e8a 100644
--- a/src/crypto/curve25519.c
+++ b/src/crypto/curve25519.c
@@ -10,6 +10,12 @@
#include <linux/random.h>
#include <crypto/algapi.h>
+#define ARCH_HAS_SEPARATE_IRQ_STACK
+
+#if defined(CONFIG_MIPS) /* TODO: add other archs that are missing a separate IRQ stack. */
+#undef ARCH_HAS_SEPARATE_IRQ_STACK
+#endif
+
static __always_inline void normalize_secret(uint8_t secret[CURVE25519_POINT_SIZE])
{
secret[0] &= 248;
@@ -975,6 +981,96 @@ static void fcontract(uint8_t *output, limb *input_limbs)
#undef F
}
+/* Conditionally swap two reduced-form limb arrays if 'iswap' is 1, but leave
+ * them unchanged if 'iswap' is 0. Runs in data-invariant time to avoid
+ * side-channel attacks.
+ *
+ * NOTE that this function requires that 'iswap' be 1 or 0; other values give
+ * wrong results. Also, the two limb arrays must be in reduced-coefficient,
+ * reduced-degree form: the values in a[10..19] or b[10..19] aren't swapped,
+ * and all all values in a[0..9],b[0..9] must have magnitude less than
+ * INT32_MAX. */
+static void swap_conditional(limb a[19], limb b[19], limb iswap)
+{
+ unsigned i;
+ const int32_t swap = (int32_t) -iswap;
+
+ for (i = 0; i < 10; ++i) {
+ const int32_t x = swap & ( ((int32_t)a[i]) ^ ((int32_t)b[i]) );
+ a[i] = ((int32_t)a[i]) ^ x;
+ b[i] = ((int32_t)b[i]) ^ x;
+ }
+}
+
+static void crecip(limb *out, const limb *z)
+{
+ limb z2[10];
+ limb z9[10];
+ limb z11[10];
+ limb z2_5_0[10];
+ limb z2_10_0[10];
+ limb z2_20_0[10];
+ limb z2_50_0[10];
+ limb z2_100_0[10];
+ limb t0[10];
+ limb t1[10];
+ int i;
+
+ /* 2 */ fsquare(z2,z);
+ /* 4 */ fsquare(t1,z2);
+ /* 8 */ fsquare(t0,t1);
+ /* 9 */ fmul(z9,t0,z);
+ /* 11 */ fmul(z11,z9,z2);
+ /* 22 */ fsquare(t0,z11);
+ /* 2^5 - 2^0 = 31 */ fmul(z2_5_0,t0,z9);
+
+ /* 2^6 - 2^1 */ fsquare(t0,z2_5_0);
+ /* 2^7 - 2^2 */ fsquare(t1,t0);
+ /* 2^8 - 2^3 */ fsquare(t0,t1);
+ /* 2^9 - 2^4 */ fsquare(t1,t0);
+ /* 2^10 - 2^5 */ fsquare(t0,t1);
+ /* 2^10 - 2^0 */ fmul(z2_10_0,t0,z2_5_0);
+
+ /* 2^11 - 2^1 */ fsquare(t0,z2_10_0);
+ /* 2^12 - 2^2 */ fsquare(t1,t0);
+ /* 2^20 - 2^10 */ for (i = 2; i < 10; i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
+ /* 2^20 - 2^0 */ fmul(z2_20_0,t1,z2_10_0);
+
+ /* 2^21 - 2^1 */ fsquare(t0,z2_20_0);
+ /* 2^22 - 2^2 */ fsquare(t1,t0);
+ /* 2^40 - 2^20 */ for (i = 2; i < 20; i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
+ /* 2^40 - 2^0 */ fmul(t0,t1,z2_20_0);
+
+ /* 2^41 - 2^1 */ fsquare(t1,t0);
+ /* 2^42 - 2^2 */ fsquare(t0,t1);
+ /* 2^50 - 2^10 */ for (i = 2; i < 10; i += 2) { fsquare(t1,t0); fsquare(t0,t1); }
+ /* 2^50 - 2^0 */ fmul(z2_50_0,t0,z2_10_0);
+
+ /* 2^51 - 2^1 */ fsquare(t0,z2_50_0);
+ /* 2^52 - 2^2 */ fsquare(t1,t0);
+ /* 2^100 - 2^50 */ for (i = 2; i < 50; i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
+ /* 2^100 - 2^0 */ fmul(z2_100_0,t1,z2_50_0);
+
+ /* 2^101 - 2^1 */ fsquare(t1,z2_100_0);
+ /* 2^102 - 2^2 */ fsquare(t0,t1);
+ /* 2^200 - 2^100 */ for (i = 2; i < 100; i += 2) { fsquare(t1,t0); fsquare(t0,t1); }
+ /* 2^200 - 2^0 */ fmul(t1,t0,z2_100_0);
+
+ /* 2^201 - 2^1 */ fsquare(t0,t1);
+ /* 2^202 - 2^2 */ fsquare(t1,t0);
+ /* 2^250 - 2^50 */ for (i = 2; i < 50; i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
+ /* 2^250 - 2^0 */ fmul(t0,t1,z2_50_0);
+
+ /* 2^251 - 2^1 */ fsquare(t1,t0);
+ /* 2^252 - 2^2 */ fsquare(t0,t1);
+ /* 2^253 - 2^3 */ fsquare(t1,t0);
+ /* 2^254 - 2^4 */ fsquare(t0,t1);
+ /* 2^255 - 2^5 */ fsquare(t1,t0);
+ /* 2^255 - 21 */ fmul(out,t1,z11);
+}
+
+
+#ifdef ARCH_HAS_SEPARATE_IRQ_STACK
/* Input: Q, Q', Q-Q'
* Output: 2Q, Q+Q'
*
@@ -1062,27 +1158,6 @@ static void fmonty(limb *x2, limb *z2, /* output 2Q */
/* |z2|i| < 2^26 */
}
-/* Conditionally swap two reduced-form limb arrays if 'iswap' is 1, but leave
- * them unchanged if 'iswap' is 0. Runs in data-invariant time to avoid
- * side-channel attacks.
- *
- * NOTE that this function requires that 'iswap' be 1 or 0; other values give
- * wrong results. Also, the two limb arrays must be in reduced-coefficient,
- * reduced-degree form: the values in a[10..19] or b[10..19] aren't swapped,
- * and all all values in a[0..9],b[0..9] must have magnitude less than
- * INT32_MAX. */
-static void swap_conditional(limb a[19], limb b[19], limb iswap)
-{
- unsigned i;
- const int32_t swap = (int32_t) -iswap;
-
- for (i = 0; i < 10; ++i) {
- const int32_t x = swap & ( ((int32_t)a[i]) ^ ((int32_t)b[i]) );
- a[i] = ((int32_t)a[i]) ^ x;
- b[i] = ((int32_t)b[i]) ^ x;
- }
-}
-
/* Calculates nQ where Q is the x-coordinate of a point on the curve
*
* resultx/resultz: the x coordinate of the resulting curve point (short form)
@@ -1135,73 +1210,6 @@ static void cmult(limb *resultx, limb *resultz, const uint8_t *n, const limb *q)
memcpy(resultz, nqz, sizeof(limb) * 10);
}
-static void crecip(limb *out, const limb *z)
-{
- limb z2[10];
- limb z9[10];
- limb z11[10];
- limb z2_5_0[10];
- limb z2_10_0[10];
- limb z2_20_0[10];
- limb z2_50_0[10];
- limb z2_100_0[10];
- limb t0[10];
- limb t1[10];
- int i;
-
- /* 2 */ fsquare(z2,z);
- /* 4 */ fsquare(t1,z2);
- /* 8 */ fsquare(t0,t1);
- /* 9 */ fmul(z9,t0,z);
- /* 11 */ fmul(z11,z9,z2);
- /* 22 */ fsquare(t0,z11);
- /* 2^5 - 2^0 = 31 */ fmul(z2_5_0,t0,z9);
-
- /* 2^6 - 2^1 */ fsquare(t0,z2_5_0);
- /* 2^7 - 2^2 */ fsquare(t1,t0);
- /* 2^8 - 2^3 */ fsquare(t0,t1);
- /* 2^9 - 2^4 */ fsquare(t1,t0);
- /* 2^10 - 2^5 */ fsquare(t0,t1);
- /* 2^10 - 2^0 */ fmul(z2_10_0,t0,z2_5_0);
-
- /* 2^11 - 2^1 */ fsquare(t0,z2_10_0);
- /* 2^12 - 2^2 */ fsquare(t1,t0);
- /* 2^20 - 2^10 */ for (i = 2; i < 10; i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
- /* 2^20 - 2^0 */ fmul(z2_20_0,t1,z2_10_0);
-
- /* 2^21 - 2^1 */ fsquare(t0,z2_20_0);
- /* 2^22 - 2^2 */ fsquare(t1,t0);
- /* 2^40 - 2^20 */ for (i = 2; i < 20; i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
- /* 2^40 - 2^0 */ fmul(t0,t1,z2_20_0);
-
- /* 2^41 - 2^1 */ fsquare(t1,t0);
- /* 2^42 - 2^2 */ fsquare(t0,t1);
- /* 2^50 - 2^10 */ for (i = 2; i < 10; i += 2) { fsquare(t1,t0); fsquare(t0,t1); }
- /* 2^50 - 2^0 */ fmul(z2_50_0,t0,z2_10_0);
-
- /* 2^51 - 2^1 */ fsquare(t0,z2_50_0);
- /* 2^52 - 2^2 */ fsquare(t1,t0);
- /* 2^100 - 2^50 */ for (i = 2; i < 50; i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
- /* 2^100 - 2^0 */ fmul(z2_100_0,t1,z2_50_0);
-
- /* 2^101 - 2^1 */ fsquare(t1,z2_100_0);
- /* 2^102 - 2^2 */ fsquare(t0,t1);
- /* 2^200 - 2^100 */ for (i = 2; i < 100; i += 2) { fsquare(t1,t0); fsquare(t0,t1); }
- /* 2^200 - 2^0 */ fmul(t1,t0,z2_100_0);
-
- /* 2^201 - 2^1 */ fsquare(t0,t1);
- /* 2^202 - 2^2 */ fsquare(t1,t0);
- /* 2^250 - 2^50 */ for (i = 2; i < 50; i += 2) { fsquare(t0,t1); fsquare(t1,t0); }
- /* 2^250 - 2^0 */ fmul(t0,t1,z2_50_0);
-
- /* 2^251 - 2^1 */ fsquare(t1,t0);
- /* 2^252 - 2^2 */ fsquare(t0,t1);
- /* 2^253 - 2^3 */ fsquare(t1,t0);
- /* 2^254 - 2^4 */ fsquare(t0,t1);
- /* 2^255 - 2^5 */ fsquare(t1,t0);
- /* 2^255 - 21 */ fmul(out,t1,z11);
-}
-
void curve25519(uint8_t mypublic[CURVE25519_POINT_SIZE], const uint8_t secret[CURVE25519_POINT_SIZE], const uint8_t basepoint[CURVE25519_POINT_SIZE])
{
limb bp[10], x[10], z[11], zmone[10];
@@ -1222,8 +1230,171 @@ void curve25519(uint8_t mypublic[CURVE25519_POINT_SIZE], const uint8_t secret[CU
memzero_explicit(z, sizeof(z));
memzero_explicit(zmone, sizeof(zmone));
}
-#endif
+#else
+struct other_stack {
+ limb origx[10], origxprime[10], zzz[19], xx[19], zz[19], xxprime[19], zzprime[19], zzzprime[19], xxxprime[19];
+ limb a[19], b[19], c[19], d[19], e[19], f[19], g[19], h[19];
+ limb bp[10], x[10], z[11], zmone[10];
+ uint8_t ee[32];
+};
+
+/* Input: Q, Q', Q-Q'
+ * Output: 2Q, Q+Q'
+ *
+ * x2 z3: long form
+ * x3 z3: long form
+ * x z: short form, destroyed
+ * xprime zprime: short form, destroyed
+ * qmqp: short form, preserved
+ *
+ * On entry and exit, the absolute value of the limbs of all inputs and outputs
+ * are < 2^26. */
+static void fmonty(struct other_stack *s,
+ limb *x2, limb *z2, /* output 2Q */
+ limb *x3, limb *z3, /* output Q + Q' */
+ limb *x, limb *z, /* input Q */
+ limb *xprime, limb *zprime, /* input Q' */
+ const limb *qmqp /* input Q - Q' */)
+{
+ memcpy(s->origx, x, 10 * sizeof(limb));
+ fsum(x, z);
+ /* |x[i]| < 2^27 */
+ fdifference(z, s->origx); /* does x - z */
+ /* |z[i]| < 2^27 */
+ memcpy(s->origxprime, xprime, sizeof(limb) * 10);
+ fsum(xprime, zprime);
+ /* |xprime[i]| < 2^27 */
+ fdifference(zprime, s->origxprime);
+ /* |zprime[i]| < 2^27 */
+ fproduct(s->xxprime, xprime, z);
+ /* |s->xxprime[i]| < 14*2^54: the largest product of two limbs will be <
+ * 2^(27+27) and fproduct adds together, at most, 14 of those products.
+ * (Approximating that to 2^58 doesn't work out.) */
+ fproduct(s->zzprime, x, zprime);
+ /* |s->zzprime[i]| < 14*2^54 */
+ freduce_degree(s->xxprime);
+ freduce_coefficients(s->xxprime);
+ /* |s->xxprime[i]| < 2^26 */
+ freduce_degree(s->zzprime);
+ freduce_coefficients(s->zzprime);
+ /* |s->zzprime[i]| < 2^26 */
+ memcpy(s->origxprime, s->xxprime, sizeof(limb) * 10);
+ fsum(s->xxprime, s->zzprime);
+ /* |s->xxprime[i]| < 2^27 */
+ fdifference(s->zzprime, s->origxprime);
+ /* |s->zzprime[i]| < 2^27 */
+ fsquare(s->xxxprime, s->xxprime);
+ /* |s->xxxprime[i]| < 2^26 */
+ fsquare(s->zzzprime, s->zzprime);
+ /* |s->zzzprime[i]| < 2^26 */
+ fproduct(s->zzprime, s->zzzprime, qmqp);
+ /* |s->zzprime[i]| < 14*2^52 */
+ freduce_degree(s->zzprime);
+ freduce_coefficients(s->zzprime);
+ /* |s->zzprime[i]| < 2^26 */
+ memcpy(x3, s->xxxprime, sizeof(limb) * 10);
+ memcpy(z3, s->zzprime, sizeof(limb) * 10);
+
+ fsquare(s->xx, x);
+ /* |s->xx[i]| < 2^26 */
+ fsquare(s->zz, z);
+ /* |s->zz[i]| < 2^26 */
+ fproduct(x2, s->xx, s->zz);
+ /* |x2[i]| < 14*2^52 */
+ freduce_degree(x2);
+ freduce_coefficients(x2);
+ /* |x2[i]| < 2^26 */
+ fdifference(s->zz, s->xx); // does s->zz = s->xx - s->zz
+ /* |s->zz[i]| < 2^27 */
+ memset(s->zzz + 10, 0, sizeof(limb) * 9);
+ fscalar_product(s->zzz, s->zz, 121665);
+ /* |s->zzz[i]| < 2^(27+17) */
+ /* No need to call freduce_degree here:
+ fscalar_product doesn't increase the degree of its input. */
+ freduce_coefficients(s->zzz);
+ /* |s->zzz[i]| < 2^26 */
+ fsum(s->zzz, s->xx);
+ /* |s->zzz[i]| < 2^27 */
+ fproduct(z2, s->zz, s->zzz);
+ /* |z2[i]| < 14*2^(26+27) */
+ freduce_degree(z2);
+ freduce_coefficients(z2);
+ /* |z2|i| < 2^26 */
+}
+
+/* Calculates nQ where Q is the x-coordinate of a point on the curve
+ *
+ * resultx/resultz: the x coordinate of the resulting curve point (short form)
+ * n: a little endian, 32-byte number
+ * q: a point of the curve (short form) */
+static void cmult(struct other_stack *s, limb *resultx, limb *resultz, const uint8_t *n, const limb *q)
+{
+ unsigned i, j;
+ limb *nqpqx = s->a, *nqpqz = s->b, *nqx = s->c, *nqz = s->d, *t;
+ limb *nqpqx2 = s->e, *nqpqz2 = s->f, *nqx2 = s->g, *nqz2 = s->h;
+
+ *nqpqz = *nqx = *nqpqz2 = *nqz2 = 1;
+ memcpy(nqpqx, q, sizeof(limb) * 10);
+
+ for (i = 0; i < 32; ++i) {
+ uint8_t byte = n[31 - i];
+ for (j = 0; j < 8; ++j) {
+ const limb bit = byte >> 7;
+
+ swap_conditional(nqx, nqpqx, bit);
+ swap_conditional(nqz, nqpqz, bit);
+ fmonty(s,
+ nqx2, nqz2,
+ nqpqx2, nqpqz2,
+ nqx, nqz,
+ nqpqx, nqpqz,
+ q);
+ swap_conditional(nqx2, nqpqx2, bit);
+ swap_conditional(nqz2, nqpqz2, bit);
+
+ t = nqx;
+ nqx = nqx2;
+ nqx2 = t;
+ t = nqz;
+ nqz = nqz2;
+ nqz2 = t;
+ t = nqpqx;
+ nqpqx = nqpqx2;
+ nqpqx2 = t;
+ t = nqpqz;
+ nqpqz = nqpqz2;
+ nqpqz2 = t;
+
+ byte <<= 1;
+ }
+ }
+
+ memcpy(resultx, nqx, sizeof(limb) * 10);
+ memcpy(resultz, nqz, sizeof(limb) * 10);
+}
+
+void curve25519(uint8_t mypublic[CURVE25519_POINT_SIZE], const uint8_t secret[CURVE25519_POINT_SIZE], const uint8_t basepoint[CURVE25519_POINT_SIZE])
+{
+ struct other_stack *s = kzalloc(sizeof(struct other_stack), GFP_KERNEL);
+ if (unlikely(!s)) {
+ memset(mypublic, 0, CURVE25519_POINT_SIZE);
+ return;
+ }
+
+ memcpy(s->ee, secret, 32);
+ normalize_secret(s->ee);
+
+ fexpand(s->bp, basepoint);
+ cmult(s, s->x, s->z, s->ee, s->bp);
+ crecip(s->zmone, s->z);
+ fmul(s->z, s->x, s->zmone);
+ fcontract(mypublic, s->z);
+
+ kzfree(s);
+}
+#endif
+#endif
void curve25519_generate_secret(uint8_t secret[CURVE25519_POINT_SIZE])
{