chacha

The ChaCha20 stream cipher (docs.ppad.tech/chacha).
git clone git://git.ppad.tech/chacha.git
Log | Files | Refs | README | LICENSE

ChaCha20.hs (11510B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE BangPatterns #-}
      3 {-# LANGUAGE MagicHash #-}
      4 {-# LANGUAGE RecordWildCards #-}
      5 {-# LANGUAGE UnboxedTuples #-}
      6 
      7 -- |
      8 -- Module: Crypto.Cipher.ChaCha20
      9 -- Copyright: (c) 2025 Jared Tobin
     10 -- License: MIT
     11 -- Maintainer: Jared Tobin <jared@ppad.tech>
     12 --
     13 -- A fast ChaCha20 implementation, as specified by
     14 -- [RFC 8439](https://datatracker.ietf.org/doc/html/rfc8439).
     15 
     16 module Crypto.Cipher.ChaCha20 (
     17     -- * ChaCha20 stream cipher
     18     cipher
     19 
     20     -- * ChaCha20 block function
     21   , block
     22 
     23     -- * Error information
     24   , Error(..)
     25 
     26     -- testing
     27   , ChaCha(..)
     28   , _chacha
     29   , _parse_key
     30   , _parse_nonce
     31   , _quarter
     32   , _quarter_pure
     33   , _rounds
     34   ) where
     35 
     36 import Control.Monad.ST
     37 import qualified Crypto.Cipher.ChaCha20.Arm as Arm
     38 import qualified Data.Bits as B
     39 import Data.Bits ((.|.), (.<<.), (.^.))
     40 import qualified Data.ByteString as BS
     41 import qualified Data.ByteString.Builder as BSB
     42 import qualified Data.ByteString.Internal as BI
     43 import qualified Data.ByteString.Unsafe as BU
     44 import Control.Monad.Primitive (PrimMonad, PrimState)
     45 import Data.Foldable (for_)
     46 import qualified Data.Primitive.PrimArray as PA
     47 import Foreign.ForeignPtr
     48 import GHC.Exts
     49 import GHC.Word
     50 
     51 -- utils ----------------------------------------------------------------------
     52 
     53 -- keystroke saver
     54 fi :: (Integral a, Num b) => a -> b
     55 fi = fromIntegral
     56 {-# INLINE fi #-}
     57 
     58 -- parse strict ByteString in LE order to Word32 (verbatim from
     59 -- Data.Binary)
     60 unsafe_word32le :: BS.ByteString -> Word32
     61 unsafe_word32le s =
     62   (fi (s `BU.unsafeIndex` 3) `B.unsafeShiftL` 24) .|.
     63   (fi (s `BU.unsafeIndex` 2) `B.unsafeShiftL` 16) .|.
     64   (fi (s `BU.unsafeIndex` 1) `B.unsafeShiftL`  8) .|.
     65   (fi (s `BU.unsafeIndex` 0))
     66 {-# INLINE unsafe_word32le #-}
     67 
     68 data WSPair = WSPair
     69   {-# UNPACK #-} !Word32
     70   {-# UNPACK #-} !BS.ByteString
     71 
     72 -- variant of Data.ByteString.splitAt that behaves like an incremental
     73 -- Word32 parser
     74 unsafe_parseWsPair :: BS.ByteString -> WSPair
     75 unsafe_parseWsPair (BI.BS x l) =
     76   WSPair (unsafe_word32le (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4))
     77 {-# INLINE unsafe_parseWsPair #-}
     78 
     79 -- chacha quarter round -------------------------------------------------------
     80 
     81 -- RFC8439 2.2
     82 _quarter
     83   :: PrimMonad m
     84   => ChaCha (PrimState m)
     85   -> Int
     86   -> Int
     87   -> Int
     88   -> Int
     89   -> m ()
     90 _quarter (ChaCha m) i0 i1 i2 i3 = do
     91   !(W32# a) <- PA.readPrimArray m i0
     92   !(W32# b) <- PA.readPrimArray m i1
     93   !(W32# c) <- PA.readPrimArray m i2
     94   !(W32# d) <- PA.readPrimArray m i3
     95 
     96   let !(# a1, b1, c1, d1 #) = quarter# a b c d
     97 
     98   PA.writePrimArray m i0 (W32# a1)
     99   PA.writePrimArray m i1 (W32# b1)
    100   PA.writePrimArray m i2 (W32# c1)
    101   PA.writePrimArray m i3 (W32# d1)
    102 {-# INLINEABLE _quarter #-}
    103 
    104 _quarter_pure
    105   :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32)
    106 _quarter_pure (W32# a) (W32# b) (W32# c) (W32# d) =
    107   let !(# a', b', c', d' #) = quarter# a b c d
    108   in  (W32# a', W32# b', W32# c', W32# d')
    109 {-# INLINE _quarter_pure #-}
    110 
    111 -- RFC8439 2.1
    112 quarter#
    113   :: Word32# -> Word32# -> Word32# -> Word32#
    114   -> (# Word32#, Word32#, Word32#, Word32# #)
    115 quarter# a b c d =
    116   let a0 = plusWord32# a b
    117       d0 = xorWord32# d a0
    118       d1 = rotateL# d0 16#
    119 
    120       c0 = plusWord32# c d1
    121       b0 = xorWord32# b c0
    122       b1 = rotateL# b0 12#
    123 
    124       a1 = plusWord32# a0 b1
    125       d2 = xorWord32# d1 a1
    126       d3 = rotateL# d2 8#
    127 
    128       c1 = plusWord32# c0 d3
    129       b2 = xorWord32# b1 c1
    130       b3 = rotateL# b2 7#
    131 
    132   in  (# a1, b3, c1, d3 #)
    133 {-# INLINE quarter# #-}
    134 
    135 rotateL# :: Word32# -> Int# -> Word32#
    136 rotateL# w i
    137   | isTrue# (i ==# 0#) = w
    138   | otherwise = wordToWord32# (
    139             ((word32ToWord# w) `uncheckedShiftL#` i)
    140       `or#` ((word32ToWord# w) `uncheckedShiftRL#` (32# -# i)))
    141 {-# INLINE rotateL# #-}
    142 
    143 -- key and nonce parsing ------------------------------------------------------
    144 
    145 data Key = Key {
    146     k0 :: {-# UNPACK #-} !Word32
    147   , k1 :: {-# UNPACK #-} !Word32
    148   , k2 :: {-# UNPACK #-} !Word32
    149   , k3 :: {-# UNPACK #-} !Word32
    150   , k4 :: {-# UNPACK #-} !Word32
    151   , k5 :: {-# UNPACK #-} !Word32
    152   , k6 :: {-# UNPACK #-} !Word32
    153   , k7 :: {-# UNPACK #-} !Word32
    154   }
    155   deriving Show
    156 
    157 -- parse strict 256-bit bytestring (length unchecked) to key
    158 _parse_key :: BS.ByteString -> Key
    159 _parse_key bs =
    160   let !(WSPair k0 t0) = unsafe_parseWsPair bs
    161       !(WSPair k1 t1) = unsafe_parseWsPair t0
    162       !(WSPair k2 t2) = unsafe_parseWsPair t1
    163       !(WSPair k3 t3) = unsafe_parseWsPair t2
    164       !(WSPair k4 t4) = unsafe_parseWsPair t3
    165       !(WSPair k5 t5) = unsafe_parseWsPair t4
    166       !(WSPair k6 t6) = unsafe_parseWsPair t5
    167       !(WSPair k7 t7) = unsafe_parseWsPair t6
    168   in  if   BS.null t7
    169       then Key {..}
    170       else error "ppad-chacha (_parse_key): internal error, bytes remaining"
    171 
    172 data Nonce = Nonce {
    173     n0 :: {-# UNPACK #-} !Word32
    174   , n1 :: {-# UNPACK #-} !Word32
    175   , n2 :: {-# UNPACK #-} !Word32
    176   }
    177   deriving Show
    178 
    179 -- parse strict 96-bit bytestring (length unchecked) to nonce
    180 _parse_nonce :: BS.ByteString -> Nonce
    181 _parse_nonce bs =
    182   let !(WSPair n0 t0) = unsafe_parseWsPair bs
    183       !(WSPair n1 t1) = unsafe_parseWsPair t0
    184       !(WSPair n2 t2) = unsafe_parseWsPair t1
    185   in  if   BS.null t2
    186       then Nonce {..}
    187       else error "ppad-chacha (_parse_nonce): internal error, bytes remaining"
    188 
    189 -- chacha20 block function ----------------------------------------------------
    190 
    191 newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32)
    192   deriving Eq
    193 
    194 _chacha
    195   :: PrimMonad m
    196   => Key
    197   -> Word32
    198   -> Nonce
    199   -> m (ChaCha (PrimState m))
    200 _chacha key counter nonce = do
    201   state <- _chacha_alloc
    202   _chacha_set state key counter nonce
    203   pure state
    204 
    205 -- allocate a new chacha state
    206 _chacha_alloc :: PrimMonad m => m (ChaCha (PrimState m))
    207 _chacha_alloc = fmap ChaCha (PA.newPrimArray 16)
    208 {-# INLINE _chacha_alloc #-}
    209 
    210 -- set the values of a chacha state
    211 _chacha_set
    212   :: PrimMonad m
    213   => ChaCha (PrimState m)
    214   -> Key
    215   -> Word32
    216   -> Nonce
    217   -> m ()
    218 _chacha_set (ChaCha arr) Key {..} counter Nonce {..}= do
    219   PA.writePrimArray arr 00 0x61707865
    220   PA.writePrimArray arr 01 0x3320646e
    221   PA.writePrimArray arr 02 0x79622d32
    222   PA.writePrimArray arr 03 0x6b206574
    223   PA.writePrimArray arr 04 k0
    224   PA.writePrimArray arr 05 k1
    225   PA.writePrimArray arr 06 k2
    226   PA.writePrimArray arr 07 k3
    227   PA.writePrimArray arr 08 k4
    228   PA.writePrimArray arr 09 k5
    229   PA.writePrimArray arr 10 k6
    230   PA.writePrimArray arr 11 k7
    231   PA.writePrimArray arr 12 counter
    232   PA.writePrimArray arr 13 n0
    233   PA.writePrimArray arr 14 n1
    234   PA.writePrimArray arr 15 n2
    235 {-# INLINEABLE _chacha_set #-}
    236 
    237 _chacha_counter
    238   :: PrimMonad m
    239   => ChaCha (PrimState m)
    240   -> Word32
    241   -> m ()
    242 _chacha_counter (ChaCha arr) counter =
    243   PA.writePrimArray arr 12 counter
    244 
    245 -- two full rounds (eight quarter rounds)
    246 _rounds :: PrimMonad m => ChaCha (PrimState m) -> m ()
    247 _rounds state = do
    248   _quarter state 00 04 08 12
    249   _quarter state 01 05 09 13
    250   _quarter state 02 06 10 14
    251   _quarter state 03 07 11 15
    252   _quarter state 00 05 10 15
    253   _quarter state 01 06 11 12
    254   _quarter state 02 07 08 13
    255   _quarter state 03 04 09 14
    256 {-# INLINEABLE _rounds #-}
    257 
    258 _block
    259   :: PrimMonad m
    260   => ChaCha (PrimState m)
    261   -> Word32
    262   -> m BS.ByteString
    263 _block state@(ChaCha s) counter = do
    264   _chacha_counter state counter
    265   i <- PA.freezePrimArray s 0 16
    266   for_ [1..10 :: Int] (const (_rounds state))
    267   for_ [0..15 :: Int] $ \idx -> do
    268     let iv = PA.indexPrimArray i idx
    269     sv <- PA.readPrimArray s idx
    270     PA.writePrimArray s idx (iv + sv)
    271   serialize state
    272 
    273 -- | Error values.
    274 data Error =
    275     InvalidKey   -- ^ the provided key was not 256 bits long
    276   | InvalidNonce -- ^ the provided nonce was none 96 bits long
    277   deriving (Eq, Show)
    278 
    279 -- RFC8439 2.3
    280 
    281 -- | The ChaCha20 block function. Useful for generating a keystream.
    282 --
    283 --   Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the
    284 --   key must be exactly 256 bits, and the nonce exactly 96 bits.
    285 block
    286   :: BS.ByteString    -- ^ 256-bit key
    287   -> Word32           -- ^ 32-bit counter
    288   -> BS.ByteString    -- ^ 96-bit nonce
    289   -> Either Error BS.ByteString    -- ^ 512-bit keystream
    290 block key@(BI.PS _ _ kl) counter nonce@(BI.PS _ _ nl)
    291   | kl /= 32 = Left InvalidKey
    292   | nl /= 12 = Left InvalidNonce
    293   | Arm.chacha20_arm_available =
    294       Right (Arm.block key counter nonce)
    295   | otherwise = pure $ runST $ do
    296       let k = _parse_key key
    297           n = _parse_nonce nonce
    298       state@(ChaCha s) <- _chacha k counter n
    299       i <- PA.freezePrimArray s 0 16
    300       for_ [1..10 :: Int] (const (_rounds state))
    301       for_ [0..15 :: Int] $ \idx -> do
    302         let iv = PA.indexPrimArray i idx
    303         sv <- PA.readPrimArray s idx
    304         PA.writePrimArray s idx (iv + sv)
    305       serialize state
    306 
    307 serialize :: PrimMonad m => ChaCha (PrimState m) -> m BS.ByteString
    308 serialize (ChaCha m) = do
    309     w64_0 <- w64 <$> PA.readPrimArray m 00 <*> PA.readPrimArray m 01
    310     w64_1 <- w64 <$> PA.readPrimArray m 02 <*> PA.readPrimArray m 03
    311     w64_2 <- w64 <$> PA.readPrimArray m 04 <*> PA.readPrimArray m 05
    312     w64_3 <- w64 <$> PA.readPrimArray m 06 <*> PA.readPrimArray m 07
    313     w64_4 <- w64 <$> PA.readPrimArray m 08 <*> PA.readPrimArray m 09
    314     w64_5 <- w64 <$> PA.readPrimArray m 10 <*> PA.readPrimArray m 11
    315     w64_6 <- w64 <$> PA.readPrimArray m 12 <*> PA.readPrimArray m 13
    316     w64_7 <- w64 <$> PA.readPrimArray m 14 <*> PA.readPrimArray m 15
    317     pure . BS.toStrict . BSB.toLazyByteString . mconcat $
    318       [w64_0, w64_1, w64_2, w64_3, w64_4, w64_5, w64_6, w64_7]
    319   where
    320     w64 a b = BSB.word64LE (fi a .|. (fi b .<<. 32))
    321 
    322 -- chacha20 encryption --------------------------------------------------------
    323 
    324 -- RFC8439 2.4
    325 
    326 -- | The ChaCha20 stream cipher. Generates a keystream and then XOR's
    327 --   the supplied input with it; use it both to encrypt plaintext and
    328 --   decrypt ciphertext.
    329 --
    330 --   Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the
    331 --   key must be exactly 256 bits, and the nonce exactly 96 bits.
    332 --
    333 --   >>> let key = "don't tell anyone my secret key!"
    334 --   >>> let non = "or my nonce!"
    335 --   >>> let cip = cipher key 1 non "but you can share the plaintext"
    336 --   >>> cip
    337 --   "\192*c\248A\204\211n\130y8\197\146k\245\178Y\197=\180_\223\138\146:^\206\&0\v[\201"
    338 --   >>> cipher key 1 non cip
    339 --   Right "but you can share the plaintext"
    340 cipher
    341   :: BS.ByteString    -- ^ 256-bit key
    342   -> Word32           -- ^ 32-bit counter
    343   -> BS.ByteString    -- ^ 96-bit nonce
    344   -> BS.ByteString    -- ^ arbitrary-length plaintext
    345   -> Either Error BS.ByteString    -- ^ ciphertext
    346 cipher raw_key@(BI.PS _ _ kl) counter raw_nonce@(BI.PS _ _ nl) plaintext
    347   | kl /= 32 = Left InvalidKey
    348   | nl /= 12 = Left InvalidNonce
    349   | Arm.chacha20_arm_available =
    350       Right (Arm.cipher raw_key counter raw_nonce plaintext)
    351   | otherwise = pure $ runST $ do
    352       let key = _parse_key raw_key
    353           non = _parse_nonce raw_nonce
    354       _cipher key counter non plaintext
    355 
    356 _cipher
    357   :: PrimMonad m
    358   => Key
    359   -> Word32
    360   -> Nonce
    361   -> BS.ByteString
    362   -> m BS.ByteString
    363 _cipher key counter nonce plaintext = do
    364   ChaCha initial <- _chacha key counter nonce
    365   state@(ChaCha s) <- _chacha_alloc
    366 
    367   let loop acc !j bs = case BS.splitAt 64 bs of
    368         (chunk@(BI.PS _ _ l), etc@(BI.PS _ _ le))
    369           | l == 0 && le == 0 -> pure $
    370               BS.toStrict (BSB.toLazyByteString acc)
    371           | otherwise -> do
    372               PA.copyMutablePrimArray s 0 initial 0 16
    373               stream <- _block state j
    374               let cip = BS.packZipWith (.^.) chunk stream
    375               loop (acc <> BSB.byteString cip) (j + 1) etc
    376 
    377   loop mempty counter plaintext
    378 {-# INLINE _cipher #-}
    379