poly1305

The Poly1305 message authentication code (docs.ppad.tech/poly1305).
git clone git://git.ppad.tech/poly1305.git
Log | Files | Refs | README | LICENSE

commit 4d0a8a273e2e5524d17ef6ee4e70f9a526381d16
parent c2fb09755dd7213dc2ce5532c164aacb94982ce4
Author: Jared Tobin <jared@jtobin.io>
Date:   Sat, 16 May 2026 13:17:05 -0230

lib: add NEON 4-way kernel (stage 2 of 2)

Replace the inner block loop of the scalar C kernel from the
previous commit with a NEON 4-way parallel implementation when the
message has enough blocks to amortise the precomputation cost.

Layout:

* Each 26-bit-limb position is held in a 'uint32x4_t' whose four
  lanes correspond to four consecutive message blocks
  (b0, b1, b2, b3).  The running accumulator h is folded into the
  first block (b0 += h) so the polynomial Horner expansion

      ((((h + b0) * r) + b1) * r + b2) * r + b3) * r
        = (h + b0) * r^4 + b1 * r^3 + b2 * r^2 + b3 * r

  becomes one batched sum-of-products.

* r-powers are packed in matching 'uint32x4_t' vectors:
  rv[i] = (r^4[i], r^3[i], r^2[i], r^1[i]).  r^2, r^3, r^4 are
  precomputed via the existing scalar 'mul_mod_p' before the loop.

* For each output limb position j, accumulate 5 partial products
  with 'vmull_u32' / 'vmull_high_u32' + 'vmlal_*' (one for each
  message-limb input), then horizontally sum the 4 lanes with
  'vaddvq_u64'.  Coefficients whose position wraps past 4 are
  pre-multiplied by 5 ('rv1_5'..'rv4_5') so the reduction is folded
  into the multiply step.

* The carry-propagation step at the end is identical to the scalar
  'mul_mod_p' path, packing 5 'uint64_t' partial sums back into
  five 26-bit limbs.

Threshold: only engage NEON when 'msglen >= 256' (16 blocks = 4
full NEON iterations).  Below this, the cost of the three setup
'mul_mod_p' calls outweighs the per-iteration savings, and the
scalar block loop wins.  Above it, NEON dominates and the gap
grows with message length.

Performance, M4 MacBook Air, GHC 9.10.3 + LLVM 19, '-fllvm':

  mac (114B  msg):   124 ns -> ~67 ns   (~1.9x, scalar path used)
  mac (1024B msg):  ~640 ns -> 224 ns   (~2.9x stage-1->stage-2,
                                         ~4.6x vs original)
  mac (4096B msg): ~2304 ns -> 822 ns   (~2.8x stage-1->stage-2,
                                         ~5.0x vs original)

(Stage 1 figures for 1024/4096 estimated by linear extrapolation
from the 114B baseline.)

Also extends 'bench/Main.hs' with 1024B and 4096B test cases so
the NEON win is visible; the existing 114B 'msg' case sits below
the threshold and exercises the scalar path.

