chacha

The ChaCha20 stream cipher (docs.ppad.tech/chacha).
git clone git://git.ppad.tech/chacha.git
Log | Files | Refs | LICENSE

commit ba5dcfccece5eca25950893371a98ae09f1a2c45
parent e1a7822844ca174fee813c76716d16c6511a3bb0
Author: Jared Tobin <jared@jtobin.io>
Date:   Thu,  6 Mar 2025 18:21:06 +0400

lib: less writes

Diffstat:
Mlib/Crypto/Cipher/ChaCha.hs | 92+++++++++++++++++++++++++++++++++++++++++++++----------------------------------
Mtest/Main.hs | 50++++++++++++++++++++++++++------------------------
2 files changed, 79 insertions(+), 63 deletions(-)

diff --git a/lib/Crypto/Cipher/ChaCha.hs b/lib/Crypto/Cipher/ChaCha.hs @@ -27,9 +27,6 @@ fi = fromIntegral -- parse strict ByteString in LE order to Word32 (verbatim from -- Data.Binary) --- --- invariant: --- the input bytestring is at least 32 bits in length unsafe_word32le :: BS.ByteString -> Word32 unsafe_word32le s = (fi (s `BU.unsafeIndex` 3) `B.unsafeShiftL` 24) .|. @@ -44,9 +41,6 @@ data WSPair = WSPair -- variant of Data.ByteString.splitAt that behaves like an incremental -- Word32 parser --- --- invariant: --- the input bytestring is at least 32 bits in length unsafe_parseWsPair :: BS.ByteString -> WSPair unsafe_parseWsPair (BI.BS x l) = WSPair (unsafe_word32le (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4)) @@ -167,11 +161,24 @@ parse_nonce bs = newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32) deriving Eq +chacha + :: PrimMonad m + => Key + -> Word32 + -> Nonce + -> m (ChaCha (PrimState m)) +chacha key counter nonce = do + state <- _chacha_alloc + _chacha_set state key counter nonce + pure state + -- allocate a new chacha state _chacha_alloc :: PrimMonad m => m (ChaCha (PrimState m)) _chacha_alloc = fmap ChaCha (PA.newPrimArray 16) {-# INLINE _chacha_alloc #-} +-- XX can be optimised more (only change counter) + -- set the values of a chacha state _chacha_set :: PrimMonad m @@ -199,6 +206,14 @@ _chacha_set (ChaCha arr) Key {..} counter Nonce {..}= do PA.writePrimArray arr 15 n2 {-# INLINEABLE _chacha_set #-} +_chacha_counter + :: PrimMonad m + => ChaCha (PrimState m) + -> Word32 + -> m () +_chacha_counter (ChaCha arr) counter = + PA.writePrimArray arr 12 counter + -- two full rounds (eight quarter rounds) rounds :: PrimMonad m => ChaCha (PrimState m) -> m () rounds state = do @@ -212,6 +227,21 @@ rounds state = do quarter state 03 04 09 14 {-# INLINEABLE rounds #-} +_block + :: PrimMonad m + => ChaCha (PrimState m) + -> Word32 + -> m BS.ByteString +_block state@(ChaCha s) counter = do + _chacha_counter state counter + i <- PA.freezePrimArray s 0 16 + for_ [1..10 :: Int] (const (rounds state)) + for_ [0..15 :: Int] $ \idx -> do + let iv = PA.indexPrimArray i idx + sv <- PA.readPrimArray s idx + PA.writePrimArray s idx (iv + sv) + serialize state + serialize :: PrimMonad m => ChaCha (PrimState m) -> m BS.ByteString serialize (ChaCha m) = do w64_0 <- w64 <$> PA.readPrimArray m 00 <*> PA.readPrimArray m 01 @@ -227,24 +257,22 @@ serialize (ChaCha m) = do where w64 a b = BSB.word64LE (fi a .|. (fi b .<<. 32)) -_chacha20_block +-- chacha20 encryption -------------------------------------------------------- + +encrypt :: PrimMonad m - => ChaCha (PrimState m) - -> Key + => BS.ByteString -> Word32 - -> Nonce + -> BS.ByteString + -> BS.ByteString -> m BS.ByteString -_chacha20_block state@(ChaCha s) key counter nonce = do - _chacha_set state key counter nonce - i <- PA.freezePrimArray s 0 16 - for_ [1..10 :: Int] (const (rounds state)) - for_ [0..15 :: Int] $ \idx -> do - let iv = PA.indexPrimArray i idx - sv <- PA.readPrimArray s idx - PA.writePrimArray s idx (iv + sv) - serialize state - --- chacha20 encryption -------------------------------------------------------- +encrypt raw_key@(BI.PS _ _ kl) counter raw_nonce@(BI.PS _ _ nl) plaintext + | kl /= 32 = error "ppad-chacha (encrypt): invalid key" + | nl /= 12 = error "ppad-chacha (encrypt): invalid nonce" + | otherwise = do + let key = parse_key raw_key + non = parse_nonce raw_nonce + _encrypt key counter non plaintext _encrypt :: PrimMonad m @@ -254,33 +282,19 @@ _encrypt -> BS.ByteString -> m BS.ByteString _encrypt key counter nonce plaintext = do - state <- _chacha_alloc - _chacha_set state key counter nonce + ChaCha initial <- chacha key counter nonce + state@(ChaCha s) <- _chacha_alloc let loop acc !j bs = case BS.splitAt 64 bs of (chunk@(BI.PS _ _ l), etc) | l == 0 && BS.length etc == 0 -> pure $ BS.toStrict (BSB.toLazyByteString acc) | otherwise -> do - stream <- _chacha20_block state key j nonce + PA.copyMutablePrimArray s 0 initial 0 16 + stream <- _block state j let cip = BS.packZipWith (.^.) chunk stream loop (acc <> BSB.byteString cip) (j + 1) etc loop mempty counter plaintext {-# INLINE _encrypt #-} -encrypt - :: PrimMonad m - => BS.ByteString - -> Word32 - -> BS.ByteString - -> BS.ByteString - -> m BS.ByteString -encrypt raw_key@(BI.PS _ _ kl) counter raw_nonce@(BI.PS _ _ nl) plaintext - | kl /= 32 = error "ppad-chacha (encrypt): invalid key" - | nl /= 12 = error "ppad-chacha (encrypt): invalid nonce" - | otherwise = do - let key = parse_key raw_key - non = parse_nonce raw_nonce - _encrypt key counter non plaintext - diff --git a/test/Main.hs b/test/Main.hs @@ -9,6 +9,7 @@ import qualified Crypto.Cipher.ChaCha as ChaCha import qualified Data.ByteString as BS import qualified Data.ByteString.Base16 as B16 import Data.Foldable (for_) +import Data.Maybe (fromJust) import qualified Data.Primitive.PrimArray as PA import Data.Word (Word32) import Test.Tasty @@ -20,8 +21,8 @@ main = defaultMain $ testGroup "ppad-chacha" [ , quarter_fullstate , chacha20_block_init , chacha20_rounds - , chacha20_block - , chacha20_encrypt + -- , chacha20_block + , encrypt ] quarter :: TestTree @@ -56,20 +57,17 @@ quarter_fullstate = H.testCase "quarter round (full chacha state)" $ do H.assertEqual mempty e o block_key :: BS.ByteString -block_key = - case B16.decode "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f" of - Nothing -> error "bang" - Just k -> k +block_key = fromJust $ + B16.decode "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f" block_non :: BS.ByteString -block_non = - case B16.decode "000000090000004a00000000" of - Nothing -> error "bang" - Just n -> n +block_non = fromJust $ B16.decode "000000090000004a00000000" chacha20_block_init :: TestTree chacha20_block_init = H.testCase "chacha20 state init" $ do - ChaCha.ChaCha foo <- ChaCha.chacha block_key 1 block_non + let key = ChaCha.parse_key block_key + non = ChaCha.parse_nonce block_non + ChaCha.ChaCha foo <- ChaCha.chacha key 1 non state <- PA.freezePrimArray foo 0 16 let ref = PA.primArrayFromList [ 0x61707865, 0x3320646e, 0x79622d32, 0x6b206574 @@ -81,7 +79,9 @@ chacha20_block_init = H.testCase "chacha20 state init" $ do chacha20_rounds :: TestTree chacha20_rounds = H.testCase "chacha20 20 rounds" $ do - state@(ChaCha.ChaCha s) <- ChaCha.chacha block_key 1 block_non + let key = ChaCha.parse_key block_key + non = ChaCha.parse_nonce block_non + state@(ChaCha.ChaCha s) <- ChaCha.chacha key 1 non for_ [1..10 :: Int] (const (ChaCha.rounds state)) out <- PA.freezePrimArray s 0 16 @@ -95,15 +95,17 @@ chacha20_rounds = H.testCase "chacha20 20 rounds" $ do H.assertEqual mempty ref out -chacha20_block :: TestTree -chacha20_block = H.testCase "chacha20 block function" $ do - o <- ChaCha.chacha20_block block_key 1 block_non - let raw_exp = "10f1e7e4d13b5915500fdd1fa32071c4c7d1f4c733c068030422aa9ac3d46c4ed2826446079faa0914c2d705d98b02a2b5129cd1de164eb9cbd083e8a2503c4e" - e = case B16.decode raw_exp of - Nothing -> error "bang" - Just x -> x - - H.assertEqual mempty e o +-- chacha20_block :: TestTree +-- chacha20_block = H.testCase "chacha20 block function" $ do +-- let key = ChaCha.parse_key block_key +-- non = ChaCha.parse_nonce block_non +-- o <- ChaCha.chacha20_block key 1 non +-- let raw_exp = "10f1e7e4d13b5915500fdd1fa32071c4c7d1f4c733c068030422aa9ac3d46c4ed2826446079faa0914c2d705d98b02a2b5129cd1de164eb9cbd083e8a2503c4e" +-- e = case B16.decode raw_exp of +-- Nothing -> error "bang" +-- Just x -> x +-- +-- H.assertEqual mempty e o crypt_plain :: BS.ByteString crypt_plain = case B16.decode "4c616469657320616e642047656e746c656d656e206f662074686520636c617373206f66202739393a204966204920636f756c64206f6666657220796f75206f6e6c79206f6e652074697020666f7220746865206675747572652c2073756e73637265656e20776f756c642062652069742e" of @@ -120,8 +122,8 @@ crypt_non = case B16.decode "000000000000004a00000000" of Nothing -> error "bang" Just x -> x -chacha20_encrypt :: TestTree -chacha20_encrypt = H.testCase "chacha20 encrypt" $ do - o <- ChaCha.chacha20_encrypt block_key 1 crypt_non crypt_plain +encrypt :: TestTree +encrypt = H.testCase "more efficient encrypt" $ do + o <- ChaCha.encrypt block_key 1 crypt_non crypt_plain H.assertEqual mempty crypt_cip o