base64_arm.c (11215B)
1 #include <stddef.h> 2 #include <stdint.h> 3 4 #if defined(__aarch64__) 5 6 #include <arm_neon.h> 7 8 static const uint8_t b64_alphabet[64] = 9 "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; 10 11 /* 12 * Encode 'l' input bytes at 'src' into ((l+2)/3)*4 ASCII chars at 'dst'. 13 * 14 * NEON kernel processes 12 input bytes per iteration: 15 * - vld1q_u8 loads 16 bytes (we use the first 12; reading 4 ahead is 16 * safe as long as l - i >= 16) 17 * - vqtbl1q_u8 with a shuffle mask gathers each 4-byte output lane as 18 * [b1, b0, b2, b1], the order that lets a single shift+mask extract 19 * each 6-bit index 20 * - 4 vshrq_n_u32 + vandq_u32 pull out indices i0..i3 (one per lane 21 * byte); see comments below for the bit math 22 * - vqtbl4q_u8 looks up each index in the 64-byte alphabet 23 * - vst1q_u8 stores 16 output chars 24 * 25 * A scalar loop finishes any full triplet that didn't make the NEON 26 * cut-off, and a final branch emits the 0/1/2-byte padded tail. 27 */ 28 void base64_encode_arm(const uint8_t *src, uint8_t *dst, size_t l) { 29 uint8x16x4_t lut; 30 lut.val[0] = vld1q_u8(b64_alphabet); 31 lut.val[1] = vld1q_u8(b64_alphabet + 16); 32 lut.val[2] = vld1q_u8(b64_alphabet + 32); 33 lut.val[3] = vld1q_u8(b64_alphabet + 48); 34 35 /* For each 4-byte lane of output of vqtbl1q_u8, we want 36 * [b1, b0, b2, b1] in memory order — viewed as a little-endian u32 37 * lane that is (b1) | (b0 << 8) | (b2 << 16) | (b1 << 24). */ 38 static const uint8_t shuf_enc[16] = { 39 1, 0, 2, 1, 40 4, 3, 5, 4, 41 7, 6, 8, 7, 42 10, 9,11,10, 43 }; 44 uint8x16_t shuf = vld1q_u8(shuf_enc); 45 46 size_t i = 0, o = 0; 47 while (i + 16 <= l) { 48 uint8x16_t in = vld1q_u8(src + i); 49 uint8x16_t shuffled = vqtbl1q_u8(in, shuf); 50 uint32x4_t lane = vreinterpretq_u32_u8(shuffled); 51 uint32x4_t mask6 = vdupq_n_u32(0x3F); 52 53 /* lane (LE) = b1 | (b0 << 8) | (b2 << 16) | (b1 << 24) 54 * i0 (top 6 of b0) = (lane >> 10) & 0x3F 55 * i1 (lo 2 of b0|hi 4 of b1)= (lane >> 4) & 0x3F 56 * i2 (lo 4 of b1|hi 2 of b2)= (lane >> 22) & 0x3F [uses b1 copy at byte 3] 57 * i3 (lo 6 of b2) = (lane >> 16) & 0x3F */ 58 uint32x4_t i0 = vandq_u32(vshrq_n_u32(lane, 10), mask6); 59 uint32x4_t i1 = vandq_u32(vshrq_n_u32(lane, 4), mask6); 60 uint32x4_t i2 = vandq_u32(vshrq_n_u32(lane, 22), mask6); 61 uint32x4_t i3 = vandq_u32(vshrq_n_u32(lane, 16), mask6); 62 63 /* assemble per-lane u32 = i0 | (i1 << 8) | (i2 << 16) | (i3 << 24) */ 64 uint32x4_t idx_u32 = vorrq_u32( 65 vorrq_u32(i0, vshlq_n_u32(i1, 8)), 66 vorrq_u32(vshlq_n_u32(i2, 16), vshlq_n_u32(i3, 24))); 67 68 uint8x16_t indices = vreinterpretq_u8_u32(idx_u32); 69 uint8x16_t chars = vqtbl4q_u8(lut, indices); 70 vst1q_u8(dst + o, chars); 71 72 i += 12; 73 o += 16; 74 } 75 76 /* scalar tail: full triplets */ 77 for (; i + 3 <= l; i += 3, o += 4) { 78 uint32_t v = ((uint32_t)src[i] << 16) 79 | ((uint32_t)src[i + 1] << 8) 80 | (uint32_t)src[i + 2]; 81 dst[o] = b64_alphabet[(v >> 18) & 0x3F]; 82 dst[o + 1] = b64_alphabet[(v >> 12) & 0x3F]; 83 dst[o + 2] = b64_alphabet[(v >> 6) & 0x3F]; 84 dst[o + 3] = b64_alphabet[ v & 0x3F]; 85 } 86 87 /* 1- or 2-byte padded tail */ 88 if (i + 1 == l) { 89 uint8_t b = src[i]; 90 dst[o] = b64_alphabet[(b >> 2) & 0x3F]; 91 dst[o + 1] = b64_alphabet[(b & 0x03) << 4]; 92 dst[o + 2] = '='; 93 dst[o + 3] = '='; 94 } else if (i + 2 == l) { 95 uint8_t b0 = src[i]; 96 uint8_t b1 = src[i + 1]; 97 dst[o] = b64_alphabet[(b0 >> 2) & 0x3F]; 98 dst[o + 1] = b64_alphabet[((b0 & 0x03) << 4) | (b1 >> 4)]; 99 dst[o + 2] = b64_alphabet[(b1 & 0x0F) << 2]; 100 dst[o + 3] = '='; 101 } 102 } 103 104 /* 105 * Convert 16 ASCII base64 chars to 6-bit values in 'val'. 106 * Each lane of 'bad' is 0xff if the corresponding input is not a 107 * valid base64 char ('A'..'Z', 'a'..'z', '0'..'9', '+', '/'), else 0. 108 * '=' is treated as invalid here; the caller handles padding. 109 */ 110 static inline void ascii_to_b64(uint8x16_t c, 111 uint8x16_t *val, 112 uint8x16_t *bad) { 113 uint8x16_t is_upper = vandq_u8(vcgeq_u8(c, vdupq_n_u8('A')), 114 vcleq_u8(c, vdupq_n_u8('Z'))); 115 uint8x16_t is_lower = vandq_u8(vcgeq_u8(c, vdupq_n_u8('a')), 116 vcleq_u8(c, vdupq_n_u8('z'))); 117 uint8x16_t is_digit = vandq_u8(vcgeq_u8(c, vdupq_n_u8('0')), 118 vcleq_u8(c, vdupq_n_u8('9'))); 119 uint8x16_t is_plus = vceqq_u8(c, vdupq_n_u8('+')); 120 uint8x16_t is_slash = vceqq_u8(c, vdupq_n_u8('/')); 121 122 /* Per-lane additive offset that takes c to its 6-bit value: 123 * 'A'..'Z': +(-65) = 0xBF mod 256 ('A' + 0xBF = 0) 124 * 'a'..'z': +(-71) = 0xB9 125 * '0'..'9': +4 126 * '+': +19 127 * '/': +16 128 * Invalid lanes get +0; 'bad' flags them. */ 129 uint8x16_t add = vorrq_u8( 130 vandq_u8(is_upper, vdupq_n_u8((uint8_t)(0u - 65))), 131 vorrq_u8( 132 vandq_u8(is_lower, vdupq_n_u8((uint8_t)(0u - 71))), 133 vorrq_u8( 134 vandq_u8(is_digit, vdupq_n_u8(4)), 135 vorrq_u8( 136 vandq_u8(is_plus, vdupq_n_u8(19)), 137 vandq_u8(is_slash, vdupq_n_u8(16)))))); 138 139 *val = vaddq_u8(c, add); 140 141 uint8x16_t any_valid = vorrq_u8(is_upper, 142 vorrq_u8(is_lower, 143 vorrq_u8(is_digit, 144 vorrq_u8(is_plus, is_slash)))); 145 *bad = vmvnq_u8(any_valid); 146 } 147 148 static inline uint8_t scalar_b64(uint8_t c) { 149 if (c >= 'A' && c <= 'Z') return (uint8_t)(c - 'A'); 150 if (c >= 'a' && c <= 'z') return (uint8_t)(c - 'a' + 26); 151 if (c >= '0' && c <= '9') return (uint8_t)(c - '0' + 52); 152 if (c == '+') return 62; 153 if (c == '/') return 63; 154 return 0x80; /* invalid sentinel */ 155 } 156 157 /* 158 * Decode 'inlen' ASCII base64 chars at 'src' into 'outlen' bytes at 159 * 'dst'. Returns 1 on success, 0 on any decoding error: malformed 160 * length, malformed padding, invalid char in body, or invalid char / 161 * non-zero non-data bits in the padded final quartet (RFC 4648 §3.5). 162 * 163 * Caller must allocate 'outlen' bytes at 'dst' and pass the correct 164 * outlen for the given inlen and padding; mismatch returns 0 with 165 * 'dst' unspecified. 166 * 167 * Body NEON kernel processes 16 input chars (= 4 quartets) per 168 * iteration: 169 * - vld1q_u8 loads 16 chars 170 * - ascii_to_b64 validates each lane and yields 6-bit values 171 * - per u32x4 lane: build the 24-bit packed value V = (v0 << 18) | 172 * (v1 << 12) | (v2 << 6) | v3, whose bytes in LE are [V_low, 173 * V_mid, V_high, 0] 174 * - vqtbl1q_u8 reshuffles those bytes into [V_high, V_mid, V_low] 175 * per triplet, yielding 12 output bytes at the bottom of the 176 * output vector 177 * - vst1q_u8 stores 16 bytes (writing 12 valid + 4 spurious; the 178 * loop bound 'o + 16 <= body_outlen' keeps the overrun within 179 * the allocated buffer, and the spurious bytes get clobbered by 180 * the next iteration or by the scalar tail / final quartet) 181 * 182 * A scalar tail finishes any body quartets that didn't make the 183 * NEON cut-off, then the padded final quartet is decoded explicitly. 184 */ 185 int base64_decode_arm(const uint8_t *src, uint8_t *dst, 186 size_t inlen, size_t outlen) { 187 if (inlen == 0) return outlen == 0; 188 if (inlen & 0x3) return 0; 189 190 uint8_t c_pre = src[inlen - 2]; 191 uint8_t c_end = src[inlen - 1]; 192 size_t pad = 0; 193 if (c_end == '=') { 194 if (c_pre == '=') pad = 2; 195 else pad = 1; 196 } else if (c_pre == '=') { 197 return 0; /* '=' at offset -2 only is malformed */ 198 } 199 200 size_t nfull = inlen >> 2; 201 if (outlen != nfull * 3 - pad) return 0; 202 203 size_t body_chars = (pad > 0) ? (inlen - 4) : inlen; 204 size_t body_outlen = (body_chars >> 2) * 3; 205 206 uint8x16_t bad_acc = vdupq_n_u8(0); 207 208 static const uint8_t pack_shuf[16] = { 209 2, 1, 0, 210 6, 5, 4, 211 10, 9, 8, 212 14,13,12, 213 0xFF, 0xFF, 0xFF, 0xFF 214 }; 215 uint8x16_t pshuf = vld1q_u8(pack_shuf); 216 217 size_t i = 0, o = 0; 218 while (o + 16 <= body_outlen) { 219 uint8x16_t c = vld1q_u8(src + i); 220 uint8x16_t val, this_bad; 221 ascii_to_b64(c, &val, &this_bad); 222 bad_acc = vorrq_u8(bad_acc, this_bad); 223 224 uint32x4_t v32 = vreinterpretq_u32_u8(val); 225 uint32x4_t mask8 = vdupq_n_u32(0xFF); 226 227 uint32x4_t p0 = vshlq_n_u32(vandq_u32(v32, mask8), 18); 228 uint32x4_t p1 = vshlq_n_u32( 229 vandq_u32(vshrq_n_u32(v32, 8), mask8), 12); 230 uint32x4_t p2 = vshlq_n_u32( 231 vandq_u32(vshrq_n_u32(v32, 16), mask8), 6); 232 uint32x4_t p3 = vshrq_n_u32(v32, 24); 233 234 uint32x4_t V = vorrq_u32(vorrq_u32(p0, p1), 235 vorrq_u32(p2, p3)); 236 uint8x16_t V_bytes = vreinterpretq_u8_u32(V); 237 uint8x16_t packed = vqtbl1q_u8(V_bytes, pshuf); 238 239 vst1q_u8(dst + o, packed); /* 12 valid bytes + 4 spurious */ 240 241 i += 16; 242 o += 12; 243 } 244 245 uint8_t tail_bad = 0; 246 247 /* scalar body tail (full quartets, no '=') */ 248 while (o + 3 <= body_outlen) { 249 uint8_t v0 = scalar_b64(src[i]); 250 uint8_t v1 = scalar_b64(src[i + 1]); 251 uint8_t v2 = scalar_b64(src[i + 2]); 252 uint8_t v3 = scalar_b64(src[i + 3]); 253 tail_bad |= (v0 | v1 | v2 | v3) & 0x80; 254 dst[o] = (uint8_t)((v0 << 2) | (v1 >> 4)); 255 dst[o + 1] = (uint8_t)(((v1 & 0x0F) << 4) | (v2 >> 2)); 256 dst[o + 2] = (uint8_t)(((v2 & 0x03) << 6) | (v3 & 0x3F)); 257 i += 4; 258 o += 3; 259 } 260 261 /* padded final quartet */ 262 if (pad > 0) { 263 uint8_t v0 = scalar_b64(src[i]); 264 uint8_t v1 = scalar_b64(src[i + 1]); 265 if ((v0 | v1) & 0x80) return 0; 266 267 if (pad == 2) { 268 /* "XX==" -> 1 output byte; bottom 4 bits of v1 must be 0 */ 269 if (v1 & 0x0F) return 0; 270 dst[o] = (uint8_t)((v0 << 2) | (v1 >> 4)); 271 } else { 272 /* "XXX=" -> 2 output bytes; bottom 2 bits of v2 must be 0 */ 273 uint8_t v2 = scalar_b64(src[i + 2]); 274 if (v2 & 0x80) return 0; 275 if (v2 & 0x03) return 0; 276 dst[o] = (uint8_t)((v0 << 2) | (v1 >> 4)); 277 dst[o + 1] = (uint8_t)(((v1 & 0x0F) << 4) | (v2 >> 2)); 278 } 279 } 280 281 return (vmaxvq_u8(bad_acc) == 0) && (tail_bad == 0); 282 } 283 284 int base64_arm_available(void) { 285 return 1; 286 } 287 288 #else 289 290 /* stubs for non-aarch64 builds; never reached because dispatch is 291 * gated on 'base64_arm_available' returning 0 */ 292 293 void base64_encode_arm(const uint8_t *src, uint8_t *dst, size_t l) { 294 (void)src; (void)dst; (void)l; 295 } 296 297 int base64_decode_arm(const uint8_t *src, uint8_t *dst, 298 size_t inlen, size_t outlen) { 299 (void)src; (void)dst; (void)inlen; (void)outlen; 300 return 0; 301 } 302 303 int base64_arm_available(void) { 304 return 0; 305 } 306 307 #endif