All 12 tasty cases (RFC 8439 sections 2.5.2 + A.3 #1-11) pass
through the dispatched path under both '-fllvm' and '-fllvm
-fsanitize' (ASan + UBSan over the NEON kernel — no diagnostics).
A.3 #2 and #3 use 375-byte messages, comfortably above the
threshold, so they exercise the 4-way kernel.

Diffstat:
Mbench/Main.hs | 8++++++++
Mcbits/poly1305_arm.c | 247++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------
2 files changed, 197 insertions(+), 58 deletions(-)

diff --git a/bench/Main.hs b/bench/Main.hs @@ -35,11 +35,19 @@ key_big :: BS.ByteString key_big = fromJust . B16.decode $ "fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff3" +msg_1k :: BS.ByteString +msg_1k = BS.replicate 1024 0x42 + +msg_4k :: BS.ByteString +msg_4k = BS.replicate 4096 0x42 + suite :: Benchmark suite = bgroup "ppad-poly1305" [ bench "mac (small key)" $ nf (Poly1305.mac key_small) msg , bench "mac (mid key)" $ nf (Poly1305.mac key_mid) msg , bench "mac (big key)" $ nf (Poly1305.mac key_big) msg + , bench "mac (1024B msg)" $ nf (Poly1305.mac key_big) msg_1k + , bench "mac (4096B msg)" $ nf (Poly1305.mac key_big) msg_4k ] diff --git a/cbits/poly1305_arm.c b/cbits/poly1305_arm.c @@ -4,17 +4,30 @@ #if defined(__aarch64__) +#include <arm_neon.h> + /* - * Poly1305 (RFC 8439). Stage 1 kernel: clean scalar C using the - * 26-bit limb representation that NEON Poly1305 implementations use, - * with 64-bit native arithmetic for the limb multiplications. Stage - * 2 (a separate commit) replaces the inner block loop with a NEON - * 4-way parallel kernel. + * Poly1305 (RFC 8439). ARM acceleration via two paths: + * + * 1. A scalar 26-bit-limb kernel using 64-bit native arithmetic for + * the limb multiplications. Used for setup (precomputing r^k), + * for messages shorter than 64 bytes, and for the < 4-block tail + * of longer messages. Also serves as the stage-1 reference; the + * NEON kernel below was added in a separate commit. + * + * 2. A NEON 4-way parallel kernel (when at least 4 full blocks + * remain). Layout: 5 'uint32x4_t' limb vectors hold one limb + * position each across 4 message blocks; matching r-power + * vectors hold (r^4, r^3, r^2, r^1) at the same limb position. + * For each output position d_j we accumulate 5 partial products + * across both vector halves with 'vmull'/'vmlal', then + * horizontally sum the 4 lanes with 'vaddvq_u64'. The same + * carry-propagation pattern as the scalar path reduces back to + * 26-bit limbs. * - * Layout: every 130-bit value is held as 5 uint32_t limbs of 26 bits - * each (top limb gets the spare 2 bits). The prime is p = 2^130 - 5, - * so the reduction rule for any 'spilled' limb beyond position 4 is - * 'add 5 * that limb back into the low end'. + * Reduction rule (used in both paths): 2^130 = 5 (mod 2^130 - 5), + * so any partial product whose 'position' exceeds 4 folds back as + * (5 * value) into position (pos - 5). */ #define MASK26 0x3ffffffu @@ -22,15 +35,13 @@ /* * Multiply two 130-bit values mod (2^130 - 5). Inputs in 5x 26-bit * limb form, output in 5x 26-bit limb form (each output limb < 2^26 - * except possibly limb 1, which may carry a small excess absorbed by - * the next 'mul_mod_p' or by 'normalize'). + * except possibly limb 1, which may carry a small excess absorbed + * by the next 'mul_mod_p' or by 'normalize'). */ static void mul_mod_p(const uint32_t a[5], const uint32_t b[5], uint32_t out[5]) { uint64_t d0, d1, d2, d3, d4, c; - /* 25 partial products; limbs that 'spill' beyond position 4 fold - * back as (5 * limb) thanks to 2^130 = 5 mod p. */ d0 = (uint64_t)a[0]*b[0] + 5 * ((uint64_t)a[4]*b[1] + (uint64_t)a[3]*b[2] + (uint64_t)a[2]*b[3] + (uint64_t)a[1]*b[4]); @@ -47,7 +58,6 @@ static void mul_mod_p(const uint32_t a[5], const uint32_t b[5], + (uint64_t)a[2]*b[2] + (uint64_t)a[3]*b[1] + (uint64_t)a[4]*b[0]; - /* single-pass carry propagation */ c = d0 >> 26; d0 &= MASK26; d1 += c; c = d1 >> 26; d1 &= MASK26; d2 += c; c = d2 >> 26; d2 &= MASK26; d3 += c; @@ -64,9 +74,7 @@ static void mul_mod_p(const uint32_t a[5], const uint32_t b[5], /* * Parse 16 little-endian bytes plus a 'hibit' value (0 or 1) at bit - * 128 into 5 26-bit limbs. 'hibit' = 1 for a full block (the implicit - * "+ 2^128" of Poly1305). Partial blocks set the marker byte inside - * the 16-byte buffer and pass 'hibit' = 0. + * 128 into 5 26-bit limbs. */ static inline void blk2limbs(const uint8_t m[16], uint32_t hibit, uint32_t l[5]) { @@ -83,13 +91,148 @@ static inline void blk2limbs(const uint8_t m[16], uint32_t hibit, } /* + * Process one full 16-byte block: h := (h + m) * r mod p, scalar + * implementation. + */ +static inline void scalar_block(uint32_t h[5], const uint32_t r[5], + const uint8_t m[16], uint32_t hibit) { + uint32_t blk[5]; + blk2limbs(m, hibit, blk); + + uint32_t hl[5]; + hl[0] = h[0] + blk[0]; + hl[1] = h[1] + blk[1]; + hl[2] = h[2] + blk[2]; + hl[3] = h[3] + blk[3]; + hl[4] = h[4] + blk[4]; + + mul_mod_p(hl, r, h); +} + +/* + * 4-block NEON update. Computes: + * h := (h + m_0)*r^4 + m_1*r^3 + m_2*r^2 + m_3*r^1 mod p + * + * where m_0..m_3 are four consecutive 16-byte input blocks (so the + * function consumes 64 bytes). The polynomial identity is the + * standard Horner expansion of 4 sequential block updates. + */ +static void neon4_block(uint32_t h[5], + const uint32_t r1[5], const uint32_t r2[5], + const uint32_t r3[5], const uint32_t r4[5], + const uint8_t m[64]) { + /* Limbify the four blocks. */ + uint32_t b0[5], b1[5], b2[5], b3[5]; + blk2limbs(m, 1, b0); + blk2limbs(m + 16, 1, b1); + blk2limbs(m + 32, 1, b2); + blk2limbs(m + 48, 1, b3); + + /* Fold the running accumulator into the first block. */ + b0[0] += h[0]; b0[1] += h[1]; b0[2] += h[2]; + b0[3] += h[3]; b0[4] += h[4]; + + /* Pack messages: mv[i] = (b0[i], b1[i], b2[i], b3[i]). */ + uint32x4_t mv0 = { b0[0], b1[0], b2[0], b3[0] }; + uint32x4_t mv1 = { b0[1], b1[1], b2[1], b3[1] }; + uint32x4_t mv2 = { b0[2], b1[2], b2[2], b3[2] }; + uint32x4_t mv3 = { b0[3], b1[3], b2[3], b3[3] }; + uint32x4_t mv4 = { b0[4], b1[4], b2[4], b3[4] }; + + /* Pack r-powers: rv[i] = (r^4[i], r^3[i], r^2[i], r^1[i]). */ + uint32x4_t rv0 = { r4[0], r3[0], r2[0], r1[0] }; + uint32x4_t rv1 = { r4[1], r3[1], r2[1], r1[1] }; + uint32x4_t rv2 = { r4[2], r3[2], r2[2], r1[2] }; + uint32x4_t rv3 = { r4[3], r3[3], r2[3], r1[3] }; + uint32x4_t rv4 = { r4[4], r3[4], r2[4], r1[4] }; + + /* 5 * r-powers, for partial products whose position wraps past + * 4 (mod 2^130 - 5 = 5). */ + uint32x4_t rv1_5 = vaddq_u32(vshlq_n_u32(rv1, 2), rv1); + uint32x4_t rv2_5 = vaddq_u32(vshlq_n_u32(rv2, 2), rv2); + uint32x4_t rv3_5 = vaddq_u32(vshlq_n_u32(rv3, 2), rv3); + uint32x4_t rv4_5 = vaddq_u32(vshlq_n_u32(rv4, 2), rv4); + + /* + * Output limb j = sum_i (mv[i] * appropriate r-power for shift j-i, + * with 5x multiplier when j - i is negative). Per output we sum + * across the 4 lanes ('vaddvq_u64') after collecting both vmull + * halves. + */ +#define MUL5_LO(m_, r_) vmull_u32(vget_low_u32(m_), vget_low_u32(r_)) +#define MUL5_HI(m_, r_) vmull_high_u32(m_, r_) +#define MLA_LO(acc, m_, r_) \ + vmlal_u32(acc, vget_low_u32(m_), vget_low_u32(r_)) +#define MLA_HI(acc, m_, r_) vmlal_high_u32(acc, m_, r_) + + uint64x2_t l0 = MUL5_LO(mv0, rv0); + uint64x2_t h0 = MUL5_HI(mv0, rv0); + l0 = MLA_LO(l0, mv1, rv4_5); h0 = MLA_HI(h0, mv1, rv4_5); + l0 = MLA_LO(l0, mv2, rv3_5); h0 = MLA_HI(h0, mv2, rv3_5); + l0 = MLA_LO(l0, mv3, rv2_5); h0 = MLA_HI(h0, mv3, rv2_5); + l0 = MLA_LO(l0, mv4, rv1_5); h0 = MLA_HI(h0, mv4, rv1_5); + uint64_t d0 = vaddvq_u64(l0) + vaddvq_u64(h0); + + uint64x2_t l1 = MUL5_LO(mv0, rv1); + uint64x2_t h1 = MUL5_HI(mv0, rv1); + l1 = MLA_LO(l1, mv1, rv0 ); h1 = MLA_HI(h1, mv1, rv0 ); + l1 = MLA_LO(l1, mv2, rv4_5); h1 = MLA_HI(h1, mv2, rv4_5); + l1 = MLA_LO(l1, mv3, rv3_5); h1 = MLA_HI(h1, mv3, rv3_5); + l1 = MLA_LO(l1, mv4, rv2_5); h1 = MLA_HI(h1, mv4, rv2_5); + uint64_t d1 = vaddvq_u64(l1) + vaddvq_u64(h1); + + uint64x2_t l2 = MUL5_LO(mv0, rv2); + uint64x2_t h2 = MUL5_HI(mv0, rv2); + l2 = MLA_LO(l2, mv1, rv1 ); h2 = MLA_HI(h2, mv1, rv1 ); + l2 = MLA_LO(l2, mv2, rv0 ); h2 = MLA_HI(h2, mv2, rv0 ); + l2 = MLA_LO(l2, mv3, rv4_5); h2 = MLA_HI(h2, mv3, rv4_5); + l2 = MLA_LO(l2, mv4, rv3_5); h2 = MLA_HI(h2, mv4, rv3_5); + uint64_t d2 = vaddvq_u64(l2) + vaddvq_u64(h2); + + uint64x2_t l3 = MUL5_LO(mv0, rv3); + uint64x2_t h3 = MUL5_HI(mv0, rv3); + l3 = MLA_LO(l3, mv1, rv2 ); h3 = MLA_HI(h3, mv1, rv2 ); + l3 = MLA_LO(l3, mv2, rv1 ); h3 = MLA_HI(h3, mv2, rv1 ); + l3 = MLA_LO(l3, mv3, rv0 ); h3 = MLA_HI(h3, mv3, rv0 ); + l3 = MLA_LO(l3, mv4, rv4_5); h3 = MLA_HI(h3, mv4, rv4_5); + uint64_t d3 = vaddvq_u64(l3) + vaddvq_u64(h3); + + uint64x2_t l4 = MUL5_LO(mv0, rv4); + uint64x2_t h4 = MUL5_HI(mv0, rv4); + l4 = MLA_LO(l4, mv1, rv3); h4 = MLA_HI(h4, mv1, rv3); + l4 = MLA_LO(l4, mv2, rv2); h4 = MLA_HI(h4, mv2, rv2); + l4 = MLA_LO(l4, mv3, rv1); h4 = MLA_HI(h4, mv3, rv1); + l4 = MLA_LO(l4, mv4, rv0); h4 = MLA_HI(h4, mv4, rv0); + uint64_t d4 = vaddvq_u64(l4) + vaddvq_u64(h4); + +#undef MUL5_LO +#undef MUL5_HI +#undef MLA_LO +#undef MLA_HI + + /* Carry propagation, same shape as 'mul_mod_p'. */ + uint64_t c; + c = d0 >> 26; d0 &= MASK26; d1 += c; + c = d1 >> 26; d1 &= MASK26; d2 += c; + c = d2 >> 26; d2 &= MASK26; d3 += c; + c = d3 >> 26; d3 &= MASK26; d4 += c; + c = d4 >> 26; d4 &= MASK26; d0 += c * 5; + c = d0 >> 26; d0 &= MASK26; d1 += c; + + h[0] = (uint32_t)d0; + h[1] = (uint32_t)d1; + h[2] = (uint32_t)d2; + h[3] = (uint32_t)d3; + h[4] = (uint32_t)d4; +} + +/* * Compute a 16-byte Poly1305 MAC over 'msg' (length 'msglen') using - * the 32-byte 'key'. Writes the tag to 'mac_out'. 'key' is the - * unclamped raw 32 bytes — clamping happens inside. + * the 32-byte 'key'. Writes the tag to 'mac_out'. */ void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg, size_t msglen, uint8_t mac_out[16]) { - /* clamp r (low 16 bytes of key, with specific bits cleared) */ + /* clamp r */ uint32_t t0, t1, t2, t3; memcpy(&t0, key, 4); memcpy(&t1, key + 4, 4); @@ -100,7 +243,6 @@ void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg, t2 &= 0x0ffffffcu; t3 &= 0x0ffffffcu; - /* r as 5 26-bit limbs */ uint32_t r[5]; r[0] = t0 & MASK26; r[1] = ((t0 >> 26) | (t1 << 6)) & MASK26; @@ -108,50 +250,42 @@ void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg, r[3] = ((t2 >> 14) | (t3 << 18)) & MASK26; r[4] = (t3 >> 8); - /* accumulator h starts at 0 */ uint32_t h[5] = { 0, 0, 0, 0, 0 }; - /* process full 16-byte blocks */ size_t pos = 0; - while (pos + 16 <= msglen) { - uint32_t blk[5]; - blk2limbs(msg + pos, 1, blk); - /* h += blk (limb-wise; later carry-propagated inside mul) */ - uint32_t hl[5]; - hl[0] = h[0] + blk[0]; - hl[1] = h[1] + blk[1]; - hl[2] = h[2] + blk[2]; - hl[3] = h[3] + blk[3]; - hl[4] = h[4] + blk[4]; + /* NEON 4-way path: amortizing the r^2/r^3/r^4 precomputation + * (3 scalar mul_mod_p calls) needs about 4 NEON iterations to + * break even versus the scalar block loop, so only engage it + * when we have at least 16 full blocks (256 bytes). */ + if (msglen >= 256) { + uint32_t r2[5], r3[5], r4[5]; + mul_mod_p(r, r, r2); + mul_mod_p(r2, r, r3); + mul_mod_p(r3, r, r4); - /* h := (h + blk) * r mod p */ - mul_mod_p(hl, r, h); + while (pos + 64 <= msglen) { + neon4_block(h, r, r2, r3, r4, msg + pos); + pos += 64; + } + } + /* Scalar tail: any remaining full blocks (< 4 of them). */ + while (pos + 16 <= msglen) { + scalar_block(h, r, msg + pos, 1); pos += 16; } - /* final partial block (1..15 trailing bytes) */ + /* Final partial block (1..15 trailing bytes), if any. */ if (pos < msglen) { size_t rem = msglen - pos; uint8_t pad[16] = { 0 }; memcpy(pad, msg + pos, rem); - pad[rem] = 1; /* RFC 8439: append "1" byte, then zero-pad */ - - uint32_t blk[5]; - blk2limbs(pad, 0, blk); /* no hibit; we set the marker inside */ - - uint32_t hl[5]; - hl[0] = h[0] + blk[0]; - hl[1] = h[1] + blk[1]; - hl[2] = h[2] + blk[2]; - hl[3] = h[3] + blk[3]; - hl[4] = h[4] + blk[4]; - mul_mod_p(hl, r, h); + pad[rem] = 1; + scalar_block(h, r, pad, 0); } - /* mul_mod_p leaves a small excess in h[1]; absorb it before the - * final mod-p reduction. */ + /* normalize h (mul_mod_p may leave a small excess in h[1]) */ { uint32_t c; c = h[1] >> 26; h[1] &= MASK26; h[2] += c; @@ -161,10 +295,8 @@ void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg, c = h[0] >> 26; h[0] &= MASK26; h[1] += c; } - /* full reduction to [0, p). Compute g = h + 5; if g overflows - * 2^130 (bit 130 set), then h >= p and we replace h with g (which - * equals h - p mod 2^130). Otherwise leave h alone. Done in - * constant time via a bitmask. */ + /* full reduction to [0, p) via constant-time conditional + * subtraction of p */ uint32_t g[5]; uint32_t c = 5; g[0] = h[0] + c; c = g[0] >> 26; g[0] &= MASK26; @@ -174,21 +306,20 @@ void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg, g[4] = h[4] + c; uint32_t carry = g[4] >> 26; g[4] &= MASK26; - uint32_t mask = (uint32_t)0 - carry; /* all-1s iff carry */ + uint32_t mask = (uint32_t)0 - carry; h[0] = (h[0] & ~mask) | (g[0] & mask); h[1] = (h[1] & ~mask) | (g[1] & mask); h[2] = (h[2] & ~mask) | (g[2] & mask); h[3] = (h[3] & ~mask) | (g[3] & mask); h[4] = (h[4] & ~mask) | (g[4] & mask); - /* repack 5x 26-bit limbs into 4x 32-bit limbs (low 128 bits) */ + /* repack 5x 26-bit limbs into 4x 32-bit limbs */ uint32_t h0 = h[0] | (h[1] << 26); uint32_t h1 = (h[1] >> 6) | (h[2] << 20); uint32_t h2 = (h[2] >> 12) | (h[3] << 14); uint32_t h3 = (h[3] >> 18) | (h[4] << 8); - /* add s (high 16 bytes of key) as 4x little-endian Word32, mod - * 2^128 (drop the final 32-bit carry). */ + /* add s (high 16 bytes of key), mod 2^128 */ uint32_t s0, s1, s2, s3; memcpy(&s0, key + 16, 4); memcpy(&s1, key + 20, 4);