chacha

The ChaCha20 stream cipher (docs.ppad.tech/chacha).
git clone git://git.ppad.tech/chacha.git
Log | Files | Refs | LICENSE

commit 6e77092c32087fdb8f4754e30ff93f344baa7142
parent 31293784294afea5d1cccd63d78dacda2c81c5de
Author: Jared Tobin <jared@jtobin.io>
Date:   Thu,  6 Mar 2025 10:40:43 +0400

lib: skeleton, passing to chacha20 block function

Diffstat:
A.ghci | 3+++
Mflake.lock | 32++++++++++++++++++++++++++++++++
Mflake.nix | 11++++++++++-
Mlib/Crypto/Cipher/ChaCha.hs | 238++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
Mppad-chacha.cabal | 5++++-
Mtest/Main.hs | 105++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
6 files changed, 389 insertions(+), 5 deletions(-)

diff --git a/.ghci b/.ghci @@ -0,0 +1,3 @@ +:set -XOverloadedStrings +:set -XMagicHash +:set prompt "> " diff --git a/flake.lock b/flake.lock @@ -34,6 +34,37 @@ "type": "github" } }, + "ppad-base16": { + "inputs": { + "flake-utils": [ + "ppad-base16", + "ppad-nixpkgs", + "flake-utils" + ], + "nixpkgs": [ + "ppad-base16", + "ppad-nixpkgs", + "nixpkgs" + ], + "ppad-nixpkgs": [ + "ppad-nixpkgs" + ] + }, + "locked": { + "lastModified": 1740802922, + "narHash": "sha256-j+sxOWxnsMgX3GHyf7Z44lvAiBkrl/p0lD0eYli6Lgo=", + "ref": "master", + "rev": "043c845ae7f280ddbfdb5568ea453c9943e49cf2", + "revCount": 21, + "type": "git", + "url": "git://git.ppad.tech/base16.git" + }, + "original": { + "ref": "master", + "type": "git", + "url": "git://git.ppad.tech/base16.git" + } + }, "ppad-nixpkgs": { "inputs": { "flake-utils": "flake-utils", @@ -64,6 +95,7 @@ "ppad-nixpkgs", "nixpkgs" ], + "ppad-base16": "ppad-base16", "ppad-nixpkgs": "ppad-nixpkgs" } }, diff --git a/flake.nix b/flake.nix @@ -2,6 +2,12 @@ description = "A pure Haskell ChaCha stream cipher."; inputs = { + ppad-base16 = { + type = "git"; + url = "git://git.ppad.tech/base16.git"; + ref = "master"; + inputs.ppad-nixpkgs.follows = "ppad-nixpkgs"; + }; ppad-nixpkgs = { type = "git"; url = "git://git.ppad.tech/nixpkgs.git"; @@ -11,7 +17,9 @@ nixpkgs.follows = "ppad-nixpkgs/nixpkgs"; }; - outputs = { self, nixpkgs, flake-utils, ppad-nixpkgs }: + outputs = { self, nixpkgs, flake-utils, ppad-nixpkgs + , ppad-base16 + }: flake-utils.lib.eachDefaultSystem (system: let lib = "ppad-chacha"; @@ -21,6 +29,7 @@ hpkgs = pkgs.haskell.packages.ghc981.extend (new: old: { ${lib} = old.callCabal2nixWithOptions lib ./. "--enable-profiling" {}; + ppad-base16 = ppad-base16.packages.${system}.default; }); cc = pkgs.stdenv.cc; diff --git a/lib/Crypto/Cipher/ChaCha.hs b/lib/Crypto/Cipher/ChaCha.hs @@ -1,10 +1,244 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE MagicHash #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE UnboxedTuples #-} module Crypto.Cipher.ChaCha where import qualified Data.Bits as B -import Data.Bits ((.^.)) -import Data.Word (Word32) +import Data.Bits ((.|.)) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Builder as BSB +import qualified Data.ByteString.Internal as BI +import qualified Data.ByteString.Unsafe as BU +import Control.Monad.Primitive (PrimMonad, PrimState) +import Data.Foldable (for_) +import qualified Data.Primitive.PrimArray as PA +import Foreign.ForeignPtr import GHC.Exts +import GHC.Word + +-- utils ---------------------------------------------------------------------- + +-- keystroke saver +fi :: (Integral a, Num b) => a -> b +fi = fromIntegral +{-# INLINE fi #-} + +-- parse strict ByteString in LE order to Word32 (verbatim from +-- Data.Binary) +-- +-- invariant: +-- the input bytestring is at least 32 bits in length +unsafe_word32le :: BS.ByteString -> Word32 +unsafe_word32le s = + (fi (s `BU.unsafeIndex` 3) `B.unsafeShiftL` 24) .|. + (fi (s `BU.unsafeIndex` 2) `B.unsafeShiftL` 16) .|. + (fi (s `BU.unsafeIndex` 1) `B.unsafeShiftL` 8) .|. + (fi (s `BU.unsafeIndex` 0)) +{-# INLINE unsafe_word32le #-} + +data WSPair = WSPair + {-# UNPACK #-} !Word32 + {-# UNPACK #-} !BS.ByteString + +-- variant of Data.ByteString.splitAt that behaves like an incremental +-- Word32 parser +-- +-- invariant: +-- the input bytestring is at least 32 bits in length +unsafe_parseWsPair :: BS.ByteString -> WSPair +unsafe_parseWsPair (BI.BS x l) = + WSPair (unsafe_word32le (BI.BS x 4)) (BI.BS (plusForeignPtr x 4) (l - 4)) +{-# INLINE unsafe_parseWsPair #-} + +-- chacha quarter round ------------------------------------------------------- + +-- RFC8439 2.2 +quarter + :: PrimMonad m + => ChaCha (PrimState m) + -> Int + -> Int + -> Int + -> Int + -> m () +quarter (ChaCha m) i0 i1 i2 i3 = do + !(W32# a) <- PA.readPrimArray m i0 + !(W32# b) <- PA.readPrimArray m i1 + !(W32# c) <- PA.readPrimArray m i2 + !(W32# d) <- PA.readPrimArray m i3 + + let !(# a1, b1, c1, d1 #) = quarter# a b c d + + PA.writePrimArray m i0 (W32# a1) + PA.writePrimArray m i1 (W32# b1) + PA.writePrimArray m i2 (W32# c1) + PA.writePrimArray m i3 (W32# d1) + +-- for easy testing +quarter' + :: Word32 -> Word32 -> Word32 -> Word32 -> (Word32, Word32, Word32, Word32) +quarter' (W32# a) (W32# b) (W32# c) (W32# d) = + let !(# a', b', c', d' #) = quarter# a b c d + in (W32# a', W32# b', W32# c', W32# d') +{-# INLINE quarter' #-} + +-- RFC8439 2.1 +quarter# + :: Word32# -> Word32# -> Word32# -> Word32# + -> (# Word32#, Word32#, Word32#, Word32# #) +quarter# a b c d = + let a0 = plusWord32# a b + d0 = xorWord32# d a0 + d1 = rotateL# d0 16# + + c0 = plusWord32# c d1 + b0 = xorWord32# b c0 + b1 = rotateL# b0 12# + + a1 = plusWord32# a0 b1 + d2 = xorWord32# d1 a1 + d3 = rotateL# d2 8# + + c1 = plusWord32# c0 d3 + b2 = xorWord32# b1 c1 + b3 = rotateL# b2 7# + + in (# a1, b3, c1, d3 #) +{-# INLINE quarter# #-} + +rotateL# :: Word32# -> Int# -> Word32# +rotateL# w i + | isTrue# (i ==# 0#) = w + | otherwise = wordToWord32# ( + ((word32ToWord# w) `uncheckedShiftL#` i) + `or#` ((word32ToWord# w) `uncheckedShiftRL#` (32# -# i))) +{-# INLINE rotateL# #-} + +-- chacha block function ------------------------------------------------------ + +data Key = Key { + k0 :: {-# UNPACK #-} !Word32 + , k1 :: {-# UNPACK #-} !Word32 + , k2 :: {-# UNPACK #-} !Word32 + , k3 :: {-# UNPACK #-} !Word32 + , k4 :: {-# UNPACK #-} !Word32 + , k5 :: {-# UNPACK #-} !Word32 + , k6 :: {-# UNPACK #-} !Word32 + , k7 :: {-# UNPACK #-} !Word32 + } + deriving (Eq, Show) + +-- parse strict 256-bit bytestring (length unchecked) to key +parse_key :: BS.ByteString -> Key +parse_key bs = + let !(WSPair k0 t0) = unsafe_parseWsPair bs + !(WSPair k1 t1) = unsafe_parseWsPair t0 + !(WSPair k2 t2) = unsafe_parseWsPair t1 + !(WSPair k3 t3) = unsafe_parseWsPair t2 + !(WSPair k4 t4) = unsafe_parseWsPair t3 + !(WSPair k5 t5) = unsafe_parseWsPair t4 + !(WSPair k6 t6) = unsafe_parseWsPair t5 + !(WSPair k7 t7) = unsafe_parseWsPair t6 + in if BS.null t7 + then Key {..} + else error "ppad-chacha (parse_key): bytes remaining" + +data Nonce = Nonce { + n0 :: {-# UNPACK #-} !Word32 + , n1 :: {-# UNPACK #-} !Word32 + , n2 :: {-# UNPACK #-} !Word32 + } + deriving (Eq, Show) + +parse_nonce :: BS.ByteString -> Nonce +parse_nonce bs = + let !(WSPair n0 t0) = unsafe_parseWsPair bs + !(WSPair n1 t1) = unsafe_parseWsPair t0 + !(WSPair n2 t2) = unsafe_parseWsPair t1 + in if BS.null t2 + then Nonce {..} + else error "ppad-chacha (parse_nonce): bytes remaining" + +newtype ChaCha s = ChaCha (PA.MutablePrimArray s Word32) + deriving Eq + +-- init chacha state +chacha + :: PrimMonad m + => BS.ByteString + -> Word32 + -> BS.ByteString + -> m (ChaCha (PrimState m)) +chacha key counter nonce = do + arr <- PA.newPrimArray 16 + PA.writePrimArray arr 00 0x61707865 + PA.writePrimArray arr 01 0x3320646e + PA.writePrimArray arr 02 0x79622d32 + PA.writePrimArray arr 03 0x6b206574 + + let Key {..} = parse_key key + PA.writePrimArray arr 04 k0 + PA.writePrimArray arr 05 k1 + PA.writePrimArray arr 06 k2 + PA.writePrimArray arr 07 k3 + PA.writePrimArray arr 08 k4 + PA.writePrimArray arr 09 k5 + PA.writePrimArray arr 10 k6 + PA.writePrimArray arr 11 k7 + + PA.writePrimArray arr 12 counter + + let Nonce {..} = parse_nonce nonce + PA.writePrimArray arr 13 n0 + PA.writePrimArray arr 14 n1 + PA.writePrimArray arr 15 n2 + + pure (ChaCha arr) + + +rounds + :: PrimMonad m + => ChaCha (PrimState m) + -> m () +rounds state = do + quarter state 00 04 08 12 + quarter state 01 05 09 13 + quarter state 02 06 10 14 + quarter state 03 07 11 15 + quarter state 00 05 10 15 + quarter state 01 06 11 12 + quarter state 02 07 08 13 + quarter state 03 04 09 14 + +serialize + :: PrimMonad m + => ChaCha (PrimState m) + -> m BS.ByteString +serialize (ChaCha m) = do + let loop acc j + | j == 16 = pure (BS.toStrict (BSB.toLazyByteString acc)) + | otherwise = do + v <- PA.readPrimArray m j + loop (acc <> BSB.word32LE v) (j + 1) + loop mempty 0 + +chacha20_block + :: PrimMonad m + => BS.ByteString + -> Word32 + -> BS.ByteString + -> m BS.ByteString +chacha20_block key counter nonce = do + state@(ChaCha s) <- chacha key counter nonce + i <- PA.freezePrimArray s 0 16 + for_ [1..10 :: Int] (const (rounds state)) + for_ [0..15 :: Int] $ \idx -> do + let !iv = PA.indexPrimArray i idx + sv <- PA.readPrimArray s idx + PA.writePrimArray s idx (iv + sv) + serialize state + + diff --git a/ppad-chacha.cabal b/ppad-chacha.cabal @@ -27,6 +27,7 @@ library build-depends: base >= 4.9 && < 5 , bytestring >= 0.9 && < 0.13 + , primitive test-suite chacha-tests type: exitcode-stdio-1.0 @@ -40,9 +41,11 @@ test-suite chacha-tests build-depends: base , bytestring + , ppad-base16 , ppad-chacha + , primitive , tasty - , tasty-quickcheck + , tasty-hunit benchmark chacha-bench type: exitcode-stdio-1.0 diff --git a/test/Main.hs b/test/Main.hs @@ -1,4 +1,107 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE UnboxedTuples #-} + module Main where +import qualified Crypto.Cipher.ChaCha as ChaCha +import qualified Data.ByteString as BS +import qualified Data.ByteString.Base16 as B16 +import Data.Foldable (for_) +import qualified Data.Primitive.PrimArray as PA +import Data.Word (Word32) +import Test.Tasty +import qualified Test.Tasty.HUnit as H + + main :: IO () -main = pure () +main = defaultMain $ testGroup "ppad-chacha" [ + quarter + , quarter_fullstate + , chacha20_block_init + , chacha20_rounds + , chacha20_block + ] + +quarter :: TestTree +quarter = H.testCase "quarter round" $ do + let e = (0xea2a92f4, 0xcb1cf8ce, 0x4581472e, 0x5881c4bb) + o = ChaCha.quarter' 0x11111111 0x01020304 0x9b8d6f43 0x01234567 + H.assertEqual mempty e o + +quarter_fullstate :: TestTree +quarter_fullstate = H.testCase "quarter round (full chacha state)" $ do + let inp :: PA.PrimArray Word32 + inp = PA.primArrayFromList [ + 0x879531e0, 0xc5ecf37d, 0x516461b1, 0xc9a62f8a + , 0x44c20ef3, 0x3390af7f, 0xd9fc690b, 0x2a5f714c + , 0x53372767, 0xb00a5631, 0x974c541a, 0x359e9963 + , 0x5c971061, 0x3d631689, 0x2098d9d6, 0x91dbd320 + ] + hot <- PA.unsafeThawPrimArray inp + + ChaCha.quarter (ChaCha.ChaCha hot) 2 7 8 13 + + o <- PA.unsafeFreezePrimArray hot + + let e :: PA.PrimArray Word32 + e = PA.primArrayFromList [ + 0x879531e0, 0xc5ecf37d, 0xbdb886dc, 0xc9a62f8a + , 0x44c20ef3, 0x3390af7f, 0xd9fc690b, 0xcfacafd2 + , 0xe46bea80, 0xb00a5631, 0x974c541a, 0x359e9963 + , 0x5c971061, 0xccc07c79, 0x2098d9d6, 0x91dbd320 + ] + + H.assertEqual mempty e o + +block_key :: BS.ByteString +block_key = + case B16.decode "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f" of + Nothing -> error "bang" + Just k -> k + +block_non :: BS.ByteString +block_non = + case B16.decode "000000090000004a00000000" of + Nothing -> error "bang" + Just n -> n + +chacha20_block_init :: TestTree +chacha20_block_init = H.testCase "chacha20 state init" $ do + ChaCha.ChaCha foo <- ChaCha.chacha block_key 1 block_non + state <- PA.freezePrimArray foo 0 16 + let ref = PA.primArrayFromList [ + 0x61707865, 0x3320646e, 0x79622d32, 0x6b206574 + , 0x03020100, 0x07060504, 0x0b0a0908, 0x0f0e0d0c + , 0x13121110, 0x17161514, 0x1b1a1918, 0x1f1e1d1c + , 0x00000001, 0x09000000, 0x4a000000, 0x00000000 + ] + H.assertEqual mempty ref state + +chacha20_rounds :: TestTree +chacha20_rounds = H.testCase "chacha20 20 rounds" $ do + state@(ChaCha.ChaCha s) <- ChaCha.chacha block_key 1 block_non + for_ [1..10 :: Int] (const (ChaCha.rounds state)) + + out <- PA.freezePrimArray s 0 16 + + let ref = PA.primArrayFromList [ + 0x837778ab, 0xe238d763, 0xa67ae21e, 0x5950bb2f + , 0xc4f2d0c7, 0xfc62bb2f, 0x8fa018fc, 0x3f5ec7b7 + , 0x335271c2, 0xf29489f3, 0xeabda8fc, 0x82e46ebd + , 0xd19c12b4, 0xb04e16de, 0x9e83d0cb, 0x4e3c50a2 + ] + + H.assertEqual mempty ref out + +chacha20_block :: TestTree +chacha20_block = H.testCase "chacha20 block function" $ do + o <- ChaCha.chacha20_block block_key 1 block_non + let raw_exp = "10f1e7e4d13b5915500fdd1fa32071c4c7d1f4c733c068030422aa9ac3d46c4ed2826446079faa0914c2d705d98b02a2b5129cd1de164eb9cbd083e8a2503c4e" + e = case B16.decode raw_exp of + Nothing -> error "bang" + Just x -> x + + H.assertEqual mempty e o +