chacha

The ChaCha20 stream cipher (docs.ppad.tech/chacha).
git clone git://git.ppad.tech/chacha.git
Log | Files | Refs | README | LICENSE

ChaCha20.hs (11109B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE MagicHash #-}
      4 {-# LANGUAGE RecordWildCards #-}
      5 {-# LANGUAGE UnboxedTuples #-}
      6 
      7 -- |
      8 -- Module: Crypto.Cipher.ChaCha20
      9 -- Copyright: (c) 2025 Jared Tobin
     10 -- License: MIT
     11 -- Maintainer: Jared Tobin <jared@ppad.tech>
     12 --
     13 -- A pure ChaCha20 implementation, as specified by
     14 -- [RFC 8439](https://datatracker.ietf.org/doc/html/rfc8439).
     15 
     16 module Crypto.Cipher.ChaCha20 (
     17     -- * ChaCha20 stream cipher
     18     cipher
     19 
     20     -- * ChaCha20 block function
     21   , block
     22 
     23     -- testing
     24   , ChaCha(..)
     25   , _chacha
     26   , _parse_key
     27   , _parse_nonce
     28   , _quarter
     29   , _quarter_pure
     30   , _rounds
     31   ) where
     32 
     33 import Control.Monad.ST
     34 import qualified Data.Bits as B
     35 import Data.Bits ((.|.), (.<<.), (.^.))
     36 import qualified Data.ByteString as BS
     37 import qualified Data.ByteString.Builder as BSB
     38 import qualified Data.ByteString.Internal as BI
     39 import qualified Data.ByteString.Unsafe as BU
     40 import Control.Monad.Primitive (PrimMonad, PrimState)
     41 import Data.Foldable (for_)
     42 import qualified Data.Primitive.PrimArray as PA
     43 import Foreign.ForeignPtr
     44 import GHC.Exts
     45 import GHC.Word
     46 
     47 -- utils ----------------------------------------------------------------------
     48 
     49 -- keystroke saver
     50 fi :: (Integral a, Num b) => a -> b
     51 fi = fromIntegral
     52 {-# INLINE fi #-}
     53 
     54 -- parse strict ByteString in LE order to Word32 (verbatim from
     55 -- Data.Binary)
     56 unsafe_word32le :: BS.ByteString -> Word32
     57 unsafe_word32le s =
     58   (fi (s `BU.unsafeIndex` 3) `B.unsafeShiftL` 24) .|.
     59   (fi (s `BU.unsafeIndex` 2) `B.unsafeShiftL` 16) .|.
     60   (fi (s `BU.unsafeIndex` 1) `B.unsafeShiftL`  8) .|.
     61   (fi (s `BU.unsafeIndex` 0))
     62 {-# INLINE unsafe_word32le #-}
     63 
     64 data WSPair = WSPair
     65   {-# UNPACK #-} !Word32
     66   {-# UNPACK #-} !BS.ByteString
     67 
     68 -- variant of Data.ByteString.splitAt that behaves like an incremental
     69 -- Word32 parser
     70 unsafe_parseWsPair :: BS.ByteString -> WSPair
     71 unsafe_parseWsPair (BI.BS x l) =
     72   WSPair (unsafe_word32le (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4))
     73 {-# INLINE unsafe_parseWsPair #-}
     74 
     75 -- chacha quarter round -------------------------------------------------------
     76 
     77 -- RFC8439 2.2
     78 _quarter
     79   :: PrimMonad m
     80   => ChaCha (PrimState m)
     81   -> Int
     82   -> Int
     83   -> Int
     84   -> Int
     85   -> m ()
     86 _quarter (ChaCha m) i0 i1 i2 i3 = do
     87   !(W32# a) <- PA.readPrimArray m i0
     88   !(W32# b) <- PA.readPrimArray m i1
     89   !(W32# c) <- PA.readPrimArray m i2
     90   !(W32# d) <- PA.readPrimArray m i3
     91 
     92   let !(# a1, b1, c1, d1 #) = quarter# a b c d
     93 
     94   PA.writePrimArray m i0 (W32# a1)
     95   PA.writePrimArray m i1 (W32# b1)
     96   PA.writePrimArray m i2 (W32# c1)
     97   PA.writePrimArray m i3 (W32# d1)
     98 {-# INLINEABLE _quarter #-}
     99 
    100 _quarter_pure
    101   :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32)
    102 _quarter_pure (W32# a) (W32# b) (W32# c) (W32# d) =
    103   let !(# a', b', c', d' #) = quarter# a b c d
    104   in  (W32# a', W32# b', W32# c', W32# d')
    105 {-# INLINE _quarter_pure #-}
    106 
    107 -- RFC8439 2.1
    108 quarter#
    109   :: Word32# -> Word32# -> Word32# -> Word32#
    110   -> (# Word32#, Word32#, Word32#, Word32# #)
    111 quarter# a b c d =
    112   let a0 = plusWord32# a b
    113       d0 = xorWord32# d a0
    114       d1 = rotateL# d0 16#
    115 
    116       c0 = plusWord32# c d1
    117       b0 = xorWord32# b c0
    118       b1 = rotateL# b0 12#
    119 
    120       a1 = plusWord32# a0 b1
    121       d2 = xorWord32# d1 a1
    122       d3 = rotateL# d2 8#
    123 
    124       c1 = plusWord32# c0 d3
    125       b2 = xorWord32# b1 c1
    126       b3 = rotateL# b2 7#
    127 
    128   in  (# a1, b3, c1, d3 #)
    129 {-# INLINE quarter# #-}
    130 
    131 rotateL# :: Word32# -> Int# -> Word32#
    132 rotateL# w i
    133   | isTrue# (i ==# 0#) = w
    134   | otherwise = wordToWord32# (
    135             ((word32ToWord# w) `uncheckedShiftL#` i)
    136       `or#` ((word32ToWord# w) `uncheckedShiftRL#` (32# -# i)))
    137 {-# INLINE rotateL# #-}
    138 
    139 -- key and nonce parsing ------------------------------------------------------
    140 
    141 data Key = Key {
    142     k0 :: {-# UNPACK #-} !Word32
    143   , k1 :: {-# UNPACK #-} !Word32
    144   , k2 :: {-# UNPACK #-} !Word32
    145   , k3 :: {-# UNPACK #-} !Word32
    146   , k4 :: {-# UNPACK #-} !Word32
    147   , k5 :: {-# UNPACK #-} !Word32
    148   , k6 :: {-# UNPACK #-} !Word32
    149   , k7 :: {-# UNPACK #-} !Word32
    150   }
    151   deriving (Eq, Show)
    152 
    153 -- parse strict 256-bit bytestring (length unchecked) to key
    154 _parse_key :: BS.ByteString -> Key
    155 _parse_key bs =
    156   let !(WSPair k0 t0) = unsafe_parseWsPair bs
    157       !(WSPair k1 t1) = unsafe_parseWsPair t0
    158       !(WSPair k2 t2) = unsafe_parseWsPair t1
    159       !(WSPair k3 t3) = unsafe_parseWsPair t2
    160       !(WSPair k4 t4) = unsafe_parseWsPair t3
    161       !(WSPair k5 t5) = unsafe_parseWsPair t4
    162       !(WSPair k6 t6) = unsafe_parseWsPair t5
    163       !(WSPair k7 t7) = unsafe_parseWsPair t6
    164   in  if   BS.null t7
    165       then Key {..}
    166       else error "ppad-chacha (_parse_key): bytes remaining"
    167 
    168 data Nonce = Nonce {
    169     n0 :: {-# UNPACK #-} !Word32
    170   , n1 :: {-# UNPACK #-} !Word32
    171   , n2 :: {-# UNPACK #-} !Word32
    172   }
    173   deriving (Eq, Show)
    174 
    175 -- parse strict 96-bit bytestring (length unchecked) to nonce
    176 _parse_nonce :: BS.ByteString -> Nonce
    177 _parse_nonce bs =
    178   let !(WSPair n0 t0) = unsafe_parseWsPair bs
    179       !(WSPair n1 t1) = unsafe_parseWsPair t0
    180       !(WSPair n2 t2) = unsafe_parseWsPair t1
    181   in  if   BS.null t2
    182       then Nonce {..}
    183       else error "ppad-chacha (_parse_nonce): bytes remaining"
    184 
    185 -- chacha20 block function ----------------------------------------------------
    186 
    187 newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32)
    188   deriving Eq
    189 
    190 _chacha
    191   :: PrimMonad m
    192   => Key
    193   -> Word32
    194   -> Nonce
    195   -> m (ChaCha (PrimState m))
    196 _chacha key counter nonce = do
    197   state <- _chacha_alloc
    198   _chacha_set state key counter nonce
    199   pure state
    200 
    201 -- allocate a new chacha state
    202 _chacha_alloc :: PrimMonad m => m (ChaCha (PrimState m))
    203 _chacha_alloc = fmap ChaCha (PA.newPrimArray 16)
    204 {-# INLINE _chacha_alloc #-}
    205 
    206 -- set the values of a chacha state
    207 _chacha_set
    208   :: PrimMonad m
    209   => ChaCha (PrimState m)
    210   -> Key
    211   -> Word32
    212   -> Nonce
    213   -> m ()
    214 _chacha_set (ChaCha arr) Key {..} counter Nonce {..}= do
    215   PA.writePrimArray arr 00 0x61707865
    216   PA.writePrimArray arr 01 0x3320646e
    217   PA.writePrimArray arr 02 0x79622d32
    218   PA.writePrimArray arr 03 0x6b206574
    219   PA.writePrimArray arr 04 k0
    220   PA.writePrimArray arr 05 k1
    221   PA.writePrimArray arr 06 k2
    222   PA.writePrimArray arr 07 k3
    223   PA.writePrimArray arr 08 k4
    224   PA.writePrimArray arr 09 k5
    225   PA.writePrimArray arr 10 k6
    226   PA.writePrimArray arr 11 k7
    227   PA.writePrimArray arr 12 counter
    228   PA.writePrimArray arr 13 n0
    229   PA.writePrimArray arr 14 n1
    230   PA.writePrimArray arr 15 n2
    231 {-# INLINEABLE _chacha_set #-}
    232 
    233 _chacha_counter
    234   :: PrimMonad m
    235   => ChaCha (PrimState m)
    236   -> Word32
    237   -> m ()
    238 _chacha_counter (ChaCha arr) counter =
    239   PA.writePrimArray arr 12 counter
    240 
    241 -- two full rounds (eight quarter rounds)
    242 _rounds :: PrimMonad m => ChaCha (PrimState m) -> m ()
    243 _rounds state = do
    244   _quarter state 00 04 08 12
    245   _quarter state 01 05 09 13
    246   _quarter state 02 06 10 14
    247   _quarter state 03 07 11 15
    248   _quarter state 00 05 10 15
    249   _quarter state 01 06 11 12
    250   _quarter state 02 07 08 13
    251   _quarter state 03 04 09 14
    252 {-# INLINEABLE _rounds #-}
    253 
    254 _block
    255   :: PrimMonad m
    256   => ChaCha (PrimState m)
    257   -> Word32
    258   -> m BS.ByteString
    259 _block state@(ChaCha s) counter = do
    260   _chacha_counter state counter
    261   i <- PA.freezePrimArray s 0 16
    262   for_ [1..10 :: Int] (const (_rounds state))
    263   for_ [0..15 :: Int] $ \idx -> do
    264     let iv = PA.indexPrimArray i idx
    265     sv <- PA.readPrimArray s idx
    266     PA.writePrimArray s idx (iv + sv)
    267   serialize state
    268 
    269 -- RFC8439 2.3
    270 
    271 -- | The ChaCha20 block function. Useful for generating a keystream.
    272 --
    273 --   Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the
    274 --   key must be exactly 256 bits, and the nonce exactly 96 bits.
    275 block
    276   :: BS.ByteString    -- ^ 256-bit key
    277   -> Word32           -- ^ 32-bit counter
    278   -> BS.ByteString    -- ^ 96-bit nonce
    279   -> BS.ByteString    -- ^ 512-bit keystream
    280 block key@(BI.PS _ _ kl) counter nonce@(BI.PS _ _ nl)
    281   | kl /= 32 = error "ppad-chacha (block): invalid key"
    282   | nl /= 12 = error "ppad-chacha (block): invalid nonce"
    283   | otherwise = runST $ do
    284       let k = _parse_key key
    285           n = _parse_nonce nonce
    286       state@(ChaCha s) <- _chacha k counter n
    287       i <- PA.freezePrimArray s 0 16
    288       for_ [1..10 :: Int] (const (_rounds state))
    289       for_ [0..15 :: Int] $ \idx -> do
    290         let iv = PA.indexPrimArray i idx
    291         sv <- PA.readPrimArray s idx
    292         PA.writePrimArray s idx (iv + sv)
    293       serialize state
    294 
    295 serialize :: PrimMonad m => ChaCha (PrimState m) -> m BS.ByteString
    296 serialize (ChaCha m) = do
    297     w64_0 <- w64 <$> PA.readPrimArray m 00 <*> PA.readPrimArray m 01
    298     w64_1 <- w64 <$> PA.readPrimArray m 02 <*> PA.readPrimArray m 03
    299     w64_2 <- w64 <$> PA.readPrimArray m 04 <*> PA.readPrimArray m 05
    300     w64_3 <- w64 <$> PA.readPrimArray m 06 <*> PA.readPrimArray m 07
    301     w64_4 <- w64 <$> PA.readPrimArray m 08 <*> PA.readPrimArray m 09
    302     w64_5 <- w64 <$> PA.readPrimArray m 10 <*> PA.readPrimArray m 11
    303     w64_6 <- w64 <$> PA.readPrimArray m 12 <*> PA.readPrimArray m 13
    304     w64_7 <- w64 <$> PA.readPrimArray m 14 <*> PA.readPrimArray m 15
    305     pure . BS.toStrict . BSB.toLazyByteString . mconcat $
    306       [w64_0, w64_1, w64_2, w64_3, w64_4, w64_5, w64_6, w64_7]
    307   where
    308     w64 a b = BSB.word64LE (fi a .|. (fi b .<<. 32))
    309 
    310 -- chacha20 encryption --------------------------------------------------------
    311 
    312 -- RFC8439 2.4
    313 
    314 -- | The ChaCha20 stream cipher. Generates a keystream and then XOR's
    315 --   the supplied input with it; use it both to encrypt plaintext and
    316 --   decrypt ciphertext.
    317 --
    318 --   Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the
    319 --   key must be exactly 256 bits, and the nonce exactly 96 bits.
    320 --
    321 --   >>> let key = "don't tell anyone my secret key!"
    322 --   >>> let non = "or my nonce!"
    323 --   >>> let cip = cipher key 1 non "but you can share the plaintext"
    324 --   >>> cip
    325 --   "\192*c\248A\204\211n\130y8\197\146k\245\178Y\197=\180_\223\138\146:^\206\&0\v[\201"
    326 --   >>> cipher key 1 non cip
    327 --   "but you can share the plaintext"
    328 cipher
    329   :: BS.ByteString    -- ^ 256-bit key
    330   -> Word32           -- ^ 32-bit counter
    331   -> BS.ByteString    -- ^ 96-bit nonce
    332   -> BS.ByteString    -- ^ arbitrary-length plaintext
    333   -> BS.ByteString    -- ^ ciphertext
    334 cipher raw_key@(BI.PS _ _ kl) counter raw_nonce@(BI.PS _ _ nl) plaintext
    335   | kl /= 32  = error "ppad-chacha (cipher): invalid key"
    336   | nl /= 12  = error "ppad-chacha (cipher): invalid nonce"
    337   | otherwise = runST $ do
    338       let key = _parse_key raw_key
    339           non = _parse_nonce raw_nonce
    340       _cipher key counter non plaintext
    341 
    342 _cipher
    343   :: PrimMonad m
    344   => Key
    345   -> Word32
    346   -> Nonce
    347   -> BS.ByteString
    348   -> m BS.ByteString
    349 _cipher key counter nonce plaintext = do
    350   ChaCha initial <- _chacha key counter nonce
    351   state@(ChaCha s) <- _chacha_alloc
    352 
    353   let loop acc !j bs = case BS.splitAt 64 bs of
    354         (chunk@(BI.PS _ _ l), etc)
    355           | l == 0 && BS.length etc == 0 -> pure $ -- XX
    356               BS.toStrict (BSB.toLazyByteString acc)
    357           | otherwise -> do
    358               PA.copyMutablePrimArray s 0 initial 0 16
    359               stream <- _block state j
    360               let cip = BS.packZipWith (.^.) chunk stream
    361               loop (acc <> BSB.byteString cip) (j + 1) etc
    362 
    363   loop mempty counter plaintext
    364 {-# INLINE _cipher #-}
    365