chacha

The ChaCha20 stream cipher (docs.ppad.tech/chacha).
git clone git://git.ppad.tech/chacha.git
Log | Files | Refs | README | LICENSE

ChaCha20.hs (11304B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE MagicHash #-}
      4 {-# LANGUAGE RecordWildCards #-}
      5 {-# LANGUAGE UnboxedTuples #-}
      6 
      7 -- |
      8 -- Module: Crypto.Cipher.ChaCha20
      9 -- Copyright: (c) 2025 Jared Tobin
     10 -- License: MIT
     11 -- Maintainer: Jared Tobin <jared@ppad.tech>
     12 --
     13 -- A pure ChaCha20 implementation, as specified by
     14 -- [RFC 8439](https://datatracker.ietf.org/doc/html/rfc8439).
     15 
     16 module Crypto.Cipher.ChaCha20 (
     17     -- * ChaCha20 stream cipher
     18     cipher
     19 
     20     -- * ChaCha20 block function
     21   , block
     22 
     23     -- * Error information
     24   , Error(..)
     25 
     26     -- testing
     27   , ChaCha(..)
     28   , _chacha
     29   , _parse_key
     30   , _parse_nonce
     31   , _quarter
     32   , _quarter_pure
     33   , _rounds
     34   ) where
     35 
     36 import Control.Monad.ST
     37 import qualified Data.Bits as B
     38 import Data.Bits ((.|.), (.<<.), (.^.))
     39 import qualified Data.ByteString as BS
     40 import qualified Data.ByteString.Builder as BSB
     41 import qualified Data.ByteString.Internal as BI
     42 import qualified Data.ByteString.Unsafe as BU
     43 import Control.Monad.Primitive (PrimMonad, PrimState)
     44 import Data.Foldable (for_)
     45 import qualified Data.Primitive.PrimArray as PA
     46 import Foreign.ForeignPtr
     47 import GHC.Exts
     48 import GHC.Word
     49 
     50 -- utils ----------------------------------------------------------------------
     51 
     52 -- keystroke saver
     53 fi :: (Integral a, Num b) => a -> b
     54 fi = fromIntegral
     55 {-# INLINE fi #-}
     56 
     57 -- parse strict ByteString in LE order to Word32 (verbatim from
     58 -- Data.Binary)
     59 unsafe_word32le :: BS.ByteString -> Word32
     60 unsafe_word32le s =
     61   (fi (s `BU.unsafeIndex` 3) `B.unsafeShiftL` 24) .|.
     62   (fi (s `BU.unsafeIndex` 2) `B.unsafeShiftL` 16) .|.
     63   (fi (s `BU.unsafeIndex` 1) `B.unsafeShiftL`  8) .|.
     64   (fi (s `BU.unsafeIndex` 0))
     65 {-# INLINE unsafe_word32le #-}
     66 
     67 data WSPair = WSPair
     68   {-# UNPACK #-} !Word32
     69   {-# UNPACK #-} !BS.ByteString
     70 
     71 -- variant of Data.ByteString.splitAt that behaves like an incremental
     72 -- Word32 parser
     73 unsafe_parseWsPair :: BS.ByteString -> WSPair
     74 unsafe_parseWsPair (BI.BS x l) =
     75   WSPair (unsafe_word32le (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4))
     76 {-# INLINE unsafe_parseWsPair #-}
     77 
     78 -- chacha quarter round -------------------------------------------------------
     79 
     80 -- RFC8439 2.2
     81 _quarter
     82   :: PrimMonad m
     83   => ChaCha (PrimState m)
     84   -> Int
     85   -> Int
     86   -> Int
     87   -> Int
     88   -> m ()
     89 _quarter (ChaCha m) i0 i1 i2 i3 = do
     90   !(W32# a) <- PA.readPrimArray m i0
     91   !(W32# b) <- PA.readPrimArray m i1
     92   !(W32# c) <- PA.readPrimArray m i2
     93   !(W32# d) <- PA.readPrimArray m i3
     94 
     95   let !(# a1, b1, c1, d1 #) = quarter# a b c d
     96 
     97   PA.writePrimArray m i0 (W32# a1)
     98   PA.writePrimArray m i1 (W32# b1)
     99   PA.writePrimArray m i2 (W32# c1)
    100   PA.writePrimArray m i3 (W32# d1)
    101 {-# INLINEABLE _quarter #-}
    102 
    103 _quarter_pure
    104   :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32)
    105 _quarter_pure (W32# a) (W32# b) (W32# c) (W32# d) =
    106   let !(# a', b', c', d' #) = quarter# a b c d
    107   in  (W32# a', W32# b', W32# c', W32# d')
    108 {-# INLINE _quarter_pure #-}
    109 
    110 -- RFC8439 2.1
    111 quarter#
    112   :: Word32# -> Word32# -> Word32# -> Word32#
    113   -> (# Word32#, Word32#, Word32#, Word32# #)
    114 quarter# a b c d =
    115   let a0 = plusWord32# a b
    116       d0 = xorWord32# d a0
    117       d1 = rotateL# d0 16#
    118 
    119       c0 = plusWord32# c d1
    120       b0 = xorWord32# b c0
    121       b1 = rotateL# b0 12#
    122 
    123       a1 = plusWord32# a0 b1
    124       d2 = xorWord32# d1 a1
    125       d3 = rotateL# d2 8#
    126 
    127       c1 = plusWord32# c0 d3
    128       b2 = xorWord32# b1 c1
    129       b3 = rotateL# b2 7#
    130 
    131   in  (# a1, b3, c1, d3 #)
    132 {-# INLINE quarter# #-}
    133 
    134 rotateL# :: Word32# -> Int# -> Word32#
    135 rotateL# w i
    136   | isTrue# (i ==# 0#) = w
    137   | otherwise = wordToWord32# (
    138             ((word32ToWord# w) `uncheckedShiftL#` i)
    139       `or#` ((word32ToWord# w) `uncheckedShiftRL#` (32# -# i)))
    140 {-# INLINE rotateL# #-}
    141 
    142 -- key and nonce parsing ------------------------------------------------------
    143 
    144 data Key = Key {
    145     k0 :: {-# UNPACK #-} !Word32
    146   , k1 :: {-# UNPACK #-} !Word32
    147   , k2 :: {-# UNPACK #-} !Word32
    148   , k3 :: {-# UNPACK #-} !Word32
    149   , k4 :: {-# UNPACK #-} !Word32
    150   , k5 :: {-# UNPACK #-} !Word32
    151   , k6 :: {-# UNPACK #-} !Word32
    152   , k7 :: {-# UNPACK #-} !Word32
    153   }
    154   deriving (Eq, Show)
    155 
    156 -- parse strict 256-bit bytestring (length unchecked) to key
    157 _parse_key :: BS.ByteString -> Key
    158 _parse_key bs =
    159   let !(WSPair k0 t0) = unsafe_parseWsPair bs
    160       !(WSPair k1 t1) = unsafe_parseWsPair t0
    161       !(WSPair k2 t2) = unsafe_parseWsPair t1
    162       !(WSPair k3 t3) = unsafe_parseWsPair t2
    163       !(WSPair k4 t4) = unsafe_parseWsPair t3
    164       !(WSPair k5 t5) = unsafe_parseWsPair t4
    165       !(WSPair k6 t6) = unsafe_parseWsPair t5
    166       !(WSPair k7 t7) = unsafe_parseWsPair t6
    167   in  if   BS.null t7
    168       then Key {..}
    169       else error "ppad-chacha (_parse_key): internal error, bytes remaining"
    170 
    171 data Nonce = Nonce {
    172     n0 :: {-# UNPACK #-} !Word32
    173   , n1 :: {-# UNPACK #-} !Word32
    174   , n2 :: {-# UNPACK #-} !Word32
    175   }
    176   deriving (Eq, Show)
    177 
    178 -- parse strict 96-bit bytestring (length unchecked) to nonce
    179 _parse_nonce :: BS.ByteString -> Nonce
    180 _parse_nonce bs =
    181   let !(WSPair n0 t0) = unsafe_parseWsPair bs
    182       !(WSPair n1 t1) = unsafe_parseWsPair t0
    183       !(WSPair n2 t2) = unsafe_parseWsPair t1
    184   in  if   BS.null t2
    185       then Nonce {..}
    186       else error "ppad-chacha (_parse_nonce): internal error, bytes remaining"
    187 
    188 -- chacha20 block function ----------------------------------------------------
    189 
    190 newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32)
    191   deriving Eq
    192 
    193 _chacha
    194   :: PrimMonad m
    195   => Key
    196   -> Word32
    197   -> Nonce
    198   -> m (ChaCha (PrimState m))
    199 _chacha key counter nonce = do
    200   state <- _chacha_alloc
    201   _chacha_set state key counter nonce
    202   pure state
    203 
    204 -- allocate a new chacha state
    205 _chacha_alloc :: PrimMonad m => m (ChaCha (PrimState m))
    206 _chacha_alloc = fmap ChaCha (PA.newPrimArray 16)
    207 {-# INLINE _chacha_alloc #-}
    208 
    209 -- set the values of a chacha state
    210 _chacha_set
    211   :: PrimMonad m
    212   => ChaCha (PrimState m)
    213   -> Key
    214   -> Word32
    215   -> Nonce
    216   -> m ()
    217 _chacha_set (ChaCha arr) Key {..} counter Nonce {..}= do
    218   PA.writePrimArray arr 00 0x61707865
    219   PA.writePrimArray arr 01 0x3320646e
    220   PA.writePrimArray arr 02 0x79622d32
    221   PA.writePrimArray arr 03 0x6b206574
    222   PA.writePrimArray arr 04 k0
    223   PA.writePrimArray arr 05 k1
    224   PA.writePrimArray arr 06 k2
    225   PA.writePrimArray arr 07 k3
    226   PA.writePrimArray arr 08 k4
    227   PA.writePrimArray arr 09 k5
    228   PA.writePrimArray arr 10 k6
    229   PA.writePrimArray arr 11 k7
    230   PA.writePrimArray arr 12 counter
    231   PA.writePrimArray arr 13 n0
    232   PA.writePrimArray arr 14 n1
    233   PA.writePrimArray arr 15 n2
    234 {-# INLINEABLE _chacha_set #-}
    235 
    236 _chacha_counter
    237   :: PrimMonad m
    238   => ChaCha (PrimState m)
    239   -> Word32
    240   -> m ()
    241 _chacha_counter (ChaCha arr) counter =
    242   PA.writePrimArray arr 12 counter
    243 
    244 -- two full rounds (eight quarter rounds)
    245 _rounds :: PrimMonad m => ChaCha (PrimState m) -> m ()
    246 _rounds state = do
    247   _quarter state 00 04 08 12
    248   _quarter state 01 05 09 13
    249   _quarter state 02 06 10 14
    250   _quarter state 03 07 11 15
    251   _quarter state 00 05 10 15
    252   _quarter state 01 06 11 12
    253   _quarter state 02 07 08 13
    254   _quarter state 03 04 09 14
    255 {-# INLINEABLE _rounds #-}
    256 
    257 _block
    258   :: PrimMonad m
    259   => ChaCha (PrimState m)
    260   -> Word32
    261   -> m BS.ByteString
    262 _block state@(ChaCha s) counter = do
    263   _chacha_counter state counter
    264   i <- PA.freezePrimArray s 0 16
    265   for_ [1..10 :: Int] (const (_rounds state))
    266   for_ [0..15 :: Int] $ \idx -> do
    267     let iv = PA.indexPrimArray i idx
    268     sv <- PA.readPrimArray s idx
    269     PA.writePrimArray s idx (iv + sv)
    270   serialize state
    271 
    272 -- | Error values.
    273 data Error =
    274     InvalidKey   -- ^ the provided key was not 256 bits long
    275   | InvalidNonce -- ^ the provided nonce was none 96 bits long
    276   deriving (Eq, Show)
    277 
    278 -- RFC8439 2.3
    279 
    280 -- | The ChaCha20 block function. Useful for generating a keystream.
    281 --
    282 --   Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the
    283 --   key must be exactly 256 bits, and the nonce exactly 96 bits.
    284 block
    285   :: BS.ByteString    -- ^ 256-bit key
    286   -> Word32           -- ^ 32-bit counter
    287   -> BS.ByteString    -- ^ 96-bit nonce
    288   -> Either Error BS.ByteString    -- ^ 512-bit keystream
    289 block key@(BI.PS _ _ kl) counter nonce@(BI.PS _ _ nl)
    290   | kl /= 32 = Left InvalidKey
    291   | nl /= 12 = Left InvalidNonce
    292   | otherwise = pure $ runST $ do
    293       let k = _parse_key key
    294           n = _parse_nonce nonce
    295       state@(ChaCha s) <- _chacha k counter n
    296       i <- PA.freezePrimArray s 0 16
    297       for_ [1..10 :: Int] (const (_rounds state))
    298       for_ [0..15 :: Int] $ \idx -> do
    299         let iv = PA.indexPrimArray i idx
    300         sv <- PA.readPrimArray s idx
    301         PA.writePrimArray s idx (iv + sv)
    302       serialize state
    303 
    304 serialize :: PrimMonad m => ChaCha (PrimState m) -> m BS.ByteString
    305 serialize (ChaCha m) = do
    306     w64_0 <- w64 <$> PA.readPrimArray m 00 <*> PA.readPrimArray m 01
    307     w64_1 <- w64 <$> PA.readPrimArray m 02 <*> PA.readPrimArray m 03
    308     w64_2 <- w64 <$> PA.readPrimArray m 04 <*> PA.readPrimArray m 05
    309     w64_3 <- w64 <$> PA.readPrimArray m 06 <*> PA.readPrimArray m 07
    310     w64_4 <- w64 <$> PA.readPrimArray m 08 <*> PA.readPrimArray m 09
    311     w64_5 <- w64 <$> PA.readPrimArray m 10 <*> PA.readPrimArray m 11
    312     w64_6 <- w64 <$> PA.readPrimArray m 12 <*> PA.readPrimArray m 13
    313     w64_7 <- w64 <$> PA.readPrimArray m 14 <*> PA.readPrimArray m 15
    314     pure . BS.toStrict . BSB.toLazyByteString . mconcat $
    315       [w64_0, w64_1, w64_2, w64_3, w64_4, w64_5, w64_6, w64_7]
    316   where
    317     w64 a b = BSB.word64LE (fi a .|. (fi b .<<. 32))
    318 
    319 -- chacha20 encryption --------------------------------------------------------
    320 
    321 -- RFC8439 2.4
    322 
    323 -- | The ChaCha20 stream cipher. Generates a keystream and then XOR's
    324 --   the supplied input with it; use it both to encrypt plaintext and
    325 --   decrypt ciphertext.
    326 --
    327 --   Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the
    328 --   key must be exactly 256 bits, and the nonce exactly 96 bits.
    329 --
    330 --   >>> let key = "don't tell anyone my secret key!"
    331 --   >>> let non = "or my nonce!"
    332 --   >>> let cip = cipher key 1 non "but you can share the plaintext"
    333 --   >>> cip
    334 --   "\192*c\248A\204\211n\130y8\197\146k\245\178Y\197=\180_\223\138\146:^\206\&0\v[\201"
    335 --   >>> cipher key 1 non cip
    336 --   Right "but you can share the plaintext"
    337 cipher
    338   :: BS.ByteString    -- ^ 256-bit key
    339   -> Word32           -- ^ 32-bit counter
    340   -> BS.ByteString    -- ^ 96-bit nonce
    341   -> BS.ByteString    -- ^ arbitrary-length plaintext
    342   -> Either Error BS.ByteString    -- ^ ciphertext
    343 cipher raw_key@(BI.PS _ _ kl) counter raw_nonce@(BI.PS _ _ nl) plaintext
    344   | kl /= 32  = Left InvalidKey
    345   | nl /= 12  = Left InvalidNonce
    346   | otherwise = pure $ runST $ do
    347       let key = _parse_key raw_key
    348           non = _parse_nonce raw_nonce
    349       _cipher key counter non plaintext
    350 
    351 _cipher
    352   :: PrimMonad m
    353   => Key
    354   -> Word32
    355   -> Nonce
    356   -> BS.ByteString
    357   -> m BS.ByteString
    358 _cipher key counter nonce plaintext = do
    359   ChaCha initial <- _chacha key counter nonce
    360   state@(ChaCha s) <- _chacha_alloc
    361 
    362   let loop acc !j bs = case BS.splitAt 64 bs of
    363         (chunk@(BI.PS _ _ l), etc@(BI.PS _ _ le))
    364           | l == 0 && le == 0 -> pure $
    365               BS.toStrict (BSB.toLazyByteString acc)
    366           | otherwise -> do
    367               PA.copyMutablePrimArray s 0 initial 0 16
    368               stream <- _block state j
    369               let cip = BS.packZipWith (.^.) chunk stream
    370               loop (acc <> BSB.byteString cip) (j + 1) etc
    371 
    372   loop mempty counter plaintext
    373 {-# INLINE _cipher #-}
    374