ChaCha20.hs (11109B)
1 {-# OPTIONS_HADDOCK prune #-} 2 {-# LANGUAGE BangPatterns #-} 3 {-# LANGUAGE MagicHash #-} 4 {-# LANGUAGE RecordWildCards #-} 5 {-# LANGUAGE UnboxedTuples #-} 6 7 -- | 8 -- Module: Crypto.Cipher.ChaCha20 9 -- Copyright: (c) 2025 Jared Tobin 10 -- License: MIT 11 -- Maintainer: Jared Tobin <jared@ppad.tech> 12 -- 13 -- A pure ChaCha20 implementation, as specified by 14 -- [RFC 8439](https://datatracker.ietf.org/doc/html/rfc8439). 15 16 module Crypto.Cipher.ChaCha20 ( 17 -- * ChaCha20 stream cipher 18 cipher 19 20 -- * ChaCha20 block function 21 , block 22 23 -- testing 24 , ChaCha(..) 25 , _chacha 26 , _parse_key 27 , _parse_nonce 28 , _quarter 29 , _quarter_pure 30 , _rounds 31 ) where 32 33 import Control.Monad.ST 34 import qualified Data.Bits as B 35 import Data.Bits ((.|.), (.<<.), (.^.)) 36 import qualified Data.ByteString as BS 37 import qualified Data.ByteString.Builder as BSB 38 import qualified Data.ByteString.Internal as BI 39 import qualified Data.ByteString.Unsafe as BU 40 import Control.Monad.Primitive (PrimMonad, PrimState) 41 import Data.Foldable (for_) 42 import qualified Data.Primitive.PrimArray as PA 43 import Foreign.ForeignPtr 44 import GHC.Exts 45 import GHC.Word 46 47 -- utils ---------------------------------------------------------------------- 48 49 -- keystroke saver 50 fi :: (Integral a, Num b) => a -> b 51 fi = fromIntegral 52 {-# INLINE fi #-} 53 54 -- parse strict ByteString in LE order to Word32 (verbatim from 55 -- Data.Binary) 56 unsafe_word32le :: BS.ByteString -> Word32 57 unsafe_word32le s = 58 (fi (s `BU.unsafeIndex` 3) `B.unsafeShiftL` 24) .|. 59 (fi (s `BU.unsafeIndex` 2) `B.unsafeShiftL` 16) .|. 60 (fi (s `BU.unsafeIndex` 1) `B.unsafeShiftL` 8) .|. 61 (fi (s `BU.unsafeIndex` 0)) 62 {-# INLINE unsafe_word32le #-} 63 64 data WSPair = WSPair 65 {-# UNPACK #-} !Word32 66 {-# UNPACK #-} !BS.ByteString 67 68 -- variant of Data.ByteString.splitAt that behaves like an incremental 69 -- Word32 parser 70 unsafe_parseWsPair :: BS.ByteString -> WSPair 71 unsafe_parseWsPair (BI.BS x l) = 72 WSPair (unsafe_word32le (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4)) 73 {-# INLINE unsafe_parseWsPair #-} 74 75 -- chacha quarter round ------------------------------------------------------- 76 77 -- RFC8439 2.2 78 _quarter 79 :: PrimMonad m 80 => ChaCha (PrimState m) 81 -> Int 82 -> Int 83 -> Int 84 -> Int 85 -> m () 86 _quarter (ChaCha m) i0 i1 i2 i3 = do 87 !(W32# a) <- PA.readPrimArray m i0 88 !(W32# b) <- PA.readPrimArray m i1 89 !(W32# c) <- PA.readPrimArray m i2 90 !(W32# d) <- PA.readPrimArray m i3 91 92 let !(# a1, b1, c1, d1 #) = quarter# a b c d 93 94 PA.writePrimArray m i0 (W32# a1) 95 PA.writePrimArray m i1 (W32# b1) 96 PA.writePrimArray m i2 (W32# c1) 97 PA.writePrimArray m i3 (W32# d1) 98 {-# INLINEABLE _quarter #-} 99 100 _quarter_pure 101 :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32) 102 _quarter_pure (W32# a) (W32# b) (W32# c) (W32# d) = 103 let !(# a', b', c', d' #) = quarter# a b c d 104 in (W32# a', W32# b', W32# c', W32# d') 105 {-# INLINE _quarter_pure #-} 106 107 -- RFC8439 2.1 108 quarter# 109 :: Word32# -> Word32# -> Word32# -> Word32# 110 -> (# Word32#, Word32#, Word32#, Word32# #) 111 quarter# a b c d = 112 let a0 = plusWord32# a b 113 d0 = xorWord32# d a0 114 d1 = rotateL# d0 16# 115 116 c0 = plusWord32# c d1 117 b0 = xorWord32# b c0 118 b1 = rotateL# b0 12# 119 120 a1 = plusWord32# a0 b1 121 d2 = xorWord32# d1 a1 122 d3 = rotateL# d2 8# 123 124 c1 = plusWord32# c0 d3 125 b2 = xorWord32# b1 c1 126 b3 = rotateL# b2 7# 127 128 in (# a1, b3, c1, d3 #) 129 {-# INLINE quarter# #-} 130 131 rotateL# :: Word32# -> Int# -> Word32# 132 rotateL# w i 133 | isTrue# (i ==# 0#) = w 134 | otherwise = wordToWord32# ( 135 ((word32ToWord# w) `uncheckedShiftL#` i) 136 `or#` ((word32ToWord# w) `uncheckedShiftRL#` (32# -# i))) 137 {-# INLINE rotateL# #-} 138 139 -- key and nonce parsing ------------------------------------------------------ 140 141 data Key = Key { 142 k0 :: {-# UNPACK #-} !Word32 143 , k1 :: {-# UNPACK #-} !Word32 144 , k2 :: {-# UNPACK #-} !Word32 145 , k3 :: {-# UNPACK #-} !Word32 146 , k4 :: {-# UNPACK #-} !Word32 147 , k5 :: {-# UNPACK #-} !Word32 148 , k6 :: {-# UNPACK #-} !Word32 149 , k7 :: {-# UNPACK #-} !Word32 150 } 151 deriving (Eq, Show) 152 153 -- parse strict 256-bit bytestring (length unchecked) to key 154 _parse_key :: BS.ByteString -> Key 155 _parse_key bs = 156 let !(WSPair k0 t0) = unsafe_parseWsPair bs 157 !(WSPair k1 t1) = unsafe_parseWsPair t0 158 !(WSPair k2 t2) = unsafe_parseWsPair t1 159 !(WSPair k3 t3) = unsafe_parseWsPair t2 160 !(WSPair k4 t4) = unsafe_parseWsPair t3 161 !(WSPair k5 t5) = unsafe_parseWsPair t4 162 !(WSPair k6 t6) = unsafe_parseWsPair t5 163 !(WSPair k7 t7) = unsafe_parseWsPair t6 164 in if BS.null t7 165 then Key {..} 166 else error "ppad-chacha (_parse_key): bytes remaining" 167 168 data Nonce = Nonce { 169 n0 :: {-# UNPACK #-} !Word32 170 , n1 :: {-# UNPACK #-} !Word32 171 , n2 :: {-# UNPACK #-} !Word32 172 } 173 deriving (Eq, Show) 174 175 -- parse strict 96-bit bytestring (length unchecked) to nonce 176 _parse_nonce :: BS.ByteString -> Nonce 177 _parse_nonce bs = 178 let !(WSPair n0 t0) = unsafe_parseWsPair bs 179 !(WSPair n1 t1) = unsafe_parseWsPair t0 180 !(WSPair n2 t2) = unsafe_parseWsPair t1 181 in if BS.null t2 182 then Nonce {..} 183 else error "ppad-chacha (_parse_nonce): bytes remaining" 184 185 -- chacha20 block function ---------------------------------------------------- 186 187 newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32) 188 deriving Eq 189 190 _chacha 191 :: PrimMonad m 192 => Key 193 -> Word32 194 -> Nonce 195 -> m (ChaCha (PrimState m)) 196 _chacha key counter nonce = do 197 state <- _chacha_alloc 198 _chacha_set state key counter nonce 199 pure state 200 201 -- allocate a new chacha state 202 _chacha_alloc :: PrimMonad m => m (ChaCha (PrimState m)) 203 _chacha_alloc = fmap ChaCha (PA.newPrimArray 16) 204 {-# INLINE _chacha_alloc #-} 205 206 -- set the values of a chacha state 207 _chacha_set 208 :: PrimMonad m 209 => ChaCha (PrimState m) 210 -> Key 211 -> Word32 212 -> Nonce 213 -> m () 214 _chacha_set (ChaCha arr) Key {..} counter Nonce {..}= do 215 PA.writePrimArray arr 00 0x61707865 216 PA.writePrimArray arr 01 0x3320646e 217 PA.writePrimArray arr 02 0x79622d32 218 PA.writePrimArray arr 03 0x6b206574 219 PA.writePrimArray arr 04 k0 220 PA.writePrimArray arr 05 k1 221 PA.writePrimArray arr 06 k2 222 PA.writePrimArray arr 07 k3 223 PA.writePrimArray arr 08 k4 224 PA.writePrimArray arr 09 k5 225 PA.writePrimArray arr 10 k6 226 PA.writePrimArray arr 11 k7 227 PA.writePrimArray arr 12 counter 228 PA.writePrimArray arr 13 n0 229 PA.writePrimArray arr 14 n1 230 PA.writePrimArray arr 15 n2 231 {-# INLINEABLE _chacha_set #-} 232 233 _chacha_counter 234 :: PrimMonad m 235 => ChaCha (PrimState m) 236 -> Word32 237 -> m () 238 _chacha_counter (ChaCha arr) counter = 239 PA.writePrimArray arr 12 counter 240 241 -- two full rounds (eight quarter rounds) 242 _rounds :: PrimMonad m => ChaCha (PrimState m) -> m () 243 _rounds state = do 244 _quarter state 00 04 08 12 245 _quarter state 01 05 09 13 246 _quarter state 02 06 10 14 247 _quarter state 03 07 11 15 248 _quarter state 00 05 10 15 249 _quarter state 01 06 11 12 250 _quarter state 02 07 08 13 251 _quarter state 03 04 09 14 252 {-# INLINEABLE _rounds #-} 253 254 _block 255 :: PrimMonad m 256 => ChaCha (PrimState m) 257 -> Word32 258 -> m BS.ByteString 259 _block state@(ChaCha s) counter = do 260 _chacha_counter state counter 261 i <- PA.freezePrimArray s 0 16 262 for_ [1..10 :: Int] (const (_rounds state)) 263 for_ [0..15 :: Int] $ \idx -> do 264 let iv = PA.indexPrimArray i idx 265 sv <- PA.readPrimArray s idx 266 PA.writePrimArray s idx (iv + sv) 267 serialize state 268 269 -- RFC8439 2.3 270 271 -- | The ChaCha20 block function. Useful for generating a keystream. 272 -- 273 -- Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the 274 -- key must be exactly 256 bits, and the nonce exactly 96 bits. 275 block 276 :: BS.ByteString -- ^ 256-bit key 277 -> Word32 -- ^ 32-bit counter 278 -> BS.ByteString -- ^ 96-bit nonce 279 -> BS.ByteString -- ^ 512-bit keystream 280 block key@(BI.PS _ _ kl) counter nonce@(BI.PS _ _ nl) 281 | kl /= 32 = error "ppad-chacha (block): invalid key" 282 | nl /= 12 = error "ppad-chacha (block): invalid nonce" 283 | otherwise = runST $ do 284 let k = _parse_key key 285 n = _parse_nonce nonce 286 state@(ChaCha s) <- _chacha k counter n 287 i <- PA.freezePrimArray s 0 16 288 for_ [1..10 :: Int] (const (_rounds state)) 289 for_ [0..15 :: Int] $ \idx -> do 290 let iv = PA.indexPrimArray i idx 291 sv <- PA.readPrimArray s idx 292 PA.writePrimArray s idx (iv + sv) 293 serialize state 294 295 serialize :: PrimMonad m => ChaCha (PrimState m) -> m BS.ByteString 296 serialize (ChaCha m) = do 297 w64_0 <- w64 <$> PA.readPrimArray m 00 <*> PA.readPrimArray m 01 298 w64_1 <- w64 <$> PA.readPrimArray m 02 <*> PA.readPrimArray m 03 299 w64_2 <- w64 <$> PA.readPrimArray m 04 <*> PA.readPrimArray m 05 300 w64_3 <- w64 <$> PA.readPrimArray m 06 <*> PA.readPrimArray m 07 301 w64_4 <- w64 <$> PA.readPrimArray m 08 <*> PA.readPrimArray m 09 302 w64_5 <- w64 <$> PA.readPrimArray m 10 <*> PA.readPrimArray m 11 303 w64_6 <- w64 <$> PA.readPrimArray m 12 <*> PA.readPrimArray m 13 304 w64_7 <- w64 <$> PA.readPrimArray m 14 <*> PA.readPrimArray m 15 305 pure . BS.toStrict . BSB.toLazyByteString . mconcat $ 306 [w64_0, w64_1, w64_2, w64_3, w64_4, w64_5, w64_6, w64_7] 307 where 308 w64 a b = BSB.word64LE (fi a .|. (fi b .<<. 32)) 309 310 -- chacha20 encryption -------------------------------------------------------- 311 312 -- RFC8439 2.4 313 314 -- | The ChaCha20 stream cipher. Generates a keystream and then XOR's 315 -- the supplied input with it; use it both to encrypt plaintext and 316 -- decrypt ciphertext. 317 -- 318 -- Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the 319 -- key must be exactly 256 bits, and the nonce exactly 96 bits. 320 -- 321 -- >>> let key = "don't tell anyone my secret key!" 322 -- >>> let non = "or my nonce!" 323 -- >>> let cip = cipher key 1 non "but you can share the plaintext" 324 -- >>> cip 325 -- "\192*c\248A\204\211n\130y8\197\146k\245\178Y\197=\180_\223\138\146:^\206\&0\v[\201" 326 -- >>> cipher key 1 non cip 327 -- "but you can share the plaintext" 328 cipher 329 :: BS.ByteString -- ^ 256-bit key 330 -> Word32 -- ^ 32-bit counter 331 -> BS.ByteString -- ^ 96-bit nonce 332 -> BS.ByteString -- ^ arbitrary-length plaintext 333 -> BS.ByteString -- ^ ciphertext 334 cipher raw_key@(BI.PS _ _ kl) counter raw_nonce@(BI.PS _ _ nl) plaintext 335 | kl /= 32 = error "ppad-chacha (cipher): invalid key" 336 | nl /= 12 = error "ppad-chacha (cipher): invalid nonce" 337 | otherwise = runST $ do 338 let key = _parse_key raw_key 339 non = _parse_nonce raw_nonce 340 _cipher key counter non plaintext 341 342 _cipher 343 :: PrimMonad m 344 => Key 345 -> Word32 346 -> Nonce 347 -> BS.ByteString 348 -> m BS.ByteString 349 _cipher key counter nonce plaintext = do 350 ChaCha initial <- _chacha key counter nonce 351 state@(ChaCha s) <- _chacha_alloc 352 353 let loop acc !j bs = case BS.splitAt 64 bs of 354 (chunk@(BI.PS _ _ l), etc) 355 | l == 0 && BS.length etc == 0 -> pure $ -- XX 356 BS.toStrict (BSB.toLazyByteString acc) 357 | otherwise -> do 358 PA.copyMutablePrimArray s 0 initial 0 16 359 stream <- _block state j 360 let cip = BS.packZipWith (.^.) chunk stream 361 loop (acc <> BSB.byteString cip) (j + 1) etc 362 363 loop mempty counter plaintext 364 {-# INLINE _cipher #-} 365