commit ba5dcfccece5eca25950893371a98ae09f1a2c45
parent e1a7822844ca174fee813c76716d16c6511a3bb0
Author: Jared Tobin <jared@jtobin.io>
Date: Thu, 6 Mar 2025 18:21:06 +0400
lib: less writes
Diffstat:
2 files changed, 79 insertions(+), 63 deletions(-)
diff --git a/lib/Crypto/Cipher/ChaCha.hs b/lib/Crypto/Cipher/ChaCha.hs
@@ -27,9 +27,6 @@ fi = fromIntegral
-- parse strict ByteString in LE order to Word32 (verbatim from
-- Data.Binary)
---
--- invariant:
--- the input bytestring is at least 32 bits in length
unsafe_word32le :: BS.ByteString -> Word32
unsafe_word32le s =
(fi (s `BU.unsafeIndex` 3) `B.unsafeShiftL` 24) .|.
@@ -44,9 +41,6 @@ data WSPair = WSPair
-- variant of Data.ByteString.splitAt that behaves like an incremental
-- Word32 parser
---
--- invariant:
--- the input bytestring is at least 32 bits in length
unsafe_parseWsPair :: BS.ByteString -> WSPair
unsafe_parseWsPair (BI.BS x l) =
WSPair (unsafe_word32le (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4))
@@ -167,11 +161,24 @@ parse_nonce bs =
newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32)
deriving Eq
+chacha
+ :: PrimMonad m
+ => Key
+ -> Word32
+ -> Nonce
+ -> m (ChaCha (PrimState m))
+chacha key counter nonce = do
+ state <- _chacha_alloc
+ _chacha_set state key counter nonce
+ pure state
+
-- allocate a new chacha state
_chacha_alloc :: PrimMonad m => m (ChaCha (PrimState m))
_chacha_alloc = fmap ChaCha (PA.newPrimArray 16)
{-# INLINE _chacha_alloc #-}
+-- XX can be optimised more (only change counter)
+
-- set the values of a chacha state
_chacha_set
:: PrimMonad m
@@ -199,6 +206,14 @@ _chacha_set (ChaCha arr) Key {..} counter Nonce {..}= do
PA.writePrimArray arr 15 n2
{-# INLINEABLE _chacha_set #-}
+_chacha_counter
+ :: PrimMonad m
+ => ChaCha (PrimState m)
+ -> Word32
+ -> m ()
+_chacha_counter (ChaCha arr) counter =
+ PA.writePrimArray arr 12 counter
+
-- two full rounds (eight quarter rounds)
rounds :: PrimMonad m => ChaCha (PrimState m) -> m ()
rounds state = do
@@ -212,6 +227,21 @@ rounds state = do
quarter state 03 04 09 14
{-# INLINEABLE rounds #-}
+_block
+ :: PrimMonad m
+ => ChaCha (PrimState m)
+ -> Word32
+ -> m BS.ByteString
+_block state@(ChaCha s) counter = do
+ _chacha_counter state counter
+ i <- PA.freezePrimArray s 0 16
+ for_ [1..10 :: Int] (const (rounds state))
+ for_ [0..15 :: Int] $ \idx -> do
+ let iv = PA.indexPrimArray i idx
+ sv <- PA.readPrimArray s idx
+ PA.writePrimArray s idx (iv + sv)
+ serialize state
+
serialize :: PrimMonad m => ChaCha (PrimState m) -> m BS.ByteString
serialize (ChaCha m) = do
w64_0 <- w64 <$> PA.readPrimArray m 00 <*> PA.readPrimArray m 01
@@ -227,24 +257,22 @@ serialize (ChaCha m) = do
where
w64 a b = BSB.word64LE (fi a .|. (fi b .<<. 32))
-_chacha20_block
+-- chacha20 encryption --------------------------------------------------------
+
+encrypt
:: PrimMonad m
- => ChaCha (PrimState m)
- -> Key
+ => BS.ByteString
-> Word32
- -> Nonce
+ -> BS.ByteString
+ -> BS.ByteString
-> m BS.ByteString
-_chacha20_block state@(ChaCha s) key counter nonce = do
- _chacha_set state key counter nonce
- i <- PA.freezePrimArray s 0 16
- for_ [1..10 :: Int] (const (rounds state))
- for_ [0..15 :: Int] $ \idx -> do
- let iv = PA.indexPrimArray i idx
- sv <- PA.readPrimArray s idx
- PA.writePrimArray s idx (iv + sv)
- serialize state
-
--- chacha20 encryption --------------------------------------------------------
+encrypt raw_key@(BI.PS _ _ kl) counter raw_nonce@(BI.PS _ _ nl) plaintext
+ | kl /= 32 = error "ppad-chacha (encrypt): invalid key"
+ | nl /= 12 = error "ppad-chacha (encrypt): invalid nonce"
+ | otherwise = do
+ let key = parse_key raw_key
+ non = parse_nonce raw_nonce
+ _encrypt key counter non plaintext
_encrypt
:: PrimMonad m
@@ -254,33 +282,19 @@ _encrypt
-> BS.ByteString
-> m BS.ByteString
_encrypt key counter nonce plaintext = do
- state <- _chacha_alloc
- _chacha_set state key counter nonce
+ ChaCha initial <- chacha key counter nonce
+ state@(ChaCha s) <- _chacha_alloc
let loop acc !j bs = case BS.splitAt 64 bs of
(chunk@(BI.PS _ _ l), etc)
| l == 0 && BS.length etc == 0 -> pure $
BS.toStrict (BSB.toLazyByteString acc)
| otherwise -> do
- stream <- _chacha20_block state key j nonce
+ PA.copyMutablePrimArray s 0 initial 0 16
+ stream <- _block state j
let cip = BS.packZipWith (.^.) chunk stream
loop (acc <> BSB.byteString cip) (j + 1) etc
loop mempty counter plaintext
{-# INLINE _encrypt #-}
-encrypt
- :: PrimMonad m
- => BS.ByteString
- -> Word32
- -> BS.ByteString
- -> BS.ByteString
- -> m BS.ByteString
-encrypt raw_key@(BI.PS _ _ kl) counter raw_nonce@(BI.PS _ _ nl) plaintext
- | kl /= 32 = error "ppad-chacha (encrypt): invalid key"
- | nl /= 12 = error "ppad-chacha (encrypt): invalid nonce"
- | otherwise = do
- let key = parse_key raw_key
- non = parse_nonce raw_nonce
- _encrypt key counter non plaintext
-
diff --git a/test/Main.hs b/test/Main.hs
@@ -9,6 +9,7 @@ import qualified Crypto.Cipher.ChaCha as ChaCha
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base16 as B16
import Data.Foldable (for_)
+import Data.Maybe (fromJust)
import qualified Data.Primitive.PrimArray as PA
import Data.Word (Word32)
import Test.Tasty
@@ -20,8 +21,8 @@ main = defaultMain $ testGroup "ppad-chacha" [
, quarter_fullstate
, chacha20_block_init
, chacha20_rounds
- , chacha20_block
- , chacha20_encrypt
+ -- , chacha20_block
+ , encrypt
]
quarter :: TestTree
@@ -56,20 +57,17 @@ quarter_fullstate = H.testCase "quarter round (full chacha state)" $ do
H.assertEqual mempty e o
block_key :: BS.ByteString
-block_key =
- case B16.decode "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f" of
- Nothing -> error "bang"
- Just k -> k
+block_key = fromJust $
+ B16.decode "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"
block_non :: BS.ByteString
-block_non =
- case B16.decode "000000090000004a00000000" of
- Nothing -> error "bang"
- Just n -> n
+block_non = fromJust $ B16.decode "000000090000004a00000000"
chacha20_block_init :: TestTree
chacha20_block_init = H.testCase "chacha20 state init" $ do
- ChaCha.ChaCha foo <- ChaCha.chacha block_key 1 block_non
+ let key = ChaCha.parse_key block_key
+ non = ChaCha.parse_nonce block_non
+ ChaCha.ChaCha foo <- ChaCha.chacha key 1 non
state <- PA.freezePrimArray foo 0 16
let ref = PA.primArrayFromList [
0x61707865, 0x3320646e, 0x79622d32, 0x6b206574
@@ -81,7 +79,9 @@ chacha20_block_init = H.testCase "chacha20 state init" $ do
chacha20_rounds :: TestTree
chacha20_rounds = H.testCase "chacha20 20 rounds" $ do
- state@(ChaCha.ChaCha s) <- ChaCha.chacha block_key 1 block_non
+ let key = ChaCha.parse_key block_key
+ non = ChaCha.parse_nonce block_non
+ state@(ChaCha.ChaCha s) <- ChaCha.chacha key 1 non
for_ [1..10 :: Int] (const (ChaCha.rounds state))
out <- PA.freezePrimArray s 0 16
@@ -95,15 +95,17 @@ chacha20_rounds = H.testCase "chacha20 20 rounds" $ do
H.assertEqual mempty ref out
-chacha20_block :: TestTree
-chacha20_block = H.testCase "chacha20 block function" $ do
- o <- ChaCha.chacha20_block block_key 1 block_non
- let raw_exp = "10f1e7e4d13b5915500fdd1fa32071c4c7d1f4c733c068030422aa9ac3d46c4ed2826446079faa0914c2d705d98b02a2b5129cd1de164eb9cbd083e8a2503c4e"
- e = case B16.decode raw_exp of
- Nothing -> error "bang"
- Just x -> x
-
- H.assertEqual mempty e o
+-- chacha20_block :: TestTree
+-- chacha20_block = H.testCase "chacha20 block function" $ do
+-- let key = ChaCha.parse_key block_key
+-- non = ChaCha.parse_nonce block_non
+-- o <- ChaCha.chacha20_block key 1 non
+-- let raw_exp = "10f1e7e4d13b5915500fdd1fa32071c4c7d1f4c733c068030422aa9ac3d46c4ed2826446079faa0914c2d705d98b02a2b5129cd1de164eb9cbd083e8a2503c4e"
+-- e = case B16.decode raw_exp of
+-- Nothing -> error "bang"
+-- Just x -> x
+--
+-- H.assertEqual mempty e o
crypt_plain :: BS.ByteString
crypt_plain = case B16.decode "4c616469657320616e642047656e746c656d656e206f662074686520636c617373206f66202739393a204966204920636f756c64206f6666657220796f75206f6e6c79206f6e652074697020666f7220746865206675747572652c2073756e73637265656e20776f756c642062652069742e" of
@@ -120,8 +122,8 @@ crypt_non = case B16.decode "000000000000004a00000000" of
Nothing -> error "bang"
Just x -> x
-chacha20_encrypt :: TestTree
-chacha20_encrypt = H.testCase "chacha20 encrypt" $ do
- o <- ChaCha.chacha20_encrypt block_key 1 crypt_non crypt_plain
+encrypt :: TestTree
+encrypt = H.testCase "more efficient encrypt" $ do
+ o <- ChaCha.encrypt block_key 1 crypt_non crypt_plain
H.assertEqual mempty crypt_cip o