poly1305_arm.c (12697B)
1 #include <stddef.h> 2 #include <stdint.h> 3 #include <string.h> 4 5 #if defined(__aarch64__) 6 7 #include <arm_neon.h> 8 9 /* 10 * Poly1305 (RFC 8439). ARM acceleration via two paths: 11 * 12 * 1. A scalar 26-bit-limb kernel using 64-bit native arithmetic for 13 * the limb multiplications. Used for setup (precomputing r^k), 14 * for messages shorter than 64 bytes, and for the < 4-block tail 15 * of longer messages. Also serves as the stage-1 reference; the 16 * NEON kernel below was added in a separate commit. 17 * 18 * 2. A NEON 4-way parallel kernel (when at least 4 full blocks 19 * remain). Layout: 5 'uint32x4_t' limb vectors hold one limb 20 * position each across 4 message blocks; matching r-power 21 * vectors hold (r^4, r^3, r^2, r^1) at the same limb position. 22 * For each output position d_j we accumulate 5 partial products 23 * across both vector halves with 'vmull'/'vmlal', then 24 * horizontally sum the 4 lanes with 'vaddvq_u64'. The same 25 * carry-propagation pattern as the scalar path reduces back to 26 * 26-bit limbs. 27 * 28 * Reduction rule (used in both paths): 2^130 = 5 (mod 2^130 - 5), 29 * so any partial product whose 'position' exceeds 4 folds back as 30 * (5 * value) into position (pos - 5). 31 */ 32 33 #define MASK26 0x3ffffffu 34 35 /* 36 * Multiply two 130-bit values mod (2^130 - 5). Inputs in 5x 26-bit 37 * limb form, output in 5x 26-bit limb form (each output limb < 2^26 38 * except possibly limb 1, which may carry a small excess absorbed 39 * by the next 'mul_mod_p' or by 'normalize'). 40 */ 41 static void mul_mod_p(const uint32_t a[5], const uint32_t b[5], 42 uint32_t out[5]) { 43 uint64_t d0, d1, d2, d3, d4, c; 44 45 d0 = (uint64_t)a[0]*b[0] 46 + 5 * ((uint64_t)a[4]*b[1] + (uint64_t)a[3]*b[2] 47 + (uint64_t)a[2]*b[3] + (uint64_t)a[1]*b[4]); 48 d1 = (uint64_t)a[0]*b[1] + (uint64_t)a[1]*b[0] 49 + 5 * ((uint64_t)a[4]*b[2] + (uint64_t)a[3]*b[3] 50 + (uint64_t)a[2]*b[4]); 51 d2 = (uint64_t)a[0]*b[2] + (uint64_t)a[1]*b[1] 52 + (uint64_t)a[2]*b[0] 53 + 5 * ((uint64_t)a[4]*b[3] + (uint64_t)a[3]*b[4]); 54 d3 = (uint64_t)a[0]*b[3] + (uint64_t)a[1]*b[2] 55 + (uint64_t)a[2]*b[1] + (uint64_t)a[3]*b[0] 56 + 5 * ((uint64_t)a[4]*b[4]); 57 d4 = (uint64_t)a[0]*b[4] + (uint64_t)a[1]*b[3] 58 + (uint64_t)a[2]*b[2] + (uint64_t)a[3]*b[1] 59 + (uint64_t)a[4]*b[0]; 60 61 c = d0 >> 26; d0 &= MASK26; d1 += c; 62 c = d1 >> 26; d1 &= MASK26; d2 += c; 63 c = d2 >> 26; d2 &= MASK26; d3 += c; 64 c = d3 >> 26; d3 &= MASK26; d4 += c; 65 c = d4 >> 26; d4 &= MASK26; d0 += c * 5; 66 c = d0 >> 26; d0 &= MASK26; d1 += c; 67 68 out[0] = (uint32_t)d0; 69 out[1] = (uint32_t)d1; 70 out[2] = (uint32_t)d2; 71 out[3] = (uint32_t)d3; 72 out[4] = (uint32_t)d4; 73 } 74 75 /* 76 * Parse 16 little-endian bytes plus a 'hibit' value (0 or 1) at bit 77 * 128 into 5 26-bit limbs. 78 */ 79 static inline void blk2limbs(const uint8_t m[16], uint32_t hibit, 80 uint32_t l[5]) { 81 uint32_t t0, t1, t2, t3; 82 memcpy(&t0, m, 4); 83 memcpy(&t1, m + 4, 4); 84 memcpy(&t2, m + 8, 4); 85 memcpy(&t3, m + 12, 4); 86 l[0] = t0 & MASK26; 87 l[1] = ((t0 >> 26) | (t1 << 6)) & MASK26; 88 l[2] = ((t1 >> 20) | (t2 << 12)) & MASK26; 89 l[3] = ((t2 >> 14) | (t3 << 18)) & MASK26; 90 l[4] = (t3 >> 8) | (hibit << 24); 91 } 92 93 /* 94 * Process one full 16-byte block: h := (h + m) * r mod p, scalar 95 * implementation. 96 */ 97 static inline void scalar_block(uint32_t h[5], const uint32_t r[5], 98 const uint8_t m[16], uint32_t hibit) { 99 uint32_t blk[5]; 100 blk2limbs(m, hibit, blk); 101 102 uint32_t hl[5]; 103 hl[0] = h[0] + blk[0]; 104 hl[1] = h[1] + blk[1]; 105 hl[2] = h[2] + blk[2]; 106 hl[3] = h[3] + blk[3]; 107 hl[4] = h[4] + blk[4]; 108 109 mul_mod_p(hl, r, h); 110 } 111 112 /* 113 * 4-block NEON update. Computes: 114 * h := (h + m_0)*r^4 + m_1*r^3 + m_2*r^2 + m_3*r^1 mod p 115 * 116 * where m_0..m_3 are four consecutive 16-byte input blocks (so the 117 * function consumes 64 bytes). The polynomial identity is the 118 * standard Horner expansion of 4 sequential block updates. 119 */ 120 static void neon4_block(uint32_t h[5], 121 const uint32_t r1[5], const uint32_t r2[5], 122 const uint32_t r3[5], const uint32_t r4[5], 123 const uint8_t m[64]) { 124 /* Limbify the four blocks. */ 125 uint32_t b0[5], b1[5], b2[5], b3[5]; 126 blk2limbs(m, 1, b0); 127 blk2limbs(m + 16, 1, b1); 128 blk2limbs(m + 32, 1, b2); 129 blk2limbs(m + 48, 1, b3); 130 131 /* Fold the running accumulator into the first block. */ 132 b0[0] += h[0]; b0[1] += h[1]; b0[2] += h[2]; 133 b0[3] += h[3]; b0[4] += h[4]; 134 135 /* Pack messages: mv[i] = (b0[i], b1[i], b2[i], b3[i]). */ 136 uint32x4_t mv0 = { b0[0], b1[0], b2[0], b3[0] }; 137 uint32x4_t mv1 = { b0[1], b1[1], b2[1], b3[1] }; 138 uint32x4_t mv2 = { b0[2], b1[2], b2[2], b3[2] }; 139 uint32x4_t mv3 = { b0[3], b1[3], b2[3], b3[3] }; 140 uint32x4_t mv4 = { b0[4], b1[4], b2[4], b3[4] }; 141 142 /* Pack r-powers: rv[i] = (r^4[i], r^3[i], r^2[i], r^1[i]). */ 143 uint32x4_t rv0 = { r4[0], r3[0], r2[0], r1[0] }; 144 uint32x4_t rv1 = { r4[1], r3[1], r2[1], r1[1] }; 145 uint32x4_t rv2 = { r4[2], r3[2], r2[2], r1[2] }; 146 uint32x4_t rv3 = { r4[3], r3[3], r2[3], r1[3] }; 147 uint32x4_t rv4 = { r4[4], r3[4], r2[4], r1[4] }; 148 149 /* 5 * r-powers, for partial products whose position wraps past 150 * 4 (mod 2^130 - 5 = 5). */ 151 uint32x4_t rv1_5 = vaddq_u32(vshlq_n_u32(rv1, 2), rv1); 152 uint32x4_t rv2_5 = vaddq_u32(vshlq_n_u32(rv2, 2), rv2); 153 uint32x4_t rv3_5 = vaddq_u32(vshlq_n_u32(rv3, 2), rv3); 154 uint32x4_t rv4_5 = vaddq_u32(vshlq_n_u32(rv4, 2), rv4); 155 156 /* 157 * Output limb j = sum_i (mv[i] * appropriate r-power for shift j-i, 158 * with 5x multiplier when j - i is negative). Per output we sum 159 * across the 4 lanes ('vaddvq_u64') after collecting both vmull 160 * halves. 161 */ 162 #define MUL5_LO(m_, r_) vmull_u32(vget_low_u32(m_), vget_low_u32(r_)) 163 #define MUL5_HI(m_, r_) vmull_high_u32(m_, r_) 164 #define MLA_LO(acc, m_, r_) \ 165 vmlal_u32(acc, vget_low_u32(m_), vget_low_u32(r_)) 166 #define MLA_HI(acc, m_, r_) vmlal_high_u32(acc, m_, r_) 167 168 uint64x2_t l0 = MUL5_LO(mv0, rv0); 169 uint64x2_t h0 = MUL5_HI(mv0, rv0); 170 l0 = MLA_LO(l0, mv1, rv4_5); h0 = MLA_HI(h0, mv1, rv4_5); 171 l0 = MLA_LO(l0, mv2, rv3_5); h0 = MLA_HI(h0, mv2, rv3_5); 172 l0 = MLA_LO(l0, mv3, rv2_5); h0 = MLA_HI(h0, mv3, rv2_5); 173 l0 = MLA_LO(l0, mv4, rv1_5); h0 = MLA_HI(h0, mv4, rv1_5); 174 uint64_t d0 = vaddvq_u64(l0) + vaddvq_u64(h0); 175 176 uint64x2_t l1 = MUL5_LO(mv0, rv1); 177 uint64x2_t h1 = MUL5_HI(mv0, rv1); 178 l1 = MLA_LO(l1, mv1, rv0 ); h1 = MLA_HI(h1, mv1, rv0 ); 179 l1 = MLA_LO(l1, mv2, rv4_5); h1 = MLA_HI(h1, mv2, rv4_5); 180 l1 = MLA_LO(l1, mv3, rv3_5); h1 = MLA_HI(h1, mv3, rv3_5); 181 l1 = MLA_LO(l1, mv4, rv2_5); h1 = MLA_HI(h1, mv4, rv2_5); 182 uint64_t d1 = vaddvq_u64(l1) + vaddvq_u64(h1); 183 184 uint64x2_t l2 = MUL5_LO(mv0, rv2); 185 uint64x2_t h2 = MUL5_HI(mv0, rv2); 186 l2 = MLA_LO(l2, mv1, rv1 ); h2 = MLA_HI(h2, mv1, rv1 ); 187 l2 = MLA_LO(l2, mv2, rv0 ); h2 = MLA_HI(h2, mv2, rv0 ); 188 l2 = MLA_LO(l2, mv3, rv4_5); h2 = MLA_HI(h2, mv3, rv4_5); 189 l2 = MLA_LO(l2, mv4, rv3_5); h2 = MLA_HI(h2, mv4, rv3_5); 190 uint64_t d2 = vaddvq_u64(l2) + vaddvq_u64(h2); 191 192 uint64x2_t l3 = MUL5_LO(mv0, rv3); 193 uint64x2_t h3 = MUL5_HI(mv0, rv3); 194 l3 = MLA_LO(l3, mv1, rv2 ); h3 = MLA_HI(h3, mv1, rv2 ); 195 l3 = MLA_LO(l3, mv2, rv1 ); h3 = MLA_HI(h3, mv2, rv1 ); 196 l3 = MLA_LO(l3, mv3, rv0 ); h3 = MLA_HI(h3, mv3, rv0 ); 197 l3 = MLA_LO(l3, mv4, rv4_5); h3 = MLA_HI(h3, mv4, rv4_5); 198 uint64_t d3 = vaddvq_u64(l3) + vaddvq_u64(h3); 199 200 uint64x2_t l4 = MUL5_LO(mv0, rv4); 201 uint64x2_t h4 = MUL5_HI(mv0, rv4); 202 l4 = MLA_LO(l4, mv1, rv3); h4 = MLA_HI(h4, mv1, rv3); 203 l4 = MLA_LO(l4, mv2, rv2); h4 = MLA_HI(h4, mv2, rv2); 204 l4 = MLA_LO(l4, mv3, rv1); h4 = MLA_HI(h4, mv3, rv1); 205 l4 = MLA_LO(l4, mv4, rv0); h4 = MLA_HI(h4, mv4, rv0); 206 uint64_t d4 = vaddvq_u64(l4) + vaddvq_u64(h4); 207 208 #undef MUL5_LO 209 #undef MUL5_HI 210 #undef MLA_LO 211 #undef MLA_HI 212 213 /* Carry propagation, same shape as 'mul_mod_p'. */ 214 uint64_t c; 215 c = d0 >> 26; d0 &= MASK26; d1 += c; 216 c = d1 >> 26; d1 &= MASK26; d2 += c; 217 c = d2 >> 26; d2 &= MASK26; d3 += c; 218 c = d3 >> 26; d3 &= MASK26; d4 += c; 219 c = d4 >> 26; d4 &= MASK26; d0 += c * 5; 220 c = d0 >> 26; d0 &= MASK26; d1 += c; 221 222 h[0] = (uint32_t)d0; 223 h[1] = (uint32_t)d1; 224 h[2] = (uint32_t)d2; 225 h[3] = (uint32_t)d3; 226 h[4] = (uint32_t)d4; 227 } 228 229 /* 230 * Compute a 16-byte Poly1305 MAC over 'msg' (length 'msglen') using 231 * the 32-byte 'key'. Writes the tag to 'mac_out'. 232 */ 233 void poly1305_mac_arm(const uint8_t key[32], const uint8_t *msg, 234 size_t msglen, uint8_t mac_out[16]) { 235 /* clamp r */ 236 uint32_t t0, t1, t2, t3; 237 memcpy(&t0, key, 4); 238 memcpy(&t1, key + 4, 4); 239 memcpy(&t2, key + 8, 4); 240 memcpy(&t3, key + 12, 4); 241 t0 &= 0x0fffffffu; 242 t1 &= 0x0ffffffcu; 243 t2 &= 0x0ffffffcu; 244 t3 &= 0x0ffffffcu; 245 246 uint32_t r[5]; 247 r[0] = t0 & MASK26; 248 r[1] = ((t0 >> 26) | (t1 << 6)) & MASK26; 249 r[2] = ((t1 >> 20) | (t2 << 12)) & MASK26; 250 r[3] = ((t2 >> 14) | (t3 << 18)) & MASK26; 251 r[4] = (t3 >> 8); 252 253 uint32_t h[5] = { 0, 0, 0, 0, 0 }; 254 255 size_t pos = 0; 256 257 /* NEON 4-way path: amortizing the r^2/r^3/r^4 precomputation 258 * (3 scalar mul_mod_p calls) needs about 4 NEON iterations to 259 * break even versus the scalar block loop, so only engage it 260 * when we have at least 16 full blocks (256 bytes). */ 261 if (msglen >= 256) { 262 uint32_t r2[5], r3[5], r4[5]; 263 mul_mod_p(r, r, r2); 264 mul_mod_p(r2, r, r3); 265 mul_mod_p(r3, r, r4); 266 267 while (pos + 64 <= msglen) { 268 neon4_block(h, r, r2, r3, r4, msg + pos); 269 pos += 64; 270 } 271 } 272 273 /* Scalar tail: any remaining full blocks (< 4 of them). */ 274 while (pos + 16 <= msglen) { 275 scalar_block(h, r, msg + pos, 1); 276 pos += 16; 277 } 278 279 /* Final partial block (1..15 trailing bytes), if any. */ 280 if (pos < msglen) { 281 size_t rem = msglen - pos; 282 uint8_t pad[16] = { 0 }; 283 memcpy(pad, msg + pos, rem); 284 pad[rem] = 1; 285 scalar_block(h, r, pad, 0); 286 } 287 288 /* normalize h (mul_mod_p may leave a small excess in h[1]) */ 289 { 290 uint32_t c; 291 c = h[1] >> 26; h[1] &= MASK26; h[2] += c; 292 c = h[2] >> 26; h[2] &= MASK26; h[3] += c; 293 c = h[3] >> 26; h[3] &= MASK26; h[4] += c; 294 c = h[4] >> 26; h[4] &= MASK26; h[0] += c * 5; 295 c = h[0] >> 26; h[0] &= MASK26; h[1] += c; 296 } 297 298 /* full reduction to [0, p) via constant-time conditional 299 * subtraction of p */ 300 uint32_t g[5]; 301 uint32_t c = 5; 302 g[0] = h[0] + c; c = g[0] >> 26; g[0] &= MASK26; 303 g[1] = h[1] + c; c = g[1] >> 26; g[1] &= MASK26; 304 g[2] = h[2] + c; c = g[2] >> 26; g[2] &= MASK26; 305 g[3] = h[3] + c; c = g[3] >> 26; g[3] &= MASK26; 306 g[4] = h[4] + c; 307 uint32_t carry = g[4] >> 26; 308 g[4] &= MASK26; 309 uint32_t mask = (uint32_t)0 - carry; 310 h[0] = (h[0] & ~mask) | (g[0] & mask); 311 h[1] = (h[1] & ~mask) | (g[1] & mask); 312 h[2] = (h[2] & ~mask) | (g[2] & mask); 313 h[3] = (h[3] & ~mask) | (g[3] & mask); 314 h[4] = (h[4] & ~mask) | (g[4] & mask); 315 316 /* repack 5x 26-bit limbs into 4x 32-bit limbs */ 317 uint32_t h0 = h[0] | (h[1] << 26); 318 uint32_t h1 = (h[1] >> 6) | (h[2] << 20); 319 uint32_t h2 = (h[2] >> 12) | (h[3] << 14); 320 uint32_t h3 = (h[3] >> 18) | (h[4] << 8); 321 322 /* add s (high 16 bytes of key), mod 2^128 */ 323 uint32_t s0, s1, s2, s3; 324 memcpy(&s0, key + 16, 4); 325 memcpy(&s1, key + 20, 4); 326 memcpy(&s2, key + 24, 4); 327 memcpy(&s3, key + 28, 4); 328 329 uint64_t a0 = (uint64_t)h0 + s0; 330 uint64_t a1 = (uint64_t)h1 + s1 + (a0 >> 32); 331 uint64_t a2 = (uint64_t)h2 + s2 + (a1 >> 32); 332 uint64_t a3 = (uint64_t)h3 + s3 + (a2 >> 32); 333 334 uint32_t o0 = (uint32_t)a0; 335 uint32_t o1 = (uint32_t)a1; 336 uint32_t o2 = (uint32_t)a2; 337 uint32_t o3 = (uint32_t)a3; 338 339 memcpy(mac_out + 0, &o0, 4); 340 memcpy(mac_out + 4, &o1, 4); 341 memcpy(mac_out + 8, &o2, 4); 342 memcpy(mac_out + 12, &o3, 4); 343 } 344 345 int poly1305_arm_available(void) { 346 return 1; 347 } 348 349 #else 350 351 void poly1305_mac_arm(const uint8_t *key, const uint8_t *msg, 352 size_t msglen, uint8_t *mac_out) { 353 (void)key; (void)msg; (void)msglen; (void)mac_out; 354 } 355 356 int poly1305_arm_available(void) { 357 return 0; 358 } 359 360 #endif