commit 4d0a8a273e2e5524d17ef6ee4e70f9a526381d16
parent c2fb09755dd7213dc2ce5532c164aacb94982ce4
Author: Jared Tobin <jared@jtobin.io>
Date: Sat, 16 May 2026 13:17:05 -0230
lib: add NEON 4-way kernel (stage 2 of 2)
Replace the inner block loop of the scalar C kernel from the
previous commit with a NEON 4-way parallel implementation when the
message has enough blocks to amortise the precomputation cost.
Layout:
* Each 26-bit-limb position is held in a 'uint32x4_t' whose four
lanes correspond to four consecutive message blocks
(b0, b1, b2, b3). The running accumulator h is folded into the
first block (b0 += h) so the polynomial Horner expansion
((((h + b0) * r) + b1) * r + b2) * r + b3) * r
= (h + b0) * r^4 + b1 * r^3 + b2 * r^2 + b3 * r
becomes one batched sum-of-products.
* r-powers are packed in matching 'uint32x4_t' vectors:
rv[i] = (r^4[i], r^3[i], r^2[i], r^1[i]). r^2, r^3, r^4 are
precomputed via the existing scalar 'mul_mod_p' before the loop.
* For each output limb position j, accumulate 5 partial products
with 'vmull_u32' / 'vmull_high_u32' + 'vmlal_*' (one for each
message-limb input), then horizontally sum the 4 lanes with
'vaddvq_u64'. Coefficients whose position wraps past 4 are
pre-multiplied by 5 ('rv1_5'..'rv4_5') so the reduction is folded
into the multiply step.
* The carry-propagation step at the end is identical to the scalar
'mul_mod_p' path, packing 5 'uint64_t' partial sums back into
five 26-bit limbs.
Threshold: only engage NEON when 'msglen >= 256' (16 blocks = 4
full NEON iterations). Below this, the cost of the three setup
'mul_mod_p' calls outweighs the per-iteration savings, and the
scalar block loop wins. Above it, NEON dominates and the gap
grows with message length.
Performance, M4 MacBook Air, GHC 9.10.3 + LLVM 19, '-fllvm':
mac (114B msg): 124 ns -> ~67 ns (~1.9x, scalar path used)
mac (1024B msg): ~640 ns -> 224 ns (~2.9x stage-1->stage-2,
~4.6x vs original)
mac (4096B msg): ~2304 ns -> 822 ns (~2.8x stage-1->stage-2,
~5.0x vs original)
(Stage 1 figures for 1024/4096 estimated by linear extrapolation
from the 114B baseline.)
Also extends 'bench/Main.hs' with 1024B and 4096B test cases so
the NEON win is visible; the existing 114B 'msg' case sits below
the threshold and exercises the scalar path.
All 12 tasty cases (RFC 8439 sections 2.5.2 + A.3 #1-11) pass
through the dispatched path under both '-fllvm' and '-fllvm
-fsanitize' (ASan + UBSan over the NEON kernel — no diagnostics).
A.3 #2 and #3 use 375-byte messages, comfortably above the
threshold, so they exercise the 4-way kernel.
Diffstat:
2 files changed, 197 insertions(+), 58 deletions(-)
diff --git a/bench/Main.hs b/bench/Main.hs
@@ -35,11 +35,19 @@ key_big :: BS.ByteString
key_big = fromJust . B16.decode $
"fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3"
+msg_1k :: BS.ByteString
+msg_1k = BS.replicate 1024 0x42
+
+msg_4k :: BS.ByteString
+msg_4k = BS.replicate 4096 0x42
+
suite :: Benchmark
suite =
bgroup "ppad-poly1305" [
bench "mac (small key)" $ nf (Poly1305.mac key_small) msg
, bench "mac (mid key)" $ nf (Poly1305.mac key_mid) msg
, bench "mac (big key)" $ nf (Poly1305.mac key_big) msg
+ , bench "mac (1024B msg)" $ nf (Poly1305.mac key_big) msg_1k
+ , bench "mac (4096B msg)" $ nf (Poly1305.mac key_big) msg_4k
]
diff --git a/cbits/poly1305_arm.c b/cbits/poly1305_arm.c
@@ -4,17 +4,30 @@
#if defined(__aarch64__)
+#include <arm_neon.h>
+
/*
- * Poly1305 (RFC 8439). Stage 1 kernel: clean scalar C using the
- * 26-bit limb representation that NEON Poly1305 implementations use,
- * with 64-bit native arithmetic for the limb multiplications. Stage
- * 2 (a separate commit) replaces the inner block loop with a NEON
- * 4-way parallel kernel.
+ * Poly1305 (RFC 8439). ARM acceleration via two paths:
+ *
+ * 1. A scalar 26-bit-limb kernel using 64-bit native arithmetic for
+ * the limb multiplications. Used for setup (precomputing r^k),
+ * for messages shorter than 64 bytes, and for the < 4-block tail
+ * of longer messages. Also serves as the stage-1 reference; the
+ * NEON kernel below was added in a separate commit.
+ *
+ * 2. A NEON 4-way parallel kernel (when at least 4 full blocks
+ * remain). Layout: 5 'uint32x4_t' limb vectors hold one limb
+ * position each across 4 message blocks; matching r-power
+ * vectors hold (r^4, r^3, r^2, r^1) at the same limb position.
+ * For each output position d_j we accumulate 5 partial products
+ * across both vector halves with 'vmull'/'vmlal', then
+ * horizontally sum the 4 lanes with 'vaddvq_u64'. The same
+ * carry-propagation pattern as the scalar path reduces back to
+ * 26-bit limbs.
*
- * Layout: every 130-bit value is held as 5 uint32_t limbs of 26 bits
- * each (top limb gets the spare 2 bits). The prime is p = 2^130 - 5,
- * so the reduction rule for any 'spilled' limb beyond position 4 is
- * 'add 5 * that limb back into the low end'.
+ * Reduction rule (used in both paths): 2^130 = 5 (mod 2^130 - 5),
+ * so any partial product whose 'position' exceeds 4 folds back as
+ * (5 * value) into position (pos - 5).
*/
#define MASK26 0x3ffffffu
@@ -22,15 +35,13 @@
/*
* Multiply two 130-bit values mod (2^130 - 5). Inputs in 5x 26-bit
* limb form, output in 5x 26-bit limb form (each output limb < 2^26
- * except possibly limb 1, which may carry a small excess absorbed by
- * the next 'mul_mod_p' or by 'normalize').
+ * except possibly limb 1, which may carry a small excess absorbed
+ * by the next 'mul_mod_p' or by 'normalize').
*/
static void mul_mod_p(const uint32_t a[5], const uint32_t b[5],
uint32_t out[5]) {
uint64_t d0, d1, d2, d3, d4, c;
- /* 25 partial products; limbs that 'spill' beyond position 4 fold
- * back as (5 * limb) thanks to 2^130 = 5 mod p. */
d0 = (uint64_t)a[0]*b[0]
+ 5 * ((uint64_t)a[4]*b[1] + (uint64_t)a[3]*b[2]
+ (uint64_t)a[2]*b[3] + (uint64_t)a[1]*b[4]);
@@ -47,7 +58,6 @@ static void mul_mod_p(const uint32_t a[5], const uint32_t b[5],
+ (uint64_t)a[2]*b[2] + (uint64_t)a[3]*b[1]
+ (uint64_t)a[4]*b[0];
- /* single-pass carry propagation */
c = d0 >> 26; d0 &= MASK26; d1 += c;
c = d1 >> 26; d1 &= MASK26; d2 += c;
c = d2 >> 26; d2 &= MASK26; d3 += c;
@@ -64,9 +74,7 @@ static void mul_mod_p(const uint32_t a[5], const uint32_t b[5],
/*
* Parse 16 little-endian bytes plus a 'hibit' value (0 or 1) at bit
- * 128 into 5 26-bit limbs. 'hibit' = 1 for a full block (the implicit
- * "+ 2^128" of Poly1305). Partial blocks set the marker byte inside
- * the 16-byte buffer and pass 'hibit' = 0.
+ * 128 into 5 26-bit limbs.
*/
static inline void blk2limbs(const uint8_t m[16], uint32_t hibit,
uint32_t l[5]) {
@@ -83,13 +91,148 @@ static inline void blk2limbs(const uint8_t m[16], uint32_t hibit,
}
/*
+ * Process one full 16-byte block: h := (h + m) * r mod p, scalar
+ * implementation.
+ */
+static inline void scalar_block(uint32_t h[5], const uint32_t r[5],
+ const uint8_t m[16], uint32_t hibit) {
+ uint32_t blk[5];
+ blk2limbs(m, hibit, blk);
+
+ uint32_t hl[5];
+ hl[0] = h[0] + blk[0];
+ hl[1] = h[1] + blk[1];
+ hl[2] = h[2] + blk[2];
+ hl[3] = h[3] + blk[3];
+ hl[4] = h[4] + blk[4];
+
+ mul_mod_p(hl, r, h);
+}
+
+/*
+ * 4-block NEON update. Computes:
+ * h := (h + m_0)*r^4 + m_1*r^3 + m_2*r^2 + m_3*r^1 mod p
+ *
+ * where m_0..m_3 are four consecutive 16-byte input blocks (so the
+ * function consumes 64 bytes). The polynomial identity is the
+ * standard Horner expansion of 4 sequential block updates.
+ */
+static void neon4_block(uint32_t h[5],
+ const uint32_t r1[5], const uint32_t r2[5],
+ const uint32_t r3[5], const uint32_t r4[5],
+ const uint8_t m[64]) {
+ /* Limbify the four blocks. */
+ uint32_t b0[5], b1[5], b2[5], b3[5];
+ blk2limbs(m, 1, b0);
+ blk2limbs(m + 16, 1, b1);
+ blk2limbs(m + 32, 1, b2);
+ blk2limbs(m + 48, 1, b3);
+
+ /* Fold the running accumulator into the first block. */
+ b0[0] += h[0]; b0[1] += h[1]; b0[2] += h[2];
+ b0[3] += h[3]; b0[4] += h[4];
+
+ /* Pack messages: mv[i] = (b0[i], b1[i], b2[i], b3[i]). */
+ uint32x4_t mv0 = { b0[0], b1[0], b2[0], b3[0] };
+ uint32x4_t mv1 = { b0[1], b1[1], b2[1], b3[1] };
+ uint32x4_t mv2 = { b0[2], b1[2], b2[2], b3[2] };
+ uint32x4_t mv3 = { b0[3], b1[3], b2[3], b3[3] };
+ uint32x4_t mv4 = { b0[4], b1[4], b2[4], b3[4] };
+
+ /* Pack r-powers: rv[i] = (r^4[i], r^3[i], r^2[i], r^1[i]). */
+ uint32x4_t rv0 = { r4[0], r3[0], r2[0], r1[0] };
+ uint32x4_t rv1 = { r4[1], r3[1], r2[1], r1[1] };
+ uint32x4_t rv2 = { r4[2], r3[2], r2[2], r1[2] };
+ uint32x4_t rv3 = { r4[3], r3[3], r2[3], r1[3] };
+ uint32x4_t rv4 = { r4[4], r3[4], r2[4], r1[4] };
+
+ /* 5 * r-powers, for partial products whose position wraps past
+ * 4 (mod 2^130 - 5 = 5). */
+ uint32x4_t rv1_5 = vaddq_u32(vshlq_n_u32(rv1, 2), rv1);
+ uint32x4_t rv2_5 = vaddq_u32(vshlq_n_u32(rv2, 2), rv2);
+ uint32x4_t rv3_5 = vaddq_u32(vshlq_n_u32(rv3, 2), rv3);
+ uint32x4_t rv4_5 = vaddq_u32(vshlq_n_u32(rv4, 2), rv4);
+
+ /*
+ * Output limb j = sum_i (mv[i] * appropriate r-power for shift j-i,
+ * with 5x multiplier when j - i is negative). Per output we sum
+ * across the 4 lanes ('vaddvq_u64') after collecting both vmull
+ * halves.
+ */
+#define MUL5_LO(m_, r_) vmull_u32(vget_low_u32(m_), vget_low_u32(r_))
+#define MUL5_HI(m_, r_) vmull_high_u32(m_, r_)
+#define MLA_LO(acc, m_, r_) \
+ vmlal_u32(acc, vget_low_u32(m_), vget_low_u32(r_))
+#define MLA_HI(acc, m_, r_) vmlal_high_u32(acc, m_, r_)
+
+ uint64x2_t l0 = MUL5_LO(mv0, rv0);
+ uint64x2_t h0 = MUL5_HI(mv0, rv0);
+ l0 = MLA_LO(l0, mv1, rv4_5); h0 = MLA_HI(h0, mv1, rv4_5);
+ l0 = MLA_LO(l0, mv2, rv3_5); h0 = MLA_HI(h0, mv2, rv3_5);
+ l0 = MLA_LO(l0, mv3, rv2_5); h0 = MLA_HI(h0, mv3, rv2_5);
+ l0 = MLA_LO(l0, mv4, rv1_5); h0 = MLA_HI(h0, mv4, rv1_5);
+ uint64_t d0 = vaddvq_u64(l0) + vaddvq_u64(h0);
+
+ uint64x2_t l1 = MUL5_LO(mv0, rv1);
+ uint64x2_t h1 = MUL5_HI(mv0, rv1);
+ l1 = MLA_LO(l1, mv1, rv0 ); h1 = MLA_HI(h1, mv1, rv0 );
+ l1 = MLA_LO(l1, mv2, rv4_5); h1 = MLA_HI(h1, mv2, rv4_5);
+ l1 = MLA_LO(l1, mv3, rv3_5); h1 = MLA_HI(h1, mv3, rv3_5);
+ l1 = MLA_LO(l1, mv4, rv2_5); h1 = MLA_HI(h1, mv4, rv2_5);
+ uint64_t d1 = vaddvq_u64(l1) + vaddvq_u64(h1);
+
+ uint64x2_t l2 = MUL5_LO(mv0, rv2);
+ uint64x2_t h2 = MUL5_HI(mv0, rv2);
+ l2 = MLA_LO(l2, mv1, rv1 ); h2 = MLA_HI(h2, mv1, rv1 );
+ l2 = MLA_LO(l2, mv2, rv0 ); h2 = MLA_HI(h2, mv2, rv0 );
+ l2 = MLA_LO(l2, mv3, rv4_5); h2 = MLA_HI(h2, mv3, rv4_5);
+ l2 = MLA_LO(l2, mv4, rv3_5); h2 = MLA_HI(h2, mv4, rv3_5);
+ uint64_t d2 = vaddvq_u64(l2) + vaddvq_u64(h2);
+
+ uint64x2_t l3 = MUL5_LO(mv0, rv3);
+ uint64x2_t h3 = MUL5_HI(mv0, rv3);
+ l3 = MLA_LO(l3, mv1, rv2 ); h3 = MLA_HI(h3, mv1, rv2 );
+ l3 = MLA_LO(l3, mv2, rv1 ); h3 = MLA_HI(h3, mv2, rv1 );
+ l3 = MLA_LO(l3, mv3, rv0 ); h3 = MLA_HI(h3, mv3, rv0 );
+ l3 = MLA_LO(l3, mv4, rv4_5); h3 = MLA_HI(h3, mv4, rv4_5);
+ uint64_t d3 = vaddvq_u64(l3) + vaddvq_u64(h3);
+
+ uint64x2_t l4 = MUL5_LO(mv0, rv4);
+ uint64x2_t h4 = MUL5_HI(mv0, rv4);
+ l4 = MLA_LO(l4, mv1, rv3); h4 = MLA_HI(h4, mv1, rv3);
+ l4 = MLA_LO(l4, mv2, rv2); h4 = MLA_HI(h4, mv2, rv2);
+ l4 = MLA_LO(l4, mv3, rv1); h4 = MLA_HI(h4, mv3, rv1);
+ l4 = MLA_LO(l4, mv4, rv0); h4 = MLA_HI(h4, mv4, rv0);
+ uint64_t d4 = vaddvq_u64(l4) + vaddvq_u64(h4);
+
+#undef MUL5_LO
+#undef MUL5_HI
+#undef MLA_LO
+#undef MLA_HI
+
+ /* Carry propagation, same shape as 'mul_mod_p'. */
+ uint64_t c;
+ c = d0 >> 26; d0 &= MASK26; d1 += c;
+ c = d1 >> 26; d1 &= MASK26; d2 += c;
+ c = d2 >> 26; d2 &= MASK26; d3 += c;
+ c = d3 >> 26; d3 &= MASK26; d4 += c;
+ c = d4 >> 26; d4 &= MASK26; d0 += c * 5;
+ c = d0 >> 26; d0 &= MASK26; d1 += c;
+
+ h[0] = (uint32_t)d0;
+ h[1] = (uint32_t)d1;
+ h[2] = (uint32_t)d2;
+ h[3] = (uint32_t)d3;
+ h[4] = (uint32_t)d4;
+}
+
+/*
* Compute a 16-byte Poly1305 MAC over 'msg' (length 'msglen') using
- * the 32-byte 'key'. Writes the tag to 'mac_out'. 'key' is the
- * unclamped raw 32 bytes — clamping happens inside.
+ * the 32-byte 'key'. Writes the tag to 'mac_out'.
*/
void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg,
size_t msglen, uint8_t mac_out[16]) {
- /* clamp r (low 16 bytes of key, with specific bits cleared) */
+ /* clamp r */
uint32_t t0, t1, t2, t3;
memcpy(&t0, key, 4);
memcpy(&t1, key + 4, 4);
@@ -100,7 +243,6 @@ void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg,
t2 &= 0x0ffffffcu;
t3 &= 0x0ffffffcu;
- /* r as 5 26-bit limbs */
uint32_t r[5];
r[0] = t0 & MASK26;
r[1] = ((t0 >> 26) | (t1 << 6)) & MASK26;
@@ -108,50 +250,42 @@ void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg,
r[3] = ((t2 >> 14) | (t3 << 18)) & MASK26;
r[4] = (t3 >> 8);
- /* accumulator h starts at 0 */
uint32_t h[5] = { 0, 0, 0, 0, 0 };
- /* process full 16-byte blocks */
size_t pos = 0;
- while (pos + 16 <= msglen) {
- uint32_t blk[5];
- blk2limbs(msg + pos, 1, blk);
- /* h += blk (limb-wise; later carry-propagated inside mul) */
- uint32_t hl[5];
- hl[0] = h[0] + blk[0];
- hl[1] = h[1] + blk[1];
- hl[2] = h[2] + blk[2];
- hl[3] = h[3] + blk[3];
- hl[4] = h[4] + blk[4];
+ /* NEON 4-way path: amortizing the r^2/r^3/r^4 precomputation
+ * (3 scalar mul_mod_p calls) needs about 4 NEON iterations to
+ * break even versus the scalar block loop, so only engage it
+ * when we have at least 16 full blocks (256 bytes). */
+ if (msglen >= 256) {
+ uint32_t r2[5], r3[5], r4[5];
+ mul_mod_p(r, r, r2);
+ mul_mod_p(r2, r, r3);
+ mul_mod_p(r3, r, r4);
- /* h := (h + blk) * r mod p */
- mul_mod_p(hl, r, h);
+ while (pos + 64 <= msglen) {
+ neon4_block(h, r, r2, r3, r4, msg + pos);
+ pos += 64;
+ }
+ }
+ /* Scalar tail: any remaining full blocks (< 4 of them). */
+ while (pos + 16 <= msglen) {
+ scalar_block(h, r, msg + pos, 1);
pos += 16;
}
- /* final partial block (1..15 trailing bytes) */
+ /* Final partial block (1..15 trailing bytes), if any. */
if (pos < msglen) {
size_t rem = msglen - pos;
uint8_t pad[16] = { 0 };
memcpy(pad, msg + pos, rem);
- pad[rem] = 1; /* RFC 8439: append "1" byte, then zero-pad */
-
- uint32_t blk[5];
- blk2limbs(pad, 0, blk); /* no hibit; we set the marker inside */
-
- uint32_t hl[5];
- hl[0] = h[0] + blk[0];
- hl[1] = h[1] + blk[1];
- hl[2] = h[2] + blk[2];
- hl[3] = h[3] + blk[3];
- hl[4] = h[4] + blk[4];
- mul_mod_p(hl, r, h);
+ pad[rem] = 1;
+ scalar_block(h, r, pad, 0);
}
- /* mul_mod_p leaves a small excess in h[1]; absorb it before the
- * final mod-p reduction. */
+ /* normalize h (mul_mod_p may leave a small excess in h[1]) */
{
uint32_t c;
c = h[1] >> 26; h[1] &= MASK26; h[2] += c;
@@ -161,10 +295,8 @@ void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg,
c = h[0] >> 26; h[0] &= MASK26; h[1] += c;
}
- /* full reduction to [0, p). Compute g = h + 5; if g overflows
- * 2^130 (bit 130 set), then h >= p and we replace h with g (which
- * equals h - p mod 2^130). Otherwise leave h alone. Done in
- * constant time via a bitmask. */
+ /* full reduction to [0, p) via constant-time conditional
+ * subtraction of p */
uint32_t g[5];
uint32_t c = 5;
g[0] = h[0] + c; c = g[0] >> 26; g[0] &= MASK26;
@@ -174,21 +306,20 @@ void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg,
g[4] = h[4] + c;
uint32_t carry = g[4] >> 26;
g[4] &= MASK26;
- uint32_t mask = (uint32_t)0 - carry; /* all-1s iff carry */
+ uint32_t mask = (uint32_t)0 - carry;
h[0] = (h[0] & ~mask) | (g[0] & mask);
h[1] = (h[1] & ~mask) | (g[1] & mask);
h[2] = (h[2] & ~mask) | (g[2] & mask);
h[3] = (h[3] & ~mask) | (g[3] & mask);
h[4] = (h[4] & ~mask) | (g[4] & mask);
- /* repack 5x 26-bit limbs into 4x 32-bit limbs (low 128 bits) */
+ /* repack 5x 26-bit limbs into 4x 32-bit limbs */
uint32_t h0 = h[0] | (h[1] << 26);
uint32_t h1 = (h[1] >> 6) | (h[2] << 20);
uint32_t h2 = (h[2] >> 12) | (h[3] << 14);
uint32_t h3 = (h[3] >> 18) | (h[4] << 8);
- /* add s (high 16 bytes of key) as 4x little-endian Word32, mod
- * 2^128 (drop the final 32-bit carry). */
+ /* add s (high 16 bytes of key), mod 2^128 */
uint32_t s0, s1, s2, s3;
memcpy(&s0, key + 16, 4);
memcpy(&s1, key + 20, 4);