poly1305

The Poly1305 message authentication code (docs.ppad.tech/poly1305).
git clone git://git.ppad.tech/poly1305.git
Log | Files | Refs | README | LICENSE

poly1305_arm.c (12697B)


      1 #include <stddef.h>
      2 #include <stdint.h>
      3 #include <string.h>
      4 
      5 #if defined(__aarch64__)
      6 
      7 #include <arm_neon.h>
      8 
      9 /*
     10  * Poly1305 (RFC 8439).  ARM acceleration via two paths:
     11  *
     12  *  1. A scalar 26-bit-limb kernel using 64-bit native arithmetic for
     13  *     the limb multiplications.  Used for setup (precomputing r^k),
     14  *     for messages shorter than 64 bytes, and for the < 4-block tail
     15  *     of longer messages.  Also serves as the stage-1 reference; the
     16  *     NEON kernel below was added in a separate commit.
     17  *
     18  *  2. A NEON 4-way parallel kernel (when at least 4 full blocks
     19  *     remain).  Layout: 5 'uint32x4_t' limb vectors hold one limb
     20  *     position each across 4 message blocks; matching r-power
     21  *     vectors hold (r^4, r^3, r^2, r^1) at the same limb position.
     22  *     For each output position d_j we accumulate 5 partial products
     23  *     across both vector halves with 'vmull'/'vmlal', then
     24  *     horizontally sum the 4 lanes with 'vaddvq_u64'.  The same
     25  *     carry-propagation pattern as the scalar path reduces back to
     26  *     26-bit limbs.
     27  *
     28  * Reduction rule (used in both paths): 2^130 = 5 (mod 2^130 - 5),
     29  * so any partial product whose 'position' exceeds 4 folds back as
     30  * (5 * value) into position (pos - 5).
     31  */
     32 
     33 #define MASK26 0x3ffffffu
     34 
     35 /*
     36  * Multiply two 130-bit values mod (2^130 - 5).  Inputs in 5x 26-bit
     37  * limb form, output in 5x 26-bit limb form (each output limb < 2^26
     38  * except possibly limb 1, which may carry a small excess absorbed
     39  * by the next 'mul_mod_p' or by 'normalize').
     40  */
     41 static void mul_mod_p(const uint32_t a[5], const uint32_t b[5],
     42                       uint32_t out[5]) {
     43     uint64_t d0, d1, d2, d3, d4, c;
     44 
     45     d0 = (uint64_t)a[0]*b[0]
     46        + 5 * ((uint64_t)a[4]*b[1] + (uint64_t)a[3]*b[2]
     47             + (uint64_t)a[2]*b[3] + (uint64_t)a[1]*b[4]);
     48     d1 = (uint64_t)a[0]*b[1] + (uint64_t)a[1]*b[0]
     49        + 5 * ((uint64_t)a[4]*b[2] + (uint64_t)a[3]*b[3]
     50             + (uint64_t)a[2]*b[4]);
     51     d2 = (uint64_t)a[0]*b[2] + (uint64_t)a[1]*b[1]
     52        + (uint64_t)a[2]*b[0]
     53        + 5 * ((uint64_t)a[4]*b[3] + (uint64_t)a[3]*b[4]);
     54     d3 = (uint64_t)a[0]*b[3] + (uint64_t)a[1]*b[2]
     55        + (uint64_t)a[2]*b[1] + (uint64_t)a[3]*b[0]
     56        + 5 * ((uint64_t)a[4]*b[4]);
     57     d4 = (uint64_t)a[0]*b[4] + (uint64_t)a[1]*b[3]
     58        + (uint64_t)a[2]*b[2] + (uint64_t)a[3]*b[1]
     59        + (uint64_t)a[4]*b[0];
     60 
     61     c = d0 >> 26; d0 &= MASK26; d1 += c;
     62     c = d1 >> 26; d1 &= MASK26; d2 += c;
     63     c = d2 >> 26; d2 &= MASK26; d3 += c;
     64     c = d3 >> 26; d3 &= MASK26; d4 += c;
     65     c = d4 >> 26; d4 &= MASK26; d0 += c * 5;
     66     c = d0 >> 26; d0 &= MASK26; d1 += c;
     67 
     68     out[0] = (uint32_t)d0;
     69     out[1] = (uint32_t)d1;
     70     out[2] = (uint32_t)d2;
     71     out[3] = (uint32_t)d3;
     72     out[4] = (uint32_t)d4;
     73 }
     74 
     75 /*
     76  * Parse 16 little-endian bytes plus a 'hibit' value (0 or 1) at bit
     77  * 128 into 5 26-bit limbs.
     78  */
     79 static inline void blk2limbs(const uint8_t m[16], uint32_t hibit,
     80                               uint32_t l[5]) {
     81     uint32_t t0, t1, t2, t3;
     82     memcpy(&t0, m,     4);
     83     memcpy(&t1, m + 4, 4);
     84     memcpy(&t2, m + 8, 4);
     85     memcpy(&t3, m + 12, 4);
     86     l[0] =  t0                                & MASK26;
     87     l[1] = ((t0 >> 26) | (t1 <<  6))          & MASK26;
     88     l[2] = ((t1 >> 20) | (t2 << 12))          & MASK26;
     89     l[3] = ((t2 >> 14) | (t3 << 18))          & MASK26;
     90     l[4] =  (t3 >>  8) | (hibit << 24);
     91 }
     92 
     93 /*
     94  * Process one full 16-byte block: h := (h + m) * r mod p, scalar
     95  * implementation.
     96  */
     97 static inline void scalar_block(uint32_t h[5], const uint32_t r[5],
     98                                 const uint8_t m[16], uint32_t hibit) {
     99     uint32_t blk[5];
    100     blk2limbs(m, hibit, blk);
    101 
    102     uint32_t hl[5];
    103     hl[0] = h[0] + blk[0];
    104     hl[1] = h[1] + blk[1];
    105     hl[2] = h[2] + blk[2];
    106     hl[3] = h[3] + blk[3];
    107     hl[4] = h[4] + blk[4];
    108 
    109     mul_mod_p(hl, r, h);
    110 }
    111 
    112 /*
    113  * 4-block NEON update.  Computes:
    114  *   h := (h + m_0)*r^4 + m_1*r^3 + m_2*r^2 + m_3*r^1   mod p
    115  *
    116  * where m_0..m_3 are four consecutive 16-byte input blocks (so the
    117  * function consumes 64 bytes).  The polynomial identity is the
    118  * standard Horner expansion of 4 sequential block updates.
    119  */
    120 static void neon4_block(uint32_t h[5],
    121                         const uint32_t r1[5], const uint32_t r2[5],
    122                         const uint32_t r3[5], const uint32_t r4[5],
    123                         const uint8_t m[64]) {
    124     /* Limbify the four blocks. */
    125     uint32_t b0[5], b1[5], b2[5], b3[5];
    126     blk2limbs(m,      1, b0);
    127     blk2limbs(m + 16, 1, b1);
    128     blk2limbs(m + 32, 1, b2);
    129     blk2limbs(m + 48, 1, b3);
    130 
    131     /* Fold the running accumulator into the first block. */
    132     b0[0] += h[0]; b0[1] += h[1]; b0[2] += h[2];
    133     b0[3] += h[3]; b0[4] += h[4];
    134 
    135     /* Pack messages: mv[i] = (b0[i], b1[i], b2[i], b3[i]). */
    136     uint32x4_t mv0 = { b0[0], b1[0], b2[0], b3[0] };
    137     uint32x4_t mv1 = { b0[1], b1[1], b2[1], b3[1] };
    138     uint32x4_t mv2 = { b0[2], b1[2], b2[2], b3[2] };
    139     uint32x4_t mv3 = { b0[3], b1[3], b2[3], b3[3] };
    140     uint32x4_t mv4 = { b0[4], b1[4], b2[4], b3[4] };
    141 
    142     /* Pack r-powers: rv[i] = (r^4[i], r^3[i], r^2[i], r^1[i]).      */
    143     uint32x4_t rv0 = { r4[0], r3[0], r2[0], r1[0] };
    144     uint32x4_t rv1 = { r4[1], r3[1], r2[1], r1[1] };
    145     uint32x4_t rv2 = { r4[2], r3[2], r2[2], r1[2] };
    146     uint32x4_t rv3 = { r4[3], r3[3], r2[3], r1[3] };
    147     uint32x4_t rv4 = { r4[4], r3[4], r2[4], r1[4] };
    148 
    149     /* 5 * r-powers, for partial products whose position wraps past
    150      * 4 (mod 2^130 - 5 = 5).                                       */
    151     uint32x4_t rv1_5 = vaddq_u32(vshlq_n_u32(rv1, 2), rv1);
    152     uint32x4_t rv2_5 = vaddq_u32(vshlq_n_u32(rv2, 2), rv2);
    153     uint32x4_t rv3_5 = vaddq_u32(vshlq_n_u32(rv3, 2), rv3);
    154     uint32x4_t rv4_5 = vaddq_u32(vshlq_n_u32(rv4, 2), rv4);
    155 
    156     /*
    157      * Output limb j = sum_i (mv[i] * appropriate r-power for shift j-i,
    158      * with 5x multiplier when j - i is negative).  Per output we sum
    159      * across the 4 lanes ('vaddvq_u64') after collecting both vmull
    160      * halves.
    161      */
    162 #define MUL5_LO(m_, r_) vmull_u32(vget_low_u32(m_), vget_low_u32(r_))
    163 #define MUL5_HI(m_, r_) vmull_high_u32(m_, r_)
    164 #define MLA_LO(acc, m_, r_) \
    165     vmlal_u32(acc, vget_low_u32(m_), vget_low_u32(r_))
    166 #define MLA_HI(acc, m_, r_) vmlal_high_u32(acc, m_, r_)
    167 
    168     uint64x2_t l0 = MUL5_LO(mv0, rv0);
    169     uint64x2_t h0 = MUL5_HI(mv0, rv0);
    170     l0 = MLA_LO(l0, mv1, rv4_5);  h0 = MLA_HI(h0, mv1, rv4_5);
    171     l0 = MLA_LO(l0, mv2, rv3_5);  h0 = MLA_HI(h0, mv2, rv3_5);
    172     l0 = MLA_LO(l0, mv3, rv2_5);  h0 = MLA_HI(h0, mv3, rv2_5);
    173     l0 = MLA_LO(l0, mv4, rv1_5);  h0 = MLA_HI(h0, mv4, rv1_5);
    174     uint64_t d0 = vaddvq_u64(l0) + vaddvq_u64(h0);
    175 
    176     uint64x2_t l1 = MUL5_LO(mv0, rv1);
    177     uint64x2_t h1 = MUL5_HI(mv0, rv1);
    178     l1 = MLA_LO(l1, mv1, rv0   );  h1 = MLA_HI(h1, mv1, rv0   );
    179     l1 = MLA_LO(l1, mv2, rv4_5);  h1 = MLA_HI(h1, mv2, rv4_5);
    180     l1 = MLA_LO(l1, mv3, rv3_5);  h1 = MLA_HI(h1, mv3, rv3_5);
    181     l1 = MLA_LO(l1, mv4, rv2_5);  h1 = MLA_HI(h1, mv4, rv2_5);
    182     uint64_t d1 = vaddvq_u64(l1) + vaddvq_u64(h1);
    183 
    184     uint64x2_t l2 = MUL5_LO(mv0, rv2);
    185     uint64x2_t h2 = MUL5_HI(mv0, rv2);
    186     l2 = MLA_LO(l2, mv1, rv1   );  h2 = MLA_HI(h2, mv1, rv1   );
    187     l2 = MLA_LO(l2, mv2, rv0   );  h2 = MLA_HI(h2, mv2, rv0   );
    188     l2 = MLA_LO(l2, mv3, rv4_5);  h2 = MLA_HI(h2, mv3, rv4_5);
    189     l2 = MLA_LO(l2, mv4, rv3_5);  h2 = MLA_HI(h2, mv4, rv3_5);
    190     uint64_t d2 = vaddvq_u64(l2) + vaddvq_u64(h2);
    191 
    192     uint64x2_t l3 = MUL5_LO(mv0, rv3);
    193     uint64x2_t h3 = MUL5_HI(mv0, rv3);
    194     l3 = MLA_LO(l3, mv1, rv2   );  h3 = MLA_HI(h3, mv1, rv2   );
    195     l3 = MLA_LO(l3, mv2, rv1   );  h3 = MLA_HI(h3, mv2, rv1   );
    196     l3 = MLA_LO(l3, mv3, rv0   );  h3 = MLA_HI(h3, mv3, rv0   );
    197     l3 = MLA_LO(l3, mv4, rv4_5);  h3 = MLA_HI(h3, mv4, rv4_5);
    198     uint64_t d3 = vaddvq_u64(l3) + vaddvq_u64(h3);
    199 
    200     uint64x2_t l4 = MUL5_LO(mv0, rv4);
    201     uint64x2_t h4 = MUL5_HI(mv0, rv4);
    202     l4 = MLA_LO(l4, mv1, rv3);  h4 = MLA_HI(h4, mv1, rv3);
    203     l4 = MLA_LO(l4, mv2, rv2);  h4 = MLA_HI(h4, mv2, rv2);
    204     l4 = MLA_LO(l4, mv3, rv1);  h4 = MLA_HI(h4, mv3, rv1);
    205     l4 = MLA_LO(l4, mv4, rv0);  h4 = MLA_HI(h4, mv4, rv0);
    206     uint64_t d4 = vaddvq_u64(l4) + vaddvq_u64(h4);
    207 
    208 #undef MUL5_LO
    209 #undef MUL5_HI
    210 #undef MLA_LO
    211 #undef MLA_HI
    212 
    213     /* Carry propagation, same shape as 'mul_mod_p'. */
    214     uint64_t c;
    215     c = d0 >> 26; d0 &= MASK26; d1 += c;
    216     c = d1 >> 26; d1 &= MASK26; d2 += c;
    217     c = d2 >> 26; d2 &= MASK26; d3 += c;
    218     c = d3 >> 26; d3 &= MASK26; d4 += c;
    219     c = d4 >> 26; d4 &= MASK26; d0 += c * 5;
    220     c = d0 >> 26; d0 &= MASK26; d1 += c;
    221 
    222     h[0] = (uint32_t)d0;
    223     h[1] = (uint32_t)d1;
    224     h[2] = (uint32_t)d2;
    225     h[3] = (uint32_t)d3;
    226     h[4] = (uint32_t)d4;
    227 }
    228 
    229 /*
    230  * Compute a 16-byte Poly1305 MAC over 'msg' (length 'msglen') using
    231  * the 32-byte 'key'.  Writes the tag to 'mac_out'.
    232  */
    233 void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg,
    234                       size_t msglen, uint8_t mac_out[16]) {
    235     /* clamp r */
    236     uint32_t t0, t1, t2, t3;
    237     memcpy(&t0, key,      4);
    238     memcpy(&t1, key + 4,  4);
    239     memcpy(&t2, key + 8,  4);
    240     memcpy(&t3, key + 12, 4);
    241     t0 &= 0x0fffffffu;
    242     t1 &= 0x0ffffffcu;
    243     t2 &= 0x0ffffffcu;
    244     t3 &= 0x0ffffffcu;
    245 
    246     uint32_t r[5];
    247     r[0] =  t0                          & MASK26;
    248     r[1] = ((t0 >> 26) | (t1 <<  6))    & MASK26;
    249     r[2] = ((t1 >> 20) | (t2 << 12))    & MASK26;
    250     r[3] = ((t2 >> 14) | (t3 << 18))    & MASK26;
    251     r[4] =  (t3 >>  8);
    252 
    253     uint32_t h[5] = { 0, 0, 0, 0, 0 };
    254 
    255     size_t pos = 0;
    256 
    257     /* NEON 4-way path: amortizing the r^2/r^3/r^4 precomputation
    258      * (3 scalar mul_mod_p calls) needs about 4 NEON iterations to
    259      * break even versus the scalar block loop, so only engage it
    260      * when we have at least 16 full blocks (256 bytes).            */
    261     if (msglen >= 256) {
    262         uint32_t r2[5], r3[5], r4[5];
    263         mul_mod_p(r,  r,  r2);
    264         mul_mod_p(r2, r,  r3);
    265         mul_mod_p(r3, r,  r4);
    266 
    267         while (pos + 64 <= msglen) {
    268             neon4_block(h, r, r2, r3, r4, msg + pos);
    269             pos += 64;
    270         }
    271     }
    272 
    273     /* Scalar tail: any remaining full blocks (< 4 of them). */
    274     while (pos + 16 <= msglen) {
    275         scalar_block(h, r, msg + pos, 1);
    276         pos += 16;
    277     }
    278 
    279     /* Final partial block (1..15 trailing bytes), if any. */
    280     if (pos < msglen) {
    281         size_t rem = msglen - pos;
    282         uint8_t pad[16] = { 0 };
    283         memcpy(pad, msg + pos, rem);
    284         pad[rem] = 1;
    285         scalar_block(h, r, pad, 0);
    286     }
    287 
    288     /* normalize h (mul_mod_p may leave a small excess in h[1]) */
    289     {
    290         uint32_t c;
    291         c = h[1] >> 26; h[1] &= MASK26; h[2] += c;
    292         c = h[2] >> 26; h[2] &= MASK26; h[3] += c;
    293         c = h[3] >> 26; h[3] &= MASK26; h[4] += c;
    294         c = h[4] >> 26; h[4] &= MASK26; h[0] += c * 5;
    295         c = h[0] >> 26; h[0] &= MASK26; h[1] += c;
    296     }
    297 
    298     /* full reduction to [0, p) via constant-time conditional
    299      * subtraction of p */
    300     uint32_t g[5];
    301     uint32_t c = 5;
    302     g[0] = h[0] + c;   c = g[0] >> 26; g[0] &= MASK26;
    303     g[1] = h[1] + c;   c = g[1] >> 26; g[1] &= MASK26;
    304     g[2] = h[2] + c;   c = g[2] >> 26; g[2] &= MASK26;
    305     g[3] = h[3] + c;   c = g[3] >> 26; g[3] &= MASK26;
    306     g[4] = h[4] + c;
    307     uint32_t carry = g[4] >> 26;
    308     g[4] &= MASK26;
    309     uint32_t mask = (uint32_t)0 - carry;
    310     h[0] = (h[0] & ~mask) | (g[0] & mask);
    311     h[1] = (h[1] & ~mask) | (g[1] & mask);
    312     h[2] = (h[2] & ~mask) | (g[2] & mask);
    313     h[3] = (h[3] & ~mask) | (g[3] & mask);
    314     h[4] = (h[4] & ~mask) | (g[4] & mask);
    315 
    316     /* repack 5x 26-bit limbs into 4x 32-bit limbs */
    317     uint32_t h0 =  h[0]        | (h[1] << 26);
    318     uint32_t h1 = (h[1] >>  6) | (h[2] << 20);
    319     uint32_t h2 = (h[2] >> 12) | (h[3] << 14);
    320     uint32_t h3 = (h[3] >> 18) | (h[4] <<  8);
    321 
    322     /* add s (high 16 bytes of key), mod 2^128 */
    323     uint32_t s0, s1, s2, s3;
    324     memcpy(&s0, key + 16, 4);
    325     memcpy(&s1, key + 20, 4);
    326     memcpy(&s2, key + 24, 4);
    327     memcpy(&s3, key + 28, 4);
    328 
    329     uint64_t a0 = (uint64_t)h0 + s0;
    330     uint64_t a1 = (uint64_t)h1 + s1 + (a0 >> 32);
    331     uint64_t a2 = (uint64_t)h2 + s2 + (a1 >> 32);
    332     uint64_t a3 = (uint64_t)h3 + s3 + (a2 >> 32);
    333 
    334     uint32_t o0 = (uint32_t)a0;
    335     uint32_t o1 = (uint32_t)a1;
    336     uint32_t o2 = (uint32_t)a2;
    337     uint32_t o3 = (uint32_t)a3;
    338 
    339     memcpy(mac_out + 0,  &o0, 4);
    340     memcpy(mac_out + 4,  &o1, 4);
    341     memcpy(mac_out + 8,  &o2, 4);
    342     memcpy(mac_out + 12, &o3, 4);
    343 }
    344 
    345 int poly1305_arm_available(void) {
    346     return 1;
    347 }
    348 
    349 #else
    350 
    351 void poly1305_mac_arm(const uint8_t *key, const uint8_t *msg,
    352                       size_t msglen, uint8_t *mac_out) {
    353     (void)key; (void)msg; (void)msglen; (void)mac_out;
    354 }
    355 
    356 int poly1305_arm_available(void) {
    357     return 0;
    358 }
    359 
    360 #endif