ChaCha20.hs (11304B)
1 {-# OPTIONS_HADDOCK prune #-} 2 {-# LANGUAGE BangPatterns #-} 3 {-# LANGUAGE MagicHash #-} 4 {-# LANGUAGE RecordWildCards #-} 5 {-# LANGUAGE UnboxedTuples #-} 6 7 -- | 8 -- Module: Crypto.Cipher.ChaCha20 9 -- Copyright: (c) 2025 Jared Tobin 10 -- License: MIT 11 -- Maintainer: Jared Tobin <jared@ppad.tech> 12 -- 13 -- A pure ChaCha20 implementation, as specified by 14 -- [RFC 8439](https://datatracker.ietf.org/doc/html/rfc8439). 15 16 module Crypto.Cipher.ChaCha20 ( 17 -- * ChaCha20 stream cipher 18 cipher 19 20 -- * ChaCha20 block function 21 , block 22 23 -- * Error information 24 , Error(..) 25 26 -- testing 27 , ChaCha(..) 28 , _chacha 29 , _parse_key 30 , _parse_nonce 31 , _quarter 32 , _quarter_pure 33 , _rounds 34 ) where 35 36 import Control.Monad.ST 37 import qualified Data.Bits as B 38 import Data.Bits ((.|.), (.<<.), (.^.)) 39 import qualified Data.ByteString as BS 40 import qualified Data.ByteString.Builder as BSB 41 import qualified Data.ByteString.Internal as BI 42 import qualified Data.ByteString.Unsafe as BU 43 import Control.Monad.Primitive (PrimMonad, PrimState) 44 import Data.Foldable (for_) 45 import qualified Data.Primitive.PrimArray as PA 46 import Foreign.ForeignPtr 47 import GHC.Exts 48 import GHC.Word 49 50 -- utils ---------------------------------------------------------------------- 51 52 -- keystroke saver 53 fi :: (Integral a, Num b) => a -> b 54 fi = fromIntegral 55 {-# INLINE fi #-} 56 57 -- parse strict ByteString in LE order to Word32 (verbatim from 58 -- Data.Binary) 59 unsafe_word32le :: BS.ByteString -> Word32 60 unsafe_word32le s = 61 (fi (s `BU.unsafeIndex` 3) `B.unsafeShiftL` 24) .|. 62 (fi (s `BU.unsafeIndex` 2) `B.unsafeShiftL` 16) .|. 63 (fi (s `BU.unsafeIndex` 1) `B.unsafeShiftL` 8) .|. 64 (fi (s `BU.unsafeIndex` 0)) 65 {-# INLINE unsafe_word32le #-} 66 67 data WSPair = WSPair 68 {-# UNPACK #-} !Word32 69 {-# UNPACK #-} !BS.ByteString 70 71 -- variant of Data.ByteString.splitAt that behaves like an incremental 72 -- Word32 parser 73 unsafe_parseWsPair :: BS.ByteString -> WSPair 74 unsafe_parseWsPair (BI.BS x l) = 75 WSPair (unsafe_word32le (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4)) 76 {-# INLINE unsafe_parseWsPair #-} 77 78 -- chacha quarter round ------------------------------------------------------- 79 80 -- RFC8439 2.2 81 _quarter 82 :: PrimMonad m 83 => ChaCha (PrimState m) 84 -> Int 85 -> Int 86 -> Int 87 -> Int 88 -> m () 89 _quarter (ChaCha m) i0 i1 i2 i3 = do 90 !(W32# a) <- PA.readPrimArray m i0 91 !(W32# b) <- PA.readPrimArray m i1 92 !(W32# c) <- PA.readPrimArray m i2 93 !(W32# d) <- PA.readPrimArray m i3 94 95 let !(# a1, b1, c1, d1 #) = quarter# a b c d 96 97 PA.writePrimArray m i0 (W32# a1) 98 PA.writePrimArray m i1 (W32# b1) 99 PA.writePrimArray m i2 (W32# c1) 100 PA.writePrimArray m i3 (W32# d1) 101 {-# INLINEABLE _quarter #-} 102 103 _quarter_pure 104 :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32) 105 _quarter_pure (W32# a) (W32# b) (W32# c) (W32# d) = 106 let !(# a', b', c', d' #) = quarter# a b c d 107 in (W32# a', W32# b', W32# c', W32# d') 108 {-# INLINE _quarter_pure #-} 109 110 -- RFC8439 2.1 111 quarter# 112 :: Word32# -> Word32# -> Word32# -> Word32# 113 -> (# Word32#, Word32#, Word32#, Word32# #) 114 quarter# a b c d = 115 let a0 = plusWord32# a b 116 d0 = xorWord32# d a0 117 d1 = rotateL# d0 16# 118 119 c0 = plusWord32# c d1 120 b0 = xorWord32# b c0 121 b1 = rotateL# b0 12# 122 123 a1 = plusWord32# a0 b1 124 d2 = xorWord32# d1 a1 125 d3 = rotateL# d2 8# 126 127 c1 = plusWord32# c0 d3 128 b2 = xorWord32# b1 c1 129 b3 = rotateL# b2 7# 130 131 in (# a1, b3, c1, d3 #) 132 {-# INLINE quarter# #-} 133 134 rotateL# :: Word32# -> Int# -> Word32# 135 rotateL# w i 136 | isTrue# (i ==# 0#) = w 137 | otherwise = wordToWord32# ( 138 ((word32ToWord# w) `uncheckedShiftL#` i) 139 `or#` ((word32ToWord# w) `uncheckedShiftRL#` (32# -# i))) 140 {-# INLINE rotateL# #-} 141 142 -- key and nonce parsing ------------------------------------------------------ 143 144 data Key = Key { 145 k0 :: {-# UNPACK #-} !Word32 146 , k1 :: {-# UNPACK #-} !Word32 147 , k2 :: {-# UNPACK #-} !Word32 148 , k3 :: {-# UNPACK #-} !Word32 149 , k4 :: {-# UNPACK #-} !Word32 150 , k5 :: {-# UNPACK #-} !Word32 151 , k6 :: {-# UNPACK #-} !Word32 152 , k7 :: {-# UNPACK #-} !Word32 153 } 154 deriving (Eq, Show) 155 156 -- parse strict 256-bit bytestring (length unchecked) to key 157 _parse_key :: BS.ByteString -> Key 158 _parse_key bs = 159 let !(WSPair k0 t0) = unsafe_parseWsPair bs 160 !(WSPair k1 t1) = unsafe_parseWsPair t0 161 !(WSPair k2 t2) = unsafe_parseWsPair t1 162 !(WSPair k3 t3) = unsafe_parseWsPair t2 163 !(WSPair k4 t4) = unsafe_parseWsPair t3 164 !(WSPair k5 t5) = unsafe_parseWsPair t4 165 !(WSPair k6 t6) = unsafe_parseWsPair t5 166 !(WSPair k7 t7) = unsafe_parseWsPair t6 167 in if BS.null t7 168 then Key {..} 169 else error "ppad-chacha (_parse_key): internal error, bytes remaining" 170 171 data Nonce = Nonce { 172 n0 :: {-# UNPACK #-} !Word32 173 , n1 :: {-# UNPACK #-} !Word32 174 , n2 :: {-# UNPACK #-} !Word32 175 } 176 deriving (Eq, Show) 177 178 -- parse strict 96-bit bytestring (length unchecked) to nonce 179 _parse_nonce :: BS.ByteString -> Nonce 180 _parse_nonce bs = 181 let !(WSPair n0 t0) = unsafe_parseWsPair bs 182 !(WSPair n1 t1) = unsafe_parseWsPair t0 183 !(WSPair n2 t2) = unsafe_parseWsPair t1 184 in if BS.null t2 185 then Nonce {..} 186 else error "ppad-chacha (_parse_nonce): internal error, bytes remaining" 187 188 -- chacha20 block function ---------------------------------------------------- 189 190 newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32) 191 deriving Eq 192 193 _chacha 194 :: PrimMonad m 195 => Key 196 -> Word32 197 -> Nonce 198 -> m (ChaCha (PrimState m)) 199 _chacha key counter nonce = do 200 state <- _chacha_alloc 201 _chacha_set state key counter nonce 202 pure state 203 204 -- allocate a new chacha state 205 _chacha_alloc :: PrimMonad m => m (ChaCha (PrimState m)) 206 _chacha_alloc = fmap ChaCha (PA.newPrimArray 16) 207 {-# INLINE _chacha_alloc #-} 208 209 -- set the values of a chacha state 210 _chacha_set 211 :: PrimMonad m 212 => ChaCha (PrimState m) 213 -> Key 214 -> Word32 215 -> Nonce 216 -> m () 217 _chacha_set (ChaCha arr) Key {..} counter Nonce {..}= do 218 PA.writePrimArray arr 00 0x61707865 219 PA.writePrimArray arr 01 0x3320646e 220 PA.writePrimArray arr 02 0x79622d32 221 PA.writePrimArray arr 03 0x6b206574 222 PA.writePrimArray arr 04 k0 223 PA.writePrimArray arr 05 k1 224 PA.writePrimArray arr 06 k2 225 PA.writePrimArray arr 07 k3 226 PA.writePrimArray arr 08 k4 227 PA.writePrimArray arr 09 k5 228 PA.writePrimArray arr 10 k6 229 PA.writePrimArray arr 11 k7 230 PA.writePrimArray arr 12 counter 231 PA.writePrimArray arr 13 n0 232 PA.writePrimArray arr 14 n1 233 PA.writePrimArray arr 15 n2 234 {-# INLINEABLE _chacha_set #-} 235 236 _chacha_counter 237 :: PrimMonad m 238 => ChaCha (PrimState m) 239 -> Word32 240 -> m () 241 _chacha_counter (ChaCha arr) counter = 242 PA.writePrimArray arr 12 counter 243 244 -- two full rounds (eight quarter rounds) 245 _rounds :: PrimMonad m => ChaCha (PrimState m) -> m () 246 _rounds state = do 247 _quarter state 00 04 08 12 248 _quarter state 01 05 09 13 249 _quarter state 02 06 10 14 250 _quarter state 03 07 11 15 251 _quarter state 00 05 10 15 252 _quarter state 01 06 11 12 253 _quarter state 02 07 08 13 254 _quarter state 03 04 09 14 255 {-# INLINEABLE _rounds #-} 256 257 _block 258 :: PrimMonad m 259 => ChaCha (PrimState m) 260 -> Word32 261 -> m BS.ByteString 262 _block state@(ChaCha s) counter = do 263 _chacha_counter state counter 264 i <- PA.freezePrimArray s 0 16 265 for_ [1..10 :: Int] (const (_rounds state)) 266 for_ [0..15 :: Int] $ \idx -> do 267 let iv = PA.indexPrimArray i idx 268 sv <- PA.readPrimArray s idx 269 PA.writePrimArray s idx (iv + sv) 270 serialize state 271 272 -- | Error values. 273 data Error = 274 InvalidKey -- ^ the provided key was not 256 bits long 275 | InvalidNonce -- ^ the provided nonce was none 96 bits long 276 deriving (Eq, Show) 277 278 -- RFC8439 2.3 279 280 -- | The ChaCha20 block function. Useful for generating a keystream. 281 -- 282 -- Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the 283 -- key must be exactly 256 bits, and the nonce exactly 96 bits. 284 block 285 :: BS.ByteString -- ^ 256-bit key 286 -> Word32 -- ^ 32-bit counter 287 -> BS.ByteString -- ^ 96-bit nonce 288 -> Either Error BS.ByteString -- ^ 512-bit keystream 289 block key@(BI.PS _ _ kl) counter nonce@(BI.PS _ _ nl) 290 | kl /= 32 = Left InvalidKey 291 | nl /= 12 = Left InvalidNonce 292 | otherwise = pure $ runST $ do 293 let k = _parse_key key 294 n = _parse_nonce nonce 295 state@(ChaCha s) <- _chacha k counter n 296 i <- PA.freezePrimArray s 0 16 297 for_ [1..10 :: Int] (const (_rounds state)) 298 for_ [0..15 :: Int] $ \idx -> do 299 let iv = PA.indexPrimArray i idx 300 sv <- PA.readPrimArray s idx 301 PA.writePrimArray s idx (iv + sv) 302 serialize state 303 304 serialize :: PrimMonad m => ChaCha (PrimState m) -> m BS.ByteString 305 serialize (ChaCha m) = do 306 w64_0 <- w64 <$> PA.readPrimArray m 00 <*> PA.readPrimArray m 01 307 w64_1 <- w64 <$> PA.readPrimArray m 02 <*> PA.readPrimArray m 03 308 w64_2 <- w64 <$> PA.readPrimArray m 04 <*> PA.readPrimArray m 05 309 w64_3 <- w64 <$> PA.readPrimArray m 06 <*> PA.readPrimArray m 07 310 w64_4 <- w64 <$> PA.readPrimArray m 08 <*> PA.readPrimArray m 09 311 w64_5 <- w64 <$> PA.readPrimArray m 10 <*> PA.readPrimArray m 11 312 w64_6 <- w64 <$> PA.readPrimArray m 12 <*> PA.readPrimArray m 13 313 w64_7 <- w64 <$> PA.readPrimArray m 14 <*> PA.readPrimArray m 15 314 pure . BS.toStrict . BSB.toLazyByteString . mconcat $ 315 [w64_0, w64_1, w64_2, w64_3, w64_4, w64_5, w64_6, w64_7] 316 where 317 w64 a b = BSB.word64LE (fi a .|. (fi b .<<. 32)) 318 319 -- chacha20 encryption -------------------------------------------------------- 320 321 -- RFC8439 2.4 322 323 -- | The ChaCha20 stream cipher. Generates a keystream and then XOR's 324 -- the supplied input with it; use it both to encrypt plaintext and 325 -- decrypt ciphertext. 326 -- 327 -- Per [RFC8439](https://datatracker.ietf.org/doc/html/rfc8439), the 328 -- key must be exactly 256 bits, and the nonce exactly 96 bits. 329 -- 330 -- >>> let key = "don't tell anyone my secret key!" 331 -- >>> let non = "or my nonce!" 332 -- >>> let cip = cipher key 1 non "but you can share the plaintext" 333 -- >>> cip 334 -- "\192*c\248A\204\211n\130y8\197\146k\245\178Y\197=\180_\223\138\146:^\206\&0\v[\201" 335 -- >>> cipher key 1 non cip 336 -- Right "but you can share the plaintext" 337 cipher 338 :: BS.ByteString -- ^ 256-bit key 339 -> Word32 -- ^ 32-bit counter 340 -> BS.ByteString -- ^ 96-bit nonce 341 -> BS.ByteString -- ^ arbitrary-length plaintext 342 -> Either Error BS.ByteString -- ^ ciphertext 343 cipher raw_key@(BI.PS _ _ kl) counter raw_nonce@(BI.PS _ _ nl) plaintext 344 | kl /= 32 = Left InvalidKey 345 | nl /= 12 = Left InvalidNonce 346 | otherwise = pure $ runST $ do 347 let key = _parse_key raw_key 348 non = _parse_nonce raw_nonce 349 _cipher key counter non plaintext 350 351 _cipher 352 :: PrimMonad m 353 => Key 354 -> Word32 355 -> Nonce 356 -> BS.ByteString 357 -> m BS.ByteString 358 _cipher key counter nonce plaintext = do 359 ChaCha initial <- _chacha key counter nonce 360 state@(ChaCha s) <- _chacha_alloc 361 362 let loop acc !j bs = case BS.splitAt 64 bs of 363 (chunk@(BI.PS _ _ l), etc@(BI.PS _ _ le)) 364 | l == 0 && le == 0 -> pure $ 365 BS.toStrict (BSB.toLazyByteString acc) 366 | otherwise -> do 367 PA.copyMutablePrimArray s 0 initial 0 16 368 stream <- _block state j 369 let cip = BS.packZipWith (.^.) chunk stream 370 loop (acc <> BSB.byteString cip) (j + 1) etc 371 372 loop mempty counter plaintext 373 {-# INLINE _cipher #-} 374