commit 6e77092c32087fdb8f4754e30ff93f344baa7142
parent 31293784294afea5d1cccd63d78dacda2c81c5de
Author: Jared Tobin <jared@jtobin.io>
Date: Thu, 6 Mar 2025 10:40:43 +0400
lib: skeleton, passing to chacha20 block function
Diffstat:
6 files changed, 389 insertions(+), 5 deletions(-)
diff --git a/.ghci b/.ghci
@@ -0,0 +1,3 @@
+:set -XOverloadedStrings
+:set -XMagicHash
+:set prompt "> "
diff --git a/flake.lock b/flake.lock
@@ -34,6 +34,37 @@
"type": "github"
}
},
+ "ppad-base16": {
+ "inputs": {
+ "flake-utils": [
+ "ppad-base16",
+ "ppad-nixpkgs",
+ "flake-utils"
+ ],
+ "nixpkgs": [
+ "ppad-base16",
+ "ppad-nixpkgs",
+ "nixpkgs"
+ ],
+ "ppad-nixpkgs": [
+ "ppad-nixpkgs"
+ ]
+ },
+ "locked": {
+ "lastModified": 1740802922,
+ "narHash": "sha256-j+sxOWxnsMgX3GHyf7Z44lvAiBkrl/p0lD0eYli6Lgo=",
+ "ref": "master",
+ "rev": "043c845ae7f280ddbfdb5568ea453c9943e49cf2",
+ "revCount": 21,
+ "type": "git",
+ "url": "git://git.ppad.tech/base16.git"
+ },
+ "original": {
+ "ref": "master",
+ "type": "git",
+ "url": "git://git.ppad.tech/base16.git"
+ }
+ },
"ppad-nixpkgs": {
"inputs": {
"flake-utils": "flake-utils",
@@ -64,6 +95,7 @@
"ppad-nixpkgs",
"nixpkgs"
],
+ "ppad-base16": "ppad-base16",
"ppad-nixpkgs": "ppad-nixpkgs"
}
},
diff --git a/flake.nix b/flake.nix
@@ -2,6 +2,12 @@
description = "A pure Haskell ChaCha stream cipher.";
inputs = {
+ ppad-base16 = {
+ type = "git";
+ url = "git://git.ppad.tech/base16.git";
+ ref = "master";
+ inputs.ppad-nixpkgs.follows = "ppad-nixpkgs";
+ };
ppad-nixpkgs = {
type = "git";
url = "git://git.ppad.tech/nixpkgs.git";
@@ -11,7 +17,9 @@
nixpkgs.follows = "ppad-nixpkgs/nixpkgs";
};
- outputs = { self, nixpkgs, flake-utils, ppad-nixpkgs }:
+ outputs = { self, nixpkgs, flake-utils, ppad-nixpkgs
+ , ppad-base16
+ }:
flake-utils.lib.eachDefaultSystem (system:
let
lib = "ppad-chacha";
@@ -21,6 +29,7 @@
hpkgs = pkgs.haskell.packages.ghc981.extend (new: old: {
${lib} = old.callCabal2nixWithOptions lib ./. "--enable-profiling" {};
+ ppad-base16 = ppad-base16.packages.${system}.default;
});
cc = pkgs.stdenv.cc;
diff --git a/lib/Crypto/Cipher/ChaCha.hs b/lib/Crypto/Cipher/ChaCha.hs
@@ -1,10 +1,244 @@
+{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE UnboxedTuples #-}
module Crypto.Cipher.ChaCha where
import qualified Data.Bits as B
-import Data.Bits ((.^.))
-import Data.Word (Word32)
+import Data.Bits ((.|.))
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Builder as BSB
+import qualified Data.ByteString.Internal as BI
+import qualified Data.ByteString.Unsafe as BU
+import Control.Monad.Primitive (PrimMonad, PrimState)
+import Data.Foldable (for_)
+import qualified Data.Primitive.PrimArray as PA
+import Foreign.ForeignPtr
import GHC.Exts
+import GHC.Word
+
+-- utils ----------------------------------------------------------------------
+
+-- keystroke saver
+fi :: (Integral a, Num b) => a -> b
+fi = fromIntegral
+{-# INLINE fi #-}
+
+-- parse strict ByteString in LE order to Word32 (verbatim from
+-- Data.Binary)
+--
+-- invariant:
+-- the input bytestring is at least 32 bits in length
+unsafe_word32le :: BS.ByteString -> Word32
+unsafe_word32le s =
+ (fi (s `BU.unsafeIndex` 3) `B.unsafeShiftL` 24) .|.
+ (fi (s `BU.unsafeIndex` 2) `B.unsafeShiftL` 16) .|.
+ (fi (s `BU.unsafeIndex` 1) `B.unsafeShiftL` 8) .|.
+ (fi (s `BU.unsafeIndex` 0))
+{-# INLINE unsafe_word32le #-}
+
+data WSPair = WSPair
+ {-# UNPACK #-} !Word32
+ {-# UNPACK #-} !BS.ByteString
+
+-- variant of Data.ByteString.splitAt that behaves like an incremental
+-- Word32 parser
+--
+-- invariant:
+-- the input bytestring is at least 32 bits in length
+unsafe_parseWsPair :: BS.ByteString -> WSPair
+unsafe_parseWsPair (BI.BS x l) =
+ WSPair (unsafe_word32le (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4))
+{-# INLINE unsafe_parseWsPair #-}
+
+-- chacha quarter round -------------------------------------------------------
+
+-- RFC8439 2.2
+quarter
+ :: PrimMonad m
+ => ChaCha (PrimState m)
+ -> Int
+ -> Int
+ -> Int
+ -> Int
+ -> m ()
+quarter (ChaCha m) i0 i1 i2 i3 = do
+ !(W32# a) <- PA.readPrimArray m i0
+ !(W32# b) <- PA.readPrimArray m i1
+ !(W32# c) <- PA.readPrimArray m i2
+ !(W32# d) <- PA.readPrimArray m i3
+
+ let !(# a1, b1, c1, d1 #) = quarter# a b c d
+
+ PA.writePrimArray m i0 (W32# a1)
+ PA.writePrimArray m i1 (W32# b1)
+ PA.writePrimArray m i2 (W32# c1)
+ PA.writePrimArray m i3 (W32# d1)
+
+-- for easy testing
+quarter'
+ :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32)
+quarter' (W32# a) (W32# b) (W32# c) (W32# d) =
+ let !(# a', b', c', d' #) = quarter# a b c d
+ in (W32# a', W32# b', W32# c', W32# d')
+{-# INLINE quarter' #-}
+
+-- RFC8439 2.1
+quarter#
+ :: Word32# -> Word32# -> Word32# -> Word32#
+ -> (# Word32#, Word32#, Word32#, Word32# #)
+quarter# a b c d =
+ let a0 = plusWord32# a b
+ d0 = xorWord32# d a0
+ d1 = rotateL# d0 16#
+
+ c0 = plusWord32# c d1
+ b0 = xorWord32# b c0
+ b1 = rotateL# b0 12#
+
+ a1 = plusWord32# a0 b1
+ d2 = xorWord32# d1 a1
+ d3 = rotateL# d2 8#
+
+ c1 = plusWord32# c0 d3
+ b2 = xorWord32# b1 c1
+ b3 = rotateL# b2 7#
+
+ in (# a1, b3, c1, d3 #)
+{-# INLINE quarter# #-}
+
+rotateL# :: Word32# -> Int# -> Word32#
+rotateL# w i
+ | isTrue# (i ==# 0#) = w
+ | otherwise = wordToWord32# (
+ ((word32ToWord# w) `uncheckedShiftL#` i)
+ `or#` ((word32ToWord# w) `uncheckedShiftRL#` (32# -# i)))
+{-# INLINE rotateL# #-}
+
+-- chacha block function ------------------------------------------------------
+
+data Key = Key {
+ k0 :: {-# UNPACK #-} !Word32
+ , k1 :: {-# UNPACK #-} !Word32
+ , k2 :: {-# UNPACK #-} !Word32
+ , k3 :: {-# UNPACK #-} !Word32
+ , k4 :: {-# UNPACK #-} !Word32
+ , k5 :: {-# UNPACK #-} !Word32
+ , k6 :: {-# UNPACK #-} !Word32
+ , k7 :: {-# UNPACK #-} !Word32
+ }
+ deriving (Eq, Show)
+
+-- parse strict 256-bit bytestring (length unchecked) to key
+parse_key :: BS.ByteString -> Key
+parse_key bs =
+ let !(WSPair k0 t0) = unsafe_parseWsPair bs
+ !(WSPair k1 t1) = unsafe_parseWsPair t0
+ !(WSPair k2 t2) = unsafe_parseWsPair t1
+ !(WSPair k3 t3) = unsafe_parseWsPair t2
+ !(WSPair k4 t4) = unsafe_parseWsPair t3
+ !(WSPair k5 t5) = unsafe_parseWsPair t4
+ !(WSPair k6 t6) = unsafe_parseWsPair t5
+ !(WSPair k7 t7) = unsafe_parseWsPair t6
+ in if BS.null t7
+ then Key {..}
+ else error "ppad-chacha (parse_key): bytes remaining"
+
+data Nonce = Nonce {
+ n0 :: {-# UNPACK #-} !Word32
+ , n1 :: {-# UNPACK #-} !Word32
+ , n2 :: {-# UNPACK #-} !Word32
+ }
+ deriving (Eq, Show)
+
+parse_nonce :: BS.ByteString -> Nonce
+parse_nonce bs =
+ let !(WSPair n0 t0) = unsafe_parseWsPair bs
+ !(WSPair n1 t1) = unsafe_parseWsPair t0
+ !(WSPair n2 t2) = unsafe_parseWsPair t1
+ in if BS.null t2
+ then Nonce {..}
+ else error "ppad-chacha (parse_nonce): bytes remaining"
+
+newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32)
+ deriving Eq
+
+-- init chacha state
+chacha
+ :: PrimMonad m
+ => BS.ByteString
+ -> Word32
+ -> BS.ByteString
+ -> m (ChaCha (PrimState m))
+chacha key counter nonce = do
+ arr <- PA.newPrimArray 16
+ PA.writePrimArray arr 00 0x61707865
+ PA.writePrimArray arr 01 0x3320646e
+ PA.writePrimArray arr 02 0x79622d32
+ PA.writePrimArray arr 03 0x6b206574
+
+ let Key {..} = parse_key key
+ PA.writePrimArray arr 04 k0
+ PA.writePrimArray arr 05 k1
+ PA.writePrimArray arr 06 k2
+ PA.writePrimArray arr 07 k3
+ PA.writePrimArray arr 08 k4
+ PA.writePrimArray arr 09 k5
+ PA.writePrimArray arr 10 k6
+ PA.writePrimArray arr 11 k7
+
+ PA.writePrimArray arr 12 counter
+
+ let Nonce {..} = parse_nonce nonce
+ PA.writePrimArray arr 13 n0
+ PA.writePrimArray arr 14 n1
+ PA.writePrimArray arr 15 n2
+
+ pure (ChaCha arr)
+
+
+rounds
+ :: PrimMonad m
+ => ChaCha (PrimState m)
+ -> m ()
+rounds state = do
+ quarter state 00 04 08 12
+ quarter state 01 05 09 13
+ quarter state 02 06 10 14
+ quarter state 03 07 11 15
+ quarter state 00 05 10 15
+ quarter state 01 06 11 12
+ quarter state 02 07 08 13
+ quarter state 03 04 09 14
+
+serialize
+ :: PrimMonad m
+ => ChaCha (PrimState m)
+ -> m BS.ByteString
+serialize (ChaCha m) = do
+ let loop acc j
+ | j == 16 = pure (BS.toStrict (BSB.toLazyByteString acc))
+ | otherwise = do
+ v <- PA.readPrimArray m j
+ loop (acc <> BSB.word32LE v) (j + 1)
+ loop mempty 0
+
+chacha20_block
+ :: PrimMonad m
+ => BS.ByteString
+ -> Word32
+ -> BS.ByteString
+ -> m BS.ByteString
+chacha20_block key counter nonce = do
+ state@(ChaCha s) <- chacha key counter nonce
+ i <- PA.freezePrimArray s 0 16
+ for_ [1..10 :: Int] (const (rounds state))
+ for_ [0..15 :: Int] $ \idx -> do
+ let !iv = PA.indexPrimArray i idx
+ sv <- PA.readPrimArray s idx
+ PA.writePrimArray s idx (iv + sv)
+ serialize state
+
+
diff --git a/ppad-chacha.cabal b/ppad-chacha.cabal
@@ -27,6 +27,7 @@ library
build-depends:
base >= 4.9 && < 5
, bytestring >= 0.9 && < 0.13
+ , primitive
test-suite chacha-tests
type: exitcode-stdio-1.0
@@ -40,9 +41,11 @@ test-suite chacha-tests
build-depends:
base
, bytestring
+ , ppad-base16
, ppad-chacha
+ , primitive
, tasty
- , tasty-quickcheck
+ , tasty-hunit
benchmark chacha-bench
type: exitcode-stdio-1.0
diff --git a/test/Main.hs b/test/Main.hs
@@ -1,4 +1,107 @@
+{-# LANGUAGE BangPatterns #-}
+{-# LANGUAGE MagicHash #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE UnboxedTuples #-}
+
module Main where
+import qualified Crypto.Cipher.ChaCha as ChaCha
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Base16 as B16
+import Data.Foldable (for_)
+import qualified Data.Primitive.PrimArray as PA
+import Data.Word (Word32)
+import Test.Tasty
+import qualified Test.Tasty.HUnit as H
+
+
main :: IO ()
-main = pure ()
+main = defaultMain $ testGroup "ppad-chacha" [
+ quarter
+ , quarter_fullstate
+ , chacha20_block_init
+ , chacha20_rounds
+ , chacha20_block
+ ]
+
+quarter :: TestTree
+quarter = H.testCase "quarter round" $ do
+ let e = (0xea2a92f4, 0xcb1cf8ce, 0x4581472e, 0x5881c4bb)
+ o = ChaCha.quarter' 0x11111111 0x01020304 0x9b8d6f43 0x01234567
+ H.assertEqual mempty e o
+
+quarter_fullstate :: TestTree
+quarter_fullstate = H.testCase "quarter round (full chacha state)" $ do
+ let inp :: PA.PrimArray Word32
+ inp = PA.primArrayFromList [
+ 0x879531e0, 0xc5ecf37d, 0x516461b1, 0xc9a62f8a
+ , 0x44c20ef3, 0x3390af7f, 0xd9fc690b, 0x2a5f714c
+ , 0x53372767, 0xb00a5631, 0x974c541a, 0x359e9963
+ , 0x5c971061, 0x3d631689, 0x2098d9d6, 0x91dbd320
+ ]
+ hot <- PA.unsafeThawPrimArray inp
+
+ ChaCha.quarter (ChaCha.ChaCha hot) 2 7 8 13
+
+ o <- PA.unsafeFreezePrimArray hot
+
+ let e :: PA.PrimArray Word32
+ e = PA.primArrayFromList [
+ 0x879531e0, 0xc5ecf37d, 0xbdb886dc, 0xc9a62f8a
+ , 0x44c20ef3, 0x3390af7f, 0xd9fc690b, 0xcfacafd2
+ , 0xe46bea80, 0xb00a5631, 0x974c541a, 0x359e9963
+ , 0x5c971061, 0xccc07c79, 0x2098d9d6, 0x91dbd320
+ ]
+
+ H.assertEqual mempty e o
+
+block_key :: BS.ByteString
+block_key =
+ case B16.decode "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f" of
+ Nothing -> error "bang"
+ Just k -> k
+
+block_non :: BS.ByteString
+block_non =
+ case B16.decode "000000090000004a00000000" of
+ Nothing -> error "bang"
+ Just n -> n
+
+chacha20_block_init :: TestTree
+chacha20_block_init = H.testCase "chacha20 state init" $ do
+ ChaCha.ChaCha foo <- ChaCha.chacha block_key 1 block_non
+ state <- PA.freezePrimArray foo 0 16
+ let ref = PA.primArrayFromList [
+ 0x61707865, 0x3320646e, 0x79622d32, 0x6b206574
+ , 0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c
+ , 0x13121110, 0x17161514, 0x1b1a1918, 0x1f1e1d1c
+ , 0x00000001, 0x09000000, 0x4a000000, 0x00000000
+ ]
+ H.assertEqual mempty ref state
+
+chacha20_rounds :: TestTree
+chacha20_rounds = H.testCase "chacha20 20 rounds" $ do
+ state@(ChaCha.ChaCha s) <- ChaCha.chacha block_key 1 block_non
+ for_ [1..10 :: Int] (const (ChaCha.rounds state))
+
+ out <- PA.freezePrimArray s 0 16
+
+ let ref = PA.primArrayFromList [
+ 0x837778ab, 0xe238d763, 0xa67ae21e, 0x5950bb2f
+ , 0xc4f2d0c7, 0xfc62bb2f, 0x8fa018fc, 0x3f5ec7b7
+ , 0x335271c2, 0xf29489f3, 0xeabda8fc, 0x82e46ebd
+ , 0xd19c12b4, 0xb04e16de, 0x9e83d0cb, 0x4e3c50a2
+ ]
+
+ H.assertEqual mempty ref out
+
+chacha20_block :: TestTree
+chacha20_block = H.testCase "chacha20 block function" $ do
+ o <- ChaCha.chacha20_block block_key 1 block_non
+ let raw_exp = "10f1e7e4d13b5915500fdd1fa32071c4c7d1f4c733c068030422aa9ac3d46c4ed2826446079faa0914c2d705d98b02a2b5129cd1de164eb9cbd083e8a2503c4e"
+ e = case B16.decode raw_exp of
+ Nothing -> error "bang"
+ Just x -> x
+
+ H.assertEqual mempty e o
+