base64

Fast Haskell base64 encoding/decoding (docs.ppad.tech/base64).
git clone git://git.ppad.tech/base64.git
Log | Files | Refs | README | LICENSE

base64_arm.c (11215B)


      1 #include <stddef.h>
      2 #include <stdint.h>
      3 
      4 #if defined(__aarch64__)
      5 
      6 #include <arm_neon.h>
      7 
      8 static const uint8_t b64_alphabet[64] =
      9     "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
     10 
     11 /*
     12  * Encode 'l' input bytes at 'src' into ((l+2)/3)*4 ASCII chars at 'dst'.
     13  *
     14  * NEON kernel processes 12 input bytes per iteration:
     15  *   - vld1q_u8 loads 16 bytes (we use the first 12; reading 4 ahead is
     16  *     safe as long as l - i >= 16)
     17  *   - vqtbl1q_u8 with a shuffle mask gathers each 4-byte output lane as
     18  *     [b1, b0, b2, b1], the order that lets a single shift+mask extract
     19  *     each 6-bit index
     20  *   - 4 vshrq_n_u32 + vandq_u32 pull out indices i0..i3 (one per lane
     21  *     byte); see comments below for the bit math
     22  *   - vqtbl4q_u8 looks up each index in the 64-byte alphabet
     23  *   - vst1q_u8 stores 16 output chars
     24  *
     25  * A scalar loop finishes any full triplet that didn't make the NEON
     26  * cut-off, and a final branch emits the 0/1/2-byte padded tail.
     27  */
     28 void base64_encode_arm(const uint8_t *src, uint8_t *dst, size_t l) {
     29     uint8x16x4_t lut;
     30     lut.val[0] = vld1q_u8(b64_alphabet);
     31     lut.val[1] = vld1q_u8(b64_alphabet + 16);
     32     lut.val[2] = vld1q_u8(b64_alphabet + 32);
     33     lut.val[3] = vld1q_u8(b64_alphabet + 48);
     34 
     35     /* For each 4-byte lane of output of vqtbl1q_u8, we want
     36      * [b1, b0, b2, b1] in memory order — viewed as a little-endian u32
     37      * lane that is (b1) | (b0 << 8) | (b2 << 16) | (b1 << 24).         */
     38     static const uint8_t shuf_enc[16] = {
     39         1, 0, 2, 1,
     40         4, 3, 5, 4,
     41         7, 6, 8, 7,
     42        10, 9,11,10,
     43     };
     44     uint8x16_t shuf = vld1q_u8(shuf_enc);
     45 
     46     size_t i = 0, o = 0;
     47     while (i + 16 <= l) {
     48         uint8x16_t in       = vld1q_u8(src + i);
     49         uint8x16_t shuffled = vqtbl1q_u8(in, shuf);
     50         uint32x4_t lane     = vreinterpretq_u32_u8(shuffled);
     51         uint32x4_t mask6    = vdupq_n_u32(0x3F);
     52 
     53         /* lane (LE) = b1 | (b0 << 8) | (b2 << 16) | (b1 << 24)
     54          *  i0 (top 6 of b0)          = (lane >> 10) & 0x3F
     55          *  i1 (lo 2 of b0|hi 4 of b1)= (lane >>  4) & 0x3F
     56          *  i2 (lo 4 of b1|hi 2 of b2)= (lane >> 22) & 0x3F  [uses b1 copy at byte 3]
     57          *  i3 (lo 6 of b2)           = (lane >> 16) & 0x3F          */
     58         uint32x4_t i0 = vandq_u32(vshrq_n_u32(lane, 10), mask6);
     59         uint32x4_t i1 = vandq_u32(vshrq_n_u32(lane,  4), mask6);
     60         uint32x4_t i2 = vandq_u32(vshrq_n_u32(lane, 22), mask6);
     61         uint32x4_t i3 = vandq_u32(vshrq_n_u32(lane, 16), mask6);
     62 
     63         /* assemble per-lane u32 = i0 | (i1 << 8) | (i2 << 16) | (i3 << 24) */
     64         uint32x4_t idx_u32 = vorrq_u32(
     65             vorrq_u32(i0, vshlq_n_u32(i1, 8)),
     66             vorrq_u32(vshlq_n_u32(i2, 16), vshlq_n_u32(i3, 24)));
     67 
     68         uint8x16_t indices = vreinterpretq_u8_u32(idx_u32);
     69         uint8x16_t chars   = vqtbl4q_u8(lut, indices);
     70         vst1q_u8(dst + o, chars);
     71 
     72         i += 12;
     73         o += 16;
     74     }
     75 
     76     /* scalar tail: full triplets */
     77     for (; i + 3 <= l; i += 3, o += 4) {
     78         uint32_t v = ((uint32_t)src[i] << 16)
     79                    | ((uint32_t)src[i + 1] << 8)
     80                    |  (uint32_t)src[i + 2];
     81         dst[o]     = b64_alphabet[(v >> 18) & 0x3F];
     82         dst[o + 1] = b64_alphabet[(v >> 12) & 0x3F];
     83         dst[o + 2] = b64_alphabet[(v >>  6) & 0x3F];
     84         dst[o + 3] = b64_alphabet[ v        & 0x3F];
     85     }
     86 
     87     /* 1- or 2-byte padded tail */
     88     if (i + 1 == l) {
     89         uint8_t b = src[i];
     90         dst[o]     = b64_alphabet[(b >> 2) & 0x3F];
     91         dst[o + 1] = b64_alphabet[(b & 0x03) << 4];
     92         dst[o + 2] = '=';
     93         dst[o + 3] = '=';
     94     } else if (i + 2 == l) {
     95         uint8_t b0 = src[i];
     96         uint8_t b1 = src[i + 1];
     97         dst[o]     = b64_alphabet[(b0 >> 2) & 0x3F];
     98         dst[o + 1] = b64_alphabet[((b0 & 0x03) << 4) | (b1 >> 4)];
     99         dst[o + 2] = b64_alphabet[(b1 & 0x0F) << 2];
    100         dst[o + 3] = '=';
    101     }
    102 }
    103 
    104 /*
    105  * Convert 16 ASCII base64 chars to 6-bit values in 'val'.
    106  * Each lane of 'bad' is 0xff if the corresponding input is not a
    107  * valid base64 char ('A'..'Z', 'a'..'z', '0'..'9', '+', '/'), else 0.
    108  * '=' is treated as invalid here; the caller handles padding.
    109  */
    110 static inline void ascii_to_b64(uint8x16_t c,
    111                                 uint8x16_t *val,
    112                                 uint8x16_t *bad) {
    113     uint8x16_t is_upper = vandq_u8(vcgeq_u8(c, vdupq_n_u8('A')),
    114                                     vcleq_u8(c, vdupq_n_u8('Z')));
    115     uint8x16_t is_lower = vandq_u8(vcgeq_u8(c, vdupq_n_u8('a')),
    116                                     vcleq_u8(c, vdupq_n_u8('z')));
    117     uint8x16_t is_digit = vandq_u8(vcgeq_u8(c, vdupq_n_u8('0')),
    118                                     vcleq_u8(c, vdupq_n_u8('9')));
    119     uint8x16_t is_plus  = vceqq_u8(c, vdupq_n_u8('+'));
    120     uint8x16_t is_slash = vceqq_u8(c, vdupq_n_u8('/'));
    121 
    122     /* Per-lane additive offset that takes c to its 6-bit value:
    123      *   'A'..'Z':  +(-65) = 0xBF mod 256   ('A' + 0xBF = 0)
    124      *   'a'..'z':  +(-71) = 0xB9
    125      *   '0'..'9':  +4
    126      *   '+':       +19
    127      *   '/':       +16
    128      * Invalid lanes get +0; 'bad' flags them.                          */
    129     uint8x16_t add = vorrq_u8(
    130         vandq_u8(is_upper, vdupq_n_u8((uint8_t)(0u - 65))),
    131         vorrq_u8(
    132             vandq_u8(is_lower, vdupq_n_u8((uint8_t)(0u - 71))),
    133             vorrq_u8(
    134                 vandq_u8(is_digit, vdupq_n_u8(4)),
    135                 vorrq_u8(
    136                     vandq_u8(is_plus,  vdupq_n_u8(19)),
    137                     vandq_u8(is_slash, vdupq_n_u8(16))))));
    138 
    139     *val = vaddq_u8(c, add);
    140 
    141     uint8x16_t any_valid = vorrq_u8(is_upper,
    142                             vorrq_u8(is_lower,
    143                               vorrq_u8(is_digit,
    144                                 vorrq_u8(is_plus, is_slash))));
    145     *bad = vmvnq_u8(any_valid);
    146 }
    147 
    148 static inline uint8_t scalar_b64(uint8_t c) {
    149     if (c >= 'A' && c <= 'Z') return (uint8_t)(c - 'A');
    150     if (c >= 'a' && c <= 'z') return (uint8_t)(c - 'a' + 26);
    151     if (c >= '0' && c <= '9') return (uint8_t)(c - '0' + 52);
    152     if (c == '+') return 62;
    153     if (c == '/') return 63;
    154     return 0x80; /* invalid sentinel */
    155 }
    156 
    157 /*
    158  * Decode 'inlen' ASCII base64 chars at 'src' into 'outlen' bytes at
    159  * 'dst'.  Returns 1 on success, 0 on any decoding error: malformed
    160  * length, malformed padding, invalid char in body, or invalid char /
    161  * non-zero non-data bits in the padded final quartet (RFC 4648 §3.5).
    162  *
    163  * Caller must allocate 'outlen' bytes at 'dst' and pass the correct
    164  * outlen for the given inlen and padding; mismatch returns 0 with
    165  * 'dst' unspecified.
    166  *
    167  * Body NEON kernel processes 16 input chars (= 4 quartets) per
    168  * iteration:
    169  *   - vld1q_u8 loads 16 chars
    170  *   - ascii_to_b64 validates each lane and yields 6-bit values
    171  *   - per u32x4 lane: build the 24-bit packed value V = (v0 << 18) |
    172  *     (v1 << 12) | (v2 << 6) | v3, whose bytes in LE are [V_low,
    173  *     V_mid, V_high, 0]
    174  *   - vqtbl1q_u8 reshuffles those bytes into [V_high, V_mid, V_low]
    175  *     per triplet, yielding 12 output bytes at the bottom of the
    176  *     output vector
    177  *   - vst1q_u8 stores 16 bytes (writing 12 valid + 4 spurious; the
    178  *     loop bound 'o + 16 <= body_outlen' keeps the overrun within
    179  *     the allocated buffer, and the spurious bytes get clobbered by
    180  *     the next iteration or by the scalar tail / final quartet)
    181  *
    182  * A scalar tail finishes any body quartets that didn't make the
    183  * NEON cut-off, then the padded final quartet is decoded explicitly.
    184  */
    185 int base64_decode_arm(const uint8_t *src, uint8_t *dst,
    186                       size_t inlen, size_t outlen) {
    187     if (inlen == 0) return outlen == 0;
    188     if (inlen & 0x3) return 0;
    189 
    190     uint8_t c_pre = src[inlen - 2];
    191     uint8_t c_end = src[inlen - 1];
    192     size_t pad = 0;
    193     if (c_end == '=') {
    194         if (c_pre == '=') pad = 2;
    195         else              pad = 1;
    196     } else if (c_pre == '=') {
    197         return 0; /* '=' at offset -2 only is malformed */
    198     }
    199 
    200     size_t nfull = inlen >> 2;
    201     if (outlen != nfull * 3 - pad) return 0;
    202 
    203     size_t body_chars  = (pad > 0) ? (inlen - 4) : inlen;
    204     size_t body_outlen = (body_chars >> 2) * 3;
    205 
    206     uint8x16_t bad_acc = vdupq_n_u8(0);
    207 
    208     static const uint8_t pack_shuf[16] = {
    209          2, 1, 0,
    210          6, 5, 4,
    211         10, 9, 8,
    212         14,13,12,
    213          0xFF, 0xFF, 0xFF, 0xFF
    214     };
    215     uint8x16_t pshuf = vld1q_u8(pack_shuf);
    216 
    217     size_t i = 0, o = 0;
    218     while (o + 16 <= body_outlen) {
    219         uint8x16_t c = vld1q_u8(src + i);
    220         uint8x16_t val, this_bad;
    221         ascii_to_b64(c, &val, &this_bad);
    222         bad_acc = vorrq_u8(bad_acc, this_bad);
    223 
    224         uint32x4_t v32   = vreinterpretq_u32_u8(val);
    225         uint32x4_t mask8 = vdupq_n_u32(0xFF);
    226 
    227         uint32x4_t p0 = vshlq_n_u32(vandq_u32(v32, mask8), 18);
    228         uint32x4_t p1 = vshlq_n_u32(
    229             vandq_u32(vshrq_n_u32(v32,  8), mask8), 12);
    230         uint32x4_t p2 = vshlq_n_u32(
    231             vandq_u32(vshrq_n_u32(v32, 16), mask8),  6);
    232         uint32x4_t p3 = vshrq_n_u32(v32, 24);
    233 
    234         uint32x4_t V       = vorrq_u32(vorrq_u32(p0, p1),
    235                                        vorrq_u32(p2, p3));
    236         uint8x16_t V_bytes = vreinterpretq_u8_u32(V);
    237         uint8x16_t packed  = vqtbl1q_u8(V_bytes, pshuf);
    238 
    239         vst1q_u8(dst + o, packed); /* 12 valid bytes + 4 spurious */
    240 
    241         i += 16;
    242         o += 12;
    243     }
    244 
    245     uint8_t tail_bad = 0;
    246 
    247     /* scalar body tail (full quartets, no '=') */
    248     while (o + 3 <= body_outlen) {
    249         uint8_t v0 = scalar_b64(src[i]);
    250         uint8_t v1 = scalar_b64(src[i + 1]);
    251         uint8_t v2 = scalar_b64(src[i + 2]);
    252         uint8_t v3 = scalar_b64(src[i + 3]);
    253         tail_bad |= (v0 | v1 | v2 | v3) & 0x80;
    254         dst[o]     = (uint8_t)((v0 << 2)         | (v1 >> 4));
    255         dst[o + 1] = (uint8_t)(((v1 & 0x0F) << 4) | (v2 >> 2));
    256         dst[o + 2] = (uint8_t)(((v2 & 0x03) << 6) | (v3 & 0x3F));
    257         i += 4;
    258         o += 3;
    259     }
    260 
    261     /* padded final quartet */
    262     if (pad > 0) {
    263         uint8_t v0 = scalar_b64(src[i]);
    264         uint8_t v1 = scalar_b64(src[i + 1]);
    265         if ((v0 | v1) & 0x80) return 0;
    266 
    267         if (pad == 2) {
    268             /* "XX==" -> 1 output byte; bottom 4 bits of v1 must be 0 */
    269             if (v1 & 0x0F) return 0;
    270             dst[o] = (uint8_t)((v0 << 2) | (v1 >> 4));
    271         } else {
    272             /* "XXX=" -> 2 output bytes; bottom 2 bits of v2 must be 0 */
    273             uint8_t v2 = scalar_b64(src[i + 2]);
    274             if (v2 & 0x80)  return 0;
    275             if (v2 & 0x03) return 0;
    276             dst[o]     = (uint8_t)((v0 << 2)        | (v1 >> 4));
    277             dst[o + 1] = (uint8_t)(((v1 & 0x0F) << 4) | (v2 >> 2));
    278         }
    279     }
    280 
    281     return (vmaxvq_u8(bad_acc) == 0) && (tail_bad == 0);
    282 }
    283 
    284 int base64_arm_available(void) {
    285     return 1;
    286 }
    287 
    288 #else
    289 
    290 /* stubs for non-aarch64 builds; never reached because dispatch is
    291  * gated on 'base64_arm_available' returning 0                     */
    292 
    293 void base64_encode_arm(const uint8_t *src, uint8_t *dst, size_t l) {
    294     (void)src; (void)dst; (void)l;
    295 }
    296 
    297 int base64_decode_arm(const uint8_t *src, uint8_t *dst,
    298                       size_t inlen, size_t outlen) {
    299     (void)src; (void)dst; (void)inlen; (void)outlen;
    300     return 0;
    301 }
    302 
    303 int base64_arm_available(void) {
    304     return 0;
    305 }
    306 
    307 #endif